| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- """
- 异步任务队列服务
- 实现真正的非阻塞AI换脸换装处理
- """
- import asyncio
- import uuid
- import threading
- from datetime import datetime
- from typing import Dict, Any, Optional, Callable
- from enum import Enum
- import queue
- import time
- from backend.utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class TaskStatus(Enum):
- """任务状态枚举"""
- PENDING = "pending" # 等待中
- PROCESSING = "processing" # 处理中
- COMPLETED = "completed" # 已完成
- FAILED = "failed" # 失败
- CANCELLED = "cancelled" # 已取消
- class TaskQueueService:
- """
- 异步任务队列服务
- 实现真正的非阻塞任务处理
- """
-
- def __init__(self, max_workers: int = 3):
- """
- 初始化任务队列服务
-
- Args:
- max_workers: 最大工作线程数
- """
- self.max_workers = max_workers
- self.task_queue = queue.Queue()
- self.tasks: Dict[str, Dict[str, Any]] = {}
- self.workers: list = []
- self.running = False
- self.lock = threading.Lock()
-
- logger.info(f"任务队列服务初始化完成,最大工作线程数: {max_workers}")
-
- def start(self):
- """启动任务队列服务"""
- if self.running:
- logger.warning("任务队列服务已在运行")
- return
-
- self.running = True
-
- # 启动工作线程
- for i in range(self.max_workers):
- worker = threading.Thread(target=self._worker_loop, name=f"TaskWorker-{i}")
- worker.daemon = True
- worker.start()
- self.workers.append(worker)
-
- logger.info(f"任务队列服务启动完成,启动 {self.max_workers} 个工作线程")
-
- def stop(self):
- """停止任务队列服务"""
- if not self.running:
- return
-
- self.running = False
-
- # 等待所有工作线程结束
- for worker in self.workers:
- worker.join(timeout=5)
-
- logger.info("任务队列服务已停止")
-
- def submit_task(self, task_func: Callable, task_args: tuple = (), task_kwargs: dict = None) -> str:
- """
- 提交任务到队列
-
- Args:
- task_func: 任务函数
- task_args: 任务参数
- task_kwargs: 任务关键字参数
-
- Returns:
- str: 任务ID
- """
- if task_kwargs is None:
- task_kwargs = {}
-
- task_id = str(uuid.uuid4())
-
- # 创建任务记录
- task_info = {
- "id": task_id,
- "func": task_func,
- "args": task_args,
- "kwargs": task_kwargs,
- "status": TaskStatus.PENDING,
- "created_at": datetime.now(),
- "started_at": None,
- "completed_at": None,
- "result": None,
- "error": None,
- "progress": 0
- }
-
- with self.lock:
- self.tasks[task_id] = task_info
-
- # 添加到队列
- self.task_queue.put(task_id)
-
- logger.info(f"任务已提交到队列: {task_id}")
- return task_id
-
- def _clean_task_for_serialization(self, task: Dict[str, Any]) -> Dict[str, Any]:
- """
- 清理任务数据,移除不可序列化的对象
-
- Args:
- task: 任务数据
-
- Returns:
- Dict: 清理后的任务数据
- """
- cleaned_task = task.copy()
-
- # 移除函数对象
- if "func" in cleaned_task:
- del cleaned_task["func"]
-
- # 清理结果中的不可序列化对象
- if "result" in cleaned_task and isinstance(cleaned_task["result"], dict):
- result = cleaned_task["result"]
- # 移除PIL图像对象
- if "result_image" in result:
- del result["result_image"]
-
- return cleaned_task
-
- def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
- """
- 获取任务状态
-
- Args:
- task_id: 任务ID
-
- Returns:
- Dict: 任务状态信息
- """
- with self.lock:
- if task_id not in self.tasks:
- return None
-
- return self._clean_task_for_serialization(self.tasks[task_id])
-
- def get_all_tasks(self, user_id: Optional[int] = None) -> list:
- """
- 获取所有任务状态
-
- Args:
- user_id: 用户ID,如果提供则只返回该用户的任务
-
- Returns:
- list: 任务列表
- """
- with self.lock:
- tasks = []
- for task_id, task in self.tasks.items():
- # 检查是否属于指定用户
- if user_id is not None:
- task_args = task.get("args", ())
- if len(task_args) > 0 and task_args[0] != user_id:
- continue
-
- tasks.append(self._clean_task_for_serialization(task))
-
- return tasks
-
- def cancel_task(self, task_id: str) -> bool:
- """
- 取消任务
-
- Args:
- task_id: 任务ID
-
- Returns:
- bool: 是否成功取消
- """
- with self.lock:
- if task_id not in self.tasks:
- return False
-
- task = self.tasks[task_id]
- if task["status"] in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
- task["status"] = TaskStatus.CANCELLED
- task["completed_at"] = datetime.now()
- logger.info(f"任务已取消: {task_id}")
- return True
-
- return False
-
- def _worker_loop(self):
- """工作线程循环"""
- thread_name = threading.current_thread().name
-
- while self.running:
- try:
- # 从队列获取任务
- task_id = self.task_queue.get(timeout=1)
-
- with self.lock:
- if task_id not in self.tasks:
- continue
-
- task = self.tasks[task_id]
-
- # 检查任务是否被取消
- if task["status"] == TaskStatus.CANCELLED:
- continue
-
- # 更新任务状态为处理中
- task["status"] = TaskStatus.PROCESSING
- task["started_at"] = datetime.now()
-
- logger.info(f"[{thread_name}] 开始处理任务: {task_id}")
-
- try:
- # 执行任务
- result = task["func"](*task["args"], **task["kwargs"])
-
- with self.lock:
- task["status"] = TaskStatus.COMPLETED
- task["result"] = result
- task["completed_at"] = datetime.now()
- task["progress"] = 100
-
- logger.info(f"[{thread_name}] 任务完成: {task_id}")
-
- except Exception as e:
- logger.error(f"[{thread_name}] 任务执行失败: {task_id}, 错误: {str(e)}")
-
- with self.lock:
- task["status"] = TaskStatus.FAILED
- task["error"] = str(e)
- task["completed_at"] = datetime.now()
-
- finally:
- self.task_queue.task_done()
-
- except queue.Empty:
- # 队列为空,继续循环
- continue
- except Exception as e:
- logger.error(f"[{thread_name}] 工作线程异常: {str(e)}")
- time.sleep(1)
-
- def get_queue_stats(self) -> Dict[str, Any]:
- """
- 获取队列统计信息
-
- Returns:
- Dict: 队列统计信息
- """
- with self.lock:
- total_tasks = len(self.tasks)
- pending_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PENDING)
- processing_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PROCESSING)
- completed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.COMPLETED)
- failed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.FAILED)
- cancelled_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.CANCELLED)
-
- return {
- "total_tasks": total_tasks,
- "pending_tasks": pending_tasks,
- "processing_tasks": processing_tasks,
- "completed_tasks": completed_tasks,
- "failed_tasks": failed_tasks,
- "cancelled_tasks": cancelled_tasks,
- "queue_size": self.task_queue.qsize(),
- "active_workers": len([w for w in self.workers if w.is_alive()]),
- "max_workers": self.max_workers
- }
- # 全局任务队列服务实例
- task_queue_service = TaskQueueService(max_workers=3)
- def get_task_queue_service() -> TaskQueueService:
- """获取全局任务队列服务实例"""
- return task_queue_service
|