Files
AIEC-RAG/retriver/langgraph/graph_nodes.py

1032 lines
43 KiB
Python
Raw Permalink Normal View History

2025-09-24 09:29:12 +08:00
"""
LangGraph工作流节点
实现具体的工作流节点逻辑
"""
import json
import asyncio
from typing import Dict, Any, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
# LangSmith会自动追踪LangChain的LLM调用
from retriver.langgraph.graph_state import (
QueryState,
RetrievalResult,
SufficiencyCheck,
update_state_with_retrieval,
update_state_with_sufficiency_check,
increment_iteration,
finalize_state
)
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever
from retriver.langgraph.langchain_components import (
OneAPILLM,
SufficiencyCheckParser,
QueryComplexityParser,
QUERY_COMPLEXITY_CHECK_PROMPT,
SUFFICIENCY_CHECK_PROMPT,
QUERY_DECOMPOSITION_PROMPT,
SUB_QUERY_GENERATION_PROMPT,
SIMPLE_ANSWER_PROMPT,
FINAL_ANSWER_PROMPT,
format_passages,
format_mixed_passages,
format_sub_queries
)
from retriver.langgraph.es_vector_retriever import ESVectorRetriever
from prompt_loader import get_prompt_loader
class GraphNodes:
"""工作流节点实现类"""
def __init__(
self,
retriever: LangChainHippoRAGRetriever,
llm: OneAPILLM,
keyword: str,
max_parallel_retrievals: int = 2,
simple_retrieval_top_k: int = 3,
complexity_llm: Optional[OneAPILLM] = None,
sufficiency_llm: Optional[OneAPILLM] = None,
skip_llm_generation: bool = False
):
"""
初始化节点处理器
Args:
retriever: HippoRAG检索器
llm: OneAPI LLM (用于生成答案)
keyword: ES索引关键词
max_parallel_retrievals: 最大并行检索数
simple_retrieval_top_k: 简单检索返回文档数
complexity_llm: 复杂度判断专用LLM如果不指定则使用llm
sufficiency_llm: 充分性检查专用LLM如果不指定则使用llm
skip_llm_generation: 是否跳过LLM生成答案仅返回检索结果
"""
self.retriever = retriever
self.llm = llm
self.complexity_llm = complexity_llm or llm # 如果没有指定使用主LLM
self.sufficiency_llm = sufficiency_llm or llm # 如果没有指定使用主LLM
self.keyword = keyword
self.max_parallel_retrievals = max_parallel_retrievals
self.skip_llm_generation = skip_llm_generation
self.sufficiency_parser = SufficiencyCheckParser()
self.complexity_parser = QueryComplexityParser()
# 创建ES向量检索器用于简单查询
self.es_vector_retriever = ESVectorRetriever(
keyword=keyword,
top_k=simple_retrieval_top_k
)
# 获取prompt加载器
self.prompt_loader = get_prompt_loader()
def _extract_response_text(self, response):
"""统一的响应文本提取方法"""
if hasattr(response, 'generations') and response.generations:
return response.generations[0][0].text
elif hasattr(response, 'content'):
return response.content
elif isinstance(response, dict) and 'response' in response:
return response['response']
else:
return str(response)
def query_complexity_check_node(self, state: QueryState) -> QueryState:
"""
查询复杂度判断节点
判断用户查询是否需要复杂的知识图谱推理
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行查询复杂度判断: {state['original_query']}")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('query_complexity_check'):
print(f"[NEXT] 跳过复杂度检查LLM调用默认为简单查询")
state['is_complex'] = False
state['complexity_level'] = 'simple'
return state
# 构建提示词
prompt = QUERY_COMPLEXITY_CHECK_PROMPT.format(
query=state['original_query']
)
try:
# 调用专用的复杂度判断LLM
response = self.complexity_llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析响应
complexity_result = self.complexity_parser.parse(response_text)
# 更新状态
state['query_complexity'] = {
"is_complex": complexity_result.is_complex,
"complexity_level": complexity_result.complexity_level,
"confidence": complexity_result.confidence,
"reason": complexity_result.reason
}
state['is_complex_query'] = complexity_result.is_complex
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 复杂度判断结果: {'复杂' if complexity_result.is_complex else '简单'} "
f"(置信度: {complexity_result.confidence:.2f})")
print(f"理由: {complexity_result.reason}")
return state
except Exception as e:
print(f"[ERROR] 复杂度判断失败: {e}")
# 如果判断失败,默认为复杂查询,走现有逻辑
state['query_complexity'] = {
"is_complex": True,
"complexity_level": "complex",
"confidence": 0.5,
"reason": f"复杂度判断失败: {str(e)},默认使用复杂检索"
}
state['is_complex_query'] = True
return state
def debug_mode_node(self, state: QueryState) -> QueryState:
"""
调试模式节点
根据用户设置的debug_mode参数决定是否覆盖复杂度判断结果
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行调试模式检查: mode={state['debug_mode']}")
# 保存原始复杂度判断结果
original_is_complex = state['is_complex_query']
if state['debug_mode'] == 'simple':
print("[?] 调试模式: 强制使用简单检索路径")
# 强制设置为简单查询,无论原始复杂度判断结果如何
state['is_complex_query'] = False
# 保留原始复杂度判断结果,但添加调试信息
if 'debug_override' not in state['query_complexity']:
state['query_complexity']['debug_override'] = {
'original_complexity': original_is_complex,
'debug_mode': 'simple',
'override_reason': '调试模式强制使用简单检索路径'
}
elif state['debug_mode'] == 'complex':
print("[?] 调试模式: 强制使用复杂检索路径")
# 强制设置为复杂查询,无论原始复杂度判断结果如何
state['is_complex_query'] = True
# 保留原始复杂度判断结果,但添加调试信息
if 'debug_override' not in state['query_complexity']:
state['query_complexity']['debug_override'] = {
'original_complexity': original_is_complex,
'debug_mode': 'complex',
'override_reason': '调试模式强制使用复杂检索路径'
}
else:
# debug_mode = "0" 或其他值,使用原始复杂度判断结果
print("[?] 调试模式: 使用自动复杂度判断结果")
return state
def query_decomposition_node(self, state: QueryState) -> QueryState:
"""
查询分解节点
将用户原始查询分解为2个便于分头检索的子查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 执行查询分解: {state['original_query']}")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('query_decomposition'):
print(f"[NEXT] 跳过查询分解LLM调用使用原查询")
# 不分解,直接使用原查询
state['decomposed_sub_queries'] = []
state['sub_queries'] = [state['original_query']]
return state
# 构建提示词
prompt = QUERY_DECOMPOSITION_PROMPT.format(
original_query=state['original_query']
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
print(f"[?] LLM原始响应: {response_text[:200]}...")
# 清理和提取JSON内容
cleaned_text = response_text.strip()
# 如果响应包含"答案:",提取其后的内容
if "答案:" in cleaned_text or "答案:" in cleaned_text:
# 找到"答案:"后的内容
answer_markers = ["答案:", "答案:"]
for marker in answer_markers:
if marker in cleaned_text:
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
break
# 清理markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试提取JSON部分查找{开始到}结束的内容)
if '{' in cleaned_text and '}' in cleaned_text:
start_idx = cleaned_text.find('{')
end_idx = cleaned_text.rfind('}') + 1
cleaned_text = cleaned_text[start_idx:end_idx]
# 解析子查询
try:
data = json.loads(cleaned_text)
raw_decomposed_queries = data.get('sub_queries', [])
# 处理不同格式的子查询与sub_query_generation_node保持一致
decomposed_sub_queries = []
for item in raw_decomposed_queries:
if isinstance(item, str):
decomposed_sub_queries.append(item)
elif isinstance(item, dict) and 'query' in item:
decomposed_sub_queries.append(item['query'])
elif isinstance(item, dict):
query_text = item.get('text', item.get('content', str(item)))
decomposed_sub_queries.append(query_text)
else:
decomposed_sub_queries.append(str(item))
# 确保都是字符串
decomposed_sub_queries = [str(query).strip() for query in decomposed_sub_queries if query]
print(f"[OK] JSON解析成功获得子查询: {decomposed_sub_queries}")
except json.JSONDecodeError as e:
print(f"[ERROR] JSON解析失败: {e}")
print(f"尝试规则提取...")
# 如果JSON解析失败使用简单规则提取
lines = response_text.split('\n')
decomposed_sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
print(f"规则提取结果: {decomposed_sub_queries}")
# 确保有子查询
if not decomposed_sub_queries:
print(f"[WARNING] LLM未生成有效子查询使用简单分解策略...")
# 简单按标点符号分解
original = state['original_query']
if '' in original:
parts = [part.strip() for part in original.split('') if part.strip()]
if len(parts) >= 2:
decomposed_sub_queries = [parts[0], parts[1]]
print(f"按中文问号分解: {decomposed_sub_queries}")
elif '?' in original:
parts = [part.strip() for part in original.split('?') if part.strip()]
if len(parts) >= 2:
decomposed_sub_queries = [parts[0], parts[1]]
print(f"按英文问号分解: {decomposed_sub_queries}")
# 如果还是没有,使用原查询
if not decomposed_sub_queries:
print(f"[WARNING] 无法自动分解,使用原查询作为两个子查询")
decomposed_sub_queries = [original, original]
elif len(decomposed_sub_queries) == 1:
print(f"[WARNING] 只获得1个子查询补充第二个")
# 如果只有一个子查询,使用原查询作为第二个
decomposed_sub_queries.append(state['original_query'])
# 限制为2个子查询
decomposed_sub_queries = decomposed_sub_queries[:2]
# 更新状态 - 存储初始分解的子查询
state['decomposed_sub_queries'] = decomposed_sub_queries
state['sub_queries'].extend(decomposed_sub_queries)
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[OK] 查询分解完成: {decomposed_sub_queries}")
return state
except Exception as e:
print(f"[ERROR] 查询分解失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_queries = [
state['original_query'] + " 详细信息",
state['original_query'] + " 相关内容"
]
state['decomposed_sub_queries'] = default_sub_queries
state['sub_queries'].extend(default_sub_queries)
return state
def simple_vector_retrieval_node(self, state: QueryState) -> QueryState:
"""
简单向量检索节点
直接与ES向量库中的文本段落进行向量匹配
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[SEARCH] 执行简单向量检索: {state['original_query']}")
try:
# 使用ES向量检索器检索相关文档
documents = self.es_vector_retriever.retrieve(state['original_query'])
# 提取段落内容和来源信息
passages = [doc.page_content for doc in documents]
sources = [f"简单检索-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
# 创建检索结果
retrieval_result = RetrievalResult(
passages=passages,
documents=documents,
sources=sources,
query=state['original_query'],
iteration=state['current_iteration']
)
# 更新状态
updated_state = update_state_with_retrieval(state, retrieval_result)
print(f"[OK] 简单向量检索完成,获得 {len(passages)} 个段落")
return updated_state
except Exception as e:
print(f"[ERROR] 简单向量检索失败: {e}")
# 如果检索失败,返回空结果
retrieval_result = RetrievalResult(
passages=[],
documents=[],
sources=[],
query=state['original_query'],
iteration=state['current_iteration']
)
return update_state_with_retrieval(state, retrieval_result)
def simple_answer_generation_node(self, state: QueryState) -> QueryState:
"""
简单答案生成节点
基于简单检索的结果生成答案
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[NOTE] 生成简单查询答案")
# 格式化检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 检查是否跳过LLM生成全局配置或prompt级别配置
if self.skip_llm_generation or self.prompt_loader.should_skip_llm('simple_answer'):
if self.skip_llm_generation:
print(f"[SEARCH] 跳过LLM生成skip_llm_generation=true直接返回检索结果")
else:
print(f"[NEXT] 跳过简单答案LLM调用simple_answer.skip_llm=true直接返回检索结果")
final_answer = f"检索到的信息:\n{formatted_passages}"
state['final_answer'] = final_answer
return finalize_state(state, final_answer)
# 构建提示词
prompt = SIMPLE_ANSWER_PROMPT.format(
query=state['original_query'],
passages=formatted_passages
)
try:
# 调用LLM生成答案
response = self.llm.invoke(prompt)
final_answer = self._extract_response_text(response)
if not final_answer.strip():
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
# 完成状态
updated_state = finalize_state(state, final_answer)
print(f"[OK] 简单答案生成完成 (长度: {len(final_answer)} 字符)")
return updated_state
except Exception as e:
print(f"[ERROR] 简单答案生成失败: {e}")
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
return finalize_state(state, error_answer)
def initial_retrieval_node(self, state: QueryState) -> QueryState:
"""
并行初始检索节点
使用原始查询和2个分解的子查询并行进行检索
每个查询返回混合检索结果TOP-10个事件节点 + TOP-3个段落节点
Args:
state: 当前状态
Returns:
更新后的状态
"""
# 准备要检索的查询列表:原始查询 + 2个子查询
original_query = state['original_query']
sub_queries = state.get('decomposed_sub_queries', [])
all_queries = [original_query] + sub_queries[:2] # 确保最多3个查询
print(f"[SEARCH] 执行并行初始检索 - {len(all_queries)} 个查询")
for i, query in enumerate(all_queries):
query_type = "原始查询" if i == 0 else f"子查询{i}"
print(f" {query_type}: {query}")
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, List, str]:
"""检索单个查询,返回文档、段落、源信息"""
import time
# 根据查询类型设置标签
if index == 0:
query_label = "原始查询"
else:
query_label = f"子查询{index}"
start_time = time.time()
print(f"[STARTING] {query_label} 开始检索 [{time.strftime('%H:%M:%S', time.localtime(start_time))}]")
try:
documents = self.retriever.invoke(query)
# 检索器现在返回混合结果(事件+段落),不再限制数量
top_documents = documents # 使用所有检索到的文档
passages = [doc.page_content for doc in top_documents]
# 根据查询类型设置源标识,支持混合节点类型
sources = []
for doc in top_documents:
# 优先使用node_id然后是passage_id
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id', 'unknown')
node_type = doc.metadata.get('node_type', 'unknown')
if index == 0:
sources.append(f"原始查询-{node_type}-{doc_id}")
else:
sources.append(f"子查询{index}-{node_type}-{doc_id}")
end_time = time.time()
duration = end_time - start_time
print(f"[OK] {query_label} 检索完成 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒,获得 {len(passages)} 个内容(事件+段落)")
return index, documents, passages, sources, query_label
except Exception as e:
end_time = time.time()
duration = end_time - start_time
print(f"[ERROR] {query_label} 检索失败 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒 - 错误: {e}")
return index, [], [], [], query_label
# 并行执行检索
all_documents = []
all_passages = []
all_sources = []
retrieval_details = {}
import time
parallel_start_time = time.time()
print(f"[FAST] 开始并行执行 {len(all_queries)} 个检索任务 [{time.strftime('%H:%M:%S', time.localtime(parallel_start_time))}]")
with ThreadPoolExecutor(max_workers=min(3, len(all_queries))) as executor:
# 提交检索任务
futures = {
executor.submit(retrieve_single_query, query, i): (query, i)
for i, query in enumerate(all_queries)
}
print(f"[?] 所有 {len(futures)} 个检索任务已提交到线程池")
# 收集结果
for future in as_completed(futures):
query, query_index = futures[future]
try:
index, documents, passages, sources, query_label = future.result()
# 记录检索详情
retrieval_details[query_label] = {
'query': query,
'passages_count': len(passages),
'documents_count': len(documents)
}
# 合并结果(暂时不去重,稍后统一处理)
all_documents.extend(documents)
all_passages.extend(passages)
all_sources.extend(sources)
except Exception as e:
print(f"[ERROR] 查询 {query_index+1} 处理失败: {e}")
parallel_end_time = time.time()
parallel_duration = parallel_end_time - parallel_start_time
print(f"[TARGET] 并行检索全部完成 [{time.strftime('%H:%M:%S', time.localtime(parallel_end_time))}] - 总耗时: {parallel_duration:.2f}")
# 去重处理基于文档ID或内容去重
unique_documents = []
unique_passages = []
unique_sources = []
seen_passage_ids = set()
seen_content_hashes = set()
for i, (doc, passage, source) in enumerate(zip(all_documents, all_passages, all_sources)):
# 尝试使用node_id或passage_id去重支持混合节点类型
doc_id = None
if doc:
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id')
if doc_id and doc_id in seen_passage_ids:
continue
# 使用内容hash去重作为backup
content_hash = hash(passage.strip())
if content_hash in seen_content_hashes:
continue
# 添加到去重后的结果
unique_documents.append(doc)
unique_passages.append(passage)
unique_sources.append(source)
if doc_id:
seen_passage_ids.add(doc_id)
seen_content_hashes.add(content_hash)
removed_count = len(all_passages) - len(unique_passages)
if removed_count > 0:
print(f"[SEARCH] 去重处理: 移除了 {removed_count} 个重复内容")
# 创建检索结果
query_description = f"并行检索: 原始查询 + {len(sub_queries)} 个子查询"
retrieval_result = RetrievalResult(
passages=unique_passages,
documents=unique_documents,
sources=unique_sources,
query=query_description,
iteration=state['current_iteration']
)
# 更新状态
updated_state = update_state_with_retrieval(state, retrieval_result)
# 存储检索详情到状态中,便于后续分析
updated_state['initial_retrieval_details'] = retrieval_details
# 收集PageRank分数信息每个查询的完整PageRank结果
for i, query in enumerate(all_queries):
try:
# 跳过PageRank数据收集以避免LangSmith传输大量数据
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
# 仅设置标识表示数据可用实际数据在HippoRAG内部处理
updated_state['pagerank_data_available'] = True
except Exception as e:
print(f"[WARNING] 收集查询{i+1}的PageRank分数失败: {e}")
total_passages_before = len(all_passages)
total_passages_after = len(unique_passages)
print(f"[SUCCESS] 并行初始检索完成")
print(f" 检索前总内容: {total_passages_before}, 去重后: {total_passages_after}")
print(f" 原始查询: {retrieval_details.get('原始查询', {}).get('passages_count', 0)} 个内容(事件+段落)")
for i in range(1, len(all_queries)):
key = f"子查询{i}"
count = retrieval_details.get(key, {}).get('passages_count', 0)
print(f" 子查询{i}: {count} 个内容(事件+段落)")
return updated_state
def sufficiency_check_node(self, state: QueryState) -> QueryState:
"""
充分性检查节点
判断当前检索到的信息是否足够回答用户查询
包含对分解子查询的处理
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行充分性检查 (迭代 {state['current_iteration']})")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('sufficiency_check'):
print(f"[NEXT] 跳过充分性检查LLM调用默认为充分")
state['is_sufficient'] = True
state['sufficiency_confidence'] = 1.0
return state
# 格式化检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化分解的子查询
decomposed_sub_queries = state.get('decomposed_sub_queries', [])
formatted_decomposed_queries = format_sub_queries(decomposed_sub_queries) if decomposed_sub_queries else ""
# 构建提示词,包含分解的子查询信息
prompt = SUFFICIENCY_CHECK_PROMPT.format(
query=state['original_query'],
passages=formatted_passages,
decomposed_sub_queries=formatted_decomposed_queries
)
# 调用专用的充分性检查LLM
try:
response = self.sufficiency_llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析响应
sufficiency_result = self.sufficiency_parser.parse(response_text)
# 创建充分性检查结果
sufficiency_check = SufficiencyCheck(
is_sufficient=sufficiency_result.is_sufficient,
confidence=sufficiency_result.confidence,
reason=sufficiency_result.reason,
sub_queries=sufficiency_result.sub_queries
)
# 更新状态
updated_state = update_state_with_sufficiency_check(state, sufficiency_check)
# 更新调试信息
updated_state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'} "
f"(置信度: {sufficiency_result.confidence:.2f})")
print(f" 基于 {len(state['all_passages'])} 个段落 (来自原始查询和{len(decomposed_sub_queries)}个子查询)")
if not sufficiency_result.is_sufficient and sufficiency_result.sub_queries:
print(f"[TARGET] 生成新的子查询: {sufficiency_result.sub_queries}")
return updated_state
except Exception as e:
print(f"[ERROR] 充分性检查失败: {e}")
# 如果检查失败,假设不充分并生成默认子查询
sufficiency_check = SufficiencyCheck(
is_sufficient=False,
confidence=0.5,
reason=f"充分性检查失败: {str(e)}",
sub_queries=[state['original_query'] + " 详细信息"]
)
return update_state_with_sufficiency_check(state, sufficiency_check)
def sub_query_generation_node(self, state: QueryState) -> QueryState:
"""
子查询生成节点
如果充分性检查不通过生成子查询
考虑之前已经生成的分解子查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 生成子查询")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('sub_query_generation'):
print(f"[NEXT] 跳过子查询生成LLM调用")
# 不生成新查询,继续下一轮
state['current_iteration'] += 1
return state
# 如果已经有子查询,直接返回
if state['current_sub_queries']:
print(f"[OK] 使用现有子查询: {state['current_sub_queries']}")
return state
# 格式化现有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化之前生成的所有子查询(包括分解的子查询)
previous_sub_queries = state.get('sub_queries', [])
formatted_previous_queries = format_sub_queries(previous_sub_queries) if previous_sub_queries else ""
# 获取充分性检查的不充分原因
sufficiency_check = state.get('sufficiency_check', {})
insufficiency_reason = sufficiency_check.get('reason', '信息不够充分,需要更多相关信息')
# 构建提示词,包含充分性检查的反馈
prompt = SUB_QUERY_GENERATION_PROMPT.format(
original_query=state['original_query'],
existing_passages=formatted_passages,
previous_sub_queries=formatted_previous_queries,
insufficiency_reason=insufficiency_reason
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 清理和提取JSON内容
cleaned_text = response_text.strip()
# 如果响应包含"答案:",提取其后的内容
if "答案:" in cleaned_text or "答案:" in cleaned_text:
# 找到"答案:"后的内容
answer_markers = ["答案:", "答案:"]
for marker in answer_markers:
if marker in cleaned_text:
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
break
# 清理markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试提取JSON部分查找{开始到}结束的内容)
if '{' in cleaned_text and '}' in cleaned_text:
start_idx = cleaned_text.find('{')
end_idx = cleaned_text.rfind('}') + 1
cleaned_text = cleaned_text[start_idx:end_idx]
# 解析子查询
try:
data = json.loads(cleaned_text)
raw_sub_queries = data.get('sub_queries', [])
# 处理不同格式的子查询
sub_queries = []
for item in raw_sub_queries:
if isinstance(item, str):
# 如果是字符串,直接使用
sub_queries.append(item)
elif isinstance(item, dict) and 'query' in item:
# 如果是字典提取query字段
sub_queries.append(item['query'])
elif isinstance(item, dict):
# 如果是字典但没有query字段尝试找到查询内容
query_text = item.get('text', item.get('content', str(item)))
sub_queries.append(query_text)
else:
# 其他情况,转为字符串
sub_queries.append(str(item))
except json.JSONDecodeError:
# 如果JSON解析失败使用简单规则提取
lines = response_text.split('\n')
sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
# 确保有子查询且都是字符串
if not sub_queries:
sub_queries = [state['original_query'] + " 更多细节"]
# 确保所有子查询都是字符串
sub_queries = [str(query).strip() for query in sub_queries if query]
# 更新状态
state['current_sub_queries'] = sub_queries[:2] # 最多2个子查询
state['sub_queries'].extend(state['current_sub_queries'])
print(f"[OK] 生成子查询: {state['current_sub_queries']}")
print(f" (避免与之前的子查询重复: {formatted_previous_queries}")
return state
except Exception as e:
print(f"[ERROR] 子查询生成失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_query = state['original_query'] + " 补充信息"
state['current_sub_queries'] = [default_sub_query]
state['sub_queries'].append(default_sub_query)
return state
def parallel_retrieval_node(self, state: QueryState) -> QueryState:
"""
并行检索节点
使用子查询并行进行检索
Args:
state: 当前状态
Returns:
更新后的状态
"""
sub_queries = state['current_sub_queries']
if not sub_queries:
print("[WARNING] 没有子查询,跳过并行检索")
return state
print(f"[?] 并行检索 {len(sub_queries)} 个子查询")
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, str]:
"""检索单个查询"""
try:
documents = self.retriever.invoke(query)
passages = [doc.page_content for doc in documents]
sources = [f"子查询{index+1}-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
return index, documents, passages, sources
except Exception as e:
print(f"[ERROR] 子查询 {index+1} 检索失败: {e}")
return index, [], [], []
# 并行执行检索
all_new_documents = []
all_new_passages = []
all_new_sources = []
with ThreadPoolExecutor(max_workers=self.max_parallel_retrievals) as executor:
# 提交检索任务
futures = {
executor.submit(retrieve_single_query, query, i): (query, i)
for i, query in enumerate(sub_queries)
}
# 收集结果
for future in as_completed(futures):
query, query_index = futures[future]
try:
index, documents, passages, sources = future.result()
if passages: # 只添加非空结果
all_new_documents.extend(documents)
all_new_passages.extend(passages)
all_new_sources.extend(sources)
print(f"[OK] 子查询 {index+1} 完成,获得 {len(passages)} 个段落")
else:
print(f"[WARNING] 子查询 {index+1} 无结果")
except Exception as e:
print(f"[ERROR] 子查询 {query_index+1} 处理失败: {e}")
# 更新状态
if all_new_passages:
retrieval_result = RetrievalResult(
passages=all_new_passages,
documents=all_new_documents,
sources=all_new_sources,
query=f"并行检索: {', '.join(sub_queries)}",
iteration=state['current_iteration']
)
state = update_state_with_retrieval(state, retrieval_result)
print(f"[SUCCESS] 并行检索完成,总共获得 {len(all_new_passages)} 个新段落")
# 收集子查询的PageRank分数信息
for i, query in enumerate(sub_queries):
try:
# 跳过PageRank数据收集以避免LangSmith传输大量数据
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
# 仅设置标识表示数据可用实际数据在HippoRAG内部处理
state['pagerank_data_available'] = True
except Exception as e:
print(f"[WARNING] 收集并行子查询{i+1}的PageRank分数失败: {e}")
else:
print("[WARNING] 并行检索无有效结果")
# 清空当前子查询
state['current_sub_queries'] = []
return state
def final_answer_generation_node(self, state: QueryState) -> QueryState:
"""
最终答案生成节点
基于所有检索到的信息生成最终答案
Args:
state: 当前状态
Returns:
更新后的状态
"""
# 格式化所有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
formatted_sub_queries = format_sub_queries(state['sub_queries'])
# 检查是否跳过LLM生成
if self.skip_llm_generation:
print(f"[SEARCH] 跳过LLM生成直接返回检索结果")
# 构建格式化的检索结果作为最终答案
retrieval_summary = f"""【检索结果汇总】
查询问题{state['original_query']}
检索到 {len(state['all_passages'])} 个相关段落
{formatted_passages}
"""
if state['sub_queries']:
retrieval_summary += f"""
相关子查询
{formatted_sub_queries}
"""
retrieval_summary += f"""
检索统计
- 查询复杂度{state.get('query_complexity', 'unknown')}
- 是否复杂查询{state.get('is_complex_query', False)}
- 迭代次数{state.get('current_iteration', 0)}
- 信息充分性{state.get('is_sufficient', False)}
"""
# 完成状态
updated_state = finalize_state(state, retrieval_summary)
print(f"[OK] 检索结果返回完成 (长度: {len(retrieval_summary)} 字符)")
return updated_state
# 原有的LLM生成逻辑
print(f"[NOTE] 生成最终答案")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('final_answer'):
print(f"[NEXT] 跳过最终答案LLM调用直接返回检索结果")
final_answer = f"检索到的信息:\n{formatted_passages}"
state['final_answer'] = final_answer
return finalize_state(state, final_answer)
# 构建提示词
prompt = FINAL_ANSWER_PROMPT.format(
original_query=state['original_query'],
all_passages=formatted_passages,
sub_queries=formatted_sub_queries
)
try:
# 调用LLM生成最终答案
response = self.llm.invoke(prompt)
final_answer = self._extract_response_text(response)
if not final_answer.strip():
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
# 完成状态
updated_state = finalize_state(state, final_answer)
print(f"[OK] 最终答案生成完成 (长度: {len(final_answer)} 字符)")
return updated_state
except Exception as e:
print(f"[ERROR] 最终答案生成失败: {e}")
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
return finalize_state(state, error_answer)
def next_iteration_node(self, state: QueryState) -> QueryState:
"""
下一轮迭代节点
增加迭代计数并准备下一轮
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 进入迭代 {state['current_iteration'] + 1}")
# 增加迭代次数
updated_state = increment_iteration(state)
return updated_state