108 lines
4.3 KiB
Python
108 lines
4.3 KiB
Python
|
|
"""
|
|||
|
|
LangGraph条件边路由函数
|
|||
|
|
专门定义工作流的路由决策逻辑,与节点逻辑分离
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Literal
|
|||
|
|
from retriver.langgraph.graph_state import QueryState
|
|||
|
|
|
|||
|
|
|
|||
|
|
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
|||
|
|
"""
|
|||
|
|
根据调试模式和查询复杂度决定路由
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前查询状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
下一个节点名称:
|
|||
|
|
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
|||
|
|
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
|||
|
|
"""
|
|||
|
|
if state['is_complex_query']:
|
|||
|
|
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
|
|||
|
|
return "initial_retrieval"
|
|||
|
|
else:
|
|||
|
|
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
|
|||
|
|
return "simple_vector_retrieval"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
|||
|
|
"""
|
|||
|
|
根据查询复杂度决定路由
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前查询状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
下一个节点名称:
|
|||
|
|
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
|||
|
|
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
|||
|
|
"""
|
|||
|
|
if state['is_complex_query']:
|
|||
|
|
print(f"[RELOAD] 复杂查询,进入HippoRAG2检索逻辑")
|
|||
|
|
return "initial_retrieval"
|
|||
|
|
else:
|
|||
|
|
print(f"[RELOAD] 简单查询,进入直接向量检索")
|
|||
|
|
return "simple_vector_retrieval"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval"]:
|
|||
|
|
"""
|
|||
|
|
决策函数:决定迭代检索工作流的下一步
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前查询状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
下一个节点名称:
|
|||
|
|
- "final_answer": 生成最终答案并结束
|
|||
|
|
- "parallel_retrieval": 执行子查询的并行检索
|
|||
|
|
- "next_iteration": 进入下一轮迭代
|
|||
|
|
|
|||
|
|
决策逻辑:
|
|||
|
|
1. 如果信息充分 -> "final_answer"
|
|||
|
|
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
|
|||
|
|
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
|
|||
|
|
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
|
|||
|
|
5. 其他情况 -> "next_iteration"
|
|||
|
|
"""
|
|||
|
|
# 如果信息充分,生成最终答案
|
|||
|
|
if state['is_sufficient']:
|
|||
|
|
print(f"[OK] 信息充分,生成最终答案")
|
|||
|
|
return "final_answer"
|
|||
|
|
|
|||
|
|
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
|
|||
|
|
if state['current_iteration'] >= state['max_iterations']:
|
|||
|
|
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
|
|||
|
|
return "final_answer"
|
|||
|
|
|
|||
|
|
# 防止死循环:检查连续不充分的次数
|
|||
|
|
iteration_history = state.get('iteration_history', [])
|
|||
|
|
consecutive_insufficient = 0
|
|||
|
|
for i in range(len(iteration_history) - 1, -1, -1):
|
|||
|
|
if not iteration_history[i].get('is_sufficient', False):
|
|||
|
|
consecutive_insufficient += 1
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 如果连续3次不充分且没有新的子查询,强制结束
|
|||
|
|
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
|
|||
|
|
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
|
|||
|
|
return "final_answer"
|
|||
|
|
|
|||
|
|
# 如果还有子查询需要处理,进行并行检索
|
|||
|
|
if state['current_sub_queries']:
|
|||
|
|
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
|
|||
|
|
return "parallel_retrieval"
|
|||
|
|
|
|||
|
|
# 如果信息不充分且还没达到最大迭代次数,生成子查询
|
|||
|
|
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
|
|||
|
|
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
|
|||
|
|
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
|
|||
|
|
|
|||
|
|
# 否则,生成子查询进行并行检索(兜底逻辑)
|
|||
|
|
print(f"[RELOAD] 兜底逻辑:生成子查询进行并行检索")
|
|||
|
|
return "parallel_retrieval"
|
|||
|
|
|