first commit

This commit is contained in:
闫旭隆
2025-10-17 09:31:28 +08:00
commit 4698145045
589 changed files with 196795 additions and 0 deletions

View File

@ -0,0 +1,440 @@
"""
流式检索包装器 - 为现有检索器添加状态回调支持
保持向后兼容,不影响原有业务逻辑
"""
import time
from typing import Dict, Any, Callable, Optional
import traceback
from exceptions import TaskCancelledException, WorkflowStoppedException
def stream_retrieve(
retriever,
query: str,
mode: str = "0",
status_callback: Optional[Callable] = None,
task_id: Optional[str] = None
) -> Dict[str, Any]:
"""
带状态回调的检索包装器
在不修改原有检索器的情况下,通过包装添加状态回调功能
Args:
retriever: 原始检索器实例
query: 查询字符串
mode: 调试模式
status_callback: 状态回调函数,接收(status_type, data)参数
task_id: 任务ID可选
Returns:
检索结果字典
"""
# 如果没有提供回调,直接调用原始方法
if not status_callback:
return retriever.retrieve(query, mode)
try:
# 开始检索
status_callback("starting", {"message": "开始分析查询"})
# 创建增强的初始状态
from retriver.langgraph.graph_state import create_initial_state
initial_state = create_initial_state(
original_query=query,
max_iterations=retriever.max_iterations,
debug_mode=mode,
task_id=task_id # 传递任务ID
)
# 配置工作流执行
config = {
"recursion_limit": 50,
"metadata": {
"stream_mode": True,
"callback_enabled": True
},
"run_name": f"Stream-Retrieval-{query[:20]}..."
}
# 执行工作流,并在关键点插入回调
result_state = execute_workflow_with_callbacks(
retriever.workflow,
initial_state,
config,
status_callback
)
# 构建最终结果
total_time = result_state.get("debug_info", {}).get("total_time", 0)
# 构建结果(使用原有逻辑)
result = build_result(retriever, result_state, total_time)
# 保存完整结果到JSON文件如果retriever有这个方法
if hasattr(retriever, '_save_full_result_to_file'):
# 构建完整结果以便保存
full_result = {
**result,
"all_documents": result_state.get('all_documents', []),
"all_passages": result_state.get('all_passages', []),
"passage_sources": result_state.get('passage_sources', []),
"retrieval_path": result_state.get('retrieval_path', 'unknown'),
"initial_retrieval_details": result_state.get('initial_retrieval_details', {}),
"current_sub_queries": result_state.get('current_sub_queries', []),
"event_triples": result_state.get('event_triples', []) # 新增事件三元组信息
}
retriever._save_full_result_to_file(full_result)
print("[INFO] 流式检索结果已保存到JSON文件")
return result
except (TaskCancelledException, WorkflowStoppedException) as e:
print(f"[CANCELLED] 流式检索被取消: {e}")
if status_callback:
status_callback("cancelled", {"message": "任务已被取消"})
# 返回取消结果
return {
"query": query,
"answer": "任务已被取消",
"cancelled": True,
"total_passages": 0,
"iterations": 0,
"sub_queries": [],
"debug_info": {"cancelled": True, "message": str(e)}
}
except Exception as e:
print(f"[ERROR] 流式检索出错: {e}")
traceback.print_exc()
# 出错时返回基础结果
return {
"query": query,
"answer": f"检索过程中遇到错误: {str(e)}",
"error": str(e),
"total_passages": 0,
"iterations": 0,
"sub_queries": [],
"supporting_facts": [],
"supporting_events": []
}
def execute_workflow_with_callbacks(
workflow,
initial_state: Dict[str, Any],
config: Dict[str, Any],
status_callback: Callable
) -> Dict[str, Any]:
"""
执行工作流并在关键节点插入状态回调
使用LangGraph的stream功能实时获取状态更新
"""
# 将回调函数注入到state中供答案生成节点使用
initial_state['stream_callback'] = status_callback
print(f"[DEBUG] 已将stream_callback注入到initial_state")
# 创建一个状态监控器
class StateMonitor:
def __init__(self, callback):
self.callback = callback
self.last_iteration = 0
self.last_doc_count = 0
self.last_sub_queries = []
self.seen_nodes = set()
self.last_sufficiency_check = None # 记录上次的充分性检查结果
def check_state(self, state, node_name=None):
"""检查状态变化并触发回调"""
# 记录节点访问
if node_name and node_name not in self.seen_nodes:
self.seen_nodes.add(node_name)
# 检查子查询 - 节点完成后发送完整数据
if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]:
sub_queries = state["decomposed_sub_queries"]
if sub_queries != self.last_sub_queries:
self.callback("sub_queries", sub_queries)
self.last_sub_queries = sub_queries
# 检查文档数量 - 节点完成后发送完整数据
if "all_passages" in state:
doc_count = len(state.get("all_passages", []))
if doc_count != self.last_doc_count:
# 提取来源文档名优先使用evidence字段
sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", []))
# 判断检索类型和原因
current_iteration = state.get("current_iteration", 0)
sub_queries = state.get("current_sub_queries", [])
# 确定检索原因
if self.last_doc_count == 0:
# 初始检索
retrieval_type = "初始检索"
retrieval_reason = "原始查询"
elif sub_queries and len(sub_queries) > 0:
# 基于子查询的检索
retrieval_type = f"子查询检索"
retrieval_reason = f"基于{len(sub_queries)}个新子查询"
else:
# 迭代检索
retrieval_type = f"{current_iteration}轮迭代"
retrieval_reason = "迭代深化"
self.callback("documents", {
"count": doc_count,
"sources": sources,
"new_docs": doc_count - self.last_doc_count,
"retrieval_type": retrieval_type,
"retrieval_reason": retrieval_reason,
"is_incremental": self.last_doc_count > 0 # 是否为增量检索
})
self.last_doc_count = doc_count
# 检查迭代轮次
if "current_iteration" in state:
current = state["current_iteration"]
max_iter = state.get("max_iterations", 3)
if current != self.last_iteration:
self.callback("iteration", {
"current": current,
"max": max_iter
})
self.last_iteration = current
# 检查充分性判断 - 在每次文档检索后都应该显示
if "sufficiency_check" in state and state["sufficiency_check"]:
check = state["sufficiency_check"]
# 创建一个可比较的字符串表示
check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}"
# 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索)
doc_count = len(state.get("all_passages", []))
if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count):
self.callback("sufficiency_check", {
"is_sufficient": check.get("is_sufficient", False),
"confidence": check.get("confidence", 0),
"reason": check.get("reason", "")
})
self.last_sufficiency_check = check_str
# 如果充分,或者达到最大迭代次数,即将生成答案
current_iteration = state.get("current_iteration", 0)
max_iterations = state.get("max_iterations", 2)
if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1:
self.callback("generating", {"message": "正在生成最终答案..."})
return state
monitor = StateMonitor(status_callback)
final_state = None
# 尝试使用stream方法获取实时更新
if hasattr(workflow, 'stream'):
try:
# 使用stream方法逐步获取状态
for chunk in workflow.stream(initial_state, config=config):
# chunk通常是 {node_name: state} 的字典
if isinstance(chunk, dict):
for node_name, node_state in chunk.items():
# 实时检查每个节点的状态
monitor.check_state(node_state, node_name)
final_state = node_state
# 根据节点名称在节点开始时发送状态(只发送一次)
# 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复
if "query_complexity_check" in node_name:
# complexity_check在节点完成后会发送这里不发送
pass
elif "query_decomposition" in node_name:
# 分解查询节点 - 不发送中间状态,结果会很快出现
pass
elif "initial_retrieval" in node_name:
# 初始检索节点 - 不发送中间状态,结果会很快出现
pass
elif "sufficiency_check" in node_name:
# 充分性检查节点开始 - 不发送,避免重复
pass
elif "sub_query_generation" in node_name:
# 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节
pass
# status_callback("generating_subqueries", {"message": "正在生成子查询..."})
elif "parallel_retrieval" in node_name:
# 并行检索节点开始 - 不显示,因为用户不需要看到这个细节
pass
# status_callback("parallel_retrieving", {"message": "正在并行检索..."})
elif "final" in node_name.lower() or "answer" in node_name.lower():
# 最终答案节点 - 状态已在充分性检查后发送,这里不重复
pass
except Exception as e:
# 如果stream失败回退到invoke
final_state = workflow.invoke(initial_state, config=config)
monitor.check_state(final_state)
else:
# 不支持stream使用普通invoke
final_state = workflow.invoke(initial_state, config=config)
monitor.check_state(final_state)
return final_state
def extract_sources_from_documents(all_documents, passage_sources):
"""
从documents中提取真实的文档名优先使用evidence字段
Args:
all_documents: 文档列表包含metadata中的evidence字段
passage_sources: 段落来源列表(作为后备)
Returns:
去重的文档名列表
"""
sources = set()
# 优先从all_documents中提取evidence字段
if all_documents:
for doc in all_documents:
if hasattr(doc, 'metadata'):
# 优先使用evidence字段段落节点的真实文档名
evidence = doc.metadata.get('evidence', '')
if evidence:
sources.add(evidence)
continue
# 其次使用source_text_id事件节点的来源
source_text_id = doc.metadata.get('source_text_id', '')
if source_text_id:
# 如果source_text_id是列表格式的字符串解析它
if source_text_id.startswith('[') and source_text_id.endswith(']'):
try:
import ast
text_ids = ast.literal_eval(source_text_id)
if isinstance(text_ids, list) and text_ids:
sources.add(text_ids[0]) # 使用第一个来源
except:
sources.add(source_text_id)
else:
sources.add(source_text_id)
continue
# 如果没有找到evidence或source_text_id使用原始的passage_sources
if not sources and passage_sources:
for source in passage_sources:
if isinstance(source, str):
# 提取文件名
if "/" in source:
filename = source.split("/")[-1]
else:
filename = source
sources.add(filename)
elif isinstance(source, dict):
# 如果是字典,尝试获取文件名字段
filename = source.get("filename") or source.get("source") or str(source)
sources.add(filename)
return list(sources)[:5] # 最多返回5个
def extract_sources(passage_sources):
"""
从passage_sources提取文档名列表保留以兼容旧代码
Args:
passage_sources: 段落来源列表
Returns:
去重的文档名列表
"""
return extract_sources_from_documents([], passage_sources)
def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]:
"""
构建最终结果(复用原有逻辑)
"""
# 获取Token使用信息如果有的话
token_info = {}
if hasattr(retriever, '_get_token_usage_info'):
token_info = retriever._get_token_usage_info()
result = {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
# 检索统计
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
# 充分性信息
"sufficiency_check": final_state.get('sufficiency_check', {}),
"is_sufficient": final_state.get('is_sufficient', False),
# 事件三元组信息(新增)
"event_triples": final_state.get('event_triples', []),
"event_triples_count": len(final_state.get('event_triples', [])),
# 支撑信息
"supporting_facts": extract_supporting_facts(final_state),
"supporting_events": extract_supporting_events(final_state),
# 调试信息
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"token_usage": token_info
}
}
return result
def extract_supporting_facts(state):
"""提取支撑事实"""
passages = state.get('all_passages', [])
facts = []
for i, passage in enumerate(passages[:10]): # 最多10个
if isinstance(passage, str):
facts.append([f"Fact_{i+1}", passage[:200]])
elif isinstance(passage, dict):
content = passage.get('content', '') or passage.get('text', '')
facts.append([f"Fact_{i+1}", content[:200]])
return facts
def extract_supporting_events(state):
"""提取支撑事件"""
# 从状态中提取事件信息
events = []
# 可以从iteration_history中提取关键事件
iteration_history = state.get('iteration_history', [])
for item in iteration_history:
if item.get('action') == 'sub_query_generation':
events.append([
"子查询生成",
f"{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询"
])
elif item.get('action') == 'parallel_retrieval':
events.append([
"并行检索",
f"{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档"
])
return events[:5] # 最多5个事件