| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565 |
- import os
- import uuid
- import hashlib
- import fal_client
- from PIL import Image
- from datetime import datetime
- from typing import Dict, Any, Optional
- from backend.modules.fal_ai.gen_video import video_generator
- from backend.modules.comfyui.ai_copywriter import gen_copywriter
- from backend.modules.database.operations import DatabaseOperations
- from backend.utils.logger_config import setup_logger
- from backend.utils.tools import download_video
- from backend.utils.system_config import Config
- from backend.services.task_queue_service import get_task_queue_service
- from dotenv import load_dotenv
- from pathlib import Path
- env_path = Path("./backend") / ".env"
- load_dotenv(dotenv_path=env_path)
- logger = setup_logger(__name__)
- class AIGenVideoService:
- """
- AI图生视频业务逻辑服务
- 负责协调AI图生视频的各个模块,包括图像处理、文本生成、数据库操作等。
- """
- def __init__(self, db_operations: Optional[DatabaseOperations] = None):
- """初始化AI图生视频业务逻辑服务
- Args:
- db_operations (Optional[DatabaseOperations], optional): 数据库操作对象。默认为None,表示使用默认的DatabaseOperations对象。
- """
- self.api_key = os.getenv("FAL_KEY")
- if not self.api_key:
- logger.warning("未设置FAL_KEY环境变量,无法使用视频生成服务")
- if self.api_key:
- fal_client.fal_key = self.api_key
- self.db_ops = db_operations or DatabaseOperations()
- self.system_config = Config('./backend/config/ai_gen_video.json')
- self.task_queue = get_task_queue_service()
- def submit_gen_video_task(
- self,
- user_id: int,
- image_id: int,
- prompt: str,
- quantity: int = 1
- ) -> str:
- """提交生成视频任务
- Args:
- user_id (int): 用户ID
- image_id (int): 图像ID
- prompt (str): 提示词
- quantity (int, optional): 生成视频的数量。默认为1。
- Returns:
- str: 任务ID
- """
- logger.info(f"提交生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}")
- # 提交任务到队列
- task_id = self.task_queue.submit_task(
- task_func=self._execute_gen_video_task,
- task_args=(user_id, image_id, prompt),
- task_kwargs={
- "quantity": max(1, int(quantity))
- }
- )
- logger.info(f"任务提交成功:task_id={task_id}")
- return task_id
- def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
- """
- 获取任务状态
-
- Args:
- task_id: 任务ID
-
- Returns:
- Dict: 任务状态信息
- """
- return self.task_queue.get_task_status(task_id)
- def get_user_tasks(self, user_id: int) -> list:
- """
- 获取用户的所有任务
-
- Args:
- user_id: 用户ID
-
- Returns:
- list: 任务列表
- """
- return self.task_queue.get_all_tasks(user_id)
- def cancel_task(self, task_id: str) -> bool:
- """
- 取消任务
-
- Args:
- task_id: 任务ID
-
- Returns:
- bool: 是否成功取消
- """
- return self.task_queue.cancel_task(task_id)
- def _execute_gen_video_task(
- self,
- user_id: int,
- image_id: int,
- prompt: str,
- quantity: int = 1
- ) -> Dict[str, Any]:
- """执行生成视频任务
- Args:
- user_id (int): 用户ID
- image_id (int): 图像ID
- prompt (str): 提示词
- Returns:
- Dict[str, Any]: 包含任务结果的字典
- """
- try:
- logger.info(f"开始执行生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}")
- # 1. 输入验证
- self._validate_inputs(user_id, image_id, prompt)
- # 2. 获取输入图片
- image, image_path = self._get_input_image(image_id)
- logger.info(f"输入图像:{image}")
- # 3. 循环生成quantity次
- total_count = max(1, int(quantity))
- process_record_ids = []
- result_video_ids = []
- copywriter_texts = []
- for _ in range(total_count):
- # 执行图生视频
- result_video = self._video_generator(prompt, image)
- logger.info(f"生成视频:{result_video}")
- # 生成文案描述
- pil_image = Image.open(image_path)
- copywriter_text = self._generate_copywriter(pil_image)
- copywriter_texts.append(copywriter_text)
- # 保留结果视频
- result_video_record = self._save_result_video(user_id, result_video)
- result_video_ids.append(result_video_record["id"])
- # 创建处理记录
- process_record = self._create_process_record(
- user_id,
- image_id,
- result_video_record["id"],
- copywriter_text,
- prompt
- )
- process_record_ids.append(process_record["id"])
- return {
- "success": True,
- "count": len(process_record_ids),
- "process_record_id": process_record_ids[0],
- "result_video_id": result_video_ids[0],
- "copywriter_text": copywriter_texts[0] if copywriter_text else None,
- "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
- "process_record_ids": process_record_ids,
- "result_video_ids": result_video_ids,
- "copywriter_texts": copywriter_texts,
- }
- except Exception as e:
- logger.error(f"AI图生视频任务执行失败: {str(e)}", exc_info=True)
- return {
- "success": False,
- "error": str(e),
- "error_type": type(e).__name__
- }
- def process_gen_video_with_record(
- self,
- user_id: int,
- image_id: int,
- prompt: str,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 完整的AI图生视频业务流程(同步版本,保持向后兼容)
- Args:
- user_id: 用户ID
- image_id: 图片ID
- prompt: 用户输入的提示词
- **kwargs: 其他可选参数
-
- Returns:
- Dict: 包含处理结果的字典
- """
- return self._execute_gen_video_task(
- user_id, image_id, prompt, **kwargs
- )
- def _validate_inputs(self, user_id: int, image_id: int, prompt: str):
- """
- 验证输入参数
-
- Args:
- user_id: 用户ID
- image_id: 图片ID
- prompt: 提示词
- """
- # 验证用户是否存在
- user = self.db_ops.get_user_by_id(user_id)
- if not user:
- raise ValueError(f"用户ID {user_id} 不存在")
-
- if not user.get("is_active", False):
- raise ValueError(f"用户ID {user_id} 已被禁用")
- # 验证图片
- image = self.db_ops.get_image_record_by_id(image_id)
- if not image:
- raise ValueError(f"图片ID {image_id} 不存在")
- if image["user_id"] != user_id:
- raise ValueError(f"图片ID {image_id} 不属于用户ID {user_id}")
- # 验证提示词
- if not prompt or not prompt.strip():
- raise ValueError("提示词不能为空")
- if len(prompt.strip()) > 500:
- raise ValueError("提示词长度不能超过500字符")
- logger.info(f"输入验证通过:用户={user_id}, 图片={image_id}")
- def _get_input_image(self, image_id: int) -> str:
- """
- 获取输入图片数据
- Args:
- image_id: 图片ID
- Returns:
- str: 图片数据
- """
- image_record = self.db_ops.get_image_record_by_id(image_id)
- if not os.path.exists(image_record["stored_path"]):
- raise FileNotFoundError(f"图片文件不存在: {image_record['stored_path']}")
- logger.info(f"配置密钥:{self.api_key}")
- fal_client.fal_key = self.api_key
- image_url = fal_client.upload_file(image_record["stored_path"])
- return image_url, image_record["stored_path"]
-
- def _video_generator(
- self,
- prompt: str,
- image: str
- ) -> str:
- """
- 执行AI图生视频处理
- Args:
- prompt: 提示词
- image: 输入图片数据
- Returns:
- str: 结果视频
- """
- logger.info(f"开始执行AI图生视频处理,提示词:{prompt}")
- try:
- result_video = video_generator.process_task_sync(prompt, image)
- logger.info("AI图生视频处理完成")
- return result_video
-
- except Exception as e:
- logger.error(f"AI图生视频处理失败: {str(e)}", exc_info=True)
- raise RuntimeError(f"AI图生视频处理失败: {str(e)}")
- def _generate_copywriter(self, result_image: Image.Image) -> str:
- """
- 基于结果图片生成文案描述
-
- Args:
- result_image: 结果图片
-
- Returns:
- str: 生成的文案描述
- """
- logger.info("开始生成文案描述")
-
- try:
- copywriter_text = gen_copywriter(result_image)
- logger.info("文案描述生成完成")
- return copywriter_text
-
- except Exception as e:
- logger.error(f"文案生成失败: {str(e)}", exc_info=True)
- # 文案生成失败不影响主流程,返回默认文案
- return "AI图生视频完成,效果很棒!✨"
- def _save_result_video(self, user_id: int, result_video: str) -> Dict[str, Any]:
- """
- 保存结果视频到数据库和文件系统
-
- Args:
- user_id: 用户ID
- result_video: 结果视频
-
- Returns:
- Dict: 视频记录信息
- """
- # 生成唯一文件名
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
- unique_id = str(uuid.uuid4())[:8]
- filename = f"result_{user_id}_{timestamp}_{unique_id}.mp4"
-
- # 保存到文件系统
- file_path = os.path.join(self.system_config.output_dir, filename)
- download_video(result_video, file_path)
- # 计算文件大小和哈希值
- file_size = os.path.getsize(file_path)
- video_hash = self._calculate_video_hash(file_path)
- # 保存到数据库
- file_size = os.path.getsize(file_path)
- video_hash = self._calculate_video_hash(file_path)
- # 保存到数据库
- video_record = self.db_ops.create_image_record(
- user_id=user_id,
- image_type="result",
- original_filename=filename,
- stored_path=file_path,
- file_size=file_size,
- image_hash=video_hash
- )
- logger.info(f"结果视频保存成功:{file_path}, 记录ID: {video_record['id']}")
- return video_record
- def _create_process_record(
- self,
- user_id: int,
- image_id: int,
- result_image_id: int,
- copywriter_text: str,
- prompt: str
- ) -> Dict[str, Any]:
- """
- 创建处理记录
-
- Args:
- user_id: 用户ID
- image_id: 输入图片ID
- result_image_id: 结果图片ID
- copywriter_text: 文案描述
- prompt: 提示词
- Returns:
- Dict: 处理记录信息
- """
- process_record = self.db_ops.create_process_record(
- user_id=user_id,
- face_image_id=image_id,
- cloth_image_id=image_id,
- result_image_id=result_image_id,
- generated_text=copywriter_text,
- task_type="img2video",
- prompt=prompt
- )
- # 更新完成时间
- self.db_ops.update_process_record(
- process_record["id"],
- {"completed_at": datetime.now()}
- )
- logger.info(f"处理记录创建成功: {process_record['id']}")
- return process_record
- def _calculate_video_hash(self, video_path: str) -> str:
- """
- 计算视频哈希值
-
- Args:
- video_path: 视频路径
-
- Returns:
- str: MD5哈希值
- """
- hash_func = hashlib.new("md5")
- try:
- with open(video_path, 'rb') as f:
- while chunk := f.read(8192):
- hash_func.update(chunk)
- return hash_func.hexdigest()
- except FileNotFoundError:
- return f"错误:文件 {video_path} 未找到。"
- except Exception as e:
- return f"计算哈希时发生错误:{e}"
- def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
- """
- 获取用户的处理历史记录
-
- Args:
- user_id: 用户ID
- page: 页码
- page_size: 每页大小
-
- Returns:
- Dict: 分页的处理记录列表
- """
- return self.db_ops.get_user_process_records(user_id, page, page_size)
- def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
- """
- 获取处理记录详情
-
- Args:
- process_id: 处理记录ID
- user_id: 用户ID(可选,用于权限验证)
-
- Returns:
- Dict: 处理记录详情,包含关联的图片信息
- """
- process_record = self.db_ops.get_process_record_by_id(process_id)
- if not process_record:
- return None
-
- # 权限验证
- if user_id and process_record["user_id"] != user_id:
- return None
-
- # 获取关联的图片信息
- face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
- cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
- result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
-
- return {
- "process_record": process_record,
- "face_image": face_image,
- "cloth_image": cloth_image,
- "result_image": result_image
- }
- def approve_process_record(self, process_id: int) -> bool:
- """
- 审核通过处理记录
-
- Args:
- process_id: 处理记录ID
-
- Returns:
- bool: 操作是否成功
- """
- try:
- # 检查记录是否存在
- process_record = self.db_ops.get_process_record_by_id(process_id)
- if not process_record:
- logger.error(f"处理记录不存在: {process_id}")
- return False
-
- # 更新状态为已审核
- update_data = {
- "status": "已审核"
- }
-
- updated_record = self.db_ops.update_process_record(process_id, update_data)
-
- if updated_record:
- logger.info(f"处理记录 {process_id} 审核通过")
- return True
- else:
- logger.error(f"更新处理记录状态失败: {process_id}")
- return False
-
- except Exception as e:
- logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
- return False
- def delete_result_image(self, process_id: int) -> bool:
- """删除处理记录中的结果视频"""
- try:
- record = self.db_ops.get_process_record_by_id(process_id)
- if not record:
- logger.error(f"处理记录不存在: {process_id}")
- return False
- update_data = {
- "result_image_id": None
- }
- updated = self.db_ops.update_process_record(process_id, update_data)
- if not updated:
- logger.error(f"更新处理记录失败: {process_id}")
- return False
- logger.info(f"处理记录 {process_id} 已解除结果图片关联")
- return True
- except Exception as e:
- logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
- return False
-
- # 创建全局服务实例
- ai_gen_video_service = AIGenVideoService()
- def process_gen_video_with_record(
- user_id: int,
- image_id: int,
- prompt: str,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 便捷函数:执行完整的图生视频流程
- Args:
- user_id: 用户ID
- image_id: 图片ID
- prompt: 提示词
- **kwargs: 其他可选参数
-
- Returns:
- Dict: 包含处理结果的字典
- """
- return ai_gen_video_service.process_gen_video_with_record(
- user_id, image_id, prompt, **kwargs
- )
- if __name__ == "__main__":
- try:
- result = process_gen_video_with_record(
- user_id=4,
- image_id=187,
- prompt="美女站在海边"
- )
- if result["success"]:
- print("处理成功!")
- print(f"处理记录ID: {result['process_record_id']}")
- print(f"结果图片ID: {result['result_video_id']}")
- print(f"文案: {result['copywriter_text']}")
- else:
- print(f"处理失败: {result['error']}")
-
- except Exception as e:
- print(f"测试失败: {str(e)}")
|