335 lines
13 KiB
Python
335 lines
13 KiB
Python
"""
|
||
基于LangGraph的迭代检索器
|
||
实现智能迭代检索工作流
|
||
"""
|
||
|
||
import time
|
||
from typing import Dict, Any, Optional
|
||
from langgraph.graph import StateGraph, END
|
||
from exceptions import TaskCancelledException, WorkflowStoppedException
|
||
|
||
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
|
||
from retriver.langgraph.graph_nodes import GraphNodes
|
||
from retriver.langgraph.routing_functions import (
|
||
should_continue_retrieval,
|
||
route_by_complexity,
|
||
route_by_debug_mode
|
||
)
|
||
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
|
||
from retriver.langgraph.langchain_components import create_oneapi_llm
|
||
|
||
|
||
class IterativeRetriever:
|
||
"""
|
||
基于LangGraph的迭代检索器
|
||
|
||
实现迭代检索流程:
|
||
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
keyword: str,
|
||
top_k: int = 2,
|
||
max_iterations: int = 2,
|
||
max_parallel_retrievals: int = 2,
|
||
oneapi_key: Optional[str] = None,
|
||
oneapi_base_url: Optional[str] = None,
|
||
model_name: Optional[str] = None,
|
||
embed_model_name: Optional[str] = None,
|
||
complexity_model_name: Optional[str] = None,
|
||
sufficiency_model_name: Optional[str] = None,
|
||
skip_llm_generation: bool = False
|
||
):
|
||
"""
|
||
初始化迭代检索器
|
||
|
||
Args:
|
||
keyword: ES索引关键词
|
||
top_k: 每次检索返回的文档数量
|
||
max_iterations: 最大迭代次数
|
||
max_parallel_retrievals: 最大并行检索数
|
||
oneapi_key: OneAPI密钥
|
||
oneapi_base_url: OneAPI基础URL
|
||
model_name: 主LLM模型名称(用于生成答案)
|
||
embed_model_name: 嵌入模型名称
|
||
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
|
||
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
|
||
skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果)
|
||
"""
|
||
self.keyword = keyword
|
||
self.top_k = top_k
|
||
self.max_iterations = max_iterations
|
||
self.max_parallel_retrievals = max_parallel_retrievals
|
||
self.skip_llm_generation = skip_llm_generation
|
||
|
||
# 使用默认值(如果没有指定特定模型)
|
||
complexity_model_name = complexity_model_name or model_name
|
||
sufficiency_model_name = sufficiency_model_name or model_name
|
||
|
||
# 创建组件
|
||
print("[CONFIG] 初始化检索器组件...")
|
||
self.retriever = create_langchain_hipporag_retriever(
|
||
keyword=keyword,
|
||
top_k=top_k,
|
||
oneapi_key=oneapi_key,
|
||
oneapi_base_url=oneapi_base_url,
|
||
oneapi_model_gen=model_name,
|
||
oneapi_model_embed=embed_model_name
|
||
)
|
||
|
||
# 创建主LLM(用于生成答案)
|
||
self.llm = create_oneapi_llm(
|
||
oneapi_key=oneapi_key,
|
||
oneapi_base_url=oneapi_base_url,
|
||
model_name=model_name
|
||
)
|
||
|
||
# 创建复杂度判断LLM(如果模型不同)
|
||
if complexity_model_name != model_name:
|
||
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
|
||
self.complexity_llm = create_oneapi_llm(
|
||
oneapi_key=oneapi_key,
|
||
oneapi_base_url=oneapi_base_url,
|
||
model_name=complexity_model_name
|
||
)
|
||
else:
|
||
self.complexity_llm = self.llm
|
||
|
||
# 创建充分性检查LLM(如果模型不同)
|
||
if sufficiency_model_name != model_name:
|
||
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
|
||
self.sufficiency_llm = create_oneapi_llm(
|
||
oneapi_key=oneapi_key,
|
||
oneapi_base_url=oneapi_base_url,
|
||
model_name=sufficiency_model_name
|
||
)
|
||
else:
|
||
self.sufficiency_llm = self.llm
|
||
|
||
# 创建节点处理器
|
||
self.nodes = GraphNodes(
|
||
retriever=self.retriever,
|
||
llm=self.llm,
|
||
complexity_llm=self.complexity_llm,
|
||
sufficiency_llm=self.sufficiency_llm,
|
||
keyword=keyword,
|
||
max_parallel_retrievals=max_parallel_retrievals,
|
||
skip_llm_generation=skip_llm_generation
|
||
)
|
||
|
||
# 构建工作流图
|
||
self.workflow = self._build_workflow()
|
||
|
||
print("[OK] 迭代检索器初始化完成")
|
||
|
||
def _build_workflow(self) -> StateGraph:
|
||
"""构建LangGraph工作流"""
|
||
print("[?] 构建工作流图...")
|
||
|
||
# 创建状态图
|
||
workflow = StateGraph(QueryState)
|
||
|
||
# 添加节点
|
||
# 新增:查询复杂度判断节点
|
||
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
|
||
|
||
# 新增:调试模式节点
|
||
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
|
||
|
||
# 简单查询路径
|
||
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
|
||
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
|
||
|
||
# 复杂查询路径(现有hipporag2逻辑)
|
||
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
|
||
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
|
||
workflow.add_node("event_triples_extraction", self.nodes.event_triples_extraction_node) # 新增事件三元组提取节点
|
||
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
|
||
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
|
||
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
|
||
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
|
||
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
|
||
|
||
# 设置入口点:从查询复杂度判断开始
|
||
workflow.set_entry_point("query_complexity_check")
|
||
|
||
# 复杂度检查后进入调试模式节点
|
||
workflow.add_edge("query_complexity_check", "debug_mode_node")
|
||
|
||
# 条件边:根据调试模式和复杂度判断结果决定路径
|
||
workflow.add_conditional_edges(
|
||
"debug_mode_node",
|
||
route_by_debug_mode,
|
||
{
|
||
"simple_vector_retrieval": "simple_vector_retrieval",
|
||
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
|
||
}
|
||
)
|
||
|
||
# 简单查询路径的边
|
||
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
|
||
workflow.add_edge("simple_answer_generation", END)
|
||
|
||
# 复杂查询路径的边(包含查询分解逻辑)
|
||
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
|
||
workflow.add_edge("initial_retrieval", "event_triples_extraction") # 初始检索后提取事件三元组
|
||
workflow.add_edge("event_triples_extraction", "sufficiency_check") # 三元组提取后进入充分性检查
|
||
|
||
# 条件边:根据充分性检查结果决定下一步
|
||
workflow.add_conditional_edges(
|
||
"sufficiency_check",
|
||
should_continue_retrieval,
|
||
{
|
||
"final_answer": "final_answer",
|
||
"parallel_retrieval": "sub_query_generation"
|
||
}
|
||
)
|
||
|
||
workflow.add_edge("sub_query_generation", "parallel_retrieval")
|
||
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后直接进入下一迭代
|
||
workflow.add_edge("next_iteration", "event_triples_extraction") # 迭代后再提取事件三元组
|
||
# 注意:这将形成循环:event_triples_extraction → sufficiency_check → (可能的) parallel_retrieval → next_iteration → event_triples_extraction
|
||
|
||
# 结束节点
|
||
workflow.add_edge("final_answer", END)
|
||
|
||
return workflow.compile()
|
||
|
||
def retrieve(self, query: str, mode: str = "0", task_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
执行迭代检索
|
||
|
||
Args:
|
||
query: 用户查询
|
||
mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
|
||
task_id: 任务ID(可选)
|
||
|
||
Returns:
|
||
包含最终答案和详细信息的字典
|
||
"""
|
||
print(f"[STARTING] 开始迭代检索: {query}")
|
||
start_time = time.time()
|
||
|
||
# 创建初始状态
|
||
initial_state = create_initial_state(
|
||
original_query=query,
|
||
max_iterations=self.max_iterations,
|
||
debug_mode=mode,
|
||
task_id=task_id # 传递任务ID
|
||
)
|
||
initial_state["debug_info"]["start_time"] = start_time
|
||
|
||
try:
|
||
# 执行工作流
|
||
final_state = self.workflow.invoke(initial_state)
|
||
|
||
# 记录结束时间
|
||
end_time = time.time()
|
||
final_state["debug_info"]["end_time"] = end_time
|
||
final_state["debug_info"]["total_time"] = end_time - start_time
|
||
|
||
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}秒")
|
||
|
||
# 返回结果
|
||
return {
|
||
"query": query,
|
||
"answer": final_state["final_answer"],
|
||
"query_complexity": final_state["query_complexity"],
|
||
"is_complex_query": final_state["is_complex_query"],
|
||
"iterations": final_state["current_iteration"],
|
||
"total_passages": len(final_state["all_passages"]),
|
||
"sub_queries": final_state["sub_queries"],
|
||
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
|
||
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
|
||
"sufficiency_check": final_state["sufficiency_check"],
|
||
"all_passages": final_state["all_passages"],
|
||
"debug_info": final_state["debug_info"],
|
||
"state_summary": get_state_summary(final_state),
|
||
"iteration_history": final_state["iteration_history"]
|
||
}
|
||
|
||
except (TaskCancelledException, WorkflowStoppedException) as e:
|
||
print(f"[CANCELLED] 迭代检索被取消: {e}")
|
||
return {
|
||
"query": query,
|
||
"answer": "任务已被取消",
|
||
"cancelled": True,
|
||
"iterations": 0,
|
||
"total_passages": 0,
|
||
"sub_queries": [],
|
||
"debug_info": {"cancelled": True, "message": str(e), "total_time": time.time() - start_time}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 迭代检索失败: {e}")
|
||
return {
|
||
"query": query,
|
||
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
|
||
"error": str(e),
|
||
"iterations": 0,
|
||
"total_passages": 0,
|
||
"sub_queries": [],
|
||
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
|
||
}
|
||
|
||
def retrieve_simple(self, query: str) -> str:
|
||
"""
|
||
简单检索接口,只返回答案
|
||
|
||
Args:
|
||
query: 用户查询
|
||
|
||
Returns:
|
||
最终答案字符串
|
||
"""
|
||
result = self.retrieve(query)
|
||
return result["answer"]
|
||
|
||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||
"""
|
||
获取检索器统计信息
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
return {
|
||
"keyword": self.keyword,
|
||
"top_k": self.top_k,
|
||
"max_iterations": self.max_iterations,
|
||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||
"retriever_type": "IterativeRetriever with LangGraph",
|
||
"model_info": {
|
||
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
|
||
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
|
||
}
|
||
}
|
||
|
||
|
||
def create_iterative_retriever(
|
||
keyword: str,
|
||
top_k: int = 2,
|
||
max_iterations: int = 2,
|
||
max_parallel_retrievals: int = 2,
|
||
**kwargs
|
||
) -> IterativeRetriever:
|
||
"""
|
||
创建迭代检索器的便捷函数
|
||
|
||
Args:
|
||
keyword: ES索引关键词
|
||
top_k: 每次检索返回的文档数量
|
||
max_iterations: 最大迭代次数
|
||
max_parallel_retrievals: 最大并行检索数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
迭代检索器实例
|
||
"""
|
||
return IterativeRetriever(
|
||
keyword=keyword,
|
||
top_k=top_k,
|
||
max_iterations=max_iterations,
|
||
max_parallel_retrievals=max_parallel_retrievals,
|
||
**kwargs
|
||
) |