first commit
This commit is contained in:
440
AIEC-RAG/retriver/langsmith/langsmith_retriever_stream.py
Normal file
440
AIEC-RAG/retriver/langsmith/langsmith_retriever_stream.py
Normal file
@ -0,0 +1,440 @@
|
||||
"""
|
||||
流式检索包装器 - 为现有检索器添加状态回调支持
|
||||
保持向后兼容,不影响原有业务逻辑
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Callable, Optional
|
||||
import traceback
|
||||
from exceptions import TaskCancelledException, WorkflowStoppedException
|
||||
|
||||
|
||||
def stream_retrieve(
|
||||
retriever,
|
||||
query: str,
|
||||
mode: str = "0",
|
||||
status_callback: Optional[Callable] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带状态回调的检索包装器
|
||||
|
||||
在不修改原有检索器的情况下,通过包装添加状态回调功能
|
||||
|
||||
Args:
|
||||
retriever: 原始检索器实例
|
||||
query: 查询字符串
|
||||
mode: 调试模式
|
||||
status_callback: 状态回调函数,接收(status_type, data)参数
|
||||
task_id: 任务ID(可选)
|
||||
|
||||
Returns:
|
||||
检索结果字典
|
||||
"""
|
||||
|
||||
# 如果没有提供回调,直接调用原始方法
|
||||
if not status_callback:
|
||||
return retriever.retrieve(query, mode)
|
||||
|
||||
try:
|
||||
# 开始检索
|
||||
status_callback("starting", {"message": "开始分析查询"})
|
||||
|
||||
# 创建增强的初始状态
|
||||
from retriver.langgraph.graph_state import create_initial_state
|
||||
initial_state = create_initial_state(
|
||||
original_query=query,
|
||||
max_iterations=retriever.max_iterations,
|
||||
debug_mode=mode,
|
||||
task_id=task_id # 传递任务ID
|
||||
)
|
||||
|
||||
# 配置工作流执行
|
||||
config = {
|
||||
"recursion_limit": 50,
|
||||
"metadata": {
|
||||
"stream_mode": True,
|
||||
"callback_enabled": True
|
||||
},
|
||||
"run_name": f"Stream-Retrieval-{query[:20]}..."
|
||||
}
|
||||
|
||||
# 执行工作流,并在关键点插入回调
|
||||
result_state = execute_workflow_with_callbacks(
|
||||
retriever.workflow,
|
||||
initial_state,
|
||||
config,
|
||||
status_callback
|
||||
)
|
||||
|
||||
# 构建最终结果
|
||||
total_time = result_state.get("debug_info", {}).get("total_time", 0)
|
||||
|
||||
# 构建结果(使用原有逻辑)
|
||||
result = build_result(retriever, result_state, total_time)
|
||||
|
||||
# 保存完整结果到JSON文件(如果retriever有这个方法)
|
||||
if hasattr(retriever, '_save_full_result_to_file'):
|
||||
# 构建完整结果以便保存
|
||||
full_result = {
|
||||
**result,
|
||||
"all_documents": result_state.get('all_documents', []),
|
||||
"all_passages": result_state.get('all_passages', []),
|
||||
"passage_sources": result_state.get('passage_sources', []),
|
||||
"retrieval_path": result_state.get('retrieval_path', 'unknown'),
|
||||
"initial_retrieval_details": result_state.get('initial_retrieval_details', {}),
|
||||
"current_sub_queries": result_state.get('current_sub_queries', []),
|
||||
"event_triples": result_state.get('event_triples', []) # 新增事件三元组信息
|
||||
}
|
||||
retriever._save_full_result_to_file(full_result)
|
||||
print("[INFO] 流式检索结果已保存到JSON文件")
|
||||
|
||||
return result
|
||||
|
||||
except (TaskCancelledException, WorkflowStoppedException) as e:
|
||||
print(f"[CANCELLED] 流式检索被取消: {e}")
|
||||
if status_callback:
|
||||
status_callback("cancelled", {"message": "任务已被取消"})
|
||||
# 返回取消结果
|
||||
return {
|
||||
"query": query,
|
||||
"answer": "任务已被取消",
|
||||
"cancelled": True,
|
||||
"total_passages": 0,
|
||||
"iterations": 0,
|
||||
"sub_queries": [],
|
||||
"debug_info": {"cancelled": True, "message": str(e)}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 流式检索出错: {e}")
|
||||
traceback.print_exc()
|
||||
# 出错时返回基础结果
|
||||
return {
|
||||
"query": query,
|
||||
"answer": f"检索过程中遇到错误: {str(e)}",
|
||||
"error": str(e),
|
||||
"total_passages": 0,
|
||||
"iterations": 0,
|
||||
"sub_queries": [],
|
||||
"supporting_facts": [],
|
||||
"supporting_events": []
|
||||
}
|
||||
|
||||
|
||||
def execute_workflow_with_callbacks(
|
||||
workflow,
|
||||
initial_state: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
status_callback: Callable
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行工作流并在关键节点插入状态回调
|
||||
|
||||
使用LangGraph的stream功能实时获取状态更新
|
||||
"""
|
||||
|
||||
# 将回调函数注入到state中,供答案生成节点使用
|
||||
initial_state['stream_callback'] = status_callback
|
||||
print(f"[DEBUG] 已将stream_callback注入到initial_state")
|
||||
|
||||
# 创建一个状态监控器
|
||||
class StateMonitor:
|
||||
def __init__(self, callback):
|
||||
self.callback = callback
|
||||
self.last_iteration = 0
|
||||
self.last_doc_count = 0
|
||||
self.last_sub_queries = []
|
||||
self.seen_nodes = set()
|
||||
self.last_sufficiency_check = None # 记录上次的充分性检查结果
|
||||
|
||||
def check_state(self, state, node_name=None):
|
||||
"""检查状态变化并触发回调"""
|
||||
|
||||
# 记录节点访问
|
||||
if node_name and node_name not in self.seen_nodes:
|
||||
self.seen_nodes.add(node_name)
|
||||
|
||||
# 检查子查询 - 节点完成后发送完整数据
|
||||
if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]:
|
||||
sub_queries = state["decomposed_sub_queries"]
|
||||
if sub_queries != self.last_sub_queries:
|
||||
self.callback("sub_queries", sub_queries)
|
||||
self.last_sub_queries = sub_queries
|
||||
|
||||
# 检查文档数量 - 节点完成后发送完整数据
|
||||
if "all_passages" in state:
|
||||
doc_count = len(state.get("all_passages", []))
|
||||
if doc_count != self.last_doc_count:
|
||||
# 提取来源文档名,优先使用evidence字段
|
||||
sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", []))
|
||||
|
||||
# 判断检索类型和原因
|
||||
current_iteration = state.get("current_iteration", 0)
|
||||
sub_queries = state.get("current_sub_queries", [])
|
||||
|
||||
# 确定检索原因
|
||||
if self.last_doc_count == 0:
|
||||
# 初始检索
|
||||
retrieval_type = "初始检索"
|
||||
retrieval_reason = "原始查询"
|
||||
elif sub_queries and len(sub_queries) > 0:
|
||||
# 基于子查询的检索
|
||||
retrieval_type = f"子查询检索"
|
||||
retrieval_reason = f"基于{len(sub_queries)}个新子查询"
|
||||
else:
|
||||
# 迭代检索
|
||||
retrieval_type = f"第{current_iteration}轮迭代"
|
||||
retrieval_reason = "迭代深化"
|
||||
|
||||
self.callback("documents", {
|
||||
"count": doc_count,
|
||||
"sources": sources,
|
||||
"new_docs": doc_count - self.last_doc_count,
|
||||
"retrieval_type": retrieval_type,
|
||||
"retrieval_reason": retrieval_reason,
|
||||
"is_incremental": self.last_doc_count > 0 # 是否为增量检索
|
||||
})
|
||||
self.last_doc_count = doc_count
|
||||
|
||||
# 检查迭代轮次
|
||||
if "current_iteration" in state:
|
||||
current = state["current_iteration"]
|
||||
max_iter = state.get("max_iterations", 3)
|
||||
if current != self.last_iteration:
|
||||
self.callback("iteration", {
|
||||
"current": current,
|
||||
"max": max_iter
|
||||
})
|
||||
self.last_iteration = current
|
||||
|
||||
# 检查充分性判断 - 在每次文档检索后都应该显示
|
||||
if "sufficiency_check" in state and state["sufficiency_check"]:
|
||||
check = state["sufficiency_check"]
|
||||
# 创建一个可比较的字符串表示
|
||||
check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}"
|
||||
|
||||
# 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索)
|
||||
doc_count = len(state.get("all_passages", []))
|
||||
if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count):
|
||||
self.callback("sufficiency_check", {
|
||||
"is_sufficient": check.get("is_sufficient", False),
|
||||
"confidence": check.get("confidence", 0),
|
||||
"reason": check.get("reason", "")
|
||||
})
|
||||
self.last_sufficiency_check = check_str
|
||||
|
||||
# 如果充分,或者达到最大迭代次数,即将生成答案
|
||||
current_iteration = state.get("current_iteration", 0)
|
||||
max_iterations = state.get("max_iterations", 2)
|
||||
if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1:
|
||||
self.callback("generating", {"message": "正在生成最终答案..."})
|
||||
|
||||
return state
|
||||
|
||||
monitor = StateMonitor(status_callback)
|
||||
final_state = None
|
||||
|
||||
# 尝试使用stream方法获取实时更新
|
||||
if hasattr(workflow, 'stream'):
|
||||
try:
|
||||
# 使用stream方法逐步获取状态
|
||||
for chunk in workflow.stream(initial_state, config=config):
|
||||
# chunk通常是 {node_name: state} 的字典
|
||||
if isinstance(chunk, dict):
|
||||
for node_name, node_state in chunk.items():
|
||||
# 实时检查每个节点的状态
|
||||
monitor.check_state(node_state, node_name)
|
||||
final_state = node_state
|
||||
|
||||
# 根据节点名称在节点开始时发送状态(只发送一次)
|
||||
# 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复
|
||||
if "query_complexity_check" in node_name:
|
||||
# complexity_check在节点完成后会发送,这里不发送
|
||||
pass
|
||||
elif "query_decomposition" in node_name:
|
||||
# 分解查询节点 - 不发送中间状态,结果会很快出现
|
||||
pass
|
||||
elif "initial_retrieval" in node_name:
|
||||
# 初始检索节点 - 不发送中间状态,结果会很快出现
|
||||
pass
|
||||
elif "sufficiency_check" in node_name:
|
||||
# 充分性检查节点开始 - 不发送,避免重复
|
||||
pass
|
||||
elif "sub_query_generation" in node_name:
|
||||
# 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节
|
||||
pass
|
||||
# status_callback("generating_subqueries", {"message": "正在生成子查询..."})
|
||||
elif "parallel_retrieval" in node_name:
|
||||
# 并行检索节点开始 - 不显示,因为用户不需要看到这个细节
|
||||
pass
|
||||
# status_callback("parallel_retrieving", {"message": "正在并行检索..."})
|
||||
elif "final" in node_name.lower() or "answer" in node_name.lower():
|
||||
# 最终答案节点 - 状态已在充分性检查后发送,这里不重复
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
# 如果stream失败,回退到invoke
|
||||
final_state = workflow.invoke(initial_state, config=config)
|
||||
monitor.check_state(final_state)
|
||||
else:
|
||||
# 不支持stream,使用普通invoke
|
||||
final_state = workflow.invoke(initial_state, config=config)
|
||||
monitor.check_state(final_state)
|
||||
|
||||
return final_state
|
||||
|
||||
|
||||
def extract_sources_from_documents(all_documents, passage_sources):
|
||||
"""
|
||||
从documents中提取真实的文档名(优先使用evidence字段)
|
||||
|
||||
Args:
|
||||
all_documents: 文档列表(包含metadata中的evidence字段)
|
||||
passage_sources: 段落来源列表(作为后备)
|
||||
|
||||
Returns:
|
||||
去重的文档名列表
|
||||
"""
|
||||
sources = set()
|
||||
|
||||
# 优先从all_documents中提取evidence字段
|
||||
if all_documents:
|
||||
for doc in all_documents:
|
||||
if hasattr(doc, 'metadata'):
|
||||
# 优先使用evidence字段(段落节点的真实文档名)
|
||||
evidence = doc.metadata.get('evidence', '')
|
||||
if evidence:
|
||||
sources.add(evidence)
|
||||
continue
|
||||
|
||||
# 其次使用source_text_id(事件节点的来源)
|
||||
source_text_id = doc.metadata.get('source_text_id', '')
|
||||
if source_text_id:
|
||||
# 如果source_text_id是列表格式的字符串,解析它
|
||||
if source_text_id.startswith('[') and source_text_id.endswith(']'):
|
||||
try:
|
||||
import ast
|
||||
text_ids = ast.literal_eval(source_text_id)
|
||||
if isinstance(text_ids, list) and text_ids:
|
||||
sources.add(text_ids[0]) # 使用第一个来源
|
||||
except:
|
||||
sources.add(source_text_id)
|
||||
else:
|
||||
sources.add(source_text_id)
|
||||
continue
|
||||
|
||||
# 如果没有找到evidence或source_text_id,使用原始的passage_sources
|
||||
if not sources and passage_sources:
|
||||
for source in passage_sources:
|
||||
if isinstance(source, str):
|
||||
# 提取文件名
|
||||
if "/" in source:
|
||||
filename = source.split("/")[-1]
|
||||
else:
|
||||
filename = source
|
||||
sources.add(filename)
|
||||
elif isinstance(source, dict):
|
||||
# 如果是字典,尝试获取文件名字段
|
||||
filename = source.get("filename") or source.get("source") or str(source)
|
||||
sources.add(filename)
|
||||
|
||||
return list(sources)[:5] # 最多返回5个
|
||||
|
||||
|
||||
def extract_sources(passage_sources):
|
||||
"""
|
||||
从passage_sources提取文档名列表(保留以兼容旧代码)
|
||||
|
||||
Args:
|
||||
passage_sources: 段落来源列表
|
||||
|
||||
Returns:
|
||||
去重的文档名列表
|
||||
"""
|
||||
return extract_sources_from_documents([], passage_sources)
|
||||
|
||||
|
||||
def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]:
|
||||
"""
|
||||
构建最终结果(复用原有逻辑)
|
||||
"""
|
||||
|
||||
# 获取Token使用信息(如果有的话)
|
||||
token_info = {}
|
||||
if hasattr(retriever, '_get_token_usage_info'):
|
||||
token_info = retriever._get_token_usage_info()
|
||||
|
||||
result = {
|
||||
"query": final_state.get('original_query', ''),
|
||||
"answer": final_state.get('final_answer', '') or "未能生成答案",
|
||||
|
||||
# 查询复杂度信息
|
||||
"query_complexity": final_state.get('query_complexity', {}),
|
||||
"is_complex_query": final_state.get('is_complex_query', False),
|
||||
|
||||
# 检索统计
|
||||
"iterations": final_state.get('current_iteration', 0),
|
||||
"total_passages": len(final_state.get('all_passages', [])),
|
||||
"sub_queries": final_state.get('sub_queries', []),
|
||||
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
|
||||
|
||||
# 充分性信息
|
||||
"sufficiency_check": final_state.get('sufficiency_check', {}),
|
||||
"is_sufficient": final_state.get('is_sufficient', False),
|
||||
|
||||
# 事件三元组信息(新增)
|
||||
"event_triples": final_state.get('event_triples', []),
|
||||
"event_triples_count": len(final_state.get('event_triples', [])),
|
||||
|
||||
# 支撑信息
|
||||
"supporting_facts": extract_supporting_facts(final_state),
|
||||
"supporting_events": extract_supporting_events(final_state),
|
||||
|
||||
# 调试信息
|
||||
"debug_info": {
|
||||
"total_time": total_time,
|
||||
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
|
||||
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
|
||||
"token_usage": token_info
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def extract_supporting_facts(state):
|
||||
"""提取支撑事实"""
|
||||
passages = state.get('all_passages', [])
|
||||
facts = []
|
||||
|
||||
for i, passage in enumerate(passages[:10]): # 最多10个
|
||||
if isinstance(passage, str):
|
||||
facts.append([f"Fact_{i+1}", passage[:200]])
|
||||
elif isinstance(passage, dict):
|
||||
content = passage.get('content', '') or passage.get('text', '')
|
||||
facts.append([f"Fact_{i+1}", content[:200]])
|
||||
|
||||
return facts
|
||||
|
||||
|
||||
def extract_supporting_events(state):
|
||||
"""提取支撑事件"""
|
||||
# 从状态中提取事件信息
|
||||
events = []
|
||||
|
||||
# 可以从iteration_history中提取关键事件
|
||||
iteration_history = state.get('iteration_history', [])
|
||||
for item in iteration_history:
|
||||
if item.get('action') == 'sub_query_generation':
|
||||
events.append([
|
||||
"子查询生成",
|
||||
f"第{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询"
|
||||
])
|
||||
elif item.get('action') == 'parallel_retrieval':
|
||||
events.append([
|
||||
"并行检索",
|
||||
f"第{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档"
|
||||
])
|
||||
|
||||
return events[:5] # 最多5个事件
|
||||
Reference in New Issue
Block a user