Files
AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py

699 lines
30 KiB
Python
Raw Permalink Normal View History

2025-09-24 09:29:12 +08:00
"""
概念探索节点实现
在PageRank收集后进行知识图谱概念探索
"""
import json
import asyncio
from typing import Dict, Any, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from retriver.langgraph.graph_state import QueryState
from retriver.langgraph.langchain_components import (
OneAPILLM,
ConceptExplorationParser,
ConceptContinueParser,
ConceptSufficiencyParser,
ConceptSubQueryParser,
CONCEPT_EXPLORATION_INIT_PROMPT,
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
format_passages,
format_mixed_passages,
format_triplets,
format_exploration_path
)
class ConceptExplorationNodes:
"""概念探索节点实现类"""
def __init__(
self,
retriever, # LangChainHippoRAGRetriever
llm: OneAPILLM,
max_exploration_steps: int = 3,
max_parallel_explorations: int = 3
):
"""
初始化概念探索节点处理器
Args:
retriever: HippoRAG检索器用于访问知识图谱
llm: OneAPI LLM
max_exploration_steps: 每个分支最大探索步数
max_parallel_explorations: 最大并行探索分支数
"""
self.retriever = retriever
self.llm = llm
self.max_exploration_steps = max_exploration_steps
self.max_parallel_explorations = max_parallel_explorations
# 创建解析器
self.concept_exploration_parser = ConceptExplorationParser()
self.concept_continue_parser = ConceptContinueParser()
self.concept_sufficiency_parser = ConceptSufficiencyParser()
self.concept_sub_query_parser = ConceptSubQueryParser()
def _extract_response_text(self, response):
"""统一的响应文本提取方法"""
if hasattr(response, 'generations') and response.generations:
return response.generations[0][0].text
elif hasattr(response, 'content'):
return response.content
elif isinstance(response, dict) and 'response' in response:
return response['response']
else:
return str(response)
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
"""
从PageRank结果中提取TOP-3个Label为Node的节点
Args:
state: 当前状态
Returns:
TOP-3 Node节点列表每个包含node_id, score, label, type信息
"""
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
pagerank_summary = state.get('pagerank_summary', {})
if not pagerank_summary:
print("[ERROR] 没有找到PageRank汇总信息")
return []
# 获取所有带标签信息的节点
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
if not all_nodes_with_labels:
print("[ERROR] 没有找到带标签的节点信息")
return []
# 筛选出Label为Node的节点
node_candidates = []
for node_info in all_nodes_with_labels:
node_type = node_info.get('type', '').lower()
label = node_info.get('label', '')
# 检查是否为Node类型的节点
if 'node' in node_type.lower() or node_type == 'node':
node_candidates.append(node_info)
# 按分数降序排列并取前3个
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
top_3_nodes = node_candidates[:3]
print(f"[OK] 找到 {len(node_candidates)} 个Node节点选择TOP-3:")
for i, node in enumerate(top_3_nodes):
node_id = node.get('node_id', 'unknown')
score = node.get('score', 0)
label = node.get('label', 'unknown')
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
return top_3_nodes
def _get_connected_concepts(self, node_id: str) -> List[str]:
"""
获取指定Node节点连接的所有Concept节点
Args:
node_id: 节点ID
Returns:
连接的Concept节点名称列表
"""
try:
# 通过检索器访问图数据获取邻居
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
return []
connected_concepts = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居节点
neighbors = list(graph.neighbors(node_id))
for neighbor_id in neighbors:
neighbor_data = graph.nodes.get(neighbor_id, {})
neighbor_type = neighbor_data.get('type', '').lower()
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
# 筛选Label为Concept的邻居
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
connected_concepts.append(neighbor_label)
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
return connected_concepts
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
return []
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
获取指定节点的所有邻居三元组排除已访问的节点
Args:
node_id: 当前节点ID
excluded_nodes: 要排除的节点ID列表
Returns:
三元组列表每个三元组包含source, relation, target信息
"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
return []
if excluded_nodes is None:
excluded_nodes = []
triplets = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居边
for neighbor_id in graph.neighbors(node_id):
# 跳过已访问的节点
if neighbor_id in excluded_nodes:
continue
# 获取边的关系信息
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
# 获取节点标签
source_label = graph.nodes.get(node_id, {}).get('label',
graph.nodes.get(node_id, {}).get('name', node_id))
target_label = graph.nodes.get(neighbor_id, {}).get('label',
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
triplets.append({
'source': source_label,
'relation': relation,
'target': target_label,
'source_id': node_id,
'target_id': neighbor_id
})
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
return triplets
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
return []
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
"""
探索单个分支的完整路径
Args:
node_info: 起始Node节点信息
user_query: 用户查询
insufficiency_reason: 不充分原因
Returns:
探索结果字典
"""
node_id = node_info.get('node_id', '')
node_label = node_info.get('label', '')
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
exploration_result = {
'start_node': node_info,
'exploration_path': [],
'visited_nodes': [node_id],
'knowledge_gained': [],
'success': False,
'error': None
}
try:
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
connected_concepts = self._get_connected_concepts(node_id)
if not connected_concepts:
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
exploration_result['error'] = "没有连接的Concept节点"
return exploration_result
# 使用LLM选择最佳概念
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
if not concept_selection or not concept_selection.selected_concept:
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
exploration_result['error'] = "未能选择有效概念"
return exploration_result
selected_concept = concept_selection.selected_concept
print(f"[POINT] 选择探索概念: {selected_concept}")
# 记录初始探索决策
exploration_result['knowledge_gained'].append({
'step': 0,
'action': 'concept_selection',
'selected_concept': selected_concept,
'reason': concept_selection.exploration_reason,
'expected_knowledge': concept_selection.expected_knowledge
})
# 找到对应的概念节点ID
concept_node_id = self._find_concept_node_id(selected_concept)
if not concept_node_id:
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
return exploration_result
# 开始多跳探索
current_node_id = concept_node_id
exploration_result['visited_nodes'].append(current_node_id)
for step in range(1, self.max_exploration_steps + 1):
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
# 获取当前节点的邻居三元组,排除已访问节点
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
if not neighbor_triplets:
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
break
# 使用LLM选择下一个探索方向
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
if not continue_result or not continue_result.selected_node:
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
break
# 找到选择节点的ID
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
break
# 记录探索步骤
step_info = {
'source': current_node_id,
'relation': continue_result.selected_relation,
'target': next_node_id,
'reason': continue_result.exploration_reason
}
exploration_result['exploration_path'].append(step_info)
# 记录获得的知识
exploration_result['knowledge_gained'].append({
'step': step,
'action': 'exploration_step',
'selected_node': continue_result.selected_node,
'relation': continue_result.selected_relation,
'reason': continue_result.exploration_reason,
'expected_knowledge': continue_result.expected_knowledge
})
# 移动到下一个节点
current_node_id = next_node_id
exploration_result['visited_nodes'].append(current_node_id)
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
exploration_result['success'] = True
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])}")
except Exception as e:
print(f"[ERROR] 分支探索失败: {e}")
exploration_result['error'] = str(e)
return exploration_result
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
"""使用LLM选择最佳概念进行探索"""
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
node_name=node_name,
connected_concepts=concepts_str,
user_query=user_query,
insufficiency_reason=insufficiency_reason
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_exploration_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 概念选择失败: {e}")
return None
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
"""使用LLM决定下一步探索方向"""
triplets_str = format_triplets(neighbor_triplets)
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
current_node=current_node,
neighbor_triplets=triplets_str,
exploration_path=exploration_path,
user_query=user_query
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_continue_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 探索继续失败: {e}")
return None
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
"""根据概念标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
node_label = node_data.get('label', node_data.get('name', ''))
node_type = node_data.get('type', '').lower()
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找概念节点ID失败: {e}")
return None
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
"""根据节点标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
label = node_data.get('label', node_data.get('name', ''))
if label == node_label or node_id == node_label:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找节点ID失败: {e}")
return None
def concept_exploration_node(self, state: QueryState) -> QueryState:
"""
概念探索节点
1. 提取TOP-3 Node节点
2. 并行探索3个分支
3. 合并探索结果
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
# 更新探索轮次
state['exploration_round'] += 1
# 获取TOP-3 Node节点
top_nodes = self._get_top_node_candidates(state)
if not top_nodes:
print("[ERROR] 未找到合适的Node节点进行探索")
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
return state
# 获取不充分原因
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
user_query = state['original_query']
# 如果是第二轮探索,使用生成的子查询
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
user_query = sub_query_info['sub_query']
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
# 并行探索多个分支
exploration_results = []
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
# 提交探索任务
futures = {
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
for node_info in top_nodes
}
# 收集结果
for future in as_completed(futures):
node_info = futures[future]
try:
result = future.result()
exploration_results.append(result)
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
except Exception as e:
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
exploration_results.append({
'start_node': node_info,
'success': False,
'error': str(e)
})
# 合并探索结果
round_key = f'round_{state["exploration_round"]}_results'
state['concept_exploration_results'][round_key] = {
'exploration_results': exploration_results,
'total_branches': len(top_nodes),
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
'query_used': user_query
}
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
return state
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
"""
概念探索的充分性检查节点
综合评估文档段落和探索知识是否足够回答查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[THINK] 执行概念探索充分性检查")
# 格式化现有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化探索获得的知识
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
# 获取原始不充分原因
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
# 使用的查询(可能是子查询)
current_query = state['original_query']
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
current_query = sub_query_info['sub_query']
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
user_query=current_query,
all_passages=formatted_passages,
exploration_knowledge=exploration_knowledge,
insufficiency_reason=original_insufficiency
)
try:
# 调用LLM进行充分性检查
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
# 更新状态
state['is_sufficient'] = sufficiency_result.is_sufficient
# 存储概念探索的充分性检查结果
round_key = f'round_{state["exploration_round"]}_results'
if round_key in state['concept_exploration_results']:
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
'is_sufficient': sufficiency_result.is_sufficient,
'confidence': sufficiency_result.confidence,
'reason': sufficiency_result.reason,
'missing_aspects': sufficiency_result.missing_aspects,
'coverage_score': sufficiency_result.coverage_score
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
if not sufficiency_result.is_sufficient:
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索充分性检查失败: {e}")
# 如果检查失败,假设不充分
state['is_sufficient'] = False
return state
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
"""
概念探索子查询生成节点
如果第一轮探索后仍不充分生成针对性子查询进行第二轮探索
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 生成概念探索子查询")
# 获取第一轮探索结果和缺失信息
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
sufficiency_info = round_1_results.get('sufficiency_check', {})
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
# 格式化第一轮探索结果
exploration_summary = self._summarize_exploration_results(round_1_results)
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
user_query=state['original_query'],
missing_aspects=str(missing_aspects),
exploration_results=exploration_summary
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sub_query_result = self.concept_sub_query_parser.parse(response_text)
# 存储子查询信息
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': sub_query_result.sub_query,
'focus_aspects': sub_query_result.focus_aspects,
'expected_improvements': sub_query_result.expected_improvements,
'confidence': sub_query_result.confidence
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
print(f" 关注方面: {sub_query_result.focus_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索子查询生成失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_query = state['original_query'] + " 补充详细信息"
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': default_sub_query,
'focus_aspects': ['补充信息'],
'expected_improvements': '获取更多相关信息',
'confidence': 0.3
}
return state
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
"""格式化探索获得的知识为字符串"""
if not exploration_results:
return "无探索知识"
knowledge_parts = []
for round_key, round_data in exploration_results.items():
if not round_key.startswith('round_'):
continue
round_num = round_key.split('_')[1]
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
exploration_list = round_data.get('exploration_results', [])
for i, branch in enumerate(exploration_list):
if not branch.get('success', False):
continue
start_node = branch.get('start_node', {})
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
action = knowledge.get('action', '')
if action == 'concept_selection':
concept = knowledge.get('selected_concept', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 选择探索概念: {concept}")
knowledge_parts.append(f" 原因: {reason}")
elif action == 'exploration_step':
node = knowledge.get('selected_node', '')
relation = knowledge.get('relation', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
knowledge_parts.append(f" 原因: {reason}")
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
"""总结探索结果为简洁描述"""
if not round_results:
return "无探索结果"
exploration_list = round_results.get('exploration_results', [])
successful_count = len([r for r in exploration_list if r.get('success', False)])
total_count = len(exploration_list)
summary_parts = [
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
]
# 提取主要发现
key_findings = []
for branch in exploration_list:
if not branch.get('success', False):
continue
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
if knowledge.get('action') == 'concept_selection':
concept = knowledge.get('selected_concept', '')
if concept:
key_findings.append(f"探索了概念: {concept}")
if key_findings:
summary_parts.append("主要发现:")
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
return "\n".join(summary_parts)