ai_swap_cloth_service.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import os
  2. import uuid
  3. import random
  4. import hashlib
  5. import numpy as np
  6. from PIL import Image
  7. from datetime import datetime
  8. from typing import Dict, Any, Optional, Tuple
  9. from backend.modules.comfyui.ai_swap_cloth import ai_swap_cloth_process
  10. from backend.modules.comfyui.ai_copywriter import gen_copywriter
  11. from backend.modules.database.operations import DatabaseOperations
  12. from backend.modules.database.models import ProcessRecord, ImageRecord, User
  13. from backend.modules.database.connection import DatabaseConnection
  14. from backend.modules.database.models import SystemConfig
  15. from backend.utils.logger_config import setup_logger
  16. from backend.utils.system_config import Config
  17. from backend.services.task_queue_service import get_task_queue_service, TaskStatus
  18. logger = setup_logger(__name__)
  19. class AISwapClothService:
  20. """
  21. AI换衣服业务逻辑服务
  22. 负责协调AI处理、文案生成、数据库操作等完整业务流程
  23. """
  24. def __init__(self, db_operations: Optional[DatabaseOperations] = None):
  25. """
  26. 初始化服务
  27. Args:
  28. db_operations: 数据库操作对象,如果为None则创建默认实例
  29. """
  30. self.db_ops = db_operations or DatabaseOperations()
  31. self.system_config = Config('./backend/config/ai_swap_cloth_config.json')
  32. self.task_queue = get_task_queue_service()
  33. # 确保输出目录存在
  34. os.makedirs(self.system_config.output_dir, exist_ok=True)
  35. def submit_swap_cloth_task(
  36. self,
  37. user_id: int,
  38. raw_image_id: int,
  39. cloth_image_id: int,
  40. quantity: int = 1
  41. ) -> str:
  42. """
  43. 提交换衣服任务到队列(非阻塞)
  44. Args:
  45. user_id: 用户ID
  46. raw_image_id: 原始图片ID
  47. cloth_image_id: 衣服图片ID
  48. quantity: 生成数量(默认1)
  49. Returns:
  50. str: 任务ID
  51. """
  52. logger.info(f"提交用户 {user_id} 的换衣服任务到队列")
  53. # 提交任务到队列
  54. task_id = self.task_queue.submit_task(
  55. task_func=self._execute_swap_cloth_task,
  56. task_args=(user_id, raw_image_id, cloth_image_id),
  57. task_kwargs={
  58. "quantity": max(1, int(quantity))
  59. }
  60. )
  61. logger.info(f"任务已提交到队列,任务ID: {task_id}")
  62. return task_id
  63. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  64. """
  65. 获取任务状态
  66. Args:
  67. task_id: 任务ID
  68. Returns:
  69. Dict: 任务状态信息
  70. """
  71. return self.task_queue.get_task_status(task_id)
  72. def get_user_tasks(self, user_id: int) -> list:
  73. """
  74. 获取用户的所有任务
  75. Args:
  76. user_id: 用户ID
  77. Returns:
  78. list: 任务列表
  79. """
  80. return self.task_queue.get_all_tasks(user_id)
  81. def cancel_task(self, task_id: str) -> bool:
  82. """
  83. 取消任务
  84. Args:
  85. task_id: 任务ID
  86. Returns:
  87. bool: 是否成功取消
  88. """
  89. return self.task_queue.cancel_task(task_id)
  90. def _execute_swap_cloth_task(
  91. self,
  92. user_id: int,
  93. raw_image_id: int,
  94. cloth_image_id: int,
  95. quantity: int = 1
  96. ) -> Dict[str, Any]:
  97. """
  98. 执行AI换衣服任务(在线程池中运行)
  99. Args:
  100. user_id: 用户ID
  101. raw_image_id: 原始图片ID
  102. cloth_image_id: 衣服图片ID
  103. quantity: 生成数量(默认1)
  104. Returns:
  105. Dict: 包含处理结果的字典
  106. """
  107. try:
  108. logger.info(f"开始执行用户 {user_id} 的换衣服任务")
  109. # 1. 输入验证
  110. self._validate_inputs(user_id, raw_image_id, cloth_image_id)
  111. # 2. 获取输入图片
  112. raw_image, cloth_image = self._get_input_images(raw_image_id, cloth_image_id)
  113. # 3. 循环生成 quantity 次
  114. total_count = max(1, int(quantity))
  115. process_record_ids = []
  116. result_image_ids = []
  117. copywriter_texts = []
  118. history_prompt_last = None
  119. # 为了提升多张图的差异性,循环内随机化种子(若底层支持)
  120. try:
  121. from backend.modules.comfyui import ai_swap_cloth as comfy_ai_swap_cloth_module
  122. except Exception:
  123. comfy_ai_swap_cloth_module = None
  124. for _ in range(total_count):
  125. if comfy_ai_swap_cloth_module is not None:
  126. try:
  127. comfy_ai_swap_cloth_module.system_config.seed = random.randint(1, 10_000_000)
  128. except Exception:
  129. pass
  130. # 执行AI换衣服
  131. result_image, history_prompt = self._process_ai_swap_cloth(
  132. raw_image, cloth_image
  133. )
  134. history_prompt_last = history_prompt
  135. # 生成文案描述
  136. copywriter_text = self._generate_copywriter(result_image)
  137. copywriter_texts.append(copywriter_text)
  138. # 保存结果图片
  139. result_image_record = self._save_result_image(user_id, result_image)
  140. result_image_ids.append(result_image_record["id"])
  141. # 创建处理记录
  142. process_record = self._create_process_record(
  143. user_id, raw_image_id, cloth_image_id,
  144. result_image_record["id"], copywriter_text, prompt="this is the prompt of swap cloth task"
  145. )
  146. process_record_ids.append(process_record["id"])
  147. logger.info(
  148. f"用户 {user_id} 的换衣服任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}"
  149. )
  150. # 为兼容旧前端,保留单值字段,同时返回批量字段
  151. return {
  152. "success": True,
  153. "count": len(process_record_ids),
  154. "process_record_id": process_record_ids[0],
  155. "result_image_id": result_image_ids[0],
  156. "copywriter_text": copywriter_texts[0] if copywriter_texts else None,
  157. "history_prompt": history_prompt_last,
  158. "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
  159. "process_record_ids": process_record_ids,
  160. "result_image_ids": result_image_ids,
  161. "copywriter_texts": copywriter_texts,
  162. }
  163. except Exception as e:
  164. logger.error(f"换衣服任务执行失败: {str(e)}")
  165. return {
  166. "success": False,
  167. "error": str(e),
  168. "error_type": type(e).__name__
  169. }
  170. def process_swap_cloth_with_record(
  171. self,
  172. user_id: int,
  173. raw_image_id: int,
  174. cloth_image_id: int,
  175. **kwargs
  176. ) -> Dict[str, Any]:
  177. """
  178. 完整的AI换衣服业务流程(同步版本,保持向后兼容)
  179. Args:
  180. user_id: 用户ID
  181. raw_image_id: 原始图片ID
  182. cloth_image_id: 衣服图片ID
  183. **kwargs: 其他可选参数
  184. Returns:
  185. Dict: 包含处理结果的字典
  186. """
  187. return self._execute_swap_cloth_task(
  188. user_id, raw_image_id, cloth_image_id, **kwargs
  189. )
  190. def _validate_inputs(self, user_id: int, raw_image_id: int, cloth_image_id: int):
  191. """
  192. 验证输入参数
  193. Args:
  194. user_id: 用户ID
  195. raw_image_id: 原始图片ID
  196. cloth_image_id: 衣服图片ID
  197. """
  198. # 验证用户是否存在
  199. user = self.db_ops.get_user_by_id(user_id)
  200. if not user:
  201. raise ValueError(f"用户ID {user_id} 不存在")
  202. if not user.get("is_active", False):
  203. raise ValueError(f"用户ID {user_id} 已被禁用")
  204. # 验证原始图片
  205. raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
  206. if not raw_image:
  207. raise ValueError(f"原始图片ID {raw_image_id} 不存在")
  208. if raw_image["image_type"] != "original":
  209. raise ValueError(f"原始图片ID {raw_image_id} 不是原始图片类型")
  210. if raw_image["user_id"] != user_id:
  211. raise ValueError(f"原始图片ID {raw_image_id} 不属于用户 {user_id}")
  212. # 验证衣服图片
  213. cloth_image = self.db_ops.get_image_record_by_id(cloth_image_id)
  214. if not cloth_image:
  215. raise ValueError(f"衣服图片ID {cloth_image_id} 不存在")
  216. if cloth_image["image_type"] != "cloth":
  217. raise ValueError(f"衣服图片ID {cloth_image_id} 不是衣服图片类型")
  218. if cloth_image["user_id"] != user_id:
  219. raise ValueError(f"衣服图片ID {cloth_image_id} 不属于用户 {user_id}")
  220. logger.info(f"输入验证通过: 用户={user_id}, 原始图片={raw_image_id}, 衣服图片={cloth_image_id}")
  221. def _get_input_images(self, raw_image_id: int, cloth_image_id: int) -> Tuple[np.ndarray, np.ndarray]:
  222. """
  223. 获取输入图片数据
  224. Args:
  225. raw_image_id: 原始图片ID
  226. cloth_image_id: 衣服图片ID
  227. Returns:
  228. Tuple: (原始图片数组, 衣服图片数组)
  229. """
  230. # 获取原始图片
  231. raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
  232. if not os.path.exists(raw_image["stored_path"]):
  233. raise FileNotFoundError(f"原始图片文件不存在: {raw_image['stored_path']}")
  234. raw_image = np.array(Image.open(raw_image["stored_path"]))
  235. # 获取衣服图片
  236. cloth_record = self.db_ops.get_image_record_by_id(cloth_image_id)
  237. if not os.path.exists(cloth_record["stored_path"]):
  238. raise FileNotFoundError(f"衣服图片文件不存在: {cloth_record['stored_path']}")
  239. cloth_image = np.array(Image.open(cloth_record["stored_path"]))
  240. logger.info(f"成功加载输入图片: 原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}")
  241. return raw_image, cloth_image
  242. def _process_ai_swap_cloth(
  243. self,
  244. raw_image: np.ndarray,
  245. cloth_image: np.ndarray
  246. ) -> Tuple[Image.Image, str]:
  247. """
  248. 执行AI换衣服处理
  249. Args:
  250. raw_image: 原始图片数组
  251. cloth_image: 衣服图片数组
  252. Returns:
  253. Tuple: (结果图片, 历史提示词)
  254. """
  255. logger.info(f"开始执行AI换衣服处理,原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}")
  256. try:
  257. result_image, history_prompt = ai_swap_cloth_process(
  258. raw_image, cloth_image
  259. )
  260. logger.info("AI换衣服处理完成")
  261. return result_image, history_prompt
  262. except Exception as e:
  263. logger.error(f"AI换衣服处理失败: {str(e)}", exc_info=True)
  264. def _generate_copywriter(self, result_image: Image.Image) -> str:
  265. """
  266. 基于结果图片生成文案描述
  267. Args:
  268. result_image: 结果图片
  269. Returns:
  270. str: 生成的文案描述
  271. """
  272. logger.info("开始生成文案描述")
  273. try:
  274. copywriter_text = gen_copywriter(result_image)
  275. logger.info("文案描述生成完成")
  276. return copywriter_text
  277. except Exception as e:
  278. logger.error(f"文案生成失败: {str(e)}", exc_info=True)
  279. # 文案生成失败不影响主流程,返回默认文案
  280. return "AI换衣服完成,效果很棒!✨"
  281. def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]:
  282. """
  283. 保存结果图片到数据库和文件系统
  284. Args:
  285. user_id: 用户ID
  286. result_image: 结果图片
  287. Returns:
  288. Dict: 图片记录信息
  289. """
  290. # 生成唯一文件名
  291. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  292. unique_id = str(uuid.uuid4())[:8]
  293. filename = f"result_{user_id}_{timestamp}_{unique_id}.png"
  294. # 保存到文件系统
  295. file_path = os.path.join(self.system_config.output_dir, filename)
  296. result_image.save(file_path, "PNG")
  297. # 计算文件大小和哈希值
  298. file_size = os.path.getsize(file_path)
  299. image_hash = self._calculate_image_hash(file_path)
  300. # 保存到数据库
  301. image_record = self.db_ops.create_image_record(
  302. user_id=user_id,
  303. image_type="result",
  304. original_filename=filename,
  305. stored_path=file_path,
  306. file_size=file_size,
  307. image_hash=image_hash
  308. )
  309. logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}")
  310. return image_record
  311. def _create_process_record(
  312. self,
  313. user_id: int,
  314. raw_image_id: int,
  315. cloth_image_id: int,
  316. result_image_id: int,
  317. copywriter_text: str,
  318. prompt: str
  319. ) -> Dict[str, Any]:
  320. """
  321. 创建处理记录
  322. Args:
  323. user_id: 用户ID
  324. raw_image_id: 原始图片ID
  325. cloth_image_id: 衣服图片ID
  326. result_image_id: 结果图片ID
  327. copywriter_text: 文案描述
  328. prompt: 提示词
  329. Returns:
  330. Dict: 处理记录信息
  331. """
  332. process_record = self.db_ops.create_process_record(
  333. user_id=user_id,
  334. face_image_id=raw_image_id,
  335. cloth_image_id=cloth_image_id,
  336. result_image_id=result_image_id,
  337. generated_text=copywriter_text,
  338. task_type="swap_cloth",
  339. prompt=prompt
  340. )
  341. # 更新完成时间
  342. self.db_ops.update_process_record(
  343. process_record["id"],
  344. {"completed_at": datetime.now()}
  345. )
  346. logger.info(f"处理记录创建成功: {process_record['id']}")
  347. return process_record
  348. def _calculate_image_hash(self, image_path: str) -> str:
  349. """
  350. 计算图片哈希值
  351. Args:
  352. image_path: 图片路径
  353. Returns:
  354. str: MD5哈希值
  355. """
  356. with open(image_path, 'rb') as f:
  357. return hashlib.md5(f.read()).hexdigest()
  358. def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  359. """
  360. 获取用户的处理历史记录
  361. Args:
  362. user_id: 用户ID
  363. page: 页码
  364. page_size: 每页大小
  365. Returns:
  366. Dict: 分页的处理记录列表
  367. """
  368. return self.db_ops.get_user_process_records(user_id, page, page_size)
  369. def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
  370. """
  371. 获取处理记录详情
  372. Args:
  373. process_id: 处理记录ID
  374. user_id: 用户ID(可选,用于权限验证)
  375. Returns:
  376. Dict: 处理记录详情,包含关联的图片信息
  377. """
  378. process_record = self.db_ops.get_process_record_by_id(process_id)
  379. if not process_record:
  380. return None
  381. # 权限验证
  382. if user_id and process_record["user_id"] != user_id:
  383. return None
  384. # 获取关联的图片信息
  385. face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
  386. cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
  387. result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
  388. return {
  389. "process_record": process_record,
  390. "face_image": face_image,
  391. "cloth_image": cloth_image,
  392. "result_image": result_image
  393. }
  394. def approve_process_record(self, process_id: int) -> bool:
  395. """
  396. 审核通过处理记录
  397. Args:
  398. process_id: 处理记录ID
  399. Returns:
  400. bool: 操作是否成功
  401. """
  402. try:
  403. # 检查记录是否存在
  404. process_record = self.db_ops.get_process_record_by_id(process_id)
  405. if not process_record:
  406. logger.error(f"处理记录不存在: {process_id}")
  407. return False
  408. # 更新状态为已审核
  409. update_data = {
  410. "status": "已审核"
  411. }
  412. updated_record = self.db_ops.update_process_record(process_id, update_data)
  413. if updated_record:
  414. logger.info(f"处理记录 {process_id} 审核通过")
  415. return True
  416. else:
  417. logger.error(f"更新处理记录状态失败: {process_id}")
  418. return False
  419. except Exception as e:
  420. logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
  421. return False
  422. def delete_result_image(self, process_id: int) -> bool:
  423. """
  424. 删除处理记录的结果图片:
  425. - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变
  426. - 可选地将状态置为“待审核”或维持原状态;此处不更改状态
  427. - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据
  428. """
  429. try:
  430. record = self.db_ops.get_process_record_by_id(process_id)
  431. if not record:
  432. logger.error(f"处理记录不存在: {process_id}")
  433. return False
  434. update_data = {
  435. "result_image_id": None
  436. }
  437. updated = self.db_ops.update_process_record(process_id, update_data)
  438. if not updated:
  439. logger.error(f"更新处理记录失败: {process_id}")
  440. return False
  441. logger.info(f"处理记录 {process_id} 已解除结果图片关联")
  442. return True
  443. except Exception as e:
  444. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  445. return False
  446. # 创建全局服务实例
  447. ai_swap_cloth_service = AISwapClothService()
  448. def process_swap_cloth_with_record(
  449. user_id: int,
  450. raw_image_id: int,
  451. cloth_image_id: int,
  452. **kwargs
  453. ) -> Dict[str, Any]:
  454. """
  455. 便捷函数:执行完整的换衣服流程
  456. Args:
  457. user_id: 用户ID
  458. raw_image_id: 原始图片ID
  459. cloth_image_id: 衣服图片ID
  460. **kwargs: 其他可选参数
  461. Returns:
  462. Dict: 包含处理结果的字典
  463. """
  464. return ai_swap_cloth_service.process_swap_cloth_with_record(
  465. user_id, raw_image_id, cloth_image_id, **kwargs
  466. )
  467. if __name__ == "__main__":
  468. try:
  469. # 假设用户ID为1,原始图片ID为1,衣服图片ID为1
  470. result = process_swap_cloth_with_record(
  471. user_id=5,
  472. raw_image_id=201,
  473. cloth_image_id=404
  474. )
  475. if result["success"]:
  476. print("处理成功!")
  477. print(f"处理记录ID: {result['process_record_id']}")
  478. print(f"结果图片ID: {result['result_image_id']}")
  479. print(f"文案: {result['copywriter_text']}")
  480. else:
  481. print(f"处理失败: {result['error']}")
  482. except Exception as e:
  483. print(f"发生错误: {str(e)}")
  484. import traceback
  485. traceback.print_exc()