| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592 |
- import os
- import uuid
- import random
- import hashlib
- import numpy as np
- from PIL import Image
- from datetime import datetime
- from typing import Dict, Any, Optional, Tuple
- from backend.modules.comfyui.ai_swap_cloth import ai_swap_cloth_process
- from backend.modules.comfyui.ai_copywriter import gen_copywriter
- from backend.modules.database.operations import DatabaseOperations
- from backend.modules.database.models import ProcessRecord, ImageRecord, User
- from backend.modules.database.connection import DatabaseConnection
- from backend.modules.database.models import SystemConfig
- from backend.utils.logger_config import setup_logger
- from backend.utils.system_config import Config
- from backend.services.task_queue_service import get_task_queue_service, TaskStatus
- logger = setup_logger(__name__)
- class AISwapClothService:
- """
- AI换衣服业务逻辑服务
- 负责协调AI处理、文案生成、数据库操作等完整业务流程
- """
- def __init__(self, db_operations: Optional[DatabaseOperations] = None):
- """
- 初始化服务
- Args:
- db_operations: 数据库操作对象,如果为None则创建默认实例
- """
- self.db_ops = db_operations or DatabaseOperations()
- self.system_config = Config('./backend/config/ai_swap_cloth_config.json')
- self.task_queue = get_task_queue_service()
- # 确保输出目录存在
- os.makedirs(self.system_config.output_dir, exist_ok=True)
- def submit_swap_cloth_task(
- self,
- user_id: int,
- raw_image_id: int,
- cloth_image_id: int,
- quantity: int = 1
- ) -> str:
- """
- 提交换衣服任务到队列(非阻塞)
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- quantity: 生成数量(默认1)
- Returns:
- str: 任务ID
- """
- logger.info(f"提交用户 {user_id} 的换衣服任务到队列")
- # 提交任务到队列
- task_id = self.task_queue.submit_task(
- task_func=self._execute_swap_cloth_task,
- task_args=(user_id, raw_image_id, cloth_image_id),
- task_kwargs={
- "quantity": max(1, int(quantity))
- }
- )
- logger.info(f"任务已提交到队列,任务ID: {task_id}")
- return task_id
- def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
- """
- 获取任务状态
-
- Args:
- task_id: 任务ID
-
- Returns:
- Dict: 任务状态信息
- """
- return self.task_queue.get_task_status(task_id)
-
- def get_user_tasks(self, user_id: int) -> list:
- """
- 获取用户的所有任务
-
- Args:
- user_id: 用户ID
-
- Returns:
- list: 任务列表
- """
- return self.task_queue.get_all_tasks(user_id)
-
- def cancel_task(self, task_id: str) -> bool:
- """
- 取消任务
-
- Args:
- task_id: 任务ID
-
- Returns:
- bool: 是否成功取消
- """
- return self.task_queue.cancel_task(task_id)
- def _execute_swap_cloth_task(
- self,
- user_id: int,
- raw_image_id: int,
- cloth_image_id: int,
- quantity: int = 1
- ) -> Dict[str, Any]:
- """
- 执行AI换衣服任务(在线程池中运行)
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- quantity: 生成数量(默认1)
- Returns:
- Dict: 包含处理结果的字典
- """
- try:
- logger.info(f"开始执行用户 {user_id} 的换衣服任务")
- # 1. 输入验证
- self._validate_inputs(user_id, raw_image_id, cloth_image_id)
- # 2. 获取输入图片
- raw_image, cloth_image = self._get_input_images(raw_image_id, cloth_image_id)
- # 3. 循环生成 quantity 次
- total_count = max(1, int(quantity))
- process_record_ids = []
- result_image_ids = []
- copywriter_texts = []
- history_prompt_last = None
- # 为了提升多张图的差异性,循环内随机化种子(若底层支持)
- try:
- from backend.modules.comfyui import ai_swap_cloth as comfy_ai_swap_cloth_module
- except Exception:
- comfy_ai_swap_cloth_module = None
- for _ in range(total_count):
- if comfy_ai_swap_cloth_module is not None:
- try:
- comfy_ai_swap_cloth_module.system_config.seed = random.randint(1, 10_000_000)
- except Exception:
- pass
- # 执行AI换衣服
- result_image, history_prompt = self._process_ai_swap_cloth(
- raw_image, cloth_image
- )
- history_prompt_last = history_prompt
- # 生成文案描述
- copywriter_text = self._generate_copywriter(result_image)
- copywriter_texts.append(copywriter_text)
- # 保存结果图片
- result_image_record = self._save_result_image(user_id, result_image)
- result_image_ids.append(result_image_record["id"])
- # 创建处理记录
- process_record = self._create_process_record(
- user_id, raw_image_id, cloth_image_id,
- result_image_record["id"], copywriter_text, prompt="this is the prompt of swap cloth task"
- )
- process_record_ids.append(process_record["id"])
- logger.info(
- f"用户 {user_id} 的换衣服任务完成,共生成 {len(process_record_ids)} 张,首个记录ID: {process_record_ids[0]}"
- )
- # 为兼容旧前端,保留单值字段,同时返回批量字段
- return {
- "success": True,
- "count": len(process_record_ids),
- "process_record_id": process_record_ids[0],
- "result_image_id": result_image_ids[0],
- "copywriter_text": copywriter_texts[0] if copywriter_texts else None,
- "history_prompt": history_prompt_last,
- "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
- "process_record_ids": process_record_ids,
- "result_image_ids": result_image_ids,
- "copywriter_texts": copywriter_texts,
- }
-
- except Exception as e:
- logger.error(f"换衣服任务执行失败: {str(e)}")
- return {
- "success": False,
- "error": str(e),
- "error_type": type(e).__name__
- }
-
- def process_swap_cloth_with_record(
- self,
- user_id: int,
- raw_image_id: int,
- cloth_image_id: int,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 完整的AI换衣服业务流程(同步版本,保持向后兼容)
-
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- **kwargs: 其他可选参数
- Returns:
- Dict: 包含处理结果的字典
- """
- return self._execute_swap_cloth_task(
- user_id, raw_image_id, cloth_image_id, **kwargs
- )
-
- def _validate_inputs(self, user_id: int, raw_image_id: int, cloth_image_id: int):
- """
- 验证输入参数
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- """
- # 验证用户是否存在
- user = self.db_ops.get_user_by_id(user_id)
- if not user:
- raise ValueError(f"用户ID {user_id} 不存在")
-
- if not user.get("is_active", False):
- raise ValueError(f"用户ID {user_id} 已被禁用")
- # 验证原始图片
- raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
- if not raw_image:
- raise ValueError(f"原始图片ID {raw_image_id} 不存在")
-
- if raw_image["image_type"] != "original":
- raise ValueError(f"原始图片ID {raw_image_id} 不是原始图片类型")
-
- if raw_image["user_id"] != user_id:
- raise ValueError(f"原始图片ID {raw_image_id} 不属于用户 {user_id}")
- # 验证衣服图片
- cloth_image = self.db_ops.get_image_record_by_id(cloth_image_id)
- if not cloth_image:
- raise ValueError(f"衣服图片ID {cloth_image_id} 不存在")
-
- if cloth_image["image_type"] != "cloth":
- raise ValueError(f"衣服图片ID {cloth_image_id} 不是衣服图片类型")
- if cloth_image["user_id"] != user_id:
- raise ValueError(f"衣服图片ID {cloth_image_id} 不属于用户 {user_id}")
- logger.info(f"输入验证通过: 用户={user_id}, 原始图片={raw_image_id}, 衣服图片={cloth_image_id}")
-
- def _get_input_images(self, raw_image_id: int, cloth_image_id: int) -> Tuple[np.ndarray, np.ndarray]:
- """
- 获取输入图片数据
-
- Args:
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- Returns:
- Tuple: (原始图片数组, 衣服图片数组)
- """
- # 获取原始图片
- raw_image = self.db_ops.get_image_record_by_id(raw_image_id)
- if not os.path.exists(raw_image["stored_path"]):
- raise FileNotFoundError(f"原始图片文件不存在: {raw_image['stored_path']}")
-
- raw_image = np.array(Image.open(raw_image["stored_path"]))
- # 获取衣服图片
- cloth_record = self.db_ops.get_image_record_by_id(cloth_image_id)
- if not os.path.exists(cloth_record["stored_path"]):
- raise FileNotFoundError(f"衣服图片文件不存在: {cloth_record['stored_path']}")
- cloth_image = np.array(Image.open(cloth_record["stored_path"]))
- logger.info(f"成功加载输入图片: 原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}")
- return raw_image, cloth_image
-
- def _process_ai_swap_cloth(
- self,
- raw_image: np.ndarray,
- cloth_image: np.ndarray
- ) -> Tuple[Image.Image, str]:
- """
- 执行AI换衣服处理
- Args:
- raw_image: 原始图片数组
- cloth_image: 衣服图片数组
- Returns:
- Tuple: (结果图片, 历史提示词)
- """
- logger.info(f"开始执行AI换衣服处理,原始图片={raw_image.shape}, 衣服图片={cloth_image.shape}")
- try:
- result_image, history_prompt = ai_swap_cloth_process(
- raw_image, cloth_image
- )
- logger.info("AI换衣服处理完成")
- return result_image, history_prompt
- except Exception as e:
- logger.error(f"AI换衣服处理失败: {str(e)}", exc_info=True)
- def _generate_copywriter(self, result_image: Image.Image) -> str:
- """
- 基于结果图片生成文案描述
-
- Args:
- result_image: 结果图片
-
- Returns:
- str: 生成的文案描述
- """
- logger.info("开始生成文案描述")
- try:
- copywriter_text = gen_copywriter(result_image)
- logger.info("文案描述生成完成")
- return copywriter_text
- except Exception as e:
- logger.error(f"文案生成失败: {str(e)}", exc_info=True)
- # 文案生成失败不影响主流程,返回默认文案
- return "AI换衣服完成,效果很棒!✨"
- def _save_result_image(self, user_id: int, result_image: Image.Image) -> Dict[str, Any]:
- """
- 保存结果图片到数据库和文件系统
-
- Args:
- user_id: 用户ID
- result_image: 结果图片
-
- Returns:
- Dict: 图片记录信息
- """
- # 生成唯一文件名
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
- unique_id = str(uuid.uuid4())[:8]
- filename = f"result_{user_id}_{timestamp}_{unique_id}.png"
-
- # 保存到文件系统
- file_path = os.path.join(self.system_config.output_dir, filename)
- result_image.save(file_path, "PNG")
-
- # 计算文件大小和哈希值
- file_size = os.path.getsize(file_path)
- image_hash = self._calculate_image_hash(file_path)
-
- # 保存到数据库
- image_record = self.db_ops.create_image_record(
- user_id=user_id,
- image_type="result",
- original_filename=filename,
- stored_path=file_path,
- file_size=file_size,
- image_hash=image_hash
- )
-
- logger.info(f"结果图片保存成功: {file_path}, 记录ID: {image_record['id']}")
- return image_record
- def _create_process_record(
- self,
- user_id: int,
- raw_image_id: int,
- cloth_image_id: int,
- result_image_id: int,
- copywriter_text: str,
- prompt: str
- ) -> Dict[str, Any]:
- """
- 创建处理记录
-
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- result_image_id: 结果图片ID
- copywriter_text: 文案描述
- prompt: 提示词
- Returns:
- Dict: 处理记录信息
- """
- process_record = self.db_ops.create_process_record(
- user_id=user_id,
- face_image_id=raw_image_id,
- cloth_image_id=cloth_image_id,
- result_image_id=result_image_id,
- generated_text=copywriter_text,
- task_type="swap_cloth",
- prompt=prompt
- )
- # 更新完成时间
- self.db_ops.update_process_record(
- process_record["id"],
- {"completed_at": datetime.now()}
- )
- logger.info(f"处理记录创建成功: {process_record['id']}")
- return process_record
-
- def _calculate_image_hash(self, image_path: str) -> str:
- """
- 计算图片哈希值
-
- Args:
- image_path: 图片路径
-
- Returns:
- str: MD5哈希值
- """
- with open(image_path, 'rb') as f:
- return hashlib.md5(f.read()).hexdigest()
- def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
- """
- 获取用户的处理历史记录
-
- Args:
- user_id: 用户ID
- page: 页码
- page_size: 每页大小
-
- Returns:
- Dict: 分页的处理记录列表
- """
- return self.db_ops.get_user_process_records(user_id, page, page_size)
-
- def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
- """
- 获取处理记录详情
-
- Args:
- process_id: 处理记录ID
- user_id: 用户ID(可选,用于权限验证)
-
- Returns:
- Dict: 处理记录详情,包含关联的图片信息
- """
- process_record = self.db_ops.get_process_record_by_id(process_id)
- if not process_record:
- return None
-
- # 权限验证
- if user_id and process_record["user_id"] != user_id:
- return None
-
- # 获取关联的图片信息
- face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
- cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
- result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
-
- return {
- "process_record": process_record,
- "face_image": face_image,
- "cloth_image": cloth_image,
- "result_image": result_image
- }
-
- def approve_process_record(self, process_id: int) -> bool:
- """
- 审核通过处理记录
-
- Args:
- process_id: 处理记录ID
-
- Returns:
- bool: 操作是否成功
- """
- try:
- # 检查记录是否存在
- process_record = self.db_ops.get_process_record_by_id(process_id)
- if not process_record:
- logger.error(f"处理记录不存在: {process_id}")
- return False
-
- # 更新状态为已审核
- update_data = {
- "status": "已审核"
- }
-
- updated_record = self.db_ops.update_process_record(process_id, update_data)
-
- if updated_record:
- logger.info(f"处理记录 {process_id} 审核通过")
- return True
- else:
- logger.error(f"更新处理记录状态失败: {process_id}")
- return False
-
- except Exception as e:
- logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
- return False
- def delete_result_image(self, process_id: int) -> bool:
- """
- 删除处理记录的结果图片:
- - 清空 ProcessRecord.result_image_id 与 generated_text 中与图片相关的内容保持不变
- - 可选地将状态置为“待审核”或维持原状态;此处不更改状态
- - 不删除底层图片文件与 ImageRecord,仅解除关联,避免误删数据
- """
- try:
- record = self.db_ops.get_process_record_by_id(process_id)
- if not record:
- logger.error(f"处理记录不存在: {process_id}")
- return False
- update_data = {
- "result_image_id": None
- }
- updated = self.db_ops.update_process_record(process_id, update_data)
- if not updated:
- logger.error(f"更新处理记录失败: {process_id}")
- return False
- logger.info(f"处理记录 {process_id} 已解除结果图片关联")
- return True
- except Exception as e:
- logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
- return False
- # 创建全局服务实例
- ai_swap_cloth_service = AISwapClothService()
- def process_swap_cloth_with_record(
- user_id: int,
- raw_image_id: int,
- cloth_image_id: int,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 便捷函数:执行完整的换衣服流程
- Args:
- user_id: 用户ID
- raw_image_id: 原始图片ID
- cloth_image_id: 衣服图片ID
- **kwargs: 其他可选参数
- Returns:
- Dict: 包含处理结果的字典
- """
- return ai_swap_cloth_service.process_swap_cloth_with_record(
- user_id, raw_image_id, cloth_image_id, **kwargs
- )
- if __name__ == "__main__":
- try:
- # 假设用户ID为1,原始图片ID为1,衣服图片ID为1
- result = process_swap_cloth_with_record(
- user_id=5,
- raw_image_id=201,
- cloth_image_id=404
- )
-
- if result["success"]:
- print("处理成功!")
- print(f"处理记录ID: {result['process_record_id']}")
- print(f"结果图片ID: {result['result_image_id']}")
- print(f"文案: {result['copywriter_text']}")
- else:
- print(f"处理失败: {result['error']}")
-
- except Exception as e:
- print(f"发生错误: {str(e)}")
- import traceback
- traceback.print_exc()
|