ai_swap_face.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import json
  2. import random
  3. import websocket
  4. from PIL import Image
  5. from datetime import datetime
  6. from backend.modules.comfyui.image_tools import upload_image, show_gif
  7. from backend.modules.comfyui.network_tools import get_images
  8. from backend.utils.logger_config import setup_logger
  9. from backend.utils.system_config import Config
  10. from backend.config.prompt_config import FACE_PROMPT, CLOTH_PROMPT, PROMPT_PROMPT
  11. system_config = Config('./backend/config/ai_swap_face_config.json')
  12. logger = setup_logger(__name__)
  13. def parse_workflow(raw_img, face_img, config) -> dict:
  14. try:
  15. # 上传原始图片
  16. logger.info('开始上传原始图片...')
  17. input_raw_img = upload_image(raw_img, config)
  18. logger.info('原始图片上传成功')
  19. # 上传人脸图片
  20. logger.info('开始上传人脸图片...')
  21. input_face_img = upload_image(face_img, config)
  22. logger.info('人脸图片上传成功')
  23. logger.info(f'Work flow args config: workflow-{config.workflowfile}')
  24. # 加载工作流配置
  25. try:
  26. with open(config.workflowfile, 'r', encoding='utf-8') as workflow:
  27. workflow_config = json.load(workflow)
  28. except FileNotFoundError:
  29. logger.error(f'工作流配置文件不存在: {config.workflowfile}')
  30. raise RuntimeError(f'工作流配置文件不存在: {config.workflowfile}')
  31. except json.JSONDecodeError as e:
  32. logger.error(f'工作流配置文件格式错误: {str(e)}')
  33. raise RuntimeError(f'工作流配置文件格式错误: {str(e)}')
  34. # 更新工作流参数
  35. workflow_config["247"]["inputs"]["image"] = input_raw_img
  36. workflow_config["245"]["inputs"]["image"] = input_face_img
  37. logger.info("开始执行工作流...")
  38. return get_images(workflow_config, config)
  39. except Exception as e:
  40. logger.error(f"工作流解析失败: {str(e)}")
  41. raise RuntimeError(f"工作流解析失败: {str(e)}")
  42. def ai_swap_face_process(raw_img, face_img) -> list:
  43. ws = websocket.WebSocket()
  44. host_url = f"ws://{system_config.server_address}/ws?clientID={system_config.client_id}"
  45. try:
  46. ws.connect(host_url)
  47. logger.info(f'Request to: {host_url}')
  48. except Exception as e:
  49. logger.error(f"无法连接到ComfyUI服务器:{host_url}, 错误:{str(e)}")
  50. raise RuntimeError(f"ComfyUI服务器连接失败: {str(e)}")
  51. try:
  52. images, history_prompt = parse_workflow(raw_img, face_img, system_config)
  53. logger.info(f"工作流解析完成,获取到{len(images)}个节点输出")
  54. except Exception as e:
  55. logger.error(f"工作流解析失败: {str(e)}")
  56. raise RuntimeError(f"工作流解析失败: {str(e)}")
  57. images_cc = []
  58. if not images:
  59. logger.error("没有获取到任何图片输出,请检查ComfyUI服务器状态和工作流配置")
  60. raise RuntimeError("AI处理未返回任何结果,请检查服务器状态")
  61. for node_id in images:
  62. logger.info(f"处理节点 {node_id},包含 {len(images[node_id])} 张图片")
  63. for image_data in images[node_id]:
  64. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  65. GIF_LOCATION = f"{system_config.output_dir}/{system_config.idx}_{timestamp}.png"
  66. try:
  67. with open(GIF_LOCATION, "wb") as binary_file:
  68. binary_file.write(image_data)
  69. show_gif(GIF_LOCATION)
  70. system_config.idx += 1
  71. images_cc.append(Image.open(GIF_LOCATION))
  72. logger.info(f"图片保存成功: {GIF_LOCATION}")
  73. except Exception as e:
  74. logger.error(f"图片保存失败: {str(e)}")
  75. raise RuntimeError(f"图片保存失败: {str(e)}")
  76. logger.info(f'Prompt queue finished! The return result: {len(images_cc)} images')
  77. # 检查是否有处理结果
  78. if not images_cc:
  79. logger.error("没有生成任何图片结果")
  80. raise RuntimeError("AI处理未生成任何图片结果,请检查输入参数和服务器状态")
  81. return images_cc[0], history_prompt
  82. if __name__ == "__main__":
  83. import numpy as np
  84. raw_img = np.array(Image.open('backend/data/02.jpg'))
  85. face_img = np.array(Image.open('backend/data/face.jpg'))
  86. images_cc, history_prompt = ai_swap_face_process(raw_img, face_img)
  87. print(f"历史提示词: {history_prompt}")