first commit
This commit is contained in:
318
retriver/langgraph/iterative_retriever.py
Normal file
318
retriver/langgraph/iterative_retriever.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""
|
||||
基于LangGraph的迭代检索器
|
||||
实现智能迭代检索工作流
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary
|
||||
from retriver.langgraph.graph_nodes import GraphNodes
|
||||
from retriver.langgraph.routing_functions import (
|
||||
should_continue_retrieval,
|
||||
route_by_complexity,
|
||||
route_by_debug_mode
|
||||
)
|
||||
from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever
|
||||
from retriver.langgraph.langchain_components import create_oneapi_llm
|
||||
|
||||
|
||||
class IterativeRetriever:
|
||||
"""
|
||||
基于LangGraph的迭代检索器
|
||||
|
||||
实现迭代检索流程:
|
||||
1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keyword: str,
|
||||
top_k: int = 2,
|
||||
max_iterations: int = 2,
|
||||
max_parallel_retrievals: int = 2,
|
||||
oneapi_key: Optional[str] = None,
|
||||
oneapi_base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
embed_model_name: Optional[str] = None,
|
||||
complexity_model_name: Optional[str] = None,
|
||||
sufficiency_model_name: Optional[str] = None,
|
||||
skip_llm_generation: bool = False
|
||||
):
|
||||
"""
|
||||
初始化迭代检索器
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 每次检索返回的文档数量
|
||||
max_iterations: 最大迭代次数
|
||||
max_parallel_retrievals: 最大并行检索数
|
||||
oneapi_key: OneAPI密钥
|
||||
oneapi_base_url: OneAPI基础URL
|
||||
model_name: 主LLM模型名称(用于生成答案)
|
||||
embed_model_name: 嵌入模型名称
|
||||
complexity_model_name: 复杂度判断模型(如不指定则使用model_name)
|
||||
sufficiency_model_name: 充分性检查模型(如不指定则使用model_name)
|
||||
skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果)
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.top_k = top_k
|
||||
self.max_iterations = max_iterations
|
||||
self.max_parallel_retrievals = max_parallel_retrievals
|
||||
self.skip_llm_generation = skip_llm_generation
|
||||
|
||||
# 使用默认值(如果没有指定特定模型)
|
||||
complexity_model_name = complexity_model_name or model_name
|
||||
sufficiency_model_name = sufficiency_model_name or model_name
|
||||
|
||||
# 创建组件
|
||||
print("[CONFIG] 初始化检索器组件...")
|
||||
self.retriever = create_langchain_hipporag_retriever(
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
oneapi_model_gen=model_name,
|
||||
oneapi_model_embed=embed_model_name
|
||||
)
|
||||
|
||||
# 创建主LLM(用于生成答案)
|
||||
self.llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# 创建复杂度判断LLM(如果模型不同)
|
||||
if complexity_model_name != model_name:
|
||||
print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}")
|
||||
self.complexity_llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=complexity_model_name
|
||||
)
|
||||
else:
|
||||
self.complexity_llm = self.llm
|
||||
|
||||
# 创建充分性检查LLM(如果模型不同)
|
||||
if sufficiency_model_name != model_name:
|
||||
print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}")
|
||||
self.sufficiency_llm = create_oneapi_llm(
|
||||
oneapi_key=oneapi_key,
|
||||
oneapi_base_url=oneapi_base_url,
|
||||
model_name=sufficiency_model_name
|
||||
)
|
||||
else:
|
||||
self.sufficiency_llm = self.llm
|
||||
|
||||
# 创建节点处理器
|
||||
self.nodes = GraphNodes(
|
||||
retriever=self.retriever,
|
||||
llm=self.llm,
|
||||
complexity_llm=self.complexity_llm,
|
||||
sufficiency_llm=self.sufficiency_llm,
|
||||
keyword=keyword,
|
||||
max_parallel_retrievals=max_parallel_retrievals,
|
||||
skip_llm_generation=skip_llm_generation
|
||||
)
|
||||
|
||||
# 构建工作流图
|
||||
self.workflow = self._build_workflow()
|
||||
|
||||
print("[OK] 迭代检索器初始化完成")
|
||||
|
||||
def _build_workflow(self) -> StateGraph:
|
||||
"""构建LangGraph工作流"""
|
||||
print("[?] 构建工作流图...")
|
||||
|
||||
# 创建状态图
|
||||
workflow = StateGraph(QueryState)
|
||||
|
||||
# 添加节点
|
||||
# 新增:查询复杂度判断节点
|
||||
workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node)
|
||||
|
||||
# 新增:调试模式节点
|
||||
workflow.add_node("debug_mode_node", self.nodes.debug_mode_node)
|
||||
|
||||
# 简单查询路径
|
||||
workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node)
|
||||
workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node)
|
||||
|
||||
# 复杂查询路径(现有hipporag2逻辑)
|
||||
workflow.add_node("query_decomposition", self.nodes.query_decomposition_node)
|
||||
workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node)
|
||||
workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node)
|
||||
workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node)
|
||||
workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node)
|
||||
workflow.add_node("next_iteration", self.nodes.next_iteration_node)
|
||||
workflow.add_node("final_answer", self.nodes.final_answer_generation_node)
|
||||
|
||||
# 设置入口点:从查询复杂度判断开始
|
||||
workflow.set_entry_point("query_complexity_check")
|
||||
|
||||
# 复杂度检查后进入调试模式节点
|
||||
workflow.add_edge("query_complexity_check", "debug_mode_node")
|
||||
|
||||
# 条件边:根据调试模式和复杂度判断结果决定路径
|
||||
workflow.add_conditional_edges(
|
||||
"debug_mode_node",
|
||||
route_by_debug_mode,
|
||||
{
|
||||
"simple_vector_retrieval": "simple_vector_retrieval",
|
||||
"initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点
|
||||
}
|
||||
)
|
||||
|
||||
# 简单查询路径的边
|
||||
workflow.add_edge("simple_vector_retrieval", "simple_answer_generation")
|
||||
workflow.add_edge("simple_answer_generation", END)
|
||||
|
||||
# 复杂查询路径的边(包含查询分解逻辑)
|
||||
workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索
|
||||
workflow.add_edge("initial_retrieval", "sufficiency_check")
|
||||
|
||||
# 条件边:根据充分性检查结果决定下一步
|
||||
workflow.add_conditional_edges(
|
||||
"sufficiency_check",
|
||||
should_continue_retrieval,
|
||||
{
|
||||
"final_answer": "final_answer",
|
||||
"parallel_retrieval": "sub_query_generation",
|
||||
"next_iteration": "next_iteration"
|
||||
}
|
||||
)
|
||||
|
||||
workflow.add_edge("sub_query_generation", "parallel_retrieval")
|
||||
workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后增加迭代次数
|
||||
workflow.add_edge("next_iteration", "sufficiency_check")
|
||||
|
||||
# 结束节点
|
||||
workflow.add_edge("final_answer", END)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
|
||||
"""
|
||||
执行迭代检索
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
|
||||
|
||||
Returns:
|
||||
包含最终答案和详细信息的字典
|
||||
"""
|
||||
print(f"[STARTING] 开始迭代检索: {query}")
|
||||
start_time = time.time()
|
||||
|
||||
# 创建初始状态
|
||||
initial_state = create_initial_state(
|
||||
original_query=query,
|
||||
max_iterations=self.max_iterations,
|
||||
debug_mode=mode
|
||||
)
|
||||
initial_state["debug_info"]["start_time"] = start_time
|
||||
|
||||
try:
|
||||
# 执行工作流
|
||||
final_state = self.workflow.invoke(initial_state)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
final_state["debug_info"]["end_time"] = end_time
|
||||
final_state["debug_info"]["total_time"] = end_time - start_time
|
||||
|
||||
print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}秒")
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"query": query,
|
||||
"answer": final_state["final_answer"],
|
||||
"query_complexity": final_state["query_complexity"],
|
||||
"is_complex_query": final_state["is_complex_query"],
|
||||
"iterations": final_state["current_iteration"],
|
||||
"total_passages": len(final_state["all_passages"]),
|
||||
"sub_queries": final_state["sub_queries"],
|
||||
"decomposed_sub_queries": final_state.get("decomposed_sub_queries", []),
|
||||
"initial_retrieval_details": final_state.get("initial_retrieval_details", {}),
|
||||
"sufficiency_check": final_state["sufficiency_check"],
|
||||
"all_passages": final_state["all_passages"],
|
||||
"debug_info": final_state["debug_info"],
|
||||
"state_summary": get_state_summary(final_state),
|
||||
"iteration_history": final_state["iteration_history"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 迭代检索失败: {e}")
|
||||
return {
|
||||
"query": query,
|
||||
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
|
||||
"error": str(e),
|
||||
"iterations": 0,
|
||||
"total_passages": 0,
|
||||
"sub_queries": [],
|
||||
"debug_info": {"error": str(e), "total_time": time.time() - start_time}
|
||||
}
|
||||
|
||||
def retrieve_simple(self, query: str) -> str:
|
||||
"""
|
||||
简单检索接口,只返回答案
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
最终答案字符串
|
||||
"""
|
||||
result = self.retrieve(query)
|
||||
return result["answer"]
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取检索器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return {
|
||||
"keyword": self.keyword,
|
||||
"top_k": self.top_k,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
"retriever_type": "IterativeRetriever with LangGraph",
|
||||
"model_info": {
|
||||
"llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'),
|
||||
"embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_iterative_retriever(
|
||||
keyword: str,
|
||||
top_k: int = 2,
|
||||
max_iterations: int = 2,
|
||||
max_parallel_retrievals: int = 2,
|
||||
**kwargs
|
||||
) -> IterativeRetriever:
|
||||
"""
|
||||
创建迭代检索器的便捷函数
|
||||
|
||||
Args:
|
||||
keyword: ES索引关键词
|
||||
top_k: 每次检索返回的文档数量
|
||||
max_iterations: 最大迭代次数
|
||||
max_parallel_retrievals: 最大并行检索数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
迭代检索器实例
|
||||
"""
|
||||
return IterativeRetriever(
|
||||
keyword=keyword,
|
||||
top_k=top_k,
|
||||
max_iterations=max_iterations,
|
||||
max_parallel_retrievals=max_parallel_retrievals,
|
||||
**kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user