Files
AIEC-RAG---/AIEC-RAG/retriver/langgraph/iterative_retriever.py
2025-09-25 10:33:37 +08:00

318 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基于LangGraph的迭代检索器
实现智能迭代检索工作流
"""
import time
from typing import Dict, Any, Optional
from langgraph.graph import StateGraph, END
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
from retriver.langgraph.graph_nodes import GraphNodes
from retriver.langgraph.routing_functions import (
should_continue_retrieval,
route_by_complexity,
route_by_debug_mode
)
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
from retriver.langgraph.langchain_components import create_oneapi_llm
class IterativeRetriever:
"""
基于LangGraph的迭代检索器
实现迭代检索流程:
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
"""
def __init__(
self,
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None,
embed_model_name: Optional[str] = None,
complexity_model_name: Optional[str] = None,
sufficiency_model_name: Optional[str] = None,
skip_llm_generation: bool = False
):
"""
初始化迭代检索器
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
model_name: 主LLM模型名称(用于生成答案)
embed_model_name: 嵌入模型名称
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
skip_llm_generation: 是否跳过LLM生成答案仅返回检索结果
"""
self.keyword = keyword
self.top_k = top_k
self.max_iterations = max_iterations
self.max_parallel_retrievals = max_parallel_retrievals
self.skip_llm_generation = skip_llm_generation
# 使用默认值(如果没有指定特定模型)
complexity_model_name = complexity_model_name or model_name
sufficiency_model_name = sufficiency_model_name or model_name
# 创建组件
print("[CONFIG] 初始化检索器组件...")
self.retriever = create_langchain_hipporag_retriever(
keyword=keyword,
top_k=top_k,
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
oneapi_model_gen=model_name,
oneapi_model_embed=embed_model_name
)
# 创建主LLM用于生成答案
self.llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=model_name
)
# 创建复杂度判断LLM如果模型不同
if complexity_model_name != model_name:
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
self.complexity_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=complexity_model_name
)
else:
self.complexity_llm = self.llm
# 创建充分性检查LLM如果模型不同
if sufficiency_model_name != model_name:
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
self.sufficiency_llm = create_oneapi_llm(
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=sufficiency_model_name
)
else:
self.sufficiency_llm = self.llm
# 创建节点处理器
self.nodes = GraphNodes(
retriever=self.retriever,
llm=self.llm,
complexity_llm=self.complexity_llm,
sufficiency_llm=self.sufficiency_llm,
keyword=keyword,
max_parallel_retrievals=max_parallel_retrievals,
skip_llm_generation=skip_llm_generation
)
# 构建工作流图
self.workflow = self._build_workflow()
print("[OK] 迭代检索器初始化完成")
def _build_workflow(self) -> StateGraph:
"""构建LangGraph工作流"""
print("[?] 构建工作流图...")
# 创建状态图
workflow = StateGraph(QueryState)
# 添加节点
# 新增:查询复杂度判断节点
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
# 新增:调试模式节点
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
# 简单查询路径
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
# 复杂查询路径现有hipporag2逻辑
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
# 设置入口点:从查询复杂度判断开始
workflow.set_entry_point("query_complexity_check")
# 复杂度检查后进入调试模式节点
workflow.add_edge("query_complexity_check", "debug_mode_node")
# 条件边:根据调试模式和复杂度判断结果决定路径
workflow.add_conditional_edges(
"debug_mode_node",
route_by_debug_mode,
{
"simple_vector_retrieval": "simple_vector_retrieval",
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
}
)
# 简单查询路径的边
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
workflow.add_edge("simple_answer_generation", END)
# 复杂查询路径的边(包含查询分解逻辑)
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
workflow.add_edge("initial_retrieval", "sufficiency_check")
# 条件边:根据充分性检查结果决定下一步
workflow.add_conditional_edges(
"sufficiency_check",
should_continue_retrieval,
{
"final_answer": "final_answer",
"parallel_retrieval": "sub_query_generation",
"next_iteration": "next_iteration"
}
)
workflow.add_edge("sub_query_generation", "parallel_retrieval")
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后增加迭代次数
workflow.add_edge("next_iteration", "sufficiency_check")
# 结束节点
workflow.add_edge("final_answer", END)
return workflow.compile()
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
"""
执行迭代检索
Args:
query: 用户查询
mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
Returns:
包含最终答案和详细信息的字典
"""
print(f"[STARTING] 开始迭代检索: {query}")
start_time = time.time()
# 创建初始状态
initial_state = create_initial_state(
original_query=query,
max_iterations=self.max_iterations,
debug_mode=mode
)
initial_state["debug_info"]["start_time"] = start_time
try:
# 执行工作流
final_state = self.workflow.invoke(initial_state)
# 记录结束时间
end_time = time.time()
final_state["debug_info"]["end_time"] = end_time
final_state["debug_info"]["total_time"] = end_time - start_time
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}")
# 返回结果
return {
"query": query,
"answer": final_state["final_answer"],
"query_complexity": final_state["query_complexity"],
"is_complex_query": final_state["is_complex_query"],
"iterations": final_state["current_iteration"],
"total_passages": len(final_state["all_passages"]),
"sub_queries": final_state["sub_queries"],
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
"sufficiency_check": final_state["sufficiency_check"],
"all_passages": final_state["all_passages"],
"debug_info": final_state["debug_info"],
"state_summary": get_state_summary(final_state),
"iteration_history": final_state["iteration_history"]
}
except Exception as e:
print(f"[ERROR] 迭代检索失败: {e}")
return {
"query": query,
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
"error": str(e),
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
}
def retrieve_simple(self, query: str) -> str:
"""
简单检索接口,只返回答案
Args:
query: 用户查询
Returns:
最终答案字符串
"""
result = self.retrieve(query)
return result["answer"]
def get_retrieval_stats(self) -> Dict[str, Any]:
"""
获取检索器统计信息
Returns:
统计信息字典
"""
return {
"keyword": self.keyword,
"top_k": self.top_k,
"max_iterations": self.max_iterations,
"max_parallel_retrievals": self.max_parallel_retrievals,
"retriever_type": "IterativeRetriever with LangGraph",
"model_info": {
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
}
}
def create_iterative_retriever(
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
**kwargs
) -> IterativeRetriever:
"""
创建迭代检索器的便捷函数
Args:
keyword: ES索引关键词
top_k: 每次检索返回的文档数量
max_iterations: 最大迭代次数
max_parallel_retrievals: 最大并行检索数
**kwargs: 其他参数
Returns:
迭代检索器实例
"""
return IterativeRetriever(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
**kwargs
)