Files
AIEC-new/AIEC-RAG/retriver/langgraph/iterative_retriever.py

335 lines
13 KiB
Python
Raw Permalink Normal View History

2025-10-17 09:31:28 +08:00
"""
基于LangGraph的迭代检索器
实现智能迭代检索工作流
"""
import time
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from exceptions import TaskCancelledException, WorkflowStoppedException
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
from retriver.langgraph.graph_nodes import GraphNodes
from retriver.langgraph.routing_functions import (
should_continue_retrieval,
route_by_complexity,
route_by_debug_mode
)
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
from retriver.langgraph.langchain_components import create_oneapi_llm
class IterativeRetriever:
"""
基于LangGraph的迭代检索器
实现迭代检索流程
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
"""
def __init__(
self,
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None,
embed_model_name: Optional[str] = None,
complexity_model_name: Optional[str] = None,
sufficiency_model_name: Optional[str] = None,
skip_llm_generation: bool = False
):
"""
初始化迭代检索器
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
model_name: 主LLM模型名称(用于生成答案)
embed_model_name: 嵌入模型名称
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
skip_llm_generation: 是否跳过LLM生成答案仅返回检索结果
"""
self.keyword = keyword
self.top_k = top_k
self.max_iterations = max_iterations
self.max_parallel_retrievals = max_parallel_retrievals
self.skip_llm_generation = skip_llm_generation
# 使用默认值(如果没有指定特定模型)
complexity_model_name = complexity_model_name or model_name
sufficiency_model_name = sufficiency_model_name or model_name
# 创建组件
print("[CONFIG] 初始化检索器组件...")
self.retriever = create_langchain_hipporag_retriever(
keyword=keyword,
top_k=top_k,
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
oneapi_model_gen=model_name,
oneapi_model_embed=embed_model_name
)
# 创建主LLM用于生成答案
self.llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=model_name
)
# 创建复杂度判断LLM如果模型不同
if complexity_model_name != model_name:
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
self.complexity_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=complexity_model_name
)
else:
self.complexity_llm = self.llm
# 创建充分性检查LLM如果模型不同
if sufficiency_model_name != model_name:
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
self.sufficiency_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=sufficiency_model_name
)
else:
self.sufficiency_llm = self.llm
# 创建节点处理器
self.nodes = GraphNodes(
retriever=self.retriever,
llm=self.llm,
complexity_llm=self.complexity_llm,
sufficiency_llm=self.sufficiency_llm,
keyword=keyword,
max_parallel_retrievals=max_parallel_retrievals,
skip_llm_generation=skip_llm_generation
)
# 构建工作流图
self.workflow = self._build_workflow()
print("[OK] 迭代检索器初始化完成")
def _build_workflow(self) -> StateGraph:
"""构建LangGraph工作流"""
print("[?] 构建工作流图...")
# 创建状态图
workflow = StateGraph(QueryState)
# 添加节点
# 新增:查询复杂度判断节点
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
# 新增:调试模式节点
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
# 简单查询路径
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
# 复杂查询路径现有hipporag2逻辑
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
workflow.add_node("event_triples_extraction", self.nodes.event_triples_extraction_node) # 新增事件三元组提取节点
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
# 设置入口点:从查询复杂度判断开始
workflow.set_entry_point("query_complexity_check")
# 复杂度检查后进入调试模式节点
workflow.add_edge("query_complexity_check", "debug_mode_node")
# 条件边:根据调试模式和复杂度判断结果决定路径
workflow.add_conditional_edges(
"debug_mode_node",
route_by_debug_mode,
{
"simple_vector_retrieval": "simple_vector_retrieval",
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
}
)
# 简单查询路径的边
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
workflow.add_edge("simple_answer_generation", END)
# 复杂查询路径的边(包含查询分解逻辑)
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
workflow.add_edge("initial_retrieval", "event_triples_extraction") # 初始检索后提取事件三元组
workflow.add_edge("event_triples_extraction", "sufficiency_check") # 三元组提取后进入充分性检查
# 条件边:根据充分性检查结果决定下一步
workflow.add_conditional_edges(
"sufficiency_check",
should_continue_retrieval,
{
"final_answer": "final_answer",
"parallel_retrieval": "sub_query_generation"
}
)
workflow.add_edge("sub_query_generation", "parallel_retrieval")
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后直接进入下一迭代
workflow.add_edge("next_iteration", "event_triples_extraction") # 迭代后再提取事件三元组
# 注意这将形成循环event_triples_extraction → sufficiency_check → (可能的) parallel_retrieval → next_iteration → event_triples_extraction
# 结束节点
workflow.add_edge("final_answer", END)
return workflow.compile()
def retrieve(self, query: str, mode: str = "0", task_id: Optional[str] = None) -> Dict[str, Any]:
"""
执行迭代检索
Args:
query: 用户查询
mode: 调试模式"0"=自动判断"simple"=强制简单路径"complex"=强制复杂路径
task_id: 任务ID可选
Returns:
包含最终答案和详细信息的字典
"""
print(f"[STARTING] 开始迭代检索: {query}")
start_time = time.time()
# 创建初始状态
initial_state = create_initial_state(
original_query=query,
max_iterations=self.max_iterations,
debug_mode=mode,
task_id=task_id # 传递任务ID
)
initial_state["debug_info"]["start_time"] = start_time
try:
# 执行工作流
final_state = self.workflow.invoke(initial_state)
# 记录结束时间
end_time = time.time()
final_state["debug_info"]["end_time"] = end_time
final_state["debug_info"]["total_time"] = end_time - start_time
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}")
# 返回结果
return {
"query": query,
"answer": final_state["final_answer"],
"query_complexity": final_state["query_complexity"],
"is_complex_query": final_state["is_complex_query"],
"iterations": final_state["current_iteration"],
"total_passages": len(final_state["all_passages"]),
"sub_queries": final_state["sub_queries"],
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
"sufficiency_check": final_state["sufficiency_check"],
"all_passages": final_state["all_passages"],
"debug_info": final_state["debug_info"],
"state_summary": get_state_summary(final_state),
"iteration_history": final_state["iteration_history"]
}
except (TaskCancelledException, WorkflowStoppedException) as e:
print(f"[CANCELLED] 迭代检索被取消: {e}")
return {
"query": query,
"answer": "任务已被取消",
"cancelled": True,
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"cancelled": True, "message": str(e), "total_time": time.time() - start_time}
}
except Exception as e:
print(f"[ERROR] 迭代检索失败: {e}")
return {
"query": query,
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
"error": str(e),
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
}
def retrieve_simple(self, query: str) -> str:
"""
简单检索接口只返回答案
Args:
query: 用户查询
Returns:
最终答案字符串
"""
result = self.retrieve(query)
return result["answer"]
def get_retrieval_stats(self) -> Dict[str, Any]:
"""
获取检索器统计信息
Returns:
统计信息字典
"""
return {
"keyword": self.keyword,
"top_k": self.top_k,
"max_iterations": self.max_iterations,
"max_parallel_retrievals": self.max_parallel_retrievals,
"retriever_type": "IterativeRetriever with LangGraph",
"model_info": {
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
}
}
def create_iterative_retriever(
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
**kwargs
) -> IterativeRetriever:
"""
创建迭代检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
**kwargs: 其他参数
Returns:
迭代检索器实例
"""
return IterativeRetriever(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
**kwargs
)