356 lines
11 KiB
Python
356 lines
11 KiB
Python
"""
|
||
任务管理器
|
||
用于管理检索任务的生命周期,支持任务取消
|
||
"""
|
||
|
||
import uuid
|
||
import time
|
||
import threading
|
||
from typing import Dict, Any, Optional
|
||
from datetime import datetime, timedelta
|
||
from enum import Enum
|
||
|
||
|
||
class TaskStatus(Enum):
|
||
"""任务状态枚举"""
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
CANCELLED = "cancelled"
|
||
FAILED = "failed"
|
||
|
||
|
||
class TaskInfo:
|
||
"""任务信息类"""
|
||
|
||
def __init__(self, task_id: str, query: str):
|
||
self.task_id = task_id
|
||
self.query = query
|
||
self.status = TaskStatus.PENDING
|
||
self.created_at = time.time()
|
||
self.started_at: Optional[float] = None
|
||
self.completed_at: Optional[float] = None
|
||
self.result: Optional[Dict[str, Any]] = None
|
||
self.error: Optional[str] = None
|
||
self.progress: int = 0
|
||
self.current_step: str = ""
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"task_id": self.task_id,
|
||
"query": self.query,
|
||
"status": self.status.value,
|
||
"created_at": self.created_at,
|
||
"started_at": self.started_at,
|
||
"completed_at": self.completed_at,
|
||
"progress": self.progress,
|
||
"current_step": self.current_step,
|
||
"error": self.error,
|
||
"elapsed_time": self.get_elapsed_time()
|
||
}
|
||
|
||
def get_elapsed_time(self) -> float:
|
||
"""获取已用时间"""
|
||
if self.started_at:
|
||
end_time = self.completed_at or time.time()
|
||
return end_time - self.started_at
|
||
return 0
|
||
|
||
|
||
class TaskManager:
|
||
"""任务管理器"""
|
||
|
||
def __init__(self, max_tasks: int = 1000, task_timeout: float = 1800):
|
||
"""
|
||
初始化任务管理器
|
||
|
||
Args:
|
||
max_tasks: 最大任务数量
|
||
task_timeout: 任务超时时间(秒),默认30分钟
|
||
"""
|
||
self.tasks: Dict[str, TaskInfo] = {}
|
||
self.stop_events: Dict[str, threading.Event] = {}
|
||
self.max_tasks = max_tasks
|
||
self.task_timeout = task_timeout
|
||
self._lock = threading.RLock()
|
||
|
||
# 启动清理线程
|
||
self._cleanup_thread = threading.Thread(target=self._cleanup_expired_tasks, daemon=True)
|
||
self._cleanup_thread.start()
|
||
|
||
def create_task(self, query: str) -> str:
|
||
"""
|
||
创建新任务
|
||
|
||
Args:
|
||
query: 查询内容
|
||
|
||
Returns:
|
||
任务ID
|
||
"""
|
||
with self._lock:
|
||
# 检查任务数量限制
|
||
if len(self.tasks) >= self.max_tasks:
|
||
# 清理已完成的任务
|
||
self._cleanup_completed_tasks()
|
||
|
||
# 如果仍然超限,拒绝创建
|
||
if len(self.tasks) >= self.max_tasks:
|
||
raise RuntimeError(f"任务数量已达上限 ({self.max_tasks})")
|
||
|
||
# 生成任务ID
|
||
task_id = str(uuid.uuid4())
|
||
|
||
# 创建任务信息
|
||
task_info = TaskInfo(task_id, query)
|
||
self.tasks[task_id] = task_info
|
||
|
||
# 创建停止事件
|
||
self.stop_events[task_id] = threading.Event()
|
||
|
||
print(f"[TaskManager] 创建任务: {task_id[:8]}... for query: {query[:50]}...")
|
||
return task_id
|
||
|
||
def start_task(self, task_id: str) -> bool:
|
||
"""
|
||
开始执行任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
if task.status == TaskStatus.PENDING:
|
||
task.status = TaskStatus.RUNNING
|
||
task.started_at = time.time()
|
||
print(f"[TaskManager] 开始任务: {task_id[:8]}...")
|
||
return True
|
||
return False
|
||
|
||
def complete_task(self, task_id: str, result: Dict[str, Any] = None) -> bool:
|
||
"""
|
||
完成任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
result: 任务结果
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
task.status = TaskStatus.COMPLETED
|
||
task.completed_at = time.time()
|
||
task.result = result
|
||
task.progress = 100
|
||
print(f"[TaskManager] 完成任务: {task_id[:8]}... 耗时: {task.get_elapsed_time():.2f}s")
|
||
return True
|
||
return False
|
||
|
||
def fail_task(self, task_id: str, error: str) -> bool:
|
||
"""
|
||
标记任务失败
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
error: 错误信息
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
task.status = TaskStatus.FAILED
|
||
task.completed_at = time.time()
|
||
task.error = error
|
||
print(f"[TaskManager] 任务失败: {task_id[:8]}... 错误: {error}")
|
||
return True
|
||
return False
|
||
|
||
def cancel_task(self, task_id: str) -> bool:
|
||
"""
|
||
取消任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.stop_events:
|
||
# 设置停止信号
|
||
self.stop_events[task_id].set()
|
||
|
||
# 更新任务状态
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
if task.status == TaskStatus.RUNNING:
|
||
task.status = TaskStatus.CANCELLED
|
||
task.completed_at = time.time()
|
||
print(f"[TaskManager] 取消任务: {task_id[:8]}...")
|
||
return True
|
||
return False
|
||
|
||
def should_stop(self, task_id: str) -> bool:
|
||
"""
|
||
检查任务是否应该停止
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
是否应该停止
|
||
"""
|
||
if task_id in self.stop_events:
|
||
return self.stop_events[task_id].is_set()
|
||
return False
|
||
|
||
def update_progress(self, task_id: str, progress: int, current_step: str = "") -> bool:
|
||
"""
|
||
更新任务进度
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
progress: 进度(0-100)
|
||
current_step: 当前步骤描述
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
task.progress = max(0, min(100, progress))
|
||
if current_step:
|
||
task.current_step = current_step
|
||
return True
|
||
return False
|
||
|
||
def get_task_info(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取任务信息
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
任务信息字典
|
||
"""
|
||
with self._lock:
|
||
if task_id in self.tasks:
|
||
return self.tasks[task_id].to_dict()
|
||
return None
|
||
|
||
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
获取所有任务信息
|
||
|
||
Returns:
|
||
所有任务信息字典
|
||
"""
|
||
with self._lock:
|
||
return {tid: task.to_dict() for tid, task in self.tasks.items()}
|
||
|
||
def cleanup_task(self, task_id: str) -> bool:
|
||
"""
|
||
清理任务资源
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
removed = False
|
||
|
||
# 移除停止事件
|
||
if task_id in self.stop_events:
|
||
del self.stop_events[task_id]
|
||
removed = True
|
||
|
||
# 注意:任务信息暂时保留,供查询使用
|
||
# 将由清理线程定期清理
|
||
|
||
if removed:
|
||
print(f"[TaskManager] 清理任务资源: {task_id[:8]}...")
|
||
|
||
return removed
|
||
|
||
def _cleanup_completed_tasks(self):
|
||
"""清理已完成的任务(内部方法)"""
|
||
with self._lock:
|
||
current_time = time.time()
|
||
to_remove = []
|
||
|
||
for task_id, task in self.tasks.items():
|
||
# 清理已完成/取消/失败且超过5分钟的任务
|
||
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
|
||
if task.completed_at and (current_time - task.completed_at) > 300:
|
||
to_remove.append(task_id)
|
||
|
||
for task_id in to_remove:
|
||
del self.tasks[task_id]
|
||
if task_id in self.stop_events:
|
||
del self.stop_events[task_id]
|
||
|
||
if to_remove:
|
||
print(f"[TaskManager] 清理了 {len(to_remove)} 个已完成的任务")
|
||
|
||
def _cleanup_expired_tasks(self):
|
||
"""清理过期任务(后台线程)"""
|
||
while True:
|
||
try:
|
||
time.sleep(60) # 每分钟检查一次
|
||
|
||
with self._lock:
|
||
current_time = time.time()
|
||
to_remove = []
|
||
|
||
for task_id, task in self.tasks.items():
|
||
# 超时的运行中任务
|
||
if task.status == TaskStatus.RUNNING:
|
||
if task.started_at and (current_time - task.started_at) > self.task_timeout:
|
||
print(f"[TaskManager] 任务超时: {task_id[:8]}...")
|
||
task.status = TaskStatus.FAILED
|
||
task.error = "Task timeout"
|
||
task.completed_at = current_time
|
||
# 设置停止信号
|
||
if task_id in self.stop_events:
|
||
self.stop_events[task_id].set()
|
||
|
||
# 清理长时间完成的任务(超过30分钟)
|
||
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
|
||
if task.completed_at and (current_time - task.completed_at) > 1800:
|
||
to_remove.append(task_id)
|
||
|
||
# 移除过期任务
|
||
for task_id in to_remove:
|
||
del self.tasks[task_id]
|
||
if task_id in self.stop_events:
|
||
del self.stop_events[task_id]
|
||
|
||
if to_remove:
|
||
print(f"[TaskManager] 清理了 {len(to_remove)} 个过期任务")
|
||
|
||
except Exception as e:
|
||
print(f"[TaskManager] 清理线程错误: {e}")
|
||
time.sleep(10) # 出错后等待10秒
|
||
|
||
|
||
# 全局任务管理器实例
|
||
task_manager = TaskManager()
|
||
|
||
|
||
def get_task_manager() -> TaskManager:
|
||
"""获取全局任务管理器实例"""
|
||
return task_manager |