""" 异步任务队列服务 实现非阻塞AI生图任务 """ import asyncio import uuid import threading from datetime import datetime from typing import Dict, Any, Optional, Callable from enum import Enum import queue import time from backend.utils.logger_config import setup_logger logger = setup_logger(__name__) class TaskStatus(Enum): """任务状态枚举""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class TaskQueueService: """ 异步任务队列服务 实现真正的非阻塞任务处理 """ def __init__(self, max_workers: int = 3): """ 初始化异步任务队列服务 """ self.max_workers = max_workers self.task_queue = queue.Queue() self.tasks: Dict[str, Dict[str, Any]] = {} self.workers: list = [] self.running = False self.lock = threading.Lock() def start(self): """启动任务队列服务""" if self.running: logger.warning("任务队列服务已启动,请勿重复启动") return self.running = True # 启动工作线程 for i in range(self.max_workers): worker = threading.Thread(target=self._worker_loop, name=f"TaskWorker-{i}") worker.daemon = True worker.start() self.workers.append(worker) logger.info("任务队列服务启动完成,启动 {self.max_workers} 个工作线程") def stop(self): """停止任务队列服务""" if not self.running: return self.running = False # 等待所有工作线程结束 for worker in self.workers: worker.join(timeout=5) logger.info("任务队列服务已停止") def submit_task(self, task_func: Callable, task_args: tuple = (), task_kwargs: dict = None) -> str: """ 提交一个任务 :param task_func: 任务函数 :param task_args: 任务参数 :param task_kwargs: 任务关键字参数 :return: 任务ID """ if task_kwargs is None: task_kwargs = {} task_id = str(uuid.uuid4()) # 创建任务记录 task_info = { "id": task_id, "func": task_func, "args": task_args, "kwargs": task_kwargs, "status": TaskStatus.PENDING.value, "created_at": datetime.now(), "start_at": None, "completed_at": None, "result": None, "error": None "progress": 0 } with self.lock: self.tasks[task_id] = task_info # 添加到队列 self.task_queue.put(task_id) logger.info(f"任务已提交到队列:{task_id}") return task_id def _clean_task_for_serialization(self, task: Dict[str, Any]) -> Dict[str, Any]: """ 清理任务信息,用于序列化 :param task: 任务信息 :return: 清理后的任务信息 """ cleaned_task = task.copy() # 移除函数对象 if "func" in cleaned_task: del cleaned_task["func"] # 清理结果中不可序列化的对象 if "result" in cleaned_task and isinstance(cleaned_task["result"], dict): result = cleaned_task["result"] # 移除PIL图像对象 if "result_image" in result: del result["result_image"] return cleaned_task def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: """ 获取任务状态 :param task_id: 任务ID :return: 任务状态 """ with self.lock: if task_id not in self.tasks: return None return self._clean_task_for_serialization(self.tasks[task_id]) def get_all_tasks(self, user_id: Optional[int] = None) -> list: """ 获取所有任务 :param user_id: 用户ID :return: 所有任务列表 """ with self.lock: tasks = [] for task_id, task in self.tasks.items(): # 检查是否属于指定用户 if user_id is not None: task_args = task.get("args", ()) if len(task_args) > 0 and task_args[0] != user_id: continue tasks.append(self._clean_task_for_serialization(task)) return tasks def cancel_task(self, task_id: str) -> bool: """ 取消任务 :param task_id: 任务ID :return: 是否成功取消 """ with self.lock: if task_id not in self.tasks: return False task = self.tasks[task_id] if task["status"] in [TaskStatus.PENDING, TaskStatus.PROCESSING]: task["status"] = TaskStatus.CANCELLED task["completed_at"] = datetime.now() logger.info(f"任务已取消: {task_id}") return True return False def _worker_loop(self): """工作线程循环""" thread_name = threading.current_thread().name while self.running: try: # 从队列获取任务 task_id = self.task_queue.get(timeout=1) while self.lock: if task_id not in self.tasks: continue task = self.tasks[task_id] # 检查任务是否被取消 if task["status"] == TaskStatus.CANCELLED: continue # 更新任务状态为处理中 task["status"] = TaskStatus.PROCESSING task["start_at"] = datetime.now() logger.info(f"{thread_name} 正在处理任务:{task_id}") try: # 执行任务 result = task["func"](*task["args"], **task["kwargs"]) with self.lock: task["status"] = TaskStatus.COMPLETED task["result"] = result task["completed_at"] = datetime.now() task["progress"] = 100 logger.info(f"{thread_name} 任务完成:{task_id}") except Exception as e: logger.error(f"{thread_name} 任务出错:{task_id}") with self.lock: task["status"] = TaskStatus.FAILED task["error"] = str(e) task["completed_at"] = datetime.now() finally: self.task_queue.task_done() except queue.Empty: # 队列为空,防止循环 continue except Exception as e: logger.error(f"{thread_name} 线程出错:{e}") time.sleep(1) def get_queue_stats(self) -> Dict[str, Any]: """获取队列统计信息""" with self.lock: total_tasks = len(self.tasks) pending_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PENDING) processing_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PROCESSING) completed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.COMPLETED) failed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.FAILED) cancelled_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.CANCELLED) return { "total_tasks": total_tasks, "pending_tasks": pending_tasks, "processing_tasks": processing_tasks, "completed_tasks": completed_tasks, "failed_tasks": failed_tasks, "cancelled_tasks": cancelled_tasks, "queue_size": self.task_queu.qsize(), "active_workers": len([w for w in self.workers if w.is_alive()]), "max_workers": self.max_workers } # 全局任务队列服务示例 task_queue_service = TaskQueueService(max_workers=3) def get_task_queue_service() -> TaskQueueService: return task_queue_service