first commit
This commit is contained in:
356
AIEC-RAG/task_manager.py
Normal file
356
AIEC-RAG/task_manager.py
Normal file
@ -0,0 +1,356 @@
|
||||
"""
|
||||
任务管理器
|
||||
用于管理检索任务的生命周期,支持任务取消
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskInfo:
|
||||
"""任务信息类"""
|
||||
|
||||
def __init__(self, task_id: str, query: str):
|
||||
self.task_id = task_id
|
||||
self.query = query
|
||||
self.status = TaskStatus.PENDING
|
||||
self.created_at = time.time()
|
||||
self.started_at: Optional[float] = None
|
||||
self.completed_at: Optional[float] = None
|
||||
self.result: Optional[Dict[str, Any]] = None
|
||||
self.error: Optional[str] = None
|
||||
self.progress: int = 0
|
||||
self.current_step: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"query": self.query,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
"progress": self.progress,
|
||||
"current_step": self.current_step,
|
||||
"error": self.error,
|
||||
"elapsed_time": self.get_elapsed_time()
|
||||
}
|
||||
|
||||
def get_elapsed_time(self) -> float:
|
||||
"""获取已用时间"""
|
||||
if self.started_at:
|
||||
end_time = self.completed_at or time.time()
|
||||
return end_time - self.started_at
|
||||
return 0
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""任务管理器"""
|
||||
|
||||
def __init__(self, max_tasks: int = 1000, task_timeout: float = 1800):
|
||||
"""
|
||||
初始化任务管理器
|
||||
|
||||
Args:
|
||||
max_tasks: 最大任务数量
|
||||
task_timeout: 任务超时时间(秒),默认30分钟
|
||||
"""
|
||||
self.tasks: Dict[str, TaskInfo] = {}
|
||||
self.stop_events: Dict[str, threading.Event] = {}
|
||||
self.max_tasks = max_tasks
|
||||
self.task_timeout = task_timeout
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 启动清理线程
|
||||
self._cleanup_thread = threading.Thread(target=self._cleanup_expired_tasks, daemon=True)
|
||||
self._cleanup_thread.start()
|
||||
|
||||
def create_task(self, query: str) -> str:
|
||||
"""
|
||||
创建新任务
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
|
||||
Returns:
|
||||
任务ID
|
||||
"""
|
||||
with self._lock:
|
||||
# 检查任务数量限制
|
||||
if len(self.tasks) >= self.max_tasks:
|
||||
# 清理已完成的任务
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
# 如果仍然超限,拒绝创建
|
||||
if len(self.tasks) >= self.max_tasks:
|
||||
raise RuntimeError(f"任务数量已达上限 ({self.max_tasks})")
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 创建任务信息
|
||||
task_info = TaskInfo(task_id, query)
|
||||
self.tasks[task_id] = task_info
|
||||
|
||||
# 创建停止事件
|
||||
self.stop_events[task_id] = threading.Event()
|
||||
|
||||
print(f"[TaskManager] 创建任务: {task_id[:8]}... for query: {query[:50]}...")
|
||||
return task_id
|
||||
|
||||
def start_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
开始执行任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
if task.status == TaskStatus.PENDING:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = time.time()
|
||||
print(f"[TaskManager] 开始任务: {task_id[:8]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
def complete_task(self, task_id: str, result: Dict[str, Any] = None) -> bool:
|
||||
"""
|
||||
完成任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
result: 任务结果
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.completed_at = time.time()
|
||||
task.result = result
|
||||
task.progress = 100
|
||||
print(f"[TaskManager] 完成任务: {task_id[:8]}... 耗时: {task.get_elapsed_time():.2f}s")
|
||||
return True
|
||||
return False
|
||||
|
||||
def fail_task(self, task_id: str, error: str) -> bool:
|
||||
"""
|
||||
标记任务失败
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
task.status = TaskStatus.FAILED
|
||||
task.completed_at = time.time()
|
||||
task.error = error
|
||||
print(f"[TaskManager] 任务失败: {task_id[:8]}... 错误: {error}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.stop_events:
|
||||
# 设置停止信号
|
||||
self.stop_events[task_id].set()
|
||||
|
||||
# 更新任务状态
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
if task.status == TaskStatus.RUNNING:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.completed_at = time.time()
|
||||
print(f"[TaskManager] 取消任务: {task_id[:8]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
def should_stop(self, task_id: str) -> bool:
|
||||
"""
|
||||
检查任务是否应该停止
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否应该停止
|
||||
"""
|
||||
if task_id in self.stop_events:
|
||||
return self.stop_events[task_id].is_set()
|
||||
return False
|
||||
|
||||
def update_progress(self, task_id: str, progress: int, current_step: str = "") -> bool:
|
||||
"""
|
||||
更新任务进度
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
progress: 进度(0-100)
|
||||
current_step: 当前步骤描述
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
task.progress = max(0, min(100, progress))
|
||||
if current_step:
|
||||
task.current_step = current_step
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_task_info(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取任务信息
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.tasks:
|
||||
return self.tasks[task_id].to_dict()
|
||||
return None
|
||||
|
||||
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有任务信息
|
||||
|
||||
Returns:
|
||||
所有任务信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {tid: task.to_dict() for tid, task in self.tasks.items()}
|
||||
|
||||
def cleanup_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
清理任务资源
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
with self._lock:
|
||||
removed = False
|
||||
|
||||
# 移除停止事件
|
||||
if task_id in self.stop_events:
|
||||
del self.stop_events[task_id]
|
||||
removed = True
|
||||
|
||||
# 注意:任务信息暂时保留,供查询使用
|
||||
# 将由清理线程定期清理
|
||||
|
||||
if removed:
|
||||
print(f"[TaskManager] 清理任务资源: {task_id[:8]}...")
|
||||
|
||||
return removed
|
||||
|
||||
def _cleanup_completed_tasks(self):
|
||||
"""清理已完成的任务(内部方法)"""
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for task_id, task in self.tasks.items():
|
||||
# 清理已完成/取消/失败且超过5分钟的任务
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
|
||||
if task.completed_at and (current_time - task.completed_at) > 300:
|
||||
to_remove.append(task_id)
|
||||
|
||||
for task_id in to_remove:
|
||||
del self.tasks[task_id]
|
||||
if task_id in self.stop_events:
|
||||
del self.stop_events[task_id]
|
||||
|
||||
if to_remove:
|
||||
print(f"[TaskManager] 清理了 {len(to_remove)} 个已完成的任务")
|
||||
|
||||
def _cleanup_expired_tasks(self):
|
||||
"""清理过期任务(后台线程)"""
|
||||
while True:
|
||||
try:
|
||||
time.sleep(60) # 每分钟检查一次
|
||||
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for task_id, task in self.tasks.items():
|
||||
# 超时的运行中任务
|
||||
if task.status == TaskStatus.RUNNING:
|
||||
if task.started_at and (current_time - task.started_at) > self.task_timeout:
|
||||
print(f"[TaskManager] 任务超时: {task_id[:8]}...")
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = "Task timeout"
|
||||
task.completed_at = current_time
|
||||
# 设置停止信号
|
||||
if task_id in self.stop_events:
|
||||
self.stop_events[task_id].set()
|
||||
|
||||
# 清理长时间完成的任务(超过30分钟)
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED]:
|
||||
if task.completed_at and (current_time - task.completed_at) > 1800:
|
||||
to_remove.append(task_id)
|
||||
|
||||
# 移除过期任务
|
||||
for task_id in to_remove:
|
||||
del self.tasks[task_id]
|
||||
if task_id in self.stop_events:
|
||||
del self.stop_events[task_id]
|
||||
|
||||
if to_remove:
|
||||
print(f"[TaskManager] 清理了 {len(to_remove)} 个过期任务")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[TaskManager] 清理线程错误: {e}")
|
||||
time.sleep(10) # 出错后等待10秒
|
||||
|
||||
|
||||
# 全局任务管理器实例
|
||||
task_manager = TaskManager()
|
||||
|
||||
|
||||
def get_task_manager() -> TaskManager:
|
||||
"""获取全局任务管理器实例"""
|
||||
return task_manager
|
||||
Reference in New Issue
Block a user