ai_gen_video_service.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. import os
  2. import uuid
  3. import hashlib
  4. import fal_client
  5. from PIL import Image
  6. from datetime import datetime
  7. from typing import Dict, Any, Optional
  8. from backend.modules.fal_ai.gen_video import video_generator
  9. from backend.modules.comfyui.ai_copywriter import gen_copywriter
  10. from backend.modules.database.operations import DatabaseOperations
  11. from backend.utils.logger_config import setup_logger
  12. from backend.utils.tools import download_video
  13. from backend.utils.system_config import Config
  14. from backend.services.task_queue_service import get_task_queue_service
  15. from dotenv import load_dotenv
  16. from pathlib import Path
  17. env_path = Path("./backend") / ".env"
  18. load_dotenv(dotenv_path=env_path)
  19. logger = setup_logger(__name__)
  20. class AIGenVideoService:
  21. """
  22. AI图生视频业务逻辑服务
  23. 负责协调AI图生视频的各个模块,包括图像处理、文本生成、数据库操作等。
  24. """
  25. def __init__(self, db_operations: Optional[DatabaseOperations] = None):
  26. """初始化AI图生视频业务逻辑服务
  27. Args:
  28. db_operations (Optional[DatabaseOperations], optional): 数据库操作对象。默认为None,表示使用默认的DatabaseOperations对象。
  29. """
  30. self.api_key = os.getenv("FAL_KEY")
  31. if not self.api_key:
  32. logger.warning("未设置FAL_KEY环境变量,无法使用视频生成服务")
  33. if self.api_key:
  34. fal_client.fal_key = self.api_key
  35. self.db_ops = db_operations or DatabaseOperations()
  36. self.system_config = Config('./backend/config/ai_gen_video.json')
  37. self.task_queue = get_task_queue_service()
  38. def submit_gen_video_task(
  39. self,
  40. user_id: int,
  41. image_id: int,
  42. prompt: str,
  43. quantity: int = 1
  44. ) -> str:
  45. """提交生成视频任务
  46. Args:
  47. user_id (int): 用户ID
  48. image_id (int): 图像ID
  49. prompt (str): 提示词
  50. quantity (int, optional): 生成视频的数量。默认为1。
  51. Returns:
  52. str: 任务ID
  53. """
  54. logger.info(f"提交生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}")
  55. # 提交任务到队列
  56. task_id = self.task_queue.submit_task(
  57. task_func=self._execute_gen_video_task,
  58. task_args=(user_id, image_id, prompt),
  59. task_kwargs={
  60. "quantity": max(1, int(quantity))
  61. }
  62. )
  63. logger.info(f"任务提交成功:task_id={task_id}")
  64. return task_id
  65. def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
  66. """
  67. 获取任务状态
  68. Args:
  69. task_id: 任务ID
  70. Returns:
  71. Dict: 任务状态信息
  72. """
  73. return self.task_queue.get_task_status(task_id)
  74. def get_user_tasks(self, user_id: int) -> list:
  75. """
  76. 获取用户的所有任务
  77. Args:
  78. user_id: 用户ID
  79. Returns:
  80. list: 任务列表
  81. """
  82. return self.task_queue.get_all_tasks(user_id)
  83. def cancel_task(self, task_id: str) -> bool:
  84. """
  85. 取消任务
  86. Args:
  87. task_id: 任务ID
  88. Returns:
  89. bool: 是否成功取消
  90. """
  91. return self.task_queue.cancel_task(task_id)
  92. def _execute_gen_video_task(
  93. self,
  94. user_id: int,
  95. image_id: int,
  96. prompt: str,
  97. quantity: int = 1
  98. ) -> Dict[str, Any]:
  99. """执行生成视频任务
  100. Args:
  101. user_id (int): 用户ID
  102. image_id (int): 图像ID
  103. prompt (str): 提示词
  104. Returns:
  105. Dict[str, Any]: 包含任务结果的字典
  106. """
  107. try:
  108. logger.info(f"开始执行生成视频任务:user_id={user_id}, image_id={image_id}, prompt={prompt}, quantity={quantity}")
  109. # 1. 输入验证
  110. self._validate_inputs(user_id, image_id, prompt)
  111. # 2. 获取输入图片
  112. image, image_path = self._get_input_image(image_id)
  113. logger.info(f"输入图像:{image}")
  114. # 3. 循环生成quantity次
  115. total_count = max(1, int(quantity))
  116. process_record_ids = []
  117. result_video_ids = []
  118. copywriter_texts = []
  119. for _ in range(total_count):
  120. # 执行图生视频
  121. result_video = self._video_generator(prompt, image)
  122. logger.info(f"生成视频:{result_video}")
  123. # 生成文案描述
  124. pil_image = Image.open(image_path)
  125. copywriter_text = self._generate_copywriter(pil_image)
  126. copywriter_texts.append(copywriter_text)
  127. # 保留结果视频
  128. result_video_record = self._save_result_video(user_id, result_video)
  129. result_video_ids.append(result_video_record["id"])
  130. # 创建处理记录
  131. process_record = self._create_process_record(
  132. user_id,
  133. image_id,
  134. result_video_record["id"],
  135. copywriter_text,
  136. prompt
  137. )
  138. process_record_ids.append(process_record["id"])
  139. return {
  140. "success": True,
  141. "count": len(process_record_ids),
  142. "process_record_id": process_record_ids[0],
  143. "result_video_id": result_video_ids[0],
  144. "copywriter_text": copywriter_texts[0] if copywriter_text else None,
  145. "process_record": self.db_ops.get_process_record_by_id(process_record_ids[0]),
  146. "process_record_ids": process_record_ids,
  147. "result_video_ids": result_video_ids,
  148. "copywriter_texts": copywriter_texts,
  149. }
  150. except Exception as e:
  151. logger.error(f"AI图生视频任务执行失败: {str(e)}", exc_info=True)
  152. return {
  153. "success": False,
  154. "error": str(e),
  155. "error_type": type(e).__name__
  156. }
  157. def process_gen_video_with_record(
  158. self,
  159. user_id: int,
  160. image_id: int,
  161. prompt: str,
  162. **kwargs
  163. ) -> Dict[str, Any]:
  164. """
  165. 完整的AI图生视频业务流程(同步版本,保持向后兼容)
  166. Args:
  167. user_id: 用户ID
  168. image_id: 图片ID
  169. prompt: 用户输入的提示词
  170. **kwargs: 其他可选参数
  171. Returns:
  172. Dict: 包含处理结果的字典
  173. """
  174. return self._execute_gen_video_task(
  175. user_id, image_id, prompt, **kwargs
  176. )
  177. def _validate_inputs(self, user_id: int, image_id: int, prompt: str):
  178. """
  179. 验证输入参数
  180. Args:
  181. user_id: 用户ID
  182. image_id: 图片ID
  183. prompt: 提示词
  184. """
  185. # 验证用户是否存在
  186. user = self.db_ops.get_user_by_id(user_id)
  187. if not user:
  188. raise ValueError(f"用户ID {user_id} 不存在")
  189. if not user.get("is_active", False):
  190. raise ValueError(f"用户ID {user_id} 已被禁用")
  191. # 验证图片
  192. image = self.db_ops.get_image_record_by_id(image_id)
  193. if not image:
  194. raise ValueError(f"图片ID {image_id} 不存在")
  195. if image["user_id"] != user_id:
  196. raise ValueError(f"图片ID {image_id} 不属于用户ID {user_id}")
  197. # 验证提示词
  198. if not prompt or not prompt.strip():
  199. raise ValueError("提示词不能为空")
  200. if len(prompt.strip()) > 500:
  201. raise ValueError("提示词长度不能超过500字符")
  202. logger.info(f"输入验证通过:用户={user_id}, 图片={image_id}")
  203. def _get_input_image(self, image_id: int) -> str:
  204. """
  205. 获取输入图片数据
  206. Args:
  207. image_id: 图片ID
  208. Returns:
  209. str: 图片数据
  210. """
  211. image_record = self.db_ops.get_image_record_by_id(image_id)
  212. if not os.path.exists(image_record["stored_path"]):
  213. raise FileNotFoundError(f"图片文件不存在: {image_record['stored_path']}")
  214. logger.info(f"配置密钥:{self.api_key}")
  215. fal_client.fal_key = self.api_key
  216. image_url = fal_client.upload_file(image_record["stored_path"])
  217. return image_url, image_record["stored_path"]
  218. def _video_generator(
  219. self,
  220. prompt: str,
  221. image: str
  222. ) -> str:
  223. """
  224. 执行AI图生视频处理
  225. Args:
  226. prompt: 提示词
  227. image: 输入图片数据
  228. Returns:
  229. str: 结果视频
  230. """
  231. logger.info(f"开始执行AI图生视频处理,提示词:{prompt}")
  232. try:
  233. result_video = video_generator.process_task_sync(prompt, image)
  234. logger.info("AI图生视频处理完成")
  235. return result_video
  236. except Exception as e:
  237. logger.error(f"AI图生视频处理失败: {str(e)}", exc_info=True)
  238. raise RuntimeError(f"AI图生视频处理失败: {str(e)}")
  239. def _generate_copywriter(self, result_image: Image.Image) -> str:
  240. """
  241. 基于结果图片生成文案描述
  242. Args:
  243. result_image: 结果图片
  244. Returns:
  245. str: 生成的文案描述
  246. """
  247. logger.info("开始生成文案描述")
  248. try:
  249. copywriter_text = gen_copywriter(result_image)
  250. logger.info("文案描述生成完成")
  251. return copywriter_text
  252. except Exception as e:
  253. logger.error(f"文案生成失败: {str(e)}", exc_info=True)
  254. # 文案生成失败不影响主流程,返回默认文案
  255. return "AI图生视频完成,效果很棒!✨"
  256. def _save_result_video(self, user_id: int, result_video: str) -> Dict[str, Any]:
  257. """
  258. 保存结果视频到数据库和文件系统
  259. Args:
  260. user_id: 用户ID
  261. result_video: 结果视频
  262. Returns:
  263. Dict: 视频记录信息
  264. """
  265. # 生成唯一文件名
  266. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  267. unique_id = str(uuid.uuid4())[:8]
  268. filename = f"result_{user_id}_{timestamp}_{unique_id}.mp4"
  269. # 保存到文件系统
  270. file_path = os.path.join(self.system_config.output_dir, filename)
  271. download_video(result_video, file_path)
  272. # 计算文件大小和哈希值
  273. file_size = os.path.getsize(file_path)
  274. video_hash = self._calculate_video_hash(file_path)
  275. # 保存到数据库
  276. file_size = os.path.getsize(file_path)
  277. video_hash = self._calculate_video_hash(file_path)
  278. # 保存到数据库
  279. video_record = self.db_ops.create_image_record(
  280. user_id=user_id,
  281. image_type="result",
  282. original_filename=filename,
  283. stored_path=file_path,
  284. file_size=file_size,
  285. image_hash=video_hash
  286. )
  287. logger.info(f"结果视频保存成功:{file_path}, 记录ID: {video_record['id']}")
  288. return video_record
  289. def _create_process_record(
  290. self,
  291. user_id: int,
  292. image_id: int,
  293. result_image_id: int,
  294. copywriter_text: str,
  295. prompt: str
  296. ) -> Dict[str, Any]:
  297. """
  298. 创建处理记录
  299. Args:
  300. user_id: 用户ID
  301. image_id: 输入图片ID
  302. result_image_id: 结果图片ID
  303. copywriter_text: 文案描述
  304. prompt: 提示词
  305. Returns:
  306. Dict: 处理记录信息
  307. """
  308. process_record = self.db_ops.create_process_record(
  309. user_id=user_id,
  310. face_image_id=image_id,
  311. cloth_image_id=image_id,
  312. result_image_id=result_image_id,
  313. generated_text=copywriter_text,
  314. task_type="img2video",
  315. prompt=prompt
  316. )
  317. # 更新完成时间
  318. self.db_ops.update_process_record(
  319. process_record["id"],
  320. {"completed_at": datetime.now()}
  321. )
  322. logger.info(f"处理记录创建成功: {process_record['id']}")
  323. return process_record
  324. def _calculate_video_hash(self, video_path: str) -> str:
  325. """
  326. 计算视频哈希值
  327. Args:
  328. video_path: 视频路径
  329. Returns:
  330. str: MD5哈希值
  331. """
  332. hash_func = hashlib.new("md5")
  333. try:
  334. with open(video_path, 'rb') as f:
  335. while chunk := f.read(8192):
  336. hash_func.update(chunk)
  337. return hash_func.hexdigest()
  338. except FileNotFoundError:
  339. return f"错误:文件 {video_path} 未找到。"
  340. except Exception as e:
  341. return f"计算哈希时发生错误:{e}"
  342. def get_user_process_history(self, user_id: int, page: int = 1, page_size: int = 20) -> Dict[str, Any]:
  343. """
  344. 获取用户的处理历史记录
  345. Args:
  346. user_id: 用户ID
  347. page: 页码
  348. page_size: 每页大小
  349. Returns:
  350. Dict: 分页的处理记录列表
  351. """
  352. return self.db_ops.get_user_process_records(user_id, page, page_size)
  353. def get_process_detail(self, process_id: int, user_id: Optional[int] = None) -> Optional[Dict[str, Any]]:
  354. """
  355. 获取处理记录详情
  356. Args:
  357. process_id: 处理记录ID
  358. user_id: 用户ID(可选,用于权限验证)
  359. Returns:
  360. Dict: 处理记录详情,包含关联的图片信息
  361. """
  362. process_record = self.db_ops.get_process_record_by_id(process_id)
  363. if not process_record:
  364. return None
  365. # 权限验证
  366. if user_id and process_record["user_id"] != user_id:
  367. return None
  368. # 获取关联的图片信息
  369. face_image = self.db_ops.get_image_record_by_id(process_record["face_image_id"])
  370. cloth_image = self.db_ops.get_image_record_by_id(process_record["cloth_image_id"])
  371. result_image = self.db_ops.get_image_record_by_id(process_record["result_image_id"])
  372. return {
  373. "process_record": process_record,
  374. "face_image": face_image,
  375. "cloth_image": cloth_image,
  376. "result_image": result_image
  377. }
  378. def approve_process_record(self, process_id: int) -> bool:
  379. """
  380. 审核通过处理记录
  381. Args:
  382. process_id: 处理记录ID
  383. Returns:
  384. bool: 操作是否成功
  385. """
  386. try:
  387. # 检查记录是否存在
  388. process_record = self.db_ops.get_process_record_by_id(process_id)
  389. if not process_record:
  390. logger.error(f"处理记录不存在: {process_id}")
  391. return False
  392. # 更新状态为已审核
  393. update_data = {
  394. "status": "已审核"
  395. }
  396. updated_record = self.db_ops.update_process_record(process_id, update_data)
  397. if updated_record:
  398. logger.info(f"处理记录 {process_id} 审核通过")
  399. return True
  400. else:
  401. logger.error(f"更新处理记录状态失败: {process_id}")
  402. return False
  403. except Exception as e:
  404. logger.error(f"审核处理记录异常: {str(e)}", exc_info=True)
  405. return False
  406. def delete_result_image(self, process_id: int) -> bool:
  407. """删除处理记录中的结果视频"""
  408. try:
  409. record = self.db_ops.get_process_record_by_id(process_id)
  410. if not record:
  411. logger.error(f"处理记录不存在: {process_id}")
  412. return False
  413. update_data = {
  414. "result_image_id": None
  415. }
  416. updated = self.db_ops.update_process_record(process_id, update_data)
  417. if not updated:
  418. logger.error(f"更新处理记录失败: {process_id}")
  419. return False
  420. logger.info(f"处理记录 {process_id} 已解除结果图片关联")
  421. return True
  422. except Exception as e:
  423. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  424. return False
  425. # 创建全局服务实例
  426. ai_gen_video_service = AIGenVideoService()
  427. def process_gen_video_with_record(
  428. user_id: int,
  429. image_id: int,
  430. prompt: str,
  431. **kwargs
  432. ) -> Dict[str, Any]:
  433. """
  434. 便捷函数:执行完整的图生视频流程
  435. Args:
  436. user_id: 用户ID
  437. image_id: 图片ID
  438. prompt: 提示词
  439. **kwargs: 其他可选参数
  440. Returns:
  441. Dict: 包含处理结果的字典
  442. """
  443. return ai_gen_video_service.process_gen_video_with_record(
  444. user_id, image_id, prompt, **kwargs
  445. )
  446. if __name__ == "__main__":
  447. try:
  448. result = process_gen_video_with_record(
  449. user_id=4,
  450. image_id=187,
  451. prompt="美女站在海边"
  452. )
  453. if result["success"]:
  454. print("处理成功!")
  455. print(f"处理记录ID: {result['process_record_id']}")
  456. print(f"结果图片ID: {result['result_video_id']}")
  457. print(f"文案: {result['copywriter_text']}")
  458. else:
  459. print(f"处理失败: {result['error']}")
  460. except Exception as e:
  461. print(f"测试失败: {str(e)}")