1235 lines
52 KiB
Python
1235 lines
52 KiB
Python
|
|
"""
|
|||
|
|
LangGraph工作流节点
|
|||
|
|
实现具体的工作流节点逻辑
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import asyncio
|
|||
|
|
from typing import Dict, Any, List, Tuple, Optional
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
|
|
|||
|
|
# LangSmith会自动追踪LangChain的LLM调用
|
|||
|
|
|
|||
|
|
from retriver.langgraph.graph_state import (
|
|||
|
|
QueryState,
|
|||
|
|
RetrievalResult,
|
|||
|
|
SufficiencyCheck,
|
|||
|
|
update_state_with_retrieval,
|
|||
|
|
update_state_with_sufficiency_check,
|
|||
|
|
update_state_with_event_triples,
|
|||
|
|
increment_iteration,
|
|||
|
|
finalize_state
|
|||
|
|
)
|
|||
|
|
from task_manager import get_task_manager
|
|||
|
|
from exceptions import TaskCancelledException, WorkflowStoppedException
|
|||
|
|
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever
|
|||
|
|
from retriver.langgraph.langchain_components import (
|
|||
|
|
OneAPILLM,
|
|||
|
|
SufficiencyCheckParser,
|
|||
|
|
QueryComplexityParser,
|
|||
|
|
QUERY_COMPLEXITY_CHECK_PROMPT,
|
|||
|
|
SUFFICIENCY_CHECK_PROMPT,
|
|||
|
|
QUERY_DECOMPOSITION_PROMPT,
|
|||
|
|
SUB_QUERY_GENERATION_PROMPT,
|
|||
|
|
SIMPLE_ANSWER_PROMPT,
|
|||
|
|
FINAL_ANSWER_PROMPT,
|
|||
|
|
format_passages,
|
|||
|
|
format_mixed_passages,
|
|||
|
|
format_sub_queries,
|
|||
|
|
format_event_triples
|
|||
|
|
)
|
|||
|
|
from retriver.langgraph.es_vector_retriever import ESVectorRetriever
|
|||
|
|
from prompt_loader import get_prompt_loader
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GraphNodes:
|
|||
|
|
"""工作流节点实现类"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
retriever: LangChainHippoRAGRetriever,
|
|||
|
|
llm: OneAPILLM,
|
|||
|
|
keyword: str,
|
|||
|
|
max_parallel_retrievals: int = 2,
|
|||
|
|
simple_retrieval_top_k: int = 3,
|
|||
|
|
complexity_llm: Optional[OneAPILLM] = None,
|
|||
|
|
sufficiency_llm: Optional[OneAPILLM] = None,
|
|||
|
|
skip_llm_generation: bool = False
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化节点处理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
retriever: HippoRAG检索器
|
|||
|
|
llm: OneAPI LLM (用于生成答案)
|
|||
|
|
keyword: ES索引关键词
|
|||
|
|
max_parallel_retrievals: 最大并行检索数
|
|||
|
|
simple_retrieval_top_k: 简单检索返回文档数
|
|||
|
|
complexity_llm: 复杂度判断专用LLM(如果不指定则使用llm)
|
|||
|
|
sufficiency_llm: 充分性检查专用LLM(如果不指定则使用llm)
|
|||
|
|
skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果)
|
|||
|
|
"""
|
|||
|
|
self.retriever = retriever
|
|||
|
|
self.llm = llm
|
|||
|
|
self.complexity_llm = complexity_llm or llm # 如果没有指定,使用主LLM
|
|||
|
|
self.sufficiency_llm = sufficiency_llm or llm # 如果没有指定,使用主LLM
|
|||
|
|
self.keyword = keyword
|
|||
|
|
self.max_parallel_retrievals = max_parallel_retrievals
|
|||
|
|
self.skip_llm_generation = skip_llm_generation
|
|||
|
|
self.sufficiency_parser = SufficiencyCheckParser()
|
|||
|
|
self.complexity_parser = QueryComplexityParser()
|
|||
|
|
|
|||
|
|
# 创建ES向量检索器用于简单查询
|
|||
|
|
self.es_vector_retriever = ESVectorRetriever(
|
|||
|
|
keyword=keyword,
|
|||
|
|
top_k=simple_retrieval_top_k
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取prompt加载器
|
|||
|
|
self.prompt_loader = get_prompt_loader()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_response_text(self, response):
|
|||
|
|
"""统一的响应文本提取方法"""
|
|||
|
|
if hasattr(response, 'generations') and response.generations:
|
|||
|
|
return response.generations[0][0].text
|
|||
|
|
elif hasattr(response, 'content'):
|
|||
|
|
return response.content
|
|||
|
|
elif isinstance(response, dict) and 'response' in response:
|
|||
|
|
return response['response']
|
|||
|
|
else:
|
|||
|
|
return str(response)
|
|||
|
|
|
|||
|
|
def _check_should_stop(self, state: QueryState, node_name: str = None):
|
|||
|
|
"""
|
|||
|
|
检查是否应该停止执行
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
node_name: 当前节点名称
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
TaskCancelledException: 如果任务被取消
|
|||
|
|
"""
|
|||
|
|
# 如果有task_id,检查任务管理器
|
|||
|
|
if state.get("task_id"):
|
|||
|
|
task_manager = get_task_manager()
|
|||
|
|
if task_manager.should_stop(state["task_id"]):
|
|||
|
|
print(f"[NODE] 节点 {node_name or '未知'} 检测到停止信号")
|
|||
|
|
raise TaskCancelledException(
|
|||
|
|
task_id=state["task_id"],
|
|||
|
|
message=f"任务在节点 {node_name} 被取消"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查状态中的停止标志
|
|||
|
|
if state.get("should_stop", False):
|
|||
|
|
raise WorkflowStoppedException(
|
|||
|
|
node_name=node_name,
|
|||
|
|
message=f"工作流在节点 {node_name} 停止"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def query_complexity_check_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
查询复杂度判断节点
|
|||
|
|
判断用户查询是否需要复杂的知识图谱推理
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[?] 执行查询复杂度判断: {state['original_query']}")
|
|||
|
|
self._check_should_stop(state, "query_complexity_check")
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM
|
|||
|
|
if self.prompt_loader.should_skip_llm('query_complexity_check'):
|
|||
|
|
print(f"[NEXT] 跳过复杂度检查LLM调用,默认为简单查询")
|
|||
|
|
state['is_complex'] = False
|
|||
|
|
state['complexity_level'] = 'simple'
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 构建提示词
|
|||
|
|
prompt = QUERY_COMPLEXITY_CHECK_PROMPT.format(
|
|||
|
|
query=state['original_query']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用专用的复杂度判断LLM
|
|||
|
|
response = self.complexity_llm.invoke(prompt)
|
|||
|
|
response_text = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
# 解析响应
|
|||
|
|
complexity_result = self.complexity_parser.parse(response_text)
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
state['query_complexity'] = {
|
|||
|
|
"is_complex": complexity_result.is_complex,
|
|||
|
|
"complexity_level": complexity_result.complexity_level,
|
|||
|
|
"confidence": complexity_result.confidence,
|
|||
|
|
"reason": complexity_result.reason
|
|||
|
|
}
|
|||
|
|
state['is_complex_query'] = complexity_result.is_complex
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
state["debug_info"]["llm_calls"] += 1
|
|||
|
|
|
|||
|
|
print(f"[INFO] 复杂度判断结果: {'复杂' if complexity_result.is_complex else '简单'} "
|
|||
|
|
f"(置信度: {complexity_result.confidence:.2f})")
|
|||
|
|
print(f"理由: {complexity_result.reason}")
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 复杂度判断失败: {e}")
|
|||
|
|
# 如果判断失败,默认为复杂查询,走现有逻辑
|
|||
|
|
state['query_complexity'] = {
|
|||
|
|
"is_complex": True,
|
|||
|
|
"complexity_level": "complex",
|
|||
|
|
"confidence": 0.5,
|
|||
|
|
"reason": f"复杂度判断失败: {str(e)},默认使用复杂检索"
|
|||
|
|
}
|
|||
|
|
state['is_complex_query'] = True
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
def debug_mode_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
调试模式节点
|
|||
|
|
根据用户设置的debug_mode参数决定是否覆盖复杂度判断结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[?] 执行调试模式检查: mode={state['debug_mode']}")
|
|||
|
|
self._check_should_stop(state, "debug_mode")
|
|||
|
|
|
|||
|
|
# 保存原始复杂度判断结果
|
|||
|
|
original_is_complex = state['is_complex_query']
|
|||
|
|
|
|||
|
|
if state['debug_mode'] == 'simple':
|
|||
|
|
print("[?] 调试模式: 强制使用简单检索路径")
|
|||
|
|
# 强制设置为简单查询,无论原始复杂度判断结果如何
|
|||
|
|
state['is_complex_query'] = False
|
|||
|
|
# 保留原始复杂度判断结果,但添加调试信息
|
|||
|
|
if 'debug_override' not in state['query_complexity']:
|
|||
|
|
state['query_complexity']['debug_override'] = {
|
|||
|
|
'original_complexity': original_is_complex,
|
|||
|
|
'debug_mode': 'simple',
|
|||
|
|
'override_reason': '调试模式强制使用简单检索路径'
|
|||
|
|
}
|
|||
|
|
elif state['debug_mode'] == 'complex':
|
|||
|
|
print("[?] 调试模式: 强制使用复杂检索路径")
|
|||
|
|
# 强制设置为复杂查询,无论原始复杂度判断结果如何
|
|||
|
|
state['is_complex_query'] = True
|
|||
|
|
# 保留原始复杂度判断结果,但添加调试信息
|
|||
|
|
if 'debug_override' not in state['query_complexity']:
|
|||
|
|
state['query_complexity']['debug_override'] = {
|
|||
|
|
'original_complexity': original_is_complex,
|
|||
|
|
'debug_mode': 'complex',
|
|||
|
|
'override_reason': '调试模式强制使用复杂检索路径'
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
# debug_mode = "0" 或其他值,使用原始复杂度判断结果
|
|||
|
|
print("[?] 调试模式: 使用自动复杂度判断结果")
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
def query_decomposition_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
查询分解节点
|
|||
|
|
将用户原始查询分解为2个便于分头检索的子查询
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[TARGET] 执行查询分解: {state['original_query']}")
|
|||
|
|
self._check_should_stop(state, "query_decomposition")
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM
|
|||
|
|
if self.prompt_loader.should_skip_llm('query_decomposition'):
|
|||
|
|
print(f"[NEXT] 跳过查询分解LLM调用,使用原查询")
|
|||
|
|
# 不分解,直接使用原查询
|
|||
|
|
state['decomposed_sub_queries'] = []
|
|||
|
|
state['sub_queries'] = [state['original_query']]
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 构建提示词
|
|||
|
|
prompt = QUERY_DECOMPOSITION_PROMPT.format(
|
|||
|
|
original_query=state['original_query']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用LLM生成子查询
|
|||
|
|
response = self.llm.invoke(prompt)
|
|||
|
|
response_text = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
print(f"[?] LLM原始响应: {response_text[:200]}...")
|
|||
|
|
|
|||
|
|
# 清理和提取JSON内容
|
|||
|
|
cleaned_text = response_text.strip()
|
|||
|
|
|
|||
|
|
# 如果响应包含"答案:",提取其后的内容
|
|||
|
|
if "答案:" in cleaned_text or "答案:" in cleaned_text:
|
|||
|
|
# 找到"答案:"后的内容
|
|||
|
|
answer_markers = ["答案:", "答案:"]
|
|||
|
|
for marker in answer_markers:
|
|||
|
|
if marker in cleaned_text:
|
|||
|
|
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 清理markdown代码块标记
|
|||
|
|
if cleaned_text.startswith('```json'):
|
|||
|
|
cleaned_text = cleaned_text[7:] # 移除 ```json
|
|||
|
|
elif cleaned_text.startswith('```'):
|
|||
|
|
cleaned_text = cleaned_text[3:] # 移除 ```
|
|||
|
|
if cleaned_text.endswith('```'):
|
|||
|
|
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
|||
|
|
cleaned_text = cleaned_text.strip()
|
|||
|
|
|
|||
|
|
# 尝试提取JSON部分(查找{开始到}结束的内容)
|
|||
|
|
if '{' in cleaned_text and '}' in cleaned_text:
|
|||
|
|
start_idx = cleaned_text.find('{')
|
|||
|
|
end_idx = cleaned_text.rfind('}') + 1
|
|||
|
|
cleaned_text = cleaned_text[start_idx:end_idx]
|
|||
|
|
|
|||
|
|
# 解析子查询
|
|||
|
|
try:
|
|||
|
|
data = json.loads(cleaned_text)
|
|||
|
|
raw_decomposed_queries = data.get('sub_queries', [])
|
|||
|
|
|
|||
|
|
# 处理不同格式的子查询(与sub_query_generation_node保持一致)
|
|||
|
|
decomposed_sub_queries = []
|
|||
|
|
for item in raw_decomposed_queries:
|
|||
|
|
if isinstance(item, str):
|
|||
|
|
decomposed_sub_queries.append(item)
|
|||
|
|
elif isinstance(item, dict) and 'query' in item:
|
|||
|
|
decomposed_sub_queries.append(item['query'])
|
|||
|
|
elif isinstance(item, dict):
|
|||
|
|
query_text = item.get('text', item.get('content', str(item)))
|
|||
|
|
decomposed_sub_queries.append(query_text)
|
|||
|
|
else:
|
|||
|
|
decomposed_sub_queries.append(str(item))
|
|||
|
|
|
|||
|
|
# 确保都是字符串
|
|||
|
|
decomposed_sub_queries = [str(query).strip() for query in decomposed_sub_queries if query]
|
|||
|
|
print(f"[OK] JSON解析成功,获得子查询: {decomposed_sub_queries}")
|
|||
|
|
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
print(f"[ERROR] JSON解析失败: {e}")
|
|||
|
|
print(f"尝试规则提取...")
|
|||
|
|
# 如果JSON解析失败,使用简单规则提取
|
|||
|
|
lines = response_text.split('\n')
|
|||
|
|
decomposed_sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
|
|||
|
|
print(f"规则提取结果: {decomposed_sub_queries}")
|
|||
|
|
|
|||
|
|
# 确保有子查询
|
|||
|
|
if not decomposed_sub_queries:
|
|||
|
|
print(f"[WARNING] LLM未生成有效子查询,使用简单分解策略...")
|
|||
|
|
# 简单按标点符号分解
|
|||
|
|
original = state['original_query']
|
|||
|
|
if '?' in original:
|
|||
|
|
parts = [part.strip() for part in original.split('?') if part.strip()]
|
|||
|
|
if len(parts) >= 2:
|
|||
|
|
decomposed_sub_queries = [parts[0], parts[1]]
|
|||
|
|
print(f"按中文问号分解: {decomposed_sub_queries}")
|
|||
|
|
elif '?' in original:
|
|||
|
|
parts = [part.strip() for part in original.split('?') if part.strip()]
|
|||
|
|
if len(parts) >= 2:
|
|||
|
|
decomposed_sub_queries = [parts[0], parts[1]]
|
|||
|
|
print(f"按英文问号分解: {decomposed_sub_queries}")
|
|||
|
|
|
|||
|
|
# 如果还是没有,使用原查询
|
|||
|
|
if not decomposed_sub_queries:
|
|||
|
|
print(f"[WARNING] 无法自动分解,使用原查询作为两个子查询")
|
|||
|
|
decomposed_sub_queries = [original, original]
|
|||
|
|
|
|||
|
|
elif len(decomposed_sub_queries) == 1:
|
|||
|
|
print(f"[WARNING] 只获得1个子查询,补充第二个")
|
|||
|
|
# 如果只有一个子查询,使用原查询作为第二个
|
|||
|
|
decomposed_sub_queries.append(state['original_query'])
|
|||
|
|
|
|||
|
|
# 限制为2个子查询
|
|||
|
|
decomposed_sub_queries = decomposed_sub_queries[:2]
|
|||
|
|
|
|||
|
|
# 更新状态 - 存储初始分解的子查询
|
|||
|
|
state['decomposed_sub_queries'] = decomposed_sub_queries
|
|||
|
|
state['sub_queries'].extend(decomposed_sub_queries)
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
state["debug_info"]["llm_calls"] += 1
|
|||
|
|
|
|||
|
|
print(f"[OK] 查询分解完成: {decomposed_sub_queries}")
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 查询分解失败: {e}")
|
|||
|
|
# 如果生成失败,使用默认子查询
|
|||
|
|
default_sub_queries = [
|
|||
|
|
state['original_query'] + " 详细信息",
|
|||
|
|
state['original_query'] + " 相关内容"
|
|||
|
|
]
|
|||
|
|
state['decomposed_sub_queries'] = default_sub_queries
|
|||
|
|
state['sub_queries'].extend(default_sub_queries)
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
def simple_vector_retrieval_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
简单向量检索节点
|
|||
|
|
直接与ES向量库中的文本段落进行向量匹配
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[SEARCH] 执行简单向量检索: {state['original_query']}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用ES向量检索器检索相关文档
|
|||
|
|
documents = self.es_vector_retriever.retrieve(state['original_query'])
|
|||
|
|
|
|||
|
|
# 提取段落内容和来源信息
|
|||
|
|
passages = [doc.page_content for doc in documents]
|
|||
|
|
sources = [f"简单检索-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
|
|||
|
|
|
|||
|
|
# 创建检索结果
|
|||
|
|
retrieval_result = RetrievalResult(
|
|||
|
|
passages=passages,
|
|||
|
|
documents=documents,
|
|||
|
|
sources=sources,
|
|||
|
|
query=state['original_query'],
|
|||
|
|
iteration=state['current_iteration']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
updated_state = update_state_with_retrieval(state, retrieval_result)
|
|||
|
|
|
|||
|
|
print(f"[OK] 简单向量检索完成,获得 {len(passages)} 个段落")
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 简单向量检索失败: {e}")
|
|||
|
|
# 如果检索失败,返回空结果
|
|||
|
|
retrieval_result = RetrievalResult(
|
|||
|
|
passages=[],
|
|||
|
|
documents=[],
|
|||
|
|
sources=[],
|
|||
|
|
query=state['original_query'],
|
|||
|
|
iteration=state['current_iteration']
|
|||
|
|
)
|
|||
|
|
return update_state_with_retrieval(state, retrieval_result)
|
|||
|
|
|
|||
|
|
def simple_answer_generation_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
简单答案生成节点
|
|||
|
|
基于简单检索的结果生成答案
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[NOTE] 生成简单查询答案")
|
|||
|
|
self._check_should_stop(state, "simple_answer_generation")
|
|||
|
|
|
|||
|
|
# 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
|||
|
|
if 'all_documents' in state and state['all_documents']:
|
|||
|
|
formatted_passages = format_mixed_passages(state['all_documents'])
|
|||
|
|
else:
|
|||
|
|
formatted_passages = format_passages(state['all_passages'])
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM生成(全局配置或prompt级别配置)
|
|||
|
|
if self.skip_llm_generation or self.prompt_loader.should_skip_llm('simple_answer'):
|
|||
|
|
if self.skip_llm_generation:
|
|||
|
|
print(f"[SEARCH] 跳过LLM生成(skip_llm_generation=true),直接返回检索结果")
|
|||
|
|
else:
|
|||
|
|
print(f"[NEXT] 跳过简单答案LLM调用(simple_answer.skip_llm=true),直接返回检索结果")
|
|||
|
|
final_answer = f"检索到的信息:\n{formatted_passages}"
|
|||
|
|
state['final_answer'] = final_answer
|
|||
|
|
return finalize_state(state, final_answer)
|
|||
|
|
|
|||
|
|
# 构建提示词
|
|||
|
|
prompt = SIMPLE_ANSWER_PROMPT.format(
|
|||
|
|
query=state['original_query'],
|
|||
|
|
passages=formatted_passages
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用LLM生成答案
|
|||
|
|
response = self.llm.invoke(prompt)
|
|||
|
|
final_answer = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
if not final_answer.strip():
|
|||
|
|
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
|
|||
|
|
|
|||
|
|
# 完成状态
|
|||
|
|
updated_state = finalize_state(state, final_answer)
|
|||
|
|
|
|||
|
|
print(f"[OK] 简单答案生成完成 (长度: {len(final_answer)} 字符)")
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 简单答案生成失败: {e}")
|
|||
|
|
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
|
|||
|
|
return finalize_state(state, error_answer)
|
|||
|
|
|
|||
|
|
def initial_retrieval_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
并行初始检索节点
|
|||
|
|
使用原始查询和2个分解的子查询并行进行检索
|
|||
|
|
每个查询返回混合检索结果:TOP-10个事件节点 + TOP-3个段落节点
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
# 准备要检索的查询列表:原始查询 + 2个子查询
|
|||
|
|
original_query = state['original_query']
|
|||
|
|
sub_queries = state.get('decomposed_sub_queries', [])
|
|||
|
|
|
|||
|
|
all_queries = [original_query] + sub_queries[:2] # 确保最多3个查询
|
|||
|
|
|
|||
|
|
print(f"[SEARCH] 执行并行初始检索 - {len(all_queries)} 个查询")
|
|||
|
|
self._check_should_stop(state, "initial_retrieval")
|
|||
|
|
|
|||
|
|
for i, query in enumerate(all_queries):
|
|||
|
|
query_type = "原始查询" if i == 0 else f"子查询{i}"
|
|||
|
|
print(f" {query_type}: {query}")
|
|||
|
|
|
|||
|
|
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, List, str]:
|
|||
|
|
"""检索单个查询,返回文档、段落、源信息"""
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
# 根据查询类型设置标签
|
|||
|
|
if index == 0:
|
|||
|
|
query_label = "原始查询"
|
|||
|
|
else:
|
|||
|
|
query_label = f"子查询{index}"
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
print(f"[STARTING] {query_label} 开始检索 [{time.strftime('%H:%M:%S', time.localtime(start_time))}]")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
documents = self.retriever.invoke(query)
|
|||
|
|
|
|||
|
|
# 检索器现在返回混合结果(事件+段落),不再限制数量
|
|||
|
|
top_documents = documents # 使用所有检索到的文档
|
|||
|
|
passages = [doc.page_content for doc in top_documents]
|
|||
|
|
|
|||
|
|
# 根据查询类型设置源标识,支持混合节点类型
|
|||
|
|
sources = []
|
|||
|
|
for doc in top_documents:
|
|||
|
|
# 优先使用node_id,然后是passage_id
|
|||
|
|
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id', 'unknown')
|
|||
|
|
node_type = doc.metadata.get('node_type', 'unknown')
|
|||
|
|
|
|||
|
|
if index == 0:
|
|||
|
|
sources.append(f"原始查询-{node_type}-{doc_id}")
|
|||
|
|
else:
|
|||
|
|
sources.append(f"子查询{index}-{node_type}-{doc_id}")
|
|||
|
|
|
|||
|
|
end_time = time.time()
|
|||
|
|
duration = end_time - start_time
|
|||
|
|
print(f"[OK] {query_label} 检索完成 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒,获得 {len(passages)} 个内容(事件+段落)")
|
|||
|
|
return index, documents, passages, sources, query_label
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
end_time = time.time()
|
|||
|
|
duration = end_time - start_time
|
|||
|
|
print(f"[ERROR] {query_label} 检索失败 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒 - 错误: {e}")
|
|||
|
|
return index, [], [], [], query_label
|
|||
|
|
|
|||
|
|
# 并行执行检索
|
|||
|
|
all_documents = []
|
|||
|
|
all_passages = []
|
|||
|
|
all_sources = []
|
|||
|
|
retrieval_details = {}
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
parallel_start_time = time.time()
|
|||
|
|
print(f"[FAST] 开始并行执行 {len(all_queries)} 个检索任务 [{time.strftime('%H:%M:%S', time.localtime(parallel_start_time))}]")
|
|||
|
|
|
|||
|
|
with ThreadPoolExecutor(max_workers=min(3, len(all_queries))) as executor:
|
|||
|
|
# 提交检索任务
|
|||
|
|
futures = {
|
|||
|
|
executor.submit(retrieve_single_query, query, i): (query, i)
|
|||
|
|
for i, query in enumerate(all_queries)
|
|||
|
|
}
|
|||
|
|
print(f"[?] 所有 {len(futures)} 个检索任务已提交到线程池")
|
|||
|
|
|
|||
|
|
# 收集结果
|
|||
|
|
for future in as_completed(futures):
|
|||
|
|
query, query_index = futures[future]
|
|||
|
|
try:
|
|||
|
|
index, documents, passages, sources, query_label = future.result()
|
|||
|
|
|
|||
|
|
# 记录检索详情
|
|||
|
|
retrieval_details[query_label] = {
|
|||
|
|
'query': query,
|
|||
|
|
'passages_count': len(passages),
|
|||
|
|
'documents_count': len(documents)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 合并结果(暂时不去重,稍后统一处理)
|
|||
|
|
all_documents.extend(documents)
|
|||
|
|
all_passages.extend(passages)
|
|||
|
|
all_sources.extend(sources)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 查询 {query_index+1} 处理失败: {e}")
|
|||
|
|
|
|||
|
|
parallel_end_time = time.time()
|
|||
|
|
parallel_duration = parallel_end_time - parallel_start_time
|
|||
|
|
print(f"[TARGET] 并行检索全部完成 [{time.strftime('%H:%M:%S', time.localtime(parallel_end_time))}] - 总耗时: {parallel_duration:.2f}秒")
|
|||
|
|
|
|||
|
|
# 去重处理:基于文档ID或内容去重
|
|||
|
|
unique_documents = []
|
|||
|
|
unique_passages = []
|
|||
|
|
unique_sources = []
|
|||
|
|
seen_passage_ids = set()
|
|||
|
|
seen_content_hashes = set()
|
|||
|
|
|
|||
|
|
for i, (doc, passage, source) in enumerate(zip(all_documents, all_passages, all_sources)):
|
|||
|
|
# 尝试使用node_id或passage_id去重,支持混合节点类型
|
|||
|
|
doc_id = None
|
|||
|
|
if doc:
|
|||
|
|
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id')
|
|||
|
|
|
|||
|
|
if doc_id and doc_id in seen_passage_ids:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 使用内容hash去重(作为backup)
|
|||
|
|
content_hash = hash(passage.strip())
|
|||
|
|
if content_hash in seen_content_hashes:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 添加到去重后的结果
|
|||
|
|
unique_documents.append(doc)
|
|||
|
|
unique_passages.append(passage)
|
|||
|
|
unique_sources.append(source)
|
|||
|
|
|
|||
|
|
if doc_id:
|
|||
|
|
seen_passage_ids.add(doc_id)
|
|||
|
|
seen_content_hashes.add(content_hash)
|
|||
|
|
|
|||
|
|
removed_count = len(all_passages) - len(unique_passages)
|
|||
|
|
if removed_count > 0:
|
|||
|
|
print(f"[SEARCH] 去重处理: 移除了 {removed_count} 个重复内容")
|
|||
|
|
|
|||
|
|
# 创建检索结果
|
|||
|
|
query_description = f"并行检索: 原始查询 + {len(sub_queries)} 个子查询"
|
|||
|
|
retrieval_result = RetrievalResult(
|
|||
|
|
passages=unique_passages,
|
|||
|
|
documents=unique_documents,
|
|||
|
|
sources=unique_sources,
|
|||
|
|
query=query_description,
|
|||
|
|
iteration=state['current_iteration']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
updated_state = update_state_with_retrieval(state, retrieval_result)
|
|||
|
|
|
|||
|
|
# 存储检索详情到状态中,便于后续分析
|
|||
|
|
updated_state['initial_retrieval_details'] = retrieval_details
|
|||
|
|
|
|||
|
|
# 收集PageRank分数信息(每个查询的完整PageRank结果)
|
|||
|
|
for i, query in enumerate(all_queries):
|
|||
|
|
try:
|
|||
|
|
# 跳过PageRank数据收集以避免LangSmith传输大量数据
|
|||
|
|
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
|
|||
|
|
# 仅设置标识表示数据可用(实际数据在HippoRAG内部处理)
|
|||
|
|
updated_state['pagerank_data_available'] = True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] 收集查询{i+1}的PageRank分数失败: {e}")
|
|||
|
|
|
|||
|
|
total_passages_before = len(all_passages)
|
|||
|
|
total_passages_after = len(unique_passages)
|
|||
|
|
print(f"[SUCCESS] 并行初始检索完成")
|
|||
|
|
print(f" 检索前总内容: {total_passages_before}, 去重后: {total_passages_after}")
|
|||
|
|
print(f" 原始查询: {retrieval_details.get('原始查询', {}).get('passages_count', 0)} 个内容(事件+段落)")
|
|||
|
|
for i in range(1, len(all_queries)):
|
|||
|
|
key = f"子查询{i}"
|
|||
|
|
count = retrieval_details.get(key, {}).get('passages_count', 0)
|
|||
|
|
print(f" 子查询{i}: {count} 个内容(事件+段落)")
|
|||
|
|
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
def sufficiency_check_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
充分性检查节点
|
|||
|
|
判断当前检索到的信息是否足够回答用户查询
|
|||
|
|
包含对分解子查询的处理
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[?] 执行充分性检查 (迭代 {state['current_iteration']})")
|
|||
|
|
self._check_should_stop(state, "sufficiency_check")
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM
|
|||
|
|
if self.prompt_loader.should_skip_llm('sufficiency_check'):
|
|||
|
|
print(f"[NEXT] 跳过充分性检查LLM调用,默认为充分")
|
|||
|
|
state['is_sufficient'] = True
|
|||
|
|
state['sufficiency_confidence'] = 1.0
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
|||
|
|
if 'all_documents' in state and state['all_documents']:
|
|||
|
|
formatted_passages = format_mixed_passages(state['all_documents'])
|
|||
|
|
else:
|
|||
|
|
formatted_passages = format_passages(state['all_passages'])
|
|||
|
|
|
|||
|
|
# 格式化分解的子查询
|
|||
|
|
decomposed_sub_queries = state.get('decomposed_sub_queries', [])
|
|||
|
|
formatted_decomposed_queries = format_sub_queries(decomposed_sub_queries) if decomposed_sub_queries else "无"
|
|||
|
|
|
|||
|
|
# 格式化事件三元组
|
|||
|
|
event_triples = state.get('event_triples', [])
|
|||
|
|
formatted_event_triples = format_event_triples(event_triples)
|
|||
|
|
|
|||
|
|
# 构建提示词,包含分解的子查询和事件三元组信息
|
|||
|
|
prompt = SUFFICIENCY_CHECK_PROMPT.format(
|
|||
|
|
query=state['original_query'],
|
|||
|
|
passages=formatted_passages,
|
|||
|
|
decomposed_sub_queries=formatted_decomposed_queries,
|
|||
|
|
event_triples=formatted_event_triples
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 调用专用的充分性检查LLM
|
|||
|
|
try:
|
|||
|
|
response = self.sufficiency_llm.invoke(prompt)
|
|||
|
|
response_text = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
# 解析响应
|
|||
|
|
sufficiency_result = self.sufficiency_parser.parse(response_text)
|
|||
|
|
|
|||
|
|
# 创建充分性检查结果
|
|||
|
|
sufficiency_check = SufficiencyCheck(
|
|||
|
|
is_sufficient=sufficiency_result.is_sufficient,
|
|||
|
|
confidence=sufficiency_result.confidence,
|
|||
|
|
reason=sufficiency_result.reason,
|
|||
|
|
sub_queries=sufficiency_result.sub_queries
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
updated_state = update_state_with_sufficiency_check(state, sufficiency_check)
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
updated_state["debug_info"]["llm_calls"] += 1
|
|||
|
|
|
|||
|
|
print(f"[INFO] 充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'} "
|
|||
|
|
f"(置信度: {sufficiency_result.confidence:.2f})")
|
|||
|
|
print(f" 基于 {len(state['all_passages'])} 个段落 (来自原始查询和{len(decomposed_sub_queries)}个子查询)")
|
|||
|
|
|
|||
|
|
if not sufficiency_result.is_sufficient and sufficiency_result.sub_queries:
|
|||
|
|
print(f"[TARGET] 生成新的子查询: {sufficiency_result.sub_queries}")
|
|||
|
|
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 充分性检查失败: {e}")
|
|||
|
|
# 如果检查失败,假设不充分并生成默认子查询
|
|||
|
|
sufficiency_check = SufficiencyCheck(
|
|||
|
|
is_sufficient=False,
|
|||
|
|
confidence=0.5,
|
|||
|
|
reason=f"充分性检查失败: {str(e)}",
|
|||
|
|
sub_queries=[state['original_query'] + " 详细信息"]
|
|||
|
|
)
|
|||
|
|
return update_state_with_sufficiency_check(state, sufficiency_check)
|
|||
|
|
|
|||
|
|
def sub_query_generation_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
子查询生成节点
|
|||
|
|
如果充分性检查不通过,生成子查询
|
|||
|
|
考虑之前已经生成的分解子查询
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[TARGET] 生成子查询")
|
|||
|
|
self._check_should_stop(state, "sub_query_generation")
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM
|
|||
|
|
if self.prompt_loader.should_skip_llm('sub_query_generation'):
|
|||
|
|
print(f"[NEXT] 跳过子查询生成LLM调用")
|
|||
|
|
# 不生成新查询,继续下一轮
|
|||
|
|
state['current_iteration'] += 1
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 如果已经有子查询,直接返回
|
|||
|
|
if state['current_sub_queries']:
|
|||
|
|
print(f"[OK] 使用现有子查询: {state['current_sub_queries']}")
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
# 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
|||
|
|
if 'all_documents' in state and state['all_documents']:
|
|||
|
|
formatted_passages = format_mixed_passages(state['all_documents'])
|
|||
|
|
else:
|
|||
|
|
formatted_passages = format_passages(state['all_passages'])
|
|||
|
|
|
|||
|
|
# 格式化之前生成的所有子查询(包括分解的子查询)
|
|||
|
|
previous_sub_queries = state.get('sub_queries', [])
|
|||
|
|
formatted_previous_queries = format_sub_queries(previous_sub_queries) if previous_sub_queries else "无"
|
|||
|
|
|
|||
|
|
# 格式化事件三元组
|
|||
|
|
event_triples = state.get('event_triples', [])
|
|||
|
|
formatted_event_triples = format_event_triples(event_triples)
|
|||
|
|
|
|||
|
|
# 获取充分性检查的不充分原因
|
|||
|
|
sufficiency_check = state.get('sufficiency_check', {})
|
|||
|
|
insufficiency_reason = sufficiency_check.get('reason', '信息不够充分,需要更多相关信息')
|
|||
|
|
|
|||
|
|
# 构建提示词,包含充分性检查的反馈和事件三元组
|
|||
|
|
prompt = SUB_QUERY_GENERATION_PROMPT.format(
|
|||
|
|
original_query=state['original_query'],
|
|||
|
|
existing_passages=formatted_passages,
|
|||
|
|
previous_sub_queries=formatted_previous_queries,
|
|||
|
|
event_triples=formatted_event_triples,
|
|||
|
|
insufficiency_reason=insufficiency_reason
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用LLM生成子查询
|
|||
|
|
response = self.llm.invoke(prompt)
|
|||
|
|
response_text = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
state["debug_info"]["llm_calls"] += 1
|
|||
|
|
|
|||
|
|
# 清理和提取JSON内容
|
|||
|
|
cleaned_text = response_text.strip()
|
|||
|
|
|
|||
|
|
# 如果响应包含"答案:",提取其后的内容
|
|||
|
|
if "答案:" in cleaned_text or "答案:" in cleaned_text:
|
|||
|
|
# 找到"答案:"后的内容
|
|||
|
|
answer_markers = ["答案:", "答案:"]
|
|||
|
|
for marker in answer_markers:
|
|||
|
|
if marker in cleaned_text:
|
|||
|
|
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 清理markdown代码块标记
|
|||
|
|
if cleaned_text.startswith('```json'):
|
|||
|
|
cleaned_text = cleaned_text[7:] # 移除 ```json
|
|||
|
|
elif cleaned_text.startswith('```'):
|
|||
|
|
cleaned_text = cleaned_text[3:] # 移除 ```
|
|||
|
|
if cleaned_text.endswith('```'):
|
|||
|
|
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
|||
|
|
cleaned_text = cleaned_text.strip()
|
|||
|
|
|
|||
|
|
# 尝试提取JSON部分(查找{开始到}结束的内容)
|
|||
|
|
if '{' in cleaned_text and '}' in cleaned_text:
|
|||
|
|
start_idx = cleaned_text.find('{')
|
|||
|
|
end_idx = cleaned_text.rfind('}') + 1
|
|||
|
|
cleaned_text = cleaned_text[start_idx:end_idx]
|
|||
|
|
|
|||
|
|
# 解析子查询
|
|||
|
|
try:
|
|||
|
|
data = json.loads(cleaned_text)
|
|||
|
|
raw_sub_queries = data.get('sub_queries', [])
|
|||
|
|
|
|||
|
|
# 处理不同格式的子查询
|
|||
|
|
sub_queries = []
|
|||
|
|
for item in raw_sub_queries:
|
|||
|
|
if isinstance(item, str):
|
|||
|
|
# 如果是字符串,直接使用
|
|||
|
|
sub_queries.append(item)
|
|||
|
|
elif isinstance(item, dict) and 'query' in item:
|
|||
|
|
# 如果是字典,提取query字段
|
|||
|
|
sub_queries.append(item['query'])
|
|||
|
|
elif isinstance(item, dict):
|
|||
|
|
# 如果是字典但没有query字段,尝试找到查询内容
|
|||
|
|
query_text = item.get('text', item.get('content', str(item)))
|
|||
|
|
sub_queries.append(query_text)
|
|||
|
|
else:
|
|||
|
|
# 其他情况,转为字符串
|
|||
|
|
sub_queries.append(str(item))
|
|||
|
|
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# 如果JSON解析失败,使用简单规则提取
|
|||
|
|
lines = response_text.split('\n')
|
|||
|
|
sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
|
|||
|
|
|
|||
|
|
# 确保有子查询且都是字符串
|
|||
|
|
if not sub_queries:
|
|||
|
|
sub_queries = [state['original_query'] + " 更多细节"]
|
|||
|
|
|
|||
|
|
# 确保所有子查询都是字符串
|
|||
|
|
sub_queries = [str(query).strip() for query in sub_queries if query]
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
state['current_sub_queries'] = sub_queries[:2] # 最多2个子查询
|
|||
|
|
state['sub_queries'].extend(state['current_sub_queries'])
|
|||
|
|
|
|||
|
|
print(f"[OK] 生成子查询: {state['current_sub_queries']}")
|
|||
|
|
print(f" (避免与之前的子查询重复: {formatted_previous_queries})")
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 子查询生成失败: {e}")
|
|||
|
|
# 如果生成失败,使用默认子查询
|
|||
|
|
default_sub_query = state['original_query'] + " 补充信息"
|
|||
|
|
state['current_sub_queries'] = [default_sub_query]
|
|||
|
|
state['sub_queries'].append(default_sub_query)
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
def parallel_retrieval_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
并行检索节点
|
|||
|
|
使用子查询并行进行检索
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
sub_queries = state['current_sub_queries']
|
|||
|
|
if not sub_queries:
|
|||
|
|
print("[WARNING] 没有子查询,跳过并行检索")
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
print(f"[?] 并行检索 {len(sub_queries)} 个子查询")
|
|||
|
|
self._check_should_stop(state, "parallel_retrieval")
|
|||
|
|
|
|||
|
|
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, str]:
|
|||
|
|
"""检索单个查询"""
|
|||
|
|
try:
|
|||
|
|
documents = self.retriever.invoke(query)
|
|||
|
|
passages = [doc.page_content for doc in documents]
|
|||
|
|
sources = [f"子查询{index+1}-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
|
|||
|
|
return index, documents, passages, sources
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 子查询 {index+1} 检索失败: {e}")
|
|||
|
|
return index, [], [], []
|
|||
|
|
|
|||
|
|
# 并行执行检索
|
|||
|
|
all_new_documents = []
|
|||
|
|
all_new_passages = []
|
|||
|
|
all_new_sources = []
|
|||
|
|
|
|||
|
|
with ThreadPoolExecutor(max_workers=self.max_parallel_retrievals) as executor:
|
|||
|
|
# 提交检索任务
|
|||
|
|
futures = {
|
|||
|
|
executor.submit(retrieve_single_query, query, i): (query, i)
|
|||
|
|
for i, query in enumerate(sub_queries)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 收集结果
|
|||
|
|
for future in as_completed(futures):
|
|||
|
|
query, query_index = futures[future]
|
|||
|
|
try:
|
|||
|
|
index, documents, passages, sources = future.result()
|
|||
|
|
if passages: # 只添加非空结果
|
|||
|
|
all_new_documents.extend(documents)
|
|||
|
|
all_new_passages.extend(passages)
|
|||
|
|
all_new_sources.extend(sources)
|
|||
|
|
print(f"[OK] 子查询 {index+1} 完成,获得 {len(passages)} 个段落")
|
|||
|
|
else:
|
|||
|
|
print(f"[WARNING] 子查询 {index+1} 无结果")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 子查询 {query_index+1} 处理失败: {e}")
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
if all_new_passages:
|
|||
|
|
retrieval_result = RetrievalResult(
|
|||
|
|
passages=all_new_passages,
|
|||
|
|
documents=all_new_documents,
|
|||
|
|
sources=all_new_sources,
|
|||
|
|
query=f"并行检索: {', '.join(sub_queries)}",
|
|||
|
|
iteration=state['current_iteration']
|
|||
|
|
)
|
|||
|
|
state = update_state_with_retrieval(state, retrieval_result)
|
|||
|
|
print(f"[SUCCESS] 并行检索完成,总共获得 {len(all_new_passages)} 个新段落")
|
|||
|
|
|
|||
|
|
# 收集子查询的PageRank分数信息
|
|||
|
|
for i, query in enumerate(sub_queries):
|
|||
|
|
try:
|
|||
|
|
# 跳过PageRank数据收集以避免LangSmith传输大量数据
|
|||
|
|
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
|
|||
|
|
# 仅设置标识表示数据可用(实际数据在HippoRAG内部处理)
|
|||
|
|
state['pagerank_data_available'] = True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WARNING] 收集并行子查询{i+1}的PageRank分数失败: {e}")
|
|||
|
|
else:
|
|||
|
|
print("[WARNING] 并行检索无有效结果")
|
|||
|
|
|
|||
|
|
# 清空当前子查询
|
|||
|
|
state['current_sub_queries'] = []
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
def final_answer_generation_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
最终答案生成节点
|
|||
|
|
基于所有检索到的信息生成最终答案
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
self._check_should_stop(state, "final_answer_generation")
|
|||
|
|
|
|||
|
|
# 格式化所有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
|||
|
|
if 'all_documents' in state and state['all_documents']:
|
|||
|
|
formatted_passages = format_mixed_passages(state['all_documents'])
|
|||
|
|
else:
|
|||
|
|
formatted_passages = format_passages(state['all_passages'])
|
|||
|
|
formatted_sub_queries = format_sub_queries(state['sub_queries'])
|
|||
|
|
|
|||
|
|
# 格式化事件三元组
|
|||
|
|
event_triples = state.get('event_triples', [])
|
|||
|
|
formatted_event_triples = format_event_triples(event_triples)
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM生成
|
|||
|
|
if self.skip_llm_generation:
|
|||
|
|
print(f"[SEARCH] 跳过LLM生成,直接返回检索结果")
|
|||
|
|
|
|||
|
|
# 构建格式化的检索结果作为最终答案
|
|||
|
|
retrieval_summary = f"""【检索结果汇总】
|
|||
|
|
|
|||
|
|
查询问题:{state['original_query']}
|
|||
|
|
|
|||
|
|
检索到 {len(state['all_passages'])} 个相关段落:
|
|||
|
|
|
|||
|
|
{formatted_passages}
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
if state['sub_queries']:
|
|||
|
|
retrieval_summary += f"""
|
|||
|
|
相关子查询:
|
|||
|
|
{formatted_sub_queries}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
retrieval_summary += f"""
|
|||
|
|
检索统计:
|
|||
|
|
- 查询复杂度:{state.get('query_complexity', 'unknown')}
|
|||
|
|
- 是否复杂查询:{state.get('is_complex_query', False)}
|
|||
|
|
- 迭代次数:{state.get('current_iteration', 0)}
|
|||
|
|
- 信息充分性:{state.get('is_sufficient', False)}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 完成状态
|
|||
|
|
updated_state = finalize_state(state, retrieval_summary)
|
|||
|
|
print(f"[OK] 检索结果返回完成 (长度: {len(retrieval_summary)} 字符)")
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
# 检查是否有流式回调
|
|||
|
|
stream_callback = state.get('stream_callback')
|
|||
|
|
|
|||
|
|
# 调试:打印state中是否有stream_callback
|
|||
|
|
if stream_callback:
|
|||
|
|
print(f"[DEBUG] 检测到stream_callback,类型: {type(stream_callback)}")
|
|||
|
|
else:
|
|||
|
|
print(f"[DEBUG] 没有stream_callback,state keys: {list(state.keys())[:10]}")
|
|||
|
|
|
|||
|
|
print(f"[NOTE] 生成最终答案" + (" (流式模式)" if stream_callback else ""))
|
|||
|
|
|
|||
|
|
# 检查是否跳过LLM
|
|||
|
|
if self.prompt_loader.should_skip_llm('final_answer'):
|
|||
|
|
print(f"[NEXT] 跳过最终答案LLM调用,直接返回检索结果")
|
|||
|
|
final_answer = f"检索到的信息:\n{formatted_passages}"
|
|||
|
|
state['final_answer'] = final_answer
|
|||
|
|
return finalize_state(state, final_answer)
|
|||
|
|
|
|||
|
|
# 构建提示词
|
|||
|
|
prompt = FINAL_ANSWER_PROMPT.format(
|
|||
|
|
original_query=state['original_query'],
|
|||
|
|
all_passages=formatted_passages,
|
|||
|
|
sub_queries=formatted_sub_queries,
|
|||
|
|
event_triples=formatted_event_triples
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if stream_callback:
|
|||
|
|
# 包装回调,添加类型标识
|
|||
|
|
def answer_stream_callback(chunk):
|
|||
|
|
stream_callback("answer_chunk", {"text": chunk})
|
|||
|
|
|
|||
|
|
# 构建config传递回调
|
|||
|
|
config = {
|
|||
|
|
'metadata': {
|
|||
|
|
'stream_callback': answer_stream_callback
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 调用LLM(会自动使用流式)
|
|||
|
|
final_answer = self.llm.invoke(prompt, config=config)
|
|||
|
|
else:
|
|||
|
|
# 非流式调用
|
|||
|
|
response = self.llm.invoke(prompt)
|
|||
|
|
final_answer = self._extract_response_text(response)
|
|||
|
|
|
|||
|
|
if not final_answer.strip():
|
|||
|
|
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
|
|||
|
|
|
|||
|
|
# 完成状态
|
|||
|
|
updated_state = finalize_state(state, final_answer)
|
|||
|
|
|
|||
|
|
print(f"[OK] 最终答案生成完成 (长度: {len(final_answer)} 字符)")
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 最终答案生成失败: {e}")
|
|||
|
|
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
|
|||
|
|
return finalize_state(state, error_answer)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def next_iteration_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
下一轮迭代节点
|
|||
|
|
增加迭代计数并准备下一轮
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[?] 进入迭代 {state['current_iteration'] + 1}")
|
|||
|
|
self._check_should_stop(state, "next_iteration")
|
|||
|
|
|
|||
|
|
# 增加迭代次数
|
|||
|
|
updated_state = increment_iteration(state)
|
|||
|
|
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
def event_triples_extraction_node(self, state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
事件三元组提取节点
|
|||
|
|
从当前检索到的事件节点中提取事件-事件之间的三元组关系
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
print(f"[SEARCH] 执行事件三元组提取")
|
|||
|
|
self._check_should_stop(state, "event_triples_extraction")
|
|||
|
|
|
|||
|
|
# 从检索结果中提取事件节点ID
|
|||
|
|
event_node_ids = []
|
|||
|
|
for doc in state.get('all_documents', []):
|
|||
|
|
if doc.metadata.get('node_type') == 'event':
|
|||
|
|
node_id = doc.metadata.get('node_id')
|
|||
|
|
if node_id:
|
|||
|
|
event_node_ids.append(node_id)
|
|||
|
|
|
|||
|
|
print(f"[INFO] 发现 {len(event_node_ids)} 个事件节点")
|
|||
|
|
|
|||
|
|
# 如果没有事件节点或图数据不可用,直接返回
|
|||
|
|
if not event_node_ids or not self.retriever.graph_data:
|
|||
|
|
print("[WARNING] 没有事件节点或图数据不可用,跳过三元组提取")
|
|||
|
|
return update_state_with_event_triples(state, [])
|
|||
|
|
|
|||
|
|
# 查询事件-事件三元组
|
|||
|
|
event_triples = self._extract_event_event_triples(event_node_ids)
|
|||
|
|
|
|||
|
|
# 更新状态
|
|||
|
|
updated_state = update_state_with_event_triples(state, event_triples)
|
|||
|
|
|
|||
|
|
print(f"[OK] 事件三元组提取完成,获得 {len(event_triples)} 个三元组")
|
|||
|
|
|
|||
|
|
# 显示提取到的三元组(调试用)
|
|||
|
|
if event_triples:
|
|||
|
|
print("[INFO] 提取到的事件-事件三元组:")
|
|||
|
|
for i, triple in enumerate(event_triples[:5]): # 只显示前5个
|
|||
|
|
if isinstance(triple, dict) and 'source_entity' in triple:
|
|||
|
|
print(f" {i+1}. [{triple['source_entity']}] --{triple['relation']}--> [{triple['target_entity']}]")
|
|||
|
|
else:
|
|||
|
|
print(f" {i+1}. {str(triple)}")
|
|||
|
|
if len(event_triples) > 5:
|
|||
|
|
print(f" ... 还有 {len(event_triples) - 5} 个三元组")
|
|||
|
|
|
|||
|
|
return updated_state
|
|||
|
|
|
|||
|
|
def _extract_event_event_triples(self, event_node_ids: List[str]) -> List[Dict[str, str]]:
|
|||
|
|
"""
|
|||
|
|
从图中提取事件节点之间的直接连接三元组
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
event_node_ids: 事件节点ID列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
三元组列表,格式: [{'source_entity': '事件A', 'relation': '关系', 'target_entity': '事件B', 'source_evidence': '来源', 'target_evidence': '来源'}, ...]
|
|||
|
|
"""
|
|||
|
|
if not self.retriever.graph_data:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
event_node_set = set(event_node_ids)
|
|||
|
|
triples = []
|
|||
|
|
|
|||
|
|
print(f"[SEARCH] 在图中查询 {len(event_node_ids)} 个事件节点间的连接")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 遍历图中所有边
|
|||
|
|
for source, target, edge_data in self.retriever.graph_data.edges(data=True):
|
|||
|
|
# 检查源和目标是否都是事件节点且在检索结果中
|
|||
|
|
if source in event_node_set and target in event_node_set:
|
|||
|
|
# 确认节点类型是事件
|
|||
|
|
source_type = self.retriever.graph_data.nodes[source].get('type', '')
|
|||
|
|
target_type = self.retriever.graph_data.nodes[target].get('type', '')
|
|||
|
|
|
|||
|
|
if source_type == 'event' and target_type == 'event':
|
|||
|
|
# 提取三元组信息
|
|||
|
|
source_node_data = self.retriever.graph_data.nodes[source]
|
|||
|
|
target_node_data = self.retriever.graph_data.nodes[target]
|
|||
|
|
|
|||
|
|
# 获取事件实体名称(优先使用name字段,这是中文名称)
|
|||
|
|
source_entity = source_node_data.get('name', source)
|
|||
|
|
target_entity = target_node_data.get('name', target)
|
|||
|
|
|
|||
|
|
# 获取关系名称(从边数据中)
|
|||
|
|
relation = edge_data.get('relation', edge_data.get('label', 'connected_to'))
|
|||
|
|
|
|||
|
|
# 获取evidence字段
|
|||
|
|
source_evidence = source_node_data.get('evidence', '')
|
|||
|
|
target_evidence = target_node_data.get('evidence', '')
|
|||
|
|
|
|||
|
|
# 构建增强的三元组,包含evidence信息
|
|||
|
|
triple = {
|
|||
|
|
'source_entity': str(source_entity),
|
|||
|
|
'relation': str(relation),
|
|||
|
|
'target_entity': str(target_entity),
|
|||
|
|
'source_evidence': str(source_evidence),
|
|||
|
|
'target_evidence': str(target_evidence)
|
|||
|
|
}
|
|||
|
|
triples.append(triple)
|
|||
|
|
|
|||
|
|
# 去重(因为可能有重复的三元组)
|
|||
|
|
unique_triples = []
|
|||
|
|
seen_triples = set()
|
|||
|
|
for triple in triples:
|
|||
|
|
# 使用三元组的核心信息作为去重键
|
|||
|
|
triple_key = (triple['source_entity'], triple['relation'], triple['target_entity'])
|
|||
|
|
if triple_key not in seen_triples:
|
|||
|
|
unique_triples.append(triple)
|
|||
|
|
seen_triples.add(triple_key)
|
|||
|
|
|
|||
|
|
print(f"[OK] 从图中提取到 {len(unique_triples)} 个唯一的事件-事件三元组")
|
|||
|
|
return unique_triples
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 事件三元组提取失败: {e}")
|
|||
|
|
return []
|
|||
|
|
|