Files
AIEC-new/AIEC-RAG/task_manager.py
2025-10-17 09:31:28 +08:00

356 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
任务管理器
用于管理检索任务的生命周期,支持任务取消
"""
import uuid
import time
import threading
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from enum import Enum
class TaskStatus(Enum):
"""任务状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
CANCELLED = "cancelled"
FAILED = "failed"
class TaskInfo:
"""任务信息类"""
def __init__(self, task_id: str, query: str):
self.task_id = task_id
self.query = query
self.status = TaskStatus.PENDING
self.created_at = time.time()
self.started_at: Optional[float] = None
self.completed_at: Optional[float] = None
self.result: Optional[Dict[str, Any]] = None
self.error: Optional[str] = None
self.progress: int = 0
self.current_step: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"task_id": self.task_id,
"query": self.query,
"status": self.status.value,
"created_at": self.created_at,
"started_at": self.started_at,
"completed_at": self.completed_at,
"progress": self.progress,
"current_step": self.current_step,
"error": self.error,
"elapsed_time": self.get_elapsed_time()
}
def get_elapsed_time(self) -> float:
"""获取已用时间"""
if self.started_at:
end_time = self.completed_at or time.time()
return end_time - self.started_at
return 0
class TaskManager:
"""任务管理器"""
def __init__(self, max_tasks: int = 1000, task_timeout: float = 1800):
"""
初始化任务管理器
Args:
max_tasks: 最大任务数量
task_timeout: 任务超时时间默认30分钟
"""
self.tasks: Dict[str, TaskInfo] = {}
self.stop_events: Dict[str, threading.Event] = {}
self.max_tasks = max_tasks
self.task_timeout = task_timeout
self._lock = threading.RLock()
# 启动清理线程
self._cleanup_thread = threading.Thread(target=self._cleanup_expired_tasks, daemon=True)
self._cleanup_thread.start()
def create_task(self, query: str) -> str:
"""
创建新任务
Args:
query: 查询内容
Returns:
任务ID
"""
with self._lock:
# 检查任务数量限制
if len(self.tasks) >= self.max_tasks:
# 清理已完成的任务
self._cleanup_completed_tasks()
# 如果仍然超限,拒绝创建
if len(self.tasks) >= self.max_tasks:
raise RuntimeError(f"任务数量已达上限 ({self.max_tasks})")
# 生成任务ID
task_id = str(uuid.uuid4())
# 创建任务信息
task_info = TaskInfo(task_id, query)
self.tasks[task_id] = task_info
# 创建停止事件
self.stop_events[task_id] = threading.Event()
print(f"[TaskManager] 创建任务: {task_id[:8]}... for query: {query[:50]}...")
return task_id
def start_task(self, task_id: str) -> bool:
"""
开始执行任务
Args:
task_id: 任务ID
Returns:
是否成功
"""
with self._lock:
if task_id in self.tasks:
task = self.tasks[task_id]
if task.status == TaskStatus.PENDING:
task.status = TaskStatus.RUNNING
task.started_at = time.time()
print(f"[TaskManager] 开始任务: {task_id[:8]}...")
return True
return False
def complete_task(self, task_id: str, result: Dict[str, Any] = None) -> bool:
"""
完成任务
Args:
task_id: 任务ID
result: 任务结果
Returns:
是否成功
"""
with self._lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.status = TaskStatus.COMPLETED
task.completed_at = time.time()
task.result = result
task.progress = 100
print(f"[TaskManager] 完成任务: {task_id[:8]}... 耗时: {task.get_elapsed_time():.2f}s")
return True
return False
def fail_task(self, task_id: str, error: str) -> bool:
"""
标记任务失败
Args:
task_id: 任务ID
error: 错误信息
Returns:
是否成功
"""
with self._lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.status = TaskStatus.FAILED
task.completed_at = time.time()
task.error = error
print(f"[TaskManager] 任务失败: {task_id[:8]}... 错误: {error}")
return True
return False
def cancel_task(self, task_id: str) -> bool:
"""
取消任务
Args:
task_id: 任务ID
Returns:
是否成功
"""
with self._lock:
if task_id in self.stop_events:
# 设置停止信号
self.stop_events[task_id].set()
# 更新任务状态
if task_id in self.tasks:
task = self.tasks[task_id]
if task.status == TaskStatus.RUNNING:
task.status = TaskStatus.CANCELLED
task.completed_at = time.time()
print(f"[TaskManager] 取消任务: {task_id[:8]}...")
return True
return False
def should_stop(self, task_id: str) -> bool:
"""
检查任务是否应该停止
Args:
task_id: 任务ID
Returns:
是否应该停止
"""
if task_id in self.stop_events:
return self.stop_events[task_id].is_set()
return False
def update_progress(self, task_id: str, progress: int, current_step: str = "") -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
progress: 进度0-100
current_step: 当前步骤描述
Returns:
是否成功
"""
with self._lock:
if task_id in self.tasks:
task = self.tasks[task_id]
task.progress = max(0, min(100, progress))
if current_step:
task.current_step = current_step
return True
return False
def get_task_info(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务信息
Args:
task_id: 任务ID
Returns:
任务信息字典
"""
with self._lock:
if task_id in self.tasks:
return self.tasks[task_id].to_dict()
return None
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
"""
获取所有任务信息
Returns:
所有任务信息字典
"""
with self._lock:
return {tid: task.to_dict() for tid, task in self.tasks.items()}
def cleanup_task(self, task_id: str) -> bool:
"""
清理任务资源
Args:
task_id: 任务ID
Returns:
是否成功
"""
with self._lock:
removed = False
# 移除停止事件
if task_id in self.stop_events:
del self.stop_events[task_id]
removed = True
# 注意:任务信息暂时保留,供查询使用
# 将由清理线程定期清理
if removed:
print(f"[TaskManager] 清理任务资源: {task_id[:8]}...")
return removed
def _cleanup_completed_tasks(self):
"""清理已完成的任务(内部方法)"""
with self._lock:
current_time = time.time()
to_remove = []
for task_id, task in self.tasks.items():
# 清理已完成/取消/失败且超过5分钟的任务
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
if task.completed_at and (current_time - task.completed_at) > 300:
to_remove.append(task_id)
for task_id in to_remove:
del self.tasks[task_id]
if task_id in self.stop_events:
del self.stop_events[task_id]
if to_remove:
print(f"[TaskManager] 清理了 {len(to_remove)} 个已完成的任务")
def _cleanup_expired_tasks(self):
"""清理过期任务(后台线程)"""
while True:
try:
time.sleep(60) # 每分钟检查一次
with self._lock:
current_time = time.time()
to_remove = []
for task_id, task in self.tasks.items():
# 超时的运行中任务
if task.status == TaskStatus.RUNNING:
if task.started_at and (current_time - task.started_at) > self.task_timeout:
print(f"[TaskManager] 任务超时: {task_id[:8]}...")
task.status = TaskStatus.FAILED
task.error = "Task timeout"
task.completed_at = current_time
# 设置停止信号
if task_id in self.stop_events:
self.stop_events[task_id].set()
# 清理长时间完成的任务超过30分钟
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
if task.completed_at and (current_time - task.completed_at) > 1800:
to_remove.append(task_id)
# 移除过期任务
for task_id in to_remove:
del self.tasks[task_id]
if task_id in self.stop_events:
del self.stop_events[task_id]
if to_remove:
print(f"[TaskManager] 清理了 {len(to_remove)} 个过期任务")
except Exception as e:
print(f"[TaskManager] 清理线程错误: {e}")
time.sleep(10) # 出错后等待10秒
# 全局任务管理器实例
task_manager = TaskManager()
def get_task_manager() -> TaskManager:
"""获取全局任务管理器实例"""
return task_manager