| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378 |
- """
- idea2video 主程序
- 从idea到video的完整任务流
- 包括:
- 1. 从idea到story
- 2. 从story到script
- 3. 从script到video
- """
- import time
- import argparse
- import asyncio
- import logging
- import json
- from pathlib import Path
- from typing import Dict, Optional
- from taskflow import TaskManager, FileIOHandler, RunManager
- from taskflow import setup_logger
- from .pipeline.idea2video_pipeline import Idea2VideoPipeline
- logger = setup_logger("examples.video_create.main", level=logging.INFO)
- def main():
- """主程序"""
- start_time = time.time()
- logger.info("=== idea2video 示例 ===\n")
- # 解析命令行参数
- parser = argparse.ArgumentParser(description="idea2video 主流程")
- parser.add_argument("--idea", type=str, required=True, help="创意")
- parser.add_argument("--user_requirement", type=str, required=False, help="用户要求")
- parser.add_argument("--max_retries", type=int, default=3, help="最大重试次数")
- parser.add_argument(
- "--resume",
- action="store_true",
- help="继续执行上次失败的运行(自动查找最新的未完成运行)"
- )
- parser.add_argument(
- "--run-id",
- type=str,
- default=None,
- help="指定要使用的运行ID(用于继续执行特定运行)"
- )
- parser.add_argument(
- "--new-run",
- action="store_true",
- help="强制创建新的运行目录(即使存在未完成的运行)"
- )
- parser.add_argument(
- "--refer-image-map",
- type=str,
- default=None,
- help="角色参考图片映射(JSON字符串),格式: '{\"角色名\": [\"图片路径1\", \"图片路径2\"]}'"
- )
- parser.add_argument(
- "--refer-image-map-file",
- type=str,
- default=None,
- help="角色参考图片映射文件路径(JSON格式),格式: {\"角色名\": [\"图片路径1\", \"图片路径2\"]}"
- )
- args = parser.parse_args()
- # 1. 创建运行管理器
- run_manager = RunManager(base_output_dir="output")
-
- # 确定运行目录策略
- if args.new_run:
- # 强制创建新运行
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- logger.info("创建新的运行目录")
- elif args.run_id:
- # 使用指定的运行ID
- run_output_dir = run_manager.create_run_directory(run_id=args.run_id)
- run_id = run_manager.get_run_id()
- logger.info(f"使用指定的运行ID: {run_id}")
- elif args.resume:
- # 自动查找最新的未完成运行
- runs = run_manager.list_runs()
- if not runs:
- logger.warning("没有找到已存在的运行,创建新运行目录")
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- else:
- # 查找未完成的运行(检查task_state.json中是否有失败的步骤)
- resume_run_id = None
- for run_info in runs:
- run_path = Path(run_info["path"])
- state_file = run_path / "task_state.json"
-
- if state_file.exists():
- try:
- import json
- with open(state_file, 'r', encoding='utf-8') as f:
- state = json.load(f)
-
- # 检查是否有失败的步骤或待执行的步骤
- steps = state.get("steps", {})
- has_failed = any(
- step.get("status") == "failed"
- for step in steps.values()
- )
- has_pending = any(
- step.get("status") in ["pending", "running"]
- for step in steps.values()
- )
-
- if has_failed or has_pending:
- resume_run_id = run_info["run_id"]
- logger.info(f"找到未完成的运行: {resume_run_id}")
- break
- except Exception as e:
- logger.warning(f"检查运行 {run_info['run_id']} 状态时出错: {e}")
- continue
-
- if resume_run_id:
- run_output_dir = run_manager.create_run_directory(run_id=resume_run_id)
- run_id = run_manager.get_run_id()
- logger.info(f"继续执行运行: {run_id}")
- else:
- logger.info("没有找到未完成的运行,创建新运行目录")
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- else:
- # 默认行为:创建新运行
- run_output_dir = run_manager.create_run_directory()
- run_id = run_manager.get_run_id()
- logger.info("创建新的运行目录")
- logger.info(f"运行ID: {run_id}")
- logger.info(f"输出目录: {run_output_dir}")
- # 2. 创建文件I/O处理器
- io_handler = FileIOHandler()
- # 3. 创建任务管理器
- state_file = str(Path(run_output_dir) / "task_state.json")
- cache_dir = str(Path(run_output_dir) / "task_cache")
- manager = TaskManager(
- state_file=state_file,
- cache_dir=cache_dir
- )
- # 4. 创建视频创作任务流
- pipeline = Idea2VideoPipeline(io_handler, run_output_dir, manager)
- # 5. 注册步骤
- logger.info("注册步骤...\n")
-
- # TaskManager 现在原生支持异步函数,无需包装器
- # 创建异步包装函数(lambda 不能是异步的)
- async def step1_func():
- return await pipeline.step1_develop_story(idea=args.idea, user_requirement=args.user_requirement)
-
- async def step2_func():
- return await pipeline.step2_develop_script(user_requirement=args.user_requirement)
-
- async def step3_func():
- return await pipeline.step3_extract_characters()
-
- async def step4_func():
- return await pipeline.step4_create_storyboard(user_requirement=args.user_requirement)
- async def step5_func():
- """
- 步骤5:生成角色肖像
- 参考图片映射的优先级:
- 1. 命令行参数 --refer-image-map(JSON字符串)
- 2. 命令行参数 --refer-image-map-file(JSON文件路径)
- 3. 从 step3 结果中读取角色数据中的 refer_image 字段
- 4. None(不使用参考图片)
- """
- # 在函数内部重新导入 json,避免嵌套函数作用域问题
- import json
-
- refer_image_map: Optional[Dict[str, list[str]]] = None
-
- # 优先级1:从命令行参数 --refer-image-map 读取(JSON字符串)
- if args.refer_image_map:
- try:
- refer_image_map = json.loads(args.refer_image_map)
- if not isinstance(refer_image_map, dict):
- raise ValueError("refer_image_map 必须是字典类型")
- logger.info(f"从命令行参数读取参考图片映射: {refer_image_map}")
- except json.JSONDecodeError as e:
- logger.error(f"解析 --refer-image-map JSON 字符串失败: {e}")
- raise ValueError(f"无效的 JSON 格式: {e}")
-
- # 优先级2:从命令行参数 --refer-image-map-file 读取(JSON文件)
- elif args.refer_image_map_file:
- try:
- map_file = Path(args.refer_image_map_file)
- if not map_file.exists():
- logger.warning(f"参考图片映射文件不存在: {map_file},将尝试从角色数据中读取")
- # 不设置 refer_image_map,让后续逻辑继续处理
- else:
- with open(map_file, 'r', encoding='utf-8') as f:
- refer_image_map = json.load(f)
- if not isinstance(refer_image_map, dict):
- raise ValueError("refer_image_map 必须是字典类型")
- logger.info(f"从文件读取参考图片映射: {refer_image_map}")
- except json.JSONDecodeError as e:
- logger.error(f"解析参考图片映射文件失败: {e}")
- raise ValueError(f"无效的 JSON 格式: {e}")
- except Exception as e:
- logger.warning(f"读取参考图片映射文件失败: {e},将尝试从角色数据中读取")
-
- # 优先级4:默认不使用参考图片
- if refer_image_map is None:
- logger.info("不使用参考图片")
-
- return await pipeline.step5_generate_portrait(
- size="2048x2048",
- refer_image_map=refer_image_map,
- style="写实"
- )
- async def step6_func():
- return await pipeline.step6_create_camera_tree()
- async def step7_func():
- return await pipeline.step7_generate_video_frames()
- async def step8_func():
- return await pipeline.step8_generate_video()
- async def step9_func():
- return await pipeline.step9_concat_clip()
- manager.register_step(
- "step1",
- step1_func,
- force_rerun=False
- )
- manager.register_step(
- "step2",
- step2_func,
- depends_on=["step1"],
- force_rerun=False
- )
- manager.register_step(
- "step3",
- step3_func,
- depends_on=["step1"],
- force_rerun=False
- )
- manager.register_step(
- "step4",
- step4_func,
- depends_on=["step2", "step3"],
- force_rerun=False
- )
- manager.register_step(
- "step5",
- step5_func,
- depends_on=["step3"],
- force_rerun=False
- )
- manager.register_step(
- "step6",
- step6_func,
- depends_on=["step4"],
- force_rerun=False
- )
- manager.register_step(
- "step7",
- step7_func,
- depends_on=["step5", "step6"],
- force_rerun=False
- )
- manager.register_step(
- "step8",
- step8_func,
- depends_on=["step7"],
- force_rerun=False
- )
- manager.register_step(
- "step9",
- step9_func,
- depends_on=["step8"],
- force_rerun=False
- )
- # 6. 显示当前状态
- summary = manager.get_summary()
- logger.info(f"总步骤数: {summary['total']}")
- logger.info(f"待执行: {summary['pending']}")
- # # 7.0. 执行所有步骤(同步版本&多线程并行版本-V1)
- # logger.info("\n开始执行视频创作...")
- # try:
- # manager.run_all(step_order=["step1", "step2", "step3", "step4"])
- # # 多线程并行执行所有步骤
- # manager.run_all_parallel(
- # step_order=["step1", "step2", "step3", "step4"],
- # max_workers=2,
- # continue_on_error=False
- # )
- # logger.info("\n视频创作完成!")
- # except Exception as e:
- # logger.error(f"\n✗ 执行失败: {e}", exc_info=True)
- # raise
- # 7. 执行所有步骤(使用异步版本,性能更优)
- logger.info("\n开始执行视频创作...")
-
- async def run_pipeline_async():
- """异步执行所有步骤"""
- # 使用异步版本的 run_all_async,性能比 ThreadPoolExecutor 更优
- await manager.run_all_async(
- step_order=["step1", "step2", "step3", "step4", "step5", "step6", "step7", "step8", "step9"],
- continue_on_error=False
- )
- logger.info("\n视频创作完成!")
-
- # 重试机制:执行1次,如果失败则重试 max_retries 次
- # 总共最多执行 max_retries + 1 次
- last_exception = None
- total_attempts = args.max_retries + 1 # 1次正常执行 + max_retries 次重试
-
- for attempt in range(total_attempts):
- try:
- if attempt == 0:
- logger.info(f"第 {attempt + 1} 次执行...")
- else:
- # 重试前等待一段时间,避免快速连续失败
- wait_time = min(2 ** (attempt - 1), 60) # 指数退避,最多60秒
- logger.info(f"等待 {wait_time} 秒后开始第 {attempt + 1} 次执行(第 {attempt} 次重试)...")
- time.sleep(wait_time)
-
- # 运行异步流程
- asyncio.run(run_pipeline_async())
- # 执行成功,退出重试循环
- if attempt > 0:
- logger.info(f"✅ 第 {attempt + 1} 次执行成功(经过 {attempt} 次重试)")
- break
-
- except Exception as e:
- last_exception = e
- if attempt == 0:
- logger.error(f"❌ 第 1 次执行失败: {e}", exc_info=True)
- else:
- logger.error(f"❌ 第 {attempt + 1} 次执行失败(第 {attempt} 次重试): {e}", exc_info=True)
-
- # 如果是最后一次尝试,记录最终失败
- if attempt == total_attempts - 1:
- logger.error(f"\n✗ 执行失败:经过 {total_attempts} 次尝试(1次正常执行 + {args.max_retries} 次重试)后仍然失败")
- raise last_exception
- # 否则继续重试
- continue
-
- # 如果所有重试都失败,这里不应该到达(因为上面已经 raise)
- # 但为了安全起见,还是检查一下
- if last_exception is not None:
- raise last_exception
- logger.info(f"\n所有结果已保存到: {run_output_dir}")
- end_time = time.time()
- logger.info(f"执行时间: {end_time - start_time} 秒")
- # python -m examples.video_create.main --idea "时尚女装" --user_requirement "设计三个场景"
- if __name__ == "__main__":
- main()
|