task_queue_service.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """
  2. 异步任务队列服务
  3. 实现真正的非阻塞AI换脸换装处理
  4. """
  5. import asyncio
  6. import uuid
  7. import threading
  8. from datetime import datetime
  9. from typing import Dict, Any, Optional, Callable
  10. from enum import Enum
  11. import queue
  12. import time
  13. from backend.utils.logger_config import setup_logger
  14. logger = setup_logger(__name__)
  15. class TaskStatus(Enum):
  16. """任务状态枚举"""
  17. PENDING = "pending" # 等待中
  18. PROCESSING = "processing" # 处理中
  19. COMPLETED = "completed" # 已完成
  20. FAILED = "failed" # 失败
  21. CANCELLED = "cancelled" # 已取消
  22. class TaskQueueService:
  23. """
  24. 异步任务队列服务
  25. 实现真正的非阻塞任务处理
  26. """
  27. def __init__(self, max_workers: int = 3):
  28. """
  29. 初始化任务队列服务
  30. Args:
  31. max_workers: 最大工作线程数
  32. """
  33. self.max_workers = max_workers
  34. self.task_queue = queue.Queue()
  35. self.tasks: Dict[str, Dict[str, Any]] = {}
  36. self.workers: list = []
  37. self.running = False
  38. self.lock = threading.Lock()
  39. logger.info(f"任务队列服务初始化完成,最大工作线程数: {max_workers}")
  40. def start(self):
  41. """启动任务队列服务"""
  42. if self.running:
  43. logger.warning("任务队列服务已在运行")
  44. return
  45. self.running = True
  46. # 启动工作线程
  47. for i in range(self.max_workers):
  48. worker = threading.Thread(target=self._worker_loop, name=f"TaskWorker-{i}")
  49. worker.daemon = True
  50. worker.start()
  51. self.workers.append(worker)
  52. logger.info(f"任务队列服务启动完成,启动 {self.max_workers} 个工作线程")
  53. def stop(self):
  54. """停止任务队列服务"""
  55. if not self.running:
  56. return
  57. self.running = False
  58. # 等待所有工作线程结束
  59. for worker in self.workers:
  60. worker.join(timeout=5)
  61. logger.info("任务队列服务已停止")
  62. def submit_task(self, task_func: Callable, task_args: tuple = (), task_kwargs: dict = None) -> str:
  63. """
  64. 提交任务到队列
  65. Args:
  66. task_func: 任务函数
  67. task_args: 任务参数
  68. task_kwargs: 任务关键字参数
  69. Returns:
  70. str: 任务ID
  71. """
  72. if task_kwargs is None:
  73. task_kwargs = {}
  74. task_id = str(uuid.uuid4())
  75. # 创建任务记录
  76. task_info = {
  77. "id": task_id,
  78. "func": task_func,
  79. "args": task_args,
  80. "kwargs": task_kwargs,
  81. "status": TaskStatus.PENDING,
  82. "created_at": datetime.now(),
  83. "started_at": None,
  84. "completed_at": None,
  85. "result": None,
  86. "error": None,
  87. "progress": 0
  88. }
  89. with self.lock:
  90. self.tasks[task_id] = task_info
  91. # 添加到队列
  92. self.task_queue.put(task_id)
  93. logger.info(f"任务已提交到队列: {task_id}")
  94. return task_id
  95. def _clean_task_for_serialization(self, task: Dict[str, Any]) -> Dict[str, Any]:
  96. """
  97. 清理任务数据,移除不可序列化的对象
  98. Args:
  99. task: 任务数据
  100. Returns:
  101. Dict: 清理后的任务数据
  102. """
  103. cleaned_task = task.copy()
  104. # 移除函数对象
  105. if "func" in cleaned_task:
  106. del cleaned_task["func"]
  107. # 清理结果中的不可序列化对象
  108. if "result" in cleaned_task and isinstance(cleaned_task["result"], dict):
  109. result = cleaned_task["result"]
  110. # 移除PIL图像对象
  111. if "result_image" in result:
  112. del result["result_image"]
  113. return cleaned_task
  114. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  115. """
  116. 获取任务状态
  117. Args:
  118. task_id: 任务ID
  119. Returns:
  120. Dict: 任务状态信息
  121. """
  122. with self.lock:
  123. if task_id not in self.tasks:
  124. return None
  125. return self._clean_task_for_serialization(self.tasks[task_id])
  126. def get_all_tasks(self, user_id: Optional[int] = None) -> list:
  127. """
  128. 获取所有任务状态
  129. Args:
  130. user_id: 用户ID,如果提供则只返回该用户的任务
  131. Returns:
  132. list: 任务列表
  133. """
  134. with self.lock:
  135. tasks = []
  136. for task_id, task in self.tasks.items():
  137. # 检查是否属于指定用户
  138. if user_id is not None:
  139. task_args = task.get("args", ())
  140. if len(task_args) > 0 and task_args[0] != user_id:
  141. continue
  142. tasks.append(self._clean_task_for_serialization(task))
  143. return tasks
  144. def cancel_task(self, task_id: str) -> bool:
  145. """
  146. 取消任务
  147. Args:
  148. task_id: 任务ID
  149. Returns:
  150. bool: 是否成功取消
  151. """
  152. with self.lock:
  153. if task_id not in self.tasks:
  154. return False
  155. task = self.tasks[task_id]
  156. if task["status"] in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
  157. task["status"] = TaskStatus.CANCELLED
  158. task["completed_at"] = datetime.now()
  159. logger.info(f"任务已取消: {task_id}")
  160. return True
  161. return False
  162. def _worker_loop(self):
  163. """工作线程循环"""
  164. thread_name = threading.current_thread().name
  165. while self.running:
  166. try:
  167. # 从队列获取任务
  168. task_id = self.task_queue.get(timeout=1)
  169. with self.lock:
  170. if task_id not in self.tasks:
  171. continue
  172. task = self.tasks[task_id]
  173. # 检查任务是否被取消
  174. if task["status"] == TaskStatus.CANCELLED:
  175. continue
  176. # 更新任务状态为处理中
  177. task["status"] = TaskStatus.PROCESSING
  178. task["started_at"] = datetime.now()
  179. logger.info(f"[{thread_name}] 开始处理任务: {task_id}")
  180. try:
  181. # 执行任务
  182. result = task["func"](*task["args"], **task["kwargs"])
  183. with self.lock:
  184. task["status"] = TaskStatus.COMPLETED
  185. task["result"] = result
  186. task["completed_at"] = datetime.now()
  187. task["progress"] = 100
  188. logger.info(f"[{thread_name}] 任务完成: {task_id}")
  189. except Exception as e:
  190. logger.error(f"[{thread_name}] 任务执行失败: {task_id}, 错误: {str(e)}")
  191. with self.lock:
  192. task["status"] = TaskStatus.FAILED
  193. task["error"] = str(e)
  194. task["completed_at"] = datetime.now()
  195. finally:
  196. self.task_queue.task_done()
  197. except queue.Empty:
  198. # 队列为空,继续循环
  199. continue
  200. except Exception as e:
  201. logger.error(f"[{thread_name}] 工作线程异常: {str(e)}")
  202. time.sleep(1)
  203. def get_queue_stats(self) -> Dict[str, Any]:
  204. """
  205. 获取队列统计信息
  206. Returns:
  207. Dict: 队列统计信息
  208. """
  209. with self.lock:
  210. total_tasks = len(self.tasks)
  211. pending_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PENDING)
  212. processing_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PROCESSING)
  213. completed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.COMPLETED)
  214. failed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.FAILED)
  215. cancelled_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.CANCELLED)
  216. return {
  217. "total_tasks": total_tasks,
  218. "pending_tasks": pending_tasks,
  219. "processing_tasks": processing_tasks,
  220. "completed_tasks": completed_tasks,
  221. "failed_tasks": failed_tasks,
  222. "cancelled_tasks": cancelled_tasks,
  223. "queue_size": self.task_queue.qsize(),
  224. "active_workers": len([w for w in self.workers if w.is_alive()]),
  225. "max_workers": self.max_workers
  226. }
  227. # 全局任务队列服务实例
  228. task_queue_service = TaskQueueService(max_workers=3)
  229. def get_task_queue_service() -> TaskQueueService:
  230. """获取全局任务队列服务实例"""
  231. return task_queue_service