Files
AIEC-RAG---/AIEC-RAG/retriver/langgraph/routing_functions.py
2025-09-25 10:33:37 +08:00

108 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph条件边路由函数
专门定义工作流的路由决策逻辑,与节点逻辑分离
"""
from typing import Literal
from retriver.langgraph.graph_state import QueryState
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据调试模式和查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询,直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
return "initial_retrieval"
else:
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
return "simple_vector_retrieval"
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
"""
根据查询复杂度决定路由
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "simple_vector_retrieval": 简单查询,直接进行向量检索
- "initial_retrieval": 复杂查询进入现有hipporag2逻辑
"""
if state['is_complex_query']:
print(f"[RELOAD] 复杂查询进入HippoRAG2检索逻辑")
return "initial_retrieval"
else:
print(f"[RELOAD] 简单查询,进入直接向量检索")
return "simple_vector_retrieval"
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval", "next_iteration"]:
"""
决策函数:决定迭代检索工作流的下一步
Args:
state: 当前查询状态
Returns:
下一个节点名称:
- "final_answer": 生成最终答案并结束
- "parallel_retrieval": 执行子查询的并行检索
- "next_iteration": 进入下一轮迭代
决策逻辑:
1. 如果信息充分 -> "final_answer"
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
5. 其他情况 -> "next_iteration"
"""
# 如果信息充分,生成最终答案
if state['is_sufficient']:
print(f"[OK] 信息充分,生成最终答案")
return "final_answer"
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
if state['current_iteration'] >= state['max_iterations']:
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
return "final_answer"
# 防止死循环:检查连续不充分的次数
iteration_history = state.get('iteration_history', [])
consecutive_insufficient = 0
for i in range(len(iteration_history) - 1, -1, -1):
if not iteration_history[i].get('is_sufficient', False):
consecutive_insufficient += 1
else:
break
# 如果连续3次不充分且没有新的子查询强制结束
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
return "final_answer"
# 如果还有子查询需要处理,进行并行检索
if state['current_sub_queries']:
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
return "parallel_retrieval"
# 如果信息不充分且还没达到最大迭代次数,生成子查询
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
# 否则,进行下一轮迭代(这种情况应该很少发生)
print(f"[RELOAD] 开始下一轮迭代")
return "next_iteration"