Files
AIEC-new/AIEC-RAG/retriver/langgraph/graph_state.py

299 lines
8.1 KiB
Python
Raw Permalink Normal View History

2025-10-17 09:31:28 +08:00
"""
LangGraph状态定义
定义工作流中的状态结构
"""
from typing import List, Dict, Any, Optional, TypedDict, Callable
from dataclasses import dataclass
from langchain_core.documents import Document
class QueryState(TypedDict):
"""查询状态类型定义"""
# 基本查询信息
original_query: str
current_iteration: int
max_iterations: int
# 调试模式0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
debug_mode: str
# 查询复杂度
query_complexity: Dict[str, Any]
is_complex_query: bool
# 检索结果
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
all_documents: List[Document] # 现在包含事件节点文档
passage_sources: List[str] # 事件来源信息
# 子查询相关
sub_queries: List[str]
current_sub_queries: List[str]
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
# 充分性检查
sufficiency_check: Dict[str, Any]
is_sufficient: bool
# 最终结果
final_answer: str
# 调试信息
debug_info: Dict[str, Any]
iteration_history: List[Dict[str, Any]]
# 概念探索相关状态(新增)
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输改为本地临时存储
concept_exploration_results: Dict[str, Any]
exploration_round: int # 当前探索轮次 (1 或 2)
# 添加本地PageRank存储标识不包含实际数据
pagerank_data_available: bool
# 流式回调函数(可选)
stream_callback: Optional[Callable]
# 事件-事件三元组信息(新增)
event_triples: List[Dict[str, str]]
# 任务管理相关字段(新增)
task_id: Optional[str] # 任务ID
should_stop: bool # 是否应该停止
@dataclass
class RetrievalResult:
"""检索结果数据类"""
passages: List[str]
documents: List[Document]
sources: List[str]
query: str
iteration: int
@dataclass
class SufficiencyCheck:
"""充分性检查结果数据类"""
is_sufficient: bool
confidence: float
reason: str
sub_queries: Optional[List[str]] = None
def create_initial_state(
original_query: str,
max_iterations: int = 2,
debug_mode: str = "0",
task_id: Optional[str] = None
) -> QueryState:
"""
创建初始状态
Args:
original_query: 用户的原始查询
max_iterations: 最大迭代次数
debug_mode: 调试模式"0"=自动判断"simple"=强制简单路径"complex"=强制复杂路径
task_id: 任务ID可选
Returns:
初始状态字典
"""
return QueryState(
original_query=original_query,
current_iteration=0,
max_iterations=max_iterations,
debug_mode=debug_mode,
query_complexity={},
is_complex_query=False,
all_passages=[],
all_documents=[],
passage_sources=[],
sub_queries=[],
current_sub_queries=[],
decomposed_sub_queries=[],
sufficiency_check={},
is_sufficient=False,
final_answer="",
debug_info={
"retrieval_calls": 0,
"llm_calls": 0,
"start_time": None,
"end_time": None
},
iteration_history=[],
concept_exploration_results={},
exploration_round=0,
pagerank_data_available=False,
stream_callback=None, # 初始化为None会在需要时设置
event_triples=[], # 初始化为空列表
task_id=task_id, # 任务ID
should_stop=False # 初始化为不停止
)
def update_state_with_retrieval(
state: QueryState,
retrieval_result: RetrievalResult
) -> QueryState:
"""
使用检索结果更新状态
Args:
state: 当前状态
retrieval_result: 检索结果
Returns:
更新后的状态
"""
# 添加新的段落和文档
state["all_passages"].extend(retrieval_result.passages)
state["all_documents"].extend(retrieval_result.documents)
state["passage_sources"].extend(retrieval_result.sources)
# 更新调试信息
state["debug_info"]["retrieval_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": retrieval_result.iteration,
"query": retrieval_result.query,
"passages_count": len(retrieval_result.passages),
"action": "retrieval"
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_sufficiency_check(
state: QueryState,
sufficiency_check: SufficiencyCheck
) -> QueryState:
"""
使用充分性检查结果更新状态
Args:
state: 当前状态
sufficiency_check: 充分性检查结果
Returns:
更新后的状态
"""
state["is_sufficient"] = sufficiency_check.is_sufficient
state["sufficiency_check"] = {
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"reason": sufficiency_check.reason,
"iteration": state["current_iteration"]
}
# 如果不充分且有子查询,更新子查询
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
state["current_sub_queries"] = sufficiency_check.sub_queries
state["sub_queries"].extend(sufficiency_check.sub_queries)
else:
state["current_sub_queries"] = []
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "sufficiency_check",
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"sub_queries_count": len(sufficiency_check.sub_queries or [])
}
state["iteration_history"].append(iteration_info)
return state
def increment_iteration(state: QueryState) -> QueryState:
"""
增加迭代次数
Args:
state: 当前状态
Returns:
更新后的状态
"""
state["current_iteration"] += 1
return state
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
"""
完成状态设置最终答案
Args:
state: 当前状态
final_answer: 最终答案
Returns:
最终状态
"""
state["final_answer"] = final_answer
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "final_answer_generation",
"answer_length": len(final_answer)
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_event_triples(
state: QueryState,
event_triples: List[Dict[str, str]]
) -> QueryState:
"""
使用事件三元组更新状态
Args:
state: 当前状态
event_triples: 事件-事件三元组列表格式: [{'source_entity': '事件A', 'relation': '关系', 'target_entity': '事件B', 'source_evidence': '来源', 'target_evidence': '来源'}, ...]
Returns:
更新后的状态
"""
state["event_triples"] = event_triples
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "event_triples_extraction",
"triples_count": len(event_triples)
}
state["iteration_history"].append(iteration_info)
return state
def get_state_summary(state: QueryState) -> Dict[str, Any]:
"""
获取状态摘要信息
Args:
state: 当前状态
Returns:
状态摘要字典
"""
return {
"original_query": state["original_query"],
"current_iteration": state["current_iteration"],
"max_iterations": state["max_iterations"],
"total_passages": len(state["all_passages"]),
"total_sub_queries": len(state["sub_queries"]),
"is_sufficient": state["is_sufficient"],
"has_final_answer": bool(state["final_answer"]),
"debug_info": state["debug_info"],
"iteration_count": len(state["iteration_history"])
}