first commit
This commit is contained in:
699
AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py
Normal file
699
AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py
Normal file
@ -0,0 +1,699 @@
|
||||
"""
|
||||
概念探索节点实现
|
||||
在PageRank收集后进行知识图谱概念探索
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from retriver.langgraph.graph_state import QueryState
|
||||
from retriver.langgraph.langchain_components import (
|
||||
OneAPILLM,
|
||||
ConceptExplorationParser,
|
||||
ConceptContinueParser,
|
||||
ConceptSufficiencyParser,
|
||||
ConceptSubQueryParser,
|
||||
CONCEPT_EXPLORATION_INIT_PROMPT,
|
||||
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
|
||||
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
|
||||
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
|
||||
format_passages,
|
||||
format_mixed_passages,
|
||||
format_triplets,
|
||||
format_exploration_path
|
||||
)
|
||||
|
||||
|
||||
class ConceptExplorationNodes:
|
||||
"""概念探索节点实现类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever, # LangChainHippoRAGRetriever
|
||||
llm: OneAPILLM,
|
||||
max_exploration_steps: int = 3,
|
||||
max_parallel_explorations: int = 3
|
||||
):
|
||||
"""
|
||||
初始化概念探索节点处理器
|
||||
|
||||
Args:
|
||||
retriever: HippoRAG检索器,用于访问知识图谱
|
||||
llm: OneAPI LLM
|
||||
max_exploration_steps: 每个分支最大探索步数
|
||||
max_parallel_explorations: 最大并行探索分支数
|
||||
"""
|
||||
self.retriever = retriever
|
||||
self.llm = llm
|
||||
self.max_exploration_steps = max_exploration_steps
|
||||
self.max_parallel_explorations = max_parallel_explorations
|
||||
|
||||
# 创建解析器
|
||||
self.concept_exploration_parser = ConceptExplorationParser()
|
||||
self.concept_continue_parser = ConceptContinueParser()
|
||||
self.concept_sufficiency_parser = ConceptSufficiencyParser()
|
||||
self.concept_sub_query_parser = ConceptSubQueryParser()
|
||||
|
||||
def _extract_response_text(self, response):
|
||||
"""统一的响应文本提取方法"""
|
||||
if hasattr(response, 'generations') and response.generations:
|
||||
return response.generations[0][0].text
|
||||
elif hasattr(response, 'content'):
|
||||
return response.content
|
||||
elif isinstance(response, dict) and 'response' in response:
|
||||
return response['response']
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从PageRank结果中提取TOP-3个Label为Node的节点
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
TOP-3 Node节点列表,每个包含node_id, score, label, type信息
|
||||
"""
|
||||
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
|
||||
|
||||
pagerank_summary = state.get('pagerank_summary', {})
|
||||
if not pagerank_summary:
|
||||
print("[ERROR] 没有找到PageRank汇总信息")
|
||||
return []
|
||||
|
||||
# 获取所有带标签信息的节点
|
||||
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
|
||||
if not all_nodes_with_labels:
|
||||
print("[ERROR] 没有找到带标签的节点信息")
|
||||
return []
|
||||
|
||||
# 筛选出Label为Node的节点
|
||||
node_candidates = []
|
||||
for node_info in all_nodes_with_labels:
|
||||
node_type = node_info.get('type', '').lower()
|
||||
label = node_info.get('label', '')
|
||||
|
||||
# 检查是否为Node类型的节点
|
||||
if 'node' in node_type.lower() or node_type == 'node':
|
||||
node_candidates.append(node_info)
|
||||
|
||||
# 按分数降序排列并取前3个
|
||||
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
|
||||
top_3_nodes = node_candidates[:3]
|
||||
|
||||
print(f"[OK] 找到 {len(node_candidates)} 个Node节点,选择TOP-3:")
|
||||
for i, node in enumerate(top_3_nodes):
|
||||
node_id = node.get('node_id', 'unknown')
|
||||
score = node.get('score', 0)
|
||||
label = node.get('label', 'unknown')
|
||||
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
|
||||
|
||||
return top_3_nodes
|
||||
|
||||
def _get_connected_concepts(self, node_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定Node节点连接的所有Concept节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
连接的Concept节点名称列表
|
||||
"""
|
||||
try:
|
||||
# 通过检索器访问图数据获取邻居
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
|
||||
return []
|
||||
|
||||
connected_concepts = []
|
||||
graph = self.retriever.graph_data
|
||||
|
||||
# 检查节点是否存在
|
||||
if node_id not in graph.nodes:
|
||||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||||
return []
|
||||
|
||||
# 获取所有邻居节点
|
||||
neighbors = list(graph.neighbors(node_id))
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
neighbor_data = graph.nodes.get(neighbor_id, {})
|
||||
neighbor_type = neighbor_data.get('type', '').lower()
|
||||
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
|
||||
|
||||
# 筛选Label为Concept的邻居
|
||||
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
|
||||
connected_concepts.append(neighbor_label)
|
||||
|
||||
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
|
||||
return connected_concepts
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定节点的所有邻居三元组,排除已访问的节点
|
||||
|
||||
Args:
|
||||
node_id: 当前节点ID
|
||||
excluded_nodes: 要排除的节点ID列表
|
||||
|
||||
Returns:
|
||||
三元组列表,每个三元组包含source, relation, target信息
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
|
||||
return []
|
||||
|
||||
if excluded_nodes is None:
|
||||
excluded_nodes = []
|
||||
|
||||
triplets = []
|
||||
graph = self.retriever.graph_data
|
||||
|
||||
# 检查节点是否存在
|
||||
if node_id not in graph.nodes:
|
||||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||||
return []
|
||||
|
||||
# 获取所有邻居边
|
||||
for neighbor_id in graph.neighbors(node_id):
|
||||
# 跳过已访问的节点
|
||||
if neighbor_id in excluded_nodes:
|
||||
continue
|
||||
|
||||
# 获取边的关系信息
|
||||
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
|
||||
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
|
||||
|
||||
# 获取节点标签
|
||||
source_label = graph.nodes.get(node_id, {}).get('label',
|
||||
graph.nodes.get(node_id, {}).get('name', node_id))
|
||||
target_label = graph.nodes.get(neighbor_id, {}).get('label',
|
||||
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
|
||||
|
||||
triplets.append({
|
||||
'source': source_label,
|
||||
'relation': relation,
|
||||
'target': target_label,
|
||||
'source_id': node_id,
|
||||
'target_id': neighbor_id
|
||||
})
|
||||
|
||||
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
|
||||
return triplets
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
|
||||
return []
|
||||
|
||||
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
|
||||
"""
|
||||
探索单个分支的完整路径
|
||||
|
||||
Args:
|
||||
node_info: 起始Node节点信息
|
||||
user_query: 用户查询
|
||||
insufficiency_reason: 不充分原因
|
||||
|
||||
Returns:
|
||||
探索结果字典
|
||||
"""
|
||||
node_id = node_info.get('node_id', '')
|
||||
node_label = node_info.get('label', '')
|
||||
|
||||
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
|
||||
|
||||
exploration_result = {
|
||||
'start_node': node_info,
|
||||
'exploration_path': [],
|
||||
'visited_nodes': [node_id],
|
||||
'knowledge_gained': [],
|
||||
'success': False,
|
||||
'error': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
|
||||
connected_concepts = self._get_connected_concepts(node_id)
|
||||
if not connected_concepts:
|
||||
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
|
||||
exploration_result['error'] = "没有连接的Concept节点"
|
||||
return exploration_result
|
||||
|
||||
# 使用LLM选择最佳概念
|
||||
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
|
||||
if not concept_selection or not concept_selection.selected_concept:
|
||||
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
|
||||
exploration_result['error'] = "未能选择有效概念"
|
||||
return exploration_result
|
||||
|
||||
selected_concept = concept_selection.selected_concept
|
||||
print(f"[POINT] 选择探索概念: {selected_concept}")
|
||||
|
||||
# 记录初始探索决策
|
||||
exploration_result['knowledge_gained'].append({
|
||||
'step': 0,
|
||||
'action': 'concept_selection',
|
||||
'selected_concept': selected_concept,
|
||||
'reason': concept_selection.exploration_reason,
|
||||
'expected_knowledge': concept_selection.expected_knowledge
|
||||
})
|
||||
|
||||
# 找到对应的概念节点ID
|
||||
concept_node_id = self._find_concept_node_id(selected_concept)
|
||||
if not concept_node_id:
|
||||
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
|
||||
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
|
||||
return exploration_result
|
||||
|
||||
# 开始多跳探索
|
||||
current_node_id = concept_node_id
|
||||
exploration_result['visited_nodes'].append(current_node_id)
|
||||
|
||||
for step in range(1, self.max_exploration_steps + 1):
|
||||
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
|
||||
|
||||
# 获取当前节点的邻居三元组,排除已访问节点
|
||||
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
|
||||
|
||||
if not neighbor_triplets:
|
||||
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
|
||||
break
|
||||
|
||||
# 使用LLM选择下一个探索方向
|
||||
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
|
||||
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
|
||||
|
||||
if not continue_result or not continue_result.selected_node:
|
||||
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
|
||||
break
|
||||
|
||||
# 找到选择节点的ID
|
||||
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
|
||||
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
|
||||
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
|
||||
break
|
||||
|
||||
# 记录探索步骤
|
||||
step_info = {
|
||||
'source': current_node_id,
|
||||
'relation': continue_result.selected_relation,
|
||||
'target': next_node_id,
|
||||
'reason': continue_result.exploration_reason
|
||||
}
|
||||
exploration_result['exploration_path'].append(step_info)
|
||||
|
||||
# 记录获得的知识
|
||||
exploration_result['knowledge_gained'].append({
|
||||
'step': step,
|
||||
'action': 'exploration_step',
|
||||
'selected_node': continue_result.selected_node,
|
||||
'relation': continue_result.selected_relation,
|
||||
'reason': continue_result.exploration_reason,
|
||||
'expected_knowledge': continue_result.expected_knowledge
|
||||
})
|
||||
|
||||
# 移动到下一个节点
|
||||
current_node_id = next_node_id
|
||||
exploration_result['visited_nodes'].append(current_node_id)
|
||||
|
||||
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
|
||||
|
||||
exploration_result['success'] = True
|
||||
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])} 步")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 分支探索失败: {e}")
|
||||
exploration_result['error'] = str(e)
|
||||
|
||||
return exploration_result
|
||||
|
||||
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
|
||||
"""使用LLM选择最佳概念进行探索"""
|
||||
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
|
||||
|
||||
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
|
||||
node_name=node_name,
|
||||
connected_concepts=concepts_str,
|
||||
user_query=user_query,
|
||||
insufficiency_reason=insufficiency_reason
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
return self.concept_exploration_parser.parse(response_text)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念选择失败: {e}")
|
||||
return None
|
||||
|
||||
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
|
||||
"""使用LLM决定下一步探索方向"""
|
||||
triplets_str = format_triplets(neighbor_triplets)
|
||||
|
||||
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
|
||||
current_node=current_node,
|
||||
neighbor_triplets=triplets_str,
|
||||
exploration_path=exploration_path,
|
||||
user_query=user_query
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
return self.concept_continue_parser.parse(response_text)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 探索继续失败: {e}")
|
||||
return None
|
||||
|
||||
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
|
||||
"""根据概念标签找到对应的节点ID"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
return None
|
||||
|
||||
graph = self.retriever.graph_data
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
node_label = node_data.get('label', node_data.get('name', ''))
|
||||
node_type = node_data.get('type', '').lower()
|
||||
|
||||
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
|
||||
return node_id
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查找概念节点ID失败: {e}")
|
||||
return None
|
||||
|
||||
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
|
||||
"""根据节点标签找到对应的节点ID"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
return None
|
||||
|
||||
graph = self.retriever.graph_data
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
label = node_data.get('label', node_data.get('name', ''))
|
||||
|
||||
if label == node_label or node_id == node_label:
|
||||
return node_id
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查找节点ID失败: {e}")
|
||||
return None
|
||||
|
||||
def concept_exploration_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索节点
|
||||
1. 提取TOP-3 Node节点
|
||||
2. 并行探索3个分支
|
||||
3. 合并探索结果
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
|
||||
|
||||
# 更新探索轮次
|
||||
state['exploration_round'] += 1
|
||||
|
||||
# 获取TOP-3 Node节点
|
||||
top_nodes = self._get_top_node_candidates(state)
|
||||
if not top_nodes:
|
||||
print("[ERROR] 未找到合适的Node节点进行探索")
|
||||
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
|
||||
return state
|
||||
|
||||
# 获取不充分原因
|
||||
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||||
user_query = state['original_query']
|
||||
|
||||
# 如果是第二轮探索,使用生成的子查询
|
||||
if state['exploration_round'] == 2:
|
||||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||||
if sub_query_info and sub_query_info.get('sub_query'):
|
||||
user_query = sub_query_info['sub_query']
|
||||
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
|
||||
|
||||
# 并行探索多个分支
|
||||
exploration_results = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
|
||||
# 提交探索任务
|
||||
futures = {
|
||||
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
|
||||
for node_info in top_nodes
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(futures):
|
||||
node_info = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
exploration_results.append(result)
|
||||
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
|
||||
exploration_results.append({
|
||||
'start_node': node_info,
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
# 合并探索结果
|
||||
round_key = f'round_{state["exploration_round"]}_results'
|
||||
state['concept_exploration_results'][round_key] = {
|
||||
'exploration_results': exploration_results,
|
||||
'total_branches': len(top_nodes),
|
||||
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
|
||||
'query_used': user_query
|
||||
}
|
||||
|
||||
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
|
||||
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
|
||||
|
||||
return state
|
||||
|
||||
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索的充分性检查节点
|
||||
综合评估文档段落和探索知识是否足够回答查询
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[THINK] 执行概念探索充分性检查")
|
||||
|
||||
# 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||||
if 'all_documents' in state and state['all_documents']:
|
||||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||||
else:
|
||||
formatted_passages = format_passages(state['all_passages'])
|
||||
|
||||
# 格式化探索获得的知识
|
||||
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
|
||||
|
||||
# 获取原始不充分原因
|
||||
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||||
|
||||
# 使用的查询(可能是子查询)
|
||||
current_query = state['original_query']
|
||||
if state['exploration_round'] == 2:
|
||||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||||
if sub_query_info and sub_query_info.get('sub_query'):
|
||||
current_query = sub_query_info['sub_query']
|
||||
|
||||
# 构建提示词
|
||||
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
|
||||
user_query=current_query,
|
||||
all_passages=formatted_passages,
|
||||
exploration_knowledge=exploration_knowledge,
|
||||
insufficiency_reason=original_insufficiency
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用LLM进行充分性检查
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
|
||||
# 解析结果
|
||||
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
|
||||
|
||||
# 更新状态
|
||||
state['is_sufficient'] = sufficiency_result.is_sufficient
|
||||
|
||||
# 存储概念探索的充分性检查结果
|
||||
round_key = f'round_{state["exploration_round"]}_results'
|
||||
if round_key in state['concept_exploration_results']:
|
||||
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
|
||||
'is_sufficient': sufficiency_result.is_sufficient,
|
||||
'confidence': sufficiency_result.confidence,
|
||||
'reason': sufficiency_result.reason,
|
||||
'missing_aspects': sufficiency_result.missing_aspects,
|
||||
'coverage_score': sufficiency_result.coverage_score
|
||||
}
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["llm_calls"] += 1
|
||||
|
||||
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
|
||||
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
|
||||
if not sufficiency_result.is_sufficient:
|
||||
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念探索充分性检查失败: {e}")
|
||||
# 如果检查失败,假设不充分
|
||||
state['is_sufficient'] = False
|
||||
return state
|
||||
|
||||
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索子查询生成节点
|
||||
如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[TARGET] 生成概念探索子查询")
|
||||
|
||||
# 获取第一轮探索结果和缺失信息
|
||||
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
|
||||
sufficiency_info = round_1_results.get('sufficiency_check', {})
|
||||
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
|
||||
|
||||
# 格式化第一轮探索结果
|
||||
exploration_summary = self._summarize_exploration_results(round_1_results)
|
||||
|
||||
# 构建提示词
|
||||
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
|
||||
user_query=state['original_query'],
|
||||
missing_aspects=str(missing_aspects),
|
||||
exploration_results=exploration_summary
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用LLM生成子查询
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
|
||||
# 解析结果
|
||||
sub_query_result = self.concept_sub_query_parser.parse(response_text)
|
||||
|
||||
# 存储子查询信息
|
||||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||||
'sub_query': sub_query_result.sub_query,
|
||||
'focus_aspects': sub_query_result.focus_aspects,
|
||||
'expected_improvements': sub_query_result.expected_improvements,
|
||||
'confidence': sub_query_result.confidence
|
||||
}
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["llm_calls"] += 1
|
||||
|
||||
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
|
||||
print(f" 关注方面: {sub_query_result.focus_aspects}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念探索子查询生成失败: {e}")
|
||||
# 如果生成失败,使用默认子查询
|
||||
default_sub_query = state['original_query'] + " 补充详细信息"
|
||||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||||
'sub_query': default_sub_query,
|
||||
'focus_aspects': ['补充信息'],
|
||||
'expected_improvements': '获取更多相关信息',
|
||||
'confidence': 0.3
|
||||
}
|
||||
return state
|
||||
|
||||
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
|
||||
"""格式化探索获得的知识为字符串"""
|
||||
if not exploration_results:
|
||||
return "无探索知识"
|
||||
|
||||
knowledge_parts = []
|
||||
|
||||
for round_key, round_data in exploration_results.items():
|
||||
if not round_key.startswith('round_'):
|
||||
continue
|
||||
|
||||
round_num = round_key.split('_')[1]
|
||||
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
|
||||
|
||||
exploration_list = round_data.get('exploration_results', [])
|
||||
for i, branch in enumerate(exploration_list):
|
||||
if not branch.get('success', False):
|
||||
continue
|
||||
|
||||
start_node = branch.get('start_node', {})
|
||||
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
|
||||
|
||||
knowledge_gained = branch.get('knowledge_gained', [])
|
||||
for knowledge in knowledge_gained:
|
||||
action = knowledge.get('action', '')
|
||||
if action == 'concept_selection':
|
||||
concept = knowledge.get('selected_concept', '')
|
||||
reason = knowledge.get('reason', '')
|
||||
knowledge_parts.append(f"- 选择探索概念: {concept}")
|
||||
knowledge_parts.append(f" 原因: {reason}")
|
||||
elif action == 'exploration_step':
|
||||
node = knowledge.get('selected_node', '')
|
||||
relation = knowledge.get('relation', '')
|
||||
reason = knowledge.get('reason', '')
|
||||
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
|
||||
knowledge_parts.append(f" 原因: {reason}")
|
||||
|
||||
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
|
||||
|
||||
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
|
||||
"""总结探索结果为简洁描述"""
|
||||
if not round_results:
|
||||
return "无探索结果"
|
||||
|
||||
exploration_list = round_results.get('exploration_results', [])
|
||||
successful_count = len([r for r in exploration_list if r.get('success', False)])
|
||||
total_count = len(exploration_list)
|
||||
|
||||
summary_parts = [
|
||||
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
|
||||
]
|
||||
|
||||
# 提取主要发现
|
||||
key_findings = []
|
||||
for branch in exploration_list:
|
||||
if not branch.get('success', False):
|
||||
continue
|
||||
knowledge_gained = branch.get('knowledge_gained', [])
|
||||
for knowledge in knowledge_gained:
|
||||
if knowledge.get('action') == 'concept_selection':
|
||||
concept = knowledge.get('selected_concept', '')
|
||||
if concept:
|
||||
key_findings.append(f"探索了概念: {concept}")
|
||||
|
||||
if key_findings:
|
||||
summary_parts.append("主要发现:")
|
||||
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
Reference in New Issue
Block a user