main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. idea2video 主程序
  3. 从idea到video的完整任务流
  4. 包括:
  5. 1. 从idea到story
  6. 2. 从story到script
  7. 3. 从script到video
  8. """
  9. import time
  10. import argparse
  11. import asyncio
  12. import logging
  13. import json
  14. from pathlib import Path
  15. from typing import Dict, Optional
  16. from taskflow import TaskManager, FileIOHandler, RunManager
  17. from taskflow import setup_logger
  18. from .pipeline.idea2video_pipeline import Idea2VideoPipeline
  19. logger = setup_logger("examples.video_create.main", level=logging.INFO)
  20. def main():
  21. """主程序"""
  22. start_time = time.time()
  23. logger.info("=== idea2video 示例 ===\n")
  24. # 解析命令行参数
  25. parser = argparse.ArgumentParser(description="idea2video 主流程")
  26. parser.add_argument("--idea", type=str, required=True, help="创意")
  27. parser.add_argument("--user_requirement", type=str, required=False, help="用户要求")
  28. parser.add_argument("--max_retries", type=int, default=3, help="最大重试次数")
  29. parser.add_argument(
  30. "--resume",
  31. action="store_true",
  32. help="继续执行上次失败的运行(自动查找最新的未完成运行)"
  33. )
  34. parser.add_argument(
  35. "--run-id",
  36. type=str,
  37. default=None,
  38. help="指定要使用的运行ID(用于继续执行特定运行)"
  39. )
  40. parser.add_argument(
  41. "--new-run",
  42. action="store_true",
  43. help="强制创建新的运行目录(即使存在未完成的运行)"
  44. )
  45. parser.add_argument(
  46. "--refer-image-map",
  47. type=str,
  48. default=None,
  49. help="角色参考图片映射(JSON字符串),格式: '{\"角色名\": [\"图片路径1\", \"图片路径2\"]}'"
  50. )
  51. parser.add_argument(
  52. "--refer-image-map-file",
  53. type=str,
  54. default=None,
  55. help="角色参考图片映射文件路径(JSON格式),格式: {\"角色名\": [\"图片路径1\", \"图片路径2\"]}"
  56. )
  57. args = parser.parse_args()
  58. # 1. 创建运行管理器
  59. run_manager = RunManager(base_output_dir="output")
  60. # 确定运行目录策略
  61. if args.new_run:
  62. # 强制创建新运行
  63. run_output_dir = run_manager.create_run_directory()
  64. run_id = run_manager.get_run_id()
  65. logger.info("创建新的运行目录")
  66. elif args.run_id:
  67. # 使用指定的运行ID
  68. run_output_dir = run_manager.create_run_directory(run_id=args.run_id)
  69. run_id = run_manager.get_run_id()
  70. logger.info(f"使用指定的运行ID: {run_id}")
  71. elif args.resume:
  72. # 自动查找最新的未完成运行
  73. runs = run_manager.list_runs()
  74. if not runs:
  75. logger.warning("没有找到已存在的运行,创建新运行目录")
  76. run_output_dir = run_manager.create_run_directory()
  77. run_id = run_manager.get_run_id()
  78. else:
  79. # 查找未完成的运行(检查task_state.json中是否有失败的步骤)
  80. resume_run_id = None
  81. for run_info in runs:
  82. run_path = Path(run_info["path"])
  83. state_file = run_path / "task_state.json"
  84. if state_file.exists():
  85. try:
  86. import json
  87. with open(state_file, 'r', encoding='utf-8') as f:
  88. state = json.load(f)
  89. # 检查是否有失败的步骤或待执行的步骤
  90. steps = state.get("steps", {})
  91. has_failed = any(
  92. step.get("status") == "failed"
  93. for step in steps.values()
  94. )
  95. has_pending = any(
  96. step.get("status") in ["pending", "running"]
  97. for step in steps.values()
  98. )
  99. if has_failed or has_pending:
  100. resume_run_id = run_info["run_id"]
  101. logger.info(f"找到未完成的运行: {resume_run_id}")
  102. break
  103. except Exception as e:
  104. logger.warning(f"检查运行 {run_info['run_id']} 状态时出错: {e}")
  105. continue
  106. if resume_run_id:
  107. run_output_dir = run_manager.create_run_directory(run_id=resume_run_id)
  108. run_id = run_manager.get_run_id()
  109. logger.info(f"继续执行运行: {run_id}")
  110. else:
  111. logger.info("没有找到未完成的运行,创建新运行目录")
  112. run_output_dir = run_manager.create_run_directory()
  113. run_id = run_manager.get_run_id()
  114. else:
  115. # 默认行为:创建新运行
  116. run_output_dir = run_manager.create_run_directory()
  117. run_id = run_manager.get_run_id()
  118. logger.info("创建新的运行目录")
  119. logger.info(f"运行ID: {run_id}")
  120. logger.info(f"输出目录: {run_output_dir}")
  121. # 2. 创建文件I/O处理器
  122. io_handler = FileIOHandler()
  123. # 3. 创建任务管理器
  124. state_file = str(Path(run_output_dir) / "task_state.json")
  125. cache_dir = str(Path(run_output_dir) / "task_cache")
  126. manager = TaskManager(
  127. state_file=state_file,
  128. cache_dir=cache_dir
  129. )
  130. # 4. 创建视频创作任务流
  131. pipeline = Idea2VideoPipeline(io_handler, run_output_dir, manager)
  132. # 5. 注册步骤
  133. logger.info("注册步骤...\n")
  134. # TaskManager 现在原生支持异步函数,无需包装器
  135. # 创建异步包装函数(lambda 不能是异步的)
  136. async def step1_func():
  137. return await pipeline.step1_develop_story(idea=args.idea, user_requirement=args.user_requirement)
  138. async def step2_func():
  139. return await pipeline.step2_develop_script(user_requirement=args.user_requirement)
  140. async def step3_func():
  141. return await pipeline.step3_extract_characters()
  142. async def step4_func():
  143. return await pipeline.step4_create_storyboard(user_requirement=args.user_requirement)
  144. async def step5_func():
  145. """
  146. 步骤5:生成角色肖像
  147. 参考图片映射的优先级:
  148. 1. 命令行参数 --refer-image-map(JSON字符串)
  149. 2. 命令行参数 --refer-image-map-file(JSON文件路径)
  150. 3. 从 step3 结果中读取角色数据中的 refer_image 字段
  151. 4. None(不使用参考图片)
  152. """
  153. # 在函数内部重新导入 json,避免嵌套函数作用域问题
  154. import json
  155. refer_image_map: Optional[Dict[str, list[str]]] = None
  156. # 优先级1:从命令行参数 --refer-image-map 读取(JSON字符串)
  157. if args.refer_image_map:
  158. try:
  159. refer_image_map = json.loads(args.refer_image_map)
  160. if not isinstance(refer_image_map, dict):
  161. raise ValueError("refer_image_map 必须是字典类型")
  162. logger.info(f"从命令行参数读取参考图片映射: {refer_image_map}")
  163. except json.JSONDecodeError as e:
  164. logger.error(f"解析 --refer-image-map JSON 字符串失败: {e}")
  165. raise ValueError(f"无效的 JSON 格式: {e}")
  166. # 优先级2:从命令行参数 --refer-image-map-file 读取(JSON文件)
  167. elif args.refer_image_map_file:
  168. try:
  169. map_file = Path(args.refer_image_map_file)
  170. if not map_file.exists():
  171. logger.warning(f"参考图片映射文件不存在: {map_file},将尝试从角色数据中读取")
  172. # 不设置 refer_image_map,让后续逻辑继续处理
  173. else:
  174. with open(map_file, 'r', encoding='utf-8') as f:
  175. refer_image_map = json.load(f)
  176. if not isinstance(refer_image_map, dict):
  177. raise ValueError("refer_image_map 必须是字典类型")
  178. logger.info(f"从文件读取参考图片映射: {refer_image_map}")
  179. except json.JSONDecodeError as e:
  180. logger.error(f"解析参考图片映射文件失败: {e}")
  181. raise ValueError(f"无效的 JSON 格式: {e}")
  182. except Exception as e:
  183. logger.warning(f"读取参考图片映射文件失败: {e},将尝试从角色数据中读取")
  184. # 优先级4:默认不使用参考图片
  185. if refer_image_map is None:
  186. logger.info("不使用参考图片")
  187. return await pipeline.step5_generate_portrait(
  188. size="2048x2048",
  189. refer_image_map=refer_image_map,
  190. style="写实"
  191. )
  192. async def step6_func():
  193. return await pipeline.step6_create_camera_tree()
  194. async def step7_func():
  195. return await pipeline.step7_generate_video_frames()
  196. async def step8_func():
  197. return await pipeline.step8_generate_video()
  198. async def step9_func():
  199. return await pipeline.step9_concat_clip()
  200. manager.register_step(
  201. "step1",
  202. step1_func,
  203. force_rerun=False
  204. )
  205. manager.register_step(
  206. "step2",
  207. step2_func,
  208. depends_on=["step1"],
  209. force_rerun=False
  210. )
  211. manager.register_step(
  212. "step3",
  213. step3_func,
  214. depends_on=["step1"],
  215. force_rerun=False
  216. )
  217. manager.register_step(
  218. "step4",
  219. step4_func,
  220. depends_on=["step2", "step3"],
  221. force_rerun=False
  222. )
  223. manager.register_step(
  224. "step5",
  225. step5_func,
  226. depends_on=["step3"],
  227. force_rerun=False
  228. )
  229. manager.register_step(
  230. "step6",
  231. step6_func,
  232. depends_on=["step4"],
  233. force_rerun=False
  234. )
  235. manager.register_step(
  236. "step7",
  237. step7_func,
  238. depends_on=["step5", "step6"],
  239. force_rerun=False
  240. )
  241. manager.register_step(
  242. "step8",
  243. step8_func,
  244. depends_on=["step7"],
  245. force_rerun=False
  246. )
  247. manager.register_step(
  248. "step9",
  249. step9_func,
  250. depends_on=["step8"],
  251. force_rerun=False
  252. )
  253. # 6. 显示当前状态
  254. summary = manager.get_summary()
  255. logger.info(f"总步骤数: {summary['total']}")
  256. logger.info(f"待执行: {summary['pending']}")
  257. # # 7.0. 执行所有步骤(同步版本&多线程并行版本-V1)
  258. # logger.info("\n开始执行视频创作...")
  259. # try:
  260. # manager.run_all(step_order=["step1", "step2", "step3", "step4"])
  261. # # 多线程并行执行所有步骤
  262. # manager.run_all_parallel(
  263. # step_order=["step1", "step2", "step3", "step4"],
  264. # max_workers=2,
  265. # continue_on_error=False
  266. # )
  267. # logger.info("\n视频创作完成!")
  268. # except Exception as e:
  269. # logger.error(f"\n✗ 执行失败: {e}", exc_info=True)
  270. # raise
  271. # 7. 执行所有步骤(使用异步版本,性能更优)
  272. logger.info("\n开始执行视频创作...")
  273. async def run_pipeline_async():
  274. """异步执行所有步骤"""
  275. # 使用异步版本的 run_all_async,性能比 ThreadPoolExecutor 更优
  276. await manager.run_all_async(
  277. step_order=["step1", "step2", "step3", "step4", "step5", "step6", "step7", "step8", "step9"],
  278. continue_on_error=False
  279. )
  280. logger.info("\n视频创作完成!")
  281. # 重试机制:执行1次,如果失败则重试 max_retries 次
  282. # 总共最多执行 max_retries + 1 次
  283. last_exception = None
  284. total_attempts = args.max_retries + 1 # 1次正常执行 + max_retries 次重试
  285. for attempt in range(total_attempts):
  286. try:
  287. if attempt == 0:
  288. logger.info(f"第 {attempt + 1} 次执行...")
  289. else:
  290. # 重试前等待一段时间,避免快速连续失败
  291. wait_time = min(2 ** (attempt - 1), 60) # 指数退避,最多60秒
  292. logger.info(f"等待 {wait_time} 秒后开始第 {attempt + 1} 次执行(第 {attempt} 次重试)...")
  293. time.sleep(wait_time)
  294. # 运行异步流程
  295. asyncio.run(run_pipeline_async())
  296. # 执行成功,退出重试循环
  297. if attempt > 0:
  298. logger.info(f"✅ 第 {attempt + 1} 次执行成功(经过 {attempt} 次重试)")
  299. break
  300. except Exception as e:
  301. last_exception = e
  302. if attempt == 0:
  303. logger.error(f"❌ 第 1 次执行失败: {e}", exc_info=True)
  304. else:
  305. logger.error(f"❌ 第 {attempt + 1} 次执行失败(第 {attempt} 次重试): {e}", exc_info=True)
  306. # 如果是最后一次尝试,记录最终失败
  307. if attempt == total_attempts - 1:
  308. logger.error(f"\n✗ 执行失败:经过 {total_attempts} 次尝试(1次正常执行 + {args.max_retries} 次重试)后仍然失败")
  309. raise last_exception
  310. # 否则继续重试
  311. continue
  312. # 如果所有重试都失败,这里不应该到达(因为上面已经 raise)
  313. # 但为了安全起见,还是检查一下
  314. if last_exception is not None:
  315. raise last_exception
  316. logger.info(f"\n所有结果已保存到: {run_output_dir}")
  317. end_time = time.time()
  318. logger.info(f"执行时间: {end_time - start_time} 秒")
  319. # python -m examples.video_create.main --idea "时尚女装" --user_requirement "设计三个场景"
  320. if __name__ == "__main__":
  321. main()