| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- """
- refer_video_create 主程序
- 基于参考视频创建视频的完整任务流
- 包括:
- 1. 从参考视频创建脚本
- 2. 优化脚本提示词(生图提示词和生视频提示词)
- 3. 基于image_prompt生成分镜
- 4. 基于video_prompt和分镜生成视频
- 5. 拼接所有视频片段
- """
- import time
- import argparse
- import asyncio
- import logging
- from pathlib import Path
- from taskflow import TaskManager, FileIOHandler, RunManager
- from taskflow import setup_logger
- from .pipeline.refer_video_create_pipeline import ReferVideoCreatePipeline
- logger = setup_logger("examples.refer_video_create.main", level=logging.INFO)
- def main():
- """主程序"""
- start_time = time.time()
- logger.info("=== refer_video_create 示例 ===\n")
- # 解析命令行参数
- parser = argparse.ArgumentParser(description="refer_video_create 主流程")
- parser.add_argument("--video-url", type=str, required=True, help="参考视频URL或路径")
- parser.add_argument("--user-prompt", type=str, required=False, default=None, help="用户提示词(可选)")
- parser.add_argument("--size", type=str, default="1440x2560", help="生成分镜图片的尺寸(默认: 1440x2560)")
- parser.add_argument(
- "--refer-image",
- nargs="*",
- default=None,
- help="参考图片路径(可选),所有分镜都会参考这张图片生成"
- )
- parser.add_argument("--max-retries", type=int, default=3, help="最大重试次数")
- parser.add_argument(
- "--resume",
- action="store_true",
- help="继续执行上次失败的运行(自动查找最新的未完成运行)"
- )
- parser.add_argument(
- "--run-id",
- type=str,
- default=None,
- help="指定要使用的运行ID(用于继续执行特定运行)"
- )
- parser.add_argument(
- "--new-run",
- action="store_true",
- help="强制创建新的运行目录(即使存在未完成的运行)"
- )
- args = parser.parse_args()
- # 1. 创建运行管理器
- run_manager = RunManager(base_output_dir="output")
-
- # 确定运行目录策略
- if args.new_run:
- # 强制创建新运行
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- logger.info("创建新的运行目录")
- elif args.run_id:
- # 使用指定的运行ID
- run_output_dir = run_manager.create_run_directory(run_id=args.run_id)
- run_id = run_manager.get_run_id()
- logger.info(f"使用指定的运行ID: {run_id}")
- elif args.resume:
- # 自动查找最新的未完成运行
- runs = run_manager.list_runs()
- if not runs:
- logger.warning("没有找到已存在的运行,创建新运行目录")
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- else:
- # 查找未完成的运行(检查task_state.json中是否有失败的步骤)
- resume_run_id = None
- for run_info in runs:
- run_path = Path(run_info["path"])
- state_file = run_path / "task_state.json"
-
- if state_file.exists():
- try:
- import json as json_module
- with open(state_file, 'r', encoding='utf-8') as f:
- state = json_module.load(f)
-
- # 检查是否有失败的步骤或待执行的步骤
- steps = state.get("steps", {})
- has_failed = any(
- step.get("status") == "failed"
- for step in steps.values()
- )
- has_pending = any(
- step.get("status") in ["pending", "running"]
- for step in steps.values()
- )
-
- if has_failed or has_pending:
- resume_run_id = run_info["run_id"]
- logger.info(f"找到未完成的运行: {resume_run_id}")
- break
- except Exception as e:
- logger.warning(f"检查运行 {run_info['run_id']} 状态时出错: {e}")
- continue
-
- if resume_run_id:
- run_output_dir = run_manager.create_run_directory(run_id=resume_run_id)
- run_id = run_manager.get_run_id()
- logger.info(f"继续执行运行: {run_id}")
- else:
- logger.info("没有找到未完成的运行,创建新运行目录")
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- else:
- # 默认行为:创建新运行
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- logger.info("创建新的运行目录")
- logger.info(f"运行ID: {run_id}")
- logger.info(f"输出目录: {run_output_dir}")
- # 2. 创建文件I/O处理器
- io_handler = FileIOHandler()
- # 3. 创建任务管理器
- state_file = str(Path(run_output_dir) / "task_state.json")
- cache_dir = str(Path(run_output_dir) / "task_cache")
- manager = TaskManager(
- state_file=state_file,
- cache_dir=cache_dir
- )
- # 4. 创建视频创作任务流
- pipeline = ReferVideoCreatePipeline(io_handler, run_output_dir, manager)
- # 5. 注册步骤
- logger.info("注册步骤...\n")
-
- # TaskManager 现在原生支持异步函数,无需包装器
- # 创建异步包装函数(lambda 不能是异步的)
- async def step1_func():
- return await pipeline.step1_create_script(
- video_url=args.video_url,
- user_prompt=args.user_prompt
- )
-
- async def step2_func():
- return await pipeline.step2_optimize_prompts()
-
- async def step3_func():
- """
- 步骤3:生成分镜图片
- 如果指定了 --refer-image,所有分镜都会参考这张图片生成
- """
- refer_image = args.refer_image
-
- if refer_image:
- for image_item in refer_image:
- # 检查文件是否存在
- refer_image_path = Path(image_item)
- if not refer_image_path.exists():
- logger.warning(f"参考图片不存在: {image_item},将不使用参考图片")
- refer_image = None
- else:
- logger.info(f"使用参考图片: {image_item}")
- else:
- logger.info("不使用参考图片")
-
- return await pipeline.step3_generate_storyboard(
- size=args.size,
- refer_image=refer_image
- )
- async def step4_func():
- return await pipeline.step4_generate_video_clips()
- async def step5_func():
- return await pipeline.step5_concat_clips()
- manager.register_step(
- "step1",
- step1_func,
- force_rerun=False
- )
- manager.register_step(
- "step2",
- step2_func,
- depends_on=["step1"],
- force_rerun=False
- )
- manager.register_step(
- "step3",
- step3_func,
- depends_on=["step2"],
- force_rerun=False
- )
- manager.register_step(
- "step4",
- step4_func,
- depends_on=["step3"],
- force_rerun=False
- )
- manager.register_step(
- "step5",
- step5_func,
- depends_on=["step4"],
- force_rerun=False
- )
- # 6. 显示当前状态
- summary = manager.get_summary()
- logger.info(f"总步骤数: {summary['total']}")
- logger.info(f"待执行: {summary['pending']}")
- # 7. 执行所有步骤(使用异步版本,性能更优)
- logger.info("\n开始执行视频创作...")
-
- async def run_pipeline_async():
- """异步执行所有步骤"""
- # 使用异步版本的 run_all_async,性能比 ThreadPoolExecutor 更优
- await manager.run_all_async(
- step_order=["step1", "step2", "step3", "step4", "step5"],
- continue_on_error=False
- )
- logger.info("\n视频创作完成!")
-
- # 重试机制:执行1次,如果失败则重试 max_retries 次
- # 总共最多执行 max_retries + 1 次
- last_exception = None
- total_attempts = args.max_retries + 1 # 1次正常执行 + max_retries 次重试
-
- for attempt in range(total_attempts):
- try:
- if attempt == 0:
- logger.info(f"第 {attempt + 1} 次执行...")
- else:
- # 重试前等待一段时间,避免快速连续失败
- wait_time = min(2 ** (attempt - 1), 60) # 指数退避,最多60秒
- logger.info(f"等待 {wait_time} 秒后开始第 {attempt + 1} 次执行(第 {attempt} 次重试)...")
- time.sleep(wait_time)
-
- # 运行异步流程
- asyncio.run(run_pipeline_async())
- # 执行成功,退出重试循环
- if attempt > 0:
- logger.info(f"✅ 第 {attempt + 1} 次执行成功(经过 {attempt} 次重试)")
- break
-
- except Exception as e:
- last_exception = e
- if attempt == 0:
- logger.error(f"❌ 第 1 次执行失败: {e}", exc_info=True)
- else:
- logger.error(f"❌ 第 {attempt + 1} 次执行失败(第 {attempt} 次重试): {e}", exc_info=True)
-
- # 如果是最后一次尝试,记录最终失败
- if attempt == total_attempts - 1:
- logger.error(f"\n✗ 执行失败:经过 {total_attempts} 次尝试(1次正常执行 + {args.max_retries} 次重试)后仍然失败")
- raise last_exception
- # 否则继续重试
- continue
-
- # 如果所有重试都失败,这里不应该到达(因为上面已经 raise)
- # 但为了安全起见,还是检查一下
- if last_exception is not None:
- raise last_exception
- logger.info(f"\n所有结果已保存到: {run_output_dir}")
- end_time = time.time()
- logger.info(f"执行时间: {end_time - start_time} 秒")
- # python -m examples.refer_video_create.main --video-url "video.mp4" --user-prompt "请开始执行你的任务"
- if __name__ == "__main__":
- main()
|