Files
AIEC-RAG/retriver/langgraph/graph_nodes.py
2025-09-24 09:29:12 +08:00

1032 lines
43 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph工作流节点
实现具体的工作流节点逻辑
"""
import json
import asyncio
from typing import Dict, Any, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
# LangSmith会自动追踪LangChain的LLM调用
from retriver.langgraph.graph_state import (
QueryState,
RetrievalResult,
SufficiencyCheck,
update_state_with_retrieval,
update_state_with_sufficiency_check,
increment_iteration,
finalize_state
)
from retriver.langgraph.langchain_hipporag_retriever import LangChainHippoRAGRetriever
from retriver.langgraph.langchain_components import (
OneAPILLM,
SufficiencyCheckParser,
QueryComplexityParser,
QUERY_COMPLEXITY_CHECK_PROMPT,
SUFFICIENCY_CHECK_PROMPT,
QUERY_DECOMPOSITION_PROMPT,
SUB_QUERY_GENERATION_PROMPT,
SIMPLE_ANSWER_PROMPT,
FINAL_ANSWER_PROMPT,
format_passages,
format_mixed_passages,
format_sub_queries
)
from retriver.langgraph.es_vector_retriever import ESVectorRetriever
from prompt_loader import get_prompt_loader
class GraphNodes:
"""工作流节点实现类"""
def __init__(
self,
retriever: LangChainHippoRAGRetriever,
llm: OneAPILLM,
keyword: str,
max_parallel_retrievals: int = 2,
simple_retrieval_top_k: int = 3,
complexity_llm: Optional[OneAPILLM] = None,
sufficiency_llm: Optional[OneAPILLM] = None,
skip_llm_generation: bool = False
):
"""
初始化节点处理器
Args:
retriever: HippoRAG检索器
llm: OneAPI LLM (用于生成答案)
keyword: ES索引关键词
max_parallel_retrievals: 最大并行检索数
simple_retrieval_top_k: 简单检索返回文档数
complexity_llm: 复杂度判断专用LLM如果不指定则使用llm
sufficiency_llm: 充分性检查专用LLM如果不指定则使用llm
skip_llm_generation: 是否跳过LLM生成答案仅返回检索结果
"""
self.retriever = retriever
self.llm = llm
self.complexity_llm = complexity_llm or llm # 如果没有指定使用主LLM
self.sufficiency_llm = sufficiency_llm or llm # 如果没有指定使用主LLM
self.keyword = keyword
self.max_parallel_retrievals = max_parallel_retrievals
self.skip_llm_generation = skip_llm_generation
self.sufficiency_parser = SufficiencyCheckParser()
self.complexity_parser = QueryComplexityParser()
# 创建ES向量检索器用于简单查询
self.es_vector_retriever = ESVectorRetriever(
keyword=keyword,
top_k=simple_retrieval_top_k
)
# 获取prompt加载器
self.prompt_loader = get_prompt_loader()
def _extract_response_text(self, response):
"""统一的响应文本提取方法"""
if hasattr(response, 'generations') and response.generations:
return response.generations[0][0].text
elif hasattr(response, 'content'):
return response.content
elif isinstance(response, dict) and 'response' in response:
return response['response']
else:
return str(response)
def query_complexity_check_node(self, state: QueryState) -> QueryState:
"""
查询复杂度判断节点
判断用户查询是否需要复杂的知识图谱推理
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行查询复杂度判断: {state['original_query']}")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('query_complexity_check'):
print(f"[NEXT] 跳过复杂度检查LLM调用默认为简单查询")
state['is_complex'] = False
state['complexity_level'] = 'simple'
return state
# 构建提示词
prompt = QUERY_COMPLEXITY_CHECK_PROMPT.format(
query=state['original_query']
)
try:
# 调用专用的复杂度判断LLM
response = self.complexity_llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析响应
complexity_result = self.complexity_parser.parse(response_text)
# 更新状态
state['query_complexity'] = {
"is_complex": complexity_result.is_complex,
"complexity_level": complexity_result.complexity_level,
"confidence": complexity_result.confidence,
"reason": complexity_result.reason
}
state['is_complex_query'] = complexity_result.is_complex
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 复杂度判断结果: {'复杂' if complexity_result.is_complex else '简单'} "
f"(置信度: {complexity_result.confidence:.2f})")
print(f"理由: {complexity_result.reason}")
return state
except Exception as e:
print(f"[ERROR] 复杂度判断失败: {e}")
# 如果判断失败,默认为复杂查询,走现有逻辑
state['query_complexity'] = {
"is_complex": True,
"complexity_level": "complex",
"confidence": 0.5,
"reason": f"复杂度判断失败: {str(e)},默认使用复杂检索"
}
state['is_complex_query'] = True
return state
def debug_mode_node(self, state: QueryState) -> QueryState:
"""
调试模式节点
根据用户设置的debug_mode参数决定是否覆盖复杂度判断结果
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行调试模式检查: mode={state['debug_mode']}")
# 保存原始复杂度判断结果
original_is_complex = state['is_complex_query']
if state['debug_mode'] == 'simple':
print("[?] 调试模式: 强制使用简单检索路径")
# 强制设置为简单查询,无论原始复杂度判断结果如何
state['is_complex_query'] = False
# 保留原始复杂度判断结果,但添加调试信息
if 'debug_override' not in state['query_complexity']:
state['query_complexity']['debug_override'] = {
'original_complexity': original_is_complex,
'debug_mode': 'simple',
'override_reason': '调试模式强制使用简单检索路径'
}
elif state['debug_mode'] == 'complex':
print("[?] 调试模式: 强制使用复杂检索路径")
# 强制设置为复杂查询,无论原始复杂度判断结果如何
state['is_complex_query'] = True
# 保留原始复杂度判断结果,但添加调试信息
if 'debug_override' not in state['query_complexity']:
state['query_complexity']['debug_override'] = {
'original_complexity': original_is_complex,
'debug_mode': 'complex',
'override_reason': '调试模式强制使用复杂检索路径'
}
else:
# debug_mode = "0" 或其他值,使用原始复杂度判断结果
print("[?] 调试模式: 使用自动复杂度判断结果")
return state
def query_decomposition_node(self, state: QueryState) -> QueryState:
"""
查询分解节点
将用户原始查询分解为2个便于分头检索的子查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 执行查询分解: {state['original_query']}")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('query_decomposition'):
print(f"[NEXT] 跳过查询分解LLM调用使用原查询")
# 不分解,直接使用原查询
state['decomposed_sub_queries'] = []
state['sub_queries'] = [state['original_query']]
return state
# 构建提示词
prompt = QUERY_DECOMPOSITION_PROMPT.format(
original_query=state['original_query']
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
print(f"[?] LLM原始响应: {response_text[:200]}...")
# 清理和提取JSON内容
cleaned_text = response_text.strip()
# 如果响应包含"答案:",提取其后的内容
if "答案:" in cleaned_text or "答案:" in cleaned_text:
# 找到"答案:"后的内容
answer_markers = ["答案:", "答案:"]
for marker in answer_markers:
if marker in cleaned_text:
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
break
# 清理markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试提取JSON部分查找{开始到}结束的内容)
if '{' in cleaned_text and '}' in cleaned_text:
start_idx = cleaned_text.find('{')
end_idx = cleaned_text.rfind('}') + 1
cleaned_text = cleaned_text[start_idx:end_idx]
# 解析子查询
try:
data = json.loads(cleaned_text)
raw_decomposed_queries = data.get('sub_queries', [])
# 处理不同格式的子查询与sub_query_generation_node保持一致
decomposed_sub_queries = []
for item in raw_decomposed_queries:
if isinstance(item, str):
decomposed_sub_queries.append(item)
elif isinstance(item, dict) and 'query' in item:
decomposed_sub_queries.append(item['query'])
elif isinstance(item, dict):
query_text = item.get('text', item.get('content', str(item)))
decomposed_sub_queries.append(query_text)
else:
decomposed_sub_queries.append(str(item))
# 确保都是字符串
decomposed_sub_queries = [str(query).strip() for query in decomposed_sub_queries if query]
print(f"[OK] JSON解析成功获得子查询: {decomposed_sub_queries}")
except json.JSONDecodeError as e:
print(f"[ERROR] JSON解析失败: {e}")
print(f"尝试规则提取...")
# 如果JSON解析失败使用简单规则提取
lines = response_text.split('\n')
decomposed_sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
print(f"规则提取结果: {decomposed_sub_queries}")
# 确保有子查询
if not decomposed_sub_queries:
print(f"[WARNING] LLM未生成有效子查询使用简单分解策略...")
# 简单按标点符号分解
original = state['original_query']
if '' in original:
parts = [part.strip() for part in original.split('') if part.strip()]
if len(parts) >= 2:
decomposed_sub_queries = [parts[0], parts[1]]
print(f"按中文问号分解: {decomposed_sub_queries}")
elif '?' in original:
parts = [part.strip() for part in original.split('?') if part.strip()]
if len(parts) >= 2:
decomposed_sub_queries = [parts[0], parts[1]]
print(f"按英文问号分解: {decomposed_sub_queries}")
# 如果还是没有,使用原查询
if not decomposed_sub_queries:
print(f"[WARNING] 无法自动分解,使用原查询作为两个子查询")
decomposed_sub_queries = [original, original]
elif len(decomposed_sub_queries) == 1:
print(f"[WARNING] 只获得1个子查询补充第二个")
# 如果只有一个子查询,使用原查询作为第二个
decomposed_sub_queries.append(state['original_query'])
# 限制为2个子查询
decomposed_sub_queries = decomposed_sub_queries[:2]
# 更新状态 - 存储初始分解的子查询
state['decomposed_sub_queries'] = decomposed_sub_queries
state['sub_queries'].extend(decomposed_sub_queries)
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
print(f"[OK] 查询分解完成: {decomposed_sub_queries}")
return state
except Exception as e:
print(f"[ERROR] 查询分解失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_queries = [
state['original_query'] + " 详细信息",
state['original_query'] + " 相关内容"
]
state['decomposed_sub_queries'] = default_sub_queries
state['sub_queries'].extend(default_sub_queries)
return state
def simple_vector_retrieval_node(self, state: QueryState) -> QueryState:
"""
简单向量检索节点
直接与ES向量库中的文本段落进行向量匹配
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[SEARCH] 执行简单向量检索: {state['original_query']}")
try:
# 使用ES向量检索器检索相关文档
documents = self.es_vector_retriever.retrieve(state['original_query'])
# 提取段落内容和来源信息
passages = [doc.page_content for doc in documents]
sources = [f"简单检索-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
# 创建检索结果
retrieval_result = RetrievalResult(
passages=passages,
documents=documents,
sources=sources,
query=state['original_query'],
iteration=state['current_iteration']
)
# 更新状态
updated_state = update_state_with_retrieval(state, retrieval_result)
print(f"[OK] 简单向量检索完成,获得 {len(passages)} 个段落")
return updated_state
except Exception as e:
print(f"[ERROR] 简单向量检索失败: {e}")
# 如果检索失败,返回空结果
retrieval_result = RetrievalResult(
passages=[],
documents=[],
sources=[],
query=state['original_query'],
iteration=state['current_iteration']
)
return update_state_with_retrieval(state, retrieval_result)
def simple_answer_generation_node(self, state: QueryState) -> QueryState:
"""
简单答案生成节点
基于简单检索的结果生成答案
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[NOTE] 生成简单查询答案")
# 格式化检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 检查是否跳过LLM生成全局配置或prompt级别配置
if self.skip_llm_generation or self.prompt_loader.should_skip_llm('simple_answer'):
if self.skip_llm_generation:
print(f"[SEARCH] 跳过LLM生成skip_llm_generation=true直接返回检索结果")
else:
print(f"[NEXT] 跳过简单答案LLM调用simple_answer.skip_llm=true直接返回检索结果")
final_answer = f"检索到的信息:\n{formatted_passages}"
state['final_answer'] = final_answer
return finalize_state(state, final_answer)
# 构建提示词
prompt = SIMPLE_ANSWER_PROMPT.format(
query=state['original_query'],
passages=formatted_passages
)
try:
# 调用LLM生成答案
response = self.llm.invoke(prompt)
final_answer = self._extract_response_text(response)
if not final_answer.strip():
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
# 完成状态
updated_state = finalize_state(state, final_answer)
print(f"[OK] 简单答案生成完成 (长度: {len(final_answer)} 字符)")
return updated_state
except Exception as e:
print(f"[ERROR] 简单答案生成失败: {e}")
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
return finalize_state(state, error_answer)
def initial_retrieval_node(self, state: QueryState) -> QueryState:
"""
并行初始检索节点
使用原始查询和2个分解的子查询并行进行检索
每个查询返回混合检索结果TOP-10个事件节点 + TOP-3个段落节点
Args:
state: 当前状态
Returns:
更新后的状态
"""
# 准备要检索的查询列表:原始查询 + 2个子查询
original_query = state['original_query']
sub_queries = state.get('decomposed_sub_queries', [])
all_queries = [original_query] + sub_queries[:2] # 确保最多3个查询
print(f"[SEARCH] 执行并行初始检索 - {len(all_queries)} 个查询")
for i, query in enumerate(all_queries):
query_type = "原始查询" if i == 0 else f"子查询{i}"
print(f" {query_type}: {query}")
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, List, str]:
"""检索单个查询,返回文档、段落、源信息"""
import time
# 根据查询类型设置标签
if index == 0:
query_label = "原始查询"
else:
query_label = f"子查询{index}"
start_time = time.time()
print(f"[STARTING] {query_label} 开始检索 [{time.strftime('%H:%M:%S', time.localtime(start_time))}]")
try:
documents = self.retriever.invoke(query)
# 检索器现在返回混合结果(事件+段落),不再限制数量
top_documents = documents # 使用所有检索到的文档
passages = [doc.page_content for doc in top_documents]
# 根据查询类型设置源标识,支持混合节点类型
sources = []
for doc in top_documents:
# 优先使用node_id然后是passage_id
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id', 'unknown')
node_type = doc.metadata.get('node_type', 'unknown')
if index == 0:
sources.append(f"原始查询-{node_type}-{doc_id}")
else:
sources.append(f"子查询{index}-{node_type}-{doc_id}")
end_time = time.time()
duration = end_time - start_time
print(f"[OK] {query_label} 检索完成 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒,获得 {len(passages)} 个内容(事件+段落)")
return index, documents, passages, sources, query_label
except Exception as e:
end_time = time.time()
duration = end_time - start_time
print(f"[ERROR] {query_label} 检索失败 [{time.strftime('%H:%M:%S', time.localtime(end_time))}] - 耗时: {duration:.2f}秒 - 错误: {e}")
return index, [], [], [], query_label
# 并行执行检索
all_documents = []
all_passages = []
all_sources = []
retrieval_details = {}
import time
parallel_start_time = time.time()
print(f"[FAST] 开始并行执行 {len(all_queries)} 个检索任务 [{time.strftime('%H:%M:%S', time.localtime(parallel_start_time))}]")
with ThreadPoolExecutor(max_workers=min(3, len(all_queries))) as executor:
# 提交检索任务
futures = {
executor.submit(retrieve_single_query, query, i): (query, i)
for i, query in enumerate(all_queries)
}
print(f"[?] 所有 {len(futures)} 个检索任务已提交到线程池")
# 收集结果
for future in as_completed(futures):
query, query_index = futures[future]
try:
index, documents, passages, sources, query_label = future.result()
# 记录检索详情
retrieval_details[query_label] = {
'query': query,
'passages_count': len(passages),
'documents_count': len(documents)
}
# 合并结果(暂时不去重,稍后统一处理)
all_documents.extend(documents)
all_passages.extend(passages)
all_sources.extend(sources)
except Exception as e:
print(f"[ERROR] 查询 {query_index+1} 处理失败: {e}")
parallel_end_time = time.time()
parallel_duration = parallel_end_time - parallel_start_time
print(f"[TARGET] 并行检索全部完成 [{time.strftime('%H:%M:%S', time.localtime(parallel_end_time))}] - 总耗时: {parallel_duration:.2f}")
# 去重处理基于文档ID或内容去重
unique_documents = []
unique_passages = []
unique_sources = []
seen_passage_ids = set()
seen_content_hashes = set()
for i, (doc, passage, source) in enumerate(zip(all_documents, all_passages, all_sources)):
# 尝试使用node_id或passage_id去重支持混合节点类型
doc_id = None
if doc:
doc_id = doc.metadata.get('node_id') or doc.metadata.get('passage_id')
if doc_id and doc_id in seen_passage_ids:
continue
# 使用内容hash去重作为backup
content_hash = hash(passage.strip())
if content_hash in seen_content_hashes:
continue
# 添加到去重后的结果
unique_documents.append(doc)
unique_passages.append(passage)
unique_sources.append(source)
if doc_id:
seen_passage_ids.add(doc_id)
seen_content_hashes.add(content_hash)
removed_count = len(all_passages) - len(unique_passages)
if removed_count > 0:
print(f"[SEARCH] 去重处理: 移除了 {removed_count} 个重复内容")
# 创建检索结果
query_description = f"并行检索: 原始查询 + {len(sub_queries)} 个子查询"
retrieval_result = RetrievalResult(
passages=unique_passages,
documents=unique_documents,
sources=unique_sources,
query=query_description,
iteration=state['current_iteration']
)
# 更新状态
updated_state = update_state_with_retrieval(state, retrieval_result)
# 存储检索详情到状态中,便于后续分析
updated_state['initial_retrieval_details'] = retrieval_details
# 收集PageRank分数信息每个查询的完整PageRank结果
for i, query in enumerate(all_queries):
try:
# 跳过PageRank数据收集以避免LangSmith传输大量数据
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
# 仅设置标识表示数据可用实际数据在HippoRAG内部处理
updated_state['pagerank_data_available'] = True
except Exception as e:
print(f"[WARNING] 收集查询{i+1}的PageRank分数失败: {e}")
total_passages_before = len(all_passages)
total_passages_after = len(unique_passages)
print(f"[SUCCESS] 并行初始检索完成")
print(f" 检索前总内容: {total_passages_before}, 去重后: {total_passages_after}")
print(f" 原始查询: {retrieval_details.get('原始查询', {}).get('passages_count', 0)} 个内容(事件+段落)")
for i in range(1, len(all_queries)):
key = f"子查询{i}"
count = retrieval_details.get(key, {}).get('passages_count', 0)
print(f" 子查询{i}: {count} 个内容(事件+段落)")
return updated_state
def sufficiency_check_node(self, state: QueryState) -> QueryState:
"""
充分性检查节点
判断当前检索到的信息是否足够回答用户查询
包含对分解子查询的处理
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 执行充分性检查 (迭代 {state['current_iteration']})")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('sufficiency_check'):
print(f"[NEXT] 跳过充分性检查LLM调用默认为充分")
state['is_sufficient'] = True
state['sufficiency_confidence'] = 1.0
return state
# 格式化检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化分解的子查询
decomposed_sub_queries = state.get('decomposed_sub_queries', [])
formatted_decomposed_queries = format_sub_queries(decomposed_sub_queries) if decomposed_sub_queries else ""
# 构建提示词,包含分解的子查询信息
prompt = SUFFICIENCY_CHECK_PROMPT.format(
query=state['original_query'],
passages=formatted_passages,
decomposed_sub_queries=formatted_decomposed_queries
)
# 调用专用的充分性检查LLM
try:
response = self.sufficiency_llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 解析响应
sufficiency_result = self.sufficiency_parser.parse(response_text)
# 创建充分性检查结果
sufficiency_check = SufficiencyCheck(
is_sufficient=sufficiency_result.is_sufficient,
confidence=sufficiency_result.confidence,
reason=sufficiency_result.reason,
sub_queries=sufficiency_result.sub_queries
)
# 更新状态
updated_state = update_state_with_sufficiency_check(state, sufficiency_check)
# 更新调试信息
updated_state["debug_info"]["llm_calls"] += 1
print(f"[INFO] 充分性检查结果: {'充分' if sufficiency_result.is_sufficient else '不充分'} "
f"(置信度: {sufficiency_result.confidence:.2f})")
print(f" 基于 {len(state['all_passages'])} 个段落 (来自原始查询和{len(decomposed_sub_queries)}个子查询)")
if not sufficiency_result.is_sufficient and sufficiency_result.sub_queries:
print(f"[TARGET] 生成新的子查询: {sufficiency_result.sub_queries}")
return updated_state
except Exception as e:
print(f"[ERROR] 充分性检查失败: {e}")
# 如果检查失败,假设不充分并生成默认子查询
sufficiency_check = SufficiencyCheck(
is_sufficient=False,
confidence=0.5,
reason=f"充分性检查失败: {str(e)}",
sub_queries=[state['original_query'] + " 详细信息"]
)
return update_state_with_sufficiency_check(state, sufficiency_check)
def sub_query_generation_node(self, state: QueryState) -> QueryState:
"""
子查询生成节点
如果充分性检查不通过,生成子查询
考虑之前已经生成的分解子查询
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[TARGET] 生成子查询")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('sub_query_generation'):
print(f"[NEXT] 跳过子查询生成LLM调用")
# 不生成新查询,继续下一轮
state['current_iteration'] += 1
return state
# 如果已经有子查询,直接返回
if state['current_sub_queries']:
print(f"[OK] 使用现有子查询: {state['current_sub_queries']}")
return state
# 格式化现有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
# 格式化之前生成的所有子查询(包括分解的子查询)
previous_sub_queries = state.get('sub_queries', [])
formatted_previous_queries = format_sub_queries(previous_sub_queries) if previous_sub_queries else ""
# 获取充分性检查的不充分原因
sufficiency_check = state.get('sufficiency_check', {})
insufficiency_reason = sufficiency_check.get('reason', '信息不够充分,需要更多相关信息')
# 构建提示词,包含充分性检查的反馈
prompt = SUB_QUERY_GENERATION_PROMPT.format(
original_query=state['original_query'],
existing_passages=formatted_passages,
previous_sub_queries=formatted_previous_queries,
insufficiency_reason=insufficiency_reason
)
try:
# 调用LLM生成子查询
response = self.llm.invoke(prompt)
response_text = self._extract_response_text(response)
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 清理和提取JSON内容
cleaned_text = response_text.strip()
# 如果响应包含"答案:",提取其后的内容
if "答案:" in cleaned_text or "答案:" in cleaned_text:
# 找到"答案:"后的内容
answer_markers = ["答案:", "答案:"]
for marker in answer_markers:
if marker in cleaned_text:
cleaned_text = cleaned_text.split(marker, 1)[1].strip()
break
# 清理markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:] # 移除 ```json
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:] # 移除 ```
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3] # 移除结尾的 ```
cleaned_text = cleaned_text.strip()
# 尝试提取JSON部分查找{开始到}结束的内容)
if '{' in cleaned_text and '}' in cleaned_text:
start_idx = cleaned_text.find('{')
end_idx = cleaned_text.rfind('}') + 1
cleaned_text = cleaned_text[start_idx:end_idx]
# 解析子查询
try:
data = json.loads(cleaned_text)
raw_sub_queries = data.get('sub_queries', [])
# 处理不同格式的子查询
sub_queries = []
for item in raw_sub_queries:
if isinstance(item, str):
# 如果是字符串,直接使用
sub_queries.append(item)
elif isinstance(item, dict) and 'query' in item:
# 如果是字典提取query字段
sub_queries.append(item['query'])
elif isinstance(item, dict):
# 如果是字典但没有query字段尝试找到查询内容
query_text = item.get('text', item.get('content', str(item)))
sub_queries.append(query_text)
else:
# 其他情况,转为字符串
sub_queries.append(str(item))
except json.JSONDecodeError:
# 如果JSON解析失败使用简单规则提取
lines = response_text.split('\n')
sub_queries = [line.strip() for line in lines if '?' in line and len(line.strip()) > 10][:2]
# 确保有子查询且都是字符串
if not sub_queries:
sub_queries = [state['original_query'] + " 更多细节"]
# 确保所有子查询都是字符串
sub_queries = [str(query).strip() for query in sub_queries if query]
# 更新状态
state['current_sub_queries'] = sub_queries[:2] # 最多2个子查询
state['sub_queries'].extend(state['current_sub_queries'])
print(f"[OK] 生成子查询: {state['current_sub_queries']}")
print(f" (避免与之前的子查询重复: {formatted_previous_queries}")
return state
except Exception as e:
print(f"[ERROR] 子查询生成失败: {e}")
# 如果生成失败,使用默认子查询
default_sub_query = state['original_query'] + " 补充信息"
state['current_sub_queries'] = [default_sub_query]
state['sub_queries'].append(default_sub_query)
return state
def parallel_retrieval_node(self, state: QueryState) -> QueryState:
"""
并行检索节点
使用子查询并行进行检索
Args:
state: 当前状态
Returns:
更新后的状态
"""
sub_queries = state['current_sub_queries']
if not sub_queries:
print("[WARNING] 没有子查询,跳过并行检索")
return state
print(f"[?] 并行检索 {len(sub_queries)} 个子查询")
def retrieve_single_query(query: str, index: int) -> Tuple[int, List, List, str]:
"""检索单个查询"""
try:
documents = self.retriever.invoke(query)
passages = [doc.page_content for doc in documents]
sources = [f"子查询{index+1}-{doc.metadata.get('passage_id', 'unknown')}" for doc in documents]
return index, documents, passages, sources
except Exception as e:
print(f"[ERROR] 子查询 {index+1} 检索失败: {e}")
return index, [], [], []
# 并行执行检索
all_new_documents = []
all_new_passages = []
all_new_sources = []
with ThreadPoolExecutor(max_workers=self.max_parallel_retrievals) as executor:
# 提交检索任务
futures = {
executor.submit(retrieve_single_query, query, i): (query, i)
for i, query in enumerate(sub_queries)
}
# 收集结果
for future in as_completed(futures):
query, query_index = futures[future]
try:
index, documents, passages, sources = future.result()
if passages: # 只添加非空结果
all_new_documents.extend(documents)
all_new_passages.extend(passages)
all_new_sources.extend(sources)
print(f"[OK] 子查询 {index+1} 完成,获得 {len(passages)} 个段落")
else:
print(f"[WARNING] 子查询 {index+1} 无结果")
except Exception as e:
print(f"[ERROR] 子查询 {query_index+1} 处理失败: {e}")
# 更新状态
if all_new_passages:
retrieval_result = RetrievalResult(
passages=all_new_passages,
documents=all_new_documents,
sources=all_new_sources,
query=f"并行检索: {', '.join(sub_queries)}",
iteration=state['current_iteration']
)
state = update_state_with_retrieval(state, retrieval_result)
print(f"[SUCCESS] 并行检索完成,总共获得 {len(all_new_passages)} 个新段落")
# 收集子查询的PageRank分数信息
for i, query in enumerate(sub_queries):
try:
# 跳过PageRank数据收集以避免LangSmith传输大量数据
# complete_ppr_info = self.retriever.get_complete_pagerank_scores(query)
# 仅设置标识表示数据可用实际数据在HippoRAG内部处理
state['pagerank_data_available'] = True
except Exception as e:
print(f"[WARNING] 收集并行子查询{i+1}的PageRank分数失败: {e}")
else:
print("[WARNING] 并行检索无有效结果")
# 清空当前子查询
state['current_sub_queries'] = []
return state
def final_answer_generation_node(self, state: QueryState) -> QueryState:
"""
最终答案生成节点
基于所有检索到的信息生成最终答案
Args:
state: 当前状态
Returns:
更新后的状态
"""
# 格式化所有检索结果如果有all_documents则使用混合格式否则使用传统格式
if 'all_documents' in state and state['all_documents']:
formatted_passages = format_mixed_passages(state['all_documents'])
else:
formatted_passages = format_passages(state['all_passages'])
formatted_sub_queries = format_sub_queries(state['sub_queries'])
# 检查是否跳过LLM生成
if self.skip_llm_generation:
print(f"[SEARCH] 跳过LLM生成直接返回检索结果")
# 构建格式化的检索结果作为最终答案
retrieval_summary = f"""【检索结果汇总】
查询问题:{state['original_query']}
检索到 {len(state['all_passages'])} 个相关段落:
{formatted_passages}
"""
if state['sub_queries']:
retrieval_summary += f"""
相关子查询:
{formatted_sub_queries}
"""
retrieval_summary += f"""
检索统计:
- 查询复杂度:{state.get('query_complexity', 'unknown')}
- 是否复杂查询:{state.get('is_complex_query', False)}
- 迭代次数:{state.get('current_iteration', 0)}
- 信息充分性:{state.get('is_sufficient', False)}
"""
# 完成状态
updated_state = finalize_state(state, retrieval_summary)
print(f"[OK] 检索结果返回完成 (长度: {len(retrieval_summary)} 字符)")
return updated_state
# 原有的LLM生成逻辑
print(f"[NOTE] 生成最终答案")
# 检查是否跳过LLM
if self.prompt_loader.should_skip_llm('final_answer'):
print(f"[NEXT] 跳过最终答案LLM调用直接返回检索结果")
final_answer = f"检索到的信息:\n{formatted_passages}"
state['final_answer'] = final_answer
return finalize_state(state, final_answer)
# 构建提示词
prompt = FINAL_ANSWER_PROMPT.format(
original_query=state['original_query'],
all_passages=formatted_passages,
sub_queries=formatted_sub_queries
)
try:
# 调用LLM生成最终答案
response = self.llm.invoke(prompt)
final_answer = self._extract_response_text(response)
if not final_answer.strip():
final_answer = "抱歉,基于当前检索到的信息,无法提供完整的答案。"
# 完成状态
updated_state = finalize_state(state, final_answer)
print(f"[OK] 最终答案生成完成 (长度: {len(final_answer)} 字符)")
return updated_state
except Exception as e:
print(f"[ERROR] 最终答案生成失败: {e}")
error_answer = f"抱歉,在生成答案时遇到错误: {str(e)}"
return finalize_state(state, error_answer)
def next_iteration_node(self, state: QueryState) -> QueryState:
"""
下一轮迭代节点
增加迭代计数并准备下一轮
Args:
state: 当前状态
Returns:
更新后的状态
"""
print(f"[?] 进入迭代 {state['current_iteration'] + 1}")
# 增加迭代次数
updated_state = increment_iteration(state)
return updated_state