Files
AIEC-RAG/retriver/langgraph/graph_state.py

256 lines
6.8 KiB
Python
Raw Normal View History

2025-09-24 09:29:12 +08:00
"""
LangGraph状态定义
定义工作流中的状态结构
"""
from typing import List, Dict, Any, Optional, TypedDict
from dataclasses import dataclass
from langchain_core.documents import Document
class QueryState(TypedDict):
"""查询状态类型定义"""
# 基本查询信息
original_query: str
current_iteration: int
max_iterations: int
# 调试模式0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
debug_mode: str
# 查询复杂度
query_complexity: Dict[str, Any]
is_complex_query: bool
# 检索结果
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
all_documents: List[Document] # 现在包含事件节点文档
passage_sources: List[str] # 事件来源信息
# 子查询相关
sub_queries: List[str]
current_sub_queries: List[str]
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
# 充分性检查
sufficiency_check: Dict[str, Any]
is_sufficient: bool
# 最终结果
final_answer: str
# 调试信息
debug_info: Dict[str, Any]
iteration_history: List[Dict[str, Any]]
# 概念探索相关状态(新增)
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输改为本地临时存储
concept_exploration_results: Dict[str, Any]
exploration_round: int # 当前探索轮次 (1 或 2)
# 添加本地PageRank存储标识不包含实际数据
pagerank_data_available: bool
@dataclass
class RetrievalResult:
"""检索结果数据类"""
passages: List[str]
documents: List[Document]
sources: List[str]
query: str
iteration: int
@dataclass
class SufficiencyCheck:
"""充分性检查结果数据类"""
is_sufficient: bool
confidence: float
reason: str
sub_queries: Optional[List[str]] = None
def create_initial_state(
original_query: str,
max_iterations: int = 2,
debug_mode: str = "0"
) -> QueryState:
"""
创建初始状态
Args:
original_query: 用户的原始查询
max_iterations: 最大迭代次数
debug_mode: 调试模式"0"=自动判断"simple"=强制简单路径"complex"=强制复杂路径
Returns:
初始状态字典
"""
return QueryState(
original_query=original_query,
current_iteration=0,
max_iterations=max_iterations,
debug_mode=debug_mode,
query_complexity={},
is_complex_query=False,
all_passages=[],
all_documents=[],
passage_sources=[],
sub_queries=[],
current_sub_queries=[],
decomposed_sub_queries=[],
sufficiency_check={},
is_sufficient=False,
final_answer="",
debug_info={
"retrieval_calls": 0,
"llm_calls": 0,
"start_time": None,
"end_time": None
},
iteration_history=[],
concept_exploration_results={},
exploration_round=0,
pagerank_data_available=False
)
def update_state_with_retrieval(
state: QueryState,
retrieval_result: RetrievalResult
) -> QueryState:
"""
使用检索结果更新状态
Args:
state: 当前状态
retrieval_result: 检索结果
Returns:
更新后的状态
"""
# 添加新的段落和文档
state["all_passages"].extend(retrieval_result.passages)
state["all_documents"].extend(retrieval_result.documents)
state["passage_sources"].extend(retrieval_result.sources)
# 更新调试信息
state["debug_info"]["retrieval_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": retrieval_result.iteration,
"query": retrieval_result.query,
"passages_count": len(retrieval_result.passages),
"action": "retrieval"
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_sufficiency_check(
state: QueryState,
sufficiency_check: SufficiencyCheck
) -> QueryState:
"""
使用充分性检查结果更新状态
Args:
state: 当前状态
sufficiency_check: 充分性检查结果
Returns:
更新后的状态
"""
state["is_sufficient"] = sufficiency_check.is_sufficient
state["sufficiency_check"] = {
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"reason": sufficiency_check.reason,
"iteration": state["current_iteration"]
}
# 如果不充分且有子查询,更新子查询
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
state["current_sub_queries"] = sufficiency_check.sub_queries
state["sub_queries"].extend(sufficiency_check.sub_queries)
else:
state["current_sub_queries"] = []
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "sufficiency_check",
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"sub_queries_count": len(sufficiency_check.sub_queries or [])
}
state["iteration_history"].append(iteration_info)
return state
def increment_iteration(state: QueryState) -> QueryState:
"""
增加迭代次数
Args:
state: 当前状态
Returns:
更新后的状态
"""
state["current_iteration"] += 1
return state
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
"""
完成状态设置最终答案
Args:
state: 当前状态
final_answer: 最终答案
Returns:
最终状态
"""
state["final_answer"] = final_answer
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "final_answer_generation",
"answer_length": len(final_answer)
}
state["iteration_history"].append(iteration_info)
return state
def get_state_summary(state: QueryState) -> Dict[str, Any]:
"""
获取状态摘要信息
Args:
state: 当前状态
Returns:
状态摘要字典
"""
return {
"original_query": state["original_query"],
"current_iteration": state["current_iteration"],
"max_iterations": state["max_iterations"],
"total_passages": len(state["all_passages"]),
"total_sub_queries": len(state["sub_queries"]),
"is_sufficient": state["is_sufficient"],
"has_final_answer": bool(state["final_answer"]),
"debug_info": state["debug_info"],
"iteration_count": len(state["iteration_history"])
}