| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- """
- 异步任务队列服务
- 实现非阻塞AI生图任务
- """
- import asyncio
- import uuid
- import threading
- from datetime import datetime
- from typing import Dict, Any, Optional, Callable
- from enum import Enum
- import queue
- import time
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class TaskStatus(Enum):
- """任务状态枚举"""
- PENDING = "pending"
- PROCESSING = "processing"
- COMPLETED = "completed"
- FAILED = "failed"
- CANCELLED = "cancelled"
- class TaskQueueService:
- """
- 异步任务队列服务
- 实现真正的非阻塞任务处理
- """
- def __init__(self, max_workers: int = 3):
- """
- 初始化异步任务队列服务
- """
- self.max_workers = max_workers
- self.task_queue = queue.Queue()
- self.tasks: Dict[str, Dict[str, Any]] = {}
- self.workers: list = []
- self.running = False
- self.lock = threading.Lock()
- def start(self):
- """启动任务队列服务"""
- if self.running:
- logger.warning("任务队列服务已启动,请勿重复启动")
- return
- self.running = True
- # 启动工作线程
- for i in range(self.max_workers):
- worker = threading.Thread(target=self._worker_loop, name=f"TaskWorker-{i}")
- worker.daemon = True
- worker.start()
- self.workers.append(worker)
- logger.info("任务队列服务启动完成,启动 {self.max_workers} 个工作线程")
- def stop(self):
- """停止任务队列服务"""
- if not self.running:
- return
- self.running = False
- # 等待所有工作线程结束
- for worker in self.workers:
- worker.join(timeout=5)
- logger.info("任务队列服务已停止")
- def submit_task(self, task_func: Callable, task_args: tuple = (), task_kwargs: dict = None) -> str:
- """
- 提交一个任务
- :param task_func: 任务函数
- :param task_args: 任务参数
- :param task_kwargs: 任务关键字参数
- :return: 任务ID
- """
- if task_kwargs is None:
- task_kwargs = {}
- task_id = str(uuid.uuid4())
- # 创建任务记录
- task_info = {
- "id": task_id,
- "func": task_func,
- "args": task_args,
- "kwargs": task_kwargs,
- "status": TaskStatus.PENDING.value,
- "created_at": datetime.now(),
- "start_at": None,
- "completed_at": None,
- "result": None,
- "error": None
- "progress": 0
- }
- with self.lock:
- self.tasks[task_id] = task_info
- # 添加到队列
- self.task_queue.put(task_id)
- logger.info(f"任务已提交到队列:{task_id}")
- return task_id
- def _clean_task_for_serialization(self, task: Dict[str, Any]) -> Dict[str, Any]:
- """
- 清理任务信息,用于序列化
- :param task: 任务信息
- :return: 清理后的任务信息
- """
- cleaned_task = task.copy()
- # 移除函数对象
- if "func" in cleaned_task:
- del cleaned_task["func"]
- # 清理结果中不可序列化的对象
- if "result" in cleaned_task and isinstance(cleaned_task["result"], dict):
- result = cleaned_task["result"]
- # 移除PIL图像对象
- if "result_image" in result:
- del result["result_image"]
- return cleaned_task
- def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
- """
- 获取任务状态
- :param task_id: 任务ID
- :return: 任务状态
- """
- with self.lock:
- if task_id not in self.tasks:
- return None
- return self._clean_task_for_serialization(self.tasks[task_id])
- def get_all_tasks(self, user_id: Optional[int] = None) -> list:
- """
- 获取所有任务
- :param user_id: 用户ID
- :return: 所有任务列表
- """
- with self.lock:
- tasks = []
- for task_id, task in self.tasks.items():
- # 检查是否属于指定用户
- if user_id is not None:
- task_args = task.get("args", ())
- if len(task_args) > 0 and task_args[0] != user_id:
- continue
- tasks.append(self._clean_task_for_serialization(task))
- return tasks
- def cancel_task(self, task_id: str) -> bool:
- """
- 取消任务
- :param task_id: 任务ID
- :return: 是否成功取消
- """
- with self.lock:
- if task_id not in self.tasks:
- return False
-
- task = self.tasks[task_id]
- if task["status"] in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
- task["status"] = TaskStatus.CANCELLED
- task["completed_at"] = datetime.now()
- logger.info(f"任务已取消: {task_id}")
- return True
-
- return False
-
- def _worker_loop(self):
- """工作线程循环"""
- thread_name = threading.current_thread().name
- while self.running:
- try:
- # 从队列获取任务
- task_id = self.task_queue.get(timeout=1)
- while self.lock:
- if task_id not in self.tasks:
- continue
- task = self.tasks[task_id]
- # 检查任务是否被取消
- if task["status"] == TaskStatus.CANCELLED:
- continue
- # 更新任务状态为处理中
- task["status"] = TaskStatus.PROCESSING
- task["start_at"] = datetime.now()
- logger.info(f"{thread_name} 正在处理任务:{task_id}")
- try:
- # 执行任务
- result = task["func"](*task["args"], **task["kwargs"])
- with self.lock:
- task["status"] = TaskStatus.COMPLETED
- task["result"] = result
- task["completed_at"] = datetime.now()
- task["progress"] = 100
- logger.info(f"{thread_name} 任务完成:{task_id}")
- except Exception as e:
- logger.error(f"{thread_name} 任务出错:{task_id}")
- with self.lock:
- task["status"] = TaskStatus.FAILED
- task["error"] = str(e)
- task["completed_at"] = datetime.now()
-
- finally:
- self.task_queue.task_done()
- except queue.Empty:
- # 队列为空,防止循环
- continue
- except Exception as e:
- logger.error(f"{thread_name} 线程出错:{e}")
- time.sleep(1)
- def get_queue_stats(self) -> Dict[str, Any]:
- """获取队列统计信息"""
- with self.lock:
- total_tasks = len(self.tasks)
- pending_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PENDING)
- processing_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PROCESSING)
- completed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.COMPLETED)
- failed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.FAILED)
- cancelled_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.CANCELLED)
- return {
- "total_tasks": total_tasks,
- "pending_tasks": pending_tasks,
- "processing_tasks": processing_tasks,
- "completed_tasks": completed_tasks,
- "failed_tasks": failed_tasks,
- "cancelled_tasks": cancelled_tasks,
- "queue_size": self.task_queu.qsize(),
- "active_workers": len([w for w in self.workers if w.is_alive()]),
- "max_workers": self.max_workers
- }
- # 全局任务队列服务示例
- task_queue_service = TaskQueueService(max_workers=3)
- def get_task_queue_service() -> TaskQueueService:
- return task_queue_service
|