video_generator.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import os
  2. import time
  3. import threading
  4. from pathlib import Path
  5. from typing import Optional, List, Dict, Any, Callable, Union
  6. from tqdm import tqdm
  7. from dataclasses import dataclass, field
  8. from modules.media_understanding.media_captioner import media_captioner
  9. from modules.media_process.media_processor import media_processor
  10. from modules.media_generate.media_generator import (
  11. video_create,
  12. handle_video_result
  13. )
  14. from utils.tools import (
  15. read_json_file,
  16. save_json_file,
  17. )
  18. from utils.logger_config import setup_logger
  19. logger = setup_logger(__name__)
  20. @dataclass
  21. class VideoGenerationConfig:
  22. """视频生成配置类"""
  23. # 路径配置
  24. output_base_dir: str = "./output"
  25. # 视频生成参数
  26. video_resolution: str = "1080p"
  27. video_ratio: str = "16:9"
  28. watermark_enabled: bool = False
  29. crop_frame: bool = False
  30. # 超时配置
  31. video_generation_timeout: int = 3600 # 秒
  32. polling_interval: int = 5 # 秒
  33. # 脚本生成参数
  34. script_prompt_type: str = "script"
  35. script_scenario: str = "video"
  36. prompt_optimization_prefix: str = "待打磨优化的提示词:"
  37. class VideoClipGenerator:
  38. """视频片段生成器"""
  39. def __init__(self, config: VideoGenerationConfig):
  40. self.config = config
  41. self._completed_tasks = 0
  42. self._lock = threading.Lock()
  43. def generate(self, video_script_data: dict) -> bool:
  44. """
  45. 生成所有视频片段
  46. Args:
  47. script_path: 脚本文件路径
  48. Returns:
  49. 是否全部成功完成
  50. Raises:
  51. FileNotFoundError: 当脚本文件不存在时
  52. ValueError: 当脚本数据无效时
  53. """
  54. # if not os.path.exists(script_path):
  55. # raise FileNotFoundError(f"脚本文件不存在: {script_path}")
  56. # # 读取脚本
  57. # video_script_data = read_json_file(script_path)
  58. # if not video_script_data:
  59. # raise ValueError("无法读取脚本文件或文件为空")
  60. lens_details = []
  61. storyboards = video_script_data.get("storyboards", [])
  62. for storyboard in storyboards:
  63. item_info = storyboard.get("storyboard", [])
  64. lens_details.append(item_info)
  65. lens_details = [item for sublist in lens_details for item in sublist]
  66. if not storyboards:
  67. raise ValueError("脚本中未找到分镜详情")
  68. total_tasks = len(lens_details)
  69. self._completed_tasks = 0
  70. logger.info(f"开始生成 {total_tasks} 个视频片段")
  71. # 创建所有视频生成任务
  72. for idx, lens_item in enumerate(tqdm(lens_details, desc="提交视频任务")):
  73. self._create_video_task(lens_item, idx, total_tasks)
  74. # 保存脚本(包含任务ID)
  75. # save_json_file(video_script_data, script_path)
  76. # 等待所有任务完成
  77. return video_script_data, self._wait_for_completion(total_tasks)
  78. def _create_video_task(
  79. self,
  80. lens_item: Dict[str, Any],
  81. task_index: int,
  82. total_tasks: int
  83. ) -> None:
  84. """
  85. 创建单个视频生成任务
  86. Args:
  87. lens_item: 分镜详情字典
  88. script_path: 脚本文件路径
  89. task_index: 任务索引
  90. total_tasks: 总任务数
  91. """
  92. lens_id = lens_item.get("idx")
  93. motion_prompt = lens_item.get("motion_desc")
  94. image_url = lens_item.get("ff_path")
  95. lens_duration = 4
  96. if not all([motion_prompt, image_url, lens_duration]):
  97. logger.warning(f"分镜 {lens_id} 缺少必要信息,跳过")
  98. return
  99. try:
  100. # 构建生成参数
  101. gen_params = self._build_gen_params(lens_duration)
  102. video_filename = os.path.basename(lens_item["ff_path"]).replace(".png", ".mp4")
  103. logger.info(f"正在生成视频片段 {lens_id}: {video_filename}")
  104. # 创建完成事件
  105. completion_event = threading.Event()
  106. # 包装回调函数
  107. wrapped_callback = self._create_callback_wrapper(
  108. handle_video_result,
  109. completion_event,
  110. task_index,
  111. total_tasks
  112. )
  113. # 提交异步任务
  114. task_id = video_create.create_video_task_async(
  115. prompt=motion_prompt,
  116. image_url=image_url,
  117. gen_params=gen_params,
  118. filename=video_filename,
  119. callback=wrapped_callback
  120. )
  121. if task_id:
  122. lens_item["clip_path"] = f"./output/{video_filename}"
  123. lens_item["task_id"] = task_id
  124. logger.info(f"视频任务 {task_index + 1}/{total_tasks} 已提交: {task_id}")
  125. else:
  126. logger.error(f"视频任务 {task_index + 1}/{total_tasks} 提交失败")
  127. except Exception as e:
  128. logger.error(f"创建视频任务时出错: {e}")
  129. def _build_gen_params(self, duration: float) -> str:
  130. """
  131. 构建视频生成参数字符串
  132. Args:
  133. duration: 视频时长
  134. Returns:
  135. 参数字符串
  136. """
  137. return (
  138. f"--rs {self.config.video_resolution} "
  139. f"--rt {self.config.video_ratio} "
  140. f"--dur {duration} "
  141. f"--wm {'true' if self.config.watermark_enabled else 'false'} "
  142. f"--cf {'true' if self.config.crop_frame else 'false'}"
  143. )
  144. def _create_callback_wrapper(
  145. self,
  146. original_callback: Callable,
  147. event: threading.Event,
  148. task_index: int,
  149. total_tasks: int
  150. ) -> Callable:
  151. """
  152. 创建回调函数包装器
  153. Args:
  154. original_callback: 原始回调函数
  155. event: 完成事件
  156. task_index: 任务索引
  157. total_tasks: 总任务数
  158. Returns:
  159. 包装后的回调函数
  160. """
  161. def wrapper(*args, **kwargs):
  162. try:
  163. # 调用原始回调
  164. if original_callback:
  165. original_callback(*args, **kwargs)
  166. except Exception as e:
  167. logger.error(f"回调函数执行出错: {e}")
  168. finally:
  169. # 标记任务完成
  170. event.set()
  171. with self._lock:
  172. self._completed_tasks += 1
  173. logger.info(f"视频任务 {task_index + 1}/{total_tasks} 完成 "
  174. f"({self._completed_tasks}/{total_tasks})")
  175. return wrapper
  176. def _wait_for_completion(self, total_tasks: int) -> bool:
  177. """
  178. 等待所有任务完成
  179. Args:
  180. total_tasks: 总任务数
  181. Returns:
  182. 是否全部成功完成
  183. """
  184. logger.info(f"等待 {total_tasks} 个视频任务完成...")
  185. start_time = time.time()
  186. timeout = self.config.video_generation_timeout
  187. while self._completed_tasks < total_tasks:
  188. elapsed = time.time() - start_time
  189. if elapsed > timeout:
  190. logger.error(f"视频生成超时({timeout}秒),"
  191. f"已完成 {self._completed_tasks}/{total_tasks} 个任务")
  192. return False
  193. remaining = total_tasks - self._completed_tasks
  194. if remaining > 0:
  195. logger.info(f"等待中... 剩余任务: {remaining}, "
  196. f"已耗时: {int(elapsed)}秒")
  197. time.sleep(self.config.polling_interval)
  198. logger.info("所有视频生成任务已完成")
  199. return True
  200. video_generator = VideoClipGenerator(VideoGenerationConfig())