Files
AIEC-new/AIEC-RAG/retriver/langgraph/graph_state.py
2025-10-17 09:31:28 +08:00

299 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangGraph状态定义
定义工作流中的状态结构
"""
from typing import List, Dict, Any, Optional, TypedDict, Callable
from dataclasses import dataclass
from langchain_core.documents import Document
class QueryState(TypedDict):
"""查询状态类型定义"""
# 基本查询信息
original_query: str
current_iteration: int
max_iterations: int
# 调试模式0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
debug_mode: str
# 查询复杂度
query_complexity: Dict[str, Any]
is_complex_query: bool
# 检索结果
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
all_documents: List[Document] # 现在包含事件节点文档
passage_sources: List[str] # 事件来源信息
# 子查询相关
sub_queries: List[str]
current_sub_queries: List[str]
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
# 充分性检查
sufficiency_check: Dict[str, Any]
is_sufficient: bool
# 最终结果
final_answer: str
# 调试信息
debug_info: Dict[str, Any]
iteration_history: List[Dict[str, Any]]
# 概念探索相关状态(新增)
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输改为本地临时存储
concept_exploration_results: Dict[str, Any]
exploration_round: int # 当前探索轮次 (1 或 2)
# 添加本地PageRank存储标识不包含实际数据
pagerank_data_available: bool
# 流式回调函数(可选)
stream_callback: Optional[Callable]
# 事件-事件三元组信息(新增)
event_triples: List[Dict[str, str]]
# 任务管理相关字段(新增)
task_id: Optional[str] # 任务ID
should_stop: bool # 是否应该停止
@dataclass
class RetrievalResult:
"""检索结果数据类"""
passages: List[str]
documents: List[Document]
sources: List[str]
query: str
iteration: int
@dataclass
class SufficiencyCheck:
"""充分性检查结果数据类"""
is_sufficient: bool
confidence: float
reason: str
sub_queries: Optional[List[str]] = None
def create_initial_state(
original_query: str,
max_iterations: int = 2,
debug_mode: str = "0",
task_id: Optional[str] = None
) -> QueryState:
"""
创建初始状态
Args:
original_query: 用户的原始查询
max_iterations: 最大迭代次数
debug_mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
task_id: 任务ID可选
Returns:
初始状态字典
"""
return QueryState(
original_query=original_query,
current_iteration=0,
max_iterations=max_iterations,
debug_mode=debug_mode,
query_complexity={},
is_complex_query=False,
all_passages=[],
all_documents=[],
passage_sources=[],
sub_queries=[],
current_sub_queries=[],
decomposed_sub_queries=[],
sufficiency_check={},
is_sufficient=False,
final_answer="",
debug_info={
"retrieval_calls": 0,
"llm_calls": 0,
"start_time": None,
"end_time": None
},
iteration_history=[],
concept_exploration_results={},
exploration_round=0,
pagerank_data_available=False,
stream_callback=None, # 初始化为None会在需要时设置
event_triples=[], # 初始化为空列表
task_id=task_id, # 任务ID
should_stop=False # 初始化为不停止
)
def update_state_with_retrieval(
state: QueryState,
retrieval_result: RetrievalResult
) -> QueryState:
"""
使用检索结果更新状态
Args:
state: 当前状态
retrieval_result: 检索结果
Returns:
更新后的状态
"""
# 添加新的段落和文档
state["all_passages"].extend(retrieval_result.passages)
state["all_documents"].extend(retrieval_result.documents)
state["passage_sources"].extend(retrieval_result.sources)
# 更新调试信息
state["debug_info"]["retrieval_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": retrieval_result.iteration,
"query": retrieval_result.query,
"passages_count": len(retrieval_result.passages),
"action": "retrieval"
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_sufficiency_check(
state: QueryState,
sufficiency_check: SufficiencyCheck
) -> QueryState:
"""
使用充分性检查结果更新状态
Args:
state: 当前状态
sufficiency_check: 充分性检查结果
Returns:
更新后的状态
"""
state["is_sufficient"] = sufficiency_check.is_sufficient
state["sufficiency_check"] = {
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"reason": sufficiency_check.reason,
"iteration": state["current_iteration"]
}
# 如果不充分且有子查询,更新子查询
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
state["current_sub_queries"] = sufficiency_check.sub_queries
state["sub_queries"].extend(sufficiency_check.sub_queries)
else:
state["current_sub_queries"] = []
# 更新调试信息
state["debug_info"]["llm_calls"] += 1
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "sufficiency_check",
"is_sufficient": sufficiency_check.is_sufficient,
"confidence": sufficiency_check.confidence,
"sub_queries_count": len(sufficiency_check.sub_queries or [])
}
state["iteration_history"].append(iteration_info)
return state
def increment_iteration(state: QueryState) -> QueryState:
"""
增加迭代次数
Args:
state: 当前状态
Returns:
更新后的状态
"""
state["current_iteration"] += 1
return state
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
"""
完成状态,设置最终答案
Args:
state: 当前状态
final_answer: 最终答案
Returns:
最终状态
"""
state["final_answer"] = final_answer
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "final_answer_generation",
"answer_length": len(final_answer)
}
state["iteration_history"].append(iteration_info)
return state
def update_state_with_event_triples(
state: QueryState,
event_triples: List[Dict[str, str]]
) -> QueryState:
"""
使用事件三元组更新状态
Args:
state: 当前状态
event_triples: 事件-事件三元组列表,格式: [{'source_entity': '事件A', 'relation': '关系', 'target_entity': '事件B', 'source_evidence': '来源', 'target_evidence': '来源'}, ...]
Returns:
更新后的状态
"""
state["event_triples"] = event_triples
# 添加到迭代历史
iteration_info = {
"iteration": state["current_iteration"],
"action": "event_triples_extraction",
"triples_count": len(event_triples)
}
state["iteration_history"].append(iteration_info)
return state
def get_state_summary(state: QueryState) -> Dict[str, Any]:
"""
获取状态摘要信息
Args:
state: 当前状态
Returns:
状态摘要字典
"""
return {
"original_query": state["original_query"],
"current_iteration": state["current_iteration"],
"max_iterations": state["max_iterations"],
"total_passages": len(state["all_passages"]),
"total_sub_queries": len(state["sub_queries"]),
"is_sufficient": state["is_sufficient"],
"has_final_answer": bool(state["final_answer"]),
"debug_info": state["debug_info"],
"iteration_count": len(state["iteration_history"])
}