main.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """
  2. refer_video_create 主程序
  3. 基于参考视频创建视频的完整任务流
  4. 包括:
  5. 1. 从参考视频创建脚本
  6. 2. 优化脚本提示词(生图提示词和生视频提示词)
  7. 3. 基于image_prompt生成分镜
  8. 4. 基于video_prompt和分镜生成视频
  9. 5. 拼接所有视频片段
  10. """
  11. import time
  12. import argparse
  13. import asyncio
  14. import logging
  15. from pathlib import Path
  16. from taskflow import TaskManager, FileIOHandler, RunManager
  17. from taskflow import setup_logger
  18. from .pipeline.refer_video_create_pipeline import ReferVideoCreatePipeline
  19. logger = setup_logger("examples.refer_video_create.main", level=logging.INFO)
  20. def main():
  21. """主程序"""
  22. start_time = time.time()
  23. logger.info("=== refer_video_create 示例 ===\n")
  24. # 解析命令行参数
  25. parser = argparse.ArgumentParser(description="refer_video_create 主流程")
  26. parser.add_argument("--video-url", type=str, required=True, help="参考视频URL或路径")
  27. parser.add_argument("--user-prompt", type=str, required=False, default=None, help="用户提示词(可选)")
  28. parser.add_argument("--size", type=str, default="1440x2560", help="生成分镜图片的尺寸(默认: 1440x2560)")
  29. parser.add_argument(
  30. "--refer-image",
  31. nargs="*",
  32. default=None,
  33. help="参考图片路径(可选),所有分镜都会参考这张图片生成"
  34. )
  35. parser.add_argument("--max-retries", type=int, default=3, help="最大重试次数")
  36. parser.add_argument(
  37. "--resume",
  38. action="store_true",
  39. help="继续执行上次失败的运行(自动查找最新的未完成运行)"
  40. )
  41. parser.add_argument(
  42. "--run-id",
  43. type=str,
  44. default=None,
  45. help="指定要使用的运行ID(用于继续执行特定运行)"
  46. )
  47. parser.add_argument(
  48. "--new-run",
  49. action="store_true",
  50. help="强制创建新的运行目录(即使存在未完成的运行)"
  51. )
  52. args = parser.parse_args()
  53. # 1. 创建运行管理器
  54. run_manager = RunManager(base_output_dir="output")
  55. # 确定运行目录策略
  56. if args.new_run:
  57. # 强制创建新运行
  58. run_output_dir = run_manager.create_run_directory()
  59. run_id = run_manager.get_run_id()
  60. logger.info("创建新的运行目录")
  61. elif args.run_id:
  62. # 使用指定的运行ID
  63. run_output_dir = run_manager.create_run_directory(run_id=args.run_id)
  64. run_id = run_manager.get_run_id()
  65. logger.info(f"使用指定的运行ID: {run_id}")
  66. elif args.resume:
  67. # 自动查找最新的未完成运行
  68. runs = run_manager.list_runs()
  69. if not runs:
  70. logger.warning("没有找到已存在的运行,创建新运行目录")
  71. run_output_dir = run_manager.create_run_directory()
  72. run_id = run_manager.get_run_id()
  73. else:
  74. # 查找未完成的运行(检查task_state.json中是否有失败的步骤)
  75. resume_run_id = None
  76. for run_info in runs:
  77. run_path = Path(run_info["path"])
  78. state_file = run_path / "task_state.json"
  79. if state_file.exists():
  80. try:
  81. import json as json_module
  82. with open(state_file, 'r', encoding='utf-8') as f:
  83. state = json_module.load(f)
  84. # 检查是否有失败的步骤或待执行的步骤
  85. steps = state.get("steps", {})
  86. has_failed = any(
  87. step.get("status") == "failed"
  88. for step in steps.values()
  89. )
  90. has_pending = any(
  91. step.get("status") in ["pending", "running"]
  92. for step in steps.values()
  93. )
  94. if has_failed or has_pending:
  95. resume_run_id = run_info["run_id"]
  96. logger.info(f"找到未完成的运行: {resume_run_id}")
  97. break
  98. except Exception as e:
  99. logger.warning(f"检查运行 {run_info['run_id']} 状态时出错: {e}")
  100. continue
  101. if resume_run_id:
  102. run_output_dir = run_manager.create_run_directory(run_id=resume_run_id)
  103. run_id = run_manager.get_run_id()
  104. logger.info(f"继续执行运行: {run_id}")
  105. else:
  106. logger.info("没有找到未完成的运行,创建新运行目录")
  107. run_output_dir = run_manager.create_run_directory()
  108. run_id = run_manager.get_run_id()
  109. else:
  110. # 默认行为:创建新运行
  111. run_output_dir = run_manager.create_run_directory()
  112. run_id = run_manager.get_run_id()
  113. logger.info("创建新的运行目录")
  114. logger.info(f"运行ID: {run_id}")
  115. logger.info(f"输出目录: {run_output_dir}")
  116. # 2. 创建文件I/O处理器
  117. io_handler = FileIOHandler()
  118. # 3. 创建任务管理器
  119. state_file = str(Path(run_output_dir) / "task_state.json")
  120. cache_dir = str(Path(run_output_dir) / "task_cache")
  121. manager = TaskManager(
  122. state_file=state_file,
  123. cache_dir=cache_dir
  124. )
  125. # 4. 创建视频创作任务流
  126. pipeline = ReferVideoCreatePipeline(io_handler, run_output_dir, manager)
  127. # 5. 注册步骤
  128. logger.info("注册步骤...\n")
  129. # TaskManager 现在原生支持异步函数,无需包装器
  130. # 创建异步包装函数(lambda 不能是异步的)
  131. async def step1_func():
  132. return await pipeline.step1_create_script(
  133. video_url=args.video_url,
  134. user_prompt=args.user_prompt
  135. )
  136. async def step2_func():
  137. return await pipeline.step2_optimize_prompts()
  138. async def step3_func():
  139. """
  140. 步骤3:生成分镜图片
  141. 如果指定了 --refer-image,所有分镜都会参考这张图片生成
  142. """
  143. refer_image = args.refer_image
  144. if refer_image:
  145. for image_item in refer_image:
  146. # 检查文件是否存在
  147. refer_image_path = Path(image_item)
  148. if not refer_image_path.exists():
  149. logger.warning(f"参考图片不存在: {image_item},将不使用参考图片")
  150. refer_image = None
  151. else:
  152. logger.info(f"使用参考图片: {image_item}")
  153. else:
  154. logger.info("不使用参考图片")
  155. return await pipeline.step3_generate_storyboard(
  156. size=args.size,
  157. refer_image=refer_image
  158. )
  159. async def step4_func():
  160. return await pipeline.step4_generate_video_clips()
  161. async def step5_func():
  162. return await pipeline.step5_concat_clips()
  163. manager.register_step(
  164. "step1",
  165. step1_func,
  166. force_rerun=False
  167. )
  168. manager.register_step(
  169. "step2",
  170. step2_func,
  171. depends_on=["step1"],
  172. force_rerun=False
  173. )
  174. manager.register_step(
  175. "step3",
  176. step3_func,
  177. depends_on=["step2"],
  178. force_rerun=False
  179. )
  180. manager.register_step(
  181. "step4",
  182. step4_func,
  183. depends_on=["step3"],
  184. force_rerun=False
  185. )
  186. manager.register_step(
  187. "step5",
  188. step5_func,
  189. depends_on=["step4"],
  190. force_rerun=False
  191. )
  192. # 6. 显示当前状态
  193. summary = manager.get_summary()
  194. logger.info(f"总步骤数: {summary['total']}")
  195. logger.info(f"待执行: {summary['pending']}")
  196. # 7. 执行所有步骤(使用异步版本,性能更优)
  197. logger.info("\n开始执行视频创作...")
  198. async def run_pipeline_async():
  199. """异步执行所有步骤"""
  200. # 使用异步版本的 run_all_async,性能比 ThreadPoolExecutor 更优
  201. await manager.run_all_async(
  202. step_order=["step1", "step2", "step3", "step4", "step5"],
  203. continue_on_error=False
  204. )
  205. logger.info("\n视频创作完成!")
  206. # 重试机制:执行1次,如果失败则重试 max_retries 次
  207. # 总共最多执行 max_retries + 1 次
  208. last_exception = None
  209. total_attempts = args.max_retries + 1 # 1次正常执行 + max_retries 次重试
  210. for attempt in range(total_attempts):
  211. try:
  212. if attempt == 0:
  213. logger.info(f"第 {attempt + 1} 次执行...")
  214. else:
  215. # 重试前等待一段时间,避免快速连续失败
  216. wait_time = min(2 ** (attempt - 1), 60) # 指数退避,最多60秒
  217. logger.info(f"等待 {wait_time} 秒后开始第 {attempt + 1} 次执行(第 {attempt} 次重试)...")
  218. time.sleep(wait_time)
  219. # 运行异步流程
  220. asyncio.run(run_pipeline_async())
  221. # 执行成功,退出重试循环
  222. if attempt > 0:
  223. logger.info(f"✅ 第 {attempt + 1} 次执行成功(经过 {attempt} 次重试)")
  224. break
  225. except Exception as e:
  226. last_exception = e
  227. if attempt == 0:
  228. logger.error(f"❌ 第 1 次执行失败: {e}", exc_info=True)
  229. else:
  230. logger.error(f"❌ 第 {attempt + 1} 次执行失败(第 {attempt} 次重试): {e}", exc_info=True)
  231. # 如果是最后一次尝试,记录最终失败
  232. if attempt == total_attempts - 1:
  233. logger.error(f"\n✗ 执行失败:经过 {total_attempts} 次尝试(1次正常执行 + {args.max_retries} 次重试)后仍然失败")
  234. raise last_exception
  235. # 否则继续重试
  236. continue
  237. # 如果所有重试都失败,这里不应该到达(因为上面已经 raise)
  238. # 但为了安全起见,还是检查一下
  239. if last_exception is not None:
  240. raise last_exception
  241. logger.info(f"\n所有结果已保存到: {run_output_dir}")
  242. end_time = time.time()
  243. logger.info(f"执行时间: {end_time - start_time} 秒")
  244. # python -m examples.refer_video_create.main --video-url "video.mp4" --user-prompt "请开始执行你的任务"
  245. if __name__ == "__main__":
  246. main()