first commit
This commit is contained in:
107
AIEC-RAG/retriver/langgraph/routing_functions.py
Normal file
107
AIEC-RAG/retriver/langgraph/routing_functions.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""
|
||||
LangGraph条件边路由函数
|
||||
专门定义工作流的路由决策逻辑,与节点逻辑分离
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from retriver.langgraph.graph_state import QueryState
|
||||
|
||||
|
||||
def route_by_debug_mode(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
||||
"""
|
||||
根据调试模式和查询复杂度决定路由
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
||||
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
||||
"""
|
||||
if state['is_complex_query']:
|
||||
print(f"[RELOAD] 路由到复杂检索逻辑 (debug_mode: {state['debug_mode']})")
|
||||
return "initial_retrieval"
|
||||
else:
|
||||
print(f"[RELOAD] 路由到简单向量检索 (debug_mode: {state['debug_mode']})")
|
||||
return "simple_vector_retrieval"
|
||||
|
||||
|
||||
def route_by_complexity(state: QueryState) -> Literal["simple_vector_retrieval", "initial_retrieval"]:
|
||||
"""
|
||||
根据查询复杂度决定路由
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "simple_vector_retrieval": 简单查询,直接进行向量检索
|
||||
- "initial_retrieval": 复杂查询,进入现有hipporag2逻辑
|
||||
"""
|
||||
if state['is_complex_query']:
|
||||
print(f"[RELOAD] 复杂查询,进入HippoRAG2检索逻辑")
|
||||
return "initial_retrieval"
|
||||
else:
|
||||
print(f"[RELOAD] 简单查询,进入直接向量检索")
|
||||
return "simple_vector_retrieval"
|
||||
|
||||
|
||||
def should_continue_retrieval(state: QueryState) -> Literal["final_answer", "parallel_retrieval", "next_iteration"]:
|
||||
"""
|
||||
决策函数:决定迭代检索工作流的下一步
|
||||
|
||||
Args:
|
||||
state: 当前查询状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称:
|
||||
- "final_answer": 生成最终答案并结束
|
||||
- "parallel_retrieval": 执行子查询的并行检索
|
||||
- "next_iteration": 进入下一轮迭代
|
||||
|
||||
决策逻辑:
|
||||
1. 如果信息充分 -> "final_answer"
|
||||
2. 如果达到最大迭代次数且信息不充分 -> "final_answer" (直接生成最终答案)
|
||||
3. 如果有待处理的子查询且未达到最大迭代 -> "parallel_retrieval"
|
||||
4. 如果连续多次不充分且没有新子查询 -> "final_answer"(避免死循环)
|
||||
5. 其他情况 -> "next_iteration"
|
||||
"""
|
||||
# 如果信息充分,生成最终答案
|
||||
if state['is_sufficient']:
|
||||
print(f"[OK] 信息充分,生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 如果已经达到最大迭代次数但信息不充分,直接生成最终答案
|
||||
if state['current_iteration'] >= state['max_iterations']:
|
||||
print(f"[RELOAD] 达到最大迭代次数 ({state['max_iterations']}) 且信息不充分,直接生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 防止死循环:检查连续不充分的次数
|
||||
iteration_history = state.get('iteration_history', [])
|
||||
consecutive_insufficient = 0
|
||||
for i in range(len(iteration_history) - 1, -1, -1):
|
||||
if not iteration_history[i].get('is_sufficient', False):
|
||||
consecutive_insufficient += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果连续3次不充分且没有新的子查询,强制结束
|
||||
if consecutive_insufficient >= 3 and not state['current_sub_queries']:
|
||||
print(f"[WARNING] 连续{consecutive_insufficient}次检索不充分且无新子查询,避免死循环,生成最终答案")
|
||||
return "final_answer"
|
||||
|
||||
# 如果还有子查询需要处理,进行并行检索
|
||||
if state['current_sub_queries']:
|
||||
print(f"[RELOAD] 执行并行检索子查询 (剩余子查询: {len(state['current_sub_queries'])})")
|
||||
return "parallel_retrieval"
|
||||
|
||||
# 如果信息不充分且还没达到最大迭代次数,生成子查询
|
||||
if not state['is_sufficient'] and state['current_iteration'] < state['max_iterations']:
|
||||
print(f"[RELOAD] 信息不充分,生成子查询进行并行检索")
|
||||
return "parallel_retrieval" # 这会路由到 sub_query_generation → parallel_retrieval
|
||||
|
||||
# 否则,进行下一轮迭代(这种情况应该很少发生)
|
||||
print(f"[RELOAD] 开始下一轮迭代")
|
||||
return "next_iteration"
|
||||
|
||||
Reference in New Issue
Block a user