440 lines
18 KiB
Python
440 lines
18 KiB
Python
"""
|
||
流式检索包装器 - 为现有检索器添加状态回调支持
|
||
保持向后兼容,不影响原有业务逻辑
|
||
"""
|
||
|
||
import time
|
||
from typing import Dict, Any, Callable, Optional
|
||
import traceback
|
||
from exceptions import TaskCancelledException, WorkflowStoppedException
|
||
|
||
|
||
def stream_retrieve(
|
||
retriever,
|
||
query: str,
|
||
mode: str = "0",
|
||
status_callback: Optional[Callable] = None,
|
||
task_id: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
带状态回调的检索包装器
|
||
|
||
在不修改原有检索器的情况下,通过包装添加状态回调功能
|
||
|
||
Args:
|
||
retriever: 原始检索器实例
|
||
query: 查询字符串
|
||
mode: 调试模式
|
||
status_callback: 状态回调函数,接收(status_type, data)参数
|
||
task_id: 任务ID(可选)
|
||
|
||
Returns:
|
||
检索结果字典
|
||
"""
|
||
|
||
# 如果没有提供回调,直接调用原始方法
|
||
if not status_callback:
|
||
return retriever.retrieve(query, mode)
|
||
|
||
try:
|
||
# 开始检索
|
||
status_callback("starting", {"message": "开始分析查询"})
|
||
|
||
# 创建增强的初始状态
|
||
from retriver.langgraph.graph_state import create_initial_state
|
||
initial_state = create_initial_state(
|
||
original_query=query,
|
||
max_iterations=retriever.max_iterations,
|
||
debug_mode=mode,
|
||
task_id=task_id # 传递任务ID
|
||
)
|
||
|
||
# 配置工作流执行
|
||
config = {
|
||
"recursion_limit": 50,
|
||
"metadata": {
|
||
"stream_mode": True,
|
||
"callback_enabled": True
|
||
},
|
||
"run_name": f"Stream-Retrieval-{query[:20]}..."
|
||
}
|
||
|
||
# 执行工作流,并在关键点插入回调
|
||
result_state = execute_workflow_with_callbacks(
|
||
retriever.workflow,
|
||
initial_state,
|
||
config,
|
||
status_callback
|
||
)
|
||
|
||
# 构建最终结果
|
||
total_time = result_state.get("debug_info", {}).get("total_time", 0)
|
||
|
||
# 构建结果(使用原有逻辑)
|
||
result = build_result(retriever, result_state, total_time)
|
||
|
||
# 保存完整结果到JSON文件(如果retriever有这个方法)
|
||
if hasattr(retriever, '_save_full_result_to_file'):
|
||
# 构建完整结果以便保存
|
||
full_result = {
|
||
**result,
|
||
"all_documents": result_state.get('all_documents', []),
|
||
"all_passages": result_state.get('all_passages', []),
|
||
"passage_sources": result_state.get('passage_sources', []),
|
||
"retrieval_path": result_state.get('retrieval_path', 'unknown'),
|
||
"initial_retrieval_details": result_state.get('initial_retrieval_details', {}),
|
||
"current_sub_queries": result_state.get('current_sub_queries', []),
|
||
"event_triples": result_state.get('event_triples', []) # 新增事件三元组信息
|
||
}
|
||
retriever._save_full_result_to_file(full_result)
|
||
print("[INFO] 流式检索结果已保存到JSON文件")
|
||
|
||
return result
|
||
|
||
except (TaskCancelledException, WorkflowStoppedException) as e:
|
||
print(f"[CANCELLED] 流式检索被取消: {e}")
|
||
if status_callback:
|
||
status_callback("cancelled", {"message": "任务已被取消"})
|
||
# 返回取消结果
|
||
return {
|
||
"query": query,
|
||
"answer": "任务已被取消",
|
||
"cancelled": True,
|
||
"total_passages": 0,
|
||
"iterations": 0,
|
||
"sub_queries": [],
|
||
"debug_info": {"cancelled": True, "message": str(e)}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 流式检索出错: {e}")
|
||
traceback.print_exc()
|
||
# 出错时返回基础结果
|
||
return {
|
||
"query": query,
|
||
"answer": f"检索过程中遇到错误: {str(e)}",
|
||
"error": str(e),
|
||
"total_passages": 0,
|
||
"iterations": 0,
|
||
"sub_queries": [],
|
||
"supporting_facts": [],
|
||
"supporting_events": []
|
||
}
|
||
|
||
|
||
def execute_workflow_with_callbacks(
|
||
workflow,
|
||
initial_state: Dict[str, Any],
|
||
config: Dict[str, Any],
|
||
status_callback: Callable
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行工作流并在关键节点插入状态回调
|
||
|
||
使用LangGraph的stream功能实时获取状态更新
|
||
"""
|
||
|
||
# 将回调函数注入到state中,供答案生成节点使用
|
||
initial_state['stream_callback'] = status_callback
|
||
print(f"[DEBUG] 已将stream_callback注入到initial_state")
|
||
|
||
# 创建一个状态监控器
|
||
class StateMonitor:
|
||
def __init__(self, callback):
|
||
self.callback = callback
|
||
self.last_iteration = 0
|
||
self.last_doc_count = 0
|
||
self.last_sub_queries = []
|
||
self.seen_nodes = set()
|
||
self.last_sufficiency_check = None # 记录上次的充分性检查结果
|
||
|
||
def check_state(self, state, node_name=None):
|
||
"""检查状态变化并触发回调"""
|
||
|
||
# 记录节点访问
|
||
if node_name and node_name not in self.seen_nodes:
|
||
self.seen_nodes.add(node_name)
|
||
|
||
# 检查子查询 - 节点完成后发送完整数据
|
||
if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]:
|
||
sub_queries = state["decomposed_sub_queries"]
|
||
if sub_queries != self.last_sub_queries:
|
||
self.callback("sub_queries", sub_queries)
|
||
self.last_sub_queries = sub_queries
|
||
|
||
# 检查文档数量 - 节点完成后发送完整数据
|
||
if "all_passages" in state:
|
||
doc_count = len(state.get("all_passages", []))
|
||
if doc_count != self.last_doc_count:
|
||
# 提取来源文档名,优先使用evidence字段
|
||
sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", []))
|
||
|
||
# 判断检索类型和原因
|
||
current_iteration = state.get("current_iteration", 0)
|
||
sub_queries = state.get("current_sub_queries", [])
|
||
|
||
# 确定检索原因
|
||
if self.last_doc_count == 0:
|
||
# 初始检索
|
||
retrieval_type = "初始检索"
|
||
retrieval_reason = "原始查询"
|
||
elif sub_queries and len(sub_queries) > 0:
|
||
# 基于子查询的检索
|
||
retrieval_type = f"子查询检索"
|
||
retrieval_reason = f"基于{len(sub_queries)}个新子查询"
|
||
else:
|
||
# 迭代检索
|
||
retrieval_type = f"第{current_iteration}轮迭代"
|
||
retrieval_reason = "迭代深化"
|
||
|
||
self.callback("documents", {
|
||
"count": doc_count,
|
||
"sources": sources,
|
||
"new_docs": doc_count - self.last_doc_count,
|
||
"retrieval_type": retrieval_type,
|
||
"retrieval_reason": retrieval_reason,
|
||
"is_incremental": self.last_doc_count > 0 # 是否为增量检索
|
||
})
|
||
self.last_doc_count = doc_count
|
||
|
||
# 检查迭代轮次
|
||
if "current_iteration" in state:
|
||
current = state["current_iteration"]
|
||
max_iter = state.get("max_iterations", 3)
|
||
if current != self.last_iteration:
|
||
self.callback("iteration", {
|
||
"current": current,
|
||
"max": max_iter
|
||
})
|
||
self.last_iteration = current
|
||
|
||
# 检查充分性判断 - 在每次文档检索后都应该显示
|
||
if "sufficiency_check" in state and state["sufficiency_check"]:
|
||
check = state["sufficiency_check"]
|
||
# 创建一个可比较的字符串表示
|
||
check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}"
|
||
|
||
# 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索)
|
||
doc_count = len(state.get("all_passages", []))
|
||
if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count):
|
||
self.callback("sufficiency_check", {
|
||
"is_sufficient": check.get("is_sufficient", False),
|
||
"confidence": check.get("confidence", 0),
|
||
"reason": check.get("reason", "")
|
||
})
|
||
self.last_sufficiency_check = check_str
|
||
|
||
# 如果充分,或者达到最大迭代次数,即将生成答案
|
||
current_iteration = state.get("current_iteration", 0)
|
||
max_iterations = state.get("max_iterations", 2)
|
||
if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1:
|
||
self.callback("generating", {"message": "正在生成最终答案..."})
|
||
|
||
return state
|
||
|
||
monitor = StateMonitor(status_callback)
|
||
final_state = None
|
||
|
||
# 尝试使用stream方法获取实时更新
|
||
if hasattr(workflow, 'stream'):
|
||
try:
|
||
# 使用stream方法逐步获取状态
|
||
for chunk in workflow.stream(initial_state, config=config):
|
||
# chunk通常是 {node_name: state} 的字典
|
||
if isinstance(chunk, dict):
|
||
for node_name, node_state in chunk.items():
|
||
# 实时检查每个节点的状态
|
||
monitor.check_state(node_state, node_name)
|
||
final_state = node_state
|
||
|
||
# 根据节点名称在节点开始时发送状态(只发送一次)
|
||
# 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复
|
||
if "query_complexity_check" in node_name:
|
||
# complexity_check在节点完成后会发送,这里不发送
|
||
pass
|
||
elif "query_decomposition" in node_name:
|
||
# 分解查询节点 - 不发送中间状态,结果会很快出现
|
||
pass
|
||
elif "initial_retrieval" in node_name:
|
||
# 初始检索节点 - 不发送中间状态,结果会很快出现
|
||
pass
|
||
elif "sufficiency_check" in node_name:
|
||
# 充分性检查节点开始 - 不发送,避免重复
|
||
pass
|
||
elif "sub_query_generation" in node_name:
|
||
# 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节
|
||
pass
|
||
# status_callback("generating_subqueries", {"message": "正在生成子查询..."})
|
||
elif "parallel_retrieval" in node_name:
|
||
# 并行检索节点开始 - 不显示,因为用户不需要看到这个细节
|
||
pass
|
||
# status_callback("parallel_retrieving", {"message": "正在并行检索..."})
|
||
elif "final" in node_name.lower() or "answer" in node_name.lower():
|
||
# 最终答案节点 - 状态已在充分性检查后发送,这里不重复
|
||
pass
|
||
|
||
except Exception as e:
|
||
# 如果stream失败,回退到invoke
|
||
final_state = workflow.invoke(initial_state, config=config)
|
||
monitor.check_state(final_state)
|
||
else:
|
||
# 不支持stream,使用普通invoke
|
||
final_state = workflow.invoke(initial_state, config=config)
|
||
monitor.check_state(final_state)
|
||
|
||
return final_state
|
||
|
||
|
||
def extract_sources_from_documents(all_documents, passage_sources):
|
||
"""
|
||
从documents中提取真实的文档名(优先使用evidence字段)
|
||
|
||
Args:
|
||
all_documents: 文档列表(包含metadata中的evidence字段)
|
||
passage_sources: 段落来源列表(作为后备)
|
||
|
||
Returns:
|
||
去重的文档名列表
|
||
"""
|
||
sources = set()
|
||
|
||
# 优先从all_documents中提取evidence字段
|
||
if all_documents:
|
||
for doc in all_documents:
|
||
if hasattr(doc, 'metadata'):
|
||
# 优先使用evidence字段(段落节点的真实文档名)
|
||
evidence = doc.metadata.get('evidence', '')
|
||
if evidence:
|
||
sources.add(evidence)
|
||
continue
|
||
|
||
# 其次使用source_text_id(事件节点的来源)
|
||
source_text_id = doc.metadata.get('source_text_id', '')
|
||
if source_text_id:
|
||
# 如果source_text_id是列表格式的字符串,解析它
|
||
if source_text_id.startswith('[') and source_text_id.endswith(']'):
|
||
try:
|
||
import ast
|
||
text_ids = ast.literal_eval(source_text_id)
|
||
if isinstance(text_ids, list) and text_ids:
|
||
sources.add(text_ids[0]) # 使用第一个来源
|
||
except:
|
||
sources.add(source_text_id)
|
||
else:
|
||
sources.add(source_text_id)
|
||
continue
|
||
|
||
# 如果没有找到evidence或source_text_id,使用原始的passage_sources
|
||
if not sources and passage_sources:
|
||
for source in passage_sources:
|
||
if isinstance(source, str):
|
||
# 提取文件名
|
||
if "/" in source:
|
||
filename = source.split("/")[-1]
|
||
else:
|
||
filename = source
|
||
sources.add(filename)
|
||
elif isinstance(source, dict):
|
||
# 如果是字典,尝试获取文件名字段
|
||
filename = source.get("filename") or source.get("source") or str(source)
|
||
sources.add(filename)
|
||
|
||
return list(sources)[:5] # 最多返回5个
|
||
|
||
|
||
def extract_sources(passage_sources):
|
||
"""
|
||
从passage_sources提取文档名列表(保留以兼容旧代码)
|
||
|
||
Args:
|
||
passage_sources: 段落来源列表
|
||
|
||
Returns:
|
||
去重的文档名列表
|
||
"""
|
||
return extract_sources_from_documents([], passage_sources)
|
||
|
||
|
||
def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]:
|
||
"""
|
||
构建最终结果(复用原有逻辑)
|
||
"""
|
||
|
||
# 获取Token使用信息(如果有的话)
|
||
token_info = {}
|
||
if hasattr(retriever, '_get_token_usage_info'):
|
||
token_info = retriever._get_token_usage_info()
|
||
|
||
result = {
|
||
"query": final_state.get('original_query', ''),
|
||
"answer": final_state.get('final_answer', '') or "未能生成答案",
|
||
|
||
# 查询复杂度信息
|
||
"query_complexity": final_state.get('query_complexity', {}),
|
||
"is_complex_query": final_state.get('is_complex_query', False),
|
||
|
||
# 检索统计
|
||
"iterations": final_state.get('current_iteration', 0),
|
||
"total_passages": len(final_state.get('all_passages', [])),
|
||
"sub_queries": final_state.get('sub_queries', []),
|
||
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
|
||
|
||
# 充分性信息
|
||
"sufficiency_check": final_state.get('sufficiency_check', {}),
|
||
"is_sufficient": final_state.get('is_sufficient', False),
|
||
|
||
# 事件三元组信息(新增)
|
||
"event_triples": final_state.get('event_triples', []),
|
||
"event_triples_count": len(final_state.get('event_triples', [])),
|
||
|
||
# 支撑信息
|
||
"supporting_facts": extract_supporting_facts(final_state),
|
||
"supporting_events": extract_supporting_events(final_state),
|
||
|
||
# 调试信息
|
||
"debug_info": {
|
||
"total_time": total_time,
|
||
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
|
||
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
|
||
"token_usage": token_info
|
||
}
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
def extract_supporting_facts(state):
|
||
"""提取支撑事实"""
|
||
passages = state.get('all_passages', [])
|
||
facts = []
|
||
|
||
for i, passage in enumerate(passages[:10]): # 最多10个
|
||
if isinstance(passage, str):
|
||
facts.append([f"Fact_{i+1}", passage[:200]])
|
||
elif isinstance(passage, dict):
|
||
content = passage.get('content', '') or passage.get('text', '')
|
||
facts.append([f"Fact_{i+1}", content[:200]])
|
||
|
||
return facts
|
||
|
||
|
||
def extract_supporting_events(state):
|
||
"""提取支撑事件"""
|
||
# 从状态中提取事件信息
|
||
events = []
|
||
|
||
# 可以从iteration_history中提取关键事件
|
||
iteration_history = state.get('iteration_history', [])
|
||
for item in iteration_history:
|
||
if item.get('action') == 'sub_query_generation':
|
||
events.append([
|
||
"子查询生成",
|
||
f"第{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询"
|
||
])
|
||
elif item.get('action') == 'parallel_retrieval':
|
||
events.append([
|
||
"并行检索",
|
||
f"第{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档"
|
||
])
|
||
|
||
return events[:5] # 最多5个事件 |