299 lines
8.1 KiB
Python
299 lines
8.1 KiB
Python
|
|
"""
|
|||
|
|
LangGraph状态定义
|
|||
|
|
定义工作流中的状态结构
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import List, Dict, Any, Optional, TypedDict, Callable
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from langchain_core.documents import Document
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QueryState(TypedDict):
|
|||
|
|
"""查询状态类型定义"""
|
|||
|
|
# 基本查询信息
|
|||
|
|
original_query: str
|
|||
|
|
current_iteration: int
|
|||
|
|
max_iterations: int
|
|||
|
|
|
|||
|
|
# 调试模式:0=自动判断, 'simple'=强制简单路径, 'complex'=强制复杂路径
|
|||
|
|
debug_mode: str
|
|||
|
|
|
|||
|
|
# 查询复杂度
|
|||
|
|
query_complexity: Dict[str, Any]
|
|||
|
|
is_complex_query: bool
|
|||
|
|
|
|||
|
|
# 检索结果
|
|||
|
|
all_passages: List[str] # 保留字段名以兼容现有代码,但现在存储事件信息
|
|||
|
|
all_documents: List[Document] # 现在包含事件节点文档
|
|||
|
|
passage_sources: List[str] # 事件来源信息
|
|||
|
|
|
|||
|
|
# 子查询相关
|
|||
|
|
sub_queries: List[str]
|
|||
|
|
current_sub_queries: List[str]
|
|||
|
|
decomposed_sub_queries: List[str] # 查询分解生成的初始子查询
|
|||
|
|
|
|||
|
|
# 充分性检查
|
|||
|
|
sufficiency_check: Dict[str, Any]
|
|||
|
|
is_sufficient: bool
|
|||
|
|
|
|||
|
|
# 最终结果
|
|||
|
|
final_answer: str
|
|||
|
|
|
|||
|
|
# 调试信息
|
|||
|
|
debug_info: Dict[str, Any]
|
|||
|
|
iteration_history: List[Dict[str, Any]]
|
|||
|
|
|
|||
|
|
# 概念探索相关状态(新增)
|
|||
|
|
# 注意: PageRank数据不再存储在状态中以避免LangSmith传输,改为本地临时存储
|
|||
|
|
concept_exploration_results: Dict[str, Any]
|
|||
|
|
exploration_round: int # 当前探索轮次 (1 或 2)
|
|||
|
|
|
|||
|
|
# 添加本地PageRank存储标识(不包含实际数据)
|
|||
|
|
pagerank_data_available: bool
|
|||
|
|
|
|||
|
|
# 流式回调函数(可选)
|
|||
|
|
stream_callback: Optional[Callable]
|
|||
|
|
|
|||
|
|
# 事件-事件三元组信息(新增)
|
|||
|
|
event_triples: List[Dict[str, str]]
|
|||
|
|
|
|||
|
|
# 任务管理相关字段(新增)
|
|||
|
|
task_id: Optional[str] # 任务ID
|
|||
|
|
should_stop: bool # 是否应该停止
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RetrievalResult:
|
|||
|
|
"""检索结果数据类"""
|
|||
|
|
passages: List[str]
|
|||
|
|
documents: List[Document]
|
|||
|
|
sources: List[str]
|
|||
|
|
query: str
|
|||
|
|
iteration: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class SufficiencyCheck:
|
|||
|
|
"""充分性检查结果数据类"""
|
|||
|
|
is_sufficient: bool
|
|||
|
|
confidence: float
|
|||
|
|
reason: str
|
|||
|
|
sub_queries: Optional[List[str]] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_initial_state(
|
|||
|
|
original_query: str,
|
|||
|
|
max_iterations: int = 2,
|
|||
|
|
debug_mode: str = "0",
|
|||
|
|
task_id: Optional[str] = None
|
|||
|
|
) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
创建初始状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
original_query: 用户的原始查询
|
|||
|
|
max_iterations: 最大迭代次数
|
|||
|
|
debug_mode: 调试模式,"0"=自动判断,"simple"=强制简单路径,"complex"=强制复杂路径
|
|||
|
|
task_id: 任务ID(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
初始状态字典
|
|||
|
|
"""
|
|||
|
|
return QueryState(
|
|||
|
|
original_query=original_query,
|
|||
|
|
current_iteration=0,
|
|||
|
|
max_iterations=max_iterations,
|
|||
|
|
debug_mode=debug_mode,
|
|||
|
|
query_complexity={},
|
|||
|
|
is_complex_query=False,
|
|||
|
|
all_passages=[],
|
|||
|
|
all_documents=[],
|
|||
|
|
passage_sources=[],
|
|||
|
|
sub_queries=[],
|
|||
|
|
current_sub_queries=[],
|
|||
|
|
decomposed_sub_queries=[],
|
|||
|
|
sufficiency_check={},
|
|||
|
|
is_sufficient=False,
|
|||
|
|
final_answer="",
|
|||
|
|
debug_info={
|
|||
|
|
"retrieval_calls": 0,
|
|||
|
|
"llm_calls": 0,
|
|||
|
|
"start_time": None,
|
|||
|
|
"end_time": None
|
|||
|
|
},
|
|||
|
|
iteration_history=[],
|
|||
|
|
concept_exploration_results={},
|
|||
|
|
exploration_round=0,
|
|||
|
|
pagerank_data_available=False,
|
|||
|
|
stream_callback=None, # 初始化为None,会在需要时设置
|
|||
|
|
event_triples=[], # 初始化为空列表
|
|||
|
|
task_id=task_id, # 任务ID
|
|||
|
|
should_stop=False # 初始化为不停止
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def update_state_with_retrieval(
|
|||
|
|
state: QueryState,
|
|||
|
|
retrieval_result: RetrievalResult
|
|||
|
|
) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
使用检索结果更新状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
retrieval_result: 检索结果
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
# 添加新的段落和文档
|
|||
|
|
state["all_passages"].extend(retrieval_result.passages)
|
|||
|
|
state["all_documents"].extend(retrieval_result.documents)
|
|||
|
|
state["passage_sources"].extend(retrieval_result.sources)
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
state["debug_info"]["retrieval_calls"] += 1
|
|||
|
|
|
|||
|
|
# 添加到迭代历史
|
|||
|
|
iteration_info = {
|
|||
|
|
"iteration": retrieval_result.iteration,
|
|||
|
|
"query": retrieval_result.query,
|
|||
|
|
"passages_count": len(retrieval_result.passages),
|
|||
|
|
"action": "retrieval"
|
|||
|
|
}
|
|||
|
|
state["iteration_history"].append(iteration_info)
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def update_state_with_sufficiency_check(
|
|||
|
|
state: QueryState,
|
|||
|
|
sufficiency_check: SufficiencyCheck
|
|||
|
|
) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
使用充分性检查结果更新状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
sufficiency_check: 充分性检查结果
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
state["is_sufficient"] = sufficiency_check.is_sufficient
|
|||
|
|
state["sufficiency_check"] = {
|
|||
|
|
"is_sufficient": sufficiency_check.is_sufficient,
|
|||
|
|
"confidence": sufficiency_check.confidence,
|
|||
|
|
"reason": sufficiency_check.reason,
|
|||
|
|
"iteration": state["current_iteration"]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 如果不充分且有子查询,更新子查询
|
|||
|
|
if not sufficiency_check.is_sufficient and sufficiency_check.sub_queries:
|
|||
|
|
state["current_sub_queries"] = sufficiency_check.sub_queries
|
|||
|
|
state["sub_queries"].extend(sufficiency_check.sub_queries)
|
|||
|
|
else:
|
|||
|
|
state["current_sub_queries"] = []
|
|||
|
|
|
|||
|
|
# 更新调试信息
|
|||
|
|
state["debug_info"]["llm_calls"] += 1
|
|||
|
|
|
|||
|
|
# 添加到迭代历史
|
|||
|
|
iteration_info = {
|
|||
|
|
"iteration": state["current_iteration"],
|
|||
|
|
"action": "sufficiency_check",
|
|||
|
|
"is_sufficient": sufficiency_check.is_sufficient,
|
|||
|
|
"confidence": sufficiency_check.confidence,
|
|||
|
|
"sub_queries_count": len(sufficiency_check.sub_queries or [])
|
|||
|
|
}
|
|||
|
|
state["iteration_history"].append(iteration_info)
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def increment_iteration(state: QueryState) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
增加迭代次数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
state["current_iteration"] += 1
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def finalize_state(state: QueryState, final_answer: str) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
完成状态,设置最终答案
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
final_answer: 最终答案
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
最终状态
|
|||
|
|
"""
|
|||
|
|
state["final_answer"] = final_answer
|
|||
|
|
|
|||
|
|
# 添加到迭代历史
|
|||
|
|
iteration_info = {
|
|||
|
|
"iteration": state["current_iteration"],
|
|||
|
|
"action": "final_answer_generation",
|
|||
|
|
"answer_length": len(final_answer)
|
|||
|
|
}
|
|||
|
|
state["iteration_history"].append(iteration_info)
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def update_state_with_event_triples(
|
|||
|
|
state: QueryState,
|
|||
|
|
event_triples: List[Dict[str, str]]
|
|||
|
|
) -> QueryState:
|
|||
|
|
"""
|
|||
|
|
使用事件三元组更新状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
event_triples: 事件-事件三元组列表,格式: [{'source_entity': '事件A', 'relation': '关系', 'target_entity': '事件B', 'source_evidence': '来源', 'target_evidence': '来源'}, ...]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
更新后的状态
|
|||
|
|
"""
|
|||
|
|
state["event_triples"] = event_triples
|
|||
|
|
|
|||
|
|
# 添加到迭代历史
|
|||
|
|
iteration_info = {
|
|||
|
|
"iteration": state["current_iteration"],
|
|||
|
|
"action": "event_triples_extraction",
|
|||
|
|
"triples_count": len(event_triples)
|
|||
|
|
}
|
|||
|
|
state["iteration_history"].append(iteration_info)
|
|||
|
|
|
|||
|
|
return state
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_state_summary(state: QueryState) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
获取状态摘要信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
state: 当前状态
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
状态摘要字典
|
|||
|
|
"""
|
|||
|
|
return {
|
|||
|
|
"original_query": state["original_query"],
|
|||
|
|
"current_iteration": state["current_iteration"],
|
|||
|
|
"max_iterations": state["max_iterations"],
|
|||
|
|
"total_passages": len(state["all_passages"]),
|
|||
|
|
"total_sub_queries": len(state["sub_queries"]),
|
|||
|
|
"is_sufficient": state["is_sufficient"],
|
|||
|
|
"has_final_answer": bool(state["final_answer"]),
|
|||
|
|
"debug_info": state["debug_info"],
|
|||
|
|
"iteration_count": len(state["iteration_history"])
|
|||
|
|
}
|