""" 概念探索节点实现 在PageRank收集后进行知识图谱概念探索 """ import json import asyncio from typing import Dict, Any, List, Tuple, Optional from concurrent.futures import ThreadPoolExecutor, as_completed from retriver.langgraph.graph_state import QueryState from retriver.langgraph.langchain_components import ( OneAPILLM, ConceptExplorationParser, ConceptContinueParser, ConceptSufficiencyParser, ConceptSubQueryParser, CONCEPT_EXPLORATION_INIT_PROMPT, CONCEPT_EXPLORATION_CONTINUE_PROMPT, CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT, CONCEPT_EXPLORATION_SUB_QUERY_PROMPT, format_passages, format_mixed_passages, format_triplets, format_exploration_path ) class ConceptExplorationNodes: """概念探索节点实现类""" def __init__( self, retriever, # LangChainHippoRAGRetriever llm: OneAPILLM, max_exploration_steps: int = 3, max_parallel_explorations: int = 3 ): """ 初始化概念探索节点处理器 Args: retriever: HippoRAG检索器,用于访问知识图谱 llm: OneAPI LLM max_exploration_steps: 每个分支最大探索步数 max_parallel_explorations: 最大并行探索分支数 """ self.retriever = retriever self.llm = llm self.max_exploration_steps = max_exploration_steps self.max_parallel_explorations = max_parallel_explorations # 创建解析器 self.concept_exploration_parser = ConceptExplorationParser() self.concept_continue_parser = ConceptContinueParser() self.concept_sufficiency_parser = ConceptSufficiencyParser() self.concept_sub_query_parser = ConceptSubQueryParser() def _extract_response_text(self, response): """统一的响应文本提取方法""" if hasattr(response, 'generations') and response.generations: return response.generations[0][0].text elif hasattr(response, 'content'): return response.content elif isinstance(response, dict) and 'response' in response: return response['response'] else: return str(response) def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]: """ 从PageRank结果中提取TOP-3个Label为Node的节点 Args: state: 当前状态 Returns: TOP-3 Node节点列表,每个包含node_id, score, label, type信息 """ print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点") pagerank_summary = state.get('pagerank_summary', {}) if not pagerank_summary: print("[ERROR] 没有找到PageRank汇总信息") return [] # 获取所有带标签信息的节点 all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', []) if not all_nodes_with_labels: print("[ERROR] 没有找到带标签的节点信息") return [] # 筛选出Label为Node的节点 node_candidates = [] for node_info in all_nodes_with_labels: node_type = node_info.get('type', '').lower() label = node_info.get('label', '') # 检查是否为Node类型的节点 if 'node' in node_type.lower() or node_type == 'node': node_candidates.append(node_info) # 按分数降序排列并取前3个 node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True) top_3_nodes = node_candidates[:3] print(f"[OK] 找到 {len(node_candidates)} 个Node节点,选择TOP-3:") for i, node in enumerate(top_3_nodes): node_id = node.get('node_id', 'unknown') score = node.get('score', 0) label = node.get('label', 'unknown') print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}") return top_3_nodes def _get_connected_concepts(self, node_id: str) -> List[str]: """ 获取指定Node节点连接的所有Concept节点 Args: node_id: 节点ID Returns: 连接的Concept节点名称列表 """ try: # 通过检索器访问图数据获取邻居 if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data: print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居") return [] connected_concepts = [] graph = self.retriever.graph_data # 检查节点是否存在 if node_id not in graph.nodes: print(f"[ERROR] 节点 {node_id} 不存在于图中") return [] # 获取所有邻居节点 neighbors = list(graph.neighbors(node_id)) for neighbor_id in neighbors: neighbor_data = graph.nodes.get(neighbor_id, {}) neighbor_type = neighbor_data.get('type', '').lower() neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id)) # 筛选Label为Concept的邻居 if 'concept' in neighbor_type.lower() or neighbor_type == 'concept': connected_concepts.append(neighbor_label) print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}") return connected_concepts except Exception as e: print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}") return [] def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]: """ 获取指定节点的所有邻居三元组,排除已访问的节点 Args: node_id: 当前节点ID excluded_nodes: 要排除的节点ID列表 Returns: 三元组列表,每个三元组包含source, relation, target信息 """ try: if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data: print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组") return [] if excluded_nodes is None: excluded_nodes = [] triplets = [] graph = self.retriever.graph_data # 检查节点是否存在 if node_id not in graph.nodes: print(f"[ERROR] 节点 {node_id} 不存在于图中") return [] # 获取所有邻居边 for neighbor_id in graph.neighbors(node_id): # 跳过已访问的节点 if neighbor_id in excluded_nodes: continue # 获取边的关系信息 edge_data = graph.get_edge_data(node_id, neighbor_id, {}) relation = edge_data.get('relation', edge_data.get('label', 'related_to')) # 获取节点标签 source_label = graph.nodes.get(node_id, {}).get('label', graph.nodes.get(node_id, {}).get('name', node_id)) target_label = graph.nodes.get(neighbor_id, {}).get('label', graph.nodes.get(neighbor_id, {}).get('name', neighbor_id)) triplets.append({ 'source': source_label, 'relation': relation, 'target': target_label, 'source_id': node_id, 'target_id': neighbor_id }) print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}") return triplets except Exception as e: print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}") return [] def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]: """ 探索单个分支的完整路径 Args: node_info: 起始Node节点信息 user_query: 用户查询 insufficiency_reason: 不充分原因 Returns: 探索结果字典 """ node_id = node_info.get('node_id', '') node_label = node_info.get('label', '') print(f"[STARTING] 开始探索分支: {node_label} ({node_id})") exploration_result = { 'start_node': node_info, 'exploration_path': [], 'visited_nodes': [node_id], 'knowledge_gained': [], 'success': False, 'error': None } try: # 步骤1: 获取连接的Concept节点并选择最佳探索方向 connected_concepts = self._get_connected_concepts(node_id) if not connected_concepts: print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点") exploration_result['error'] = "没有连接的Concept节点" return exploration_result # 使用LLM选择最佳概念 concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason) if not concept_selection or not concept_selection.selected_concept: print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念") exploration_result['error'] = "未能选择有效概念" return exploration_result selected_concept = concept_selection.selected_concept print(f"[POINT] 选择探索概念: {selected_concept}") # 记录初始探索决策 exploration_result['knowledge_gained'].append({ 'step': 0, 'action': 'concept_selection', 'selected_concept': selected_concept, 'reason': concept_selection.exploration_reason, 'expected_knowledge': concept_selection.expected_knowledge }) # 找到对应的概念节点ID concept_node_id = self._find_concept_node_id(selected_concept) if not concept_node_id: print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID") exploration_result['error'] = f"未找到概念节点ID: {selected_concept}" return exploration_result # 开始多跳探索 current_node_id = concept_node_id exploration_result['visited_nodes'].append(current_node_id) for step in range(1, self.max_exploration_steps + 1): print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}") # 获取当前节点的邻居三元组,排除已访问节点 neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes']) if not neighbor_triplets: print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束") break # 使用LLM选择下一个探索方向 exploration_path_str = format_exploration_path(exploration_result['exploration_path']) continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query) if not continue_result or not continue_result.selected_node: print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点") break # 找到选择节点的ID next_node_id = self._find_node_id_by_label(continue_result.selected_node) if not next_node_id or next_node_id in exploration_result['visited_nodes']: print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}") break # 记录探索步骤 step_info = { 'source': current_node_id, 'relation': continue_result.selected_relation, 'target': next_node_id, 'reason': continue_result.exploration_reason } exploration_result['exploration_path'].append(step_info) # 记录获得的知识 exploration_result['knowledge_gained'].append({ 'step': step, 'action': 'exploration_step', 'selected_node': continue_result.selected_node, 'relation': continue_result.selected_relation, 'reason': continue_result.exploration_reason, 'expected_knowledge': continue_result.expected_knowledge }) # 移动到下一个节点 current_node_id = next_node_id exploration_result['visited_nodes'].append(current_node_id) print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}") exploration_result['success'] = True print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])} 步") except Exception as e: print(f"[ERROR] 分支探索失败: {e}") exploration_result['error'] = str(e) return exploration_result def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str): """使用LLM选择最佳概念进行探索""" concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts]) prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format( node_name=node_name, connected_concepts=concepts_str, user_query=user_query, insufficiency_reason=insufficiency_reason ) try: response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) return self.concept_exploration_parser.parse(response_text) except Exception as e: print(f"[ERROR] 概念选择失败: {e}") return None def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str): """使用LLM决定下一步探索方向""" triplets_str = format_triplets(neighbor_triplets) prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format( current_node=current_node, neighbor_triplets=triplets_str, exploration_path=exploration_path, user_query=user_query ) try: response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) return self.concept_continue_parser.parse(response_text) except Exception as e: print(f"[ERROR] 探索继续失败: {e}") return None def _find_concept_node_id(self, concept_label: str) -> Optional[str]: """根据概念标签找到对应的节点ID""" try: if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data: return None graph = self.retriever.graph_data for node_id, node_data in graph.nodes(data=True): node_label = node_data.get('label', node_data.get('name', '')) node_type = node_data.get('type', '').lower() if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type: return node_id return None except Exception as e: print(f"[ERROR] 查找概念节点ID失败: {e}") return None def _find_node_id_by_label(self, node_label: str) -> Optional[str]: """根据节点标签找到对应的节点ID""" try: if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data: return None graph = self.retriever.graph_data for node_id, node_data in graph.nodes(data=True): label = node_data.get('label', node_data.get('name', '')) if label == node_label or node_id == node_label: return node_id return None except Exception as e: print(f"[ERROR] 查找节点ID失败: {e}") return None def concept_exploration_node(self, state: QueryState) -> QueryState: """ 概念探索节点 1. 提取TOP-3 Node节点 2. 并行探索3个分支 3. 合并探索结果 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})") # 更新探索轮次 state['exploration_round'] += 1 # 获取TOP-3 Node节点 top_nodes = self._get_top_node_candidates(state) if not top_nodes: print("[ERROR] 未找到合适的Node节点进行探索") state['concept_exploration_results'] = {'error': '未找到合适的Node节点'} return state # 获取不充分原因 insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分') user_query = state['original_query'] # 如果是第二轮探索,使用生成的子查询 if state['exploration_round'] == 2: sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {}) if sub_query_info and sub_query_info.get('sub_query'): user_query = sub_query_info['sub_query'] print(f"[TARGET] 第二轮探索使用子查询: {user_query}") # 并行探索多个分支 exploration_results = [] with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor: # 提交探索任务 futures = { executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info for node_info in top_nodes } # 收集结果 for future in as_completed(futures): node_info = futures[future] try: result = future.result() exploration_results.append(result) print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成") except Exception as e: print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}") exploration_results.append({ 'start_node': node_info, 'success': False, 'error': str(e) }) # 合并探索结果 round_key = f'round_{state["exploration_round"]}_results' state['concept_exploration_results'][round_key] = { 'exploration_results': exploration_results, 'total_branches': len(top_nodes), 'successful_branches': len([r for r in exploration_results if r.get('success', False)]), 'query_used': user_query } print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成") print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}") return state def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState: """ 概念探索的充分性检查节点 综合评估文档段落和探索知识是否足够回答查询 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[THINK] 执行概念探索充分性检查") # 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式) if 'all_documents' in state and state['all_documents']: formatted_passages = format_mixed_passages(state['all_documents']) else: formatted_passages = format_passages(state['all_passages']) # 格式化探索获得的知识 exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results']) # 获取原始不充分原因 original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分') # 使用的查询(可能是子查询) current_query = state['original_query'] if state['exploration_round'] == 2: sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {}) if sub_query_info and sub_query_info.get('sub_query'): current_query = sub_query_info['sub_query'] # 构建提示词 prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format( user_query=current_query, all_passages=formatted_passages, exploration_knowledge=exploration_knowledge, insufficiency_reason=original_insufficiency ) try: # 调用LLM进行充分性检查 response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) # 解析结果 sufficiency_result = self.concept_sufficiency_parser.parse(response_text) # 更新状态 state['is_sufficient'] = sufficiency_result.is_sufficient # 存储概念探索的充分性检查结果 round_key = f'round_{state["exploration_round"]}_results' if round_key in state['concept_exploration_results']: state['concept_exploration_results'][round_key]['sufficiency_check'] = { 'is_sufficient': sufficiency_result.is_sufficient, 'confidence': sufficiency_result.confidence, 'reason': sufficiency_result.reason, 'missing_aspects': sufficiency_result.missing_aspects, 'coverage_score': sufficiency_result.coverage_score } # 更新调试信息 state["debug_info"]["llm_calls"] += 1 print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}") print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}") if not sufficiency_result.is_sufficient: print(f" 缺失方面: {sufficiency_result.missing_aspects}") return state except Exception as e: print(f"[ERROR] 概念探索充分性检查失败: {e}") # 如果检查失败,假设不充分 state['is_sufficient'] = False return state def concept_sub_query_generation_node(self, state: QueryState) -> QueryState: """ 概念探索子查询生成节点 如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索 Args: state: 当前状态 Returns: 更新后的状态 """ print(f"[TARGET] 生成概念探索子查询") # 获取第一轮探索结果和缺失信息 round_1_results = state['concept_exploration_results'].get('round_1_results', {}) sufficiency_info = round_1_results.get('sufficiency_check', {}) missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息']) # 格式化第一轮探索结果 exploration_summary = self._summarize_exploration_results(round_1_results) # 构建提示词 prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format( user_query=state['original_query'], missing_aspects=str(missing_aspects), exploration_results=exploration_summary ) try: # 调用LLM生成子查询 response = self.llm.invoke(prompt) response_text = self._extract_response_text(response) # 解析结果 sub_query_result = self.concept_sub_query_parser.parse(response_text) # 存储子查询信息 state['concept_exploration_results']['round_1_results']['sub_query_info'] = { 'sub_query': sub_query_result.sub_query, 'focus_aspects': sub_query_result.focus_aspects, 'expected_improvements': sub_query_result.expected_improvements, 'confidence': sub_query_result.confidence } # 更新调试信息 state["debug_info"]["llm_calls"] += 1 print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}") print(f" 关注方面: {sub_query_result.focus_aspects}") return state except Exception as e: print(f"[ERROR] 概念探索子查询生成失败: {e}") # 如果生成失败,使用默认子查询 default_sub_query = state['original_query'] + " 补充详细信息" state['concept_exploration_results']['round_1_results']['sub_query_info'] = { 'sub_query': default_sub_query, 'focus_aspects': ['补充信息'], 'expected_improvements': '获取更多相关信息', 'confidence': 0.3 } return state def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str: """格式化探索获得的知识为字符串""" if not exploration_results: return "无探索知识" knowledge_parts = [] for round_key, round_data in exploration_results.items(): if not round_key.startswith('round_'): continue round_num = round_key.split('_')[1] knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===") exploration_list = round_data.get('exploration_results', []) for i, branch in enumerate(exploration_list): if not branch.get('success', False): continue start_node = branch.get('start_node', {}) knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):") knowledge_gained = branch.get('knowledge_gained', []) for knowledge in knowledge_gained: action = knowledge.get('action', '') if action == 'concept_selection': concept = knowledge.get('selected_concept', '') reason = knowledge.get('reason', '') knowledge_parts.append(f"- 选择探索概念: {concept}") knowledge_parts.append(f" 原因: {reason}") elif action == 'exploration_step': node = knowledge.get('selected_node', '') relation = knowledge.get('relation', '') reason = knowledge.get('reason', '') knowledge_parts.append(f"- 探索到: {node} (关系: {relation})") knowledge_parts.append(f" 原因: {reason}") return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识" def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str: """总结探索结果为简洁描述""" if not round_results: return "无探索结果" exploration_list = round_results.get('exploration_results', []) successful_count = len([r for r in exploration_list if r.get('success', False)]) total_count = len(exploration_list) summary_parts = [ f"第一轮探索: {successful_count}/{total_count} 个分支成功", ] # 提取主要发现 key_findings = [] for branch in exploration_list: if not branch.get('success', False): continue knowledge_gained = branch.get('knowledge_gained', []) for knowledge in knowledge_gained: if knowledge.get('action') == 'concept_selection': concept = knowledge.get('selected_concept', '') if concept: key_findings.append(f"探索了概念: {concept}") if key_findings: summary_parts.append("主要发现:") summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个 return "\n".join(summary_parts)