ai_swap_bg_service.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. import os
  2. import uuid
  3. import random
  4. import hashlib
  5. import numpy as np
  6. from PIL import Image
  7. from datetime import datetime
  8. from typing import Dict, Any, Optional, Tuple
  9. from backend.modules.comfyui.ai_swap_bg import ai_swap_bg_process
  10. from backend.modules.comfyui.ai_copywriter import gen_copywriter
  11. from backend.modules.database.operations import DatabaseOperations
  12. from backend.modules.database.models import ProcessRecord, ImageRecord, User
  13. from backend.modules.database.connection import DatabaseConnection
  14. from backend.modules.database.models import SystemConfig
  15. from backend.utils.logger_config import setup_logger
  16. from backend.utils.system_config import Config
  17. from backend.services.task_queue_service import get_task_queue_service, TaskStatus
  18. logger = setup_logger(__name__)
  19. class AISwapBgService:
  20. """
  21. AI换背景业务逻辑服务
  22. 负责协调AI处理、文案生成、数据库操作等完整业务流程
  23. """
  24. def __init__(self, db_operations: Optional[DatabaseOperations] = None):
  25. """
  26. 初始化服务
  27. Args:
  28. db_operations: 数据库操作对象,如果为None则创建默认实例
  29. """
  30. self.db_ops = db_operations or DatabaseOperations()
  31. self.system_config = Config('./backend/config/ai_swap_bg_config.json')
  32. self.task_queue = get_task_queue_service()
  33. # 确保输出目录存在
  34. os.makedirs(self.system_config.output_dir, exist_ok=True)
  35. def submit_swap_bg_task(
  36. self,
  37. user_id: int,
  38. image_id: int,
  39. prompt: str,
  40. quantity: int = 1
  41. ) -> str:
  42. """
  43. 提交换背景任务到队列(非阻塞)
  44. Args:
  45. user_id: 用户ID
  46. image_id: 图片ID
  47. prompt: 提示词
  48. quantity: 生成数量
  49. Returns:
  50. str: 任务ID
  51. """
  52. logger.info(f"提交用户 {user_id} 的换背景任务到队列")
  53. # 提交任务到队列
  54. task_id = self.task_queue.submit_task(
  55. task_func=self._execute_swap_bg_task,
  56. task_args=(user_id, image_id, prompt),
  57. task_kwargs={
  58. "quantity": max(1, int(quantity))
  59. }
  60. )
  61. logger.info(f"任务已提交到队列,任务ID: {task_id}")
  62. return task_id
  63. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  64. """
  65. 获取任务状态
  66. Args:
  67. task_id: 任务ID
  68. Returns:
  69. Dict: 任务状态信息
  70. """
  71. return self.task_queue.get_task_status(task_id)
  72. def get_user_tasks(self, user_id: int) -> list:
  73. """
  74. 获取用户的所有任务
  75. Args:
  76. user_id: 用户ID
  77. Returns:
  78. list: 任务列表
  79. """
  80. return self.task_queue.get_all_tasks(user_id)
  81. def cancel_task(self, task_id: str) -> bool:
  82. """
  83. 取消任务
  84. Args:
  85. task_id: 任务ID
  86. Returns:
  87. bool: 是否成功取消
  88. """
  89. return self.task_queue.cancel_task(task_id)
  90. def _execute_swap_bg_task(
  91. self,
  92. user_id: int,
  93. image_id: int,
  94. prompt: str,
  95. quantity: int = 1
  96. ) -> Dict[str, Any]:
  97. """
  98. 执行换背景任务(在线程池中运行)
  99. Args:
  100. user_id: 用户ID
  101. image_id: 图片ID
  102. prompt: 提示词
  103. quantity: 生成数量
  104. Returns:
  105. Dict: 包含处理结果的字典
  106. """
  107. try:
  108. logger.info(f"开始执行用户 {user_id} 的换背景任务")
  109. # 1. 输入验证
  110. self._validate_inputs(user_id, image_id, prompt)
  111. # 2. 获取输入图片
  112. image = self._get_input_image(image_id)
  113. # 3. 循环生成 quantity 次
  114. total_count = max(1, int(quantity))
  115. process_record_ids = []
  116. result_image_ids = []
  117. copywriter_texts = []
  118. history_prompt_last = None
  119. # 为了提升多张图的差异性,循环内随机化种子(若底层支持)
  120. try:
  121. from backend.modules.comfyui import ai_swap_bg as comfy_ai_swap_bg_module
  122. except Exception:
  123. comfy_ai_swap_bg_module = None
  124. for _ in range(total_count):
  125. if comfy_ai_swap_bg_module is not None:
  126. try:
  127. comfy_ai_swap_bg_module.system_config.seed = random.randint(1, 10_000_000)
  128. except Exception:
  129. pass
  130. # 执行AI换背景
  131. result_image, history_prompt = self._process_ai_swap_bg(
  132. prompt, image
  133. )
  134. history_prompt_last = history_prompt
  135. # 生成文案描述
  136. copywriter_text = self._generate_copywriter(result_image)
  137. copywriter_texts.append(copywriter_text)
  138. # 保存结果图片
  139. result_image_record = self._save_result_image(user_id, result_image)
  140. result_image_ids.append(result_image_record["id"])
  141. # 创建处理记录
  142. process_record = self._create_process_record(
  143. user_id, image_id, result_image_record["id"], copywriter_text, prompt
  144. )
  145. process_record_ids.append(process_record["id"])
  146. logger.info(
  147. f"用户 {user_id} 的换背景任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}"
  148. )
  149. # 为兼容旧前端,保留单值字段,同时返回批量字段
  150. return {
  151. "success": True,
  152. "count": len(process_record_ids),
  153. "process_record_id": process_record_ids[0],
  154. "result_image_id": result_image_ids[0],
  155. "copywriter_text": copywriter_texts[0] if copywriter_texts else None,
  156. "history_prompt": history_prompt_last,
  157. "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
  158. "process_record_ids": process_record_ids,
  159. "result_image_ids": result_image_ids,
  160. "copywriter_texts": copywriter_texts,
  161. }
  162. except Exception as e:
  163. logger.error(f"AI换背景任务执行失败: {str(e)}", exc_info=True)
  164. return {
  165. "success": False,
  166. "error": str(e),
  167. "error_type": type(e).__name__
  168. }
  169. def process_swap_bg_with_record(
  170. self,
  171. user_id: int,
  172. image_id: int,
  173. prompt: str,
  174. **kwargs
  175. ) -> Dict[str, Any]:
  176. """
  177. 完整的AI换背景业务流程(同步版本,保持向后兼容)
  178. Args:
  179. user_id: 用户ID
  180. image_id: 图片ID
  181. prompt: 用户输入的提示词
  182. **kwargs: 其他可选参数
  183. Returns:
  184. Dict: 包含处理结果的字典
  185. """
  186. return self._execute_swap_bg_task(
  187. user_id, image_id, prompt, **kwargs
  188. )
  189. def _validate_inputs(self, user_id: int, image_id: int, prompt: str):
  190. """
  191. 验证输入参数
  192. Args:
  193. user_id: 用户ID
  194. image_id: 图片ID
  195. prompt: 提示词
  196. """
  197. # 验证用户是否存在
  198. user = self.db_ops.get_user_by_id(user_id)
  199. if not user:
  200. raise ValueError(f"用户ID {user_id} 不存在")
  201. if not user.get("is_active", False):
  202. raise ValueError(f"用户ID {user_id} 已被禁用")
  203. # 验证图片
  204. image = self.db_ops.get_image_record_by_id(image_id)
  205. if not image:
  206. raise ValueError(f"图片ID {image_id} 不存在")
  207. if image["image_type"] != "original":
  208. raise ValueError(f"图片ID {image_id} 不是原始图片类型")
  209. if image["user_id"] != user_id:
  210. raise ValueError(f"图片ID {image_id} 不属于用户 {user_id}")
  211. # 验证提示词
  212. if not prompt or not prompt.strip():
  213. raise ValueError("提示词不能为空")
  214. if len(prompt.strip()) > 500:
  215. raise ValueError("提示词长度不能超过500字符")
  216. logger.info(f"输入验证通过: 用户={user_id}, 图片={image_id}")
  217. def _get_input_image(self, image_id: int) -> np.ndarray:
  218. """
  219. 获取输入图片数据
  220. Args:
  221. image_id: 图片ID
  222. Returns:
  223. np.ndarray: 图片数据
  224. """
  225. # 获取图片记录
  226. image_record = self.db_ops.get_image_record_by_id(image_id)
  227. if not os.path.exists(image_record["stored_path"]):
  228. raise FileNotFoundError(f"图片文件不存在: {image_record['stored_path']}")
  229. # 读取图片数据
  230. image = np.array(Image.open(image_record["stored_path"]))
  231. return image
  232. def _process_ai_swap_bg(
  233. self,
  234. prompt: str,
  235. image: np.ndarray
  236. ) -> Tuple[Image.Image, str]:
  237. """
  238. 执行AI换背景处理
  239. Args:
  240. prompt: 提示词
  241. image: 输入图片数据
  242. Returns:
  243. Tuple: (结果图片, 历史提示词)
  244. """
  245. logger.info(f"开始执行AI换背景处理,提示词: {prompt}")
  246. try:
  247. result_image, history_prompt = ai_swap_bg_process(
  248. prompt, image
  249. )
  250. logger.info("AI换背景处理完成")
  251. return result_image, history_prompt
  252. except Exception as e:
  253. logger.error(f"AI换背景处理失败: {str(e)}", exc_info=True)
  254. raise RuntimeError(f"AI换背景处理失败: {str(e)}")
  255. def _generate_copywriter(self, result_image: Image.Image) -> str:
  256. """
  257. 基于结果图片生成文案描述
  258. Args:
  259. result_image: 结果图片
  260. Returns:
  261. str: 生成的文案描述
  262. """
  263. logger.info("开始生成文案描述")
  264. try:
  265. copywriter_text = gen_copywriter(result_image)
  266. logger.info("文案描述生成完成")
  267. return copywriter_text
  268. except Exception as e:
  269. logger.error(f"文案生成失败: {str(e)}", exc_info=True)
  270. # 文案生成失败不影响主流程,返回默认文案
  271. return "AI换脸换装完成,效果很棒!✨"
  272. def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]:
  273. """
  274. 保存结果图片到数据库和文件系统
  275. Args:
  276. user_id: 用户ID
  277. result_image: 结果图片
  278. Returns:
  279. Dict: 图片记录信息
  280. """
  281. # 生成唯一文件名
  282. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  283. unique_id = str(uuid.uuid4())[:8]
  284. filename = f"result_{user_id}_{timestamp}_{unique_id}.png"
  285. # 保存到文件系统
  286. file_path = os.path.join(self.system_config.output_dir, filename)
  287. result_image.save(file_path, "PNG")
  288. # 计算文件大小和哈希值
  289. file_size = os.path.getsize(file_path)
  290. image_hash = self._calculate_image_hash(file_path)
  291. # 保存到数据库
  292. image_record = self.db_ops.create_image_record(
  293. user_id=user_id,
  294. image_type="result",
  295. original_filename=filename,
  296. stored_path=file_path,
  297. file_size=file_size,
  298. image_hash=image_hash
  299. )
  300. logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}")
  301. return image_record
  302. def _create_process_record(
  303. self,
  304. user_id: int,
  305. image_id: int,
  306. result_image_id: int,
  307. copywriter_text: str,
  308. prompt: str
  309. ) -> Dict[str, Any]:
  310. """
  311. 创建处理记录
  312. Args:
  313. user_id: 用户ID
  314. image_id: 输入图片ID
  315. result_image_id: 结果图片ID
  316. copywriter_text: 文案描述
  317. prompt: 提示词
  318. Returns:
  319. Dict: 处理记录信息
  320. """
  321. process_record = self.db_ops.create_process_record(
  322. user_id=user_id,
  323. face_image_id=image_id,
  324. cloth_image_id=image_id,
  325. result_image_id=result_image_id,
  326. generated_text=copywriter_text,
  327. task_type="swap_bg",
  328. prompt=prompt
  329. )
  330. # 更新完成时间
  331. self.db_ops.update_process_record(
  332. process_record["id"],
  333. {"completed_at": datetime.now()}
  334. )
  335. logger.info(f"处理记录创建成功: {process_record['id']}")
  336. return process_record
  337. def _calculate_image_hash(self, image_path: str) -> str:
  338. """
  339. 计算图片哈希值
  340. Args:
  341. image_path: 图片路径
  342. Returns:
  343. str: MD5哈希值
  344. """
  345. with open(image_path, 'rb') as f:
  346. return hashlib.md5(f.read()).hexdigest()
  347. def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  348. """
  349. 获取用户的处理历史记录
  350. Args:
  351. user_id: 用户ID
  352. page: 页码
  353. page_size: 每页大小
  354. Returns:
  355. Dict: 分页的处理记录列表
  356. """
  357. return self.db_ops.get_user_process_records(user_id, page, page_size)
  358. def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
  359. """
  360. 获取处理记录详情
  361. Args:
  362. process_id: 处理记录ID
  363. user_id: 用户ID(可选,用于权限验证)
  364. Returns:
  365. Dict: 处理记录详情,包含关联的图片信息
  366. """
  367. process_record = self.db_ops.get_process_record_by_id(process_id)
  368. if not process_record:
  369. return None
  370. # 权限验证
  371. if user_id and process_record["user_id"] != user_id:
  372. return None
  373. # 获取关联的图片信息
  374. face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
  375. cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
  376. result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
  377. return {
  378. "process_record": process_record,
  379. "face_image": face_image,
  380. "cloth_image": cloth_image,
  381. "result_image": result_image
  382. }
  383. def approve_process_record(self, process_id: int) -> bool:
  384. """
  385. 审核通过处理记录
  386. Args:
  387. process_id: 处理记录ID
  388. Returns:
  389. bool: 操作是否成功
  390. """
  391. try:
  392. # 检查记录是否存在
  393. process_record = self.db_ops.get_process_record_by_id(process_id)
  394. if not process_record:
  395. logger.error(f"处理记录不存在: {process_id}")
  396. return False
  397. # 更新状态为已审核
  398. update_data = {
  399. "status": "已审核"
  400. }
  401. updated_record = self.db_ops.update_process_record(process_id, update_data)
  402. if updated_record:
  403. logger.info(f"处理记录 {process_id} 审核通过")
  404. return True
  405. else:
  406. logger.error(f"更新处理记录状态失败: {process_id}")
  407. return False
  408. except Exception as e:
  409. logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
  410. return False
  411. def delete_result_image(self, process_id: int) -> bool:
  412. """
  413. 删除处理记录的结果图片:
  414. - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变
  415. - 可选地将状态置为“待审核”或维持原状态;此处不更改状态
  416. - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据
  417. """
  418. try:
  419. record = self.db_ops.get_process_record_by_id(process_id)
  420. if not record:
  421. logger.error(f"处理记录不存在: {process_id}")
  422. return False
  423. update_data = {
  424. "result_image_id": None
  425. }
  426. updated = self.db_ops.update_process_record(process_id, update_data)
  427. if not updated:
  428. logger.error(f"更新处理记录失败: {process_id}")
  429. return False
  430. logger.info(f"处理记录 {process_id} 已解除结果图片关联")
  431. return True
  432. except Exception as e:
  433. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  434. return False
  435. # 创建全局服务实例
  436. ai_swap_bg_service = AISwapBgService()
  437. def process_swap_bg_with_record(
  438. user_id: int,
  439. image_id: int,
  440. prompt: str,
  441. **kwargs
  442. ) -> Dict[str, Any]:
  443. """
  444. 便捷函数:执行完整的换背景流程
  445. Args:
  446. user_id: 用户ID
  447. image_id: 图片ID
  448. prompt: 提示词
  449. **kwargs: 其他可选参数
  450. Returns:
  451. Dict: 包含处理结果的字典
  452. """
  453. return ai_swap_bg_service.process_swap_bg_with_record(
  454. user_id, image_id, prompt, **kwargs
  455. )
  456. if __name__ == "__main__":
  457. try:
  458. # 假设用户ID为1,图片ID为1
  459. result = process_swap_bg_with_record(
  460. user_id=4,
  461. image_id=187,
  462. prompt="美女站在海边"
  463. )
  464. if result["success"]:
  465. print("处理成功!")
  466. print(f"处理记录ID: {result['process_record_id']}")
  467. print(f"结果图片ID: {result['result_image_id']}")
  468. print(f"文案: {result['copywriter_text']}")
  469. else:
  470. print(f"处理失败: {result['error']}")
  471. except Exception as e:
  472. print(f"测试失败: {str(e)}")