""" LangGraph工作流节点 实现具体的工作流节点逻辑 """ import json import asyncio from typing import Dict, Any, List, Tuple, Optional from concurrent.futures import ThreadPoolExecutor, as_completed # LangSmith会自动追踪LangChain的LLM调用 from retriver.langgraph.graph_state import ( QueryState, RetrievalResult, SufficiencyCheck, update_state_with_retrieval, update_state_with_sufficiency_check, update_state_with_event_triples, increment_iteration, finalize_state ) from task_manager import get_task_manager from exceptions import TaskCancelledException, WorkflowStoppedException from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever from retriver.langgraph.langchain_components import ( OneAPILLM, SufficiencyCheckParser, QueryComplexityParser, QUERY_COMPLEXITY_CHECK_PROMPT, SUFFICIENCY_CHECK_PROMPT, QUERY_DECOMPOSITION_PROMPT, SUB_QUERY_GENERATION_PROMPT, SIMPLE_ANSWER_PROMPT, FINAL_ANSWER_PROMPT, format_passages, format_mixed_passages, format_sub_queries, format_event_triples ) from retriver.langgraph.es_vector_retriever import ESVectorRetriever from prompt_loader import get_prompt_loader class GraphNodes: """工作流节点实现类""" def __init__( self, retriever: LangChainHippoRAGRetriever, llm: OneAPILLM, keyword: str, max_parallel_retrievals: int = 2, simple_retrieval_top_k: int = 3, complexity_llm: Optional[OneAPILLM] = None, sufficiency_llm: Optional[OneAPILLM] = None, skip_llm_generation: bool = False ): """ 初始化节点处理器 Args: retriever: HippoRAG检索器 llm: OneAPI LLM (用于生成答案) keyword: ES索引关键词 max_parallel_retrievals: 最大并行检索数 simple_retrieval_top_k: 简单检索返回文档数 complexity_llm: 复杂度判断专用LLM(如果不指定则使用llm) sufficiency_llm: 充分性检查专用LLM(如果不指定则使用llm) skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果) """ self.retriever = retriever self.llm = llm self.complexity_llm = complexity_llm or llm # 如果没有指定,使用主LLM self.sufficiency_llm = sufficiency_llm or llm # 如果没有指定,使用主LLM self.keyword = keyword self.max_parallel_retrievals = max_parallel_retrievals self.skip_llm_generation = skip_llm_generation self.sufficiency_parser = SufficiencyCheckParser() self.complexity_parser = QueryComplexityParser() # 创建ES向量检索器用于简单查询 self.es_vector_retriever = ESVectorRetriever( keyword=keyword, top_k=simple_retrieval_top_k ) # 获取prompt加载器 self.prompt_loader = get_prompt_loader() def _extract_response_text(self, response): """统一的响应文本提取方法""" if hasattr(response, 'generations') and response.generations: return response.generations[0][0].text elif hasattr(response, 'content'): return response.content elif isinstance(response, dict) and 'response' in response: return response['response'] else: return str(response) def _check_should_stop(self, state: QueryState, node_name: str = None): """ 检查是否应该停止执行 Args: state: 当前状态 node_name: 当前节点名称 Raises: TaskCancelledException: 如果任务被取消 """ # 如果有task_id,检查任务管理器 if state.get("task_id"): task_manager = get_task_manager() if task_manager.should_stop(state["task_id"]): print(f"[NODE] 节点 {node_name or '未知'} 检测到停止信号") raise TaskCancelledException( task_id=state["task_id"], message=f"任务在节点 {node_name} 被取消" ) # 检查状态中的停止标志 if state.get("should_stop", False): raise WorkflowStoppedException( node_name=node_name, message=f"工作流在节点 {node_name} 停止" ) def query_complexity_check_node(self, state: QueryState) -> QueryState: """ 查询复杂度判断节点 判断用户查询是否需要复杂的知识图谱推理 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[?] 执行查询复杂度判断: {state['original_query']}") self._check_should_stop(state, "query_complexity_check") # 检查是否跳过LLM if self.prompt_loader.should_skip_llm('query_complexity_check'): print(f"[NEXT] 跳过复杂度检查LLM调用,默认为简单查询") state['is_complex'] = False state['complexity_level'] = 'simple' return state # 构建提示词 prompt = QUERY_COMPLEXITY_CHECK_PROMPT.format( query=state['original_query'] ) try: # 调用专用的复杂度判断LLM response = self.complexity_llm.invoke(prompt) response_text = self._extract_response_text(response) # 解析响应 complexity_result = self.complexity_parser.parse(response_text) # 更新状态 state['query_complexity'] = { "is_complex": complexity_result.is_complex, "complexity_level": complexity_result.complexity_level, "confidence": complexity_result.confidence, "reason": complexity_result.reason } state['is_complex_query'] = complexity_result.is_complex # 更新调试信息 state["debug_info"]["llm_calls"] += 1 print(f"[INFO] 复杂度判断结果: {'复杂' if complexity_result.is_complex else '简单'} " f"(置信度: {complexity_result.confidence:.2f})") print(f"理由: {complexity_result.reason}") return state except Exception as e: print(f"[ERROR] 复杂度判断失败: {e}") # 如果判断失败,默认为复杂查询,走现有逻辑 state['query_complexity'] = { "is_complex": True, "complexity_level": "complex", "confidence": 0.5, "reason": f"复杂度判断失败: {str(e)},默认使用复杂检索" } state['is_complex_query'] = True return state def debug_mode_node(self, state: QueryState) -> QueryState: """ 调试模式节点 根据用户设置的debug_mode参数决定是否覆盖复杂度判断结果 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[?] 执行调试模式检查: mode={state['debug_mode']}") self._check_should_stop(state, "debug_mode") # 保存原始复杂度判断结果 original_is_complex = state['is_complex_query'] if state['debug_mode'] == 'simple': print("[?] 调试模式: 强制使用简单检索路径") # 强制设置为简单查询,无论原始复杂度判断结果如何 state['is_complex_query'] = False # 保留原始复杂度判断结果,但添加调试信息 if 'debug_override' not in state['query_complexity']: state['query_complexity']['debug_override'] = { 'original_complexity': original_is_complex, 'debug_mode': 'simple', 'override_reason': '调试模式强制使用简单检索路径' } elif state['debug_mode'] == 'complex': print("[?] 调试模式: 强制使用复杂检索路径") # 强制设置为复杂查询,无论原始复杂度判断结果如何 state['is_complex_query'] = True # 保留原始复杂度判断结果,但添加调试信息 if 'debug_override' not in state['query_complexity']: state['query_complexity']['debug_override'] = { 'original_complexity': original_is_complex, 'debug_mode': 'complex', 'override_reason': '调试模式强制使用复杂检索路径' } else: # debug_mode = "0" 或其他值,使用原始复杂度判断结果 print("[?] 调试模式: 使用自动复杂度判断结果") return state def query_decomposition_node(self, state: QueryState) -> QueryState: """ 查询分解节点 将用户原始查询分解为2个便于分头检索的子查询 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[TARGET] 执行查询分解: {state['original_query']}") self._check_should_stop(state, "query_decomposition") # 检查是否跳过LLM if self.prompt_loader.should_skip_llm('query_decomposition'): print(f"[NEXT] 跳过查询分解LLM调用,使用原查询") # 不分解,直接使用原查询 state['decomposed_sub_queries'] = [] state['sub_queries'] = [state['original_query']] return state # 构建提示词 prompt = QUERY_DECOMPOSITION_PROMPT.format( original_query=state['original_query'] ) try: # 调用LLM生成子查询 response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) print(f"[?] LLM原始响应: {response_text[:200]}...") # 清理和提取JSON内容 cleaned_text = response_text.strip() # 如果响应包含"答案:",提取其后的内容 if "答案:" in cleaned_text or "答案:" in cleaned_text: # 找到"答案:"后的内容 answer_markers = ["答案:", "答案:"] for marker in answer_markers: if marker in cleaned_text: cleaned_text = cleaned_text.split(marker, 1)[1].strip() break # 清理markdown代码块标记 if cleaned_text.startswith('```json'): cleaned_text = cleaned_text[7:] # 移除 ```json elif cleaned_text.startswith('```'): cleaned_text = cleaned_text[3:] # 移除 ``` if cleaned_text.endswith('```'): cleaned_text = cleaned_text[:-3] # 移除结尾的 ``` cleaned_text = cleaned_text.strip() # 尝试提取JSON部分(查找{开始到}结束的内容) if '{' in cleaned_text and '}' in cleaned_text: start_idx = cleaned_text.find('{') end_idx = cleaned_text.rfind('}') + 1 cleaned_text = cleaned_text[start_idx:end_idx] # 解析子查询 try: data = json.loads(cleaned_text) raw_decomposed_queries = data.get('sub_queries', []) # 处理不同格式的子查询(与sub_query_generation_node保持一致) decomposed_sub_queries = [] for item in raw_decomposed_queries: if isinstance(item, str): decomposed_sub_queries.append(item) elif isinstance(item, dict) and 'query' in item: decomposed_sub_queries.append(item['query']) elif isinstance(item, dict): query_text = item.get('text', item.get('content', str(item))) decomposed_sub_queries.append(query_text) else: decomposed_sub_queries.append(str(item)) # 确保都是字符串 decomposed_sub_queries = [str(query).strip() for query in decomposed_sub_queries if query] print(f"[OK] JSON解析成功,获得子查询: {decomposed_sub_queries}") except json.JSONDecodeError as e: print(f"[ERROR] JSON解析失败: {e}") print(f"尝试规则提取...") # 如果JSON解析失败,使用简单规则提取 lines = response_text.split('\n') decomposed_sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2] print(f"规则提取结果: {decomposed_sub_queries}") # 确保有子查询 if not decomposed_sub_queries: print(f"[WARNING] LLM未生成有效子查询,使用简单分解策略...") # 简单按标点符号分解 original = state['original_query'] if '?' in original: parts = [part.strip() for part in original.split('?') if part.strip()] if len(parts) >= 2: decomposed_sub_queries = [parts[0], parts[1]] print(f"按中文问号分解: {decomposed_sub_queries}") elif '?' in original: parts = [part.strip() for part in original.split('?') if part.strip()] if len(parts) >= 2: decomposed_sub_queries = [parts[0], parts[1]] print(f"按英文问号分解: {decomposed_sub_queries}") # 如果还是没有,使用原查询 if not decomposed_sub_queries: print(f"[WARNING] 无法自动分解,使用原查询作为两个子查询") decomposed_sub_queries = [original, original] elif len(decomposed_sub_queries) == 1: print(f"[WARNING] 只获得1个子查询,补充第二个") # 如果只有一个子查询,使用原查询作为第二个 decomposed_sub_queries.append(state['original_query']) # 限制为2个子查询 decomposed_sub_queries = decomposed_sub_queries[:2] # 更新状态 - 存储初始分解的子查询 state['decomposed_sub_queries'] = decomposed_sub_queries state['sub_queries'].extend(decomposed_sub_queries) # 更新调试信息 state["debug_info"]["llm_calls"] += 1 print(f"[OK] 查询分解完成: {decomposed_sub_queries}") return state except Exception as e: print(f"[ERROR] 查询分解失败: {e}") # 如果生成失败,使用默认子查询 default_sub_queries = [ state['original_query'] + " 详细信息", state['original_query'] + " 相关内容" ] state['decomposed_sub_queries'] = default_sub_queries state['sub_queries'].extend(default_sub_queries) return state def simple_vector_retrieval_node(self, state: QueryState) -> QueryState: """ 简单向量检索节点 直接与ES向量库中的文本段落进行向量匹配 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[SEARCH] 执行简单向量检索: {state['original_query']}") try: # 使用ES向量检索器检索相关文档 documents = self.es_vector_retriever.retrieve(state['original_query']) # 提取段落内容和来源信息 passages = [doc.page_content for doc in documents] sources = [f"简单检索-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents] # 创建检索结果 retrieval_result = RetrievalResult( passages=passages, documents=documents, sources=sources, query=state['original_query'], iteration=state['current_iteration'] ) # 更新状态 updated_state = update_state_with_retrieval(state, retrieval_result) print(f"[OK] 简单向量检索完成,获得 {len(passages)} 个段落") return updated_state except Exception as e: print(f"[ERROR] 简单向量检索失败: {e}") # 如果检索失败,返回空结果 retrieval_result = RetrievalResult( passages=[], documents=[], sources=[], query=state['original_query'], iteration=state['current_iteration'] ) return update_state_with_retrieval(state, retrieval_result) def simple_answer_generation_node(self, state: QueryState) -> QueryState: """ 简单答案生成节点 基于简单检索的结果生成答案 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[NOTE] 生成简单查询答案") self._check_should_stop(state, "simple_answer_generation") # 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式) if 'all_documents' in state and state['all_documents']: formatted_passages = format_mixed_passages(state['all_documents']) else: formatted_passages = format_passages(state['all_passages']) # 检查是否跳过LLM生成(全局配置或prompt级别配置) if self.skip_llm_generation or self.prompt_loader.should_skip_llm('simple_answer'): if self.skip_llm_generation: print(f"[SEARCH] 跳过LLM生成(skip_llm_generation=true),直接返回检索结果") else: print(f"[NEXT] 跳过简单答案LLM调用(simple_answer.skip_llm=true),直接返回检索结果") final_answer = f"检索到的信息:\n{formatted_passages}" state['final_answer'] = final_answer return finalize_state(state, final_answer) # 构建提示词 prompt = SIMPLE_ANSWER_PROMPT.format( query=state['original_query'], passages=formatted_passages ) try: # 调用LLM生成答案 response = self.llm.invoke(prompt) final_answer = self._extract_response_text(response) if not final_answer.strip(): final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。" # 完成状态 updated_state = finalize_state(state, final_answer) print(f"[OK] 简单答案生成完成 (长度: {len(final_answer)} 字符)") return updated_state except Exception as e: print(f"[ERROR] 简单答案生成失败: {e}") error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}" return finalize_state(state, error_answer) def initial_retrieval_node(self, state: QueryState) -> QueryState: """ 并行初始检索节点 使用原始查询和2个分解的子查询并行进行检索 每个查询返回混合检索结果:TOP-10个事件节点 + TOP-3个段落节点 Args: state: 当前状态 Returns: 更新后的状态 """ # 准备要检索的查询列表:原始查询 + 2个子查询 original_query = state['original_query'] sub_queries = state.get('decomposed_sub_queries', []) all_queries = [original_query] + sub_queries[:2] # 确保最多3个查询 print(f"[SEARCH] 执行并行初始检索 - {len(all_queries)} 个查询") self._check_should_stop(state, "initial_retrieval") for i, query in enumerate(all_queries): query_type = "原始查询" if i == 0 else f"子查询{i}" print(f" {query_type}: {query}") def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, List, str]: """检索单个查询,返回文档、段落、源信息""" import time # 根据查询类型设置标签 if index == 0: query_label = "原始查询" else: query_label = f"子查询{index}" start_time = time.time() print(f"[STARTING] {query_label} 开始检索 [{time.strftime('%H:%M:%S', time.localtime(start_time))}]") try: documents = self.retriever.invoke(query) # 检索器现在返回混合结果(事件+段落),不再限制数量 top_documents = documents # 使用所有检索到的文档 passages = [doc.page_content for doc in top_documents] # 根据查询类型设置源标识,支持混合节点类型 sources = [] for doc in top_documents: # 优先使用node_id,然后是passage_id doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id', 'unknown') node_type = doc.metadata.get('node_type', 'unknown') if index == 0: sources.append(f"原始查询-{node_type}-{doc_id}") else: sources.append(f"子查询{index}-{node_type}-{doc_id}") end_time = time.time() duration = end_time - start_time print(f"[OK] {query_label} 检索完成 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒,获得 {len(passages)} 个内容(事件+段落)") return index, documents, passages, sources, query_label except Exception as e: end_time = time.time() duration = end_time - start_time print(f"[ERROR] {query_label} 检索失败 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒 - 错误: {e}") return index, [], [], [], query_label # 并行执行检索 all_documents = [] all_passages = [] all_sources = [] retrieval_details = {} import time parallel_start_time = time.time() print(f"[FAST] 开始并行执行 {len(all_queries)} 个检索任务 [{time.strftime('%H:%M:%S', time.localtime(parallel_start_time))}]") with ThreadPoolExecutor(max_workers=min(3, len(all_queries))) as executor: # 提交检索任务 futures = { executor.submit(retrieve_single_query, query, i): (query, i) for i, query in enumerate(all_queries) } print(f"[?] 所有 {len(futures)} 个检索任务已提交到线程池") # 收集结果 for future in as_completed(futures): query, query_index = futures[future] try: index, documents, passages, sources, query_label = future.result() # 记录检索详情 retrieval_details[query_label] = { 'query': query, 'passages_count': len(passages), 'documents_count': len(documents) } # 合并结果(暂时不去重,稍后统一处理) all_documents.extend(documents) all_passages.extend(passages) all_sources.extend(sources) except Exception as e: print(f"[ERROR] 查询 {query_index+1} 处理失败: {e}") parallel_end_time = time.time() parallel_duration = parallel_end_time - parallel_start_time print(f"[TARGET] 并行检索全部完成 [{time.strftime('%H:%M:%S', time.localtime(parallel_end_time))}] - 总耗时: {parallel_duration:.2f}秒") # 去重处理:基于文档ID或内容去重 unique_documents = [] unique_passages = [] unique_sources = [] seen_passage_ids = set() seen_content_hashes = set() for i, (doc, passage, source) in enumerate(zip(all_documents, all_passages, all_sources)): # 尝试使用node_id或passage_id去重,支持混合节点类型 doc_id = None if doc: doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id') if doc_id and doc_id in seen_passage_ids: continue # 使用内容hash去重(作为backup) content_hash = hash(passage.strip()) if content_hash in seen_content_hashes: continue # 添加到去重后的结果 unique_documents.append(doc) unique_passages.append(passage) unique_sources.append(source) if doc_id: seen_passage_ids.add(doc_id) seen_content_hashes.add(content_hash) removed_count = len(all_passages) - len(unique_passages) if removed_count > 0: print(f"[SEARCH] 去重处理: 移除了 {removed_count} 个重复内容") # 创建检索结果 query_description = f"并行检索: 原始查询 + {len(sub_queries)} 个子查询" retrieval_result = RetrievalResult( passages=unique_passages, documents=unique_documents, sources=unique_sources, query=query_description, iteration=state['current_iteration'] ) # 更新状态 updated_state = update_state_with_retrieval(state, retrieval_result) # 存储检索详情到状态中,便于后续分析 updated_state['initial_retrieval_details'] = retrieval_details # 收集PageRank分数信息(每个查询的完整PageRank结果) for i, query in enumerate(all_queries): try: # 跳过PageRank数据收集以避免LangSmith传输大量数据 # complete_ppr_info = self.retriever.get_complete_pagerank_scores(query) # 仅设置标识表示数据可用(实际数据在HippoRAG内部处理) updated_state['pagerank_data_available'] = True except Exception as e: print(f"[WARNING] 收集查询{i+1}的PageRank分数失败: {e}") total_passages_before = len(all_passages) total_passages_after = len(unique_passages) print(f"[SUCCESS] 并行初始检索完成") print(f" 检索前总内容: {total_passages_before}, 去重后: {total_passages_after}") print(f" 原始查询: {retrieval_details.get('原始查询', {}).get('passages_count', 0)} 个内容(事件+段落)") for i in range(1, len(all_queries)): key = f"子查询{i}" count = retrieval_details.get(key, {}).get('passages_count', 0) print(f" 子查询{i}: {count} 个内容(事件+段落)") return updated_state def sufficiency_check_node(self, state: QueryState) -> QueryState: """ 充分性检查节点 判断当前检索到的信息是否足够回答用户查询 包含对分解子查询的处理 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[?] 执行充分性检查 (迭代 {state['current_iteration']})") self._check_should_stop(state, "sufficiency_check") # 检查是否跳过LLM if self.prompt_loader.should_skip_llm('sufficiency_check'): print(f"[NEXT] 跳过充分性检查LLM调用,默认为充分") state['is_sufficient'] = True state['sufficiency_confidence'] = 1.0 return state # 格式化检索结果(如果有all_documents则使用混合格式,否则使用传统格式) if 'all_documents' in state and state['all_documents']: formatted_passages = format_mixed_passages(state['all_documents']) else: formatted_passages = format_passages(state['all_passages']) # 格式化分解的子查询 decomposed_sub_queries = state.get('decomposed_sub_queries', []) formatted_decomposed_queries = format_sub_queries(decomposed_sub_queries) if decomposed_sub_queries else "无" # 格式化事件三元组 event_triples = state.get('event_triples', []) formatted_event_triples = format_event_triples(event_triples) # 构建提示词,包含分解的子查询和事件三元组信息 prompt = SUFFICIENCY_CHECK_PROMPT.format( query=state['original_query'], passages=formatted_passages, decomposed_sub_queries=formatted_decomposed_queries, event_triples=formatted_event_triples ) # 调用专用的充分性检查LLM try: response = self.sufficiency_llm.invoke(prompt) response_text = self._extract_response_text(response) # 解析响应 sufficiency_result = self.sufficiency_parser.parse(response_text) # 创建充分性检查结果 sufficiency_check = SufficiencyCheck( is_sufficient=sufficiency_result.is_sufficient, confidence=sufficiency_result.confidence, reason=sufficiency_result.reason, sub_queries=sufficiency_result.sub_queries ) # 更新状态 updated_state = update_state_with_sufficiency_check(state, sufficiency_check) # 更新调试信息 updated_state["debug_info"]["llm_calls"] += 1 print(f"[INFO] 充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'} " f"(置信度: {sufficiency_result.confidence:.2f})") print(f" 基于 {len(state['all_passages'])} 个段落 (来自原始查询和{len(decomposed_sub_queries)}个子查询)") if not sufficiency_result.is_sufficient and sufficiency_result.sub_queries: print(f"[TARGET] 生成新的子查询: {sufficiency_result.sub_queries}") return updated_state except Exception as e: print(f"[ERROR] 充分性检查失败: {e}") # 如果检查失败,假设不充分并生成默认子查询 sufficiency_check = SufficiencyCheck( is_sufficient=False, confidence=0.5, reason=f"充分性检查失败: {str(e)}", sub_queries=[state['original_query'] + " 详细信息"] ) return update_state_with_sufficiency_check(state, sufficiency_check) def sub_query_generation_node(self, state: QueryState) -> QueryState: """ 子查询生成节点 如果充分性检查不通过,生成子查询 考虑之前已经生成的分解子查询 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[TARGET] 生成子查询") self._check_should_stop(state, "sub_query_generation") # 检查是否跳过LLM if self.prompt_loader.should_skip_llm('sub_query_generation'): print(f"[NEXT] 跳过子查询生成LLM调用") # 不生成新查询,继续下一轮 state['current_iteration'] += 1 return state # 如果已经有子查询,直接返回 if state['current_sub_queries']: print(f"[OK] 使用现有子查询: {state['current_sub_queries']}") return state # 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式) if 'all_documents' in state and state['all_documents']: formatted_passages = format_mixed_passages(state['all_documents']) else: formatted_passages = format_passages(state['all_passages']) # 格式化之前生成的所有子查询(包括分解的子查询) previous_sub_queries = state.get('sub_queries', []) formatted_previous_queries = format_sub_queries(previous_sub_queries) if previous_sub_queries else "无" # 格式化事件三元组 event_triples = state.get('event_triples', []) formatted_event_triples = format_event_triples(event_triples) # 获取充分性检查的不充分原因 sufficiency_check = state.get('sufficiency_check', {}) insufficiency_reason = sufficiency_check.get('reason', '信息不够充分,需要更多相关信息') # 构建提示词,包含充分性检查的反馈和事件三元组 prompt = SUB_QUERY_GENERATION_PROMPT.format( original_query=state['original_query'], existing_passages=formatted_passages, previous_sub_queries=formatted_previous_queries, event_triples=formatted_event_triples, insufficiency_reason=insufficiency_reason ) try: # 调用LLM生成子查询 response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) # 更新调试信息 state["debug_info"]["llm_calls"] += 1 # 清理和提取JSON内容 cleaned_text = response_text.strip() # 如果响应包含"答案:",提取其后的内容 if "答案:" in cleaned_text or "答案:" in cleaned_text: # 找到"答案:"后的内容 answer_markers = ["答案:", "答案:"] for marker in answer_markers: if marker in cleaned_text: cleaned_text = cleaned_text.split(marker, 1)[1].strip() break # 清理markdown代码块标记 if cleaned_text.startswith('```json'): cleaned_text = cleaned_text[7:] # 移除 ```json elif cleaned_text.startswith('```'): cleaned_text = cleaned_text[3:] # 移除 ``` if cleaned_text.endswith('```'): cleaned_text = cleaned_text[:-3] # 移除结尾的 ``` cleaned_text = cleaned_text.strip() # 尝试提取JSON部分(查找{开始到}结束的内容) if '{' in cleaned_text and '}' in cleaned_text: start_idx = cleaned_text.find('{') end_idx = cleaned_text.rfind('}') + 1 cleaned_text = cleaned_text[start_idx:end_idx] # 解析子查询 try: data = json.loads(cleaned_text) raw_sub_queries = data.get('sub_queries', []) # 处理不同格式的子查询 sub_queries = [] for item in raw_sub_queries: if isinstance(item, str): # 如果是字符串,直接使用 sub_queries.append(item) elif isinstance(item, dict) and 'query' in item: # 如果是字典,提取query字段 sub_queries.append(item['query']) elif isinstance(item, dict): # 如果是字典但没有query字段,尝试找到查询内容 query_text = item.get('text', item.get('content', str(item))) sub_queries.append(query_text) else: # 其他情况,转为字符串 sub_queries.append(str(item)) except json.JSONDecodeError: # 如果JSON解析失败,使用简单规则提取 lines = response_text.split('\n') sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2] # 确保有子查询且都是字符串 if not sub_queries: sub_queries = [state['original_query'] + " 更多细节"] # 确保所有子查询都是字符串 sub_queries = [str(query).strip() for query in sub_queries if query] # 更新状态 state['current_sub_queries'] = sub_queries[:2] # 最多2个子查询 state['sub_queries'].extend(state['current_sub_queries']) print(f"[OK] 生成子查询: {state['current_sub_queries']}") print(f" (避免与之前的子查询重复: {formatted_previous_queries})") return state except Exception as e: print(f"[ERROR] 子查询生成失败: {e}") # 如果生成失败,使用默认子查询 default_sub_query = state['original_query'] + " 补充信息" state['current_sub_queries'] = [default_sub_query] state['sub_queries'].append(default_sub_query) return state def parallel_retrieval_node(self, state: QueryState) -> QueryState: """ 并行检索节点 使用子查询并行进行检索 Args: state: 当前状态 Returns: 更新后的状态 """ sub_queries = state['current_sub_queries'] if not sub_queries: print("[WARNING] 没有子查询,跳过并行检索") return state print(f"[?] 并行检索 {len(sub_queries)} 个子查询") self._check_should_stop(state, "parallel_retrieval") def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, str]: """检索单个查询""" try: documents = self.retriever.invoke(query) passages = [doc.page_content for doc in documents] sources = [f"子查询{index+1}-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents] return index, documents, passages, sources except Exception as e: print(f"[ERROR] 子查询 {index+1} 检索失败: {e}") return index, [], [], [] # 并行执行检索 all_new_documents = [] all_new_passages = [] all_new_sources = [] with ThreadPoolExecutor(max_workers=self.max_parallel_retrievals) as executor: # 提交检索任务 futures = { executor.submit(retrieve_single_query, query, i): (query, i) for i, query in enumerate(sub_queries) } # 收集结果 for future in as_completed(futures): query, query_index = futures[future] try: index, documents, passages, sources = future.result() if passages: # 只添加非空结果 all_new_documents.extend(documents) all_new_passages.extend(passages) all_new_sources.extend(sources) print(f"[OK] 子查询 {index+1} 完成,获得 {len(passages)} 个段落") else: print(f"[WARNING] 子查询 {index+1} 无结果") except Exception as e: print(f"[ERROR] 子查询 {query_index+1} 处理失败: {e}") # 更新状态 if all_new_passages: retrieval_result = RetrievalResult( passages=all_new_passages, documents=all_new_documents, sources=all_new_sources, query=f"并行检索: {', '.join(sub_queries)}", iteration=state['current_iteration'] ) state = update_state_with_retrieval(state, retrieval_result) print(f"[SUCCESS] 并行检索完成,总共获得 {len(all_new_passages)} 个新段落") # 收集子查询的PageRank分数信息 for i, query in enumerate(sub_queries): try: # 跳过PageRank数据收集以避免LangSmith传输大量数据 # complete_ppr_info = self.retriever.get_complete_pagerank_scores(query) # 仅设置标识表示数据可用(实际数据在HippoRAG内部处理) state['pagerank_data_available'] = True except Exception as e: print(f"[WARNING] 收集并行子查询{i+1}的PageRank分数失败: {e}") else: print("[WARNING] 并行检索无有效结果") # 清空当前子查询 state['current_sub_queries'] = [] return state def final_answer_generation_node(self, state: QueryState) -> QueryState: """ 最终答案生成节点 基于所有检索到的信息生成最终答案 Args: state: 当前状态 Returns: 更新后的状态 """ self._check_should_stop(state, "final_answer_generation") # 格式化所有检索结果(如果有all_documents则使用混合格式,否则使用传统格式) if 'all_documents' in state and state['all_documents']: formatted_passages = format_mixed_passages(state['all_documents']) else: formatted_passages = format_passages(state['all_passages']) formatted_sub_queries = format_sub_queries(state['sub_queries']) # 格式化事件三元组 event_triples = state.get('event_triples', []) formatted_event_triples = format_event_triples(event_triples) # 检查是否跳过LLM生成 if self.skip_llm_generation: print(f"[SEARCH] 跳过LLM生成,直接返回检索结果") # 构建格式化的检索结果作为最终答案 retrieval_summary = f"""【检索结果汇总】 查询问题:{state['original_query']} 检索到 {len(state['all_passages'])} 个相关段落: {formatted_passages} """ if state['sub_queries']: retrieval_summary += f""" 相关子查询: {formatted_sub_queries} """ retrieval_summary += f""" 检索统计: - 查询复杂度:{state.get('query_complexity', 'unknown')} - 是否复杂查询:{state.get('is_complex_query', False)} - 迭代次数:{state.get('current_iteration', 0)} - 信息充分性:{state.get('is_sufficient', False)} """ # 完成状态 updated_state = finalize_state(state, retrieval_summary) print(f"[OK] 检索结果返回完成 (长度: {len(retrieval_summary)} 字符)") return updated_state # 检查是否有流式回调 stream_callback = state.get('stream_callback') # 调试:打印state中是否有stream_callback if stream_callback: print(f"[DEBUG] 检测到stream_callback,类型: {type(stream_callback)}") else: print(f"[DEBUG] 没有stream_callback,state keys: {list(state.keys())[:10]}") print(f"[NOTE] 生成最终答案" + (" (流式模式)" if stream_callback else "")) # 检查是否跳过LLM if self.prompt_loader.should_skip_llm('final_answer'): print(f"[NEXT] 跳过最终答案LLM调用,直接返回检索结果") final_answer = f"检索到的信息:\n{formatted_passages}" state['final_answer'] = final_answer return finalize_state(state, final_answer) # 构建提示词 prompt = FINAL_ANSWER_PROMPT.format( original_query=state['original_query'], all_passages=formatted_passages, sub_queries=formatted_sub_queries, event_triples=formatted_event_triples ) try: if stream_callback: # 包装回调,添加类型标识 def answer_stream_callback(chunk): stream_callback("answer_chunk", {"text": chunk}) # 构建config传递回调 config = { 'metadata': { 'stream_callback': answer_stream_callback } } # 调用LLM(会自动使用流式) final_answer = self.llm.invoke(prompt, config=config) else: # 非流式调用 response = self.llm.invoke(prompt) final_answer = self._extract_response_text(response) if not final_answer.strip(): final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。" # 完成状态 updated_state = finalize_state(state, final_answer) print(f"[OK] 最终答案生成完成 (长度: {len(final_answer)} 字符)") return updated_state except Exception as e: print(f"[ERROR] 最终答案生成失败: {e}") error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}" return finalize_state(state, error_answer) def next_iteration_node(self, state: QueryState) -> QueryState: """ 下一轮迭代节点 增加迭代计数并准备下一轮 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[?] 进入迭代 {state['current_iteration'] + 1}") self._check_should_stop(state, "next_iteration") # 增加迭代次数 updated_state = increment_iteration(state) return updated_state def event_triples_extraction_node(self, state: QueryState) -> QueryState: """ 事件三元组提取节点 从当前检索到的事件节点中提取事件-事件之间的三元组关系 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[SEARCH] 执行事件三元组提取") self._check_should_stop(state, "event_triples_extraction") # 从检索结果中提取事件节点ID event_node_ids = [] for doc in state.get('all_documents', []): if doc.metadata.get('node_type') == 'event': node_id = doc.metadata.get('node_id') if node_id: event_node_ids.append(node_id) print(f"[INFO] 发现 {len(event_node_ids)} 个事件节点") # 如果没有事件节点或图数据不可用,直接返回 if not event_node_ids or not self.retriever.graph_data: print("[WARNING] 没有事件节点或图数据不可用,跳过三元组提取") return update_state_with_event_triples(state, []) # 查询事件-事件三元组 event_triples = self._extract_event_event_triples(event_node_ids) # 更新状态 updated_state = update_state_with_event_triples(state, event_triples) print(f"[OK] 事件三元组提取完成,获得 {len(event_triples)} 个三元组") # 显示提取到的三元组(调试用) if event_triples: print("[INFO] 提取到的事件-事件三元组:") for i, triple in enumerate(event_triples[:5]): # 只显示前5个 if isinstance(triple, dict) and 'source_entity' in triple: print(f" {i+1}. [{triple['source_entity']}] --{triple['relation']}--> [{triple['target_entity']}]") else: print(f" {i+1}. {str(triple)}") if len(event_triples) > 5: print(f" ... 还有 {len(event_triples) - 5} 个三元组") return updated_state def _extract_event_event_triples(self, event_node_ids: List[str]) -> List[Dict[str, str]]: """ 从图中提取事件节点之间的直接连接三元组 Args: event_node_ids: 事件节点ID列表 Returns: 三元组列表,格式: [{'source_entity': '事件A', 'relation': '关系', 'target_entity': '事件B', 'source_evidence': '来源', 'target_evidence': '来源'}, ...] """ if not self.retriever.graph_data: return [] event_node_set = set(event_node_ids) triples = [] print(f"[SEARCH] 在图中查询 {len(event_node_ids)} 个事件节点间的连接") try: # 遍历图中所有边 for source, target, edge_data in self.retriever.graph_data.edges(data=True): # 检查源和目标是否都是事件节点且在检索结果中 if source in event_node_set and target in event_node_set: # 确认节点类型是事件 source_type = self.retriever.graph_data.nodes[source].get('type', '') target_type = self.retriever.graph_data.nodes[target].get('type', '') if source_type == 'event' and target_type == 'event': # 提取三元组信息 source_node_data = self.retriever.graph_data.nodes[source] target_node_data = self.retriever.graph_data.nodes[target] # 获取事件实体名称(优先使用name字段,这是中文名称) source_entity = source_node_data.get('name', source) target_entity = target_node_data.get('name', target) # 获取关系名称(从边数据中) relation = edge_data.get('relation', edge_data.get('label', 'connected_to')) # 获取evidence字段 source_evidence = source_node_data.get('evidence', '') target_evidence = target_node_data.get('evidence', '') # 构建增强的三元组,包含evidence信息 triple = { 'source_entity': str(source_entity), 'relation': str(relation), 'target_entity': str(target_entity), 'source_evidence': str(source_evidence), 'target_evidence': str(target_evidence) } triples.append(triple) # 去重(因为可能有重复的三元组) unique_triples = [] seen_triples = set() for triple in triples: # 使用三元组的核心信息作为去重键 triple_key = (triple['source_entity'], triple['relation'], triple['target_entity']) if triple_key not in seen_triples: unique_triples.append(triple) seen_triples.add(triple_key) print(f"[OK] 从图中提取到 {len(unique_triples)} 个唯一的事件-事件三元组") return unique_triples except Exception as e: print(f"[ERROR] 事件三元组提取失败: {e}") return []