Files
AIEC-RAG---/AIEC-RAG/retriver/langgraph/routing_functions.py

108 lines
4.2 KiB
Python
Raw Permalink Normal View History

2025-09-25 10:33:37 +08:00
"""
LangGraph条件边路由函数
专门定义工作流的路由决策逻辑与节点逻辑分离
"""
from typing import Literal
from retriver.langgraph.graph_state import QueryState
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据调试模式和查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
return "initial_retrieval"
else:
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
return "simple_vector_retrieval"
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 复杂查询进入HippoRAG2检索逻辑")
return "initial_retrieval"
else:
print(f"[RELOAD] 简单查询,进入直接向量检索")
return "simple_vector_retrieval"
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval", "next_iteration"]:
"""
决策函数决定迭代检索工作流的下一步
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "final_answer": 生成最终答案并结束
- "parallel_retrieval": 执行子查询的并行检索
- "next_iteration": 进入下一轮迭代
决策逻辑:
1. 如果信息充分 -> "final_answer"
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
5. 其他情况 -> "next_iteration"
"""
# 如果信息充分,生成最终答案
if state['is_sufficient']:
print(f"[OK] 信息充分,生成最终答案")
return "final_answer"
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
if state['current_iteration'] >= state['max_iterations']:
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
return "final_answer"
# 防止死循环:检查连续不充分的次数
iteration_history = state.get('iteration_history', [])
consecutive_insufficient = 0
for i in range(len(iteration_history) - 1, -1, -1):
if not iteration_history[i].get('is_sufficient', False):
consecutive_insufficient += 1
else:
break
# 如果连续3次不充分且没有新的子查询强制结束
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
return "final_answer"
# 如果还有子查询需要处理,进行并行检索
if state['current_sub_queries']:
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
return "parallel_retrieval"
# 如果信息不充分且还没达到最大迭代次数,生成子查询
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
# 否则,进行下一轮迭代(这种情况应该很少发生)
print(f"[RELOAD] 开始下一轮迭代")
return "next_iteration"