699 lines
30 KiB
Python
699 lines
30 KiB
Python
"""
|
||
概念探索节点实现
|
||
在PageRank收集后进行知识图谱概念探索
|
||
"""
|
||
|
||
import json
|
||
import asyncio
|
||
from typing import Dict, Any, List, Tuple, Optional
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
from retriver.langgraph.graph_state import QueryState
|
||
from retriver.langgraph.langchain_components import (
|
||
OneAPILLM,
|
||
ConceptExplorationParser,
|
||
ConceptContinueParser,
|
||
ConceptSufficiencyParser,
|
||
ConceptSubQueryParser,
|
||
CONCEPT_EXPLORATION_INIT_PROMPT,
|
||
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
|
||
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
|
||
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
|
||
format_passages,
|
||
format_mixed_passages,
|
||
format_triplets,
|
||
format_exploration_path
|
||
)
|
||
|
||
|
||
class ConceptExplorationNodes:
|
||
"""概念探索节点实现类"""
|
||
|
||
def __init__(
|
||
self,
|
||
retriever, # LangChainHippoRAGRetriever
|
||
llm: OneAPILLM,
|
||
max_exploration_steps: int = 3,
|
||
max_parallel_explorations: int = 3
|
||
):
|
||
"""
|
||
初始化概念探索节点处理器
|
||
|
||
Args:
|
||
retriever: HippoRAG检索器,用于访问知识图谱
|
||
llm: OneAPI LLM
|
||
max_exploration_steps: 每个分支最大探索步数
|
||
max_parallel_explorations: 最大并行探索分支数
|
||
"""
|
||
self.retriever = retriever
|
||
self.llm = llm
|
||
self.max_exploration_steps = max_exploration_steps
|
||
self.max_parallel_explorations = max_parallel_explorations
|
||
|
||
# 创建解析器
|
||
self.concept_exploration_parser = ConceptExplorationParser()
|
||
self.concept_continue_parser = ConceptContinueParser()
|
||
self.concept_sufficiency_parser = ConceptSufficiencyParser()
|
||
self.concept_sub_query_parser = ConceptSubQueryParser()
|
||
|
||
def _extract_response_text(self, response):
|
||
"""统一的响应文本提取方法"""
|
||
if hasattr(response, 'generations') and response.generations:
|
||
return response.generations[0][0].text
|
||
elif hasattr(response, 'content'):
|
||
return response.content
|
||
elif isinstance(response, dict) and 'response' in response:
|
||
return response['response']
|
||
else:
|
||
return str(response)
|
||
|
||
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
|
||
"""
|
||
从PageRank结果中提取TOP-3个Label为Node的节点
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
TOP-3 Node节点列表,每个包含node_id, score, label, type信息
|
||
"""
|
||
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
|
||
|
||
pagerank_summary = state.get('pagerank_summary', {})
|
||
if not pagerank_summary:
|
||
print("[ERROR] 没有找到PageRank汇总信息")
|
||
return []
|
||
|
||
# 获取所有带标签信息的节点
|
||
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
|
||
if not all_nodes_with_labels:
|
||
print("[ERROR] 没有找到带标签的节点信息")
|
||
return []
|
||
|
||
# 筛选出Label为Node的节点
|
||
node_candidates = []
|
||
for node_info in all_nodes_with_labels:
|
||
node_type = node_info.get('type', '').lower()
|
||
label = node_info.get('label', '')
|
||
|
||
# 检查是否为Node类型的节点
|
||
if 'node' in node_type.lower() or node_type == 'node':
|
||
node_candidates.append(node_info)
|
||
|
||
# 按分数降序排列并取前3个
|
||
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
|
||
top_3_nodes = node_candidates[:3]
|
||
|
||
print(f"[OK] 找到 {len(node_candidates)} 个Node节点,选择TOP-3:")
|
||
for i, node in enumerate(top_3_nodes):
|
||
node_id = node.get('node_id', 'unknown')
|
||
score = node.get('score', 0)
|
||
label = node.get('label', 'unknown')
|
||
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
|
||
|
||
return top_3_nodes
|
||
|
||
def _get_connected_concepts(self, node_id: str) -> List[str]:
|
||
"""
|
||
获取指定Node节点连接的所有Concept节点
|
||
|
||
Args:
|
||
node_id: 节点ID
|
||
|
||
Returns:
|
||
连接的Concept节点名称列表
|
||
"""
|
||
try:
|
||
# 通过检索器访问图数据获取邻居
|
||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
|
||
return []
|
||
|
||
connected_concepts = []
|
||
graph = self.retriever.graph_data
|
||
|
||
# 检查节点是否存在
|
||
if node_id not in graph.nodes:
|
||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||
return []
|
||
|
||
# 获取所有邻居节点
|
||
neighbors = list(graph.neighbors(node_id))
|
||
|
||
for neighbor_id in neighbors:
|
||
neighbor_data = graph.nodes.get(neighbor_id, {})
|
||
neighbor_type = neighbor_data.get('type', '').lower()
|
||
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
|
||
|
||
# 筛选Label为Concept的邻居
|
||
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
|
||
connected_concepts.append(neighbor_label)
|
||
|
||
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
|
||
return connected_concepts
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
|
||
return []
|
||
|
||
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||
"""
|
||
获取指定节点的所有邻居三元组,排除已访问的节点
|
||
|
||
Args:
|
||
node_id: 当前节点ID
|
||
excluded_nodes: 要排除的节点ID列表
|
||
|
||
Returns:
|
||
三元组列表,每个三元组包含source, relation, target信息
|
||
"""
|
||
try:
|
||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
|
||
return []
|
||
|
||
if excluded_nodes is None:
|
||
excluded_nodes = []
|
||
|
||
triplets = []
|
||
graph = self.retriever.graph_data
|
||
|
||
# 检查节点是否存在
|
||
if node_id not in graph.nodes:
|
||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||
return []
|
||
|
||
# 获取所有邻居边
|
||
for neighbor_id in graph.neighbors(node_id):
|
||
# 跳过已访问的节点
|
||
if neighbor_id in excluded_nodes:
|
||
continue
|
||
|
||
# 获取边的关系信息
|
||
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
|
||
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
|
||
|
||
# 获取节点标签
|
||
source_label = graph.nodes.get(node_id, {}).get('label',
|
||
graph.nodes.get(node_id, {}).get('name', node_id))
|
||
target_label = graph.nodes.get(neighbor_id, {}).get('label',
|
||
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
|
||
|
||
triplets.append({
|
||
'source': source_label,
|
||
'relation': relation,
|
||
'target': target_label,
|
||
'source_id': node_id,
|
||
'target_id': neighbor_id
|
||
})
|
||
|
||
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
|
||
return triplets
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
|
||
return []
|
||
|
||
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
|
||
"""
|
||
探索单个分支的完整路径
|
||
|
||
Args:
|
||
node_info: 起始Node节点信息
|
||
user_query: 用户查询
|
||
insufficiency_reason: 不充分原因
|
||
|
||
Returns:
|
||
探索结果字典
|
||
"""
|
||
node_id = node_info.get('node_id', '')
|
||
node_label = node_info.get('label', '')
|
||
|
||
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
|
||
|
||
exploration_result = {
|
||
'start_node': node_info,
|
||
'exploration_path': [],
|
||
'visited_nodes': [node_id],
|
||
'knowledge_gained': [],
|
||
'success': False,
|
||
'error': None
|
||
}
|
||
|
||
try:
|
||
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
|
||
connected_concepts = self._get_connected_concepts(node_id)
|
||
if not connected_concepts:
|
||
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
|
||
exploration_result['error'] = "没有连接的Concept节点"
|
||
return exploration_result
|
||
|
||
# 使用LLM选择最佳概念
|
||
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
|
||
if not concept_selection or not concept_selection.selected_concept:
|
||
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
|
||
exploration_result['error'] = "未能选择有效概念"
|
||
return exploration_result
|
||
|
||
selected_concept = concept_selection.selected_concept
|
||
print(f"[POINT] 选择探索概念: {selected_concept}")
|
||
|
||
# 记录初始探索决策
|
||
exploration_result['knowledge_gained'].append({
|
||
'step': 0,
|
||
'action': 'concept_selection',
|
||
'selected_concept': selected_concept,
|
||
'reason': concept_selection.exploration_reason,
|
||
'expected_knowledge': concept_selection.expected_knowledge
|
||
})
|
||
|
||
# 找到对应的概念节点ID
|
||
concept_node_id = self._find_concept_node_id(selected_concept)
|
||
if not concept_node_id:
|
||
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
|
||
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
|
||
return exploration_result
|
||
|
||
# 开始多跳探索
|
||
current_node_id = concept_node_id
|
||
exploration_result['visited_nodes'].append(current_node_id)
|
||
|
||
for step in range(1, self.max_exploration_steps + 1):
|
||
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
|
||
|
||
# 获取当前节点的邻居三元组,排除已访问节点
|
||
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
|
||
|
||
if not neighbor_triplets:
|
||
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
|
||
break
|
||
|
||
# 使用LLM选择下一个探索方向
|
||
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
|
||
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
|
||
|
||
if not continue_result or not continue_result.selected_node:
|
||
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
|
||
break
|
||
|
||
# 找到选择节点的ID
|
||
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
|
||
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
|
||
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
|
||
break
|
||
|
||
# 记录探索步骤
|
||
step_info = {
|
||
'source': current_node_id,
|
||
'relation': continue_result.selected_relation,
|
||
'target': next_node_id,
|
||
'reason': continue_result.exploration_reason
|
||
}
|
||
exploration_result['exploration_path'].append(step_info)
|
||
|
||
# 记录获得的知识
|
||
exploration_result['knowledge_gained'].append({
|
||
'step': step,
|
||
'action': 'exploration_step',
|
||
'selected_node': continue_result.selected_node,
|
||
'relation': continue_result.selected_relation,
|
||
'reason': continue_result.exploration_reason,
|
||
'expected_knowledge': continue_result.expected_knowledge
|
||
})
|
||
|
||
# 移动到下一个节点
|
||
current_node_id = next_node_id
|
||
exploration_result['visited_nodes'].append(current_node_id)
|
||
|
||
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
|
||
|
||
exploration_result['success'] = True
|
||
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])} 步")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 分支探索失败: {e}")
|
||
exploration_result['error'] = str(e)
|
||
|
||
return exploration_result
|
||
|
||
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
|
||
"""使用LLM选择最佳概念进行探索"""
|
||
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
|
||
|
||
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
|
||
node_name=node_name,
|
||
connected_concepts=concepts_str,
|
||
user_query=user_query,
|
||
insufficiency_reason=insufficiency_reason
|
||
)
|
||
|
||
try:
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
return self.concept_exploration_parser.parse(response_text)
|
||
except Exception as e:
|
||
print(f"[ERROR] 概念选择失败: {e}")
|
||
return None
|
||
|
||
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
|
||
"""使用LLM决定下一步探索方向"""
|
||
triplets_str = format_triplets(neighbor_triplets)
|
||
|
||
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
|
||
current_node=current_node,
|
||
neighbor_triplets=triplets_str,
|
||
exploration_path=exploration_path,
|
||
user_query=user_query
|
||
)
|
||
|
||
try:
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
return self.concept_continue_parser.parse(response_text)
|
||
except Exception as e:
|
||
print(f"[ERROR] 探索继续失败: {e}")
|
||
return None
|
||
|
||
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
|
||
"""根据概念标签找到对应的节点ID"""
|
||
try:
|
||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||
return None
|
||
|
||
graph = self.retriever.graph_data
|
||
for node_id, node_data in graph.nodes(data=True):
|
||
node_label = node_data.get('label', node_data.get('name', ''))
|
||
node_type = node_data.get('type', '').lower()
|
||
|
||
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
|
||
return node_id
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"[ERROR] 查找概念节点ID失败: {e}")
|
||
return None
|
||
|
||
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
|
||
"""根据节点标签找到对应的节点ID"""
|
||
try:
|
||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||
return None
|
||
|
||
graph = self.retriever.graph_data
|
||
for node_id, node_data in graph.nodes(data=True):
|
||
label = node_data.get('label', node_data.get('name', ''))
|
||
|
||
if label == node_label or node_id == node_label:
|
||
return node_id
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"[ERROR] 查找节点ID失败: {e}")
|
||
return None
|
||
|
||
def concept_exploration_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
概念探索节点
|
||
1. 提取TOP-3 Node节点
|
||
2. 并行探索3个分支
|
||
3. 合并探索结果
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
|
||
|
||
# 更新探索轮次
|
||
state['exploration_round'] += 1
|
||
|
||
# 获取TOP-3 Node节点
|
||
top_nodes = self._get_top_node_candidates(state)
|
||
if not top_nodes:
|
||
print("[ERROR] 未找到合适的Node节点进行探索")
|
||
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
|
||
return state
|
||
|
||
# 获取不充分原因
|
||
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||
user_query = state['original_query']
|
||
|
||
# 如果是第二轮探索,使用生成的子查询
|
||
if state['exploration_round'] == 2:
|
||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||
if sub_query_info and sub_query_info.get('sub_query'):
|
||
user_query = sub_query_info['sub_query']
|
||
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
|
||
|
||
# 并行探索多个分支
|
||
exploration_results = []
|
||
|
||
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
|
||
# 提交探索任务
|
||
futures = {
|
||
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
|
||
for node_info in top_nodes
|
||
}
|
||
|
||
# 收集结果
|
||
for future in as_completed(futures):
|
||
node_info = futures[future]
|
||
try:
|
||
result = future.result()
|
||
exploration_results.append(result)
|
||
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
|
||
except Exception as e:
|
||
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
|
||
exploration_results.append({
|
||
'start_node': node_info,
|
||
'success': False,
|
||
'error': str(e)
|
||
})
|
||
|
||
# 合并探索结果
|
||
round_key = f'round_{state["exploration_round"]}_results'
|
||
state['concept_exploration_results'][round_key] = {
|
||
'exploration_results': exploration_results,
|
||
'total_branches': len(top_nodes),
|
||
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
|
||
'query_used': user_query
|
||
}
|
||
|
||
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
|
||
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
|
||
|
||
return state
|
||
|
||
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
概念探索的充分性检查节点
|
||
综合评估文档段落和探索知识是否足够回答查询
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[THINK] 执行概念探索充分性检查")
|
||
|
||
# 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||
if 'all_documents' in state and state['all_documents']:
|
||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||
else:
|
||
formatted_passages = format_passages(state['all_passages'])
|
||
|
||
# 格式化探索获得的知识
|
||
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
|
||
|
||
# 获取原始不充分原因
|
||
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||
|
||
# 使用的查询(可能是子查询)
|
||
current_query = state['original_query']
|
||
if state['exploration_round'] == 2:
|
||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||
if sub_query_info and sub_query_info.get('sub_query'):
|
||
current_query = sub_query_info['sub_query']
|
||
|
||
# 构建提示词
|
||
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
|
||
user_query=current_query,
|
||
all_passages=formatted_passages,
|
||
exploration_knowledge=exploration_knowledge,
|
||
insufficiency_reason=original_insufficiency
|
||
)
|
||
|
||
try:
|
||
# 调用LLM进行充分性检查
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
# 解析结果
|
||
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
|
||
|
||
# 更新状态
|
||
state['is_sufficient'] = sufficiency_result.is_sufficient
|
||
|
||
# 存储概念探索的充分性检查结果
|
||
round_key = f'round_{state["exploration_round"]}_results'
|
||
if round_key in state['concept_exploration_results']:
|
||
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
|
||
'is_sufficient': sufficiency_result.is_sufficient,
|
||
'confidence': sufficiency_result.confidence,
|
||
'reason': sufficiency_result.reason,
|
||
'missing_aspects': sufficiency_result.missing_aspects,
|
||
'coverage_score': sufficiency_result.coverage_score
|
||
}
|
||
|
||
# 更新调试信息
|
||
state["debug_info"]["llm_calls"] += 1
|
||
|
||
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
|
||
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
|
||
if not sufficiency_result.is_sufficient:
|
||
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 概念探索充分性检查失败: {e}")
|
||
# 如果检查失败,假设不充分
|
||
state['is_sufficient'] = False
|
||
return state
|
||
|
||
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
|
||
"""
|
||
概念探索子查询生成节点
|
||
如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索
|
||
|
||
Args:
|
||
state: 当前状态
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
print(f"[TARGET] 生成概念探索子查询")
|
||
|
||
# 获取第一轮探索结果和缺失信息
|
||
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
|
||
sufficiency_info = round_1_results.get('sufficiency_check', {})
|
||
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
|
||
|
||
# 格式化第一轮探索结果
|
||
exploration_summary = self._summarize_exploration_results(round_1_results)
|
||
|
||
# 构建提示词
|
||
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
|
||
user_query=state['original_query'],
|
||
missing_aspects=str(missing_aspects),
|
||
exploration_results=exploration_summary
|
||
)
|
||
|
||
try:
|
||
# 调用LLM生成子查询
|
||
response = self.llm.invoke(prompt)
|
||
response_text = self._extract_response_text(response)
|
||
|
||
# 解析结果
|
||
sub_query_result = self.concept_sub_query_parser.parse(response_text)
|
||
|
||
# 存储子查询信息
|
||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||
'sub_query': sub_query_result.sub_query,
|
||
'focus_aspects': sub_query_result.focus_aspects,
|
||
'expected_improvements': sub_query_result.expected_improvements,
|
||
'confidence': sub_query_result.confidence
|
||
}
|
||
|
||
# 更新调试信息
|
||
state["debug_info"]["llm_calls"] += 1
|
||
|
||
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
|
||
print(f" 关注方面: {sub_query_result.focus_aspects}")
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 概念探索子查询生成失败: {e}")
|
||
# 如果生成失败,使用默认子查询
|
||
default_sub_query = state['original_query'] + " 补充详细信息"
|
||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||
'sub_query': default_sub_query,
|
||
'focus_aspects': ['补充信息'],
|
||
'expected_improvements': '获取更多相关信息',
|
||
'confidence': 0.3
|
||
}
|
||
return state
|
||
|
||
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
|
||
"""格式化探索获得的知识为字符串"""
|
||
if not exploration_results:
|
||
return "无探索知识"
|
||
|
||
knowledge_parts = []
|
||
|
||
for round_key, round_data in exploration_results.items():
|
||
if not round_key.startswith('round_'):
|
||
continue
|
||
|
||
round_num = round_key.split('_')[1]
|
||
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
|
||
|
||
exploration_list = round_data.get('exploration_results', [])
|
||
for i, branch in enumerate(exploration_list):
|
||
if not branch.get('success', False):
|
||
continue
|
||
|
||
start_node = branch.get('start_node', {})
|
||
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
|
||
|
||
knowledge_gained = branch.get('knowledge_gained', [])
|
||
for knowledge in knowledge_gained:
|
||
action = knowledge.get('action', '')
|
||
if action == 'concept_selection':
|
||
concept = knowledge.get('selected_concept', '')
|
||
reason = knowledge.get('reason', '')
|
||
knowledge_parts.append(f"- 选择探索概念: {concept}")
|
||
knowledge_parts.append(f" 原因: {reason}")
|
||
elif action == 'exploration_step':
|
||
node = knowledge.get('selected_node', '')
|
||
relation = knowledge.get('relation', '')
|
||
reason = knowledge.get('reason', '')
|
||
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
|
||
knowledge_parts.append(f" 原因: {reason}")
|
||
|
||
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
|
||
|
||
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
|
||
"""总结探索结果为简洁描述"""
|
||
if not round_results:
|
||
return "无探索结果"
|
||
|
||
exploration_list = round_results.get('exploration_results', [])
|
||
successful_count = len([r for r in exploration_list if r.get('success', False)])
|
||
total_count = len(exploration_list)
|
||
|
||
summary_parts = [
|
||
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
|
||
]
|
||
|
||
# 提取主要发现
|
||
key_findings = []
|
||
for branch in exploration_list:
|
||
if not branch.get('success', False):
|
||
continue
|
||
knowledge_gained = branch.get('knowledge_gained', [])
|
||
for knowledge in knowledge_gained:
|
||
if knowledge.get('action') == 'concept_selection':
|
||
concept = knowledge.get('selected_concept', '')
|
||
if concept:
|
||
key_findings.append(f"探索了概念: {concept}")
|
||
|
||
if key_findings:
|
||
summary_parts.append("主要发现:")
|
||
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
|
||
|
||
return "\n".join(summary_parts) |