ai_swap_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. import os
  2. import hashlib
  3. import uuid
  4. from datetime import datetime
  5. from typing import Dict, Any, Optional, Tuple
  6. import random
  7. from PIL import Image
  8. import numpy as np
  9. from backend.modules.comfyui.ai_swap import ai_swap_process
  10. from backend.modules.comfyui.ai_copywriter import gen_copywriter
  11. from backend.modules.database.operations import DatabaseOperations
  12. from backend.modules.database.models import ProcessRecord, ImageRecord, User
  13. from backend.modules.database.connection import DatabaseConnection
  14. from backend.modules.database.models import SystemConfig
  15. from backend.utils.logger_config import setup_logger
  16. from backend.utils.system_config import Config
  17. from backend.services.task_queue_service import get_task_queue_service, TaskStatus
  18. logger = setup_logger(__name__)
  19. class AISwapService:
  20. """
  21. AI换脸换装业务逻辑服务
  22. 负责协调AI处理、文案生成、数据库操作等完整业务流程
  23. """
  24. def __init__(self, db_operations: Optional[DatabaseOperations] = None):
  25. """
  26. 初始化服务
  27. Args:
  28. db_operations: 数据库操作对象,如果为None则创建默认实例
  29. """
  30. self.db_ops = db_operations or DatabaseOperations()
  31. self.system_config = Config('./backend/config/workflow_config.json')
  32. self.task_queue = get_task_queue_service()
  33. # 确保输出目录存在
  34. os.makedirs(self.system_config.output_dir, exist_ok=True)
  35. def submit_swap_task(
  36. self,
  37. user_id: int,
  38. face_image_id: int,
  39. cloth_image_id: int,
  40. prompt: str,
  41. face_prompt: Optional[str] = None,
  42. cloth_prompt: Optional[str] = None,
  43. prompt_prompt: Optional[str] = None,
  44. quantity: int = 1
  45. ) -> str:
  46. """
  47. 提交AI换脸换装任务到队列(非阻塞)
  48. Args:
  49. user_id: 用户ID
  50. face_image_id: 人脸图片ID
  51. cloth_image_id: 服装图片ID
  52. prompt: 用户输入的提示词
  53. face_prompt: 人脸识别提示词(可选)
  54. cloth_prompt: 服装识别提示词(可选)
  55. prompt_prompt: 提示词优化(可选)
  56. Returns:
  57. str: 任务ID
  58. """
  59. logger.info(f"提交用户 {user_id} 的换脸换装任务到队列")
  60. # 提交任务到队列
  61. task_id = self.task_queue.submit_task(
  62. task_func=self._execute_swap_task,
  63. task_args=(user_id, face_image_id, cloth_image_id, prompt),
  64. task_kwargs={
  65. "face_prompt": face_prompt,
  66. "cloth_prompt": cloth_prompt,
  67. "prompt_prompt": prompt_prompt,
  68. "quantity": max(1, int(quantity))
  69. }
  70. )
  71. logger.info(f"任务已提交到队列,任务ID: {task_id}")
  72. return task_id
  73. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  74. """
  75. 获取任务状态
  76. Args:
  77. task_id: 任务ID
  78. Returns:
  79. Dict: 任务状态信息
  80. """
  81. return self.task_queue.get_task_status(task_id)
  82. def get_user_tasks(self, user_id: int) -> list:
  83. """
  84. 获取用户的所有任务
  85. Args:
  86. user_id: 用户ID
  87. Returns:
  88. list: 任务列表
  89. """
  90. return self.task_queue.get_all_tasks(user_id)
  91. def cancel_task(self, task_id: str) -> bool:
  92. """
  93. 取消任务
  94. Args:
  95. task_id: 任务ID
  96. Returns:
  97. bool: 是否成功取消
  98. """
  99. return self.task_queue.cancel_task(task_id)
  100. def _execute_swap_task(
  101. self,
  102. user_id: int,
  103. face_image_id: int,
  104. cloth_image_id: int,
  105. prompt: str,
  106. face_prompt: Optional[str] = None,
  107. cloth_prompt: Optional[str] = None,
  108. prompt_prompt: Optional[str] = None,
  109. quantity: int = 1
  110. ) -> Dict[str, Any]:
  111. """
  112. 执行AI换脸换装任务(在线程池中运行)
  113. Args:
  114. user_id: 用户ID
  115. face_image_id: 人脸图片ID
  116. cloth_image_id: 服装图片ID
  117. prompt: 用户输入的提示词
  118. face_prompt: 人脸识别提示词(可选)
  119. cloth_prompt: 服装识别提示词(可选)
  120. prompt_prompt: 提示词优化(可选)
  121. Returns:
  122. Dict: 包含处理结果的字典
  123. """
  124. try:
  125. logger.info(f"开始执行用户 {user_id} 的换脸换装任务")
  126. # 1. 输入验证
  127. self._validate_inputs(user_id, face_image_id, cloth_image_id, prompt)
  128. # 2. 获取输入图片
  129. face_image, cloth_image = self._get_input_images(face_image_id, cloth_image_id)
  130. # 3. 循环生成 quantity 次
  131. total_count = max(1, int(quantity))
  132. process_record_ids = []
  133. result_image_ids = []
  134. copywriter_texts = []
  135. history_prompt_last = None
  136. # 为了提升多张图的差异性,循环内随机化种子(若底层支持)
  137. try:
  138. from backend.modules.comfyui import ai_swap as comfy_ai_swap_module
  139. except Exception:
  140. comfy_ai_swap_module = None
  141. for _ in range(total_count):
  142. if comfy_ai_swap_module is not None:
  143. try:
  144. comfy_ai_swap_module.system_config.seed = random.randint(1, 10_000_000)
  145. except Exception:
  146. pass
  147. # 执行AI换脸换装
  148. result_image, history_prompt = self._process_ai_swap(
  149. prompt, face_image, cloth_image, face_prompt, cloth_prompt, prompt_prompt
  150. )
  151. history_prompt_last = history_prompt
  152. # 生成文案描述
  153. copywriter_text = self._generate_copywriter(result_image)
  154. copywriter_texts.append(copywriter_text)
  155. # 保存结果图片
  156. result_image_record = self._save_result_image(user_id, result_image)
  157. result_image_ids.append(result_image_record["id"])
  158. # 创建处理记录
  159. process_record = self._create_process_record(
  160. user_id, face_image_id, cloth_image_id,
  161. result_image_record["id"], copywriter_text, prompt
  162. )
  163. process_record_ids.append(process_record["id"])
  164. logger.info(
  165. f"用户 {user_id} 的换脸换装任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}"
  166. )
  167. # 为兼容旧前端,保留单值字段,同时返回批量字段
  168. return {
  169. "success": True,
  170. "count": len(process_record_ids),
  171. "process_record_id": process_record_ids[0],
  172. "result_image_id": result_image_ids[0],
  173. "copywriter_text": copywriter_texts[0] if copywriter_texts else None,
  174. "history_prompt": history_prompt_last,
  175. "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
  176. "process_record_ids": process_record_ids,
  177. "result_image_ids": result_image_ids,
  178. "copywriter_texts": copywriter_texts,
  179. }
  180. except Exception as e:
  181. logger.error(f"AI换脸换装任务执行失败: {str(e)}", exc_info=True)
  182. return {
  183. "success": False,
  184. "error": str(e),
  185. "error_type": type(e).__name__
  186. }
  187. def process_swap_with_record(
  188. self,
  189. user_id: int,
  190. face_image_id: int,
  191. cloth_image_id: int,
  192. prompt: str,
  193. face_prompt: Optional[str] = None,
  194. cloth_prompt: Optional[str] = None,
  195. prompt_prompt: Optional[str] = None
  196. ) -> Dict[str, Any]:
  197. """
  198. 完整的AI换脸换装业务流程(同步版本,保持向后兼容)
  199. Args:
  200. user_id: 用户ID
  201. face_image_id: 人脸图片ID
  202. cloth_image_id: 服装图片ID
  203. prompt: 用户输入的提示词
  204. face_prompt: 人脸识别提示词(可选)
  205. cloth_prompt: 服装识别提示词(可选)
  206. prompt_prompt: 提示词优化(可选)
  207. Returns:
  208. Dict: 包含处理结果的字典
  209. """
  210. return self._execute_swap_task(
  211. user_id, face_image_id, cloth_image_id, prompt,
  212. face_prompt, cloth_prompt, prompt_prompt
  213. )
  214. def _validate_inputs(self, user_id: int, face_image_id: int, cloth_image_id: int, prompt: str):
  215. """
  216. 验证输入参数
  217. Args:
  218. user_id: 用户ID
  219. face_image_id: 人脸图片ID
  220. cloth_image_id: 服装图片ID
  221. prompt: 提示词
  222. Raises:
  223. ValueError: 当输入验证失败时
  224. """
  225. # 验证用户是否存在
  226. user = self.db_ops.get_user_by_id(user_id)
  227. if not user:
  228. raise ValueError(f"用户ID {user_id} 不存在")
  229. if not user.get("is_active", False):
  230. raise ValueError(f"用户ID {user_id} 已被禁用")
  231. # 验证人脸图片
  232. face_image = self.db_ops.get_image_record_by_id(face_image_id)
  233. if not face_image:
  234. raise ValueError(f"人脸图片ID {face_image_id} 不存在")
  235. if face_image["image_type"] != "face":
  236. raise ValueError(f"图片ID {face_image_id} 不是人脸图片类型")
  237. if face_image["user_id"] != user_id:
  238. raise ValueError(f"人脸图片ID {face_image_id} 不属于用户 {user_id}")
  239. # 验证服装图片
  240. cloth_image = self.db_ops.get_image_record_by_id(cloth_image_id)
  241. if not cloth_image:
  242. raise ValueError(f"服装图片ID {cloth_image_id} 不存在")
  243. if cloth_image["image_type"] != "cloth":
  244. raise ValueError(f"图片ID {cloth_image_id} 不是服装图片类型")
  245. if cloth_image["user_id"] != user_id:
  246. raise ValueError(f"服装图片ID {cloth_image_id} 不属于用户 {user_id}")
  247. # 验证提示词
  248. if not prompt or not prompt.strip():
  249. raise ValueError("提示词不能为空")
  250. if len(prompt.strip()) > 500:
  251. raise ValueError("提示词长度不能超过500字符")
  252. logger.info(f"输入验证通过: 用户={user_id}, 人脸图片={face_image_id}, 服装图片={cloth_image_id}")
  253. def _get_input_images(self, face_image_id: int, cloth_image_id: int) -> Tuple[np.ndarray, np.ndarray]:
  254. """
  255. 获取输入图片数据
  256. Args:
  257. face_image_id: 人脸图片ID
  258. cloth_image_id: 服装图片ID
  259. Returns:
  260. Tuple: (人脸图片数组, 服装图片数组)
  261. """
  262. # 获取人脸图片
  263. face_record = self.db_ops.get_image_record_by_id(face_image_id)
  264. if not os.path.exists(face_record["stored_path"]):
  265. raise FileNotFoundError(f"人脸图片文件不存在: {face_record['stored_path']}")
  266. face_image = np.array(Image.open(face_record["stored_path"]))
  267. # 获取服装图片
  268. cloth_record = self.db_ops.get_image_record_by_id(cloth_image_id)
  269. if not os.path.exists(cloth_record["stored_path"]):
  270. raise FileNotFoundError(f"服装图片文件不存在: {cloth_record['stored_path']}")
  271. cloth_image = np.array(Image.open(cloth_record["stored_path"]))
  272. logger.info(f"成功加载输入图片: 人脸={face_image.shape}, 服装={cloth_image.shape}")
  273. return face_image, cloth_image
  274. def _process_ai_swap(
  275. self,
  276. prompt: str,
  277. face_img: np.ndarray,
  278. cloth_img: np.ndarray,
  279. face_prompt: Optional[str] = None,
  280. cloth_prompt: Optional[str] = None,
  281. prompt_prompt: Optional[str] = None
  282. ) -> Tuple[Image.Image, str]:
  283. """
  284. 执行AI换脸换装处理
  285. Args:
  286. prompt: 提示词
  287. face_img: 人脸图片数组
  288. cloth_img: 服装图片数组
  289. face_prompt: 人脸识别提示词
  290. cloth_prompt: 服装识别提示词
  291. prompt_prompt: 提示词优化
  292. Returns:
  293. Tuple: (结果图片, 历史提示词)
  294. """
  295. logger.info(f"开始执行AI换脸换装处理,提示词: {prompt}")
  296. try:
  297. result_image, history_prompt = ai_swap_process(
  298. prompt, face_img, cloth_img, face_prompt, cloth_prompt, prompt_prompt
  299. )
  300. logger.info("AI换脸换装处理完成")
  301. return result_image, history_prompt
  302. except Exception as e:
  303. logger.error(f"AI换脸换装处理失败: {str(e)}", exc_info=True)
  304. raise RuntimeError(f"AI换脸换装处理失败: {str(e)}")
  305. def _generate_copywriter(self, result_image: Image.Image) -> str:
  306. """
  307. 基于结果图片生成文案描述
  308. Args:
  309. result_image: 结果图片
  310. Returns:
  311. str: 生成的文案描述
  312. """
  313. logger.info("开始生成文案描述")
  314. try:
  315. copywriter_text = gen_copywriter(result_image)
  316. logger.info("文案描述生成完成")
  317. return copywriter_text
  318. except Exception as e:
  319. logger.error(f"文案生成失败: {str(e)}", exc_info=True)
  320. # 文案生成失败不影响主流程,返回默认文案
  321. return "AI换脸换装完成,效果很棒!✨"
  322. def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]:
  323. """
  324. 保存结果图片到数据库和文件系统
  325. Args:
  326. user_id: 用户ID
  327. result_image: 结果图片
  328. Returns:
  329. Dict: 图片记录信息
  330. """
  331. # 生成唯一文件名
  332. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  333. unique_id = str(uuid.uuid4())[:8]
  334. filename = f"result_{user_id}_{timestamp}_{unique_id}.png"
  335. # 保存到文件系统
  336. file_path = os.path.join(self.system_config.output_dir, filename)
  337. result_image.save(file_path, "PNG")
  338. # 计算文件大小和哈希值
  339. file_size = os.path.getsize(file_path)
  340. image_hash = self._calculate_image_hash(file_path)
  341. # 保存到数据库
  342. image_record = self.db_ops.create_image_record(
  343. user_id=user_id,
  344. image_type="result",
  345. original_filename=filename,
  346. stored_path=file_path,
  347. file_size=file_size,
  348. image_hash=image_hash
  349. )
  350. logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}")
  351. return image_record
  352. def _create_process_record(
  353. self,
  354. user_id: int,
  355. face_image_id: int,
  356. cloth_image_id: int,
  357. result_image_id: int,
  358. copywriter_text: str,
  359. prompt: str
  360. ) -> Dict[str, Any]:
  361. """
  362. 创建处理记录
  363. Args:
  364. user_id: 用户ID
  365. face_image_id: 人脸图片ID
  366. cloth_image_id: 服装图片ID
  367. result_image_id: 结果图片ID
  368. copywriter_text: 文案描述
  369. prompt: 提示词
  370. Returns:
  371. Dict: 处理记录信息
  372. """
  373. process_record = self.db_ops.create_process_record(
  374. user_id=user_id,
  375. face_image_id=face_image_id,
  376. cloth_image_id=cloth_image_id,
  377. result_image_id=result_image_id,
  378. generated_text=copywriter_text,
  379. task_type="swap_all",
  380. prompt=prompt
  381. )
  382. # 更新完成时间
  383. self.db_ops.update_process_record(
  384. process_record["id"],
  385. {"completed_at": datetime.now()}
  386. )
  387. logger.info(f"处理记录创建成功: {process_record['id']}")
  388. return process_record
  389. def _calculate_image_hash(self, image_path: str) -> str:
  390. """
  391. 计算图片哈希值
  392. Args:
  393. image_path: 图片路径
  394. Returns:
  395. str: MD5哈希值
  396. """
  397. with open(image_path, 'rb') as f:
  398. return hashlib.md5(f.read()).hexdigest()
  399. def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  400. """
  401. 获取用户的处理历史记录
  402. Args:
  403. user_id: 用户ID
  404. page: 页码
  405. page_size: 每页大小
  406. Returns:
  407. Dict: 分页的处理记录列表
  408. """
  409. return self.db_ops.get_user_process_records(user_id, page, page_size)
  410. def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
  411. """
  412. 获取处理记录详情
  413. Args:
  414. process_id: 处理记录ID
  415. user_id: 用户ID(可选,用于权限验证)
  416. Returns:
  417. Dict: 处理记录详情,包含关联的图片信息
  418. """
  419. process_record = self.db_ops.get_process_record_by_id(process_id)
  420. if not process_record:
  421. return None
  422. # 权限验证
  423. if user_id and process_record["user_id"] != user_id:
  424. return None
  425. # 获取关联的图片信息
  426. face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
  427. cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
  428. result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
  429. return {
  430. "process_record": process_record,
  431. "face_image": face_image,
  432. "cloth_image": cloth_image,
  433. "result_image": result_image
  434. }
  435. def approve_process_record(self, process_id: int) -> bool:
  436. """
  437. 审核通过处理记录
  438. Args:
  439. process_id: 处理记录ID
  440. Returns:
  441. bool: 操作是否成功
  442. """
  443. try:
  444. # 检查记录是否存在
  445. process_record = self.db_ops.get_process_record_by_id(process_id)
  446. if not process_record:
  447. logger.error(f"处理记录不存在: {process_id}")
  448. return False
  449. # 更新状态为已审核
  450. update_data = {
  451. "status": "已审核"
  452. }
  453. updated_record = self.db_ops.update_process_record(process_id, update_data)
  454. if updated_record:
  455. logger.info(f"处理记录 {process_id} 审核通过")
  456. return True
  457. else:
  458. logger.error(f"更新处理记录状态失败: {process_id}")
  459. return False
  460. except Exception as e:
  461. logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
  462. return False
  463. def delete_result_image(self, process_id: int) -> bool:
  464. """
  465. 删除处理记录的结果图片:
  466. - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变
  467. - 可选地将状态置为“待审核”或维持原状态;此处不更改状态
  468. - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据
  469. """
  470. try:
  471. record = self.db_ops.get_process_record_by_id(process_id)
  472. if not record:
  473. logger.error(f"处理记录不存在: {process_id}")
  474. return False
  475. update_data = {
  476. "result_image_id": None
  477. }
  478. updated = self.db_ops.update_process_record(process_id, update_data)
  479. if not updated:
  480. logger.error(f"更新处理记录失败: {process_id}")
  481. return False
  482. logger.info(f"处理记录 {process_id} 已解除结果图片关联")
  483. return True
  484. except Exception as e:
  485. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  486. return False
  487. # 创建全局服务实例
  488. ai_swap_service = AISwapService()
  489. def process_swap_with_record(
  490. user_id: int,
  491. face_image_id: int,
  492. cloth_image_id: int,
  493. prompt: str,
  494. **kwargs
  495. ) -> Dict[str, Any]:
  496. """
  497. 便捷函数:执行完整的换脸换装流程
  498. Args:
  499. user_id: 用户ID
  500. face_image_id: 人脸图片ID
  501. cloth_image_id: 服装图片ID
  502. prompt: 提示词
  503. **kwargs: 其他参数
  504. Returns:
  505. Dict: 处理结果
  506. """
  507. return ai_swap_service.process_swap_with_record(
  508. user_id, face_image_id, cloth_image_id, prompt, **kwargs
  509. )
  510. if __name__ == "__main__":
  511. # 测试代码
  512. try:
  513. # 假设用户ID为1,图片ID为1和2
  514. result = process_swap_with_record(
  515. user_id=4,
  516. face_image_id=191,
  517. cloth_image_id=190,
  518. prompt="美女在沙漠绿洲行走"
  519. )
  520. if result["success"]:
  521. print("处理成功!")
  522. print(f"处理记录ID: {result['process_record_id']}")
  523. print(f"结果图片ID: {result['result_image_id']}")
  524. print(f"文案: {result['copywriter_text']}")
  525. else:
  526. print(f"处理失败: {result['error']}")
  527. except Exception as e:
  528. print(f"测试失败: {str(e)}")