1056 lines
44 KiB
Python
1056 lines
44 KiB
Python
"""
|
||
LangGraph工作流节点
|
||
实现具体的工作流节点逻辑
|
||
"""
|
||
|
||
import json
|
||
import asyncio
|
||
from typing import Dict, Any, List, Tuple, Optional
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
# LangSmith会自动追踪LangChain的LLM调用
|
||
|
||
from retriver.langgraph.graph_state import (
|
||
QueryState,
|
||
RetrievalResult,
|
||
SufficiencyCheck,
|
||
update_state_with_retrieval,
|
||
update_state_with_sufficiency_check,
|
||
increment_iteration,
|
||
finalize_state
|
||
)
|
||
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever
|
||
from retriver.langgraph.langchain_components import (
|
||
OneAPILLM,
|
||
SufficiencyCheckParser,
|
||
QueryComplexityParser,
|
||
QUERY_COMPLEXITY_CHECK_PROMPT,
|
||
SUFFICIENCY_CHECK_PROMPT,
|
||
QUERY_DECOMPOSITION_PROMPT,
|
||
SUB_QUERY_GENERATION_PROMPT,
|
||
SIMPLE_ANSWER_PROMPT,
|
||
FINAL_ANSWER_PROMPT,
|
||
format_passages,
|
||
format_mixed_passages,
|
||
format_sub_queries
|
||
)
|
||
from retriver.langgraph.es_vector_retriever import ESVectorRetriever
|
||
from prompt_loader import get_prompt_loader
|
||
|
||
|
||
class GraphNodes:
|
||
"""工作流节点实现类"""
|
||
|
||
def __init__(
|
||
self,
|
||
retriever: LangChainHippoRAGRetriever,
|
||
llm: OneAPILLM,
|
||
keyword: str,
|
||
max_parallel_retrievals: int = 2,
|
||
simple_retrieval_top_k: int = 3,
|
||
complexity_llm: Optional[OneAPILLM] = None,
|
||
sufficiency_llm: Optional[OneAPILLM] = None,
|
||
skip_llm_generation: bool = False
|
||
):
|
||
"""
|
||
初始化节点处理器
|
||
|
||
Args:
|
||
retriever: HippoRAG检索器
|
||
llm: OneAPI LLM (用于生成答案)
|
||
keyword: ES索引关键词
|
||
max_parallel_retrievals: 最大并行检索数
|
||
simple_retrieval_top_k: 简单检索返回文档数
|
||
complexity_llm: 复杂度判断专用LLM(如果不指定则使用llm)
|
||
sufficiency_llm: 充分性检查专用LLM(如果不指定则使用llm)
|
||
skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果)
|
||
"""
|
||
self.retriever = retriever
|
||
self.llm = llm
|
||
self.complexity_llm = complexity_llm or llm # 如果没有指定,使用主LLM
|
||
self.sufficiency_llm = sufficiency_llm or llm # 如果没有指定,使用主LLM
|
||
self.keyword = keyword
|
||
self.max_parallel_retrievals = max_parallel_retrievals
|
||
self.skip_llm_generation = skip_llm_generation
|
||
self.sufficiency_parser = SufficiencyCheckParser()
|
||
self.complexity_parser = QueryComplexityParser()
|
||
|
||
# 创建ES向量检索器用于简单查询
|
||
self.es_vector_retriever = ESVectorRetriever(
|
||
keyword=keyword,
|
||
top_k=simple_retrieval_top_k
|
||
)
|
||
|
||
# 获取prompt加载器
|
||
self.prompt_loader = get_prompt_loader()
|
||
|
||
|
||
def _extract_response_text(self, response):
|
||
"""统一的响应文本提取方法"""
|
||
if hasattr(response, 'generations') and response.generations:
|
||
return response.generations[0][0].text
|
||
elif hasattr(response, 'content'):
|
||
return response.content
|
||
elif isinstance(response, dict) and 'response' in response:
|
||
return response['response']
|
||
else:
|
||
return str(response)
|
||
|
||
def query_complexity_check_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
查询复杂度判断节点
|
||
判断用户查询是否需要复杂的知识图谱推理
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[?] 执行查询复杂度判断: {state['original_query']}")
|
||
|
||
# 检查是否跳过LLM
|
||
if self.prompt_loader.should_skip_llm('query_complexity_check'):
|
||
print(f"[NEXT] 跳过复杂度检查LLM调用,默认为简单查询")
|
||
state['is_complex'] = False
|
||
state['complexity_level'] = 'simple'
|
||
return state
|
||
|
||
# 构建提示词
|
||
prompt = QUERY_COMPLEXITY_CHECK_PROMPT.format(
|
||
query=state['original_query']
|
||
)
|
||
|
||
try:
|
||
# 调用专用的复杂度判断LLM
|
||
response = self.complexity_llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
# 解析响应
|
||
complexity_result = self.complexity_parser.parse(response_text)
|
||
|
||
# 更新状态
|
||
state['query_complexity'] = {
|
||
"is_complex": complexity_result.is_complex,
|
||
"complexity_level": complexity_result.complexity_level,
|
||
"confidence": complexity_result.confidence,
|
||
"reason": complexity_result.reason
|
||
}
|
||
state['is_complex_query'] = complexity_result.is_complex
|
||
|
||
# 更新调试信息
|
||
state["debug_info"]["llm_calls"] += 1
|
||
|
||
print(f"[INFO] 复杂度判断结果: {'复杂' if complexity_result.is_complex else '简单'} "
|
||
f"(置信度: {complexity_result.confidence:.2f})")
|
||
print(f"理由: {complexity_result.reason}")
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 复杂度判断失败: {e}")
|
||
# 如果判断失败,默认为复杂查询,走现有逻辑
|
||
state['query_complexity'] = {
|
||
"is_complex": True,
|
||
"complexity_level": "complex",
|
||
"confidence": 0.5,
|
||
"reason": f"复杂度判断失败: {str(e)},默认使用复杂检索"
|
||
}
|
||
state['is_complex_query'] = True
|
||
return state
|
||
|
||
def debug_mode_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
调试模式节点
|
||
根据用户设置的debug_mode参数决定是否覆盖复杂度判断结果
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[?] 执行调试模式检查: mode={state['debug_mode']}")
|
||
|
||
# 保存原始复杂度判断结果
|
||
original_is_complex = state['is_complex_query']
|
||
|
||
if state['debug_mode'] == 'simple':
|
||
print("[?] 调试模式: 强制使用简单检索路径")
|
||
# 强制设置为简单查询,无论原始复杂度判断结果如何
|
||
state['is_complex_query'] = False
|
||
# 保留原始复杂度判断结果,但添加调试信息
|
||
if 'debug_override' not in state['query_complexity']:
|
||
state['query_complexity']['debug_override'] = {
|
||
'original_complexity': original_is_complex,
|
||
'debug_mode': 'simple',
|
||
'override_reason': '调试模式强制使用简单检索路径'
|
||
}
|
||
elif state['debug_mode'] == 'complex':
|
||
print("[?] 调试模式: 强制使用复杂检索路径")
|
||
# 强制设置为复杂查询,无论原始复杂度判断结果如何
|
||
state['is_complex_query'] = True
|
||
# 保留原始复杂度判断结果,但添加调试信息
|
||
if 'debug_override' not in state['query_complexity']:
|
||
state['query_complexity']['debug_override'] = {
|
||
'original_complexity': original_is_complex,
|
||
'debug_mode': 'complex',
|
||
'override_reason': '调试模式强制使用复杂检索路径'
|
||
}
|
||
else:
|
||
# debug_mode = "0" 或其他值,使用原始复杂度判断结果
|
||
print("[?] 调试模式: 使用自动复杂度判断结果")
|
||
|
||
return state
|
||
|
||
def query_decomposition_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
查询分解节点
|
||
将用户原始查询分解为2个便于分头检索的子查询
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[TARGET] 执行查询分解: {state['original_query']}")
|
||
|
||
# 检查是否跳过LLM
|
||
if self.prompt_loader.should_skip_llm('query_decomposition'):
|
||
print(f"[NEXT] 跳过查询分解LLM调用,使用原查询")
|
||
# 不分解,直接使用原查询
|
||
state['decomposed_sub_queries'] = []
|
||
state['sub_queries'] = [state['original_query']]
|
||
return state
|
||
|
||
# 构建提示词
|
||
prompt = QUERY_DECOMPOSITION_PROMPT.format(
|
||
original_query=state['original_query']
|
||
)
|
||
|
||
try:
|
||
# 调用LLM生成子查询
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
print(f"[?] LLM原始响应: {response_text[:200]}...")
|
||
|
||
# 清理和提取JSON内容
|
||
cleaned_text = response_text.strip()
|
||
|
||
# 如果响应包含"答案:",提取其后的内容
|
||
if "答案:" in cleaned_text or "答案:" in cleaned_text:
|
||
# 找到"答案:"后的内容
|
||
answer_markers = ["答案:", "答案:"]
|
||
for marker in answer_markers:
|
||
if marker in cleaned_text:
|
||
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
|
||
break
|
||
|
||
# 清理markdown代码块标记
|
||
if cleaned_text.startswith('```json'):
|
||
cleaned_text = cleaned_text[7:] # 移除 ```json
|
||
elif cleaned_text.startswith('```'):
|
||
cleaned_text = cleaned_text[3:] # 移除 ```
|
||
if cleaned_text.endswith('```'):
|
||
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
||
cleaned_text = cleaned_text.strip()
|
||
|
||
# 尝试提取JSON部分(查找{开始到}结束的内容)
|
||
if '{' in cleaned_text and '}' in cleaned_text:
|
||
start_idx = cleaned_text.find('{')
|
||
end_idx = cleaned_text.rfind('}') + 1
|
||
cleaned_text = cleaned_text[start_idx:end_idx]
|
||
|
||
# 解析子查询
|
||
try:
|
||
data = json.loads(cleaned_text)
|
||
raw_decomposed_queries = data.get('sub_queries', [])
|
||
|
||
# 处理不同格式的子查询(与sub_query_generation_node保持一致)
|
||
decomposed_sub_queries = []
|
||
for item in raw_decomposed_queries:
|
||
if isinstance(item, str):
|
||
decomposed_sub_queries.append(item)
|
||
elif isinstance(item, dict) and 'query' in item:
|
||
decomposed_sub_queries.append(item['query'])
|
||
elif isinstance(item, dict):
|
||
query_text = item.get('text', item.get('content', str(item)))
|
||
decomposed_sub_queries.append(query_text)
|
||
else:
|
||
decomposed_sub_queries.append(str(item))
|
||
|
||
# 确保都是字符串
|
||
decomposed_sub_queries = [str(query).strip() for query in decomposed_sub_queries if query]
|
||
print(f"[OK] JSON解析成功,获得子查询: {decomposed_sub_queries}")
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"[ERROR] JSON解析失败: {e}")
|
||
print(f"尝试规则提取...")
|
||
# 如果JSON解析失败,使用简单规则提取
|
||
lines = response_text.split('\n')
|
||
decomposed_sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
|
||
print(f"规则提取结果: {decomposed_sub_queries}")
|
||
|
||
# 确保有子查询
|
||
if not decomposed_sub_queries:
|
||
print(f"[WARNING] LLM未生成有效子查询,使用简单分解策略...")
|
||
# 简单按标点符号分解
|
||
original = state['original_query']
|
||
if '?' in original:
|
||
parts = [part.strip() for part in original.split('?') if part.strip()]
|
||
if len(parts) >= 2:
|
||
decomposed_sub_queries = [parts[0], parts[1]]
|
||
print(f"按中文问号分解: {decomposed_sub_queries}")
|
||
elif '?' in original:
|
||
parts = [part.strip() for part in original.split('?') if part.strip()]
|
||
if len(parts) >= 2:
|
||
decomposed_sub_queries = [parts[0], parts[1]]
|
||
print(f"按英文问号分解: {decomposed_sub_queries}")
|
||
|
||
# 如果还是没有,使用原查询
|
||
if not decomposed_sub_queries:
|
||
print(f"[WARNING] 无法自动分解,使用原查询作为两个子查询")
|
||
decomposed_sub_queries = [original, original]
|
||
|
||
elif len(decomposed_sub_queries) == 1:
|
||
print(f"[WARNING] 只获得1个子查询,补充第二个")
|
||
# 如果只有一个子查询,使用原查询作为第二个
|
||
decomposed_sub_queries.append(state['original_query'])
|
||
|
||
# 限制为2个子查询
|
||
decomposed_sub_queries = decomposed_sub_queries[:2]
|
||
|
||
# 更新状态 - 存储初始分解的子查询
|
||
state['decomposed_sub_queries'] = decomposed_sub_queries
|
||
state['sub_queries'].extend(decomposed_sub_queries)
|
||
|
||
# 更新调试信息
|
||
state["debug_info"]["llm_calls"] += 1
|
||
|
||
print(f"[OK] 查询分解完成: {decomposed_sub_queries}")
|
||
return state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 查询分解失败: {e}")
|
||
# 如果生成失败,使用默认子查询
|
||
default_sub_queries = [
|
||
state['original_query'] + " 详细信息",
|
||
state['original_query'] + " 相关内容"
|
||
]
|
||
state['decomposed_sub_queries'] = default_sub_queries
|
||
state['sub_queries'].extend(default_sub_queries)
|
||
return state
|
||
|
||
def simple_vector_retrieval_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
简单向量检索节点
|
||
直接与ES向量库中的文本段落进行向量匹配
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[SEARCH] 执行简单向量检索: {state['original_query']}")
|
||
|
||
try:
|
||
# 使用ES向量检索器检索相关文档
|
||
documents = self.es_vector_retriever.retrieve(state['original_query'])
|
||
|
||
# 提取段落内容和来源信息
|
||
passages = [doc.page_content for doc in documents]
|
||
sources = [f"简单检索-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
|
||
|
||
# 创建检索结果
|
||
retrieval_result = RetrievalResult(
|
||
passages=passages,
|
||
documents=documents,
|
||
sources=sources,
|
||
query=state['original_query'],
|
||
iteration=state['current_iteration']
|
||
)
|
||
|
||
# 更新状态
|
||
updated_state = update_state_with_retrieval(state, retrieval_result)
|
||
|
||
print(f"[OK] 简单向量检索完成,获得 {len(passages)} 个段落")
|
||
return updated_state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 简单向量检索失败: {e}")
|
||
# 如果检索失败,返回空结果
|
||
retrieval_result = RetrievalResult(
|
||
passages=[],
|
||
documents=[],
|
||
sources=[],
|
||
query=state['original_query'],
|
||
iteration=state['current_iteration']
|
||
)
|
||
return update_state_with_retrieval(state, retrieval_result)
|
||
|
||
def simple_answer_generation_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
简单答案生成节点
|
||
基于简单检索的结果生成答案
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[NOTE] 生成简单查询答案")
|
||
|
||
# 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||
if 'all_documents' in state and state['all_documents']:
|
||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||
else:
|
||
formatted_passages = format_passages(state['all_passages'])
|
||
|
||
# 检查是否跳过LLM生成(全局配置或prompt级别配置)
|
||
if self.skip_llm_generation or self.prompt_loader.should_skip_llm('simple_answer'):
|
||
if self.skip_llm_generation:
|
||
print(f"[SEARCH] 跳过LLM生成(skip_llm_generation=true),直接返回检索结果")
|
||
else:
|
||
print(f"[NEXT] 跳过简单答案LLM调用(simple_answer.skip_llm=true),直接返回检索结果")
|
||
final_answer = f"检索到的信息:\n{formatted_passages}"
|
||
state['final_answer'] = final_answer
|
||
return finalize_state(state, final_answer)
|
||
|
||
# 构建提示词
|
||
prompt = SIMPLE_ANSWER_PROMPT.format(
|
||
query=state['original_query'],
|
||
passages=formatted_passages
|
||
)
|
||
|
||
try:
|
||
# 调用LLM生成答案
|
||
response = self.llm.invoke(prompt)
|
||
final_answer = self._extract_response_text(response)
|
||
|
||
if not final_answer.strip():
|
||
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
|
||
|
||
# 完成状态
|
||
updated_state = finalize_state(state, final_answer)
|
||
|
||
print(f"[OK] 简单答案生成完成 (长度: {len(final_answer)} 字符)")
|
||
return updated_state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 简单答案生成失败: {e}")
|
||
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
|
||
return finalize_state(state, error_answer)
|
||
|
||
def initial_retrieval_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
并行初始检索节点
|
||
使用原始查询和2个分解的子查询并行进行检索
|
||
每个查询返回混合检索结果:TOP-10个事件节点 + TOP-3个段落节点
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
# 准备要检索的查询列表:原始查询 + 2个子查询
|
||
original_query = state['original_query']
|
||
sub_queries = state.get('decomposed_sub_queries', [])
|
||
|
||
all_queries = [original_query] + sub_queries[:2] # 确保最多3个查询
|
||
|
||
print(f"[SEARCH] 执行并行初始检索 - {len(all_queries)} 个查询")
|
||
for i, query in enumerate(all_queries):
|
||
query_type = "原始查询" if i == 0 else f"子查询{i}"
|
||
print(f" {query_type}: {query}")
|
||
|
||
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, List, str]:
|
||
"""检索单个查询,返回文档、段落、源信息"""
|
||
import time
|
||
|
||
# 根据查询类型设置标签
|
||
if index == 0:
|
||
query_label = "原始查询"
|
||
else:
|
||
query_label = f"子查询{index}"
|
||
|
||
start_time = time.time()
|
||
print(f"[STARTING] {query_label} 开始检索 [{time.strftime('%H:%M:%S', time.localtime(start_time))}]")
|
||
|
||
try:
|
||
documents = self.retriever.invoke(query)
|
||
|
||
# 检索器现在返回混合结果(事件+段落),不再限制数量
|
||
top_documents = documents # 使用所有检索到的文档
|
||
passages = [doc.page_content for doc in top_documents]
|
||
|
||
# 根据查询类型设置源标识,支持混合节点类型
|
||
sources = []
|
||
for doc in top_documents:
|
||
# 优先使用node_id,然后是passage_id
|
||
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id', 'unknown')
|
||
node_type = doc.metadata.get('node_type', 'unknown')
|
||
|
||
if index == 0:
|
||
sources.append(f"原始查询-{node_type}-{doc_id}")
|
||
else:
|
||
sources.append(f"子查询{index}-{node_type}-{doc_id}")
|
||
|
||
end_time = time.time()
|
||
duration = end_time - start_time
|
||
print(f"[OK] {query_label} 检索完成 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒,获得 {len(passages)} 个内容(事件+段落)")
|
||
return index, documents, passages, sources, query_label
|
||
|
||
except Exception as e:
|
||
end_time = time.time()
|
||
duration = end_time - start_time
|
||
print(f"[ERROR] {query_label} 检索失败 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒 - 错误: {e}")
|
||
return index, [], [], [], query_label
|
||
|
||
# 并行执行检索
|
||
all_documents = []
|
||
all_passages = []
|
||
all_sources = []
|
||
retrieval_details = {}
|
||
|
||
import time
|
||
parallel_start_time = time.time()
|
||
print(f"[FAST] 开始并行执行 {len(all_queries)} 个检索任务 [{time.strftime('%H:%M:%S', time.localtime(parallel_start_time))}]")
|
||
|
||
with ThreadPoolExecutor(max_workers=min(3, len(all_queries))) as executor:
|
||
# 提交检索任务
|
||
futures = {
|
||
executor.submit(retrieve_single_query, query, i): (query, i)
|
||
for i, query in enumerate(all_queries)
|
||
}
|
||
print(f"[?] 所有 {len(futures)} 个检索任务已提交到线程池")
|
||
|
||
# 收集结果
|
||
for future in as_completed(futures):
|
||
query, query_index = futures[future]
|
||
try:
|
||
index, documents, passages, sources, query_label = future.result()
|
||
|
||
# 记录检索详情
|
||
retrieval_details[query_label] = {
|
||
'query': query,
|
||
'passages_count': len(passages),
|
||
'documents_count': len(documents)
|
||
}
|
||
|
||
# 合并结果(暂时不去重,稍后统一处理)
|
||
all_documents.extend(documents)
|
||
all_passages.extend(passages)
|
||
all_sources.extend(sources)
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 查询 {query_index+1} 处理失败: {e}")
|
||
|
||
parallel_end_time = time.time()
|
||
parallel_duration = parallel_end_time - parallel_start_time
|
||
print(f"[TARGET] 并行检索全部完成 [{time.strftime('%H:%M:%S', time.localtime(parallel_end_time))}] - 总耗时: {parallel_duration:.2f}秒")
|
||
|
||
# 去重处理:基于文档ID或内容去重
|
||
unique_documents = []
|
||
unique_passages = []
|
||
unique_sources = []
|
||
seen_passage_ids = set()
|
||
seen_content_hashes = set()
|
||
|
||
for i, (doc, passage, source) in enumerate(zip(all_documents, all_passages, all_sources)):
|
||
# 尝试使用node_id或passage_id去重,支持混合节点类型
|
||
doc_id = None
|
||
if doc:
|
||
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id')
|
||
|
||
if doc_id and doc_id in seen_passage_ids:
|
||
continue
|
||
|
||
# 使用内容hash去重(作为backup)
|
||
content_hash = hash(passage.strip())
|
||
if content_hash in seen_content_hashes:
|
||
continue
|
||
|
||
# 添加到去重后的结果
|
||
unique_documents.append(doc)
|
||
unique_passages.append(passage)
|
||
unique_sources.append(source)
|
||
|
||
if doc_id:
|
||
seen_passage_ids.add(doc_id)
|
||
seen_content_hashes.add(content_hash)
|
||
|
||
removed_count = len(all_passages) - len(unique_passages)
|
||
if removed_count > 0:
|
||
print(f"[SEARCH] 去重处理: 移除了 {removed_count} 个重复内容")
|
||
|
||
# 创建检索结果
|
||
query_description = f"并行检索: 原始查询 + {len(sub_queries)} 个子查询"
|
||
retrieval_result = RetrievalResult(
|
||
passages=unique_passages,
|
||
documents=unique_documents,
|
||
sources=unique_sources,
|
||
query=query_description,
|
||
iteration=state['current_iteration']
|
||
)
|
||
|
||
# 更新状态
|
||
updated_state = update_state_with_retrieval(state, retrieval_result)
|
||
|
||
# 存储检索详情到状态中,便于后续分析
|
||
updated_state['initial_retrieval_details'] = retrieval_details
|
||
|
||
# 收集PageRank分数信息(每个查询的完整PageRank结果)
|
||
for i, query in enumerate(all_queries):
|
||
try:
|
||
# 跳过PageRank数据收集以避免LangSmith传输大量数据
|
||
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
|
||
# 仅设置标识表示数据可用(实际数据在HippoRAG内部处理)
|
||
updated_state['pagerank_data_available'] = True
|
||
except Exception as e:
|
||
print(f"[WARNING] 收集查询{i+1}的PageRank分数失败: {e}")
|
||
|
||
total_passages_before = len(all_passages)
|
||
total_passages_after = len(unique_passages)
|
||
print(f"[SUCCESS] 并行初始检索完成")
|
||
print(f" 检索前总内容: {total_passages_before}, 去重后: {total_passages_after}")
|
||
print(f" 原始查询: {retrieval_details.get('原始查询', {}).get('passages_count', 0)} 个内容(事件+段落)")
|
||
for i in range(1, len(all_queries)):
|
||
key = f"子查询{i}"
|
||
count = retrieval_details.get(key, {}).get('passages_count', 0)
|
||
print(f" 子查询{i}: {count} 个内容(事件+段落)")
|
||
|
||
return updated_state
|
||
|
||
def sufficiency_check_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
充分性检查节点
|
||
判断当前检索到的信息是否足够回答用户查询
|
||
包含对分解子查询的处理
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[?] 执行充分性检查 (迭代 {state['current_iteration']})")
|
||
|
||
# 检查是否跳过LLM
|
||
if self.prompt_loader.should_skip_llm('sufficiency_check'):
|
||
print(f"[NEXT] 跳过充分性检查LLM调用,默认为充分")
|
||
state['is_sufficient'] = True
|
||
state['sufficiency_confidence'] = 1.0
|
||
return state
|
||
|
||
# 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||
if 'all_documents' in state and state['all_documents']:
|
||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||
else:
|
||
formatted_passages = format_passages(state['all_passages'])
|
||
|
||
# 格式化分解的子查询
|
||
decomposed_sub_queries = state.get('decomposed_sub_queries', [])
|
||
formatted_decomposed_queries = format_sub_queries(decomposed_sub_queries) if decomposed_sub_queries else "无"
|
||
|
||
# 构建提示词,包含分解的子查询信息
|
||
prompt = SUFFICIENCY_CHECK_PROMPT.format(
|
||
query=state['original_query'],
|
||
passages=formatted_passages,
|
||
decomposed_sub_queries=formatted_decomposed_queries
|
||
)
|
||
|
||
# 调用专用的充分性检查LLM
|
||
try:
|
||
response = self.sufficiency_llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
# 解析响应
|
||
sufficiency_result = self.sufficiency_parser.parse(response_text)
|
||
|
||
# 创建充分性检查结果
|
||
sufficiency_check = SufficiencyCheck(
|
||
is_sufficient=sufficiency_result.is_sufficient,
|
||
confidence=sufficiency_result.confidence,
|
||
reason=sufficiency_result.reason,
|
||
sub_queries=sufficiency_result.sub_queries
|
||
)
|
||
|
||
# 更新状态
|
||
updated_state = update_state_with_sufficiency_check(state, sufficiency_check)
|
||
|
||
# 更新调试信息
|
||
updated_state["debug_info"]["llm_calls"] += 1
|
||
|
||
print(f"[INFO] 充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'} "
|
||
f"(置信度: {sufficiency_result.confidence:.2f})")
|
||
print(f" 基于 {len(state['all_passages'])} 个段落 (来自原始查询和{len(decomposed_sub_queries)}个子查询)")
|
||
|
||
if not sufficiency_result.is_sufficient and sufficiency_result.sub_queries:
|
||
print(f"[TARGET] 生成新的子查询: {sufficiency_result.sub_queries}")
|
||
|
||
return updated_state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 充分性检查失败: {e}")
|
||
# 如果检查失败,假设不充分并生成默认子查询
|
||
sufficiency_check = SufficiencyCheck(
|
||
is_sufficient=False,
|
||
confidence=0.5,
|
||
reason=f"充分性检查失败: {str(e)}",
|
||
sub_queries=[state['original_query'] + " 详细信息"]
|
||
)
|
||
return update_state_with_sufficiency_check(state, sufficiency_check)
|
||
|
||
def sub_query_generation_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
子查询生成节点
|
||
如果充分性检查不通过,生成子查询
|
||
考虑之前已经生成的分解子查询
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[TARGET] 生成子查询")
|
||
|
||
# 检查是否跳过LLM
|
||
if self.prompt_loader.should_skip_llm('sub_query_generation'):
|
||
print(f"[NEXT] 跳过子查询生成LLM调用")
|
||
# 不生成新查询,继续下一轮
|
||
state['current_iteration'] += 1
|
||
return state
|
||
|
||
# 如果已经有子查询,直接返回
|
||
if state['current_sub_queries']:
|
||
print(f"[OK] 使用现有子查询: {state['current_sub_queries']}")
|
||
return state
|
||
|
||
# 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||
if 'all_documents' in state and state['all_documents']:
|
||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||
else:
|
||
formatted_passages = format_passages(state['all_passages'])
|
||
|
||
# 格式化之前生成的所有子查询(包括分解的子查询)
|
||
previous_sub_queries = state.get('sub_queries', [])
|
||
formatted_previous_queries = format_sub_queries(previous_sub_queries) if previous_sub_queries else "无"
|
||
|
||
# 获取充分性检查的不充分原因
|
||
sufficiency_check = state.get('sufficiency_check', {})
|
||
insufficiency_reason = sufficiency_check.get('reason', '信息不够充分,需要更多相关信息')
|
||
|
||
# 构建提示词,包含充分性检查的反馈
|
||
prompt = SUB_QUERY_GENERATION_PROMPT.format(
|
||
original_query=state['original_query'],
|
||
existing_passages=formatted_passages,
|
||
previous_sub_queries=formatted_previous_queries,
|
||
insufficiency_reason=insufficiency_reason
|
||
)
|
||
|
||
try:
|
||
# 调用LLM生成子查询
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
# 更新调试信息
|
||
state["debug_info"]["llm_calls"] += 1
|
||
|
||
# 清理和提取JSON内容
|
||
cleaned_text = response_text.strip()
|
||
|
||
# 如果响应包含"答案:",提取其后的内容
|
||
if "答案:" in cleaned_text or "答案:" in cleaned_text:
|
||
# 找到"答案:"后的内容
|
||
answer_markers = ["答案:", "答案:"]
|
||
for marker in answer_markers:
|
||
if marker in cleaned_text:
|
||
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
|
||
break
|
||
|
||
# 清理markdown代码块标记
|
||
if cleaned_text.startswith('```json'):
|
||
cleaned_text = cleaned_text[7:] # 移除 ```json
|
||
elif cleaned_text.startswith('```'):
|
||
cleaned_text = cleaned_text[3:] # 移除 ```
|
||
if cleaned_text.endswith('```'):
|
||
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
||
cleaned_text = cleaned_text.strip()
|
||
|
||
# 尝试提取JSON部分(查找{开始到}结束的内容)
|
||
if '{' in cleaned_text and '}' in cleaned_text:
|
||
start_idx = cleaned_text.find('{')
|
||
end_idx = cleaned_text.rfind('}') + 1
|
||
cleaned_text = cleaned_text[start_idx:end_idx]
|
||
|
||
# 解析子查询
|
||
try:
|
||
data = json.loads(cleaned_text)
|
||
raw_sub_queries = data.get('sub_queries', [])
|
||
|
||
# 处理不同格式的子查询
|
||
sub_queries = []
|
||
for item in raw_sub_queries:
|
||
if isinstance(item, str):
|
||
# 如果是字符串,直接使用
|
||
sub_queries.append(item)
|
||
elif isinstance(item, dict) and 'query' in item:
|
||
# 如果是字典,提取query字段
|
||
sub_queries.append(item['query'])
|
||
elif isinstance(item, dict):
|
||
# 如果是字典但没有query字段,尝试找到查询内容
|
||
query_text = item.get('text', item.get('content', str(item)))
|
||
sub_queries.append(query_text)
|
||
else:
|
||
# 其他情况,转为字符串
|
||
sub_queries.append(str(item))
|
||
|
||
except json.JSONDecodeError:
|
||
# 如果JSON解析失败,使用简单规则提取
|
||
lines = response_text.split('\n')
|
||
sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
|
||
|
||
# 确保有子查询且都是字符串
|
||
if not sub_queries:
|
||
sub_queries = [state['original_query'] + " 更多细节"]
|
||
|
||
# 确保所有子查询都是字符串
|
||
sub_queries = [str(query).strip() for query in sub_queries if query]
|
||
|
||
# 更新状态
|
||
state['current_sub_queries'] = sub_queries[:2] # 最多2个子查询
|
||
state['sub_queries'].extend(state['current_sub_queries'])
|
||
|
||
print(f"[OK] 生成子查询: {state['current_sub_queries']}")
|
||
print(f" (避免与之前的子查询重复: {formatted_previous_queries})")
|
||
return state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 子查询生成失败: {e}")
|
||
# 如果生成失败,使用默认子查询
|
||
default_sub_query = state['original_query'] + " 补充信息"
|
||
state['current_sub_queries'] = [default_sub_query]
|
||
state['sub_queries'].append(default_sub_query)
|
||
return state
|
||
|
||
def parallel_retrieval_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
并行检索节点
|
||
使用子查询并行进行检索
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
sub_queries = state['current_sub_queries']
|
||
if not sub_queries:
|
||
print("[WARNING] 没有子查询,跳过并行检索")
|
||
return state
|
||
|
||
print(f"[?] 并行检索 {len(sub_queries)} 个子查询")
|
||
|
||
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, str]:
|
||
"""检索单个查询"""
|
||
try:
|
||
documents = self.retriever.invoke(query)
|
||
passages = [doc.page_content for doc in documents]
|
||
sources = [f"子查询{index+1}-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
|
||
return index, documents, passages, sources
|
||
except Exception as e:
|
||
print(f"[ERROR] 子查询 {index+1} 检索失败: {e}")
|
||
return index, [], [], []
|
||
|
||
# 并行执行检索
|
||
all_new_documents = []
|
||
all_new_passages = []
|
||
all_new_sources = []
|
||
|
||
with ThreadPoolExecutor(max_workers=self.max_parallel_retrievals) as executor:
|
||
# 提交检索任务
|
||
futures = {
|
||
executor.submit(retrieve_single_query, query, i): (query, i)
|
||
for i, query in enumerate(sub_queries)
|
||
}
|
||
|
||
# 收集结果
|
||
for future in as_completed(futures):
|
||
query, query_index = futures[future]
|
||
try:
|
||
index, documents, passages, sources = future.result()
|
||
if passages: # 只添加非空结果
|
||
all_new_documents.extend(documents)
|
||
all_new_passages.extend(passages)
|
||
all_new_sources.extend(sources)
|
||
print(f"[OK] 子查询 {index+1} 完成,获得 {len(passages)} 个段落")
|
||
else:
|
||
print(f"[WARNING] 子查询 {index+1} 无结果")
|
||
except Exception as e:
|
||
print(f"[ERROR] 子查询 {query_index+1} 处理失败: {e}")
|
||
|
||
# 更新状态
|
||
if all_new_passages:
|
||
retrieval_result = RetrievalResult(
|
||
passages=all_new_passages,
|
||
documents=all_new_documents,
|
||
sources=all_new_sources,
|
||
query=f"并行检索: {', '.join(sub_queries)}",
|
||
iteration=state['current_iteration']
|
||
)
|
||
state = update_state_with_retrieval(state, retrieval_result)
|
||
print(f"[SUCCESS] 并行检索完成,总共获得 {len(all_new_passages)} 个新段落")
|
||
|
||
# 收集子查询的PageRank分数信息
|
||
for i, query in enumerate(sub_queries):
|
||
try:
|
||
# 跳过PageRank数据收集以避免LangSmith传输大量数据
|
||
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
|
||
# 仅设置标识表示数据可用(实际数据在HippoRAG内部处理)
|
||
state['pagerank_data_available'] = True
|
||
except Exception as e:
|
||
print(f"[WARNING] 收集并行子查询{i+1}的PageRank分数失败: {e}")
|
||
else:
|
||
print("[WARNING] 并行检索无有效结果")
|
||
|
||
# 清空当前子查询
|
||
state['current_sub_queries'] = []
|
||
|
||
return state
|
||
|
||
def final_answer_generation_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
最终答案生成节点
|
||
基于所有检索到的信息生成最终答案
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
# 格式化所有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||
if 'all_documents' in state and state['all_documents']:
|
||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||
else:
|
||
formatted_passages = format_passages(state['all_passages'])
|
||
formatted_sub_queries = format_sub_queries(state['sub_queries'])
|
||
|
||
# 检查是否跳过LLM生成
|
||
if self.skip_llm_generation:
|
||
print(f"[SEARCH] 跳过LLM生成,直接返回检索结果")
|
||
|
||
# 构建格式化的检索结果作为最终答案
|
||
retrieval_summary = f"""【检索结果汇总】
|
||
|
||
查询问题:{state['original_query']}
|
||
|
||
检索到 {len(state['all_passages'])} 个相关段落:
|
||
|
||
{formatted_passages}
|
||
|
||
"""
|
||
if state['sub_queries']:
|
||
retrieval_summary += f"""
|
||
相关子查询:
|
||
{formatted_sub_queries}
|
||
"""
|
||
|
||
retrieval_summary += f"""
|
||
检索统计:
|
||
- 查询复杂度:{state.get('query_complexity', 'unknown')}
|
||
- 是否复杂查询:{state.get('is_complex_query', False)}
|
||
- 迭代次数:{state.get('current_iteration', 0)}
|
||
- 信息充分性:{state.get('is_sufficient', False)}
|
||
"""
|
||
|
||
# 完成状态
|
||
updated_state = finalize_state(state, retrieval_summary)
|
||
print(f"[OK] 检索结果返回完成 (长度: {len(retrieval_summary)} 字符)")
|
||
return updated_state
|
||
|
||
# 检查是否有流式回调
|
||
stream_callback = state.get('stream_callback')
|
||
|
||
# 调试:打印state中是否有stream_callback
|
||
if stream_callback:
|
||
print(f"[DEBUG] 检测到stream_callback,类型: {type(stream_callback)}")
|
||
else:
|
||
print(f"[DEBUG] 没有stream_callback,state keys: {list(state.keys())[:10]}")
|
||
|
||
print(f"[NOTE] 生成最终答案" + (" (流式模式)" if stream_callback else ""))
|
||
|
||
# 检查是否跳过LLM
|
||
if self.prompt_loader.should_skip_llm('final_answer'):
|
||
print(f"[NEXT] 跳过最终答案LLM调用,直接返回检索结果")
|
||
final_answer = f"检索到的信息:\n{formatted_passages}"
|
||
state['final_answer'] = final_answer
|
||
return finalize_state(state, final_answer)
|
||
|
||
# 构建提示词
|
||
prompt = FINAL_ANSWER_PROMPT.format(
|
||
original_query=state['original_query'],
|
||
all_passages=formatted_passages,
|
||
sub_queries=formatted_sub_queries
|
||
)
|
||
|
||
try:
|
||
if stream_callback:
|
||
# 包装回调,添加类型标识
|
||
def answer_stream_callback(chunk):
|
||
stream_callback("answer_chunk", {"text": chunk})
|
||
|
||
# 构建config传递回调
|
||
config = {
|
||
'metadata': {
|
||
'stream_callback': answer_stream_callback
|
||
}
|
||
}
|
||
|
||
# 调用LLM(会自动使用流式)
|
||
final_answer = self.llm.invoke(prompt, config=config)
|
||
else:
|
||
# 非流式调用
|
||
response = self.llm.invoke(prompt)
|
||
final_answer = self._extract_response_text(response)
|
||
|
||
if not final_answer.strip():
|
||
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
|
||
|
||
# 完成状态
|
||
updated_state = finalize_state(state, final_answer)
|
||
|
||
print(f"[OK] 最终答案生成完成 (长度: {len(final_answer)} 字符)")
|
||
return updated_state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 最终答案生成失败: {e}")
|
||
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
|
||
return finalize_state(state, error_answer)
|
||
|
||
|
||
def next_iteration_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
下一轮迭代节点
|
||
增加迭代计数并准备下一轮
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[?] 进入迭代 {state['current_iteration'] + 1}")
|
||
|
||
# 增加迭代次数
|
||
updated_state = increment_iteration(state)
|
||
|
||
return updated_state
|
||
|