416 lines
16 KiB
Python
416 lines
16 KiB
Python
|
|
"""
|
|||
|
|
流式检索包装器 - 为现有检索器添加状态回调支持
|
|||
|
|
保持向后兼容,不影响原有业务逻辑
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
from typing import Dict, Any, Callable, Optional
|
|||
|
|
import traceback
|
|||
|
|
|
|||
|
|
|
|||
|
|
def stream_retrieve(
|
|||
|
|
retriever,
|
|||
|
|
query: str,
|
|||
|
|
mode: str = "0",
|
|||
|
|
status_callback: Optional[Callable] = None
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
带状态回调的检索包装器
|
|||
|
|
|
|||
|
|
在不修改原有检索器的情况下,通过包装添加状态回调功能
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
retriever: 原始检索器实例
|
|||
|
|
query: 查询字符串
|
|||
|
|
mode: 调试模式
|
|||
|
|
status_callback: 状态回调函数,接收(status_type, data)参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
检索结果字典
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 如果没有提供回调,直接调用原始方法
|
|||
|
|
if not status_callback:
|
|||
|
|
return retriever.retrieve(query, mode)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 开始检索
|
|||
|
|
status_callback("starting", {"message": "开始分析查询"})
|
|||
|
|
|
|||
|
|
# 创建增强的初始状态
|
|||
|
|
from retriver.langgraph.graph_state import create_initial_state
|
|||
|
|
initial_state = create_initial_state(
|
|||
|
|
original_query=query,
|
|||
|
|
max_iterations=retriever.max_iterations,
|
|||
|
|
debug_mode=mode
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 配置工作流执行
|
|||
|
|
config = {
|
|||
|
|
"recursion_limit": 50,
|
|||
|
|
"metadata": {
|
|||
|
|
"stream_mode": True,
|
|||
|
|
"callback_enabled": True
|
|||
|
|
},
|
|||
|
|
"run_name": f"Stream-Retrieval-{query[:20]}..."
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 执行工作流,并在关键点插入回调
|
|||
|
|
result_state = execute_workflow_with_callbacks(
|
|||
|
|
retriever.workflow,
|
|||
|
|
initial_state,
|
|||
|
|
config,
|
|||
|
|
status_callback
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 构建最终结果
|
|||
|
|
total_time = result_state.get("debug_info", {}).get("total_time", 0)
|
|||
|
|
|
|||
|
|
# 构建结果(使用原有逻辑)
|
|||
|
|
result = build_result(retriever, result_state, total_time)
|
|||
|
|
|
|||
|
|
# 保存完整结果到JSON文件(如果retriever有这个方法)
|
|||
|
|
if hasattr(retriever, '_save_full_result_to_file'):
|
|||
|
|
# 构建完整结果以便保存
|
|||
|
|
full_result = {
|
|||
|
|
**result,
|
|||
|
|
"all_documents": result_state.get('all_documents', []),
|
|||
|
|
"all_passages": result_state.get('all_passages', []),
|
|||
|
|
"passage_sources": result_state.get('passage_sources', []),
|
|||
|
|
"retrieval_path": result_state.get('retrieval_path', 'unknown'),
|
|||
|
|
"initial_retrieval_details": result_state.get('initial_retrieval_details', {}),
|
|||
|
|
"current_sub_queries": result_state.get('current_sub_queries', [])
|
|||
|
|
}
|
|||
|
|
retriever._save_full_result_to_file(full_result)
|
|||
|
|
print("[INFO] 流式检索结果已保存到JSON文件")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[ERROR] 流式检索出错: {e}")
|
|||
|
|
traceback.print_exc()
|
|||
|
|
# 出错时返回基础结果
|
|||
|
|
return {
|
|||
|
|
"query": query,
|
|||
|
|
"answer": f"检索过程中遇到错误: {str(e)}",
|
|||
|
|
"error": str(e),
|
|||
|
|
"total_passages": 0,
|
|||
|
|
"iterations": 0,
|
|||
|
|
"sub_queries": [],
|
|||
|
|
"supporting_facts": [],
|
|||
|
|
"supporting_events": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def execute_workflow_with_callbacks(
|
|||
|
|
workflow,
|
|||
|
|
initial_state: Dict[str, Any],
|
|||
|
|
config: Dict[str, Any],
|
|||
|
|
status_callback: Callable
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行工作流并在关键节点插入状态回调
|
|||
|
|
|
|||
|
|
使用LangGraph的stream功能实时获取状态更新
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 将回调函数注入到state中,供答案生成节点使用
|
|||
|
|
initial_state['stream_callback'] = status_callback
|
|||
|
|
print(f"[DEBUG] 已将stream_callback注入到initial_state")
|
|||
|
|
|
|||
|
|
# 创建一个状态监控器
|
|||
|
|
class StateMonitor:
|
|||
|
|
def __init__(self, callback):
|
|||
|
|
self.callback = callback
|
|||
|
|
self.last_iteration = 0
|
|||
|
|
self.last_doc_count = 0
|
|||
|
|
self.last_sub_queries = []
|
|||
|
|
self.seen_nodes = set()
|
|||
|
|
self.last_sufficiency_check = None # 记录上次的充分性检查结果
|
|||
|
|
|
|||
|
|
def check_state(self, state, node_name=None):
|
|||
|
|
"""检查状态变化并触发回调"""
|
|||
|
|
|
|||
|
|
# 记录节点访问
|
|||
|
|
if node_name and node_name not in self.seen_nodes:
|
|||
|
|
self.seen_nodes.add(node_name)
|
|||
|
|
|
|||
|
|
# 检查子查询 - 节点完成后发送完整数据
|
|||
|
|
if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]:
|
|||
|
|
sub_queries = state["decomposed_sub_queries"]
|
|||
|
|
if sub_queries != self.last_sub_queries:
|
|||
|
|
self.callback("sub_queries", sub_queries)
|
|||
|
|
self.last_sub_queries = sub_queries
|
|||
|
|
|
|||
|
|
# 检查文档数量 - 节点完成后发送完整数据
|
|||
|
|
if "all_passages" in state:
|
|||
|
|
doc_count = len(state.get("all_passages", []))
|
|||
|
|
if doc_count != self.last_doc_count:
|
|||
|
|
# 提取来源文档名,优先使用evidence字段
|
|||
|
|
sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", []))
|
|||
|
|
|
|||
|
|
# 判断检索类型和原因
|
|||
|
|
current_iteration = state.get("current_iteration", 0)
|
|||
|
|
sub_queries = state.get("current_sub_queries", [])
|
|||
|
|
|
|||
|
|
# 确定检索原因
|
|||
|
|
if self.last_doc_count == 0:
|
|||
|
|
# 初始检索
|
|||
|
|
retrieval_type = "初始检索"
|
|||
|
|
retrieval_reason = "原始查询"
|
|||
|
|
elif sub_queries and len(sub_queries) > 0:
|
|||
|
|
# 基于子查询的检索
|
|||
|
|
retrieval_type = f"子查询检索"
|
|||
|
|
retrieval_reason = f"基于{len(sub_queries)}个新子查询"
|
|||
|
|
else:
|
|||
|
|
# 迭代检索
|
|||
|
|
retrieval_type = f"第{current_iteration}轮迭代"
|
|||
|
|
retrieval_reason = "迭代深化"
|
|||
|
|
|
|||
|
|
self.callback("documents", {
|
|||
|
|
"count": doc_count,
|
|||
|
|
"sources": sources,
|
|||
|
|
"new_docs": doc_count - self.last_doc_count,
|
|||
|
|
"retrieval_type": retrieval_type,
|
|||
|
|
"retrieval_reason": retrieval_reason,
|
|||
|
|
"is_incremental": self.last_doc_count > 0 # 是否为增量检索
|
|||
|
|
})
|
|||
|
|
self.last_doc_count = doc_count
|
|||
|
|
|
|||
|
|
# 检查迭代轮次
|
|||
|
|
if "current_iteration" in state:
|
|||
|
|
current = state["current_iteration"]
|
|||
|
|
max_iter = state.get("max_iterations", 3)
|
|||
|
|
if current != self.last_iteration:
|
|||
|
|
self.callback("iteration", {
|
|||
|
|
"current": current,
|
|||
|
|
"max": max_iter
|
|||
|
|
})
|
|||
|
|
self.last_iteration = current
|
|||
|
|
|
|||
|
|
# 检查充分性判断 - 在每次文档检索后都应该显示
|
|||
|
|
if "sufficiency_check" in state and state["sufficiency_check"]:
|
|||
|
|
check = state["sufficiency_check"]
|
|||
|
|
# 创建一个可比较的字符串表示
|
|||
|
|
check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}"
|
|||
|
|
|
|||
|
|
# 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索)
|
|||
|
|
doc_count = len(state.get("all_passages", []))
|
|||
|
|
if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count):
|
|||
|
|
self.callback("sufficiency_check", {
|
|||
|
|
"is_sufficient": check.get("is_sufficient", False),
|
|||
|
|
"confidence": check.get("confidence", 0),
|
|||
|
|
"reason": check.get("reason", "")
|
|||
|
|
})
|
|||
|
|
self.last_sufficiency_check = check_str
|
|||
|
|
|
|||
|
|
# 如果充分,或者达到最大迭代次数,即将生成答案
|
|||
|
|
current_iteration = state.get("current_iteration", 0)
|
|||
|
|
max_iterations = state.get("max_iterations", 2)
|
|||
|
|
if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1:
|
|||
|
|
self.callback("generating", {"message": "正在生成最终答案..."})
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
monitor = StateMonitor(status_callback)
|
|||
|
|
final_state = None
|
|||
|
|
|
|||
|
|
# 尝试使用stream方法获取实时更新
|
|||
|
|
if hasattr(workflow, 'stream'):
|
|||
|
|
try:
|
|||
|
|
# 使用stream方法逐步获取状态
|
|||
|
|
for chunk in workflow.stream(initial_state, config=config):
|
|||
|
|
# chunk通常是 {node_name: state} 的字典
|
|||
|
|
if isinstance(chunk, dict):
|
|||
|
|
for node_name, node_state in chunk.items():
|
|||
|
|
# 实时检查每个节点的状态
|
|||
|
|
monitor.check_state(node_state, node_name)
|
|||
|
|
final_state = node_state
|
|||
|
|
|
|||
|
|
# 根据节点名称在节点开始时发送状态(只发送一次)
|
|||
|
|
# 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复
|
|||
|
|
if "query_complexity_check" in node_name:
|
|||
|
|
# complexity_check在节点完成后会发送,这里不发送
|
|||
|
|
pass
|
|||
|
|
elif "query_decomposition" in node_name:
|
|||
|
|
# 分解查询节点 - 不发送中间状态,结果会很快出现
|
|||
|
|
pass
|
|||
|
|
elif "initial_retrieval" in node_name:
|
|||
|
|
# 初始检索节点 - 不发送中间状态,结果会很快出现
|
|||
|
|
pass
|
|||
|
|
elif "sufficiency_check" in node_name:
|
|||
|
|
# 充分性检查节点开始 - 不发送,避免重复
|
|||
|
|
pass
|
|||
|
|
elif "sub_query_generation" in node_name:
|
|||
|
|
# 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节
|
|||
|
|
pass
|
|||
|
|
# status_callback("generating_subqueries", {"message": "正在生成子查询..."})
|
|||
|
|
elif "parallel_retrieval" in node_name:
|
|||
|
|
# 并行检索节点开始 - 不显示,因为用户不需要看到这个细节
|
|||
|
|
pass
|
|||
|
|
# status_callback("parallel_retrieving", {"message": "正在并行检索..."})
|
|||
|
|
elif "final" in node_name.lower() or "answer" in node_name.lower():
|
|||
|
|
# 最终答案节点 - 状态已在充分性检查后发送,这里不重复
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 如果stream失败,回退到invoke
|
|||
|
|
final_state = workflow.invoke(initial_state, config=config)
|
|||
|
|
monitor.check_state(final_state)
|
|||
|
|
else:
|
|||
|
|
# 不支持stream,使用普通invoke
|
|||
|
|
final_state = workflow.invoke(initial_state, config=config)
|
|||
|
|
monitor.check_state(final_state)
|
|||
|
|
|
|||
|
|
return final_state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_sources_from_documents(all_documents, passage_sources):
|
|||
|
|
"""
|
|||
|
|
从documents中提取真实的文档名(优先使用evidence字段)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
all_documents: 文档列表(包含metadata中的evidence字段)
|
|||
|
|
passage_sources: 段落来源列表(作为后备)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
去重的文档名列表
|
|||
|
|
"""
|
|||
|
|
sources = set()
|
|||
|
|
|
|||
|
|
# 优先从all_documents中提取evidence字段
|
|||
|
|
if all_documents:
|
|||
|
|
for doc in all_documents:
|
|||
|
|
if hasattr(doc, 'metadata'):
|
|||
|
|
# 优先使用evidence字段(段落节点的真实文档名)
|
|||
|
|
evidence = doc.metadata.get('evidence', '')
|
|||
|
|
if evidence:
|
|||
|
|
sources.add(evidence)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 其次使用source_text_id(事件节点的来源)
|
|||
|
|
source_text_id = doc.metadata.get('source_text_id', '')
|
|||
|
|
if source_text_id:
|
|||
|
|
# 如果source_text_id是列表格式的字符串,解析它
|
|||
|
|
if source_text_id.startswith('[') and source_text_id.endswith(']'):
|
|||
|
|
try:
|
|||
|
|
import ast
|
|||
|
|
text_ids = ast.literal_eval(source_text_id)
|
|||
|
|
if isinstance(text_ids, list) and text_ids:
|
|||
|
|
sources.add(text_ids[0]) # 使用第一个来源
|
|||
|
|
except:
|
|||
|
|
sources.add(source_text_id)
|
|||
|
|
else:
|
|||
|
|
sources.add(source_text_id)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 如果没有找到evidence或source_text_id,使用原始的passage_sources
|
|||
|
|
if not sources and passage_sources:
|
|||
|
|
for source in passage_sources:
|
|||
|
|
if isinstance(source, str):
|
|||
|
|
# 提取文件名
|
|||
|
|
if "/" in source:
|
|||
|
|
filename = source.split("/")[-1]
|
|||
|
|
else:
|
|||
|
|
filename = source
|
|||
|
|
sources.add(filename)
|
|||
|
|
elif isinstance(source, dict):
|
|||
|
|
# 如果是字典,尝试获取文件名字段
|
|||
|
|
filename = source.get("filename") or source.get("source") or str(source)
|
|||
|
|
sources.add(filename)
|
|||
|
|
|
|||
|
|
return list(sources)[:5] # 最多返回5个
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_sources(passage_sources):
|
|||
|
|
"""
|
|||
|
|
从passage_sources提取文档名列表(保留以兼容旧代码)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
passage_sources: 段落来源列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
去重的文档名列表
|
|||
|
|
"""
|
|||
|
|
return extract_sources_from_documents([], passage_sources)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
构建最终结果(复用原有逻辑)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 获取Token使用信息(如果有的话)
|
|||
|
|
token_info = {}
|
|||
|
|
if hasattr(retriever, '_get_token_usage_info'):
|
|||
|
|
token_info = retriever._get_token_usage_info()
|
|||
|
|
|
|||
|
|
result = {
|
|||
|
|
"query": final_state.get('original_query', ''),
|
|||
|
|
"answer": final_state.get('final_answer', '') or "未能生成答案",
|
|||
|
|
|
|||
|
|
# 查询复杂度信息
|
|||
|
|
"query_complexity": final_state.get('query_complexity', {}),
|
|||
|
|
"is_complex_query": final_state.get('is_complex_query', False),
|
|||
|
|
|
|||
|
|
# 检索统计
|
|||
|
|
"iterations": final_state.get('current_iteration', 0),
|
|||
|
|
"total_passages": len(final_state.get('all_passages', [])),
|
|||
|
|
"sub_queries": final_state.get('sub_queries', []),
|
|||
|
|
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
|
|||
|
|
|
|||
|
|
# 充分性信息
|
|||
|
|
"sufficiency_check": final_state.get('sufficiency_check', {}),
|
|||
|
|
"is_sufficient": final_state.get('is_sufficient', False),
|
|||
|
|
|
|||
|
|
# 支撑信息
|
|||
|
|
"supporting_facts": extract_supporting_facts(final_state),
|
|||
|
|
"supporting_events": extract_supporting_events(final_state),
|
|||
|
|
|
|||
|
|
# 调试信息
|
|||
|
|
"debug_info": {
|
|||
|
|
"total_time": total_time,
|
|||
|
|
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
|
|||
|
|
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
|
|||
|
|
"token_usage": token_info
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_supporting_facts(state):
|
|||
|
|
"""提取支撑事实"""
|
|||
|
|
passages = state.get('all_passages', [])
|
|||
|
|
facts = []
|
|||
|
|
|
|||
|
|
for i, passage in enumerate(passages[:10]): # 最多10个
|
|||
|
|
if isinstance(passage, str):
|
|||
|
|
facts.append([f"Fact_{i+1}", passage[:200]])
|
|||
|
|
elif isinstance(passage, dict):
|
|||
|
|
content = passage.get('content', '') or passage.get('text', '')
|
|||
|
|
facts.append([f"Fact_{i+1}", content[:200]])
|
|||
|
|
|
|||
|
|
return facts
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_supporting_events(state):
|
|||
|
|
"""提取支撑事件"""
|
|||
|
|
# 从状态中提取事件信息
|
|||
|
|
events = []
|
|||
|
|
|
|||
|
|
# 可以从iteration_history中提取关键事件
|
|||
|
|
iteration_history = state.get('iteration_history', [])
|
|||
|
|
for item in iteration_history:
|
|||
|
|
if item.get('action') == 'sub_query_generation':
|
|||
|
|
events.append([
|
|||
|
|
"子查询生成",
|
|||
|
|
f"第{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询"
|
|||
|
|
])
|
|||
|
|
elif item.get('action') == 'parallel_retrieval':
|
|||
|
|
events.append([
|
|||
|
|
"并行检索",
|
|||
|
|
f"第{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档"
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
return events[:5] # 最多5个事件
|