ai_swap.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import json
  2. import random
  3. import websocket
  4. from PIL import Image
  5. from datetime import datetime
  6. from backend.modules.comfyui.image_tools import upload_image, show_gif
  7. from backend.modules.comfyui.network_tools import get_images
  8. from backend.utils.logger_config import setup_logger
  9. from backend.utils.system_config import Config
  10. from backend.config.prompt_config import FACE_PROMPT, CLOTH_PROMPT, PROMPT_PROMPT
  11. system_config = Config('./backend/config/workflow_config.json')
  12. logger = setup_logger(__name__)
  13. def parse_workflow(prompt, face_img, cloth_img, config, face_prompt, cloth_prompt, prompt_prompt) -> dict:
  14. try:
  15. # 上传图片
  16. logger.info("开始上传人脸图片...")
  17. input_face_img = upload_image(face_img, config)
  18. logger.info("人脸图片上传成功")
  19. logger.info("开始上传服装图片...")
  20. input_cloth_img = upload_image(cloth_img, config)
  21. logger.info("服装图片上传成功")
  22. logger.info(f'Work flow args config: workflow-{config.workflowfile}, prompt-{prompt}, seed-{config.seed}')
  23. # 加载工作流配置
  24. try:
  25. with open(config.workflowfile, "r", encoding="utf-8") as workflow:
  26. workflow_config = json.load(workflow)
  27. except FileNotFoundError:
  28. logger.error(f"工作流配置文件不存在: {config.workflowfile}")
  29. raise RuntimeError(f"工作流配置文件不存在: {config.workflowfile}")
  30. except json.JSONDecodeError as e:
  31. logger.error(f"工作流配置文件格式错误: {str(e)}")
  32. raise RuntimeError(f"工作流配置文件格式错误: {str(e)}")
  33. # 更新工作流参数(旧版)
  34. # workflow_config["268"]["inputs"]["text"] = prompt
  35. # workflow_config["271"]["inputs"]["image"] = input_face_img
  36. # workflow_config["265"]["inputs"]["image"] = input_cloth_img
  37. # workflow_config["267"]["inputs"]["seed"] = config.seed
  38. # 更新工作流参数(1014版)
  39. workflow_config["251"]["inputs"]["prompt"] = prompt
  40. workflow_config["248"]["inputs"]["image"] = input_face_img
  41. workflow_config["249"]["inputs"]["image"] = input_cloth_img
  42. workflow_config["251"]["inputs"]["seed"] = config.seed
  43. # 更新工作流参数(1028版)
  44. # workflow_config["253"]["inputs"]["prompt"] = prompt
  45. # workflow_config["248"]["inputs"]["image"] = input_face_img
  46. # workflow_config["249"]["inputs"]["image"] = input_cloth_img
  47. # TODO: 开放人脸识别、服装识别、提示词优化 三类系统提示词
  48. logger.info(f"system_prompt unable: {face_prompt}, {cloth_prompt}, {prompt_prompt}")
  49. # workflow_config["217"]["inputs"]["system_prompt"] = face_prompt
  50. # workflow_config["174"]["inputs"]["system_prompt"] = cloth_prompt
  51. # workflow_config["231"]["inputs"]["system_prompt"] = prompt_prompt
  52. logger.info("开始执行工作流...")
  53. return get_images(workflow_config, config)
  54. except Exception as e:
  55. logger.error(f"工作流解析失败: {str(e)}")
  56. raise RuntimeError(f"工作流解析失败: {str(e)}")
  57. def ai_swap_process(prompt, face_img, cloth_img, face_prompt=FACE_PROMPT, cloth_prompt=CLOTH_PROMPT, prompt_prompt=PROMPT_PROMPT.format(history="暂无历史记录")) -> list:
  58. ws = websocket.WebSocket()
  59. host_url = f"ws://{system_config.server_address}/ws?clientID={system_config.client_id}"
  60. try:
  61. ws.connect(host_url)
  62. logger.info(f'Request to: {host_url}')
  63. except Exception as e:
  64. logger.error(f"无法连接到ComfyUI服务器: {host_url}, 错误: {str(e)}")
  65. raise RuntimeError(f"ComfyUI服务器连接失败: {str(e)}")
  66. # 从kwargs中提取参数
  67. try:
  68. images, history_prompt = parse_workflow(prompt, face_img, cloth_img, system_config, face_prompt, cloth_prompt, prompt_prompt)
  69. logger.info(f"工作流解析完成,获取到 {len(images)} 个节点输出")
  70. except Exception as e:
  71. logger.error(f"工作流解析失败: {str(e)}")
  72. raise RuntimeError(f"工作流解析失败: {str(e)}")
  73. images_cc = []
  74. # 检查是否有图片输出
  75. if not images:
  76. logger.error("没有获取到任何图片输出,请检查ComfyUI服务器状态和工作流配置")
  77. raise RuntimeError("AI处理未返回任何结果,请检查服务器状态")
  78. for node_id in images:
  79. logger.info(f"处理节点 {node_id},包含 {len(images[node_id])} 张图片")
  80. for image_data in images[node_id]:
  81. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  82. GIF_LOCATION = f"{system_config.output_dir}/{system_config.idx}_{system_config.seed}_{timestamp}.png"
  83. try:
  84. with open(GIF_LOCATION, "wb") as binary_file:
  85. binary_file.write(image_data)
  86. show_gif(GIF_LOCATION)
  87. system_config.idx += 1
  88. images_cc.append(Image.open(GIF_LOCATION))
  89. logger.info(f"图片保存成功: {GIF_LOCATION}")
  90. except Exception as e:
  91. logger.error(f"图片保存失败: {str(e)}")
  92. raise RuntimeError(f"图片保存失败: {str(e)}")
  93. logger.info(f'Prompt queue finished! The return result: {len(images_cc)} images')
  94. # 检查是否有处理结果
  95. if not images_cc:
  96. logger.error("没有生成任何图片结果")
  97. raise RuntimeError("AI处理未生成任何图片结果,请检查输入参数和服务器状态")
  98. return images_cc[0], history_prompt
  99. if __name__ == "__main__":
  100. import numpy as np
  101. face_img = np.array(Image.open('backend/data/face.png'))
  102. cloth_img = np.array(Image.open('backend/data/cloth.png'))
  103. ai_swap_process('美女站在海边', face_img, cloth_img)