ai_swap_face_service.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import os
  2. import uuid
  3. import random
  4. import hashlib
  5. import numpy as np
  6. from PIL import Image
  7. from datetime import datetime
  8. from typing import Dict, Any, Optional, Tuple
  9. from backend.modules.comfyui.ai_swap_face import ai_swap_face_process
  10. from backend.modules.comfyui.ai_copywriter import gen_copywriter
  11. from backend.modules.database.operations import DatabaseOperations
  12. from backend.modules.database.models import ProcessRecord, ImageRecord, User
  13. from backend.modules.database.connection import DatabaseConnection
  14. from backend.modules.database.models import SystemConfig
  15. from backend.utils.logger_config import setup_logger
  16. from backend.utils.system_config import Config
  17. from backend.services.task_queue_service import get_task_queue_service, TaskStatus
  18. logger = setup_logger(__name__)
  19. class AISwapFaceService:
  20. """
  21. AI换脸业务逻辑服务
  22. 负责协调AI处理、文案生成、数据库操作等完整业务流程
  23. """
  24. def __init__(self, db_operations: Optional[DatabaseOperations] = None):
  25. """
  26. 初始化服务
  27. Args:
  28. db_operations: 数据库操作对象,如果为None则创建默认实例
  29. """
  30. self.db_ops = db_operations or DatabaseOperations()
  31. self.system_config = Config('./backend/config/ai_swap_face_config.json')
  32. self.task_queue = get_task_queue_service()
  33. # 确保输出目录存在
  34. os.makedirs(self.system_config.output_dir, exist_ok=True)
  35. def submit_swap_face_task(
  36. self,
  37. user_id: int,
  38. raw_image_id: int,
  39. face_image_id: int,
  40. quantity: int = 1
  41. ) -> str:
  42. """
  43. 提交换脸任务到队列(非阻塞)
  44. Args:
  45. user_id: 用户ID
  46. raw_image_id: 原始图片ID
  47. face_image_id: 人脸图片ID
  48. quantity: 生成数量(默认1)
  49. Returns:
  50. str: 任务ID
  51. """
  52. logger.info(f"提交用户 {user_id} 的换脸任务到队列")
  53. # 提交任务到队列
  54. task_id = self.task_queue.submit_task(
  55. task_func=self._execute_swap_face_task,
  56. task_args=(user_id, raw_image_id, face_image_id),
  57. task_kwargs={
  58. "quantity": max(1, int(quantity))
  59. }
  60. )
  61. logger.info(f"任务已提交到队列,任务ID: {task_id}")
  62. return task_id
  63. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  64. """
  65. 获取任务状态
  66. Args:
  67. task_id: 任务ID
  68. Returns:
  69. Dict: 任务状态信息
  70. """
  71. return self.task_queue.get_task_status(task_id)
  72. def get_user_tasks(self, user_id: int) -> list:
  73. """
  74. 获取用户的所有任务
  75. Args:
  76. user_id: 用户ID
  77. Returns:
  78. list: 任务列表
  79. """
  80. return self.task_queue.get_all_tasks(user_id)
  81. def cancel_task(self, task_id: str) -> bool:
  82. """
  83. 取消任务
  84. Args:
  85. task_id: 任务ID
  86. Returns:
  87. bool: 是否成功取消
  88. """
  89. return self.task_queue.cancel_task(task_id)
  90. def _execute_swap_face_task(
  91. self,
  92. user_id: int,
  93. raw_image_id: int,
  94. face_image_id: int,
  95. quantity: int = 1
  96. ) -> Dict[str, Any]:
  97. """
  98. 执行AI换脸任务(在线程池中运行)
  99. Args:
  100. user_id: 用户ID
  101. raw_image_id: 原始图片ID
  102. face_image_id: 人脸图片ID
  103. quantity: 生成数量(默认1)
  104. Returns:
  105. Dict: 包含处理结果的字典
  106. """
  107. try:
  108. logger.info(f"开始执行用户 {user_id} 的换脸任务")
  109. # 1. 输入验证
  110. self._validate_inputs(user_id, raw_image_id, face_image_id)
  111. # 2. 获取输入图片
  112. raw_image, face_image = self._get_input_images(raw_image_id, face_image_id)
  113. # 3. 循环生成 quantity 次
  114. total_count = max(1, int(quantity))
  115. process_record_ids = []
  116. result_image_ids = []
  117. copywriter_texts = []
  118. history_prompt_last = None
  119. logger.info(f"开始执行用户 {user_id} 的换脸任务,共生成 {total_count} 张")
  120. # 为了提升多张图的差异性,循环内随机化种子(若底层支持)
  121. try:
  122. from backend.modules.comfyui import ai_swap_face as comfy_ai_swap_face_module
  123. except Exception:
  124. comfy_ai_swap_face_module = None
  125. for _ in range(total_count):
  126. if comfy_ai_swap_face_module is not None:
  127. try:
  128. comfy_ai_swap_face_module.system_config.seed = random.randint(1, 10_000_000)
  129. except Exception:
  130. pass
  131. # 执行AI换脸
  132. result_image, history_prompt = self._process_ai_swap_face(
  133. raw_image, face_image
  134. )
  135. history_prompt_last = history_prompt
  136. # 生成文案描述
  137. copywriter_text = self._generate_copywriter(result_image)
  138. copywriter_texts.append(copywriter_text)
  139. # 保存结果图片
  140. result_image_record = self._save_result_image(user_id, result_image)
  141. result_image_ids.append(result_image_record["id"])
  142. # 创建处理记录
  143. process_record = self._create_process_record(
  144. user_id, raw_image_id, face_image_id,
  145. result_image_record["id"], copywriter_text, prompt="this is the prompt of swap face task"
  146. )
  147. process_record_ids.append(process_record["id"])
  148. logger.info(
  149. f"用户 {user_id} 的换脸任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}"
  150. )
  151. # 为兼容旧前端,保留单值字段,同时返回批量字段
  152. return {
  153. "success": True,
  154. "count": len(process_record_ids),
  155. "process_record_id": process_record_ids[0],
  156. "result_image_id": result_image_ids[0],
  157. "copywriter_text": copywriter_texts[0] if copywriter_texts else None,
  158. "history_prompt": history_prompt_last,
  159. "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
  160. "process_record_ids": process_record_ids,
  161. "result_image_ids": result_image_ids,
  162. "copywriter_texts": copywriter_texts,
  163. }
  164. except Exception as e:
  165. logger.error(f"换脸任务执行失败: {str(e)}")
  166. return {
  167. "success": False,
  168. "error": str(e),
  169. "error_type": type(e).__name__
  170. }
  171. def process_swap_face_with_record(
  172. self,
  173. user_id: int,
  174. raw_image_id: int,
  175. face_image_id: int,
  176. **kwargs
  177. ) -> Dict[str, Any]:
  178. """
  179. 完整的AI换脸业务流程(同步版本,保持向后兼容)
  180. Args:
  181. """
  182. return self._execute_swap_face_task(
  183. user_id, raw_image_id, face_image_id, **kwargs
  184. )
  185. def _validate_inputs(self, user_id: int, raw_image_id: int, face_image_id: int):
  186. """
  187. 验证输入参数
  188. Args:
  189. user_id: 用户ID
  190. raw_image_id: 原始图片ID
  191. face_image_id: 人脸图片ID
  192. """
  193. # 验证用户是否存在
  194. user = self.db_ops.get_user_by_id(user_id)
  195. if not user:
  196. raise ValueError(f"用户ID {user_id} 不存在")
  197. if not user.get("is_active", False):
  198. raise ValueError(f"用户ID {user_id} 已被禁用")
  199. # 验证原始图片
  200. raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
  201. if not raw_image:
  202. raise ValueError(f"原始图片ID {raw_image_id} 不存在")
  203. if raw_image["image_type"] != "original":
  204. raise ValueError(f"原始图片ID {raw_image_id} 不是原始图片类型")
  205. if raw_image["user_id"] != user_id:
  206. raise ValueError(f"原始图片ID {raw_image_id} 不属于用户 {user_id}")
  207. # 验证人脸图片
  208. face_image = self.db_ops.get_image_record_by_id(face_image_id)
  209. if not face_image:
  210. raise ValueError(f"人脸图片ID {face_image_id} 不存在")
  211. if face_image["image_type"] != "face":
  212. raise ValueError(f"人脸图片ID {face_image_id} 不是人脸图片类型")
  213. if face_image["user_id"] != user_id:
  214. raise ValueError(f"人脸图片ID {face_image_id} 不属于用户 {user_id}")
  215. logger.info(f"输入验证通过: 用户={user_id}, 原始图片={raw_image_id}, 人脸图片={face_image_id}")
  216. def _get_input_images(self, raw_image_id: int, face_image_id: int) -> Tuple[np.ndarray, np.ndarray]:
  217. """
  218. 获取输入图片数据
  219. Args:
  220. raw_image_id: 原始图片ID
  221. face_image_id: 人脸图片ID
  222. Returns:
  223. Tuple: (原始图片数组, 人脸图片数组)
  224. """
  225. # 获取原始图片
  226. raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
  227. if not os.path.exists(raw_image["stored_path"]):
  228. raise FileNotFoundError(f"原始图片文件不存在: {raw_image['stored_path']}")
  229. raw_image = np.array(Image.open(raw_image["stored_path"]))
  230. # 获取人脸图片
  231. face_image = self.db_ops.get_image_record_by_id(face_image_id)
  232. if not os.path.exists(face_image["stored_path"]):
  233. raise FileNotFoundError(f"人脸图片文件不存在: {face_image['stored_path']}")
  234. face_image = np.array(Image.open(face_image["stored_path"]))
  235. logger.info(f"成功加载输入图片: 原始图片={raw_image.shape}, 人脸图片={face_image.shape}")
  236. return raw_image, face_image
  237. def _process_ai_swap_face(
  238. self,
  239. raw_image: np.ndarray,
  240. face_image: np.ndarray
  241. ) -> Tuple[Image.Image, str]:
  242. """
  243. 执行AI换脸处理
  244. Args:
  245. raw_image: 原始图片数组
  246. face_image: 人脸图片数组
  247. Returns:
  248. Tuple: (结果图片, 历史提示词)
  249. """
  250. logger.info(f"开始执行AI换脸处理,原始图片={raw_image.shape}, 人脸图片={face_image.shape}")
  251. try:
  252. result_image, history_prompt = ai_swap_face_process(
  253. raw_image, face_image
  254. )
  255. logger.info("AI换脸处理完成")
  256. return result_image, history_prompt
  257. except Exception as e:
  258. logger.error(f"AI换脸处理失败: {str(e)}", exc_info=True)
  259. def _generate_copywriter(self, result_image: Image.Image) -> str:
  260. """
  261. 基于结果图片生成文案描述
  262. Args:
  263. result_image: 结果图片
  264. Returns:
  265. str: 生成的文案描述
  266. """
  267. logger.info("开始生成文案描述")
  268. try:
  269. copywriter_text = gen_copywriter(result_image)
  270. logger.info("文案描述生成完成")
  271. return copywriter_text
  272. except Exception as e:
  273. logger.error(f"文案生成失败: {str(e)}", exc_info=True)
  274. # 文案生成失败不影响主流程,返回默认文案
  275. return "AI换衣服完成,效果很棒!✨"
  276. def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]:
  277. """
  278. 保存结果图片到数据库和文件系统
  279. Args:
  280. user_id: 用户ID
  281. result_image: 结果图片
  282. Returns:
  283. Dict: 图片记录信息
  284. """
  285. # 生成唯一文件名
  286. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  287. unique_id = str(uuid.uuid4())[:8]
  288. filename = f"result_{user_id}_{timestamp}_{unique_id}.png"
  289. # 保存到文件系统
  290. file_path = os.path.join(self.system_config.output_dir, filename)
  291. result_image.save(file_path, "PNG")
  292. # 计算文件大小和哈希值
  293. file_size = os.path.getsize(file_path)
  294. image_hash = self._calculate_image_hash(file_path)
  295. # 保存到数据库
  296. image_record = self.db_ops.create_image_record(
  297. user_id=user_id,
  298. image_type="result",
  299. original_filename=filename,
  300. stored_path=file_path,
  301. file_size=file_size,
  302. image_hash=image_hash
  303. )
  304. logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}")
  305. return image_record
  306. def _create_process_record(
  307. self,
  308. user_id: int,
  309. raw_image_id: int,
  310. face_image_id: int,
  311. result_image_id: int,
  312. copywriter_text: str,
  313. prompt: str
  314. ) -> Dict[str, Any]:
  315. """
  316. 创建处理记录
  317. Args:
  318. user_id: 用户ID
  319. raw_image_id: 原始图片ID
  320. face_image_id: 人脸图片ID
  321. result_image_id: 结果图片ID
  322. copywriter_text: 文案描述
  323. prompt: 提示词
  324. Returns:
  325. Dict: 处理记录信息
  326. """
  327. logger.info(f"开始创建处理记录,用户ID={user_id}, 原始图片ID={raw_image_id}, 人脸图片ID={face_image_id}, 结果图片ID={result_image_id}")
  328. try:
  329. process_record = self.db_ops.create_process_record(
  330. user_id=user_id,
  331. face_image_id=face_image_id,
  332. cloth_image_id=raw_image_id,
  333. result_image_id=result_image_id,
  334. generated_text=copywriter_text,
  335. task_type="swap_face",
  336. prompt=prompt
  337. )
  338. # 更新完成时间
  339. self.db_ops.update_process_record(
  340. process_record["id"],
  341. {"completed_at": datetime.now()}
  342. )
  343. logger.info(f"处理记录创建成功: {process_record['id']}")
  344. return process_record
  345. except Exception as e:
  346. logger.error(f"创建处理记录失败: {str(e)}", exc_info=True)
  347. return None
  348. def _calculate_image_hash(self, image_path: str) -> str:
  349. """
  350. 计算图片哈希值
  351. Args:
  352. image_path: 图片路径
  353. Returns:
  354. str: MD5哈希值
  355. """
  356. with open(image_path, 'rb') as f:
  357. return hashlib.md5(f.read()).hexdigest()
  358. def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  359. """
  360. 获取用户的处理历史记录
  361. Args:
  362. user_id: 用户ID
  363. page: 页码
  364. page_size: 每页大小
  365. Returns:
  366. Dict: 分页的处理记录列表
  367. """
  368. return self.db_ops.get_user_process_records(user_id, page, page_size)
  369. def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
  370. """
  371. 获取处理记录详情
  372. Args:
  373. process_id: 处理记录ID
  374. user_id: 用户ID(可选,用于权限验证)
  375. Returns:
  376. Dict: 处理记录详情,包含关联的图片信息
  377. """
  378. process_record = self.db_ops.get_process_record_by_id(process_id)
  379. if not process_record:
  380. return None
  381. # 权限验证
  382. if user_id and process_record["user_id"] != user_id:
  383. return None
  384. # 获取关联的图片信息
  385. face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
  386. cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
  387. result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
  388. return {
  389. "process_record": process_record,
  390. "face_image": face_image,
  391. "cloth_image": cloth_image,
  392. "result_image": result_image
  393. }
  394. def approve_process_record(self, process_id: int) -> bool:
  395. """
  396. 审核通过处理记录
  397. Args:
  398. process_id: 处理记录ID
  399. Returns:
  400. bool: 操作是否成功
  401. """
  402. try:
  403. # 检查记录是否存在
  404. process_record = self.db_ops.get_process_record_by_id(process_id)
  405. if not process_record:
  406. logger.error(f"处理记录不存在: {process_id}")
  407. return False
  408. # 更新状态为已审核
  409. update_data = {
  410. "status": "已审核"
  411. }
  412. updated_record = self.db_ops.update_process_record(process_id, update_data)
  413. if updated_record:
  414. logger.info(f"处理记录 {process_id} 审核通过")
  415. return True
  416. else:
  417. logger.error(f"更新处理记录状态失败: {process_id}")
  418. return False
  419. except Exception as e:
  420. logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
  421. return False
  422. def delete_result_image(self, process_id: int) -> bool:
  423. """
  424. 删除处理记录的结果图片:
  425. - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变
  426. - 可选地将状态置为“待审核”或维持原状态;此处不更改状态
  427. - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据
  428. """
  429. try:
  430. record = self.db_ops.get_process_record_by_id(process_id)
  431. if not record:
  432. logger.error(f"处理记录不存在: {process_id}")
  433. return False
  434. update_data = {
  435. "result_image_id": None
  436. }
  437. updated = self.db_ops.update_process_record(process_id, update_data)
  438. if not updated:
  439. logger.error(f"更新处理记录失败: {process_id}")
  440. return False
  441. logger.info(f"处理记录 {process_id} 已解除结果图片关联")
  442. return True
  443. except Exception as e:
  444. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  445. return False
  446. # 创建全局服务实例
  447. ai_swap_face_service = AISwapFaceService()
  448. def process_swap_face_with_record(
  449. user_id: int,
  450. raw_image_id: int,
  451. face_image_id: int,
  452. **kwargs
  453. ) -> Dict[str, Any]:
  454. """
  455. 便捷函数:执行完整的换脸流程
  456. Args:
  457. user_id: 用户ID
  458. raw_image_id: 原始图片ID
  459. face_image_id: 人脸图片ID
  460. **kwargs: 其他可选参数
  461. Returns:
  462. Dict: 包含处理结果的字典
  463. """
  464. return ai_swap_face_service.process_swap_face_with_record(
  465. user_id, raw_image_id, face_image_id, **kwargs
  466. )
  467. if __name__ == "__main__":
  468. try:
  469. result = process_swap_face_with_record(
  470. user_id=5,
  471. raw_image_id=201,
  472. face_image_id=199
  473. )
  474. if result["success"]:
  475. print("处理成功!")
  476. print(f"处理记录ID: {result['process_record_id']}")
  477. print(f"结果图片ID: {result['result_image_id']}")
  478. print(f"文案: {result['copywriter_text']}")
  479. else:
  480. print(f"处理失败: {result['error']}")
  481. except Exception as e:
  482. print(f"发生错误: {str(e)}")
  483. import traceback
  484. traceback.print_exc()