first commit

This commit is contained in:
闫旭隆
2025-09-25 10:33:37 +08:00
commit 34839c2654
387 changed files with 149159 additions and 0 deletions

View File

@ -0,0 +1,14 @@
"""
智能检索系统
基于LangChain + LangGraph实现的迭代检索框架
"""
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever, create_langchain_hipporag_retriever
from retriver.langgraph.iterative_retriever import IterativeRetriever, create_iterative_retriever
__all__ = [
"LangChainHippoRAGRetriever",
"IterativeRetriever",
"create_langchain_hipporag_retriever",
"create_iterative_retriever"
]

View File

@ -0,0 +1,699 @@
"""
概念探索节点实现
在PageRank收集后进行知识图谱概念探索
"""
import json
import asyncio
from typing import Dict, Any, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from retriver.langgraph.graph_state import QueryState
from retriver.langgraph.langchain_components import (
OneAPILLM,
ConceptExplorationParser,
ConceptContinueParser,
ConceptSufficiencyParser,
ConceptSubQueryParser,
CONCEPT_EXPLORATION_INIT_PROMPT,
CONCEPT_EXPLORATION_CONTINUE_PROMPT,
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT,
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT,
format_passages,
format_mixed_passages,
format_triplets,
format_exploration_path
)
class ConceptExplorationNodes:
"""概念探索节点实现类"""
def __init__(
self,
retriever, # LangChainHippoRAGRetriever
llm: OneAPILLM,
max_exploration_steps: int = 3,
max_parallel_explorations: int = 3
):
"""
初始化概念探索节点处理器
Args:
retriever: HippoRAG检索器用于访问知识图谱
llm: OneAPI LLM
max_exploration_steps: 每个分支最大探索步数
max_parallel_explorations: 最大并行探索分支数
"""
self.retriever = retriever
self.llm = llm
self.max_exploration_steps = max_exploration_steps
self.max_parallel_explorations = max_parallel_explorations
# 创建解析器
self.concept_exploration_parser = ConceptExplorationParser()
self.concept_continue_parser = ConceptContinueParser()
self.concept_sufficiency_parser = ConceptSufficiencyParser()
self.concept_sub_query_parser = ConceptSubQueryParser()
def _extract_response_text(self, response):
"""统一的响应文本提取方法"""
if hasattr(response, 'generations') and response.generations:
return response.generations[0][0].text
elif hasattr(response, 'content'):
return response.content
elif isinstance(response, dict) and 'response' in response:
return response['response']
else:
return str(response)
def _get_top_node_candidates(self, state: QueryState) -> List[Dict[str, Any]]:
"""
从PageRank结果中提取TOP-3个Label为Node的节点
Args:
state: 当前状态
Returns:
TOP-3 Node节点列表每个包含node_id, score, label, type信息
"""
print("[SEARCH] 从PageRank结果中提取TOP-3 Node节点")
pagerank_summary = state.get('pagerank_summary', {})
if not pagerank_summary:
print("[ERROR] 没有找到PageRank汇总信息")
return []
# 获取所有带标签信息的节点
all_nodes_with_labels = pagerank_summary.get('all_nodes_with_labels', [])
if not all_nodes_with_labels:
print("[ERROR] 没有找到带标签的节点信息")
return []
# 筛选出Label为Node的节点
node_candidates = []
for node_info in all_nodes_with_labels:
node_type = node_info.get('type', '').lower()
label = node_info.get('label', '')
# 检查是否为Node类型的节点
if 'node' in node_type.lower() or node_type == 'node':
node_candidates.append(node_info)
# 按分数降序排列并取前3个
node_candidates.sort(key=lambda x: x.get('score', 0), reverse=True)
top_3_nodes = node_candidates[:3]
print(f"[OK] 找到 {len(node_candidates)} 个Node节点选择TOP-3:")
for i, node in enumerate(top_3_nodes):
node_id = node.get('node_id', 'unknown')
score = node.get('score', 0)
label = node.get('label', 'unknown')
print(f" {i+1}. {node_id} ('{label}') - Score: {score:.6f}")
return top_3_nodes
def _get_connected_concepts(self, node_id: str) -> List[str]:
"""
获取指定Node节点连接的所有Concept节点
Args:
node_id: 节点ID
Returns:
连接的Concept节点名称列表
"""
try:
# 通过检索器访问图数据获取邻居
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居")
return []
connected_concepts = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居节点
neighbors = list(graph.neighbors(node_id))
for neighbor_id in neighbors:
neighbor_data = graph.nodes.get(neighbor_id, {})
neighbor_type = neighbor_data.get('type', '').lower()
neighbor_label = neighbor_data.get('label', neighbor_data.get('name', neighbor_id))
# 筛选Label为Concept的邻居
if 'concept' in neighbor_type.lower() or neighbor_type == 'concept':
connected_concepts.append(neighbor_label)
print(f"[OK] 节点 {node_id} 连接的Concept数量: {len(connected_concepts)}")
return connected_concepts
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 连接的概念失败: {e}")
return []
def _get_neighbor_triplets(self, node_id: str, excluded_nodes: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
获取指定节点的所有邻居三元组,排除已访问的节点
Args:
node_id: 当前节点ID
excluded_nodes: 要排除的节点ID列表
Returns:
三元组列表每个三元组包含source, relation, target信息
"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
print(f"[ERROR] 无法访问图数据获取节点 {node_id} 的邻居三元组")
return []
if excluded_nodes is None:
excluded_nodes = []
triplets = []
graph = self.retriever.graph_data
# 检查节点是否存在
if node_id not in graph.nodes:
print(f"[ERROR] 节点 {node_id} 不存在于图中")
return []
# 获取所有邻居边
for neighbor_id in graph.neighbors(node_id):
# 跳过已访问的节点
if neighbor_id in excluded_nodes:
continue
# 获取边的关系信息
edge_data = graph.get_edge_data(node_id, neighbor_id, {})
relation = edge_data.get('relation', edge_data.get('label', 'related_to'))
# 获取节点标签
source_label = graph.nodes.get(node_id, {}).get('label',
graph.nodes.get(node_id, {}).get('name', node_id))
target_label = graph.nodes.get(neighbor_id, {}).get('label',
graph.nodes.get(neighbor_id, {}).get('name', neighbor_id))
triplets.append({
'source': source_label,
'relation': relation,
'target': target_label,
'source_id': node_id,
'target_id': neighbor_id
})
print(f"[OK] 获取节点 {node_id} 的邻居三元组数量: {len(triplets)}")
return triplets
except Exception as e:
print(f"[ERROR] 获取节点 {node_id} 的邻居三元组失败: {e}")
return []
def _explore_single_branch(self, node_info: Dict[str, Any], user_query: str, insufficiency_reason: str) -> Dict[str, Any]:
"""
探索单个分支的完整路径
Args:
node_info: 起始Node节点信息
user_query: 用户查询
insufficiency_reason: 不充分原因
Returns:
探索结果字典
"""
node_id = node_info.get('node_id', '')
node_label = node_info.get('label', '')
print(f"[STARTING] 开始探索分支: {node_label} ({node_id})")
exploration_result = {
'start_node': node_info,
'exploration_path': [],
'visited_nodes': [node_id],
'knowledge_gained': [],
'success': False,
'error': None
}
try:
# 步骤1: 获取连接的Concept节点并选择最佳探索方向
connected_concepts = self._get_connected_concepts(node_id)
if not connected_concepts:
print(f"[WARNING] 节点 {node_label} 没有连接的Concept节点")
exploration_result['error'] = "没有连接的Concept节点"
return exploration_result
# 使用LLM选择最佳概念
concept_selection = self._select_best_concept(node_label, connected_concepts, user_query, insufficiency_reason)
if not concept_selection or not concept_selection.selected_concept:
print(f"[WARNING] 未能为节点 {node_label} 选择到有效概念")
exploration_result['error'] = "未能选择有效概念"
return exploration_result
selected_concept = concept_selection.selected_concept
print(f"[POINT] 选择探索概念: {selected_concept}")
# 记录初始探索决策
exploration_result['knowledge_gained'].append({
'step': 0,
'action': 'concept_selection',
'selected_concept': selected_concept,
'reason': concept_selection.exploration_reason,
'expected_knowledge': concept_selection.expected_knowledge
})
# 找到对应的概念节点ID
concept_node_id = self._find_concept_node_id(selected_concept)
if not concept_node_id:
print(f"[WARNING] 未找到概念 {selected_concept} 对应的节点ID")
exploration_result['error'] = f"未找到概念节点ID: {selected_concept}"
return exploration_result
# 开始多跳探索
current_node_id = concept_node_id
exploration_result['visited_nodes'].append(current_node_id)
for step in range(1, self.max_exploration_steps + 1):
print(f"[SEARCH] 探索步骤 {step}/{self.max_exploration_steps}")
# 获取当前节点的邻居三元组,排除已访问节点
neighbor_triplets = self._get_neighbor_triplets(current_node_id, exploration_result['visited_nodes'])
if not neighbor_triplets:
print(f"[WARNING] 节点 {current_node_id} 没有未访问的邻居,探索结束")
break
# 使用LLM选择下一个探索方向
exploration_path_str = format_exploration_path(exploration_result['exploration_path'])
continue_result = self._continue_exploration(current_node_id, neighbor_triplets, exploration_path_str, user_query)
if not continue_result or not continue_result.selected_node:
print(f"[WARNING] 步骤 {step} 未能选择下一个探索节点")
break
# 找到选择节点的ID
next_node_id = self._find_node_id_by_label(continue_result.selected_node)
if not next_node_id or next_node_id in exploration_result['visited_nodes']:
print(f"[WARNING] 选择的节点无效或已访问: {continue_result.selected_node}")
break
# 记录探索步骤
step_info = {
'source': current_node_id,
'relation': continue_result.selected_relation,
'target': next_node_id,
'reason': continue_result.exploration_reason
}
exploration_result['exploration_path'].append(step_info)
# 记录获得的知识
exploration_result['knowledge_gained'].append({
'step': step,
'action': 'exploration_step',
'selected_node': continue_result.selected_node,
'relation': continue_result.selected_relation,
'reason': continue_result.exploration_reason,
'expected_knowledge': continue_result.expected_knowledge
})
# 移动到下一个节点
current_node_id = next_node_id
exploration_result['visited_nodes'].append(current_node_id)
print(f"[OK] 步骤 {step} 完成: {continue_result.selected_node}")
exploration_result['success'] = True
print(f"[SUCCESS] 分支探索完成: {len(exploration_result['exploration_path'])}")
except Exception as e:
print(f"[ERROR] 分支探索失败: {e}")
exploration_result['error'] = str(e)
return exploration_result
def _select_best_concept(self, node_name: str, connected_concepts: List[str], user_query: str, insufficiency_reason: str):
"""使用LLM选择最佳概念进行探索"""
concepts_str = "\n".join([f"- {concept}" for concept in connected_concepts])
prompt = CONCEPT_EXPLORATION_INIT_PROMPT.format(
node_name=node_name,
connected_concepts=concepts_str,
user_query=user_query,
insufficiency_reason=insufficiency_reason
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_exploration_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 概念选择失败: {e}")
return None
def _continue_exploration(self, current_node: str, neighbor_triplets: List[Dict[str, str]], exploration_path: str, user_query: str):
"""使用LLM决定下一步探索方向"""
triplets_str = format_triplets(neighbor_triplets)
prompt = CONCEPT_EXPLORATION_CONTINUE_PROMPT.format(
current_node=current_node,
neighbor_triplets=triplets_str,
exploration_path=exploration_path,
user_query=user_query
)
try:
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
return self.concept_continue_parser.parse(response_text)
except Exception as e:
print(f"[ERROR] 探索继续失败: {e}")
return None
def _find_concept_node_id(self, concept_label: str) -> Optional[str]:
"""根据概念标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
node_label = node_data.get('label', node_data.get('name', ''))
node_type = node_data.get('type', '').lower()
if (node_label == concept_label or node_id == concept_label) and 'concept' in node_type:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找概念节点ID失败: {e}")
return None
def _find_node_id_by_label(self, node_label: str) -> Optional[str]:
"""根据节点标签找到对应的节点ID"""
try:
if not hasattr(self.retriever, 'graph_data') or not self.retriever.graph_data:
return None
graph = self.retriever.graph_data
for node_id, node_data in graph.nodes(data=True):
label = node_data.get('label', node_data.get('name', ''))
if label == node_label or node_id == node_label:
return node_id
return None
except Exception as e:
print(f"[ERROR] 查找节点ID失败: {e}")
return None
def concept_exploration_node(self, state: QueryState) -> QueryState:
"""
概念探索节点
1. 提取TOP-3 Node节点
2. 并行探索3个分支
3. 合并探索结果
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 开始概念探索 (轮次 {state['exploration_round'] + 1})")
# 更新探索轮次
state['exploration_round'] += 1
# 获取TOP-3 Node节点
top_nodes = self._get_top_node_candidates(state)
if not top_nodes:
print("[ERROR] 未找到合适的Node节点进行探索")
state['concept_exploration_results'] = {'error': '未找到合适的Node节点'}
return state
# 获取不充分原因
insufficiency_reason = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
user_query = state['original_query']
# 如果是第二轮探索,使用生成的子查询
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
user_query = sub_query_info['sub_query']
print(f"[TARGET] 第二轮探索使用子查询: {user_query}")
# 并行探索多个分支
exploration_results = []
with ThreadPoolExecutor(max_workers=min(len(top_nodes), self.max_parallel_explorations)) as executor:
# 提交探索任务
futures = {
executor.submit(self._explore_single_branch, node_info, user_query, insufficiency_reason): node_info
for node_info in top_nodes
}
# 收集结果
for future in as_completed(futures):
node_info = futures[future]
try:
result = future.result()
exploration_results.append(result)
print(f"[OK] 分支 {node_info.get('label', 'unknown')} 探索完成")
except Exception as e:
print(f"[ERROR] 分支 {node_info.get('label', 'unknown')} 探索失败: {e}")
exploration_results.append({
'start_node': node_info,
'success': False,
'error': str(e)
})
# 合并探索结果
round_key = f'round_{state["exploration_round"]}_results'
state['concept_exploration_results'][round_key] = {
'exploration_results': exploration_results,
'total_branches': len(top_nodes),
'successful_branches': len([r for r in exploration_results if r.get('success', False)]),
'query_used': user_query
}
print(f"[SUCCESS] 第{state['exploration_round']}轮概念探索完成")
print(f" 总分支: {len(top_nodes)}, 成功: {len([r for r in exploration_results if r.get('success', False)])}")
return state
def concept_exploration_sufficiency_check_node(self, state: QueryState) -> QueryState:
"""
概念探索的充分性检查节点
综合评估文档段落和探索知识是否足够回答查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[THINK] 执行概念探索充分性检查")
# 格式化现有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化探索获得的知识
exploration_knowledge = self._format_exploration_knowledge(state['concept_exploration_results'])
# 获取原始不充分原因
original_insufficiency = state.get('sufficiency_check', {}).get('reason', '信息不够充分')
# 使用的查询(可能是子查询)
current_query = state['original_query']
if state['exploration_round'] == 2:
sub_query_info = state['concept_exploration_results'].get('round_1_results', {}).get('sub_query_info', {})
if sub_query_info and sub_query_info.get('sub_query'):
current_query = sub_query_info['sub_query']
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT.format(
user_query=current_query,
all_passages=formatted_passages,
exploration_knowledge=exploration_knowledge,
insufficiency_reason=original_insufficiency
)
try:
# 调用LLM进行充分性检查
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sufficiency_result = self.concept_sufficiency_parser.parse(response_text)
# 更新状态
state['is_sufficient'] = sufficiency_result.is_sufficient
# 存储概念探索的充分性检查结果
round_key = f'round_{state["exploration_round"]}_results'
if round_key in state['concept_exploration_results']:
state['concept_exploration_results'][round_key]['sufficiency_check'] = {
'is_sufficient': sufficiency_result.is_sufficient,
'confidence': sufficiency_result.confidence,
'reason': sufficiency_result.reason,
'missing_aspects': sufficiency_result.missing_aspects,
'coverage_score': sufficiency_result.coverage_score
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 概念探索充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'}")
print(f" 置信度: {sufficiency_result.confidence:.2f}, 覆盖度: {sufficiency_result.coverage_score:.2f}")
if not sufficiency_result.is_sufficient:
print(f" 缺失方面: {sufficiency_result.missing_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索充分性检查失败: {e}")
# 如果检查失败,假设不充分
state['is_sufficient'] = False
return state
def concept_sub_query_generation_node(self, state: QueryState) -> QueryState:
"""
概念探索子查询生成节点
如果第一轮探索后仍不充分,生成针对性子查询进行第二轮探索
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 生成概念探索子查询")
# 获取第一轮探索结果和缺失信息
round_1_results = state['concept_exploration_results'].get('round_1_results', {})
sufficiency_info = round_1_results.get('sufficiency_check', {})
missing_aspects = sufficiency_info.get('missing_aspects', ['需要更多相关信息'])
# 格式化第一轮探索结果
exploration_summary = self._summarize_exploration_results(round_1_results)
# 构建提示词
prompt = CONCEPT_EXPLORATION_SUB_QUERY_PROMPT.format(
user_query=state['original_query'],
missing_aspects=str(missing_aspects),
exploration_results=exploration_summary
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析结果
sub_query_result = self.concept_sub_query_parser.parse(response_text)
# 存储子查询信息
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': sub_query_result.sub_query,
'focus_aspects': sub_query_result.focus_aspects,
'expected_improvements': sub_query_result.expected_improvements,
'confidence': sub_query_result.confidence
}
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[OK] 生成概念探索子查询: {sub_query_result.sub_query}")
print(f" 关注方面: {sub_query_result.focus_aspects}")
return state
except Exception as e:
print(f"[ERROR] 概念探索子查询生成失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_query = state['original_query'] + " 补充详细信息"
state['concept_exploration_results']['round_1_results']['sub_query_info'] = {
'sub_query': default_sub_query,
'focus_aspects': ['补充信息'],
'expected_improvements': '获取更多相关信息',
'confidence': 0.3
}
return state
def _format_exploration_knowledge(self, exploration_results: Dict[str, Any]) -> str:
"""格式化探索获得的知识为字符串"""
if not exploration_results:
return "无探索知识"
knowledge_parts = []
for round_key, round_data in exploration_results.items():
if not round_key.startswith('round_'):
continue
round_num = round_key.split('_')[1]
knowledge_parts.append(f"\n=== 第{round_num}轮探索结果 ===")
exploration_list = round_data.get('exploration_results', [])
for i, branch in enumerate(exploration_list):
if not branch.get('success', False):
continue
start_node = branch.get('start_node', {})
knowledge_parts.append(f"\n分支{i+1} (起点: {start_node.get('label', 'unknown')}):")
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
action = knowledge.get('action', '')
if action == 'concept_selection':
concept = knowledge.get('selected_concept', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 选择探索概念: {concept}")
knowledge_parts.append(f" 原因: {reason}")
elif action == 'exploration_step':
node = knowledge.get('selected_node', '')
relation = knowledge.get('relation', '')
reason = knowledge.get('reason', '')
knowledge_parts.append(f"- 探索到: {node} (关系: {relation})")
knowledge_parts.append(f" 原因: {reason}")
return "\n".join(knowledge_parts) if knowledge_parts else "无有效探索知识"
def _summarize_exploration_results(self, round_results: Dict[str, Any]) -> str:
"""总结探索结果为简洁描述"""
if not round_results:
return "无探索结果"
exploration_list = round_results.get('exploration_results', [])
successful_count = len([r for r in exploration_list if r.get('success', False)])
total_count = len(exploration_list)
summary_parts = [
f"第一轮探索: {successful_count}/{total_count} 个分支成功",
]
# 提取主要发现
key_findings = []
for branch in exploration_list:
if not branch.get('success', False):
continue
knowledge_gained = branch.get('knowledge_gained', [])
for knowledge in knowledge_gained:
if knowledge.get('action') == 'concept_selection':
concept = knowledge.get('selected_concept', '')
if concept:
key_findings.append(f"探索了概念: {concept}")
if key_findings:
summary_parts.append("主要发现:")
summary_parts.extend([f"- {finding}" for finding in key_findings[:3]]) # 最多3个
return "\n".join(summary_parts)

View File

@ -0,0 +1,269 @@
"""
阿里云DashScope嵌入模型实现
直接调用阿里云文本嵌入API不通过OneAPI兼容层
"""
import os
import sys
import json
import time
import requests
import numpy as np
from typing import List, Union, Optional
from dotenv import load_dotenv
# 添加项目根目录到路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
# 加载环境变量
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
class DashScopeEmbeddingModel(BaseEmbeddingModel):
"""
阿里云DashScope嵌入模型实现
直接调用DashScope文本嵌入API
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0
):
"""
初始化DashScope嵌入模型
Args:
api_key: 阿里云DashScope API Key
model_name: 嵌入模型名称,如 text-embedding-v3
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
"""
# 初始化父类传入None作为sentence_encoder我们不使用它
super().__init__(sentence_encoder=None)
self.api_key = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
self.model_name = model_name or os.getenv('ONEAPI_MODEL_EMBED', 'text-embedding-v3')
self.max_retries = max_retries
self.retry_delay = retry_delay
if not self.api_key:
raise ValueError("DashScope API Key未设置请在.env文件中设置ONEAPI_KEY")
# DashScope文本嵌入API端点
self.api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
# 设置请求头
self.headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json'
}
print(f"[OK] DashScope嵌入模型初始化完成: {self.model_name}")
self._test_connection()
def encode(self, texts, **kwargs):
"""
实现基类的抽象方法
对文本进行编码,返回嵌入向量
Args:
texts: 文本或文本列表
**kwargs: 其他参数batch_size, show_progress_bar, query_type等
Returns:
嵌入向量或嵌入向量列表numpy数组格式
"""
import numpy as np
if isinstance(texts, str):
# 单个文本
result = self.embed_query(texts)
return np.array(result)
else:
# 文本列表
results = self.embed_texts(texts)
return np.array(results)
def _test_connection(self):
"""测试DashScope嵌入API连接"""
try:
# 发送一个简单的测试请求
test_payload = {
"model": self.model_name,
"input": {
"texts": ["test"]
}
}
response = requests.post(
self.api_url,
headers=self.headers,
json=test_payload,
timeout=(5, 15) # 短超时进行连接测试
)
if response.status_code == 200:
result = response.json()
if result.get('output') and result['output'].get('embeddings'):
print("[OK] DashScope嵌入API连接测试成功")
else:
print(f"[WARNING] DashScope嵌入API响应格式异常: {result}")
else:
print(f"[WARNING] DashScope嵌入API连接异常状态码: {response.status_code}")
print(f" 响应: {response.text[:200]}")
except requests.exceptions.Timeout:
print("[WARNING] DashScope嵌入API连接超时但将在实际请求时重试")
except requests.exceptions.ConnectionError:
print("[WARNING] DashScope嵌入API连接失败请检查网络状态")
except Exception as e:
print(f"[WARNING] DashScope嵌入API连接测试出错: {e}")
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
对文本列表进行嵌入
Args:
texts: 文本列表
Returns:
嵌入向量列表
"""
import numpy as np
if not texts:
return []
# DashScope API支持批量处理但需要分批处理大量文本
batch_size = 10 # 每批处理10个文本
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
batch_embeddings = self._embed_batch(batch_texts)
all_embeddings.extend(batch_embeddings)
return all_embeddings
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
"""对一批文本进行嵌入"""
payload = {
"model": self.model_name,
"input": {
"texts": texts
}
}
for attempt in range(self.max_retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=(30, 120) # 连接30秒读取120秒
)
if response.status_code == 200:
result = response.json()
# 检查是否有错误
if result.get('code'):
error_msg = result.get('message', f'API错误代码: {result["code"]}')
if attempt == self.max_retries - 1:
raise RuntimeError(f"DashScope嵌入API错误: {error_msg}")
else:
print(f"API错误正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
continue
# 提取嵌入向量
if result.get('output') and result['output'].get('embeddings'):
embeddings_data = result['output']['embeddings']
# 提取向量数据
embeddings = []
for embedding_item in embeddings_data:
if isinstance(embedding_item, dict) and 'embedding' in embedding_item:
embeddings.append(embedding_item['embedding'])
elif isinstance(embedding_item, list):
embeddings.append(embedding_item)
else:
print(f"[WARNING] 未知的嵌入格式: {type(embedding_item)}")
return embeddings
else:
error_msg = f"API响应格式错误: {result}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
time.sleep(self.retry_delay * (attempt + 1))
else:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败状态码: {response.status_code}, 响应: {error_text}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
time.sleep(self.retry_delay * (attempt + 1))
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断请求")
raise KeyboardInterrupt("用户中断请求")
except requests.exceptions.Timeout as e:
error_msg = f"请求超时: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
else:
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.exceptions.ConnectionError as e:
error_msg = f"连接错误: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
else:
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.RequestException as e:
error_msg = f"网络请求异常: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
else:
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
time.sleep(self.retry_delay * (attempt + 1))
raise RuntimeError("所有重试都失败了")
def embed_query(self, text: str) -> List[float]:
"""
对单个查询文本进行嵌入
Args:
text: 查询文本
Returns:
嵌入向量
"""
embeddings = self.embed_texts([text])
return embeddings[0] if embeddings else []
def create_dashscope_embedding_model(
api_key: Optional[str] = None,
model_name: Optional[str] = None,
**kwargs
) -> DashScopeEmbeddingModel:
"""创建DashScope嵌入模型实例的便捷函数"""
return DashScopeEmbeddingModel(
api_key=api_key,
model_name=model_name,
**kwargs
)

View File

@ -0,0 +1,605 @@
"""
阿里云DashScope原生LLM实现
直接调用阿里云通义千问API不通过OneAPI兼容层
"""
import os
import sys
import json
import time
import requests
from typing import List, Dict, Any, Optional, Union, Iterator, Callable
from langchain_core.language_models import BaseLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation
from dotenv import load_dotenv
# 加载环境变量
load_dotenv(os.path.join(os.path.dirname(__file__), '..', '..', '.env'))
class DashScopeLLM(BaseLLM):
"""
阿里云DashScope原生LLM实现
直接调用通义千问API
"""
# Pydantic字段定义
api_key: str = ""
model_name: str = "qwen-turbo"
max_retries: int = 3
retry_delay: float = 1.0
api_url: str = ""
headers: dict = {}
last_token_usage: dict = {}
total_token_usage: dict = {}
def __init__(
self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
**kwargs
):
"""
初始化DashScope LLM
Args:
api_key: 阿里云DashScope API Key
model_name: 模型名称,如 qwen-turbo, qwen-plus, qwen-max
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
"""
# 先设置字段值
api_key_value = api_key or os.getenv('ONEAPI_KEY') # 复用现有环境变量
model_name_value = model_name or os.getenv('ONEAPI_MODEL_GEN', 'qwen-turbo')
if not api_key_value:
raise ValueError("DashScope API Key未设置请在.env文件中设置ONEAPI_KEY")
# DashScope API端点
api_url_value = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
# 设置请求头
headers_value = {
'Authorization': f'Bearer {api_key_value}',
'Content-Type': 'application/json',
'X-DashScope-SSE': 'disable' # 禁用流式响应
}
# Token使用统计
last_token_usage_value = {}
total_token_usage_value = {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
'call_count': 0
}
# 初始化父类,传递所有字段
super().__init__(
api_key=api_key_value,
model_name=model_name_value,
max_retries=max_retries,
retry_delay=retry_delay,
api_url=api_url_value,
headers=headers_value,
last_token_usage=last_token_usage_value,
total_token_usage=total_token_usage_value,
**kwargs
)
print(f"[OK] DashScope LLM初始化完成: {self.model_name}")
self._test_connection()
def _test_connection(self):
"""测试DashScope连接"""
try:
# 发送一个简单的测试请求
test_payload = {
"model": self.model_name,
"input": {
"messages": [
{"role": "user", "content": "hello"}
]
},
"parameters": {
"max_tokens": 10,
"temperature": 0.1
}
}
response = requests.post(
self.api_url,
headers=self.headers,
json=test_payload,
timeout=(5, 15) # 短超时进行连接测试
)
if response.status_code == 200:
result = response.json()
if result.get('output') and result['output'].get('text'):
print("[OK] DashScope连接测试成功")
else:
print(f"[WARNING] DashScope响应格式异常: {result}")
else:
print(f"[WARNING] DashScope连接异常状态码: {response.status_code}")
print(f" 响应: {response.text[:200]}")
except requests.exceptions.Timeout:
print("[WARNING] DashScope连接超时但将在实际请求时重试")
except requests.exceptions.ConnectionError:
print("[WARNING] DashScope连接失败请检查网络状态")
except Exception as e:
print(f"[WARNING] DashScope连接测试出错: {e}")
@property
def _llm_type(self) -> str:
return "dashscope"
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""生成响应"""
generations = []
for prompt in prompts:
try:
response_text = self._call_api(prompt, **kwargs)
generations.append([Generation(text=response_text)])
except Exception as e:
print(f"[ERROR] DashScope API调用失败: {e}")
generations.append([Generation(text=f"API调用失败: {str(e)}")])
return LLMResult(generations=generations)
def _call_api(self, prompt: str, **kwargs) -> str:
"""调用DashScope API"""
# 构建请求payload
payload = {
"model": self.model_name,
"input": {
"messages": [
{"role": "user", "content": prompt}
]
},
"parameters": {
"max_tokens": kwargs.get("max_tokens", 2048),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.8),
}
}
# 如果有stop参数添加到payload中
stop = kwargs.get("stop")
if stop:
payload["parameters"]["stop"] = stop
for attempt in range(self.max_retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=(30, 120) # 连接30秒读取120秒
)
if response.status_code == 200:
result = response.json()
# 检查是否有错误
if result.get('code'):
error_msg = result.get('message', f'API错误代码: {result["code"]}')
if attempt == self.max_retries - 1:
raise RuntimeError(f"DashScope API错误: {error_msg}")
else:
print(f"API错误正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
continue
# 提取响应内容
if result.get('output') and result['output'].get('text'):
content = result['output']['text']
# 提取Token使用信息
usage_info = result.get('usage', {})
self.last_token_usage = {
'prompt_tokens': usage_info.get('input_tokens', 0),
'completion_tokens': usage_info.get('output_tokens', 0),
'total_tokens': usage_info.get('total_tokens', 0)
}
# 更新累计统计
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
self.total_token_usage['call_count'] += 1
return content.strip()
else:
error_msg = f"API响应格式错误: {result}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"响应格式错误,正在重试 ({attempt + 2}/{self.max_retries})")
time.sleep(self.retry_delay * (attempt + 1))
else:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败状态码: {response.status_code}, 响应: {error_text}"
if attempt == self.max_retries - 1:
raise RuntimeError(error_msg)
else:
print(f"请求失败,正在重试 ({attempt + 2}/{self.max_retries}): 状态码 {response.status_code}")
time.sleep(self.retry_delay * (attempt + 1))
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断请求")
raise KeyboardInterrupt("用户中断请求")
except requests.exceptions.Timeout as e:
error_msg = f"请求超时: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍超时: {error_msg}")
else:
print(f"请求超时,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.exceptions.ConnectionError as e:
error_msg = f"连接错误: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍无法连接: {error_msg}")
else:
print(f"连接错误,正在重试 ({attempt + 2}/{self.max_retries}): {error_msg}")
time.sleep(self.retry_delay * (attempt + 1))
except requests.RequestException as e:
error_msg = f"网络请求异常: {str(e)}"
if attempt == self.max_retries - 1:
raise RuntimeError(f"经过 {self.max_retries} 次重试后仍失败: {error_msg}")
else:
print(f"网络异常,正在重试 ({attempt + 2}/{self.max_retries}): {str(e)[:100]}")
time.sleep(self.retry_delay * (attempt + 1))
raise RuntimeError("所有重试都失败了")
def generate_response(self, messages, max_new_tokens=4096, temperature=0.7, response_format=None, **kwargs):
"""
生成响应,兼容原始接口
Args:
messages: 消息列表或批量消息列表
max_new_tokens: 最大生成token数
temperature: 温度参数
response_format: 响应格式
**kwargs: 其他参数
Returns:
生成的文本或文本列表
"""
# 检查是否为批量处理
is_batch = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list)
if is_batch:
# 批量处理
results = []
for msg_list in messages:
try:
# 构造prompt
prompt = self._format_messages(msg_list)
result = self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
results.append(result)
except Exception as e:
print(f"批量处理中的单个请求失败: {str(e)}")
results.append("")
return results
else:
# 单个处理
prompt = self._format_messages(messages)
return self._call_api(prompt, max_tokens=max_new_tokens, temperature=temperature)
def _format_messages(self, messages):
"""将消息列表格式化为单一的prompt"""
formatted_parts = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted_parts.append(f"System: {content}")
elif role == 'user':
formatted_parts.append(f"User: {content}")
elif role == 'assistant':
formatted_parts.append(f"Assistant: {content}")
return "\n\n".join(formatted_parts)
def ner(self, text: str) -> str:
"""
命名实体识别
Args:
text: 输入文本
Returns:
提取的实体,用逗号分隔
"""
messages = [
{
"role": "system",
"content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."
},
{
"role": "user",
"content": f"Extract the named entities from: {text}"
}
]
try:
return self.generate_response(messages, max_new_tokens=1024, temperature=0.7)
except Exception as e:
print(f"NER失败: {str(e)}")
return text # 如果失败,返回原文本
def filter_triples_with_entity_event(self, question: str, triples: str) -> str:
"""
基于实体和事件过滤三元组
Args:
question: 查询问题
triples: 三元组JSON字符串
Returns:
过滤后的三元组JSON字符串
"""
import json
import json_repair
# 改进的过滤提示词 - 中文版本,强调严格子集过滤
filter_messages = [
{
"role": "system",
"content": """你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
关键要求:
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
2. 你的输出必须是输入事实的严格子集
3. 只包含与回答问题直接相关的事实
4. 如果不确定,宁可选择更少的事实,也不要选择更多
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
过滤规则:
- 只选择包含与问题直接相关的实体或关系的事实
- 不能修改、改写或创建输入事实的变体
- 不能添加看起来相关但不在输入中的事实
- 输出事实数量必须 ≤ 输入事实数量
返回格式为包含"fact"键的JSON对象值为选中的事实数组。
示例:
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
问题A和B是什么关系
正确输出:{"fact": [["A", "关系1", "B"]]}
错误做法:添加新事实或修改现有事实"""
},
{
"role": "user",
"content": f"""问题:{question}
待筛选的输入事实:
{triples}
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!"""
}
]
try:
response = self.generate_response(
filter_messages,
max_new_tokens=4096,
temperature=0.0
)
# 尝试解析JSON响应
try:
parsed_response = json.loads(response)
if 'fact' in parsed_response:
return json.dumps(parsed_response, ensure_ascii=False)
else:
# 如果没有fact字段尝试修复
return json.dumps({"fact": []}, ensure_ascii=False)
except json.JSONDecodeError:
# 如果JSON解析失败尝试使用json_repair
try:
parsed_response = json_repair.loads(response)
if 'fact' in parsed_response:
return json.dumps(parsed_response, ensure_ascii=False)
else:
return json.dumps({"fact": []}, ensure_ascii=False)
except:
# 如果所有解析都失败,返回空结果
return json.dumps({"fact": []}, ensure_ascii=False)
except Exception as e:
print(f"三元组过滤失败: {str(e)}")
# 如果过滤失败,返回原始三元组
return triples
def generate_with_context(self,
question: str,
context: str,
max_new_tokens: int = 1024,
temperature: float = 0.7) -> str:
"""
基于上下文生成回答
Args:
question: 问题
context: 上下文
max_new_tokens: 最大生成token数
temperature: 温度参数
Returns:
生成的回答
"""
messages = [
{
"role": "system",
"content": "You are a helpful assistant. Answer the question based on the provided context. Think step by step."
},
{
"role": "user",
"content": f"{context}\n\n{question}\nThought:"
}
]
try:
return self.generate_response(messages, max_new_tokens, temperature)
except Exception as e:
print(f"基于上下文生成失败: {str(e)}")
return "抱歉,我无法基于提供的上下文回答这个问题。"
def get_token_usage(self) -> Dict[str, Any]:
"""获取Token使用统计"""
return {
"last_usage": self.last_token_usage.copy(),
"total_usage": self.total_token_usage.copy(),
"model_name": self.model_name
}
def stream_generate(
self,
prompt: str,
stream_callback: Optional[Callable[[str], None]] = None,
**kwargs
) -> str:
"""
流式生成响应 - 使用OpenAI兼容接口
Args:
prompt: 输入提示词
stream_callback: 流式回调函数接收每个token chunk
**kwargs: 其他参数
Returns:
完整的生成文本
"""
try:
# 使用OpenAI Python SDK进行流式调用
from openai import OpenAI
# 创建OpenAI兼容客户端
client = OpenAI(
api_key=self.api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 发起流式请求
stream = client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt}
],
stream=True,
max_tokens=kwargs.get("max_tokens", 2048),
temperature=kwargs.get("temperature", 0.7),
top_p=kwargs.get("top_p", 0.8),
stream_options={"include_usage": True}
)
full_text = ""
chunk_count = 0
print(f"[DEBUG] 开始接收OpenAI兼容流式响应...")
for chunk in stream:
# 提取增量内容
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
text_chunk = delta.content
full_text += text_chunk
chunk_count += 1
if stream_callback:
stream_callback(text_chunk)
# 检查是否包含使用信息
if hasattr(chunk, 'usage') and chunk.usage:
self._update_token_usage({
'input_tokens': chunk.usage.prompt_tokens,
'output_tokens': chunk.usage.completion_tokens,
'total_tokens': chunk.usage.total_tokens
})
print(f"[DEBUG] 流式结束,共收到 {chunk_count} 个chunk总长度 {len(full_text)} 字符")
return full_text.strip()
except ImportError:
print("[WARNING] OpenAI SDK未安装降级到非流式")
return self._call_api(prompt, **kwargs)
except Exception as e:
print(f"[WARNING] OpenAI兼容流式失败: {e},降级到非流式")
return self._call_api(prompt, **kwargs)
def _update_token_usage(self, usage_info: Dict[str, Any]):
"""更新Token使用统计"""
self.last_token_usage = {
'prompt_tokens': usage_info.get('input_tokens', 0),
'completion_tokens': usage_info.get('output_tokens', 0),
'total_tokens': usage_info.get('total_tokens', 0)
}
self.total_token_usage['prompt_tokens'] += self.last_token_usage['prompt_tokens']
self.total_token_usage['completion_tokens'] += self.last_token_usage['completion_tokens']
self.total_token_usage['total_tokens'] += self.last_token_usage['total_tokens']
self.total_token_usage['call_count'] += 1
def invoke(
self,
input: Union[str, List[str]],
config: Optional[dict] = None,
**kwargs
) -> Union[str, LLMResult]:
"""
统一的调用接口,支持流式和非流式
Args:
input: 输入文本或文本列表
config: 配置字典包含stream_callback等
**kwargs: 其他参数
Returns:
生成的文本或LLMResult对象
"""
# 从config中提取流式回调
stream_callback = None
if config and config.get('metadata', {}).get('stream_callback'):
stream_callback = config['metadata']['stream_callback']
# 如果是字符串输入
if isinstance(input, str):
if stream_callback:
# 使用流式生成
return self.stream_generate(input, stream_callback, **kwargs)
else:
# 使用普通生成
return self._call_api(input, **kwargs)
# 如果是列表输入批处理调用原有的_generate
return self._generate(input, **kwargs)
def create_dashscope_llm(
api_key: Optional[str] = None,
model_name: Optional[str] = None,
**kwargs
) -> DashScopeLLM:
"""创建DashScope LLM实例的便捷函数"""
return DashScopeLLM(
api_key=api_key,
model_name=model_name,
**kwargs
)

View File

@ -0,0 +1,169 @@
"""
ES向量检索器
用于直接与ES向量库进行向量匹配检索
"""
import os
import sys
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
from elasticsearch_vectorization.es_client_wrapper import ESClientWrapper
from elasticsearch_vectorization.config import ElasticsearchConfig
class ESVectorRetriever:
"""ES向量检索器用于直接进行向量相似度匹配"""
def __init__(
self,
keyword: str,
top_k: int = 3,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
embed_model_name: Optional[str] = None
):
"""
初始化ES向量检索器
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
embed_model_name: 嵌入模型名称
"""
self.keyword = keyword
self.top_k = top_k
# 初始化嵌入模型
self.embedding_model = DashScopeEmbeddingModel(
api_key=oneapi_key,
model_name=embed_model_name
)
# 初始化ES客户端
self.es_client = ESClientWrapper()
# 设置索引名称
self.passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
print(f"ES向量检索器初始化完成目标索引: {self.passages_index}")
def retrieve(self, query: str) -> List[Document]:
"""
检索相关文档
Args:
query: 查询文本
Returns:
检索到的文档列表
"""
try:
# 生成查询向量
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0]
# 确保向量是列表格式
if hasattr(query_embedding, 'tolist'):
query_vector = query_embedding.tolist()
else:
query_vector = list(query_embedding)
# 执行向量搜索
search_result = self.es_client.vector_search(
index_name=self.passages_index,
vector=query_vector,
field="embedding",
size=self.top_k
)
# 解析搜索结果
documents = []
hits = search_result.get("hits", {}).get("hits", [])
for hit in hits:
source = hit["_source"]
score = hit["_score"]
# 创建Document对象
doc = Document(
page_content=source.get("content", ""),
metadata={
"passage_id": source.get("passage_id", ""),
"file_id": source.get("file_id", ""),
"evidence": source.get("evidence", ""),
"score": score,
"source": "es_vector_search"
}
)
documents.append(doc)
print(f"ES向量检索完成找到 {len(documents)} 个相关文档")
return documents
except Exception as e:
print(f"ES向量检索失败: {e}")
return []
def test_connection(self) -> bool:
"""测试ES连接"""
try:
return self.es_client.ping()
except:
return False
def get_index_stats(self) -> Dict[str, Any]:
"""获取索引统计信息"""
try:
query = {"match_all": {}}
result = self.es_client.search(self.passages_index, query, size=0)
total = result.get("hits", {}).get("total", 0)
# 兼容不同ES版本的total格式
if isinstance(total, dict):
count = total.get("value", 0)
else:
count = total
return {
"index_name": self.passages_index,
"document_count": count,
"top_k": self.top_k
}
except Exception as e:
print(f"获取索引统计信息失败: {e}")
return {
"index_name": self.passages_index,
"document_count": 0,
"top_k": self.top_k,
"error": str(e)
}
def create_es_vector_retriever(
keyword: str,
top_k: int = 3,
**kwargs
) -> ESVectorRetriever:
"""
创建ES向量检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 返回的文档数量
**kwargs: 其他参数
Returns:
ES向量检索器实例
"""
return ESVectorRetriever(
keyword=keyword,
top_k=top_k,
**kwargs
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,260 @@
"""
LangGraph状态定义
定义工作流中的状态结构
"""
from typing import List, Dict, Any, Optional, TypedDict, Callable
from dataclasses import dataclass
from langchain_core.documents import Document
class QueryState(TypedDict):
"""查询状态类型定义"""
# 基本查询信息
original_query: str
current_iteration: int
max_iterations: int
# 调试模式0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
debug_mode: str
# 查询复杂度
query_complexity: Dict[str, Any]
is_complex_query: bool
# 检索结果
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
all_documents: List[Document] # 现在包含事件节点文档
passage_sources: List[str] # 事件来源信息
# 子查询相关
sub_queries: List[str]
current_sub_queries: List[str]
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
# 充分性检查
sufficiency_check: Dict[str, Any]
is_sufficient: bool
# 最终结果
final_answer: str
# 调试信息
debug_info: Dict[str, Any]
iteration_history: List[Dict[str, Any]]
# 概念探索相关状态(新增)
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输改为本地临时存储
concept_exploration_results: Dict[str, Any]
exploration_round: int # 当前探索轮次 (1 或 2)
# 添加本地PageRank存储标识不包含实际数据
pagerank_data_available: bool
# 流式回调函数(可选)
stream_callback: Optional[Callable]
@dataclass
class RetrievalResult:
"""检索结果数据类"""
passages: List[str]
documents: List[Document]
sources: List[str]
query: str
iteration: int
@dataclass
class SufficiencyCheck:
"""充分性检查结果数据类"""
is_sufficient: bool
confidence: float
reason: str
sub_queries: Optional[List[str]] = None
def create_initial_state(
original_query: str,
max_iterations: int = 2,
debug_mode: str = "0"
) -> QueryState:
"""
创建初始状态
Args:
original_query: 用户的原始查询
max_iterations: 最大迭代次数
debug_mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
Returns:
初始状态字典
"""
return QueryState(
original_query=original_query,
current_iteration=0,
max_iterations=max_iterations,
debug_mode=debug_mode,
query_complexity={},
is_complex_query=False,
all_passages=[],
all_documents=[],
passage_sources=[],
sub_queries=[],
current_sub_queries=[],
decomposed_sub_queries=[],
sufficiency_check={},
is_sufficient=False,
final_answer="",
debug_info={
"retrieval_calls": 0,
"llm_calls": 0,
"start_time": None,
"end_time": None
},
iteration_history=[],
concept_exploration_results={},
exploration_round=0,
pagerank_data_available=False,
stream_callback=None # 初始化为None会在需要时设置
)
def update_state_with_retrieval(
state: QueryState,
retrieval_result: RetrievalResult
) -> QueryState:
"""
使用检索结果更新状态
Args:
state: 当前状态
retrieval_result: 检索结果
Returns:
更新后的状态
"""
# 添加新的段落和文档
state["all_passages"].extend(retrieval_result.passages)
state["all_documents"].extend(retrieval_result.documents)
state["passage_sources"].extend(retrieval_result.sources)
# 更新调试信息
state["debug_info"]["retrieval_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": retrieval_result.iteration,
"query": retrieval_result.query,
"passages_count": len(retrieval_result.passages),
"action": "retrieval"
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_sufficiency_check(
state: QueryState,
sufficiency_check: SufficiencyCheck
) -> QueryState:
"""
使用充分性检查结果更新状态
Args:
state: 当前状态
sufficiency_check: 充分性检查结果
Returns:
更新后的状态
"""
state["is_sufficient"] = sufficiency_check.is_sufficient
state["sufficiency_check"] = {
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"reason": sufficiency_check.reason,
"iteration": state["current_iteration"]
}
# 如果不充分且有子查询,更新子查询
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
state["current_sub_queries"] = sufficiency_check.sub_queries
state["sub_queries"].extend(sufficiency_check.sub_queries)
else:
state["current_sub_queries"] = []
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "sufficiency_check",
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"sub_queries_count": len(sufficiency_check.sub_queries or [])
}
state["iteration_history"].append(iteration_info)
return state
def increment_iteration(state: QueryState) -> QueryState:
"""
增加迭代次数
Args:
state: 当前状态
Returns:
更新后的状态
"""
state["current_iteration"] += 1
return state
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
"""
完成状态,设置最终答案
Args:
state: 当前状态
final_answer: 最终答案
Returns:
最终状态
"""
state["final_answer"] = final_answer
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "final_answer_generation",
"answer_length": len(final_answer)
}
state["iteration_history"].append(iteration_info)
return state
def get_state_summary(state: QueryState) -> Dict[str, Any]:
"""
获取状态摘要信息
Args:
state: 当前状态
Returns:
状态摘要字典
"""
return {
"original_query": state["original_query"],
"current_iteration": state["current_iteration"],
"max_iterations": state["max_iterations"],
"total_passages": len(state["all_passages"]),
"total_sub_queries": len(state["sub_queries"]),
"is_sufficient": state["is_sufficient"],
"has_final_answer": bool(state["final_answer"]),
"debug_info": state["debug_info"],
"iteration_count": len(state["iteration_history"])
}

View File

@ -0,0 +1,318 @@
"""
基于LangGraph的迭代检索器
实现智能迭代检索工作流
"""
import time
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
from retriver.langgraph.graph_nodes import GraphNodes
from retriver.langgraph.routing_functions import (
should_continue_retrieval,
route_by_complexity,
route_by_debug_mode
)
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
from retriver.langgraph.langchain_components import create_oneapi_llm
class IterativeRetriever:
"""
基于LangGraph的迭代检索器
实现迭代检索流程:
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
"""
def __init__(
self,
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None,
embed_model_name: Optional[str] = None,
complexity_model_name: Optional[str] = None,
sufficiency_model_name: Optional[str] = None,
skip_llm_generation: bool = False
):
"""
初始化迭代检索器
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
model_name: 主LLM模型名称(用于生成答案)
embed_model_name: 嵌入模型名称
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
skip_llm_generation: 是否跳过LLM生成答案仅返回检索结果
"""
self.keyword = keyword
self.top_k = top_k
self.max_iterations = max_iterations
self.max_parallel_retrievals = max_parallel_retrievals
self.skip_llm_generation = skip_llm_generation
# 使用默认值(如果没有指定特定模型)
complexity_model_name = complexity_model_name or model_name
sufficiency_model_name = sufficiency_model_name or model_name
# 创建组件
print("[CONFIG] 初始化检索器组件...")
self.retriever = create_langchain_hipporag_retriever(
keyword=keyword,
top_k=top_k,
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
oneapi_model_gen=model_name,
oneapi_model_embed=embed_model_name
)
# 创建主LLM用于生成答案
self.llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=model_name
)
# 创建复杂度判断LLM如果模型不同
if complexity_model_name != model_name:
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
self.complexity_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=complexity_model_name
)
else:
self.complexity_llm = self.llm
# 创建充分性检查LLM如果模型不同
if sufficiency_model_name != model_name:
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
self.sufficiency_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=sufficiency_model_name
)
else:
self.sufficiency_llm = self.llm
# 创建节点处理器
self.nodes = GraphNodes(
retriever=self.retriever,
llm=self.llm,
complexity_llm=self.complexity_llm,
sufficiency_llm=self.sufficiency_llm,
keyword=keyword,
max_parallel_retrievals=max_parallel_retrievals,
skip_llm_generation=skip_llm_generation
)
# 构建工作流图
self.workflow = self._build_workflow()
print("[OK] 迭代检索器初始化完成")
def _build_workflow(self) -> StateGraph:
"""构建LangGraph工作流"""
print("[?] 构建工作流图...")
# 创建状态图
workflow = StateGraph(QueryState)
# 添加节点
# 新增:查询复杂度判断节点
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
# 新增:调试模式节点
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
# 简单查询路径
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
# 复杂查询路径现有hipporag2逻辑
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
# 设置入口点:从查询复杂度判断开始
workflow.set_entry_point("query_complexity_check")
# 复杂度检查后进入调试模式节点
workflow.add_edge("query_complexity_check", "debug_mode_node")
# 条件边:根据调试模式和复杂度判断结果决定路径
workflow.add_conditional_edges(
"debug_mode_node",
route_by_debug_mode,
{
"simple_vector_retrieval": "simple_vector_retrieval",
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
}
)
# 简单查询路径的边
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
workflow.add_edge("simple_answer_generation", END)
# 复杂查询路径的边(包含查询分解逻辑)
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
workflow.add_edge("initial_retrieval", "sufficiency_check")
# 条件边:根据充分性检查结果决定下一步
workflow.add_conditional_edges(
"sufficiency_check",
should_continue_retrieval,
{
"final_answer": "final_answer",
"parallel_retrieval": "sub_query_generation",
"next_iteration": "next_iteration"
}
)
workflow.add_edge("sub_query_generation", "parallel_retrieval")
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后增加迭代次数
workflow.add_edge("next_iteration", "sufficiency_check")
# 结束节点
workflow.add_edge("final_answer", END)
return workflow.compile()
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
"""
执行迭代检索
Args:
query: 用户查询
mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
Returns:
包含最终答案和详细信息的字典
"""
print(f"[STARTING] 开始迭代检索: {query}")
start_time = time.time()
# 创建初始状态
initial_state = create_initial_state(
original_query=query,
max_iterations=self.max_iterations,
debug_mode=mode
)
initial_state["debug_info"]["start_time"] = start_time
try:
# 执行工作流
final_state = self.workflow.invoke(initial_state)
# 记录结束时间
end_time = time.time()
final_state["debug_info"]["end_time"] = end_time
final_state["debug_info"]["total_time"] = end_time - start_time
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}")
# 返回结果
return {
"query": query,
"answer": final_state["final_answer"],
"query_complexity": final_state["query_complexity"],
"is_complex_query": final_state["is_complex_query"],
"iterations": final_state["current_iteration"],
"total_passages": len(final_state["all_passages"]),
"sub_queries": final_state["sub_queries"],
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
"sufficiency_check": final_state["sufficiency_check"],
"all_passages": final_state["all_passages"],
"debug_info": final_state["debug_info"],
"state_summary": get_state_summary(final_state),
"iteration_history": final_state["iteration_history"]
}
except Exception as e:
print(f"[ERROR] 迭代检索失败: {e}")
return {
"query": query,
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
"error": str(e),
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
}
def retrieve_simple(self, query: str) -> str:
"""
简单检索接口,只返回答案
Args:
query: 用户查询
Returns:
最终答案字符串
"""
result = self.retrieve(query)
return result["answer"]
def get_retrieval_stats(self) -> Dict[str, Any]:
"""
获取检索器统计信息
Returns:
统计信息字典
"""
return {
"keyword": self.keyword,
"top_k": self.top_k,
"max_iterations": self.max_iterations,
"max_parallel_retrievals": self.max_parallel_retrievals,
"retriever_type": "IterativeRetriever with LangGraph",
"model_info": {
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
}
}
def create_iterative_retriever(
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
**kwargs
) -> IterativeRetriever:
"""
创建迭代检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
**kwargs: 其他参数
Returns:
迭代检索器实例
"""
return IterativeRetriever(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
**kwargs
)

View File

@ -0,0 +1,952 @@
"""
LangChain组件实现
包含LLM、提示词模板、输出解析器等
"""
import os
import sys
import json
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
from langchain_core.language_models import BaseLLM
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import BaseOutputParser, PydanticOutputParser
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
# 尝试导入LangSmith的traceable装饰器
try:
from langsmith import traceable
LANGSMITH_AVAILABLE = True
except ImportError:
def traceable(name=None):
def decorator(func):
return func
return decorator
LANGSMITH_AVAILABLE = False
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langgraph.dashscope_llm import DashScopeLLM
from prompt_loader import get_prompt_loader
class SufficiencyCheckResult(BaseModel):
"""充分性检查结果"""
is_sufficient: bool = Field(description="是否足够回答用户查询")
confidence: float = Field(description="置信度0-1之间", ge=0, le=1)
reason: str = Field(description="判断理由")
sub_queries: Optional[List[str]] = Field(default=None, description="如果不充分,生成的子查询列表")
class QueryComplexityResult(BaseModel):
"""查询复杂度判断结果"""
is_complex: bool = Field(description="查询是否复杂")
complexity_level: str = Field(description="复杂度级别simple/complex")
confidence: float = Field(description="置信度0-1之间", ge=0, le=1)
reason: str = Field(description="判断理由")
class OneAPILLM(BaseLLM):
"""
LangChain包装的DashScope LLM保持原有接口兼容性
直接使用阿里云DashScope原生API
支持流式输出
"""
# Pydantic字段声明
oneapi_generator: Any = Field(default=None, description="DashScope LLM实例", exclude=True)
def __init__(self, dashscope_llm: DashScopeLLM, **kwargs):
super().__init__(**kwargs)
self.oneapi_generator = dashscope_llm # 保持接口兼容性
def _generate(
self,
messages: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
):
"""生成响应"""
# 简化处理,取第一个消息
message = messages[0] if messages else ""
# 直接调用DashScope LLM
try:
response_text = self.oneapi_generator._call_api(
message,
stop=stop,
**kwargs
)
except Exception as e:
print(f"[ERROR] DashScope API调用失败: {e}")
response_text = f"API调用失败: {str(e)}"
from langchain_core.outputs import LLMResult, Generation
# 获取Token使用信息
token_usage = getattr(self.oneapi_generator, 'last_token_usage', {})
total_usage = getattr(self.oneapi_generator, 'total_token_usage', {})
# 打印Token信息到控制台
if token_usage:
print(f"[?] LLM调用Token消耗: 输入{token_usage.get('prompt_tokens', 0)}, "
f"输出{token_usage.get('completion_tokens', 0)}, "
f"总计{token_usage.get('total_tokens', 0)} "
f"[累计: {total_usage.get('total_tokens', 0)}]")
result = LLMResult(
generations=[[Generation(text=response_text)]],
llm_output={
"model_name": self.oneapi_generator.model_name,
"token_usage": {
"current_call": token_usage,
"cumulative_total": total_usage,
"prompt_tokens": token_usage.get('prompt_tokens', 0),
"completion_tokens": token_usage.get('completion_tokens', 0),
"total_tokens": token_usage.get('total_tokens', 0)
}
}
)
return result
def _llm_type(self) -> str:
return "oneapi_llm"
def invoke(
self,
input: Any,
config: Optional[dict] = None,
**kwargs: Any
) -> Any:
"""
统一的调用接口,支持流式和非流式
Args:
input: 输入文本
config: 配置字典包含stream_callback等
**kwargs: 其他参数
Returns:
生成的文本或LLMResult对象
"""
# 直接委托给底层的DashScopeLLM的invoke方法
return self.oneapi_generator.invoke(input, config=config, **kwargs)
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"model_name": self.oneapi_generator.model_name,
"api_url": getattr(self.oneapi_generator, 'api_url', 'dashscope')
}
def stream(self, prompt: str, **kwargs):
"""
流式生成文本
Args:
prompt: 输入提示词
**kwargs: 其他参数
Yields:
每个token字符串
"""
# 直接调用底层DashScope的流式方法
return self.oneapi_generator.stream(prompt, **kwargs)
class SufficiencyCheckParser(BaseOutputParser[SufficiencyCheckResult]):
"""充分性检查结果解析器"""
def parse(self, text: str) -> SufficiencyCheckResult:
"""解析LLM输出为结构化结果"""
try:
# 清理markdown代码块标记
cleaned_text = text.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试解析JSON
data = json.loads(cleaned_text)
return SufficiencyCheckResult(**data)
except (json.JSONDecodeError, ValueError):
# 如果JSON解析失败使用规则解析
return self._rule_based_parse(text)
def _rule_based_parse(self, text: str) -> SufficiencyCheckResult:
"""基于规则的解析"""
text_lower = text.lower().strip()
# 判断是否充分
is_sufficient = any(keyword in text_lower for keyword in [
"充分", "足够", "sufficient", "enough", "adequate"
]) and not any(keyword in text_lower for keyword in [
"不充分", "不足够", "insufficient", "not enough", "inadequate"
])
# 提取置信度
confidence = 0.7 # 默认置信度
# 提取理由
reason = text[:200] if len(text) <= 200 else text[:200] + "..."
# 提取子查询
sub_queries = None
if not is_sufficient:
# 简单提取子查询(这里可以更复杂的实现)
lines = text.split('\n')
queries = []
for line in lines:
if '?' in line and len(line.strip()) > 10:
queries.append(line.strip())
sub_queries = queries[:2] if queries else None
return SufficiencyCheckResult(
is_sufficient=is_sufficient,
confidence=confidence,
reason=reason,
sub_queries=sub_queries
)
class QueryComplexityParser(BaseOutputParser[QueryComplexityResult]):
"""查询复杂度判断结果解析器"""
def parse(self, text: str) -> QueryComplexityResult:
"""解析LLM输出为结构化结果"""
try:
# 清理markdown代码块标记
cleaned_text = text.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试解析JSON
data = json.loads(cleaned_text)
return QueryComplexityResult(**data)
except (json.JSONDecodeError, ValueError):
# 如果JSON解析失败使用规则解析
return self._rule_based_parse(text)
def _rule_based_parse(self, text: str) -> QueryComplexityResult:
"""基于规则的解析"""
text_lower = text.lower().strip()
# 判断是否复杂
is_complex = any(keyword in text_lower for keyword in [
"复杂", "complex", "推理", "多步", "关联", "综合", "分析"
]) or any(keyword in text_lower for keyword in [
"需要", "require", "关系", "连接", "因果", "比较"
])
# 如果包含简单标识,则不复杂
if any(keyword in text_lower for keyword in [
"简单", "simple", "直接", "单一", "基础"
]):
is_complex = False
complexity_level = "complex" if is_complex else "simple"
confidence = 0.8 # 默认置信度
reason = text[:200] if len(text) <= 200 else text[:200] + "..."
return QueryComplexityResult(
is_complex=is_complex,
complexity_level=complexity_level,
confidence=confidence,
reason=reason
)
@dataclass
class ConceptExplorationResult:
"""概念探索结果数据类"""
selected_concept: str
exploration_reason: str
confidence: float
expected_knowledge: str
@dataclass
class ConceptContinueResult:
"""概念探索继续结果数据类"""
selected_node: str
selected_relation: str
exploration_reason: str
confidence: float
expected_knowledge: str
@dataclass
class ConceptSufficiencyResult:
"""概念探索充分性检查结果数据类"""
is_sufficient: bool
confidence: float
reason: str
missing_aspects: List[str]
coverage_score: float
@dataclass
class ConceptSubQueryResult:
"""概念探索子查询生成结果数据类"""
sub_query: str
focus_aspects: List[str]
expected_improvements: str
confidence: float
class ConceptExplorationParser:
"""概念探索初始选择解析器"""
def parse(self, text: str) -> ConceptExplorationResult:
"""解析概念探索选择结果"""
try:
data = json.loads(text.strip())
return ConceptExplorationResult(
selected_concept=data.get("selected_concept", ""),
exploration_reason=data.get("exploration_reason", ""),
confidence=float(data.get("confidence", 0.5)),
expected_knowledge=data.get("expected_knowledge", "")
)
except (json.JSONDecodeError, ValueError, KeyError) as e:
# 如果解析失败,返回默认结果
return ConceptExplorationResult(
selected_concept="",
exploration_reason=f"解析失败: {str(e)}",
confidence=0.1,
expected_knowledge=""
)
class ConceptContinueParser:
"""概念探索继续解析器"""
def parse(self, text: str) -> ConceptContinueResult:
"""解析概念探索继续结果"""
try:
data = json.loads(text.strip())
return ConceptContinueResult(
selected_node=data.get("selected_node", ""),
selected_relation=data.get("selected_relation", ""),
exploration_reason=data.get("exploration_reason", ""),
confidence=float(data.get("confidence", 0.5)),
expected_knowledge=data.get("expected_knowledge", "")
)
except (json.JSONDecodeError, ValueError, KeyError) as e:
# 如果解析失败,返回默认结果
return ConceptContinueResult(
selected_node="",
selected_relation="",
exploration_reason=f"解析失败: {str(e)}",
confidence=0.1,
expected_knowledge=""
)
class ConceptSufficiencyParser:
"""概念探索充分性检查解析器"""
def parse(self, text: str) -> ConceptSufficiencyResult:
"""解析概念探索充分性检查结果"""
try:
data = json.loads(text.strip())
return ConceptSufficiencyResult(
is_sufficient=bool(data.get("is_sufficient", False)),
confidence=float(data.get("confidence", 0.5)),
reason=data.get("reason", ""),
missing_aspects=data.get("missing_aspects", []),
coverage_score=float(data.get("coverage_score", 0.0))
)
except (json.JSONDecodeError, ValueError, KeyError) as e:
# 如果解析失败,返回不充分的默认结果
return ConceptSufficiencyResult(
is_sufficient=False,
confidence=0.1,
reason=f"解析失败: {str(e)}",
missing_aspects=["解析错误"],
coverage_score=0.0
)
class ConceptSubQueryParser:
"""概念探索子查询生成解析器"""
def parse(self, text: str) -> ConceptSubQueryResult:
"""解析概念探索子查询生成结果"""
try:
data = json.loads(text.strip())
return ConceptSubQueryResult(
sub_query=data.get("sub_query", ""),
focus_aspects=data.get("focus_aspects", []),
expected_improvements=data.get("expected_improvements", ""),
confidence=float(data.get("confidence", 0.5))
)
except (json.JSONDecodeError, ValueError, KeyError) as e:
# 如果解析失败,返回默认结果
return ConceptSubQueryResult(
sub_query="",
focus_aspects=["解析错误"],
expected_improvements=f"解析失败: {str(e)}",
confidence=0.1
)
# 获取全局提示词加载器
_prompt_loader = get_prompt_loader()
# 动态加载提示词模板
QUERY_COMPLEXITY_CHECK_PROMPT = _prompt_loader.get_prompt_template('query_complexity_check')
# 保留原始定义作为备份(已注释)
'''
QUERY_COMPLEXITY_CHECK_PROMPT = PromptTemplate(
input_variables=["query"],
template="""
作为一个智能查询分析助手,请分析用户查询的复杂度,判断该查询是否需要生成多方面的多个子查询来回答。
用户查询:{query}
请根据以下标准判断查询复杂度:
【复杂查询Complex特征 - 优先判断】:
1. **多问句查询**:包含多个问号(?),涉及多个不同的问题或主题
2. **跨领域查询**:涉及多个不同领域或行业的知识(如金融+技术+风险管理等)
3. **复合问题**:一个查询中包含多个子查询,即使每个子查询本身较简单
4. **关系型查询**:询问实体间的关系、比较、关联等
5. **因果推理**:询问原因、结果、影响等
6. **综合分析**:需要综合多个信息源进行分析
7. **推理链查询**:需要通过知识图谱的路径推理才能回答
8. **列表型查询**:要求列举多个项目或要素的问题
【简单查询Simple特征】
1. **单一问句**:只包含一个问号,聚焦于单一主题
2. **单领域查询**:仅涉及一个明确的领域或概念
3. **直接定义查询**:询问单一概念的定义、特点、属性等
4. **单一实体信息查询**:询问某个具体事物的基本信息
5. **可以通过文档中的连续文本段落直接回答的单一问题**
请以JSON格式返回判断结果
{{
"is_complex": false, // true表示复杂查询false表示简单查询
"complexity_level": "simple", // "simple""complex"
"confidence": 0.9, // 置信度0-1之间
"reason": "这是一个复杂查询,需要生成多方面的个子查询来回答..."
}}
请确保返回有效的JSON格式。
"""
)
'''
# 动态加载充分性检查提示词
SUFFICIENCY_CHECK_PROMPT = _prompt_loader.get_prompt_template('sufficiency_check')
# 保留原始定义作为备份(已注释)
'''
SUFFICIENCY_CHECK_PROMPT = PromptTemplate(
input_variables=["query", "passages", "decomposed_sub_queries"],
template="""
作为一个智能问答助手,请判断仅从给定的信息是否已经足够回答用户的查询。
用户查询:{query}
已生成的子查询:{decomposed_sub_queries}
检索到的信息(包括原始查询和子查询的结果):
{passages}
请分析这些信息是否包含足够的内容来完整回答用户的查询。注意检索结果包含两部分:
1. 【事件信息】- 来自知识图谱的事件节点,包含事件描述和上下文
2. 【段落信息】- 来自文档的段落内容,包含详细的文本描述
如果信息充分请返回JSON格式
{{
"is_sufficient": true,
"confidence": 0.9,
"reason": "事件信息和段落信息包含了回答查询所需的关键内容..."
}}
如果信息不充分请返回JSON格式并详细说明缺失的信息
{{
"is_sufficient": false,
"confidence": 0.5,
"reason": "检索信息缺少某些关键内容具体包括1) 缺少XXX的详细描述2) 缺少XXX的具体实例3) 缺少XXX的应用场景等..."
}}
请确保返回有效的JSON格式。
"""
)
'''
# 动态加载查询分解提示词
QUERY_DECOMPOSITION_PROMPT = _prompt_loader.get_prompt_template('query_decomposition')
# 保留原始定义作为备份(已注释)
'''
QUERY_DECOMPOSITION_PROMPT = PromptTemplate(
input_variables=["original_query"],
template="""
你是一个专业的查询分解助手。你的任务是将用户的复合查询分解为2个独立的、适合向量检索的子查询。
用户原始查询:{original_query}
【重要】向量检索特点:
- 向量检索通过语义相似度匹配,找到与查询最相关的文档段落
- 每个子查询应该聚焦一个明确的主题或概念
- 子查询应该是完整的、可独立检索的问题
【分解策略】:
1. **识别查询结构**
- 仔细查看查询中是否有多个问号()分隔的不同问题
- 查找连接词:"""以及""还有""另外"
- 查找标点符号:句号(。)、分号()等分隔符
2. **按主题分解**
- 如果查询包含多个独立主题,将其分解为独立的问题
- 每个子查询保持完整性,包含必要的上下文信息
【示例】:
输入:"混沌工程是什么?混沌工程的基础建设是由什么整合的?"
分析:这是两个关于混沌工程的不同问题
输出:["混沌工程是什么", "混沌工程的基础建设是由什么整合的"]
输入:"2023年某年度报告中提到存在数据安全问题的支付平台可能涉及哪些外部技术组件混沌工程的定义是什么"
分析:这包含两个完全不同的主题
输出:["2023年某年度报告中提到存在数据安全问题的支付平台可能涉及哪些外部技术组件", "混沌工程的定义是什么"]
输入:"什么是人工智能?它的主要应用有哪些?"
分析:两个相关但独立的问题
输出:["什么是人工智能", "人工智能的主要应用有哪些"]
输入:"区块链技术的优势和挑战是什么?"
分析:单一主题但包含两个方面
输出:["区块链技术的优势是什么", "区块链技术面临哪些挑战"]
【要求】:
1. [NOTE] 必须使用自然语言,类似人类会问的问题
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
3. [SEARCH] 严格按照查询的自然分割点进行分解
4. [?] 每个子查询必须是完整的自然语言问题
5. [BLOCKED] 绝不要添加"详细信息""相关内容""补充信息"等后缀
6. [INFO] 保持原始查询中的所有关键信息(时间、地点、对象等)
7. [TARGET] 确保两个子查询可以独立进行向量检索
请严格按照JSON格式返回自然语言查询
{{
"sub_queries": ["自然语言子查询1", "自然语言子查询2"]
}}
"""
)
'''
# 动态加载子查询生成提示词
SUB_QUERY_GENERATION_PROMPT = _prompt_loader.get_prompt_template('sub_query_generation')
# 保留原始定义作为备份(已注释)
'''
SUB_QUERY_GENERATION_PROMPT = PromptTemplate(
input_variables=["original_query", "existing_passages", "previous_sub_queries", "insufficiency_reason"],
template="""
基于用户的原始查询、已有的检索结果和充分性检查反馈生成2个相关的子查询来获取缺失信息。
原始查询:{original_query}
之前生成的子查询:{previous_sub_queries}
已有检索结果(包含事件信息和段落信息):
{existing_passages}
充分性检查反馈(信息不充分的原因):
{insufficiency_reason}
请根据充分性检查反馈中指出的缺失信息生成2个具体的自然语言子查询来补充这些缺失内容。注意已有检索结果包含
1. 【事件信息】- 来自知识图谱的事件节点
2. 【段落信息】- 来自文档的段落内容
【重要要求】:
1. [NOTE] 必须使用自然语言表达,类似人类会问的问题
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
3. [OK] 使用疑问句形式,如"什么是...""如何...""有哪些..."
4. [TARGET] 直接针对充分性检查中指出的缺失信息
5. [LINK] 与原始查询高度相关
6. [INFO] 查询要具体明确,能够获取到具体的信息
7. [BLOCKED] 避免与之前已生成的子查询重复
8. [TARGET] 确保每个子查询都能独立检索到有价值的信息
【示例格式】:
- [OK] 正确:"混沌工程的基础建设具体包括哪些组件?"
- [ERROR] 错误:"SELECT * FROM chaos_engineering WHERE type='infrastructure'"
- [OK] 正确:"供应链风险管理中常见的网络安全威胁有哪些?"
- [ERROR] 错误:"SELECT risks FROM supply_chain WHERE category='network'"
请以JSON格式返回自然语言查询
{{
"sub_queries": ["自然语言子查询1", "自然语言子查询2"]
}}
"""
)
'''
# 动态加载简单答案提示词
SIMPLE_ANSWER_PROMPT = _prompt_loader.get_prompt_template('simple_answer')
# 保留原始定义作为备份(已注释)
'''
SIMPLE_ANSWER_PROMPT = PromptTemplate(
input_variables=["query", "passages"],
template="""
基于检索到的信息,请为用户的查询提供一个准确、简洁的答案。
用户查询:{query}
检索到的相关信息:
{passages}
请基于这些信息回答用户的查询。注意检索结果包含:
1. 【事件信息】- 来自知识图谱的事件节点,提供事件相关的上下文
2. 【段落信息】- 来自文档的段落内容,提供详细的文本描述
要求:
1. 直接回答用户的查询
2. 严格基于提供的信息,不要编造内容
3. 综合利用事件信息和段落信息
4. 如果信息不足以完整回答,请明确说明
5. 答案要简洁明了,重点突出
答案:
"""
)
'''
# 动态加载最终答案提示词
FINAL_ANSWER_PROMPT = _prompt_loader.get_prompt_template('final_answer')
# 保留原始定义作为备份(已注释)
'''
FINAL_ANSWER_PROMPT = PromptTemplate(
input_variables=["original_query", "all_passages", "sub_queries"],
template="""
基于所有检索到的信息,请为用户的查询提供一个完整、准确的答案。
用户查询:{original_query}
子查询历史:{sub_queries}
所有检索到的信息:
{all_passages}
请基于这些信息提供一个全面的答案。注意检索结果包含:
1. 【事件信息】- 来自知识图谱的事件节点,提供事件相关的上下文
2. 【段落信息】- 来自文档的段落内容,提供详细的文本描述
要求:
1. 直接回答用户的查询
2. 综合利用事件信息和段落信息,不要编造内容
3. 如果信息仍然不足,请明确说明
4. 答案要结构清晰,逻辑连贯
答案:
"""
)
'''
def create_oneapi_llm(
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None
) -> OneAPILLM:
"""
创建DashScope LLM的便捷函数保持原有接口兼容性
Args:
oneapi_key: 阿里云DashScope API Key
oneapi_base_url: 不再使用,保持兼容性
model_name: 模型名称
Returns:
OneAPI LLM实例内部使用DashScope
"""
# 从环境变量获取配置默认使用qwen-max
api_key = oneapi_key or os.getenv("ONEAPI_KEY")
model_name = model_name or os.getenv("ONEAPI_MODEL_MAX", "qwen-max")
# 创建DashScope LLM
dashscope_llm = DashScopeLLM(
api_key=api_key,
model_name=model_name
)
return OneAPILLM(dashscope_llm)
def format_passages(passages: List[str]) -> str:
"""格式化段落列表为字符串"""
if not passages:
return "无相关段落"
formatted = []
for i, passage in enumerate(passages, 1):
formatted.append(f"段落{i}{passage}")
return "\n\n".join(formatted)
def format_mixed_passages(documents: List[Document]) -> str:
"""格式化混合的文档列表(事件+段落)为字符串"""
if not documents:
return "无相关信息"
event_docs = [doc for doc in documents if doc.metadata.get('node_type') == 'event']
text_docs = [doc for doc in documents if doc.metadata.get('node_type') == 'text']
formatted_parts = []
# 格式化事件信息
if event_docs:
formatted_parts.append("【事件信息】")
for i, doc in enumerate(event_docs, 1):
formatted_parts.append(f"事件{i}{doc.page_content}")
# 格式化段落信息
if text_docs:
if formatted_parts: # 如果有事件信息,添加分隔
formatted_parts.append("") # 空行分隔
formatted_parts.append("【段落信息】")
for i, doc in enumerate(text_docs, 1):
formatted_parts.append(f"段落{i}{doc.page_content}")
return "\n\n".join(formatted_parts)
def format_sub_queries(sub_queries: List[str]) -> str:
"""格式化子查询列表为字符串"""
if not sub_queries:
return "无子查询"
return "".join(sub_queries)
def format_triplets(triplets: List[Dict[str, str]]) -> str:
"""格式化三元组列表为字符串"""
if not triplets:
return "无三元组信息"
formatted = []
for i, triplet in enumerate(triplets, 1):
source = triplet.get('source', 'unknown')
relation = triplet.get('relation', 'unknown')
target = triplet.get('target', 'unknown')
formatted.append(f"{i}. ({source}) --[{relation}]--> ({target})")
return "\n".join(formatted)
def format_exploration_path(path: List[Dict[str, str]]) -> str:
"""格式化探索路径为字符串"""
if not path:
return "无探索路径"
path_str = []
for step in path:
source = step.get('source', 'unknown')
relation = step.get('relation', 'unknown')
target = step.get('target', 'unknown')
reason = step.get('reason', '')
path_str.append(f"({source}) --[{relation}]--> ({target})")
if reason:
path_str.append(f" 探索原因: {reason}")
return " -> ".join([step.split(" 探索原因:")[0] for step in path_str if not step.strip().startswith("探索原因:")])
# 概念探索相关提示词模板
CONCEPT_EXPLORATION_INIT_PROMPT = PromptTemplate(
input_variables=["node_name", "connected_concepts", "user_query", "insufficiency_reason"],
template="""
作为一个知识图谱概念探索专家你需要分析一个Node节点及其连接的Concept节点判断出最值得探索的概念方向。
**探索起点Node节点**: {node_name}
**该Node连接的Concept节点列表**:
{connected_concepts}
**用户原始查询**: {user_query}
**上次充分性检查不通过的原因**: {insufficiency_reason}
**任务**: 根据用户查询和不充分的原因从连接的Concept列表中选择一个最值得深入探索的概念这个概念应该最有可能帮助回答用户查询或补充缺失的信息。
**判断标准**:
1. 与用户查询的相关性最高
2. 最有可能补充当前缺失的关键信息
3. 具有较强的延展性,能够引出更多相关知识
请以JSON格式返回结果
{{
"selected_concept": "最值得探索的概念名称",
"exploration_reason": "选择这个概念的详细原因",
"confidence": 0.9, // 置信度0-1之间
"expected_knowledge": "期望从这个概念探索中获得什么知识"
}}
只返回JSON格式的结果不要添加其他内容。
"""
)
CONCEPT_EXPLORATION_CONTINUE_PROMPT = PromptTemplate(
input_variables=["current_node", "neighbor_triplets", "exploration_path", "user_query"],
template="""
作为一个知识图谱概念探索专家,你正在进行概念探索,需要决定下一步的探索方向。
**当前节点**: {current_node}
**当前节点的邻居三元组**:
{neighbor_triplets}
**已有探索路径**:
{exploration_path}
**用户原始查询**: {user_query}
**任务**: 从当前节点的邻居中选择一个最值得继续探索的节点,确保:
1. 不重复之前探索过的节点
2. 选择的节点与用户查询最相关
3. 能够获得更深入的知识
请以JSON格式返回结果
{{
"selected_node": "选择的下一个节点名称",
"selected_relation": "到达该节点的关系",
"exploration_reason": "选择这个节点的详细原因",
"confidence": 0.9, // 置信度0-1之间
"expected_knowledge": "期望从这个节点探索中获得什么知识"
}}
只返回JSON格式的结果不要添加其他内容。
"""
)
CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT = PromptTemplate(
input_variables=["user_query", "all_passages", "exploration_knowledge", "insufficiency_reason"],
template="""
作为一个智能信息充分性评估专家,请综合评估当前信息是否足够回答用户查询。
**用户查询**: {user_query}
**现有检索信息**:
{all_passages}
**概念探索获得的知识**:
{exploration_knowledge}
**上次不充分的原因**: {insufficiency_reason}
**评估任务**:
1. 综合分析检索信息(事件信息+段落信息)和探索知识是否能够完整回答用户查询
2. 判断是否还有关键信息缺失
3. 如果仍然不充分,明确指出还需要什么信息
**评估标准**:
- 信息的完整性:能否覆盖查询的所有方面
- 信息的准确性:提供的信息是否准确可靠
- 信息的相关性:信息与查询的匹配度
- 逻辑连贯性:信息间是否形成完整的逻辑链
请以JSON格式返回评估结果
{{
"is_sufficient": false, // true表示信息充分false表示不充分
"confidence": 0.85, // 判断的置信度0-1之间
"reason": "详细说明充分或不充分的原因",
"missing_aspects": ["缺失的关键信息1", "缺失的关键信息2"], // 如果不充分,列出缺失的关键方面
"coverage_score": 0.7 // 当前信息对查询的覆盖程度0-1之间
}}
只返回JSON格式的结果不要添加其他内容。
"""
)
CONCEPT_EXPLORATION_SUB_QUERY_PROMPT = PromptTemplate(
input_variables=["user_query", "missing_aspects", "exploration_results"],
template="""
作为一个智能查询分析专家,根据概念探索的结果和仍然缺失的信息,生成新的子查询来进行第二轮概念探索。
**原始用户查询**: {user_query}
**仍然缺失的关键信息**: {missing_aspects}
**第一轮探索结果**: {exploration_results}
**任务**: 生成一个更精确的子查询,这个查询应该:
1. 专门针对缺失的关键信息
2. 比原始查询更具体和聚焦
3. 能够引导第二轮探索找到缺失的关键信息
**生成原则**:
- 保持与原始查询的相关性
- 专注于最关键的缺失方面
- 使用明确、具体的表述
- 避免过于宽泛或模糊
请以JSON格式返回结果
{{
"sub_query": "针对缺失信息生成的新查询",
"focus_aspects": ["查询重点关注的方面1", "方面2"],
"expected_improvements": "期望这个子查询如何改善信息充分性",
"confidence": 0.9
}}
只返回JSON格式的结果不要添加其他内容。
"""
)
# 导出列表
__all__ = [
'OneAPILLM',
'SufficiencyCheckParser',
'QueryComplexityParser',
'ConceptExplorationParser',
'ConceptContinueParser',
'ConceptSufficiencyParser',
'ConceptSubQueryParser',
'QueryComplexityResult',
'SufficiencyCheckResult',
'ConceptExplorationResult',
'ConceptContinueResult',
'ConceptSufficiencyResult',
'ConceptSubQueryResult',
'QUERY_COMPLEXITY_CHECK_PROMPT',
'SUFFICIENCY_CHECK_PROMPT',
'QUERY_DECOMPOSITION_PROMPT',
'SUB_QUERY_GENERATION_PROMPT',
'SIMPLE_ANSWER_PROMPT',
'FINAL_ANSWER_PROMPT',
'CONCEPT_EXPLORATION_INIT_PROMPT',
'CONCEPT_EXPLORATION_CONTINUE_PROMPT',
'CONCEPT_EXPLORATION_SUFFICIENCY_PROMPT',
'CONCEPT_EXPLORATION_SUB_QUERY_PROMPT',
'create_oneapi_llm',
'format_passages',
'format_mixed_passages',
'format_sub_queries',
'format_triplets',
'format_exploration_path'
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,107 @@
"""
LangGraph条件边路由函数
专门定义工作流的路由决策逻辑,与节点逻辑分离
"""
from typing import Literal
from retriver.langgraph.graph_state import QueryState
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据调试模式和查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询,直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
return "initial_retrieval"
else:
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
return "simple_vector_retrieval"
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询,直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 复杂查询进入HippoRAG2检索逻辑")
return "initial_retrieval"
else:
print(f"[RELOAD] 简单查询,进入直接向量检索")
return "simple_vector_retrieval"
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval", "next_iteration"]:
"""
决策函数:决定迭代检索工作流的下一步
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "final_answer": 生成最终答案并结束
- "parallel_retrieval": 执行子查询的并行检索
- "next_iteration": 进入下一轮迭代
决策逻辑:
1. 如果信息充分 -> "final_answer"
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
5. 其他情况 -> "next_iteration"
"""
# 如果信息充分,生成最终答案
if state['is_sufficient']:
print(f"[OK] 信息充分,生成最终答案")
return "final_answer"
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
if state['current_iteration'] >= state['max_iterations']:
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
return "final_answer"
# 防止死循环:检查连续不充分的次数
iteration_history = state.get('iteration_history', [])
consecutive_insufficient = 0
for i in range(len(iteration_history) - 1, -1, -1):
if not iteration_history[i].get('is_sufficient', False):
consecutive_insufficient += 1
else:
break
# 如果连续3次不充分且没有新的子查询强制结束
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
return "final_answer"
# 如果还有子查询需要处理,进行并行检索
if state['current_sub_queries']:
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
return "parallel_retrieval"
# 如果信息不充分且还没达到最大迭代次数,生成子查询
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
# 否则,进行下一轮迭代(这种情况应该很少发生)
print(f"[RELOAD] 开始下一轮迭代")
return "next_iteration"