first commit
This commit is contained in:
14
AIEC-RAG/retriver/langgraph/__init__.py
Normal file
14
AIEC-RAG/retriver/langgraph/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
智能检索系统
|
||||
基于LangChain + LangGraph实现的迭代检索框架
|
||||
"""
|
||||
|
||||
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever, create_langchain_hipporag_retriever
|
||||
from retriver.langgraph.iterative_retriever import IterativeRetriever, create_iterative_retriever
|
||||
|
||||
__all__ = [
|
||||
"LangChainHippoRAGRetriever",
|
||||
"IterativeRetriever",
|
||||
"create_langchain_hipporag_retriever",
|
||||
"create_iterative_retriever"
|
||||
]
|
||||
BIN
AIEC-RAG/retriver/langgraph/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
AIEC-RAG/retriver/langgraph/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
AIEC-RAG/retriver/langgraph/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
AIEC-RAG/retriver/langgraph/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
699
AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py
Normal file
699
AIEC-RAG/retriver/langgraph/concept_exploration_nodes.py
Normal file
@ -0,0 +1,699 @@
|
||||
"""
|
||||
概念探索节点实现
|
||||
在PageRank收集后进行知识图谱概念探索
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from retriver.langgraph.graph_state import QueryState
|
||||
from retriver.langgraph.langchain_components import (
|
||||
OneAPILLM,
|
||||
ConceptExplorationParser,
|
||||
ConceptContinueParser,
|
||||
ConceptSufficiencyParser,
|
||||
ConceptSubQueryParser,
|
||||
CONCEPT_EXPLORATION_INIT_PROMPT,
|
||||
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
|
||||
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
|
||||
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
|
||||
format_passages,
|
||||
format_mixed_passages,
|
||||
format_triplets,
|
||||
format_exploration_path
|
||||
)
|
||||
|
||||
|
||||
class ConceptExplorationNodes:
|
||||
"""概念探索节点实现类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever, # LangChainHippoRAGRetriever
|
||||
llm: OneAPILLM,
|
||||
max_exploration_steps: int = 3,
|
||||
max_parallel_explorations: int = 3
|
||||
):
|
||||
"""
|
||||
初始化概念探索节点处理器
|
||||
|
||||
Args:
|
||||
retriever: HippoRAG检索器,用于访问知识图谱
|
||||
llm: OneAPI LLM
|
||||
max_exploration_steps: 每个分支最大探索步数
|
||||
max_parallel_explorations: 最大并行探索分支数
|
||||
"""
|
||||
self.retriever = retriever
|
||||
self.llm = llm
|
||||
self.max_exploration_steps = max_exploration_steps
|
||||
self.max_parallel_explorations = max_parallel_explorations
|
||||
|
||||
# 创建解析器
|
||||
self.concept_exploration_parser = ConceptExplorationParser()
|
||||
self.concept_continue_parser = ConceptContinueParser()
|
||||
self.concept_sufficiency_parser = ConceptSufficiencyParser()
|
||||
self.concept_sub_query_parser = ConceptSubQueryParser()
|
||||
|
||||
def _extract_response_text(self, response):
|
||||
"""统一的响应文本提取方法"""
|
||||
if hasattr(response, 'generations') and response.generations:
|
||||
return response.generations[0][0].text
|
||||
elif hasattr(response, 'content'):
|
||||
return response.content
|
||||
elif isinstance(response, dict) and 'response' in response:
|
||||
return response['response']
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从PageRank结果中提取TOP-3个Label为Node的节点
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
TOP-3 Node节点列表,每个包含node_id, score, label, type信息
|
||||
"""
|
||||
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
|
||||
|
||||
pagerank_summary = state.get('pagerank_summary', {})
|
||||
if not pagerank_summary:
|
||||
print("[ERROR] 没有找到PageRank汇总信息")
|
||||
return []
|
||||
|
||||
# 获取所有带标签信息的节点
|
||||
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
|
||||
if not all_nodes_with_labels:
|
||||
print("[ERROR] 没有找到带标签的节点信息")
|
||||
return []
|
||||
|
||||
# 筛选出Label为Node的节点
|
||||
node_candidates = []
|
||||
for node_info in all_nodes_with_labels:
|
||||
node_type = node_info.get('type', '').lower()
|
||||
label = node_info.get('label', '')
|
||||
|
||||
# 检查是否为Node类型的节点
|
||||
if 'node' in node_type.lower() or node_type == 'node':
|
||||
node_candidates.append(node_info)
|
||||
|
||||
# 按分数降序排列并取前3个
|
||||
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
|
||||
top_3_nodes = node_candidates[:3]
|
||||
|
||||
print(f"[OK] 找到 {len(node_candidates)} 个Node节点,选择TOP-3:")
|
||||
for i, node in enumerate(top_3_nodes):
|
||||
node_id = node.get('node_id', 'unknown')
|
||||
score = node.get('score', 0)
|
||||
label = node.get('label', 'unknown')
|
||||
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
|
||||
|
||||
return top_3_nodes
|
||||
|
||||
def _get_connected_concepts(self, node_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定Node节点连接的所有Concept节点
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
|
||||
Returns:
|
||||
连接的Concept节点名称列表
|
||||
"""
|
||||
try:
|
||||
# 通过检索器访问图数据获取邻居
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
|
||||
return []
|
||||
|
||||
connected_concepts = []
|
||||
graph = self.retriever.graph_data
|
||||
|
||||
# 检查节点是否存在
|
||||
if node_id not in graph.nodes:
|
||||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||||
return []
|
||||
|
||||
# 获取所有邻居节点
|
||||
neighbors = list(graph.neighbors(node_id))
|
||||
|
||||
for neighbor_id in neighbors:
|
||||
neighbor_data = graph.nodes.get(neighbor_id, {})
|
||||
neighbor_type = neighbor_data.get('type', '').lower()
|
||||
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
|
||||
|
||||
# 筛选Label为Concept的邻居
|
||||
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
|
||||
connected_concepts.append(neighbor_label)
|
||||
|
||||
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
|
||||
return connected_concepts
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
|
||||
return []
|
||||
|
||||
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取指定节点的所有邻居三元组,排除已访问的节点
|
||||
|
||||
Args:
|
||||
node_id: 当前节点ID
|
||||
excluded_nodes: 要排除的节点ID列表
|
||||
|
||||
Returns:
|
||||
三元组列表,每个三元组包含source, relation, target信息
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
|
||||
return []
|
||||
|
||||
if excluded_nodes is None:
|
||||
excluded_nodes = []
|
||||
|
||||
triplets = []
|
||||
graph = self.retriever.graph_data
|
||||
|
||||
# 检查节点是否存在
|
||||
if node_id not in graph.nodes:
|
||||
print(f"[ERROR] 节点 {node_id} 不存在于图中")
|
||||
return []
|
||||
|
||||
# 获取所有邻居边
|
||||
for neighbor_id in graph.neighbors(node_id):
|
||||
# 跳过已访问的节点
|
||||
if neighbor_id in excluded_nodes:
|
||||
continue
|
||||
|
||||
# 获取边的关系信息
|
||||
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
|
||||
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
|
||||
|
||||
# 获取节点标签
|
||||
source_label = graph.nodes.get(node_id, {}).get('label',
|
||||
graph.nodes.get(node_id, {}).get('name', node_id))
|
||||
target_label = graph.nodes.get(neighbor_id, {}).get('label',
|
||||
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
|
||||
|
||||
triplets.append({
|
||||
'source': source_label,
|
||||
'relation': relation,
|
||||
'target': target_label,
|
||||
'source_id': node_id,
|
||||
'target_id': neighbor_id
|
||||
})
|
||||
|
||||
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
|
||||
return triplets
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
|
||||
return []
|
||||
|
||||
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
|
||||
"""
|
||||
探索单个分支的完整路径
|
||||
|
||||
Args:
|
||||
node_info: 起始Node节点信息
|
||||
user_query: 用户查询
|
||||
insufficiency_reason: 不充分原因
|
||||
|
||||
Returns:
|
||||
探索结果字典
|
||||
"""
|
||||
node_id = node_info.get('node_id', '')
|
||||
node_label = node_info.get('label', '')
|
||||
|
||||
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
|
||||
|
||||
exploration_result = {
|
||||
'start_node': node_info,
|
||||
'exploration_path': [],
|
||||
'visited_nodes': [node_id],
|
||||
'knowledge_gained': [],
|
||||
'success': False,
|
||||
'error': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
|
||||
connected_concepts = self._get_connected_concepts(node_id)
|
||||
if not connected_concepts:
|
||||
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
|
||||
exploration_result['error'] = "没有连接的Concept节点"
|
||||
return exploration_result
|
||||
|
||||
# 使用LLM选择最佳概念
|
||||
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
|
||||
if not concept_selection or not concept_selection.selected_concept:
|
||||
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
|
||||
exploration_result['error'] = "未能选择有效概念"
|
||||
return exploration_result
|
||||
|
||||
selected_concept = concept_selection.selected_concept
|
||||
print(f"[POINT] 选择探索概念: {selected_concept}")
|
||||
|
||||
# 记录初始探索决策
|
||||
exploration_result['knowledge_gained'].append({
|
||||
'step': 0,
|
||||
'action': 'concept_selection',
|
||||
'selected_concept': selected_concept,
|
||||
'reason': concept_selection.exploration_reason,
|
||||
'expected_knowledge': concept_selection.expected_knowledge
|
||||
})
|
||||
|
||||
# 找到对应的概念节点ID
|
||||
concept_node_id = self._find_concept_node_id(selected_concept)
|
||||
if not concept_node_id:
|
||||
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
|
||||
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
|
||||
return exploration_result
|
||||
|
||||
# 开始多跳探索
|
||||
current_node_id = concept_node_id
|
||||
exploration_result['visited_nodes'].append(current_node_id)
|
||||
|
||||
for step in range(1, self.max_exploration_steps + 1):
|
||||
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
|
||||
|
||||
# 获取当前节点的邻居三元组,排除已访问节点
|
||||
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
|
||||
|
||||
if not neighbor_triplets:
|
||||
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
|
||||
break
|
||||
|
||||
# 使用LLM选择下一个探索方向
|
||||
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
|
||||
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
|
||||
|
||||
if not continue_result or not continue_result.selected_node:
|
||||
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
|
||||
break
|
||||
|
||||
# 找到选择节点的ID
|
||||
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
|
||||
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
|
||||
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
|
||||
break
|
||||
|
||||
# 记录探索步骤
|
||||
step_info = {
|
||||
'source': current_node_id,
|
||||
'relation': continue_result.selected_relation,
|
||||
'target': next_node_id,
|
||||
'reason': continue_result.exploration_reason
|
||||
}
|
||||
exploration_result['exploration_path'].append(step_info)
|
||||
|
||||
# 记录获得的知识
|
||||
exploration_result['knowledge_gained'].append({
|
||||
'step': step,
|
||||
'action': 'exploration_step',
|
||||
'selected_node': continue_result.selected_node,
|
||||
'relation': continue_result.selected_relation,
|
||||
'reason': continue_result.exploration_reason,
|
||||
'expected_knowledge': continue_result.expected_knowledge
|
||||
})
|
||||
|
||||
# 移动到下一个节点
|
||||
current_node_id = next_node_id
|
||||
exploration_result['visited_nodes'].append(current_node_id)
|
||||
|
||||
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
|
||||
|
||||
exploration_result['success'] = True
|
||||
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])} 步")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 分支探索失败: {e}")
|
||||
exploration_result['error'] = str(e)
|
||||
|
||||
return exploration_result
|
||||
|
||||
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
|
||||
"""使用LLM选择最佳概念进行探索"""
|
||||
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
|
||||
|
||||
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
|
||||
node_name=node_name,
|
||||
connected_concepts=concepts_str,
|
||||
user_query=user_query,
|
||||
insufficiency_reason=insufficiency_reason
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
return self.concept_exploration_parser.parse(response_text)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念选择失败: {e}")
|
||||
return None
|
||||
|
||||
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
|
||||
"""使用LLM决定下一步探索方向"""
|
||||
triplets_str = format_triplets(neighbor_triplets)
|
||||
|
||||
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
|
||||
current_node=current_node,
|
||||
neighbor_triplets=triplets_str,
|
||||
exploration_path=exploration_path,
|
||||
user_query=user_query
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
return self.concept_continue_parser.parse(response_text)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 探索继续失败: {e}")
|
||||
return None
|
||||
|
||||
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
|
||||
"""根据概念标签找到对应的节点ID"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
return None
|
||||
|
||||
graph = self.retriever.graph_data
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
node_label = node_data.get('label', node_data.get('name', ''))
|
||||
node_type = node_data.get('type', '').lower()
|
||||
|
||||
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
|
||||
return node_id
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查找概念节点ID失败: {e}")
|
||||
return None
|
||||
|
||||
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
|
||||
"""根据节点标签找到对应的节点ID"""
|
||||
try:
|
||||
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
|
||||
return None
|
||||
|
||||
graph = self.retriever.graph_data
|
||||
for node_id, node_data in graph.nodes(data=True):
|
||||
label = node_data.get('label', node_data.get('name', ''))
|
||||
|
||||
if label == node_label or node_id == node_label:
|
||||
return node_id
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 查找节点ID失败: {e}")
|
||||
return None
|
||||
|
||||
def concept_exploration_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索节点
|
||||
1. 提取TOP-3 Node节点
|
||||
2. 并行探索3个分支
|
||||
3. 合并探索结果
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
|
||||
|
||||
# 更新探索轮次
|
||||
state['exploration_round'] += 1
|
||||
|
||||
# 获取TOP-3 Node节点
|
||||
top_nodes = self._get_top_node_candidates(state)
|
||||
if not top_nodes:
|
||||
print("[ERROR] 未找到合适的Node节点进行探索")
|
||||
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
|
||||
return state
|
||||
|
||||
# 获取不充分原因
|
||||
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||||
user_query = state['original_query']
|
||||
|
||||
# 如果是第二轮探索,使用生成的子查询
|
||||
if state['exploration_round'] == 2:
|
||||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||||
if sub_query_info and sub_query_info.get('sub_query'):
|
||||
user_query = sub_query_info['sub_query']
|
||||
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
|
||||
|
||||
# 并行探索多个分支
|
||||
exploration_results = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
|
||||
# 提交探索任务
|
||||
futures = {
|
||||
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
|
||||
for node_info in top_nodes
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(futures):
|
||||
node_info = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
exploration_results.append(result)
|
||||
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
|
||||
exploration_results.append({
|
||||
'start_node': node_info,
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
# 合并探索结果
|
||||
round_key = f'round_{state["exploration_round"]}_results'
|
||||
state['concept_exploration_results'][round_key] = {
|
||||
'exploration_results': exploration_results,
|
||||
'total_branches': len(top_nodes),
|
||||
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
|
||||
'query_used': user_query
|
||||
}
|
||||
|
||||
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
|
||||
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
|
||||
|
||||
return state
|
||||
|
||||
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索的充分性检查节点
|
||||
综合评估文档段落和探索知识是否足够回答查询
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[THINK] 执行概念探索充分性检查")
|
||||
|
||||
# 格式化现有检索结果(如果有all_documents则使用混合格式,否则使用传统格式)
|
||||
if 'all_documents' in state and state['all_documents']:
|
||||
formatted_passages = format_mixed_passages(state['all_documents'])
|
||||
else:
|
||||
formatted_passages = format_passages(state['all_passages'])
|
||||
|
||||
# 格式化探索获得的知识
|
||||
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
|
||||
|
||||
# 获取原始不充分原因
|
||||
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
|
||||
|
||||
# 使用的查询(可能是子查询)
|
||||
current_query = state['original_query']
|
||||
if state['exploration_round'] == 2:
|
||||
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
|
||||
if sub_query_info and sub_query_info.get('sub_query'):
|
||||
current_query = sub_query_info['sub_query']
|
||||
|
||||
# 构建提示词
|
||||
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
|
||||
user_query=current_query,
|
||||
all_passages=formatted_passages,
|
||||
exploration_knowledge=exploration_knowledge,
|
||||
insufficiency_reason=original_insufficiency
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用LLM进行充分性检查
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
|
||||
# 解析结果
|
||||
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
|
||||
|
||||
# 更新状态
|
||||
state['is_sufficient'] = sufficiency_result.is_sufficient
|
||||
|
||||
# 存储概念探索的充分性检查结果
|
||||
round_key = f'round_{state["exploration_round"]}_results'
|
||||
if round_key in state['concept_exploration_results']:
|
||||
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
|
||||
'is_sufficient': sufficiency_result.is_sufficient,
|
||||
'confidence': sufficiency_result.confidence,
|
||||
'reason': sufficiency_result.reason,
|
||||
'missing_aspects': sufficiency_result.missing_aspects,
|
||||
'coverage_score': sufficiency_result.coverage_score
|
||||
}
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["llm_calls"] += 1
|
||||
|
||||
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
|
||||
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
|
||||
if not sufficiency_result.is_sufficient:
|
||||
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念探索充分性检查失败: {e}")
|
||||
# 如果检查失败,假设不充分
|
||||
state['is_sufficient'] = False
|
||||
return state
|
||||
|
||||
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
|
||||
"""
|
||||
概念探索子查询生成节点
|
||||
如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
print(f"[TARGET] 生成概念探索子查询")
|
||||
|
||||
# 获取第一轮探索结果和缺失信息
|
||||
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
|
||||
sufficiency_info = round_1_results.get('sufficiency_check', {})
|
||||
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
|
||||
|
||||
# 格式化第一轮探索结果
|
||||
exploration_summary = self._summarize_exploration_results(round_1_results)
|
||||
|
||||
# 构建提示词
|
||||
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
|
||||
user_query=state['original_query'],
|
||||
missing_aspects=str(missing_aspects),
|
||||
exploration_results=exploration_summary
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用LLM生成子查询
|
||||
response = self.llm.invoke(prompt)
|
||||
response_text = self._extract_response_text(response)
|
||||
|
||||
# 解析结果
|
||||
sub_query_result = self.concept_sub_query_parser.parse(response_text)
|
||||
|
||||
# 存储子查询信息
|
||||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||||
'sub_query': sub_query_result.sub_query,
|
||||
'focus_aspects': sub_query_result.focus_aspects,
|
||||
'expected_improvements': sub_query_result.expected_improvements,
|
||||
'confidence': sub_query_result.confidence
|
||||
}
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["llm_calls"] += 1
|
||||
|
||||
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
|
||||
print(f" 关注方面: {sub_query_result.focus_aspects}")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 概念探索子查询生成失败: {e}")
|
||||
# 如果生成失败,使用默认子查询
|
||||
default_sub_query = state['original_query'] + " 补充详细信息"
|
||||
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
|
||||
'sub_query': default_sub_query,
|
||||
'focus_aspects': ['补充信息'],
|
||||
'expected_improvements': '获取更多相关信息',
|
||||
'confidence': 0.3
|
||||
}
|
||||
return state
|
||||
|
||||
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
|
||||
"""格式化探索获得的知识为字符串"""
|
||||
if not exploration_results:
|
||||
return "无探索知识"
|
||||
|
||||
knowledge_parts = []
|
||||
|
||||
for round_key, round_data in exploration_results.items():
|
||||
if not round_key.startswith('round_'):
|
||||
continue
|
||||
|
||||
round_num = round_key.split('_')[1]
|
||||
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
|
||||
|
||||
exploration_list = round_data.get('exploration_results', [])
|
||||
for i, branch in enumerate(exploration_list):
|
||||
if not branch.get('success', False):
|
||||
continue
|
||||
|
||||
start_node = branch.get('start_node', {})
|
||||
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
|
||||
|
||||
knowledge_gained = branch.get('knowledge_gained', [])
|
||||
for knowledge in knowledge_gained:
|
||||
action = knowledge.get('action', '')
|
||||
if action == 'concept_selection':
|
||||
concept = knowledge.get('selected_concept', '')
|
||||
reason = knowledge.get('reason', '')
|
||||
knowledge_parts.append(f"- 选择探索概念: {concept}")
|
||||
knowledge_parts.append(f" 原因: {reason}")
|
||||
elif action == 'exploration_step':
|
||||
node = knowledge.get('selected_node', '')
|
||||
relation = knowledge.get('relation', '')
|
||||
reason = knowledge.get('reason', '')
|
||||
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
|
||||
knowledge_parts.append(f" 原因: {reason}")
|
||||
|
||||
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
|
||||
|
||||
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
|
||||
"""总结探索结果为简洁描述"""
|
||||
if not round_results:
|
||||
return "无探索结果"
|
||||
|
||||
exploration_list = round_results.get('exploration_results', [])
|
||||
successful_count = len([r for r in exploration_list if r.get('success', False)])
|
||||
total_count = len(exploration_list)
|
||||
|
||||
summary_parts = [
|
||||
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
|
||||
]
|
||||
|
||||
# 提取主要发现
|
||||
key_findings = []
|
||||
for branch in exploration_list:
|
||||
if not branch.get('success', False):
|
||||
continue
|
||||
knowledge_gained = branch.get('knowledge_gained', [])
|
||||
for knowledge in knowledge_gained:
|
||||
if knowledge.get('action') == 'concept_selection':
|
||||
concept = knowledge.get('selected_concept', '')
|
||||
if concept:
|
||||
key_findings.append(f"探索了概念: {concept}")
|
||||
|
||||
if key_findings:
|
||||
summary_parts.append("主要发现:")
|
||||
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
269
AIEC-RAG/retriver/langgraph/dashscope_embedding.py
Normal file
269
AIEC-RAG/retriver/langgraph/dashscope_embedding.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""
|
||||
阿里云DashScope嵌入模型实现
|
||||
直接调用阿里云文本嵌入API,不通过OneAPI兼容层
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Union, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||||
sys.path.append(project_root)
|
||||
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
||||
|
||||
|
||||
class DashScopeEmbeddingModel(BaseEmbeddingModel):
|
||||
"""
|
||||
阿里云DashScope嵌入模型实现
|
||||
直接调用DashScope文本嵌入API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0
|
||||
):
|
||||
"""
|
||||
初始化DashScope嵌入模型
|
||||
|
||||
Args:
|
||||
api_key: 阿里云DashScope API Key
|
||||
model_name: 嵌入模型名称,如 text-embedding-v3
|
||||
max_retries: 最大重试次数
|
||||
retry_delay: 重试延迟时间(秒)
|
||||
"""
|
||||
# 初始化父类,传入None作为sentence_encoder(我们不使用它)
|
||||
super().__init__(sentence_encoder=None)
|
||||
|
||||
self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
||||
self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3')
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
||||
|
||||
# DashScope文本嵌入API端点
|
||||
self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
|
||||
# 设置请求头
|
||||
self.headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}")
|
||||
self._test_connection()
|
||||
|
||||
def encode(self, texts, **kwargs):
|
||||
"""
|
||||
实现基类的抽象方法
|
||||
对文本进行编码,返回嵌入向量
|
||||
|
||||
Args:
|
||||
texts: 文本或文本列表
|
||||
**kwargs: 其他参数(batch_size, show_progress_bar, query_type等)
|
||||
|
||||
Returns:
|
||||
嵌入向量或嵌入向量列表(numpy数组格式)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
if isinstance(texts, str):
|
||||
# 单个文本
|
||||
result = self.embed_query(texts)
|
||||
return np.array(result)
|
||||
else:
|
||||
# 文本列表
|
||||
results = self.embed_texts(texts)
|
||||
return np.array(results)
|
||||
|
||||
def _test_connection(self):
|
||||
"""测试DashScope嵌入API连接"""
|
||||
try:
|
||||
# 发送一个简单的测试请求
|
||||
test_payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"texts": ["test"]
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=test_payload,
|
||||
timeout=(5, 15) # 短超时进行连接测试
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('output') and result['output'].get('embeddings'):
|
||||
print("[OK] DashScope嵌入API连接测试成功")
|
||||
else:
|
||||
print(f"[WARNING] DashScope嵌入API响应格式异常: {result}")
|
||||
else:
|
||||
print(f"[WARNING] DashScope嵌入API连接异常,状态码: {response.status_code}")
|
||||
print(f" 响应: {response.text[:200]}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("[WARNING] DashScope嵌入API连接超时,但将在实际请求时重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("[WARNING] DashScope嵌入API连接失败,请检查网络状态")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] DashScope嵌入API连接测试出错: {e}")
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
对文本列表进行嵌入
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
"""
|
||||
import numpy as np
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# DashScope API支持批量处理,但需要分批处理大量文本
|
||||
batch_size = 10 # 每批处理10个文本
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
batch_embeddings = self._embed_batch(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""对一批文本进行嵌入"""
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"texts": texts
|
||||
}
|
||||
}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=(30, 120) # 连接30秒,读取120秒
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
# 检查是否有错误
|
||||
if result.get('code'):
|
||||
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"DashScope嵌入API错误: {error_msg}")
|
||||
else:
|
||||
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
|
||||
# 提取嵌入向量
|
||||
if result.get('output') and result['output'].get('embeddings'):
|
||||
embeddings_data = result['output']['embeddings']
|
||||
|
||||
# 提取向量数据
|
||||
embeddings = []
|
||||
for embedding_item in embeddings_data:
|
||||
if isinstance(embedding_item, dict) and 'embedding' in embedding_item:
|
||||
embeddings.append(embedding_item['embedding'])
|
||||
elif isinstance(embedding_item, list):
|
||||
embeddings.append(embedding_item)
|
||||
else:
|
||||
print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}")
|
||||
|
||||
return embeddings
|
||||
else:
|
||||
error_msg = f"API响应格式错误: {result}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
else:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n[WARNING] 用户中断请求")
|
||||
raise KeyboardInterrupt("用户中断请求")
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
error_msg = f"请求超时: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
||||
else:
|
||||
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
error_msg = f"连接错误: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
||||
else:
|
||||
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"网络请求异常: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
||||
else:
|
||||
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
raise RuntimeError("所有重试都失败了")
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
对单个查询文本进行嵌入
|
||||
|
||||
Args:
|
||||
text: 查询文本
|
||||
|
||||
Returns:
|
||||
嵌入向量
|
||||
"""
|
||||
embeddings = self.embed_texts([text])
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
|
||||
def create_dashscope_embedding_model(
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> DashScopeEmbeddingModel:
|
||||
"""创建DashScope嵌入模型实例的便捷函数"""
|
||||
return DashScopeEmbeddingModel(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
**kwargs
|
||||
)
|
||||
605
AIEC-RAG/retriver/langgraph/dashscope_llm.py
Normal file
605
AIEC-RAG/retriver/langgraph/dashscope_llm.py
Normal file
@ -0,0 +1,605 @@
|
||||
"""
|
||||
阿里云DashScope原生LLM实现
|
||||
直接调用阿里云通义千问API,不通过OneAPI兼容层
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
from typing import List, Dict, Any, Optional, Union, Iterator, Callable
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
|
||||
|
||||
|
||||
class DashScopeLLM(BaseLLM):
|
||||
"""
|
||||
阿里云DashScope原生LLM实现
|
||||
直接调用通义千问API
|
||||
"""
|
||||
|
||||
# Pydantic字段定义
|
||||
api_key: str = ""
|
||||
model_name: str = "qwen-turbo"
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
api_url: str = ""
|
||||
headers: dict = {}
|
||||
last_token_usage: dict = {}
|
||||
total_token_usage: dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
初始化DashScope LLM
|
||||
|
||||
Args:
|
||||
api_key: 阿里云DashScope API Key
|
||||
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
|
||||
max_retries: 最大重试次数
|
||||
retry_delay: 重试延迟时间(秒)
|
||||
"""
|
||||
# 先设置字段值
|
||||
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
|
||||
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
|
||||
|
||||
if not api_key_value:
|
||||
raise ValueError("DashScope API Key未设置,请在.env文件中设置ONEAPI_KEY")
|
||||
|
||||
# DashScope API端点
|
||||
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
|
||||
# 设置请求头
|
||||
headers_value = {
|
||||
'Authorization': f'Bearer {api_key_value}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-DashScope-SSE': 'disable' # 禁用流式响应
|
||||
}
|
||||
|
||||
# Token使用统计
|
||||
last_token_usage_value = {}
|
||||
total_token_usage_value = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
'call_count': 0
|
||||
}
|
||||
|
||||
# 初始化父类,传递所有字段
|
||||
super().__init__(
|
||||
api_key=api_key_value,
|
||||
model_name=model_name_value,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
api_url=api_url_value,
|
||||
headers=headers_value,
|
||||
last_token_usage=last_token_usage_value,
|
||||
total_token_usage=total_token_usage_value,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
|
||||
self._test_connection()
|
||||
|
||||
def _test_connection(self):
|
||||
"""测试DashScope连接"""
|
||||
try:
|
||||
# 发送一个简单的测试请求
|
||||
test_payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.1
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=test_payload,
|
||||
timeout=(5, 15) # 短超时进行连接测试
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('output') and result['output'].get('text'):
|
||||
print("[OK] DashScope连接测试成功")
|
||||
else:
|
||||
print(f"[WARNING] DashScope响应格式异常: {result}")
|
||||
else:
|
||||
print(f"[WARNING] DashScope连接异常,状态码: {response.status_code}")
|
||||
print(f" 响应: {response.text[:200]}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("[WARNING] DashScope连接超时,但将在实际请求时重试")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("[WARNING] DashScope连接失败,请检查网络状态")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] DashScope连接测试出错: {e}")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "dashscope"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""生成响应"""
|
||||
generations = []
|
||||
|
||||
for prompt in prompts:
|
||||
try:
|
||||
response_text = self._call_api(prompt, **kwargs)
|
||||
generations.append([Generation(text=response_text)])
|
||||
except Exception as e:
|
||||
print(f"[ERROR] DashScope API调用失败: {e}")
|
||||
generations.append([Generation(text=f"API调用失败: {str(e)}")])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _call_api(self, prompt: str, **kwargs) -> str:
|
||||
"""调用DashScope API"""
|
||||
# 构建请求payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"max_tokens": kwargs.get("max_tokens", 2048),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"top_p": kwargs.get("top_p", 0.8),
|
||||
}
|
||||
}
|
||||
|
||||
# 如果有stop参数,添加到payload中
|
||||
stop = kwargs.get("stop")
|
||||
if stop:
|
||||
payload["parameters"]["stop"] = stop
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
timeout=(30, 120) # 连接30秒,读取120秒
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
|
||||
# 检查是否有错误
|
||||
if result.get('code'):
|
||||
error_msg = result.get('message', f'API错误代码: {result["code"]}')
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"DashScope API错误: {error_msg}")
|
||||
else:
|
||||
print(f"API错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
|
||||
# 提取响应内容
|
||||
if result.get('output') and result['output'].get('text'):
|
||||
content = result['output']['text']
|
||||
|
||||
# 提取Token使用信息
|
||||
usage_info = result.get('usage', {})
|
||||
self.last_token_usage = {
|
||||
'prompt_tokens': usage_info.get('input_tokens', 0),
|
||||
'completion_tokens': usage_info.get('output_tokens', 0),
|
||||
'total_tokens': usage_info.get('total_tokens', 0)
|
||||
}
|
||||
|
||||
# 更新累计统计
|
||||
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
||||
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
||||
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
||||
self.total_token_usage['call_count'] += 1
|
||||
|
||||
return content.strip()
|
||||
else:
|
||||
error_msg = f"API响应格式错误: {result}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
else:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
error_msg = f"API请求失败,状态码: {response.status_code}, 响应: {error_text}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n[WARNING] 用户中断请求")
|
||||
raise KeyboardInterrupt("用户中断请求")
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
error_msg = f"请求超时: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
|
||||
else:
|
||||
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
error_msg = f"连接错误: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
|
||||
else:
|
||||
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
except requests.RequestException as e:
|
||||
error_msg = f"网络请求异常: {str(e)}"
|
||||
if attempt == self.max_retries - 1:
|
||||
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
|
||||
else:
|
||||
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
|
||||
time.sleep(self.retry_delay * (attempt + 1))
|
||||
|
||||
raise RuntimeError("所有重试都失败了")
|
||||
|
||||
|
||||
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
|
||||
"""
|
||||
生成响应,兼容原始接口
|
||||
|
||||
Args:
|
||||
messages: 消息列表或批量消息列表
|
||||
max_new_tokens: 最大生成token数
|
||||
temperature: 温度参数
|
||||
response_format: 响应格式
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本或文本列表
|
||||
"""
|
||||
# 检查是否为批量处理
|
||||
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
|
||||
|
||||
if is_batch:
|
||||
# 批量处理
|
||||
results = []
|
||||
for msg_list in messages:
|
||||
try:
|
||||
# 构造prompt
|
||||
prompt = self._format_messages(msg_list)
|
||||
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"批量处理中的单个请求失败: {str(e)}")
|
||||
results.append("")
|
||||
return results
|
||||
else:
|
||||
# 单个处理
|
||||
prompt = self._format_messages(messages)
|
||||
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def _format_messages(self, messages):
|
||||
"""将消息列表格式化为单一的prompt"""
|
||||
formatted_parts = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted_parts.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted_parts.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted_parts.append(f"Assistant: {content}")
|
||||
return "\n\n".join(formatted_parts)
|
||||
|
||||
def ner(self, text: str) -> str:
|
||||
"""
|
||||
命名实体识别
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
提取的实体,用逗号分隔
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Extract the named entities from: {text}"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
|
||||
except Exception as e:
|
||||
print(f"NER失败: {str(e)}")
|
||||
return text # 如果失败,返回原文本
|
||||
|
||||
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
|
||||
"""
|
||||
基于实体和事件过滤三元组
|
||||
|
||||
Args:
|
||||
question: 查询问题
|
||||
triples: 三元组JSON字符串
|
||||
|
||||
Returns:
|
||||
过滤后的三元组JSON字符串
|
||||
"""
|
||||
import json
|
||||
import json_repair
|
||||
|
||||
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
|
||||
filter_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
|
||||
|
||||
关键要求:
|
||||
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
|
||||
2. 你的输出必须是输入事实的严格子集
|
||||
3. 只包含与回答问题直接相关的事实
|
||||
4. 如果不确定,宁可选择更少的事实,也不要选择更多
|
||||
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
|
||||
|
||||
过滤规则:
|
||||
- 只选择包含与问题直接相关的实体或关系的事实
|
||||
- 不能修改、改写或创建输入事实的变体
|
||||
- 不能添加看起来相关但不在输入中的事实
|
||||
- 输出事实数量必须 ≤ 输入事实数量
|
||||
|
||||
返回格式为包含"fact"键的JSON对象,值为选中的事实数组。
|
||||
|
||||
示例:
|
||||
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
|
||||
问题:A和B是什么关系?
|
||||
正确输出:{"fact": [["A", "关系1", "B"]]}
|
||||
错误做法:添加新事实或修改现有事实"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""问题:{question}
|
||||
|
||||
待筛选的输入事实:
|
||||
{triples}
|
||||
|
||||
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.generate_response(
|
||||
filter_messages,
|
||||
max_new_tokens=4096,
|
||||
temperature=0.0
|
||||
)
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
parsed_response = json.loads(response)
|
||||
if 'fact' in parsed_response:
|
||||
return json.dumps(parsed_response, ensure_ascii=False)
|
||||
else:
|
||||
# 如果没有fact字段,尝试修复
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试使用json_repair
|
||||
try:
|
||||
parsed_response = json_repair.loads(response)
|
||||
if 'fact' in parsed_response:
|
||||
return json.dumps(parsed_response, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
except:
|
||||
# 如果所有解析都失败,返回空结果
|
||||
return json.dumps({"fact": []}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"三元组过滤失败: {str(e)}")
|
||||
# 如果过滤失败,返回原始三元组
|
||||
return triples
|
||||
|
||||
def generate_with_context(self,
|
||||
question: str,
|
||||
context: str,
|
||||
max_new_tokens: int = 1024,
|
||||
temperature: float = 0.7) -> str:
|
||||
"""
|
||||
基于上下文生成回答
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
context: 上下文
|
||||
max_new_tokens: 最大生成token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
生成的回答
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{context}\n\n{question}\nThought:"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
return self.generate_response(messages, max_new_tokens, temperature)
|
||||
except Exception as e:
|
||||
print(f"基于上下文生成失败: {str(e)}")
|
||||
return "抱歉,我无法基于提供的上下文回答这个问题。"
|
||||
|
||||
def get_token_usage(self) -> Dict[str, Any]:
|
||||
"""获取Token使用统计"""
|
||||
return {
|
||||
"last_usage": self.last_token_usage.copy(),
|
||||
"total_usage": self.total_token_usage.copy(),
|
||||
"model_name": self.model_name
|
||||
}
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
流式生成响应 - 使用OpenAI兼容接口
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
stream_callback: 流式回调函数,接收每个token chunk
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的生成文本
|
||||
"""
|
||||
try:
|
||||
# 使用OpenAI Python SDK进行流式调用
|
||||
from openai import OpenAI
|
||||
|
||||
# 创建OpenAI兼容客户端
|
||||
client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
# 发起流式请求
|
||||
stream = client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=kwargs.get("max_tokens", 2048),
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
top_p=kwargs.get("top_p", 0.8),
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
chunk_count = 0
|
||||
print(f"[DEBUG] 开始接收OpenAI兼容流式响应...")
|
||||
|
||||
for chunk in stream:
|
||||
# 提取增量内容
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta and delta.content:
|
||||
text_chunk = delta.content
|
||||
full_text += text_chunk
|
||||
chunk_count += 1
|
||||
|
||||
if stream_callback:
|
||||
stream_callback(text_chunk)
|
||||
|
||||
# 检查是否包含使用信息
|
||||
if hasattr(chunk, 'usage') and chunk.usage:
|
||||
self._update_token_usage({
|
||||
'input_tokens': chunk.usage.prompt_tokens,
|
||||
'output_tokens': chunk.usage.completion_tokens,
|
||||
'total_tokens': chunk.usage.total_tokens
|
||||
})
|
||||
|
||||
print(f"[DEBUG] 流式结束,共收到 {chunk_count} 个chunk,总长度 {len(full_text)} 字符")
|
||||
return full_text.strip()
|
||||
|
||||
except ImportError:
|
||||
print("[WARNING] OpenAI SDK未安装,降级到非流式")
|
||||
return self._call_api(prompt, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"[WARNING] OpenAI兼容流式失败: {e},降级到非流式")
|
||||
return self._call_api(prompt, **kwargs)
|
||||
|
||||
def _update_token_usage(self, usage_info: Dict[str, Any]):
|
||||
"""更新Token使用统计"""
|
||||
self.last_token_usage = {
|
||||
'prompt_tokens': usage_info.get('input_tokens', 0),
|
||||
'completion_tokens': usage_info.get('output_tokens', 0),
|
||||
'total_tokens': usage_info.get('total_tokens', 0)
|
||||
}
|
||||
|
||||
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
|
||||
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
|
||||
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
|
||||
self.total_token_usage['call_count'] += 1
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, List[str]],
|
||||
config: Optional[dict] = None,
|
||||
**kwargs
|
||||
) -> Union[str, LLMResult]:
|
||||
"""
|
||||
统一的调用接口,支持流式和非流式
|
||||
|
||||
Args:
|
||||
input: 输入文本或文本列表
|
||||
config: 配置字典,包含stream_callback等
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本或LLMResult对象
|
||||
"""
|
||||
# 从config中提取流式回调
|
||||
stream_callback = None
|
||||
if config and config.get('metadata', {}).get('stream_callback'):
|
||||
stream_callback = config['metadata']['stream_callback']
|
||||
|
||||
# 如果是字符串输入
|
||||
if isinstance(input, str):
|
||||
if stream_callback:
|
||||
# 使用流式生成
|
||||
return self.stream_generate(input, stream_callback, **kwargs)
|
||||
else:
|
||||
# 使用普通生成
|
||||
return self._call_api(input, **kwargs)
|
||||
|
||||
# 如果是列表输入(批处理),调用原有的_generate
|
||||
return self._generate(input, **kwargs)
|
||||
|
||||
|
||||
def create_dashscope_llm(
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> DashScopeLLM:
|
||||
"""创建DashScope LLM实例的便捷函数"""
|
||||
return DashScopeLLM(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
**kwargs
|
||||
)
|
||||
169
AIEC-RAG/retriver/langgraph/es_vector_retriever.py
Normal file
169
AIEC-RAG/retriver/langgraph/es_vector_retriever.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""
|
||||
ES向量检索器
|
||||
用于直接与ES向量库进行向量匹配检索
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.documents import Document
|
||||
|
||||
# 添加路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||||
sys.path.append(project_root)
|
||||
|
||||
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
|
||||
from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper
|
||||
from elasticsearch_vectorization.config import ElasticsearchConfig
|
||||
|
||||
|
||||
class ESVectorRetriever:
|
||||
"""ES向量检索器,用于直接进行向量相似度匹配"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keyword: str,
|
||||
top_k: int = 3,
|
||||
oneapi_key: Optional[str] = None,
|
||||
oneapi_base_url: Optional[str] = None,
|
||||
embed_model_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化ES向量检索器
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 返回的文档数量
|
||||
oneapi_key: OneAPI密钥
|
||||
oneapi_base_url: OneAPI基础URL
|
||||
embed_model_name: 嵌入模型名称
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.top_k = top_k
|
||||
|
||||
# 初始化嵌入模型
|
||||
self.embedding_model = DashScopeEmbeddingModel(
|
||||
api_key=oneapi_key,
|
||||
model_name=embed_model_name
|
||||
)
|
||||
|
||||
# 初始化ES客户端
|
||||
self.es_client = ESClientWrapper()
|
||||
|
||||
# 设置索引名称
|
||||
self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
|
||||
|
||||
print(f"ES向量检索器初始化完成,目标索引: {self.passages_index}")
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
检索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
检索到的文档列表
|
||||
"""
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0]
|
||||
|
||||
# 确保向量是列表格式
|
||||
if hasattr(query_embedding, 'tolist'):
|
||||
query_vector = query_embedding.tolist()
|
||||
else:
|
||||
query_vector = list(query_embedding)
|
||||
|
||||
# 执行向量搜索
|
||||
search_result = self.es_client.vector_search(
|
||||
index_name=self.passages_index,
|
||||
vector=query_vector,
|
||||
field="embedding",
|
||||
size=self.top_k
|
||||
)
|
||||
|
||||
# 解析搜索结果
|
||||
documents = []
|
||||
hits = search_result.get("hits", {}).get("hits", [])
|
||||
|
||||
for hit in hits:
|
||||
source = hit["_source"]
|
||||
score = hit["_score"]
|
||||
|
||||
# 创建Document对象
|
||||
doc = Document(
|
||||
page_content=source.get("content", ""),
|
||||
metadata={
|
||||
"passage_id": source.get("passage_id", ""),
|
||||
"file_id": source.get("file_id", ""),
|
||||
"evidence": source.get("evidence", ""),
|
||||
"score": score,
|
||||
"source": "es_vector_search"
|
||||
}
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
print(f"ES向量检索完成,找到 {len(documents)} 个相关文档")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
print(f"ES向量检索失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试ES连接"""
|
||||
try:
|
||||
return self.es_client.ping()
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_index_stats(self) -> Dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
try:
|
||||
query = {"match_all": {}}
|
||||
result = self.es_client.search(self.passages_index, query, size=0)
|
||||
total = result.get("hits", {}).get("total", 0)
|
||||
|
||||
# 兼容不同ES版本的total格式
|
||||
if isinstance(total, dict):
|
||||
count = total.get("value", 0)
|
||||
else:
|
||||
count = total
|
||||
|
||||
return {
|
||||
"index_name": self.passages_index,
|
||||
"document_count": count,
|
||||
"top_k": self.top_k
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取索引统计信息失败: {e}")
|
||||
return {
|
||||
"index_name": self.passages_index,
|
||||
"document_count": 0,
|
||||
"top_k": self.top_k,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def create_es_vector_retriever(
|
||||
keyword: str,
|
||||
top_k: int = 3,
|
||||
**kwargs
|
||||
) -> ESVectorRetriever:
|
||||
"""
|
||||
创建ES向量检索器的便捷函数
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 返回的文档数量
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ES向量检索器实例
|
||||
"""
|
||||
return ESVectorRetriever(
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
**kwargs
|
||||
)
|
||||
1055
AIEC-RAG/retriver/langgraph/graph_nodes.py
Normal file
1055
AIEC-RAG/retriver/langgraph/graph_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
260
AIEC-RAG/retriver/langgraph/graph_state.py
Normal file
260
AIEC-RAG/retriver/langgraph/graph_state.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
LangGraph状态定义
|
||||
定义工作流中的状态结构
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, TypedDict, Callable
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class QueryState(TypedDict):
|
||||
"""查询状态类型定义"""
|
||||
# 基本查询信息
|
||||
original_query: str
|
||||
current_iteration: int
|
||||
max_iterations: int
|
||||
|
||||
# 调试模式:0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
|
||||
debug_mode: str
|
||||
|
||||
# 查询复杂度
|
||||
query_complexity: Dict[str, Any]
|
||||
is_complex_query: bool
|
||||
|
||||
# 检索结果
|
||||
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
|
||||
all_documents: List[Document] # 现在包含事件节点文档
|
||||
passage_sources: List[str] # 事件来源信息
|
||||
|
||||
# 子查询相关
|
||||
sub_queries: List[str]
|
||||
current_sub_queries: List[str]
|
||||
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
|
||||
|
||||
# 充分性检查
|
||||
sufficiency_check: Dict[str, Any]
|
||||
is_sufficient: bool
|
||||
|
||||
# 最终结果
|
||||
final_answer: str
|
||||
|
||||
# 调试信息
|
||||
debug_info: Dict[str, Any]
|
||||
iteration_history: List[Dict[str, Any]]
|
||||
|
||||
# 概念探索相关状态(新增)
|
||||
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输,改为本地临时存储
|
||||
concept_exploration_results: Dict[str, Any]
|
||||
exploration_round: int # 当前探索轮次 (1 或 2)
|
||||
|
||||
# 添加本地PageRank存储标识(不包含实际数据)
|
||||
pagerank_data_available: bool
|
||||
|
||||
# 流式回调函数(可选)
|
||||
stream_callback: Optional[Callable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""检索结果数据类"""
|
||||
passages: List[str]
|
||||
documents: List[Document]
|
||||
sources: List[str]
|
||||
query: str
|
||||
iteration: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SufficiencyCheck:
|
||||
"""充分性检查结果数据类"""
|
||||
is_sufficient: bool
|
||||
confidence: float
|
||||
reason: str
|
||||
sub_queries: Optional[List[str]] = None
|
||||
|
||||
|
||||
def create_initial_state(
|
||||
original_query: str,
|
||||
max_iterations: int = 2,
|
||||
debug_mode: str = "0"
|
||||
) -> QueryState:
|
||||
"""
|
||||
创建初始状态
|
||||
|
||||
Args:
|
||||
original_query: 用户的原始查询
|
||||
max_iterations: 最大迭代次数
|
||||
debug_mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
|
||||
|
||||
Returns:
|
||||
初始状态字典
|
||||
"""
|
||||
return QueryState(
|
||||
original_query=original_query,
|
||||
current_iteration=0,
|
||||
max_iterations=max_iterations,
|
||||
debug_mode=debug_mode,
|
||||
query_complexity={},
|
||||
is_complex_query=False,
|
||||
all_passages=[],
|
||||
all_documents=[],
|
||||
passage_sources=[],
|
||||
sub_queries=[],
|
||||
current_sub_queries=[],
|
||||
decomposed_sub_queries=[],
|
||||
sufficiency_check={},
|
||||
is_sufficient=False,
|
||||
final_answer="",
|
||||
debug_info={
|
||||
"retrieval_calls": 0,
|
||||
"llm_calls": 0,
|
||||
"start_time": None,
|
||||
"end_time": None
|
||||
},
|
||||
iteration_history=[],
|
||||
concept_exploration_results={},
|
||||
exploration_round=0,
|
||||
pagerank_data_available=False,
|
||||
stream_callback=None # 初始化为None,会在需要时设置
|
||||
)
|
||||
|
||||
|
||||
def update_state_with_retrieval(
|
||||
state: QueryState,
|
||||
retrieval_result: RetrievalResult
|
||||
) -> QueryState:
|
||||
"""
|
||||
使用检索结果更新状态
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
retrieval_result: 检索结果
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
# 添加新的段落和文档
|
||||
state["all_passages"].extend(retrieval_result.passages)
|
||||
state["all_documents"].extend(retrieval_result.documents)
|
||||
state["passage_sources"].extend(retrieval_result.sources)
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["retrieval_calls"] += 1
|
||||
|
||||
# 添加到迭代历史
|
||||
iteration_info = {
|
||||
"iteration": retrieval_result.iteration,
|
||||
"query": retrieval_result.query,
|
||||
"passages_count": len(retrieval_result.passages),
|
||||
"action": "retrieval"
|
||||
}
|
||||
state["iteration_history"].append(iteration_info)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def update_state_with_sufficiency_check(
|
||||
state: QueryState,
|
||||
sufficiency_check: SufficiencyCheck
|
||||
) -> QueryState:
|
||||
"""
|
||||
使用充分性检查结果更新状态
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
sufficiency_check: 充分性检查结果
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
state["is_sufficient"] = sufficiency_check.is_sufficient
|
||||
state["sufficiency_check"] = {
|
||||
"is_sufficient": sufficiency_check.is_sufficient,
|
||||
"confidence": sufficiency_check.confidence,
|
||||
"reason": sufficiency_check.reason,
|
||||
"iteration": state["current_iteration"]
|
||||
}
|
||||
|
||||
# 如果不充分且有子查询,更新子查询
|
||||
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
|
||||
state["current_sub_queries"] = sufficiency_check.sub_queries
|
||||
state["sub_queries"].extend(sufficiency_check.sub_queries)
|
||||
else:
|
||||
state["current_sub_queries"] = []
|
||||
|
||||
# 更新调试信息
|
||||
state["debug_info"]["llm_calls"] += 1
|
||||
|
||||
# 添加到迭代历史
|
||||
iteration_info = {
|
||||
"iteration": state["current_iteration"],
|
||||
"action": "sufficiency_check",
|
||||
"is_sufficient": sufficiency_check.is_sufficient,
|
||||
"confidence": sufficiency_check.confidence,
|
||||
"sub_queries_count": len(sufficiency_check.sub_queries or [])
|
||||
}
|
||||
state["iteration_history"].append(iteration_info)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def increment_iteration(state: QueryState) -> QueryState:
|
||||
"""
|
||||
增加迭代次数
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
state["current_iteration"] += 1
|
||||
return state
|
||||
|
||||
|
||||
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
|
||||
"""
|
||||
完成状态,设置最终答案
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
final_answer: 最终答案
|
||||
|
||||
Returns:
|
||||
最终状态
|
||||
"""
|
||||
state["final_answer"] = final_answer
|
||||
|
||||
# 添加到迭代历史
|
||||
iteration_info = {
|
||||
"iteration": state["current_iteration"],
|
||||
"action": "final_answer_generation",
|
||||
"answer_length": len(final_answer)
|
||||
}
|
||||
state["iteration_history"].append(iteration_info)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def get_state_summary(state: QueryState) -> Dict[str, Any]:
|
||||
"""
|
||||
获取状态摘要信息
|
||||
|
||||
Args:
|
||||
state: 当前状态
|
||||
|
||||
Returns:
|
||||
状态摘要字典
|
||||
"""
|
||||
return {
|
||||
"original_query": state["original_query"],
|
||||
"current_iteration": state["current_iteration"],
|
||||
"max_iterations": state["max_iterations"],
|
||||
"total_passages": len(state["all_passages"]),
|
||||
"total_sub_queries": len(state["sub_queries"]),
|
||||
"is_sufficient": state["is_sufficient"],
|
||||
"has_final_answer": bool(state["final_answer"]),
|
||||
"debug_info": state["debug_info"],
|
||||
"iteration_count": len(state["iteration_history"])
|
||||
}
|
||||
318
AIEC-RAG/retriver/langgraph/iterative_retriever.py
Normal file
318
AIEC-RAG/retriver/langgraph/iterative_retriever.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""
|
||||
基于LangGraph的迭代检索器
|
||||
实现智能迭代检索工作流
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
|
||||
from retriver.langgraph.graph_nodes import GraphNodes
|
||||
from retriver.langgraph.routing_functions import (
|
||||
should_continue_retrieval,
|
||||
route_by_complexity,
|
||||
route_by_debug_mode
|
||||
)
|
||||
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
|
||||
from retriver.langgraph.langchain_components import create_oneapi_llm
|
||||
|
||||
|
||||
class IterativeRetriever:
|
||||
"""
|
||||
基于LangGraph的迭代检索器
|
||||
|
||||
实现迭代检索流程:
|
||||
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keyword: str,
|
||||
top_k: int = 2,
|
||||
max_iterations: int = 2,
|
||||
max_parallel_retrievals: int = 2,
|
||||
oneapi_key: Optional[str] = None,
|
||||
oneapi_base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
embed_model_name: Optional[str] = None,
|
||||
complexity_model_name: Optional[str] = None,
|
||||
sufficiency_model_name: Optional[str] = None,
|
||||
skip_llm_generation: bool = False
|
||||
):
|
||||
"""
|
||||
初始化迭代检索器
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 每次检索返回的文档数量
|
||||
max_iterations: 最大迭代次数
|
||||
max_parallel_retrievals: 最大并行检索数
|
||||
oneapi_key: OneAPI密钥
|
||||
oneapi_base_url: OneAPI基础URL
|
||||
model_name: 主LLM模型名称(用于生成答案)
|
||||
embed_model_name: 嵌入模型名称
|
||||
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
|
||||
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
|
||||
skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果)
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.top_k = top_k
|
||||
self.max_iterations = max_iterations
|
||||
self.max_parallel_retrievals = max_parallel_retrievals
|
||||
self.skip_llm_generation = skip_llm_generation
|
||||
|
||||
# 使用默认值(如果没有指定特定模型)
|
||||
complexity_model_name = complexity_model_name or model_name
|
||||
sufficiency_model_name = sufficiency_model_name or model_name
|
||||
|
||||
# 创建组件
|
||||
print("[CONFIG] 初始化检索器组件...")
|
||||
self.retriever = create_langchain_hipporag_retriever(
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
oneapi_model_gen=model_name,
|
||||
oneapi_model_embed=embed_model_name
|
||||
)
|
||||
|
||||
# 创建主LLM(用于生成答案)
|
||||
self.llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# 创建复杂度判断LLM(如果模型不同)
|
||||
if complexity_model_name != model_name:
|
||||
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
|
||||
self.complexity_llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=complexity_model_name
|
||||
)
|
||||
else:
|
||||
self.complexity_llm = self.llm
|
||||
|
||||
# 创建充分性检查LLM(如果模型不同)
|
||||
if sufficiency_model_name != model_name:
|
||||
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
|
||||
self.sufficiency_llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=sufficiency_model_name
|
||||
)
|
||||
else:
|
||||
self.sufficiency_llm = self.llm
|
||||
|
||||
# 创建节点处理器
|
||||
self.nodes = GraphNodes(
|
||||
retriever=self.retriever,
|
||||
llm=self.llm,
|
||||
complexity_llm=self.complexity_llm,
|
||||
sufficiency_llm=self.sufficiency_llm,
|
||||
keyword=keyword,
|
||||
max_parallel_retrievals=max_parallel_retrievals,
|
||||
skip_llm_generation=skip_llm_generation
|
||||
)
|
||||
|
||||
# 构建工作流图
|
||||
self.workflow = self._build_workflow()
|
||||
|
||||
print("[OK] 迭代检索器初始化完成")
|
||||
|
||||
def _build_workflow(self) -> StateGraph:
|
||||
"""构建LangGraph工作流"""
|
||||
print("[?] 构建工作流图...")
|
||||
|
||||
# 创建状态图
|
||||
workflow = StateGraph(QueryState)
|
||||
|
||||
# 添加节点
|
||||
# 新增:查询复杂度判断节点
|
||||
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
|
||||
|
||||
# 新增:调试模式节点
|
||||
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
|
||||
|
||||
# 简单查询路径
|
||||
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
|
||||
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
|
||||
|
||||
# 复杂查询路径(现有hipporag2逻辑)
|
||||
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
|
||||
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
|
||||
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
|
||||
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
|
||||
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
|
||||
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
|
||||
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
|
||||
|
||||
# 设置入口点:从查询复杂度判断开始
|
||||
workflow.set_entry_point("query_complexity_check")
|
||||
|
||||
# 复杂度检查后进入调试模式节点
|
||||
workflow.add_edge("query_complexity_check", "debug_mode_node")
|
||||
|
||||
# 条件边:根据调试模式和复杂度判断结果决定路径
|
||||
workflow.add_conditional_edges(
|
||||
"debug_mode_node",
|
||||
route_by_debug_mode,
|
||||
{
|
||||
"simple_vector_retrieval": "simple_vector_retrieval",
|
||||
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
|
||||
}
|
||||
)
|
||||
|
||||
# 简单查询路径的边
|
||||
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
|
||||
workflow.add_edge("simple_answer_generation", END)
|
||||
|
||||
# 复杂查询路径的边(包含查询分解逻辑)
|
||||
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
|
||||
workflow.add_edge("initial_retrieval", "sufficiency_check")
|
||||
|
||||
# 条件边:根据充分性检查结果决定下一步
|
||||
workflow.add_conditional_edges(
|
||||
"sufficiency_check",
|
||||
should_continue_retrieval,
|
||||
{
|
||||
"final_answer": "final_answer",
|
||||
"parallel_retrieval": "sub_query_generation",
|
||||
"next_iteration": "next_iteration"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("sub_query_generation", "parallel_retrieval")
|
||||
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后增加迭代次数
|
||||
workflow.add_edge("next_iteration", "sufficiency_check")
|
||||
|
||||
# 结束节点
|
||||
workflow.add_edge("final_answer", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
|
||||
"""
|
||||
执行迭代检索
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
|
||||
|
||||
Returns:
|
||||
包含最终答案和详细信息的字典
|
||||
"""
|
||||
print(f"[STARTING] 开始迭代检索: {query}")
|
||||
start_time = time.time()
|
||||
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(
|
||||
original_query=query,
|
||||
max_iterations=self.max_iterations,
|
||||
debug_mode=mode
|
||||
)
|
||||
initial_state["debug_info"]["start_time"] = start_time
|
||||
|
||||
try:
|
||||
# 执行工作流
|
||||
final_state = self.workflow.invoke(initial_state)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
final_state["debug_info"]["end_time"] = end_time
|
||||
final_state["debug_info"]["total_time"] = end_time - start_time
|
||||
|
||||
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}秒")
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"query": query,
|
||||
"answer": final_state["final_answer"],
|
||||
"query_complexity": final_state["query_complexity"],
|
||||
"is_complex_query": final_state["is_complex_query"],
|
||||
"iterations": final_state["current_iteration"],
|
||||
"total_passages": len(final_state["all_passages"]),
|
||||
"sub_queries": final_state["sub_queries"],
|
||||
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
|
||||
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
|
||||
"sufficiency_check": final_state["sufficiency_check"],
|
||||
"all_passages": final_state["all_passages"],
|
||||
"debug_info": final_state["debug_info"],
|
||||
"state_summary": get_state_summary(final_state),
|
||||
"iteration_history": final_state["iteration_history"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 迭代检索失败: {e}")
|
||||
return {
|
||||
"query": query,
|
||||
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
|
||||
"error": str(e),
|
||||
"iterations": 0,
|
||||
"total_passages": 0,
|
||||
"sub_queries": [],
|
||||
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
|
||||
}
|
||||
|
||||
def retrieve_simple(self, query: str) -> str:
|
||||
"""
|
||||
简单检索接口,只返回答案
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
最终答案字符串
|
||||
"""
|
||||
result = self.retrieve(query)
|
||||
return result["answer"]
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取检索器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return {
|
||||
"keyword": self.keyword,
|
||||
"top_k": self.top_k,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
"retriever_type": "IterativeRetriever with LangGraph",
|
||||
"model_info": {
|
||||
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
|
||||
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_iterative_retriever(
|
||||
keyword: str,
|
||||
top_k: int = 2,
|
||||
max_iterations: int = 2,
|
||||
max_parallel_retrievals: int = 2,
|
||||
**kwargs
|
||||
) -> IterativeRetriever:
|
||||
"""
|
||||
创建迭代检索器的便捷函数
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 每次检索返回的文档数量
|
||||
max_iterations: 最大迭代次数
|
||||
max_parallel_retrievals: 最大并行检索数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
迭代检索器实例
|
||||
"""
|
||||
return IterativeRetriever(
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
max_iterations=max_iterations,
|
||||
max_parallel_retrievals=max_parallel_retrievals,
|
||||
**kwargs
|
||||
)
|
||||
952
AIEC-RAG/retriver/langgraph/langchain_components.py
Normal file
952
AIEC-RAG/retriver/langgraph/langchain_components.py
Normal file
@ -0,0 +1,952 @@
|
||||
"""
|
||||
LangChain组件实现
|
||||
包含LLM、提示词模板、输出解析器等
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import BaseOutputParser, PydanticOutputParser
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 尝试导入LangSmith的traceable装饰器
|
||||
try:
|
||||
from langsmith import traceable
|
||||
LANGSMITH_AVAILABLE = True
|
||||
except ImportError:
|
||||
def traceable(name=None):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
LANGSMITH_AVAILABLE = False
|
||||
|
||||
# 添加路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||||
sys.path.append(project_root)
|
||||
|
||||
from retriver.langgraph.dashscope_llm import DashScopeLLM
|
||||
from prompt_loader import get_prompt_loader
|
||||
|
||||
|
||||
class SufficiencyCheckResult(BaseModel):
|
||||
"""充分性检查结果"""
|
||||
is_sufficient: bool = Field(description="是否足够回答用户查询")
|
||||
confidence: float = Field(description="置信度,0-1之间", ge=0, le=1)
|
||||
reason: str = Field(description="判断理由")
|
||||
sub_queries: Optional[List[str]] = Field(default=None, description="如果不充分,生成的子查询列表")
|
||||
|
||||
|
||||
class QueryComplexityResult(BaseModel):
|
||||
"""查询复杂度判断结果"""
|
||||
is_complex: bool = Field(description="查询是否复杂")
|
||||
complexity_level: str = Field(description="复杂度级别:simple/complex")
|
||||
confidence: float = Field(description="置信度,0-1之间", ge=0, le=1)
|
||||
reason: str = Field(description="判断理由")
|
||||
|
||||
|
||||
class OneAPILLM(BaseLLM):
|
||||
"""
|
||||
LangChain包装的DashScope LLM(保持原有接口兼容性)
|
||||
直接使用阿里云DashScope原生API
|
||||
支持流式输出
|
||||
"""
|
||||
|
||||
# Pydantic字段声明
|
||||
oneapi_generator: Any = Field(default=None, description="DashScope LLM实例", exclude=True)
|
||||
|
||||
def __init__(self, dashscope_llm: DashScopeLLM, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.oneapi_generator = dashscope_llm # 保持接口兼容性
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""生成响应"""
|
||||
# 简化处理,取第一个消息
|
||||
message = messages[0] if messages else ""
|
||||
|
||||
# 直接调用DashScope LLM
|
||||
try:
|
||||
response_text = self.oneapi_generator._call_api(
|
||||
message,
|
||||
stop=stop,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] DashScope API调用失败: {e}")
|
||||
response_text = f"API调用失败: {str(e)}"
|
||||
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
|
||||
# 获取Token使用信息
|
||||
token_usage = getattr(self.oneapi_generator, 'last_token_usage', {})
|
||||
total_usage = getattr(self.oneapi_generator, 'total_token_usage', {})
|
||||
|
||||
# 打印Token信息到控制台
|
||||
if token_usage:
|
||||
print(f"[?] LLM调用Token消耗: 输入{token_usage.get('prompt_tokens', 0)}, "
|
||||
f"输出{token_usage.get('completion_tokens', 0)}, "
|
||||
f"总计{token_usage.get('total_tokens', 0)} "
|
||||
f"[累计: {total_usage.get('total_tokens', 0)}]")
|
||||
|
||||
result = LLMResult(
|
||||
generations=[[Generation(text=response_text)]],
|
||||
llm_output={
|
||||
"model_name": self.oneapi_generator.model_name,
|
||||
"token_usage": {
|
||||
"current_call": token_usage,
|
||||
"cumulative_total": total_usage,
|
||||
"prompt_tokens": token_usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": token_usage.get('completion_tokens', 0),
|
||||
"total_tokens": token_usage.get('total_tokens', 0)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _llm_type(self) -> str:
|
||||
return "oneapi_llm"
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
统一的调用接口,支持流式和非流式
|
||||
|
||||
Args:
|
||||
input: 输入文本
|
||||
config: 配置字典,包含stream_callback等
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的文本或LLMResult对象
|
||||
"""
|
||||
# 直接委托给底层的DashScopeLLM的invoke方法
|
||||
return self.oneapi_generator.invoke(input, config=config, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.oneapi_generator.model_name,
|
||||
"api_url": getattr(self.oneapi_generator, 'api_url', 'dashscope')
|
||||
}
|
||||
|
||||
def stream(self, prompt: str, **kwargs):
|
||||
"""
|
||||
流式生成文本
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
每个token字符串
|
||||
"""
|
||||
# 直接调用底层DashScope的流式方法
|
||||
return self.oneapi_generator.stream(prompt, **kwargs)
|
||||
|
||||
|
||||
class SufficiencyCheckParser(BaseOutputParser[SufficiencyCheckResult]):
|
||||
"""充分性检查结果解析器"""
|
||||
|
||||
def parse(self, text: str) -> SufficiencyCheckResult:
|
||||
"""解析LLM输出为结构化结果"""
|
||||
try:
|
||||
# 清理markdown代码块标记
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:] # 移除 ```json
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:] # 移除 ```
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
# 尝试解析JSON
|
||||
data = json.loads(cleaned_text)
|
||||
return SufficiencyCheckResult(**data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# 如果JSON解析失败,使用规则解析
|
||||
return self._rule_based_parse(text)
|
||||
|
||||
def _rule_based_parse(self, text: str) -> SufficiencyCheckResult:
|
||||
"""基于规则的解析"""
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
# 判断是否充分
|
||||
is_sufficient = any(keyword in text_lower for keyword in [
|
||||
"充分", "足够", "sufficient", "enough", "adequate"
|
||||
]) and not any(keyword in text_lower for keyword in [
|
||||
"不充分", "不足够", "insufficient", "not enough", "inadequate"
|
||||
])
|
||||
|
||||
# 提取置信度
|
||||
confidence = 0.7 # 默认置信度
|
||||
|
||||
# 提取理由
|
||||
reason = text[:200] if len(text) <= 200 else text[:200] + "..."
|
||||
|
||||
# 提取子查询
|
||||
sub_queries = None
|
||||
if not is_sufficient:
|
||||
# 简单提取子查询(这里可以更复杂的实现)
|
||||
lines = text.split('\n')
|
||||
queries = []
|
||||
for line in lines:
|
||||
if '?' in line and len(line.strip()) > 10:
|
||||
queries.append(line.strip())
|
||||
sub_queries = queries[:2] if queries else None
|
||||
|
||||
return SufficiencyCheckResult(
|
||||
is_sufficient=is_sufficient,
|
||||
confidence=confidence,
|
||||
reason=reason,
|
||||
sub_queries=sub_queries
|
||||
)
|
||||
|
||||
|
||||
class QueryComplexityParser(BaseOutputParser[QueryComplexityResult]):
|
||||
"""查询复杂度判断结果解析器"""
|
||||
|
||||
def parse(self, text: str) -> QueryComplexityResult:
|
||||
"""解析LLM输出为结构化结果"""
|
||||
try:
|
||||
# 清理markdown代码块标记
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:] # 移除 ```json
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:] # 移除 ```
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
# 尝试解析JSON
|
||||
data = json.loads(cleaned_text)
|
||||
return QueryComplexityResult(**data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# 如果JSON解析失败,使用规则解析
|
||||
return self._rule_based_parse(text)
|
||||
|
||||
def _rule_based_parse(self, text: str) -> QueryComplexityResult:
|
||||
"""基于规则的解析"""
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
# 判断是否复杂
|
||||
is_complex = any(keyword in text_lower for keyword in [
|
||||
"复杂", "complex", "推理", "多步", "关联", "综合", "分析"
|
||||
]) or any(keyword in text_lower for keyword in [
|
||||
"需要", "require", "关系", "连接", "因果", "比较"
|
||||
])
|
||||
|
||||
# 如果包含简单标识,则不复杂
|
||||
if any(keyword in text_lower for keyword in [
|
||||
"简单", "simple", "直接", "单一", "基础"
|
||||
]):
|
||||
is_complex = False
|
||||
|
||||
complexity_level = "complex" if is_complex else "simple"
|
||||
confidence = 0.8 # 默认置信度
|
||||
reason = text[:200] if len(text) <= 200 else text[:200] + "..."
|
||||
|
||||
return QueryComplexityResult(
|
||||
is_complex=is_complex,
|
||||
complexity_level=complexity_level,
|
||||
confidence=confidence,
|
||||
reason=reason
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptExplorationResult:
|
||||
"""概念探索结果数据类"""
|
||||
selected_concept: str
|
||||
exploration_reason: str
|
||||
confidence: float
|
||||
expected_knowledge: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptContinueResult:
|
||||
"""概念探索继续结果数据类"""
|
||||
selected_node: str
|
||||
selected_relation: str
|
||||
exploration_reason: str
|
||||
confidence: float
|
||||
expected_knowledge: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptSufficiencyResult:
|
||||
"""概念探索充分性检查结果数据类"""
|
||||
is_sufficient: bool
|
||||
confidence: float
|
||||
reason: str
|
||||
missing_aspects: List[str]
|
||||
coverage_score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptSubQueryResult:
|
||||
"""概念探索子查询生成结果数据类"""
|
||||
sub_query: str
|
||||
focus_aspects: List[str]
|
||||
expected_improvements: str
|
||||
confidence: float
|
||||
|
||||
|
||||
class ConceptExplorationParser:
|
||||
"""概念探索初始选择解析器"""
|
||||
|
||||
def parse(self, text: str) -> ConceptExplorationResult:
|
||||
"""解析概念探索选择结果"""
|
||||
try:
|
||||
data = json.loads(text.strip())
|
||||
return ConceptExplorationResult(
|
||||
selected_concept=data.get("selected_concept", ""),
|
||||
exploration_reason=data.get("exploration_reason", ""),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
expected_knowledge=data.get("expected_knowledge", "")
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
# 如果解析失败,返回默认结果
|
||||
return ConceptExplorationResult(
|
||||
selected_concept="",
|
||||
exploration_reason=f"解析失败: {str(e)}",
|
||||
confidence=0.1,
|
||||
expected_knowledge=""
|
||||
)
|
||||
|
||||
|
||||
class ConceptContinueParser:
|
||||
"""概念探索继续解析器"""
|
||||
|
||||
def parse(self, text: str) -> ConceptContinueResult:
|
||||
"""解析概念探索继续结果"""
|
||||
try:
|
||||
data = json.loads(text.strip())
|
||||
return ConceptContinueResult(
|
||||
selected_node=data.get("selected_node", ""),
|
||||
selected_relation=data.get("selected_relation", ""),
|
||||
exploration_reason=data.get("exploration_reason", ""),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
expected_knowledge=data.get("expected_knowledge", "")
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
# 如果解析失败,返回默认结果
|
||||
return ConceptContinueResult(
|
||||
selected_node="",
|
||||
selected_relation="",
|
||||
exploration_reason=f"解析失败: {str(e)}",
|
||||
confidence=0.1,
|
||||
expected_knowledge=""
|
||||
)
|
||||
|
||||
|
||||
class ConceptSufficiencyParser:
|
||||
"""概念探索充分性检查解析器"""
|
||||
|
||||
def parse(self, text: str) -> ConceptSufficiencyResult:
|
||||
"""解析概念探索充分性检查结果"""
|
||||
try:
|
||||
data = json.loads(text.strip())
|
||||
return ConceptSufficiencyResult(
|
||||
is_sufficient=bool(data.get("is_sufficient", False)),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
reason=data.get("reason", ""),
|
||||
missing_aspects=data.get("missing_aspects", []),
|
||||
coverage_score=float(data.get("coverage_score", 0.0))
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
# 如果解析失败,返回不充分的默认结果
|
||||
return ConceptSufficiencyResult(
|
||||
is_sufficient=False,
|
||||
confidence=0.1,
|
||||
reason=f"解析失败: {str(e)}",
|
||||
missing_aspects=["解析错误"],
|
||||
coverage_score=0.0
|
||||
)
|
||||
|
||||
|
||||
class ConceptSubQueryParser:
|
||||
"""概念探索子查询生成解析器"""
|
||||
|
||||
def parse(self, text: str) -> ConceptSubQueryResult:
|
||||
"""解析概念探索子查询生成结果"""
|
||||
try:
|
||||
data = json.loads(text.strip())
|
||||
return ConceptSubQueryResult(
|
||||
sub_query=data.get("sub_query", ""),
|
||||
focus_aspects=data.get("focus_aspects", []),
|
||||
expected_improvements=data.get("expected_improvements", ""),
|
||||
confidence=float(data.get("confidence", 0.5))
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
# 如果解析失败,返回默认结果
|
||||
return ConceptSubQueryResult(
|
||||
sub_query="",
|
||||
focus_aspects=["解析错误"],
|
||||
expected_improvements=f"解析失败: {str(e)}",
|
||||
confidence=0.1
|
||||
)
|
||||
|
||||
|
||||
# 获取全局提示词加载器
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
# 动态加载提示词模板
|
||||
QUERY_COMPLEXITY_CHECK_PROMPT = _prompt_loader.get_prompt_template('query_complexity_check')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
QUERY_COMPLEXITY_CHECK_PROMPT = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template="""
|
||||
作为一个智能查询分析助手,请分析用户查询的复杂度,判断该查询是否需要生成多方面的多个子查询来回答。
|
||||
|
||||
用户查询:{query}
|
||||
|
||||
请根据以下标准判断查询复杂度:
|
||||
|
||||
【复杂查询(Complex)特征 - 优先判断】:
|
||||
1. **多问句查询**:包含多个问号(?),涉及多个不同的问题或主题
|
||||
2. **跨领域查询**:涉及多个不同领域或行业的知识(如金融+技术+风险管理等)
|
||||
3. **复合问题**:一个查询中包含多个子查询,即使每个子查询本身较简单
|
||||
4. **关系型查询**:询问实体间的关系、比较、关联等
|
||||
5. **因果推理**:询问原因、结果、影响等
|
||||
6. **综合分析**:需要综合多个信息源进行分析
|
||||
7. **推理链查询**:需要通过知识图谱的路径推理才能回答
|
||||
8. **列表型查询**:要求列举多个项目或要素的问题
|
||||
|
||||
【简单查询(Simple)特征】:
|
||||
1. **单一问句**:只包含一个问号,聚焦于单一主题
|
||||
2. **单领域查询**:仅涉及一个明确的领域或概念
|
||||
3. **直接定义查询**:询问单一概念的定义、特点、属性等
|
||||
4. **单一实体信息查询**:询问某个具体事物的基本信息
|
||||
5. **可以通过文档中的连续文本段落直接回答的单一问题**
|
||||
|
||||
请以JSON格式返回判断结果:
|
||||
{{
|
||||
"is_complex": false, // true表示复杂查询,false表示简单查询
|
||||
"complexity_level": "simple", // "simple" 或 "complex"
|
||||
"confidence": 0.9, // 置信度,0-1之间
|
||||
"reason": "这是一个复杂查询,需要生成多方面的个子查询来回答..."
|
||||
}}
|
||||
|
||||
请确保返回有效的JSON格式。
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
# 动态加载充分性检查提示词
|
||||
SUFFICIENCY_CHECK_PROMPT = _prompt_loader.get_prompt_template('sufficiency_check')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
SUFFICIENCY_CHECK_PROMPT = PromptTemplate(
|
||||
input_variables=["query", "passages", "decomposed_sub_queries"],
|
||||
template="""
|
||||
作为一个智能问答助手,请判断仅从给定的信息是否已经足够回答用户的查询。
|
||||
|
||||
用户查询:{query}
|
||||
|
||||
已生成的子查询:{decomposed_sub_queries}
|
||||
|
||||
检索到的信息(包括原始查询和子查询的结果):
|
||||
{passages}
|
||||
|
||||
请分析这些信息是否包含足够的内容来完整回答用户的查询。注意检索结果包含两部分:
|
||||
1. 【事件信息】- 来自知识图谱的事件节点,包含事件描述和上下文
|
||||
2. 【段落信息】- 来自文档的段落内容,包含详细的文本描述
|
||||
|
||||
如果信息充分,请返回JSON格式:
|
||||
{{
|
||||
"is_sufficient": true,
|
||||
"confidence": 0.9,
|
||||
"reason": "事件信息和段落信息包含了回答查询所需的关键内容..."
|
||||
}}
|
||||
|
||||
如果信息不充分,请返回JSON格式,并详细说明缺失的信息:
|
||||
{{
|
||||
"is_sufficient": false,
|
||||
"confidence": 0.5,
|
||||
"reason": "检索信息缺少某些关键内容,具体包括:1) 缺少XXX的详细描述;2) 缺少XXX的具体实例;3) 缺少XXX的应用场景等..."
|
||||
}}
|
||||
|
||||
请确保返回有效的JSON格式。
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
# 动态加载查询分解提示词
|
||||
QUERY_DECOMPOSITION_PROMPT = _prompt_loader.get_prompt_template('query_decomposition')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
QUERY_DECOMPOSITION_PROMPT = PromptTemplate(
|
||||
input_variables=["original_query"],
|
||||
template="""
|
||||
你是一个专业的查询分解助手。你的任务是将用户的复合查询分解为2个独立的、适合向量检索的子查询。
|
||||
|
||||
用户原始查询:{original_query}
|
||||
|
||||
【重要】向量检索特点:
|
||||
- 向量检索通过语义相似度匹配,找到与查询最相关的文档段落
|
||||
- 每个子查询应该聚焦一个明确的主题或概念
|
||||
- 子查询应该是完整的、可独立检索的问题
|
||||
|
||||
【分解策略】:
|
||||
|
||||
1. **识别查询结构**:
|
||||
- 仔细查看查询中是否有多个问号(?)分隔的不同问题
|
||||
- 查找连接词:"和"、"以及"、"还有"、"另外"
|
||||
- 查找标点符号:句号(。)、分号(;)等分隔符
|
||||
|
||||
2. **按主题分解**:
|
||||
- 如果查询包含多个独立主题,将其分解为独立的问题
|
||||
- 每个子查询保持完整性,包含必要的上下文信息
|
||||
|
||||
【示例】:
|
||||
|
||||
输入:"混沌工程是什么?混沌工程的基础建设是由什么整合的?"
|
||||
分析:这是两个关于混沌工程的不同问题
|
||||
输出:["混沌工程是什么", "混沌工程的基础建设是由什么整合的"]
|
||||
|
||||
输入:"2023年某年度报告中提到存在数据安全问题的支付平台可能涉及哪些外部技术组件?混沌工程的定义是什么?"
|
||||
分析:这包含两个完全不同的主题
|
||||
输出:["2023年某年度报告中提到存在数据安全问题的支付平台可能涉及哪些外部技术组件", "混沌工程的定义是什么"]
|
||||
|
||||
输入:"什么是人工智能?它的主要应用有哪些?"
|
||||
分析:两个相关但独立的问题
|
||||
输出:["什么是人工智能", "人工智能的主要应用有哪些"]
|
||||
|
||||
输入:"区块链技术的优势和挑战是什么?"
|
||||
分析:单一主题但包含两个方面
|
||||
输出:["区块链技术的优势是什么", "区块链技术面临哪些挑战"]
|
||||
|
||||
【要求】:
|
||||
1. [NOTE] 必须使用自然语言,类似人类会问的问题
|
||||
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
|
||||
3. [SEARCH] 严格按照查询的自然分割点进行分解
|
||||
4. [?] 每个子查询必须是完整的自然语言问题
|
||||
5. [BLOCKED] 绝不要添加"详细信息"、"相关内容"、"补充信息"等后缀
|
||||
6. [INFO] 保持原始查询中的所有关键信息(时间、地点、对象等)
|
||||
7. [TARGET] 确保两个子查询可以独立进行向量检索
|
||||
|
||||
请严格按照JSON格式返回自然语言查询:
|
||||
{{
|
||||
"sub_queries": ["自然语言子查询1", "自然语言子查询2"]
|
||||
}}
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
# 动态加载子查询生成提示词
|
||||
SUB_QUERY_GENERATION_PROMPT = _prompt_loader.get_prompt_template('sub_query_generation')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
SUB_QUERY_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["original_query", "existing_passages", "previous_sub_queries", "insufficiency_reason"],
|
||||
template="""
|
||||
基于用户的原始查询、已有的检索结果和充分性检查反馈,生成2个相关的子查询来获取缺失信息。
|
||||
|
||||
原始查询:{original_query}
|
||||
|
||||
之前生成的子查询:{previous_sub_queries}
|
||||
|
||||
已有检索结果(包含事件信息和段落信息):
|
||||
{existing_passages}
|
||||
|
||||
充分性检查反馈(信息不充分的原因):
|
||||
{insufficiency_reason}
|
||||
|
||||
请根据充分性检查反馈中指出的缺失信息,生成2个具体的自然语言子查询来补充这些缺失内容。注意已有检索结果包含:
|
||||
1. 【事件信息】- 来自知识图谱的事件节点
|
||||
2. 【段落信息】- 来自文档的段落内容
|
||||
|
||||
【重要要求】:
|
||||
1. [NOTE] 必须使用自然语言表达,类似人类会问的问题
|
||||
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
|
||||
3. [OK] 使用疑问句形式,如"什么是..."、"如何..."、"有哪些..."等
|
||||
4. [TARGET] 直接针对充分性检查中指出的缺失信息
|
||||
5. [LINK] 与原始查询高度相关
|
||||
6. [INFO] 查询要具体明确,能够获取到具体的信息
|
||||
7. [BLOCKED] 避免与之前已生成的子查询重复
|
||||
8. [TARGET] 确保每个子查询都能独立检索到有价值的信息
|
||||
|
||||
【示例格式】:
|
||||
- [OK] 正确:"混沌工程的基础建设具体包括哪些组件?"
|
||||
- [ERROR] 错误:"SELECT * FROM chaos_engineering WHERE type='infrastructure'"
|
||||
- [OK] 正确:"供应链风险管理中常见的网络安全威胁有哪些?"
|
||||
- [ERROR] 错误:"SELECT risks FROM supply_chain WHERE category='network'"
|
||||
|
||||
请以JSON格式返回自然语言查询:
|
||||
{{
|
||||
"sub_queries": ["自然语言子查询1", "自然语言子查询2"]
|
||||
}}
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
# 动态加载简单答案提示词
|
||||
SIMPLE_ANSWER_PROMPT = _prompt_loader.get_prompt_template('simple_answer')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
SIMPLE_ANSWER_PROMPT = PromptTemplate(
|
||||
input_variables=["query", "passages"],
|
||||
template="""
|
||||
基于检索到的信息,请为用户的查询提供一个准确、简洁的答案。
|
||||
|
||||
用户查询:{query}
|
||||
|
||||
检索到的相关信息:
|
||||
{passages}
|
||||
|
||||
请基于这些信息回答用户的查询。注意检索结果包含:
|
||||
1. 【事件信息】- 来自知识图谱的事件节点,提供事件相关的上下文
|
||||
2. 【段落信息】- 来自文档的段落内容,提供详细的文本描述
|
||||
|
||||
要求:
|
||||
1. 直接回答用户的查询
|
||||
2. 严格基于提供的信息,不要编造内容
|
||||
3. 综合利用事件信息和段落信息
|
||||
4. 如果信息不足以完整回答,请明确说明
|
||||
5. 答案要简洁明了,重点突出
|
||||
|
||||
答案:
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
# 动态加载最终答案提示词
|
||||
FINAL_ANSWER_PROMPT = _prompt_loader.get_prompt_template('final_answer')
|
||||
|
||||
# 保留原始定义作为备份(已注释)
|
||||
'''
|
||||
FINAL_ANSWER_PROMPT = PromptTemplate(
|
||||
input_variables=["original_query", "all_passages", "sub_queries"],
|
||||
template="""
|
||||
基于所有检索到的信息,请为用户的查询提供一个完整、准确的答案。
|
||||
|
||||
用户查询:{original_query}
|
||||
|
||||
子查询历史:{sub_queries}
|
||||
|
||||
所有检索到的信息:
|
||||
{all_passages}
|
||||
|
||||
请基于这些信息提供一个全面的答案。注意检索结果包含:
|
||||
1. 【事件信息】- 来自知识图谱的事件节点,提供事件相关的上下文
|
||||
2. 【段落信息】- 来自文档的段落内容,提供详细的文本描述
|
||||
|
||||
要求:
|
||||
1. 直接回答用户的查询
|
||||
2. 综合利用事件信息和段落信息,不要编造内容
|
||||
3. 如果信息仍然不足,请明确说明
|
||||
4. 答案要结构清晰,逻辑连贯
|
||||
|
||||
答案:
|
||||
"""
|
||||
)
|
||||
'''
|
||||
|
||||
|
||||
def create_oneapi_llm(
|
||||
oneapi_key: Optional[str] = None,
|
||||
oneapi_base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None
|
||||
) -> OneAPILLM:
|
||||
"""
|
||||
创建DashScope LLM的便捷函数(保持原有接口兼容性)
|
||||
|
||||
Args:
|
||||
oneapi_key: 阿里云DashScope API Key
|
||||
oneapi_base_url: 不再使用,保持兼容性
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
OneAPI LLM实例(内部使用DashScope)
|
||||
"""
|
||||
# 从环境变量获取配置,默认使用qwen-max
|
||||
api_key = oneapi_key or os.getenv("ONEAPI_KEY")
|
||||
model_name = model_name or os.getenv("ONEAPI_MODEL_MAX", "qwen-max")
|
||||
|
||||
# 创建DashScope LLM
|
||||
dashscope_llm = DashScopeLLM(
|
||||
api_key=api_key,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
return OneAPILLM(dashscope_llm)
|
||||
|
||||
|
||||
def format_passages(passages: List[str]) -> str:
|
||||
"""格式化段落列表为字符串"""
|
||||
if not passages:
|
||||
return "无相关段落"
|
||||
|
||||
formatted = []
|
||||
for i, passage in enumerate(passages, 1):
|
||||
formatted.append(f"段落{i}:{passage}")
|
||||
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
def format_mixed_passages(documents: List[Document]) -> str:
|
||||
"""格式化混合的文档列表(事件+段落)为字符串"""
|
||||
if not documents:
|
||||
return "无相关信息"
|
||||
|
||||
event_docs = [doc for doc in documents if doc.metadata.get('node_type') == 'event']
|
||||
text_docs = [doc for doc in documents if doc.metadata.get('node_type') == 'text']
|
||||
|
||||
formatted_parts = []
|
||||
|
||||
# 格式化事件信息
|
||||
if event_docs:
|
||||
formatted_parts.append("【事件信息】")
|
||||
for i, doc in enumerate(event_docs, 1):
|
||||
formatted_parts.append(f"事件{i}:{doc.page_content}")
|
||||
|
||||
# 格式化段落信息
|
||||
if text_docs:
|
||||
if formatted_parts: # 如果有事件信息,添加分隔
|
||||
formatted_parts.append("") # 空行分隔
|
||||
formatted_parts.append("【段落信息】")
|
||||
for i, doc in enumerate(text_docs, 1):
|
||||
formatted_parts.append(f"段落{i}:{doc.page_content}")
|
||||
|
||||
return "\n\n".join(formatted_parts)
|
||||
|
||||
|
||||
def format_sub_queries(sub_queries: List[str]) -> str:
|
||||
"""格式化子查询列表为字符串"""
|
||||
if not sub_queries:
|
||||
return "无子查询"
|
||||
|
||||
return "、".join(sub_queries)
|
||||
|
||||
|
||||
def format_triplets(triplets: List[Dict[str, str]]) -> str:
|
||||
"""格式化三元组列表为字符串"""
|
||||
if not triplets:
|
||||
return "无三元组信息"
|
||||
|
||||
formatted = []
|
||||
for i, triplet in enumerate(triplets, 1):
|
||||
source = triplet.get('source', 'unknown')
|
||||
relation = triplet.get('relation', 'unknown')
|
||||
target = triplet.get('target', 'unknown')
|
||||
formatted.append(f"{i}. ({source}) --[{relation}]--> ({target})")
|
||||
|
||||
return "\n".join(formatted)
|
||||
|
||||
|
||||
def format_exploration_path(path: List[Dict[str, str]]) -> str:
|
||||
"""格式化探索路径为字符串"""
|
||||
if not path:
|
||||
return "无探索路径"
|
||||
|
||||
path_str = []
|
||||
for step in path:
|
||||
source = step.get('source', 'unknown')
|
||||
relation = step.get('relation', 'unknown')
|
||||
target = step.get('target', 'unknown')
|
||||
reason = step.get('reason', '')
|
||||
path_str.append(f"({source}) --[{relation}]--> ({target})")
|
||||
if reason:
|
||||
path_str.append(f" 探索原因: {reason}")
|
||||
|
||||
return " -> ".join([step.split(" 探索原因:")[0] for step in path_str if not step.strip().startswith("探索原因:")])
|
||||
|
||||
|
||||
# 概念探索相关提示词模板
|
||||
|
||||
CONCEPT_EXPLORATION_INIT_PROMPT = PromptTemplate(
|
||||
input_variables=["node_name", "connected_concepts", "user_query", "insufficiency_reason"],
|
||||
template="""
|
||||
作为一个知识图谱概念探索专家,你需要分析一个Node节点及其连接的Concept节点,判断出最值得探索的概念方向。
|
||||
|
||||
**探索起点Node节点**: {node_name}
|
||||
|
||||
**该Node连接的Concept节点列表**:
|
||||
{connected_concepts}
|
||||
|
||||
**用户原始查询**: {user_query}
|
||||
|
||||
**上次充分性检查不通过的原因**: {insufficiency_reason}
|
||||
|
||||
**任务**: 根据用户查询和不充分的原因,从连接的Concept列表中选择一个最值得深入探索的概念,这个概念应该最有可能帮助回答用户查询或补充缺失的信息。
|
||||
|
||||
**判断标准**:
|
||||
1. 与用户查询的相关性最高
|
||||
2. 最有可能补充当前缺失的关键信息
|
||||
3. 具有较强的延展性,能够引出更多相关知识
|
||||
|
||||
请以JSON格式返回结果:
|
||||
{{
|
||||
"selected_concept": "最值得探索的概念名称",
|
||||
"exploration_reason": "选择这个概念的详细原因",
|
||||
"confidence": 0.9, // 置信度,0-1之间
|
||||
"expected_knowledge": "期望从这个概念探索中获得什么知识"
|
||||
}}
|
||||
|
||||
只返回JSON格式的结果,不要添加其他内容。
|
||||
"""
|
||||
)
|
||||
|
||||
CONCEPT_EXPLORATION_CONTINUE_PROMPT = PromptTemplate(
|
||||
input_variables=["current_node", "neighbor_triplets", "exploration_path", "user_query"],
|
||||
template="""
|
||||
作为一个知识图谱概念探索专家,你正在进行概念探索,需要决定下一步的探索方向。
|
||||
|
||||
**当前节点**: {current_node}
|
||||
|
||||
**当前节点的邻居三元组**:
|
||||
{neighbor_triplets}
|
||||
|
||||
**已有探索路径**:
|
||||
{exploration_path}
|
||||
|
||||
**用户原始查询**: {user_query}
|
||||
|
||||
**任务**: 从当前节点的邻居中选择一个最值得继续探索的节点,确保:
|
||||
1. 不重复之前探索过的节点
|
||||
2. 选择的节点与用户查询最相关
|
||||
3. 能够获得更深入的知识
|
||||
|
||||
请以JSON格式返回结果:
|
||||
{{
|
||||
"selected_node": "选择的下一个节点名称",
|
||||
"selected_relation": "到达该节点的关系",
|
||||
"exploration_reason": "选择这个节点的详细原因",
|
||||
"confidence": 0.9, // 置信度,0-1之间
|
||||
"expected_knowledge": "期望从这个节点探索中获得什么知识"
|
||||
}}
|
||||
|
||||
只返回JSON格式的结果,不要添加其他内容。
|
||||
"""
|
||||
)
|
||||
|
||||
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT = PromptTemplate(
|
||||
input_variables=["user_query", "all_passages", "exploration_knowledge", "insufficiency_reason"],
|
||||
template="""
|
||||
作为一个智能信息充分性评估专家,请综合评估当前信息是否足够回答用户查询。
|
||||
|
||||
**用户查询**: {user_query}
|
||||
|
||||
**现有检索信息**:
|
||||
{all_passages}
|
||||
|
||||
**概念探索获得的知识**:
|
||||
{exploration_knowledge}
|
||||
|
||||
**上次不充分的原因**: {insufficiency_reason}
|
||||
|
||||
**评估任务**:
|
||||
1. 综合分析检索信息(事件信息+段落信息)和探索知识是否能够完整回答用户查询
|
||||
2. 判断是否还有关键信息缺失
|
||||
3. 如果仍然不充分,明确指出还需要什么信息
|
||||
|
||||
**评估标准**:
|
||||
- 信息的完整性:能否覆盖查询的所有方面
|
||||
- 信息的准确性:提供的信息是否准确可靠
|
||||
- 信息的相关性:信息与查询的匹配度
|
||||
- 逻辑连贯性:信息间是否形成完整的逻辑链
|
||||
|
||||
请以JSON格式返回评估结果:
|
||||
{{
|
||||
"is_sufficient": false, // true表示信息充分,false表示不充分
|
||||
"confidence": 0.85, // 判断的置信度,0-1之间
|
||||
"reason": "详细说明充分或不充分的原因",
|
||||
"missing_aspects": ["缺失的关键信息1", "缺失的关键信息2"], // 如果不充分,列出缺失的关键方面
|
||||
"coverage_score": 0.7 // 当前信息对查询的覆盖程度,0-1之间
|
||||
}}
|
||||
|
||||
只返回JSON格式的结果,不要添加其他内容。
|
||||
"""
|
||||
)
|
||||
|
||||
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT = PromptTemplate(
|
||||
input_variables=["user_query", "missing_aspects", "exploration_results"],
|
||||
template="""
|
||||
作为一个智能查询分析专家,根据概念探索的结果和仍然缺失的信息,生成新的子查询来进行第二轮概念探索。
|
||||
|
||||
**原始用户查询**: {user_query}
|
||||
|
||||
**仍然缺失的关键信息**: {missing_aspects}
|
||||
|
||||
**第一轮探索结果**: {exploration_results}
|
||||
|
||||
**任务**: 生成一个更精确的子查询,这个查询应该:
|
||||
1. 专门针对缺失的关键信息
|
||||
2. 比原始查询更具体和聚焦
|
||||
3. 能够引导第二轮探索找到缺失的关键信息
|
||||
|
||||
**生成原则**:
|
||||
- 保持与原始查询的相关性
|
||||
- 专注于最关键的缺失方面
|
||||
- 使用明确、具体的表述
|
||||
- 避免过于宽泛或模糊
|
||||
|
||||
请以JSON格式返回结果:
|
||||
{{
|
||||
"sub_query": "针对缺失信息生成的新查询",
|
||||
"focus_aspects": ["查询重点关注的方面1", "方面2"],
|
||||
"expected_improvements": "期望这个子查询如何改善信息充分性",
|
||||
"confidence": 0.9
|
||||
}}
|
||||
|
||||
只返回JSON格式的结果,不要添加其他内容。
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# 导出列表
|
||||
__all__ = [
|
||||
'OneAPILLM',
|
||||
'SufficiencyCheckParser',
|
||||
'QueryComplexityParser',
|
||||
'ConceptExplorationParser',
|
||||
'ConceptContinueParser',
|
||||
'ConceptSufficiencyParser',
|
||||
'ConceptSubQueryParser',
|
||||
'QueryComplexityResult',
|
||||
'SufficiencyCheckResult',
|
||||
'ConceptExplorationResult',
|
||||
'ConceptContinueResult',
|
||||
'ConceptSufficiencyResult',
|
||||
'ConceptSubQueryResult',
|
||||
'QUERY_COMPLEXITY_CHECK_PROMPT',
|
||||
'SUFFICIENCY_CHECK_PROMPT',
|
||||
'QUERY_DECOMPOSITION_PROMPT',
|
||||
'SUB_QUERY_GENERATION_PROMPT',
|
||||
'SIMPLE_ANSWER_PROMPT',
|
||||
'FINAL_ANSWER_PROMPT',
|
||||
'CONCEPT_EXPLORATION_INIT_PROMPT',
|
||||
'CONCEPT_EXPLORATION_CONTINUE_PROMPT',
|
||||
'CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT',
|
||||
'CONCEPT_EXPLORATION_SUB_QUERY_PROMPT',
|
||||
'create_oneapi_llm',
|
||||
'format_passages',
|
||||
'format_mixed_passages',
|
||||
'format_sub_queries',
|
||||
'format_triplets',
|
||||
'format_exploration_path'
|
||||
]
|
||||
1082
AIEC-RAG/retriver/langgraph/langchain_hipporag_retriever.py
Normal file
1082
AIEC-RAG/retriver/langgraph/langchain_hipporag_retriever.py
Normal file
File diff suppressed because it is too large
Load Diff
107
AIEC-RAG/retriver/langgraph/routing_functions.py
Normal file
107
AIEC-RAG/retriver/langgraph/routing_functions.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""
|
||||
LangGraph条件边路由函数
|
||||
专门定义工作流的路由决策逻辑,与节点逻辑分离
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from retriver.langgraph.graph_state import QueryState
|
||||
|
||||
|
||||
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
||||
"""
|
||||
根据调试模式和查询复杂度决定路由
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
||||
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
||||
"""
|
||||
if state['is_complex_query']:
|
||||
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
|
||||
return "initial_retrieval"
|
||||
else:
|
||||
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
|
||||
return "simple_vector_retrieval"
|
||||
|
||||
|
||||
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
||||
"""
|
||||
根据查询复杂度决定路由
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
||||
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
||||
"""
|
||||
if state['is_complex_query']:
|
||||
print(f"[RELOAD] 复杂查询,进入HippoRAG2检索逻辑")
|
||||
return "initial_retrieval"
|
||||
else:
|
||||
print(f"[RELOAD] 简单查询,进入直接向量检索")
|
||||
return "simple_vector_retrieval"
|
||||
|
||||
|
||||
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval", "next_iteration"]:
|
||||
"""
|
||||
决策函数:决定迭代检索工作流的下一步
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "final_answer": 生成最终答案并结束
|
||||
- "parallel_retrieval": 执行子查询的并行检索
|
||||
- "next_iteration": 进入下一轮迭代
|
||||
|
||||
决策逻辑:
|
||||
1. 如果信息充分 -> "final_answer"
|
||||
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
|
||||
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
|
||||
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
|
||||
5. 其他情况 -> "next_iteration"
|
||||
"""
|
||||
# 如果信息充分,生成最终答案
|
||||
if state['is_sufficient']:
|
||||
print(f"[OK] 信息充分,生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
|
||||
if state['current_iteration'] >= state['max_iterations']:
|
||||
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 防止死循环:检查连续不充分的次数
|
||||
iteration_history = state.get('iteration_history', [])
|
||||
consecutive_insufficient = 0
|
||||
for i in range(len(iteration_history) - 1, -1, -1):
|
||||
if not iteration_history[i].get('is_sufficient', False):
|
||||
consecutive_insufficient += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果连续3次不充分且没有新的子查询,强制结束
|
||||
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
|
||||
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 如果还有子查询需要处理,进行并行检索
|
||||
if state['current_sub_queries']:
|
||||
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
|
||||
return "parallel_retrieval"
|
||||
|
||||
# 如果信息不充分且还没达到最大迭代次数,生成子查询
|
||||
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
|
||||
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
|
||||
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
|
||||
|
||||
# 否则,进行下一轮迭代(这种情况应该很少发生)
|
||||
print(f"[RELOAD] 开始下一轮迭代")
|
||||
return "next_iteration"
|
||||
|
||||
Reference in New Issue
Block a user