task_queue_service_new.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """
  2. 异步任务队列服务
  3. 实现非阻塞AI生图任务
  4. """
  5. import asyncio
  6. import uuid
  7. import threading
  8. from datetime import datetime
  9. from typing import Dict, Any, Optional, Callable
  10. from enum import Enum
  11. import queue
  12. import time
  13. from backend.utils.logger_config import setup_logger
  14. logger = setup_logger(__name__)
  15. class TaskStatus(Enum):
  16. """任务状态枚举"""
  17. PENDING = "pending"
  18. PROCESSING = "processing"
  19. COMPLETED = "completed"
  20. FAILED = "failed"
  21. CANCELLED = "cancelled"
  22. class TaskQueueService:
  23. """
  24. 异步任务队列服务
  25. 实现真正的非阻塞任务处理
  26. """
  27. def __init__(self, max_workers: int = 3):
  28. """
  29. 初始化异步任务队列服务
  30. """
  31. self.max_workers = max_workers
  32. self.task_queue = queue.Queue()
  33. self.tasks: Dict[str, Dict[str, Any]] = {}
  34. self.workers: list = []
  35. self.running = False
  36. self.lock = threading.Lock()
  37. def start(self):
  38. """启动任务队列服务"""
  39. if self.running:
  40. logger.warning("任务队列服务已启动,请勿重复启动")
  41. return
  42. self.running = True
  43. # 启动工作线程
  44. for i in range(self.max_workers):
  45. worker = threading.Thread(target=self._worker_loop, name=f"TaskWorker-{i}")
  46. worker.daemon = True
  47. worker.start()
  48. self.workers.append(worker)
  49. logger.info("任务队列服务启动完成,启动 {self.max_workers} 个工作线程")
  50. def stop(self):
  51. """停止任务队列服务"""
  52. if not self.running:
  53. return
  54. self.running = False
  55. # 等待所有工作线程结束
  56. for worker in self.workers:
  57. worker.join(timeout=5)
  58. logger.info("任务队列服务已停止")
  59. def submit_task(self, task_func: Callable, task_args: tuple = (), task_kwargs: dict = None) -> str:
  60. """
  61. 提交一个任务
  62. :param task_func: 任务函数
  63. :param task_args: 任务参数
  64. :param task_kwargs: 任务关键字参数
  65. :return: 任务ID
  66. """
  67. if task_kwargs is None:
  68. task_kwargs = {}
  69. task_id = str(uuid.uuid4())
  70. # 创建任务记录
  71. task_info = {
  72. "id": task_id,
  73. "func": task_func,
  74. "args": task_args,
  75. "kwargs": task_kwargs,
  76. "status": TaskStatus.PENDING.value,
  77. "created_at": datetime.now(),
  78. "start_at": None,
  79. "completed_at": None,
  80. "result": None,
  81. "error": None
  82. "progress": 0
  83. }
  84. with self.lock:
  85. self.tasks[task_id] = task_info
  86. # 添加到队列
  87. self.task_queue.put(task_id)
  88. logger.info(f"任务已提交到队列:{task_id}")
  89. return task_id
  90. def _clean_task_for_serialization(self, task: Dict[str, Any]) -> Dict[str, Any]:
  91. """
  92. 清理任务信息,用于序列化
  93. :param task: 任务信息
  94. :return: 清理后的任务信息
  95. """
  96. cleaned_task = task.copy()
  97. # 移除函数对象
  98. if "func" in cleaned_task:
  99. del cleaned_task["func"]
  100. # 清理结果中不可序列化的对象
  101. if "result" in cleaned_task and isinstance(cleaned_task["result"], dict):
  102. result = cleaned_task["result"]
  103. # 移除PIL图像对象
  104. if "result_image" in result:
  105. del result["result_image"]
  106. return cleaned_task
  107. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  108. """
  109. 获取任务状态
  110. :param task_id: 任务ID
  111. :return: 任务状态
  112. """
  113. with self.lock:
  114. if task_id not in self.tasks:
  115. return None
  116. return self._clean_task_for_serialization(self.tasks[task_id])
  117. def get_all_tasks(self, user_id: Optional[int] = None) -> list:
  118. """
  119. 获取所有任务
  120. :param user_id: 用户ID
  121. :return: 所有任务列表
  122. """
  123. with self.lock:
  124. tasks = []
  125. for task_id, task in self.tasks.items():
  126. # 检查是否属于指定用户
  127. if user_id is not None:
  128. task_args = task.get("args", ())
  129. if len(task_args) > 0 and task_args[0] != user_id:
  130. continue
  131. tasks.append(self._clean_task_for_serialization(task))
  132. return tasks
  133. def cancel_task(self, task_id: str) -> bool:
  134. """
  135. 取消任务
  136. :param task_id: 任务ID
  137. :return: 是否成功取消
  138. """
  139. with self.lock:
  140. if task_id not in self.tasks:
  141. return False
  142. task = self.tasks[task_id]
  143. if task["status"] in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
  144. task["status"] = TaskStatus.CANCELLED
  145. task["completed_at"] = datetime.now()
  146. logger.info(f"任务已取消: {task_id}")
  147. return True
  148. return False
  149. def _worker_loop(self):
  150. """工作线程循环"""
  151. thread_name = threading.current_thread().name
  152. while self.running:
  153. try:
  154. # 从队列获取任务
  155. task_id = self.task_queue.get(timeout=1)
  156. while self.lock:
  157. if task_id not in self.tasks:
  158. continue
  159. task = self.tasks[task_id]
  160. # 检查任务是否被取消
  161. if task["status"] == TaskStatus.CANCELLED:
  162. continue
  163. # 更新任务状态为处理中
  164. task["status"] = TaskStatus.PROCESSING
  165. task["start_at"] = datetime.now()
  166. logger.info(f"{thread_name} 正在处理任务:{task_id}")
  167. try:
  168. # 执行任务
  169. result = task["func"](*task["args"], **task["kwargs"])
  170. with self.lock:
  171. task["status"] = TaskStatus.COMPLETED
  172. task["result"] = result
  173. task["completed_at"] = datetime.now()
  174. task["progress"] = 100
  175. logger.info(f"{thread_name} 任务完成:{task_id}")
  176. except Exception as e:
  177. logger.error(f"{thread_name} 任务出错:{task_id}")
  178. with self.lock:
  179. task["status"] = TaskStatus.FAILED
  180. task["error"] = str(e)
  181. task["completed_at"] = datetime.now()
  182. finally:
  183. self.task_queue.task_done()
  184. except queue.Empty:
  185. # 队列为空,防止循环
  186. continue
  187. except Exception as e:
  188. logger.error(f"{thread_name} 线程出错:{e}")
  189. time.sleep(1)
  190. def get_queue_stats(self) -> Dict[str, Any]:
  191. """获取队列统计信息"""
  192. with self.lock:
  193. total_tasks = len(self.tasks)
  194. pending_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PENDING)
  195. processing_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.PROCESSING)
  196. completed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.COMPLETED)
  197. failed_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.FAILED)
  198. cancelled_tasks = sum(1 for task in self.tasks.values() if task["status"] == TaskStatus.CANCELLED)
  199. return {
  200. "total_tasks": total_tasks,
  201. "pending_tasks": pending_tasks,
  202. "processing_tasks": processing_tasks,
  203. "completed_tasks": completed_tasks,
  204. "failed_tasks": failed_tasks,
  205. "cancelled_tasks": cancelled_tasks,
  206. "queue_size": self.task_queu.qsize(),
  207. "active_workers": len([w for w in self.workers if w.is_alive()]),
  208. "max_workers": self.max_workers
  209. }
  210. # 全局任务队列服务示例
  211. task_queue_service = TaskQueueService(max_workers=3)
  212. def get_task_queue_service() -> TaskQueueService:
  213. return task_queue_service