import os import uuid import hashlib import fal_client from PIL import Image from datetime import datetime from typing import Dict, Any, Optional from backend.modules.fal_ai.gen_video import video_generator from backend.modules.comfyui.ai_copywriter import gen_copywriter from backend.modules.database.operations import DatabaseOperations from backend.utils.logger_config import setup_logger from backend.utils.tools import download_video from backend.utils.system_config import Config from backend.services.task_queue_service import get_task_queue_service from dotenv import load_dotenv from pathlib import Path env_path = Path("./backend") / ".env" load_dotenv(dotenv_path=env_path) logger = setup_logger(__name__) class AIGenVideoService: """ AI图生视频业务逻辑服务 负责协调AI图生视频的各个模块,包括图像处理、文本生成、数据库操作等。 """ def __init__(self, db_operations: Optional[DatabaseOperations] = None): """初始化AI图生视频业务逻辑服务 Args: db_operations (Optional[DatabaseOperations], optional): 数据库操作对象。默认为None,表示使用默认的DatabaseOperations对象。 """ self.api_key = os.getenv("FAL_KEY") if not self.api_key: logger.warning("未设置FAL_KEY环境变量,无法使用视频生成服务") if self.api_key: fal_client.fal_key = self.api_key self.db_ops = db_operations or DatabaseOperations() self.system_config = Config('./backend/config/ai_gen_video.json') self.task_queue = get_task_queue_service() def submit_gen_video_task( self, user_id: int, image_id: int, prompt: str, quantity: int = 1 ) -> str: """提交生成视频任务 Args: user_id (int): 用户ID image_id (int): 图像ID prompt (str): 提示词 quantity (int, optional): 生成视频的数量。默认为1。 Returns: str: 任务ID """ logger.info(f"提交生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}") # 提交任务到队列 task_id = self.task_queue.submit_task( task_func=self._execute_gen_video_task, task_args=(user_id, image_id, prompt), task_kwargs={ "quantity": max(1, int(quantity)) } ) logger.info(f"任务提交成功:task_id={task_id}") return task_id def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: """ 获取任务状态 Args: task_id: 任务ID Returns: Dict: 任务状态信息 """ return self.task_queue.get_task_status(task_id) def get_user_tasks(self, user_id: int) -> list: """ 获取用户的所有任务 Args: user_id: 用户ID Returns: list: 任务列表 """ return self.task_queue.get_all_tasks(user_id) def cancel_task(self, task_id: str) -> bool: """ 取消任务 Args: task_id: 任务ID Returns: bool: 是否成功取消 """ return self.task_queue.cancel_task(task_id) def _execute_gen_video_task( self, user_id: int, image_id: int, prompt: str, quantity: int = 1 ) -> Dict[str, Any]: """执行生成视频任务 Args: user_id (int): 用户ID image_id (int): 图像ID prompt (str): 提示词 Returns: Dict[str, Any]: 包含任务结果的字典 """ try: logger.info(f"开始执行生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}") # 1. 输入验证 self._validate_inputs(user_id, image_id, prompt) # 2. 获取输入图片 image, image_path = self._get_input_image(image_id) logger.info(f"输入图像:{image}") # 3. 循环生成quantity次 total_count = max(1, int(quantity)) process_record_ids = [] result_video_ids = [] copywriter_texts = [] for _ in range(total_count): # 执行图生视频 result_video = self._video_generator(prompt, image) logger.info(f"生成视频:{result_video}") # 生成文案描述 pil_image = Image.open(image_path) copywriter_text = self._generate_copywriter(pil_image) copywriter_texts.append(copywriter_text) # 保留结果视频 result_video_record = self._save_result_video(user_id, result_video) result_video_ids.append(result_video_record["id"]) # 创建处理记录 process_record = self._create_process_record( user_id, image_id, result_video_record["id"], copywriter_text, prompt ) process_record_ids.append(process_record["id"]) return { "success": True, "count": len(process_record_ids), "process_record_id": process_record_ids[0], "result_video_id": result_video_ids[0], "copywriter_text": copywriter_texts[0] if copywriter_text else None, "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]), "process_record_ids": process_record_ids, "result_video_ids": result_video_ids, "copywriter_texts": copywriter_texts, } except Exception as e: logger.error(f"AI图生视频任务执行失败: {str(e)}", exc_info=True) return { "success": False, "error": str(e), "error_type": type(e).__name__ } def process_gen_video_with_record( self, user_id: int, image_id: int, prompt: str, **kwargs ) -> Dict[str, Any]: """ 完整的AI图生视频业务流程(同步版本,保持向后兼容) Args: user_id: 用户ID image_id: 图片ID prompt: 用户输入的提示词 **kwargs: 其他可选参数 Returns: Dict: 包含处理结果的字典 """ return self._execute_gen_video_task( user_id, image_id, prompt, **kwargs ) def _validate_inputs(self, user_id: int, image_id: int, prompt: str): """ 验证输入参数 Args: user_id: 用户ID image_id: 图片ID prompt: 提示词 """ # 验证用户是否存在 user = self.db_ops.get_user_by_id(user_id) if not user: raise ValueError(f"用户ID {user_id} 不存在") if not user.get("is_active", False): raise ValueError(f"用户ID {user_id} 已被禁用") # 验证图片 image = self.db_ops.get_image_record_by_id(image_id) if not image: raise ValueError(f"图片ID {image_id} 不存在") if image["user_id"] != user_id: raise ValueError(f"图片ID {image_id} 不属于用户ID {user_id}") # 验证提示词 if not prompt or not prompt.strip(): raise ValueError("提示词不能为空") if len(prompt.strip()) > 500: raise ValueError("提示词长度不能超过500字符") logger.info(f"输入验证通过:用户={user_id}, 图片={image_id}") def _get_input_image(self, image_id: int) -> str: """ 获取输入图片数据 Args: image_id: 图片ID Returns: str: 图片数据 """ image_record = self.db_ops.get_image_record_by_id(image_id) if not os.path.exists(image_record["stored_path"]): raise FileNotFoundError(f"图片文件不存在: {image_record['stored_path']}") logger.info(f"配置密钥:{self.api_key}") fal_client.fal_key = self.api_key image_url = fal_client.upload_file(image_record["stored_path"]) return image_url, image_record["stored_path"] def _video_generator( self, prompt: str, image: str ) -> str: """ 执行AI图生视频处理 Args: prompt: 提示词 image: 输入图片数据 Returns: str: 结果视频 """ logger.info(f"开始执行AI图生视频处理,提示词:{prompt}") try: result_video = video_generator.process_task_sync(prompt, image) logger.info("AI图生视频处理完成") return result_video except Exception as e: logger.error(f"AI图生视频处理失败: {str(e)}", exc_info=True) raise RuntimeError(f"AI图生视频处理失败: {str(e)}") def _generate_copywriter(self, result_image: Image.Image) -> str: """ 基于结果图片生成文案描述 Args: result_image: 结果图片 Returns: str: 生成的文案描述 """ logger.info("开始生成文案描述") try: copywriter_text = gen_copywriter(result_image) logger.info("文案描述生成完成") return copywriter_text except Exception as e: logger.error(f"文案生成失败: {str(e)}", exc_info=True) # 文案生成失败不影响主流程,返回默认文案 return "AI图生视频完成,效果很棒!✨" def _save_result_video(self, user_id: int, result_video: str) -> Dict[str, Any]: """ 保存结果视频到数据库和文件系统 Args: user_id: 用户ID result_video: 结果视频 Returns: Dict: 视频记录信息 """ # 生成唯一文件名 timestamp = datetime.now().strftime("%Y%m%d%H%M%S") unique_id = str(uuid.uuid4())[:8] filename = f"result_{user_id}_{timestamp}_{unique_id}.mp4" # 保存到文件系统 file_path = os.path.join(self.system_config.output_dir, filename) download_video(result_video, file_path) # 计算文件大小和哈希值 file_size = os.path.getsize(file_path) video_hash = self._calculate_video_hash(file_path) # 保存到数据库 file_size = os.path.getsize(file_path) video_hash = self._calculate_video_hash(file_path) # 保存到数据库 video_record = self.db_ops.create_image_record( user_id=user_id, image_type="result", original_filename=filename, stored_path=file_path, file_size=file_size, image_hash=video_hash ) logger.info(f"结果视频保存成功:{file_path}, 记录ID: {video_record['id']}") return video_record def _create_process_record( self, user_id: int, image_id: int, result_image_id: int, copywriter_text: str, prompt: str ) -> Dict[str, Any]: """ 创建处理记录 Args: user_id: 用户ID image_id: 输入图片ID result_image_id: 结果图片ID copywriter_text: 文案描述 prompt: 提示词 Returns: Dict: 处理记录信息 """ process_record = self.db_ops.create_process_record( user_id=user_id, face_image_id=image_id, cloth_image_id=image_id, result_image_id=result_image_id, generated_text=copywriter_text, task_type="img2video", prompt=prompt ) # 更新完成时间 self.db_ops.update_process_record( process_record["id"], {"completed_at": datetime.now()} ) logger.info(f"处理记录创建成功: {process_record['id']}") return process_record def _calculate_video_hash(self, video_path: str) -> str: """ 计算视频哈希值 Args: video_path: 视频路径 Returns: str: MD5哈希值 """ hash_func = hashlib.new("md5") try: with open(video_path, 'rb') as f: while chunk := f.read(8192): hash_func.update(chunk) return hash_func.hexdigest() except FileNotFoundError: return f"错误:文件 {video_path} 未找到。" except Exception as e: return f"计算哈希时发生错误:{e}" def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]: """ 获取用户的处理历史记录 Args: user_id: 用户ID page: 页码 page_size: 每页大小 Returns: Dict: 分页的处理记录列表 """ return self.db_ops.get_user_process_records(user_id, page, page_size) def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]: """ 获取处理记录详情 Args: process_id: 处理记录ID user_id: 用户ID(可选,用于权限验证) Returns: Dict: 处理记录详情,包含关联的图片信息 """ process_record = self.db_ops.get_process_record_by_id(process_id) if not process_record: return None # 权限验证 if user_id and process_record["user_id"] != user_id: return None # 获取关联的图片信息 face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"]) cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"]) result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"]) return { "process_record": process_record, "face_image": face_image, "cloth_image": cloth_image, "result_image": result_image } def approve_process_record(self, process_id: int) -> bool: """ 审核通过处理记录 Args: process_id: 处理记录ID Returns: bool: 操作是否成功 """ try: # 检查记录是否存在 process_record = self.db_ops.get_process_record_by_id(process_id) if not process_record: logger.error(f"处理记录不存在: {process_id}") return False # 更新状态为已审核 update_data = { "status": "已审核" } updated_record = self.db_ops.update_process_record(process_id, update_data) if updated_record: logger.info(f"处理记录 {process_id} 审核通过") return True else: logger.error(f"更新处理记录状态失败: {process_id}") return False except Exception as e: logger.error(f"审核处理记录异常: {str(e)}", exc_info=True) return False def delete_result_image(self, process_id: int) -> bool: """删除处理记录中的结果视频""" try: record = self.db_ops.get_process_record_by_id(process_id) if not record: logger.error(f"处理记录不存在: {process_id}") return False update_data = { "result_image_id": None } updated = self.db_ops.update_process_record(process_id, update_data) if not updated: logger.error(f"更新处理记录失败: {process_id}") return False logger.info(f"处理记录 {process_id} 已解除结果图片关联") return True except Exception as e: logger.error(f"删除结果图片失败: {str(e)}", exc_info=True) return False # 创建全局服务实例 ai_gen_video_service = AIGenVideoService() def process_gen_video_with_record( user_id: int, image_id: int, prompt: str, **kwargs ) -> Dict[str, Any]: """ 便捷函数:执行完整的图生视频流程 Args: user_id: 用户ID image_id: 图片ID prompt: 提示词 **kwargs: 其他可选参数 Returns: Dict: 包含处理结果的字典 """ return ai_gen_video_service.process_gen_video_with_record( user_id, image_id, prompt, **kwargs ) if __name__ == "__main__": try: result = process_gen_video_with_record( user_id=4, image_id=187, prompt="美女站在海边" ) if result["success"]: print("处理成功!") print(f"处理记录ID: {result['process_record_id']}") print(f"结果图片ID: {result['result_video_id']}") print(f"文案: {result['copywriter_text']}") else: print(f"处理失败: {result['error']}") except Exception as e: print(f"测试失败: {str(e)}")