Files
AIEC-new/AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py
2025-10-17 09:31:28 +08:00

699 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
概念探索节点实现
在PageRank收集后进行知识图谱概念探索
"""
import json
import asyncio
from typing import Dict, Any, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from retriver.langgraph.graph_state import QueryState
from retriver.langgraph.langchain_components import (
OneAPILLM,
ConceptExplorationParser,
ConceptContinueParser,
ConceptSufficiencyParser,
ConceptSubQueryParser,
CONCEPT_EXPLORATION_INIT_PROMPT,
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
format_passages,
format_mixed_passages,
format_triplets,
format_exploration_path
)
class ConceptExplorationNodes:
"""概念探索节点实现类"""
def __init__(
self,
retriever, # LangChainHippoRAGRetriever
llm: OneAPILLM,
max_exploration_steps: int = 3,
max_parallel_explorations: int = 3
):
"""
初始化概念探索节点处理器
Args:
retriever: HippoRAG检索器用于访问知识图谱
llm: OneAPI LLM
max_exploration_steps: 每个分支最大探索步数
max_parallel_explorations: 最大并行探索分支数
"""
self.retriever = retriever
self.llm = llm
self.max_exploration_steps = max_exploration_steps
self.max_parallel_explorations = max_parallel_explorations
# 创建解析器
self.concept_exploration_parser = ConceptExplorationParser()
self.concept_continue_parser = ConceptContinueParser()
self.concept_sufficiency_parser = ConceptSufficiencyParser()
self.concept_sub_query_parser = ConceptSubQueryParser()
def _extract_response_text(self, response):
"""统一的响应文本提取方法"""
if hasattr(response, 'generations') and response.generations:
return response.generations[0][0].text
elif hasattr(response, 'content'):
return response.content
elif isinstance(response, dict) and 'response' in response:
return response['response']
else:
return str(response)
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
"""
从PageRank结果中提取TOP-3个Label为Node的节点
Args:
state: 当前状态
Returns:
TOP-3 Node节点列表每个包含node_id, score, label, type信息
"""
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
pagerank_summary = state.get('pagerank_summary', {})
if not pagerank_summary:
print("[ERROR] 没有找到PageRank汇总信息")
return []
# 获取所有带标签信息的节点
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
if not all_nodes_with_labels:
print("[ERROR] 没有找到带标签的节点信息")
return []
# 筛选出Label为Node的节点
node_candidates = []
for node_info in all_nodes_with_labels:
node_type = node_info.get('type', '').lower()
label = node_info.get('label', '')
# 检查是否为Node类型的节点
if 'node' in node_type.lower() or node_type == 'node':
node_candidates.append(node_info)
# 按分数降序排列并取前3个
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
top_3_nodes = node_candidates[:3]
print(f"[OK] 找到 {len(node_candidates)} 个Node节点选择TOP-3:")
for i, node in enumerate(top_3_nodes):
node_id = node.get('node_id', 'unknown')
score = node.get('score', 0)
label = node.get('label', 'unknown')
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
return top_3_nodes
def _get_connected_concepts(self, node_id: str) -> List[str]:
"""
获取指定Node节点连接的所有Concept节点
Args:
node_id: 节点ID
Returns:
连接的Concept节点名称列表
"""
try:
# 通过检索器访问图数据获取邻居
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
return []
connected_concepts = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居节点
neighbors = list(graph.neighbors(node_id))
for neighbor_id in neighbors:
neighbor_data = graph.nodes.get(neighbor_id, {})
neighbor_type = neighbor_data.get('type', '').lower()
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
# 筛选Label为Concept的邻居
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
connected_concepts.append(neighbor_label)
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
return connected_concepts
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
return []
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
获取指定节点的所有邻居三元组,排除已访问的节点
Args:
node_id: 当前节点ID
excluded_nodes: 要排除的节点ID列表
Returns:
三元组列表每个三元组包含source, relation, target信息
"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
return []
if excluded_nodes is None:
excluded_nodes = []
triplets = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居边
for neighbor_id in graph.neighbors(node_id):
# 跳过已访问的节点
if neighbor_id in excluded_nodes:
continue
# 获取边的关系信息
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
# 获取节点标签
source_label = graph.nodes.get(node_id, {}).get('label',
graph.nodes.get(node_id, {}).get('name', node_id))
target_label = graph.nodes.get(neighbor_id, {}).get('label',
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
triplets.append({
'source': source_label,
'relation': relation,
'target': target_label,
'source_id': node_id,
'target_id': neighbor_id
})
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
return triplets
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
return []
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
"""
探索单个分支的完整路径
Args:
node_info: 起始Node节点信息
user_query: 用户查询
insufficiency_reason: 不充分原因
Returns:
探索结果字典
"""
node_id = node_info.get('node_id', '')
node_label = node_info.get('label', '')
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
exploration_result = {
'start_node': node_info,
'exploration_path': [],
'visited_nodes': [node_id],
'knowledge_gained': [],
'success': False,
'error': None
}
try:
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
connected_concepts = self._get_connected_concepts(node_id)
if not connected_concepts:
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
exploration_result['error'] = "没有连接的Concept节点"
return exploration_result
# 使用LLM选择最佳概念
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
if not concept_selection or not concept_selection.selected_concept:
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
exploration_result['error'] = "未能选择有效概念"
return exploration_result
selected_concept = concept_selection.selected_concept
print(f"[POINT] 选择探索概念: {selected_concept}")
# 记录初始探索决策
exploration_result['knowledge_gained'].append({
'step': 0,
'action': 'concept_selection',
'selected_concept': selected_concept,
'reason': concept_selection.exploration_reason,
'expected_knowledge': concept_selection.expected_knowledge
})
# 找到对应的概念节点ID
concept_node_id = self._find_concept_node_id(selected_concept)
if not concept_node_id:
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
return exploration_result
# 开始多跳探索
current_node_id = concept_node_id
exploration_result['visited_nodes'].append(current_node_id)
for step in range(1, self.max_exploration_steps + 1):
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
# 获取当前节点的邻居三元组,排除已访问节点
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
if not neighbor_triplets:
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
break
# 使用LLM选择下一个探索方向
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
if not continue_result or not continue_result.selected_node:
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
break
# 找到选择节点的ID
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
break
# 记录探索步骤
step_info = {
'source': current_node_id,
'relation': continue_result.selected_relation,
'target': next_node_id,
'reason': continue_result.exploration_reason
}
exploration_result['exploration_path'].append(step_info)
# 记录获得的知识
exploration_result['knowledge_gained'].append({
'step': step,
'action': 'exploration_step',
'selected_node': continue_result.selected_node,
'relation': continue_result.selected_relation,
'reason': continue_result.exploration_reason,
'expected_knowledge': continue_result.expected_knowledge
})
# 移动到下一个节点
current_node_id = next_node_id
exploration_result['visited_nodes'].append(current_node_id)
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
exploration_result['success'] = True
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])}")
except Exception as e:
print(f"[ERROR] 分支探索失败: {e}")
exploration_result['error'] = str(e)
return exploration_result
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
"""使用LLM选择最佳概念进行探索"""
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
node_name=node_name,
connected_concepts=concepts_str,
user_query=user_query,
insufficiency_reason=insufficiency_reason
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_exploration_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 概念选择失败: {e}")
return None
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
"""使用LLM决定下一步探索方向"""
triplets_str = format_triplets(neighbor_triplets)
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
current_node=current_node,
neighbor_triplets=triplets_str,
exploration_path=exploration_path,
user_query=user_query
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_continue_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 探索继续失败: {e}")
return None
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
"""根据概念标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
node_label = node_data.get('label', node_data.get('name', ''))
node_type = node_data.get('type', '').lower()
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找概念节点ID失败: {e}")
return None
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
"""根据节点标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
label = node_data.get('label', node_data.get('name', ''))
if label == node_label or node_id == node_label:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找节点ID失败: {e}")
return None
def concept_exploration_node(self, state: QueryState) -> QueryState:
"""
概念探索节点
1. 提取TOP-3 Node节点
2. 并行探索3个分支
3. 合并探索结果
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
# 更新探索轮次
state['exploration_round'] += 1
# 获取TOP-3 Node节点
top_nodes = self._get_top_node_candidates(state)
if not top_nodes:
print("[ERROR] 未找到合适的Node节点进行探索")
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
return state
# 获取不充分原因
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
user_query = state['original_query']
# 如果是第二轮探索,使用生成的子查询
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
user_query = sub_query_info['sub_query']
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
# 并行探索多个分支
exploration_results = []
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
# 提交探索任务
futures = {
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
for node_info in top_nodes
}
# 收集结果
for future in as_completed(futures):
node_info = futures[future]
try:
result = future.result()
exploration_results.append(result)
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
except Exception as e:
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
exploration_results.append({
'start_node': node_info,
'success': False,
'error': str(e)
})
# 合并探索结果
round_key = f'round_{state["exploration_round"]}_results'
state['concept_exploration_results'][round_key] = {
'exploration_results': exploration_results,
'total_branches': len(top_nodes),
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
'query_used': user_query
}
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
return state
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
"""
概念探索的充分性检查节点
综合评估文档段落和探索知识是否足够回答查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[THINK] 执行概念探索充分性检查")
# 格式化现有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化探索获得的知识
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
# 获取原始不充分原因
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
# 使用的查询(可能是子查询)
current_query = state['original_query']
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
current_query = sub_query_info['sub_query']
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
user_query=current_query,
all_passages=formatted_passages,
exploration_knowledge=exploration_knowledge,
insufficiency_reason=original_insufficiency
)
try:
# 调用LLM进行充分性检查
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
# 更新状态
state['is_sufficient'] = sufficiency_result.is_sufficient
# 存储概念探索的充分性检查结果
round_key = f'round_{state["exploration_round"]}_results'
if round_key in state['concept_exploration_results']:
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
'is_sufficient': sufficiency_result.is_sufficient,
'confidence': sufficiency_result.confidence,
'reason': sufficiency_result.reason,
'missing_aspects': sufficiency_result.missing_aspects,
'coverage_score': sufficiency_result.coverage_score
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
if not sufficiency_result.is_sufficient:
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索充分性检查失败: {e}")
# 如果检查失败,假设不充分
state['is_sufficient'] = False
return state
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
"""
概念探索子查询生成节点
如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 生成概念探索子查询")
# 获取第一轮探索结果和缺失信息
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
sufficiency_info = round_1_results.get('sufficiency_check', {})
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
# 格式化第一轮探索结果
exploration_summary = self._summarize_exploration_results(round_1_results)
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
user_query=state['original_query'],
missing_aspects=str(missing_aspects),
exploration_results=exploration_summary
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sub_query_result = self.concept_sub_query_parser.parse(response_text)
# 存储子查询信息
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': sub_query_result.sub_query,
'focus_aspects': sub_query_result.focus_aspects,
'expected_improvements': sub_query_result.expected_improvements,
'confidence': sub_query_result.confidence
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
print(f" 关注方面: {sub_query_result.focus_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索子查询生成失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_query = state['original_query'] + " 补充详细信息"
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': default_sub_query,
'focus_aspects': ['补充信息'],
'expected_improvements': '获取更多相关信息',
'confidence': 0.3
}
return state
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
"""格式化探索获得的知识为字符串"""
if not exploration_results:
return "无探索知识"
knowledge_parts = []
for round_key, round_data in exploration_results.items():
if not round_key.startswith('round_'):
continue
round_num = round_key.split('_')[1]
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
exploration_list = round_data.get('exploration_results', [])
for i, branch in enumerate(exploration_list):
if not branch.get('success', False):
continue
start_node = branch.get('start_node', {})
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
action = knowledge.get('action', '')
if action == 'concept_selection':
concept = knowledge.get('selected_concept', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 选择探索概念: {concept}")
knowledge_parts.append(f" 原因: {reason}")
elif action == 'exploration_step':
node = knowledge.get('selected_node', '')
relation = knowledge.get('relation', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
knowledge_parts.append(f" 原因: {reason}")
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
"""总结探索结果为简洁描述"""
if not round_results:
return "无探索结果"
exploration_list = round_results.get('exploration_results', [])
successful_count = len([r for r in exploration_list if r.get('success', False)])
total_count = len(exploration_list)
summary_parts = [
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
]
# 提取主要发现
key_findings = []
for branch in exploration_list:
if not branch.get('success', False):
continue
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
if knowledge.get('action') == 'concept_selection':
concept = knowledge.get('selected_concept', '')
if concept:
key_findings.append(f"探索了概念: {concept}")
if key_findings:
summary_parts.append("主要发现:")
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
return "\n".join(summary_parts)