""" 基于LangGraph的迭代检索器 实现智能迭代检索工作流 """ import time from typing import Dict, Any, Optional from langgraph.graph import StateGraph, END from retriver.langgraph.graph_state import QueryState, create_initial_state, get_state_summary from retriver.langgraph.graph_nodes import GraphNodes from retriver.langgraph.routing_functions import ( should_continue_retrieval, route_by_complexity, route_by_debug_mode ) from retriver.langgraph.langchain_hipporag_retriever import create_langchain_hipporag_retriever from retriver.langgraph.langchain_components import create_oneapi_llm class IterativeRetriever: """ 基于LangGraph的迭代检索器 实现迭代检索流程: 1. 初始检索 -> 2. 充分性检查 -> 3. 子查询生成(如需要) -> 4. 并行检索 -> 5. 重复2-4直到充分或达到最大迭代次数 -> 6. 生成最终答案 """ def __init__( self, keyword: str, top_k: int = 2, max_iterations: int = 2, max_parallel_retrievals: int = 2, oneapi_key: Optional[str] = None, oneapi_base_url: Optional[str] = None, model_name: Optional[str] = None, embed_model_name: Optional[str] = None, complexity_model_name: Optional[str] = None, sufficiency_model_name: Optional[str] = None, skip_llm_generation: bool = False ): """ 初始化迭代检索器 Args: keyword: ES索引关键词 top_k: 每次检索返回的文档数量 max_iterations: 最大迭代次数 max_parallel_retrievals: 最大并行检索数 oneapi_key: OneAPI密钥 oneapi_base_url: OneAPI基础URL model_name: 主LLM模型名称(用于生成答案) embed_model_name: 嵌入模型名称 complexity_model_name: 复杂度判断模型(如不指定则使用model_name) sufficiency_model_name: 充分性检查模型(如不指定则使用model_name) skip_llm_generation: 是否跳过LLM生成答案(仅返回检索结果) """ self.keyword = keyword self.top_k = top_k self.max_iterations = max_iterations self.max_parallel_retrievals = max_parallel_retrievals self.skip_llm_generation = skip_llm_generation # 使用默认值(如果没有指定特定模型) complexity_model_name = complexity_model_name or model_name sufficiency_model_name = sufficiency_model_name or model_name # 创建组件 print("[CONFIG] 初始化检索器组件...") self.retriever = create_langchain_hipporag_retriever( keyword=keyword, top_k=top_k, oneapi_key=oneapi_key, oneapi_base_url=oneapi_base_url, oneapi_model_gen=model_name, oneapi_model_embed=embed_model_name ) # 创建主LLM(用于生成答案) self.llm = create_oneapi_llm( oneapi_key=oneapi_key, oneapi_base_url=oneapi_base_url, model_name=model_name ) # 创建复杂度判断LLM(如果模型不同) if complexity_model_name != model_name: print(f" [INFO] 使用独立的复杂度判断模型: {complexity_model_name}") self.complexity_llm = create_oneapi_llm( oneapi_key=oneapi_key, oneapi_base_url=oneapi_base_url, model_name=complexity_model_name ) else: self.complexity_llm = self.llm # 创建充分性检查LLM(如果模型不同) if sufficiency_model_name != model_name: print(f" [INFO] 使用独立的充分性检查模型: {sufficiency_model_name}") self.sufficiency_llm = create_oneapi_llm( oneapi_key=oneapi_key, oneapi_base_url=oneapi_base_url, model_name=sufficiency_model_name ) else: self.sufficiency_llm = self.llm # 创建节点处理器 self.nodes = GraphNodes( retriever=self.retriever, llm=self.llm, complexity_llm=self.complexity_llm, sufficiency_llm=self.sufficiency_llm, keyword=keyword, max_parallel_retrievals=max_parallel_retrievals, skip_llm_generation=skip_llm_generation ) # 构建工作流图 self.workflow = self._build_workflow() print("[OK] 迭代检索器初始化完成") def _build_workflow(self) -> StateGraph: """构建LangGraph工作流""" print("[?] 构建工作流图...") # 创建状态图 workflow = StateGraph(QueryState) # 添加节点 # 新增:查询复杂度判断节点 workflow.add_node("query_complexity_check", self.nodes.query_complexity_check_node) # 新增:调试模式节点 workflow.add_node("debug_mode_node", self.nodes.debug_mode_node) # 简单查询路径 workflow.add_node("simple_vector_retrieval", self.nodes.simple_vector_retrieval_node) workflow.add_node("simple_answer_generation", self.nodes.simple_answer_generation_node) # 复杂查询路径(现有hipporag2逻辑) workflow.add_node("query_decomposition", self.nodes.query_decomposition_node) workflow.add_node("initial_retrieval", self.nodes.initial_retrieval_node) workflow.add_node("sufficiency_check", self.nodes.sufficiency_check_node) workflow.add_node("sub_query_generation", self.nodes.sub_query_generation_node) workflow.add_node("parallel_retrieval", self.nodes.parallel_retrieval_node) workflow.add_node("next_iteration", self.nodes.next_iteration_node) workflow.add_node("final_answer", self.nodes.final_answer_generation_node) # 设置入口点:从查询复杂度判断开始 workflow.set_entry_point("query_complexity_check") # 复杂度检查后进入调试模式节点 workflow.add_edge("query_complexity_check", "debug_mode_node") # 条件边:根据调试模式和复杂度判断结果决定路径 workflow.add_conditional_edges( "debug_mode_node", route_by_debug_mode, { "simple_vector_retrieval": "simple_vector_retrieval", "initial_retrieval": "query_decomposition" # 复杂路径先进入查询分解节点 } ) # 简单查询路径的边 workflow.add_edge("simple_vector_retrieval", "simple_answer_generation") workflow.add_edge("simple_answer_generation", END) # 复杂查询路径的边(包含查询分解逻辑) workflow.add_edge("query_decomposition", "initial_retrieval") # 查询分解后进入并行初始检索 workflow.add_edge("initial_retrieval", "sufficiency_check") # 条件边:根据充分性检查结果决定下一步 workflow.add_conditional_edges( "sufficiency_check", should_continue_retrieval, { "final_answer": "final_answer", "parallel_retrieval": "sub_query_generation", "next_iteration": "next_iteration" } ) workflow.add_edge("sub_query_generation", "parallel_retrieval") workflow.add_edge("parallel_retrieval", "next_iteration") # 并行检索后增加迭代次数 workflow.add_edge("next_iteration", "sufficiency_check") # 结束节点 workflow.add_edge("final_answer", END) return workflow.compile() def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]: """ 执行迭代检索 Args: query: 用户查询 mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径 Returns: 包含最终答案和详细信息的字典 """ print(f"[STARTING] 开始迭代检索: {query}") start_time = time.time() # 创建初始状态 initial_state = create_initial_state( original_query=query, max_iterations=self.max_iterations, debug_mode=mode ) initial_state["debug_info"]["start_time"] = start_time try: # 执行工作流 final_state = self.workflow.invoke(initial_state) # 记录结束时间 end_time = time.time() final_state["debug_info"]["end_time"] = end_time final_state["debug_info"]["total_time"] = end_time - start_time print(f"[SUCCESS] 迭代检索完成,耗时 {end_time - start_time:.2f}秒") # 返回结果 return { "query": query, "answer": final_state["final_answer"], "query_complexity": final_state["query_complexity"], "is_complex_query": final_state["is_complex_query"], "iterations": final_state["current_iteration"], "total_passages": len(final_state["all_passages"]), "sub_queries": final_state["sub_queries"], "decomposed_sub_queries": final_state.get("decomposed_sub_queries", []), "initial_retrieval_details": final_state.get("initial_retrieval_details", {}), "sufficiency_check": final_state["sufficiency_check"], "all_passages": final_state["all_passages"], "debug_info": final_state["debug_info"], "state_summary": get_state_summary(final_state), "iteration_history": final_state["iteration_history"] } except Exception as e: print(f"[ERROR] 迭代检索失败: {e}") return { "query": query, "answer": f"抱歉,检索过程中遇到错误: {str(e)}", "error": str(e), "iterations": 0, "total_passages": 0, "sub_queries": [], "debug_info": {"error": str(e), "total_time": time.time() - start_time} } def retrieve_simple(self, query: str) -> str: """ 简单检索接口,只返回答案 Args: query: 用户查询 Returns: 最终答案字符串 """ result = self.retrieve(query) return result["answer"] def get_retrieval_stats(self) -> Dict[str, Any]: """ 获取检索器统计信息 Returns: 统计信息字典 """ return { "keyword": self.keyword, "top_k": self.top_k, "max_iterations": self.max_iterations, "max_parallel_retrievals": self.max_parallel_retrievals, "retriever_type": "IterativeRetriever with LangGraph", "model_info": { "llm_model": getattr(self.llm.oneapi_generator, 'model_name', 'unknown'), "embed_model": getattr(self.retriever.embedding_model, 'model_name', 'unknown') } } def create_iterative_retriever( keyword: str, top_k: int = 2, max_iterations: int = 2, max_parallel_retrievals: int = 2, **kwargs ) -> IterativeRetriever: """ 创建迭代检索器的便捷函数 Args: keyword: ES索引关键词 top_k: 每次检索返回的文档数量 max_iterations: 最大迭代次数 max_parallel_retrievals: 最大并行检索数 **kwargs: 其他参数 Returns: 迭代检索器实例 """ return IterativeRetriever( keyword=keyword, top_k=top_k, max_iterations=max_iterations, max_parallel_retrievals=max_parallel_retrievals, **kwargs )