Files
AIEC-new/AIEC-RAG/retriver/langsmith/langsmith_retriever_stream.py
2025-10-17 09:31:28 +08:00

440 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式检索包装器 - 为现有检索器添加状态回调支持
保持向后兼容,不影响原有业务逻辑
"""
import time
from typing import Dict, Any, Callable, Optional
import traceback
from exceptions import TaskCancelledException, WorkflowStoppedException
def stream_retrieve(
retriever,
query: str,
mode: str = "0",
status_callback: Optional[Callable] = None,
task_id: Optional[str] = None
) -> Dict[str, Any]:
"""
带状态回调的检索包装器
在不修改原有检索器的情况下,通过包装添加状态回调功能
Args:
retriever: 原始检索器实例
query: 查询字符串
mode: 调试模式
status_callback: 状态回调函数,接收(status_type, data)参数
task_id: 任务ID可选
Returns:
检索结果字典
"""
# 如果没有提供回调,直接调用原始方法
if not status_callback:
return retriever.retrieve(query, mode)
try:
# 开始检索
status_callback("starting", {"message": "开始分析查询"})
# 创建增强的初始状态
from retriver.langgraph.graph_state import create_initial_state
initial_state = create_initial_state(
original_query=query,
max_iterations=retriever.max_iterations,
debug_mode=mode,
task_id=task_id # 传递任务ID
)
# 配置工作流执行
config = {
"recursion_limit": 50,
"metadata": {
"stream_mode": True,
"callback_enabled": True
},
"run_name": f"Stream-Retrieval-{query[:20]}..."
}
# 执行工作流,并在关键点插入回调
result_state = execute_workflow_with_callbacks(
retriever.workflow,
initial_state,
config,
status_callback
)
# 构建最终结果
total_time = result_state.get("debug_info", {}).get("total_time", 0)
# 构建结果(使用原有逻辑)
result = build_result(retriever, result_state, total_time)
# 保存完整结果到JSON文件如果retriever有这个方法
if hasattr(retriever, '_save_full_result_to_file'):
# 构建完整结果以便保存
full_result = {
**result,
"all_documents": result_state.get('all_documents', []),
"all_passages": result_state.get('all_passages', []),
"passage_sources": result_state.get('passage_sources', []),
"retrieval_path": result_state.get('retrieval_path', 'unknown'),
"initial_retrieval_details": result_state.get('initial_retrieval_details', {}),
"current_sub_queries": result_state.get('current_sub_queries', []),
"event_triples": result_state.get('event_triples', []) # 新增事件三元组信息
}
retriever._save_full_result_to_file(full_result)
print("[INFO] 流式检索结果已保存到JSON文件")
return result
except (TaskCancelledException, WorkflowStoppedException) as e:
print(f"[CANCELLED] 流式检索被取消: {e}")
if status_callback:
status_callback("cancelled", {"message": "任务已被取消"})
# 返回取消结果
return {
"query": query,
"answer": "任务已被取消",
"cancelled": True,
"total_passages": 0,
"iterations": 0,
"sub_queries": [],
"debug_info": {"cancelled": True, "message": str(e)}
}
except Exception as e:
print(f"[ERROR] 流式检索出错: {e}")
traceback.print_exc()
# 出错时返回基础结果
return {
"query": query,
"answer": f"检索过程中遇到错误: {str(e)}",
"error": str(e),
"total_passages": 0,
"iterations": 0,
"sub_queries": [],
"supporting_facts": [],
"supporting_events": []
}
def execute_workflow_with_callbacks(
workflow,
initial_state: Dict[str, Any],
config: Dict[str, Any],
status_callback: Callable
) -> Dict[str, Any]:
"""
执行工作流并在关键节点插入状态回调
使用LangGraph的stream功能实时获取状态更新
"""
# 将回调函数注入到state中供答案生成节点使用
initial_state['stream_callback'] = status_callback
print(f"[DEBUG] 已将stream_callback注入到initial_state")
# 创建一个状态监控器
class StateMonitor:
def __init__(self, callback):
self.callback = callback
self.last_iteration = 0
self.last_doc_count = 0
self.last_sub_queries = []
self.seen_nodes = set()
self.last_sufficiency_check = None # 记录上次的充分性检查结果
def check_state(self, state, node_name=None):
"""检查状态变化并触发回调"""
# 记录节点访问
if node_name and node_name not in self.seen_nodes:
self.seen_nodes.add(node_name)
# 检查子查询 - 节点完成后发送完整数据
if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]:
sub_queries = state["decomposed_sub_queries"]
if sub_queries != self.last_sub_queries:
self.callback("sub_queries", sub_queries)
self.last_sub_queries = sub_queries
# 检查文档数量 - 节点完成后发送完整数据
if "all_passages" in state:
doc_count = len(state.get("all_passages", []))
if doc_count != self.last_doc_count:
# 提取来源文档名优先使用evidence字段
sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", []))
# 判断检索类型和原因
current_iteration = state.get("current_iteration", 0)
sub_queries = state.get("current_sub_queries", [])
# 确定检索原因
if self.last_doc_count == 0:
# 初始检索
retrieval_type = "初始检索"
retrieval_reason = "原始查询"
elif sub_queries and len(sub_queries) > 0:
# 基于子查询的检索
retrieval_type = f"子查询检索"
retrieval_reason = f"基于{len(sub_queries)}个新子查询"
else:
# 迭代检索
retrieval_type = f"{current_iteration}轮迭代"
retrieval_reason = "迭代深化"
self.callback("documents", {
"count": doc_count,
"sources": sources,
"new_docs": doc_count - self.last_doc_count,
"retrieval_type": retrieval_type,
"retrieval_reason": retrieval_reason,
"is_incremental": self.last_doc_count > 0 # 是否为增量检索
})
self.last_doc_count = doc_count
# 检查迭代轮次
if "current_iteration" in state:
current = state["current_iteration"]
max_iter = state.get("max_iterations", 3)
if current != self.last_iteration:
self.callback("iteration", {
"current": current,
"max": max_iter
})
self.last_iteration = current
# 检查充分性判断 - 在每次文档检索后都应该显示
if "sufficiency_check" in state and state["sufficiency_check"]:
check = state["sufficiency_check"]
# 创建一个可比较的字符串表示
check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}"
# 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索)
doc_count = len(state.get("all_passages", []))
if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count):
self.callback("sufficiency_check", {
"is_sufficient": check.get("is_sufficient", False),
"confidence": check.get("confidence", 0),
"reason": check.get("reason", "")
})
self.last_sufficiency_check = check_str
# 如果充分,或者达到最大迭代次数,即将生成答案
current_iteration = state.get("current_iteration", 0)
max_iterations = state.get("max_iterations", 2)
if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1:
self.callback("generating", {"message": "正在生成最终答案..."})
return state
monitor = StateMonitor(status_callback)
final_state = None
# 尝试使用stream方法获取实时更新
if hasattr(workflow, 'stream'):
try:
# 使用stream方法逐步获取状态
for chunk in workflow.stream(initial_state, config=config):
# chunk通常是 {node_name: state} 的字典
if isinstance(chunk, dict):
for node_name, node_state in chunk.items():
# 实时检查每个节点的状态
monitor.check_state(node_state, node_name)
final_state = node_state
# 根据节点名称在节点开始时发送状态(只发送一次)
# 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复
if "query_complexity_check" in node_name:
# complexity_check在节点完成后会发送这里不发送
pass
elif "query_decomposition" in node_name:
# 分解查询节点 - 不发送中间状态,结果会很快出现
pass
elif "initial_retrieval" in node_name:
# 初始检索节点 - 不发送中间状态,结果会很快出现
pass
elif "sufficiency_check" in node_name:
# 充分性检查节点开始 - 不发送,避免重复
pass
elif "sub_query_generation" in node_name:
# 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节
pass
# status_callback("generating_subqueries", {"message": "正在生成子查询..."})
elif "parallel_retrieval" in node_name:
# 并行检索节点开始 - 不显示,因为用户不需要看到这个细节
pass
# status_callback("parallel_retrieving", {"message": "正在并行检索..."})
elif "final" in node_name.lower() or "answer" in node_name.lower():
# 最终答案节点 - 状态已在充分性检查后发送,这里不重复
pass
except Exception as e:
# 如果stream失败回退到invoke
final_state = workflow.invoke(initial_state, config=config)
monitor.check_state(final_state)
else:
# 不支持stream使用普通invoke
final_state = workflow.invoke(initial_state, config=config)
monitor.check_state(final_state)
return final_state
def extract_sources_from_documents(all_documents, passage_sources):
"""
从documents中提取真实的文档名优先使用evidence字段
Args:
all_documents: 文档列表包含metadata中的evidence字段
passage_sources: 段落来源列表(作为后备)
Returns:
去重的文档名列表
"""
sources = set()
# 优先从all_documents中提取evidence字段
if all_documents:
for doc in all_documents:
if hasattr(doc, 'metadata'):
# 优先使用evidence字段段落节点的真实文档名
evidence = doc.metadata.get('evidence', '')
if evidence:
sources.add(evidence)
continue
# 其次使用source_text_id事件节点的来源
source_text_id = doc.metadata.get('source_text_id', '')
if source_text_id:
# 如果source_text_id是列表格式的字符串解析它
if source_text_id.startswith('[') and source_text_id.endswith(']'):
try:
import ast
text_ids = ast.literal_eval(source_text_id)
if isinstance(text_ids, list) and text_ids:
sources.add(text_ids[0]) # 使用第一个来源
except:
sources.add(source_text_id)
else:
sources.add(source_text_id)
continue
# 如果没有找到evidence或source_text_id使用原始的passage_sources
if not sources and passage_sources:
for source in passage_sources:
if isinstance(source, str):
# 提取文件名
if "/" in source:
filename = source.split("/")[-1]
else:
filename = source
sources.add(filename)
elif isinstance(source, dict):
# 如果是字典,尝试获取文件名字段
filename = source.get("filename") or source.get("source") or str(source)
sources.add(filename)
return list(sources)[:5] # 最多返回5个
def extract_sources(passage_sources):
"""
从passage_sources提取文档名列表保留以兼容旧代码
Args:
passage_sources: 段落来源列表
Returns:
去重的文档名列表
"""
return extract_sources_from_documents([], passage_sources)
def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]:
"""
构建最终结果(复用原有逻辑)
"""
# 获取Token使用信息如果有的话
token_info = {}
if hasattr(retriever, '_get_token_usage_info'):
token_info = retriever._get_token_usage_info()
result = {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
# 检索统计
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
# 充分性信息
"sufficiency_check": final_state.get('sufficiency_check', {}),
"is_sufficient": final_state.get('is_sufficient', False),
# 事件三元组信息(新增)
"event_triples": final_state.get('event_triples', []),
"event_triples_count": len(final_state.get('event_triples', [])),
# 支撑信息
"supporting_facts": extract_supporting_facts(final_state),
"supporting_events": extract_supporting_events(final_state),
# 调试信息
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"token_usage": token_info
}
}
return result
def extract_supporting_facts(state):
"""提取支撑事实"""
passages = state.get('all_passages', [])
facts = []
for i, passage in enumerate(passages[:10]): # 最多10个
if isinstance(passage, str):
facts.append([f"Fact_{i+1}", passage[:200]])
elif isinstance(passage, dict):
content = passage.get('content', '') or passage.get('text', '')
facts.append([f"Fact_{i+1}", content[:200]])
return facts
def extract_supporting_events(state):
"""提取支撑事件"""
# 从状态中提取事件信息
events = []
# 可以从iteration_history中提取关键事件
iteration_history = state.get('iteration_history', [])
for item in iteration_history:
if item.get('action') == 'sub_query_generation':
events.append([
"子查询生成",
f"{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询"
])
elif item.get('action') == 'parallel_retrieval':
events.append([
"并行检索",
f"{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档"
])
return events[:5] # 最多5个事件