""" 流式检索包装器 - 为现有检索器添加状态回调支持 保持向后兼容,不影响原有业务逻辑 """ import time from typing import Dict, Any, Callable, Optional import traceback def stream_retrieve( retriever, query: str, mode: str = "0", status_callback: Optional[Callable] = None ) -> Dict[str, Any]: """ 带状态回调的检索包装器 在不修改原有检索器的情况下,通过包装添加状态回调功能 Args: retriever: 原始检索器实例 query: 查询字符串 mode: 调试模式 status_callback: 状态回调函数,接收(status_type, data)参数 Returns: 检索结果字典 """ # 如果没有提供回调,直接调用原始方法 if not status_callback: return retriever.retrieve(query, mode) try: # 开始检索 status_callback("starting", {"message": "开始分析查询"}) # 创建增强的初始状态 from retriver.langgraph.graph_state import create_initial_state initial_state = create_initial_state( original_query=query, max_iterations=retriever.max_iterations, debug_mode=mode ) # 配置工作流执行 config = { "recursion_limit": 50, "metadata": { "stream_mode": True, "callback_enabled": True }, "run_name": f"Stream-Retrieval-{query[:20]}..." } # 执行工作流,并在关键点插入回调 result_state = execute_workflow_with_callbacks( retriever.workflow, initial_state, config, status_callback ) # 构建最终结果 total_time = result_state.get("debug_info", {}).get("total_time", 0) # 构建结果(使用原有逻辑) result = build_result(retriever, result_state, total_time) # 保存完整结果到JSON文件(如果retriever有这个方法) if hasattr(retriever, '_save_full_result_to_file'): # 构建完整结果以便保存 full_result = { **result, "all_documents": result_state.get('all_documents', []), "all_passages": result_state.get('all_passages', []), "passage_sources": result_state.get('passage_sources', []), "retrieval_path": result_state.get('retrieval_path', 'unknown'), "initial_retrieval_details": result_state.get('initial_retrieval_details', {}), "current_sub_queries": result_state.get('current_sub_queries', []) } retriever._save_full_result_to_file(full_result) print("[INFO] 流式检索结果已保存到JSON文件") return result except Exception as e: print(f"[ERROR] 流式检索出错: {e}") traceback.print_exc() # 出错时返回基础结果 return { "query": query, "answer": f"检索过程中遇到错误: {str(e)}", "error": str(e), "total_passages": 0, "iterations": 0, "sub_queries": [], "supporting_facts": [], "supporting_events": [] } def execute_workflow_with_callbacks( workflow, initial_state: Dict[str, Any], config: Dict[str, Any], status_callback: Callable ) -> Dict[str, Any]: """ 执行工作流并在关键节点插入状态回调 使用LangGraph的stream功能实时获取状态更新 """ # 将回调函数注入到state中,供答案生成节点使用 initial_state['stream_callback'] = status_callback print(f"[DEBUG] 已将stream_callback注入到initial_state") # 创建一个状态监控器 class StateMonitor: def __init__(self, callback): self.callback = callback self.last_iteration = 0 self.last_doc_count = 0 self.last_sub_queries = [] self.seen_nodes = set() self.last_sufficiency_check = None # 记录上次的充分性检查结果 def check_state(self, state, node_name=None): """检查状态变化并触发回调""" # 记录节点访问 if node_name and node_name not in self.seen_nodes: self.seen_nodes.add(node_name) # 检查子查询 - 节点完成后发送完整数据 if "decomposed_sub_queries" in state and state["decomposed_sub_queries"]: sub_queries = state["decomposed_sub_queries"] if sub_queries != self.last_sub_queries: self.callback("sub_queries", sub_queries) self.last_sub_queries = sub_queries # 检查文档数量 - 节点完成后发送完整数据 if "all_passages" in state: doc_count = len(state.get("all_passages", [])) if doc_count != self.last_doc_count: # 提取来源文档名,优先使用evidence字段 sources = extract_sources_from_documents(state.get("all_documents", []), state.get("passage_sources", [])) # 判断检索类型和原因 current_iteration = state.get("current_iteration", 0) sub_queries = state.get("current_sub_queries", []) # 确定检索原因 if self.last_doc_count == 0: # 初始检索 retrieval_type = "初始检索" retrieval_reason = "原始查询" elif sub_queries and len(sub_queries) > 0: # 基于子查询的检索 retrieval_type = f"子查询检索" retrieval_reason = f"基于{len(sub_queries)}个新子查询" else: # 迭代检索 retrieval_type = f"第{current_iteration}轮迭代" retrieval_reason = "迭代深化" self.callback("documents", { "count": doc_count, "sources": sources, "new_docs": doc_count - self.last_doc_count, "retrieval_type": retrieval_type, "retrieval_reason": retrieval_reason, "is_incremental": self.last_doc_count > 0 # 是否为增量检索 }) self.last_doc_count = doc_count # 检查迭代轮次 if "current_iteration" in state: current = state["current_iteration"] max_iter = state.get("max_iterations", 3) if current != self.last_iteration: self.callback("iteration", { "current": current, "max": max_iter }) self.last_iteration = current # 检查充分性判断 - 在每次文档检索后都应该显示 if "sufficiency_check" in state and state["sufficiency_check"]: check = state["sufficiency_check"] # 创建一个可比较的字符串表示 check_str = f"{check.get('is_sufficient')}_{check.get('confidence', 0):.2f}_{check.get('iteration', 0)}" # 如果是新的充分性检查结果,或者文档数量有变化(表示刚完成检索) doc_count = len(state.get("all_passages", [])) if (check and check_str != self.last_sufficiency_check) or (doc_count > 0 and doc_count != self.last_doc_count): self.callback("sufficiency_check", { "is_sufficient": check.get("is_sufficient", False), "confidence": check.get("confidence", 0), "reason": check.get("reason", "") }) self.last_sufficiency_check = check_str # 如果充分,或者达到最大迭代次数,即将生成答案 current_iteration = state.get("current_iteration", 0) max_iterations = state.get("max_iterations", 2) if check.get("is_sufficient", False) or current_iteration >= max_iterations - 1: self.callback("generating", {"message": "正在生成最终答案..."}) return state monitor = StateMonitor(status_callback) final_state = None # 尝试使用stream方法获取实时更新 if hasattr(workflow, 'stream'): try: # 使用stream方法逐步获取状态 for chunk in workflow.stream(initial_state, config=config): # chunk通常是 {node_name: state} 的字典 if isinstance(chunk, dict): for node_name, node_state in chunk.items(): # 实时检查每个节点的状态 monitor.check_state(node_state, node_name) final_state = node_state # 根据节点名称在节点开始时发送状态(只发送一次) # 注意:这些状态只在节点开始时发送,不要和节点完成后的状态重复 if "query_complexity_check" in node_name: # complexity_check在节点完成后会发送,这里不发送 pass elif "query_decomposition" in node_name: # 分解查询节点 - 不发送中间状态,结果会很快出现 pass elif "initial_retrieval" in node_name: # 初始检索节点 - 不发送中间状态,结果会很快出现 pass elif "sufficiency_check" in node_name: # 充分性检查节点开始 - 不发送,避免重复 pass elif "sub_query_generation" in node_name: # 子查询生成节点开始 - 不显示,因为用户不需要看到这个细节 pass # status_callback("generating_subqueries", {"message": "正在生成子查询..."}) elif "parallel_retrieval" in node_name: # 并行检索节点开始 - 不显示,因为用户不需要看到这个细节 pass # status_callback("parallel_retrieving", {"message": "正在并行检索..."}) elif "final" in node_name.lower() or "answer" in node_name.lower(): # 最终答案节点 - 状态已在充分性检查后发送,这里不重复 pass except Exception as e: # 如果stream失败,回退到invoke final_state = workflow.invoke(initial_state, config=config) monitor.check_state(final_state) else: # 不支持stream,使用普通invoke final_state = workflow.invoke(initial_state, config=config) monitor.check_state(final_state) return final_state def extract_sources_from_documents(all_documents, passage_sources): """ 从documents中提取真实的文档名(优先使用evidence字段) Args: all_documents: 文档列表(包含metadata中的evidence字段) passage_sources: 段落来源列表(作为后备) Returns: 去重的文档名列表 """ sources = set() # 优先从all_documents中提取evidence字段 if all_documents: for doc in all_documents: if hasattr(doc, 'metadata'): # 优先使用evidence字段(段落节点的真实文档名) evidence = doc.metadata.get('evidence', '') if evidence: sources.add(evidence) continue # 其次使用source_text_id(事件节点的来源) source_text_id = doc.metadata.get('source_text_id', '') if source_text_id: # 如果source_text_id是列表格式的字符串,解析它 if source_text_id.startswith('[') and source_text_id.endswith(']'): try: import ast text_ids = ast.literal_eval(source_text_id) if isinstance(text_ids, list) and text_ids: sources.add(text_ids[0]) # 使用第一个来源 except: sources.add(source_text_id) else: sources.add(source_text_id) continue # 如果没有找到evidence或source_text_id,使用原始的passage_sources if not sources and passage_sources: for source in passage_sources: if isinstance(source, str): # 提取文件名 if "/" in source: filename = source.split("/")[-1] else: filename = source sources.add(filename) elif isinstance(source, dict): # 如果是字典,尝试获取文件名字段 filename = source.get("filename") or source.get("source") or str(source) sources.add(filename) return list(sources)[:5] # 最多返回5个 def extract_sources(passage_sources): """ 从passage_sources提取文档名列表(保留以兼容旧代码) Args: passage_sources: 段落来源列表 Returns: 去重的文档名列表 """ return extract_sources_from_documents([], passage_sources) def build_result(retriever, final_state: Dict[str, Any], total_time: float) -> Dict[str, Any]: """ 构建最终结果(复用原有逻辑) """ # 获取Token使用信息(如果有的话) token_info = {} if hasattr(retriever, '_get_token_usage_info'): token_info = retriever._get_token_usage_info() result = { "query": final_state.get('original_query', ''), "answer": final_state.get('final_answer', '') or "未能生成答案", # 查询复杂度信息 "query_complexity": final_state.get('query_complexity', {}), "is_complex_query": final_state.get('is_complex_query', False), # 检索统计 "iterations": final_state.get('current_iteration', 0), "total_passages": len(final_state.get('all_passages', [])), "sub_queries": final_state.get('sub_queries', []), "decomposed_sub_queries": final_state.get('decomposed_sub_queries', []), # 充分性信息 "sufficiency_check": final_state.get('sufficiency_check', {}), "is_sufficient": final_state.get('is_sufficient', False), # 支撑信息 "supporting_facts": extract_supporting_facts(final_state), "supporting_events": extract_supporting_events(final_state), # 调试信息 "debug_info": { "total_time": total_time, "retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0), "llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0), "token_usage": token_info } } return result def extract_supporting_facts(state): """提取支撑事实""" passages = state.get('all_passages', []) facts = [] for i, passage in enumerate(passages[:10]): # 最多10个 if isinstance(passage, str): facts.append([f"Fact_{i+1}", passage[:200]]) elif isinstance(passage, dict): content = passage.get('content', '') or passage.get('text', '') facts.append([f"Fact_{i+1}", content[:200]]) return facts def extract_supporting_events(state): """提取支撑事件""" # 从状态中提取事件信息 events = [] # 可以从iteration_history中提取关键事件 iteration_history = state.get('iteration_history', []) for item in iteration_history: if item.get('action') == 'sub_query_generation': events.append([ "子查询生成", f"第{item.get('iteration', 0)}轮: 生成{len(item.get('sub_queries', []))}个子查询" ]) elif item.get('action') == 'parallel_retrieval': events.append([ "并行检索", f"第{item.get('iteration', 0)}轮: 检索{item.get('documents_retrieved', 0)}篇文档" ]) return events[:5] # 最多5个事件