import os import uuid import random import hashlib import numpy as np from PIL import Image from datetime import datetime from typing import Dict, Any, Optional, Tuple from backend.modules.comfyui.ai_swap_cloth import ai_swap_cloth_process from backend.modules.comfyui.ai_copywriter import gen_copywriter from backend.modules.database.operations import DatabaseOperations from backend.modules.database.models import ProcessRecord, ImageRecord, User from backend.modules.database.connection import DatabaseConnection from backend.modules.database.models import SystemConfig from backend.utils.logger_config import setup_logger from backend.utils.system_config import Config from backend.services.task_queue_service import get_task_queue_service, TaskStatus logger = setup_logger(__name__) class AISwapClothService: """ AI换衣服业务逻辑服务 负责协调AI处理、文案生成、数据库操作等完整业务流程 """ def __init__(self, db_operations: Optional[DatabaseOperations] = None): """ 初始化服务 Args: db_operations: 数据库操作对象,如果为None则创建默认实例 """ self.db_ops = db_operations or DatabaseOperations() self.system_config = Config('./backend/config/ai_swap_cloth_config.json') self.task_queue = get_task_queue_service() # 确保输出目录存在 os.makedirs(self.system_config.output_dir, exist_ok=True) def submit_swap_cloth_task( self, user_id: int, raw_image_id: int, cloth_image_id: int, quantity: int = 1 ) -> str: """ 提交换衣服任务到队列(非阻塞) Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID quantity: 生成数量(默认1) Returns: str: 任务ID """ logger.info(f"提交用户 {user_id} 的换衣服任务到队列") # 提交任务到队列 task_id = self.task_queue.submit_task( task_func=self._execute_swap_cloth_task, task_args=(user_id, raw_image_id, cloth_image_id), task_kwargs={ "quantity": max(1, int(quantity)) } ) logger.info(f"任务已提交到队列,任务ID: {task_id}") return task_id def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: """ 获取任务状态 Args: task_id: 任务ID Returns: Dict: 任务状态信息 """ return self.task_queue.get_task_status(task_id) def get_user_tasks(self, user_id: int) -> list: """ 获取用户的所有任务 Args: user_id: 用户ID Returns: list: 任务列表 """ return self.task_queue.get_all_tasks(user_id) def cancel_task(self, task_id: str) -> bool: """ 取消任务 Args: task_id: 任务ID Returns: bool: 是否成功取消 """ return self.task_queue.cancel_task(task_id) def _execute_swap_cloth_task( self, user_id: int, raw_image_id: int, cloth_image_id: int, quantity: int = 1 ) -> Dict[str, Any]: """ 执行AI换衣服任务(在线程池中运行) Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID quantity: 生成数量(默认1) Returns: Dict: 包含处理结果的字典 """ try: logger.info(f"开始执行用户 {user_id} 的换衣服任务") # 1. 输入验证 self._validate_inputs(user_id, raw_image_id, cloth_image_id) # 2. 获取输入图片 raw_image, cloth_image = self._get_input_images(raw_image_id, cloth_image_id) # 3. 循环生成 quantity 次 total_count = max(1, int(quantity)) process_record_ids = [] result_image_ids = [] copywriter_texts = [] history_prompt_last = None # 为了提升多张图的差异性,循环内随机化种子(若底层支持) try: from backend.modules.comfyui import ai_swap_cloth as comfy_ai_swap_cloth_module except Exception: comfy_ai_swap_cloth_module = None for _ in range(total_count): if comfy_ai_swap_cloth_module is not None: try: comfy_ai_swap_cloth_module.system_config.seed = random.randint(1, 10_000_000) except Exception: pass # 执行AI换衣服 result_image, history_prompt = self._process_ai_swap_cloth( raw_image, cloth_image ) history_prompt_last = history_prompt # 生成文案描述 copywriter_text = self._generate_copywriter(result_image) copywriter_texts.append(copywriter_text) # 保存结果图片 result_image_record = self._save_result_image(user_id, result_image) result_image_ids.append(result_image_record["id"]) # 创建处理记录 process_record = self._create_process_record( user_id, raw_image_id, cloth_image_id, result_image_record["id"], copywriter_text, prompt="this is the prompt of swap cloth task" ) process_record_ids.append(process_record["id"]) logger.info( f"用户 {user_id} 的换衣服任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}" ) # 为兼容旧前端,保留单值字段,同时返回批量字段 return { "success": True, "count": len(process_record_ids), "process_record_id": process_record_ids[0], "result_image_id": result_image_ids[0], "copywriter_text": copywriter_texts[0] if copywriter_texts else None, "history_prompt": history_prompt_last, "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]), "process_record_ids": process_record_ids, "result_image_ids": result_image_ids, "copywriter_texts": copywriter_texts, } except Exception as e: logger.error(f"换衣服任务执行失败: {str(e)}") return { "success": False, "error": str(e), "error_type": type(e).__name__ } def process_swap_cloth_with_record( self, user_id: int, raw_image_id: int, cloth_image_id: int, **kwargs ) -> Dict[str, Any]: """ 完整的AI换衣服业务流程(同步版本,保持向后兼容) Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID **kwargs: 其他可选参数 Returns: Dict: 包含处理结果的字典 """ return self._execute_swap_cloth_task( user_id, raw_image_id, cloth_image_id, **kwargs ) def _validate_inputs(self, user_id: int, raw_image_id: int, cloth_image_id: int): """ 验证输入参数 Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID """ # 验证用户是否存在 user = self.db_ops.get_user_by_id(user_id) if not user: raise ValueError(f"用户ID {user_id} 不存在") if not user.get("is_active", False): raise ValueError(f"用户ID {user_id} 已被禁用") # 验证原始图片 raw_image = self.db_ops.get_image_record_by_id(raw_image_id) if not raw_image: raise ValueError(f"原始图片ID {raw_image_id} 不存在") if raw_image["image_type"] != "original": raise ValueError(f"原始图片ID {raw_image_id} 不是原始图片类型") if raw_image["user_id"] != user_id: raise ValueError(f"原始图片ID {raw_image_id} 不属于用户 {user_id}") # 验证衣服图片 cloth_image = self.db_ops.get_image_record_by_id(cloth_image_id) if not cloth_image: raise ValueError(f"衣服图片ID {cloth_image_id} 不存在") if cloth_image["image_type"] != "cloth": raise ValueError(f"衣服图片ID {cloth_image_id} 不是衣服图片类型") if cloth_image["user_id"] != user_id: raise ValueError(f"衣服图片ID {cloth_image_id} 不属于用户 {user_id}") logger.info(f"输入验证通过: 用户={user_id}, 原始图片={raw_image_id}, 衣服图片={cloth_image_id}") def _get_input_images(self, raw_image_id: int, cloth_image_id: int) -> Tuple[np.ndarray, np.ndarray]: """ 获取输入图片数据 Args: raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID Returns: Tuple: (原始图片数组, 衣服图片数组) """ # 获取原始图片 raw_image = self.db_ops.get_image_record_by_id(raw_image_id) if not os.path.exists(raw_image["stored_path"]): raise FileNotFoundError(f"原始图片文件不存在: {raw_image['stored_path']}") raw_image = np.array(Image.open(raw_image["stored_path"])) # 获取衣服图片 cloth_record = self.db_ops.get_image_record_by_id(cloth_image_id) if not os.path.exists(cloth_record["stored_path"]): raise FileNotFoundError(f"衣服图片文件不存在: {cloth_record['stored_path']}") cloth_image = np.array(Image.open(cloth_record["stored_path"])) logger.info(f"成功加载输入图片: 原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}") return raw_image, cloth_image def _process_ai_swap_cloth( self, raw_image: np.ndarray, cloth_image: np.ndarray ) -> Tuple[Image.Image, str]: """ 执行AI换衣服处理 Args: raw_image: 原始图片数组 cloth_image: 衣服图片数组 Returns: Tuple: (结果图片, 历史提示词) """ logger.info(f"开始执行AI换衣服处理,原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}") try: result_image, history_prompt = ai_swap_cloth_process( raw_image, cloth_image ) logger.info("AI换衣服处理完成") return result_image, history_prompt except Exception as e: logger.error(f"AI换衣服处理失败: {str(e)}", exc_info=True) def _generate_copywriter(self, result_image: Image.Image) -> str: """ 基于结果图片生成文案描述 Args: result_image: 结果图片 Returns: str: 生成的文案描述 """ logger.info("开始生成文案描述") try: copywriter_text = gen_copywriter(result_image) logger.info("文案描述生成完成") return copywriter_text except Exception as e: logger.error(f"文案生成失败: {str(e)}", exc_info=True) # 文案生成失败不影响主流程,返回默认文案 return "AI换衣服完成,效果很棒!✨" def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]: """ 保存结果图片到数据库和文件系统 Args: user_id: 用户ID result_image: 结果图片 Returns: Dict: 图片记录信息 """ # 生成唯一文件名 timestamp = datetime.now().strftime("%Y%m%d%H%M%S") unique_id = str(uuid.uuid4())[:8] filename = f"result_{user_id}_{timestamp}_{unique_id}.png" # 保存到文件系统 file_path = os.path.join(self.system_config.output_dir, filename) result_image.save(file_path, "PNG") # 计算文件大小和哈希值 file_size = os.path.getsize(file_path) image_hash = self._calculate_image_hash(file_path) # 保存到数据库 image_record = self.db_ops.create_image_record( user_id=user_id, image_type="result", original_filename=filename, stored_path=file_path, file_size=file_size, image_hash=image_hash ) logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}") return image_record def _create_process_record( self, user_id: int, raw_image_id: int, cloth_image_id: int, result_image_id: int, copywriter_text: str, prompt: str ) -> Dict[str, Any]: """ 创建处理记录 Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID result_image_id: 结果图片ID copywriter_text: 文案描述 prompt: 提示词 Returns: Dict: 处理记录信息 """ process_record = self.db_ops.create_process_record( user_id=user_id, face_image_id=raw_image_id, cloth_image_id=cloth_image_id, result_image_id=result_image_id, generated_text=copywriter_text, task_type="swap_cloth", prompt=prompt ) # 更新完成时间 self.db_ops.update_process_record( process_record["id"], {"completed_at": datetime.now()} ) logger.info(f"处理记录创建成功: {process_record['id']}") return process_record def _calculate_image_hash(self, image_path: str) -> str: """ 计算图片哈希值 Args: image_path: 图片路径 Returns: str: MD5哈希值 """ with open(image_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest() def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """ 获取用户的处理历史记录 Args: user_id: 用户ID page: 页码 page_size: 每页大小 Returns: Dict: 分页的处理记录列表 """ return self.db_ops.get_user_process_records(user_id, page, page_size) def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]: """ 获取处理记录详情 Args: process_id: 处理记录ID user_id: 用户ID(可选,用于权限验证) Returns: Dict: 处理记录详情,包含关联的图片信息 """ process_record = self.db_ops.get_process_record_by_id(process_id) if not process_record: return None # 权限验证 if user_id and process_record["user_id"] != user_id: return None # 获取关联的图片信息 face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"]) cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"]) result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"]) return { "process_record": process_record, "face_image": face_image, "cloth_image": cloth_image, "result_image": result_image } def approve_process_record(self, process_id: int) -> bool: """ 审核通过处理记录 Args: process_id: 处理记录ID Returns: bool: 操作是否成功 """ try: # 检查记录是否存在 process_record = self.db_ops.get_process_record_by_id(process_id) if not process_record: logger.error(f"处理记录不存在: {process_id}") return False # 更新状态为已审核 update_data = { "status": "已审核" } updated_record = self.db_ops.update_process_record(process_id, update_data) if updated_record: logger.info(f"处理记录 {process_id} 审核通过") return True else: logger.error(f"更新处理记录状态失败: {process_id}") return False except Exception as e: logger.error(f"审核处理记录异常: {str(e)}", exc_info=True) return False def delete_result_image(self, process_id: int) -> bool: """ 删除处理记录的结果图片: - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变 - 可选地将状态置为“待审核”或维持原状态;此处不更改状态 - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据 """ try: record = self.db_ops.get_process_record_by_id(process_id) if not record: logger.error(f"处理记录不存在: {process_id}") return False update_data = { "result_image_id": None } updated = self.db_ops.update_process_record(process_id, update_data) if not updated: logger.error(f"更新处理记录失败: {process_id}") return False logger.info(f"处理记录 {process_id} 已解除结果图片关联") return True except Exception as e: logger.error(f"删除结果图片失败: {str(e)}", exc_info=True) return False # 创建全局服务实例 ai_swap_cloth_service = AISwapClothService() def process_swap_cloth_with_record( user_id: int, raw_image_id: int, cloth_image_id: int, **kwargs ) -> Dict[str, Any]: """ 便捷函数:执行完整的换衣服流程 Args: user_id: 用户ID raw_image_id: 原始图片ID cloth_image_id: 衣服图片ID **kwargs: 其他可选参数 Returns: Dict: 包含处理结果的字典 """ return ai_swap_cloth_service.process_swap_cloth_with_record( user_id, raw_image_id, cloth_image_id, **kwargs ) if __name__ == "__main__": try: # 假设用户ID为1,原始图片ID为1,衣服图片ID为1 result = process_swap_cloth_with_record( user_id=5, raw_image_id=201, cloth_image_id=404 ) if result["success"]: print("处理成功!") print(f"处理记录ID: {result['process_record_id']}") print(f"结果图片ID: {result['result_image_id']}") print(f"文案: {result['copywriter_text']}") else: print(f"处理失败: {result['error']}") except Exception as e: print(f"发生错误: {str(e)}") import traceback traceback.print_exc()