refer_video_create_pipeline.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. """
  2. refer_video_create 任务流
  3. 基于参考视频创建视频的完整任务流
  4. 包括:
  5. 1. 从参考视频创建脚本
  6. 2. 优化脚本提示词(生图提示词和生视频提示词)
  7. 3. 基于image_prompt生成分镜
  8. 4. 基于video_prompt和分镜生成视频
  9. 5. 拼接所有视频片段
  10. """
  11. import os
  12. import asyncio
  13. from pathlib import Path
  14. from typing import Dict, Optional
  15. from taskflow import TaskManager, FileIOHandler
  16. from ..mcps.script_create import create_script_refer_video
  17. from ..mcps.script_optimate import optimate_script
  18. from ..mcps.script_check import check_image_script, check_video_script
  19. from api_modules.ark_client_async import AsyncArkClient
  20. from api_modules.ark_image_client_async import AsyncArkImageClient
  21. from api_modules.ark_video_client_async import AsyncArkVideoClient
  22. from examples.video_create.mcps.concat_clip import concat_videos
  23. from examples.video_create.utils.tools import string_to_json, download_image, upload_file_to_tos
  24. from taskflow import get_logger
  25. logger = get_logger("examples.refer_video_create.pipeline.refer_video_create_pipeline")
  26. class ReferVideoCreatePipeline:
  27. """基于参考视频的视频创作任务流"""
  28. def __init__(self, io_handler: FileIOHandler, output_dir: str, manager: TaskManager):
  29. """
  30. 初始化视频创作任务流
  31. Args:
  32. io_handler: 文件I/O处理器
  33. output_dir: 输出目录
  34. manager: 任务管理器
  35. """
  36. self.io_handler = io_handler
  37. self.output_dir = Path(output_dir)
  38. self.output_dir.mkdir(parents=True, exist_ok=True)
  39. self.manager = manager
  40. async def step1_create_script(self, video_url: str, user_prompt: Optional[str] = None) -> Dict:
  41. """步骤1:从参考视频创建初始脚本"""
  42. if user_prompt is None:
  43. user_prompt = "请开始执行你的任务"
  44. async with AsyncArkClient() as client:
  45. script_text = await create_script_refer_video(
  46. client=client,
  47. video_url=video_url,
  48. user_prompt=user_prompt
  49. )
  50. # 解析JSON字符串
  51. script = self.io_handler.string_to_json(script_text)
  52. output_file = str(self.output_dir / "step1_script.json")
  53. await self.io_handler.write_json_async(script, output_file)
  54. return {
  55. "output_file": output_file,
  56. "data": script
  57. }
  58. async def step2_optimize_prompts(self) -> Dict:
  59. """步骤2:优化脚本提示词(并行优化生图提示词和生视频提示词)"""
  60. previous_output = self.manager.load_step_output("step1")
  61. if previous_output is None:
  62. raise ValueError("步骤1未完成,无法优化提示词")
  63. script = previous_output["data"]
  64. async def optimize_single_lens(client: AsyncArkClient, lens: Dict) -> None:
  65. """优化单个镜头的提示词"""
  66. lens_id = lens.get("lens_id")
  67. lens_params = lens.get("lens_params", "")
  68. core_vision = lens.get("core_vision", "")
  69. # 构建优化提示词:lens_params + core_vision
  70. prompt_text = f"{lens_params} {core_vision}".strip()
  71. # 并行优化生图提示词和生视频提示词
  72. async def optimize_image_prompt():
  73. optimized = await optimate_script(
  74. client=client,
  75. user_prompt=prompt_text,
  76. prompt_type="image"
  77. )
  78. lens["image_prompt"] = optimized.strip()
  79. logger.info(f"镜头 {lens_id} 的生图提示词优化完成")
  80. async def optimize_video_prompt():
  81. optimized = await optimate_script(
  82. client=client,
  83. user_prompt=prompt_text,
  84. prompt_type="video"
  85. )
  86. lens["video_prompt"] = optimized.strip()
  87. logger.info(f"镜头 {lens_id} 的生视频提示词优化完成")
  88. # 并行执行两个优化任务
  89. await asyncio.gather(optimize_image_prompt(), optimize_video_prompt())
  90. async with AsyncArkClient() as client:
  91. # 并行处理所有镜头
  92. tasks = [
  93. optimize_single_lens(client, lens)
  94. for lens in script["lens_details"]
  95. ]
  96. await asyncio.gather(*tasks)
  97. output_file = str(self.output_dir / "step2_optimized_script.json")
  98. await self.io_handler.write_json_async(script, output_file)
  99. return {
  100. "output_file": output_file,
  101. "data": script
  102. }
  103. async def step3_generate_storyboard(
  104. self,
  105. size: Optional[str] = "1440x2560",
  106. refer_image: Optional[list[str]] = None
  107. ) -> Dict:
  108. """
  109. 步骤3:基于image_prompt和用户指定的参考图片生成分镜图片
  110. Args:
  111. size: 生成图片的尺寸,默认为 "1440x2560"
  112. refer_image: 参考图片路径(可选),所有分镜都会参考这张图片生成
  113. 如果为None,则不使用参考图片
  114. """
  115. previous_output = self.manager.load_step_output("step2")
  116. if previous_output is None:
  117. raise ValueError("步骤2未完成,无法生成分镜")
  118. script = previous_output["data"]
  119. # 确保storyboard目录存在
  120. storyboard_dir = self.output_dir / "storyboard"
  121. storyboard_dir.mkdir(parents=True, exist_ok=True)
  122. # 准备参考图片列表(如果提供了参考图片)
  123. reference_images = None
  124. if refer_image is not None and isinstance(refer_image, str):
  125. # 确保是列表格式
  126. reference_images = [refer_image]
  127. logger.info(f"所有分镜将参考图片: {refer_image}")
  128. elif refer_image is not None and isinstance(refer_image, list):
  129. reference_images = refer_image
  130. logger.info(f"所有分镜将参考图片: {refer_image}")
  131. else:
  132. logger.info("不使用参考图片")
  133. async def generate_single_storyboard(
  134. image_client: AsyncArkImageClient,
  135. ark_client: AsyncArkClient,
  136. lens: Dict
  137. ) -> None:
  138. """生成单个镜头的分镜图片(带审查和重试机制)"""
  139. lens_id = lens.get("lens_id")
  140. image_prompt = lens.get("image_prompt")
  141. if not image_prompt:
  142. raise ValueError(f"镜头 {lens_id} 缺少 image_prompt 字段")
  143. image_save_path = str(storyboard_dir / f"lens_{lens_id}_{self.output_dir.name}.png")
  144. # 如果文件已存在,跳过生成
  145. if os.path.exists(image_save_path):
  146. logger.info(f"分镜图片已存在,跳过生成: lens {lens_id}")
  147. lens["storyboard_url"] = image_save_path
  148. return
  149. # 最大重试次数
  150. max_retries = 5
  151. attempt_count = 0
  152. review_passed = False
  153. temp_image_path = str(storyboard_dir / f"lens_{lens_id}_{self.output_dir.name}_temp.png")
  154. last_image_url = None
  155. while attempt_count <= max_retries and not review_passed:
  156. try:
  157. attempt_count += 1
  158. logger.info(f"镜头 {lens_id} 开始第 {attempt_count} 次生成...")
  159. # 生成分镜图片(使用参考图片)
  160. response = await image_client.generate_image(
  161. prompt=image_prompt,
  162. size=size,
  163. reference_image=reference_images
  164. )
  165. image_url = image_client.get_image_url(response)
  166. if not image_url:
  167. raise ValueError(f"镜头 {lens_id} 生成分镜图片失败,未获取到图片URL")
  168. last_image_url = image_url
  169. # 下载图片到临时路径(用于审查)
  170. await asyncio.to_thread(download_image, image_url, temp_image_path)
  171. # 上传图片到TOS获取URL(用于审查)
  172. image_url_for_check = await asyncio.to_thread(upload_file_to_tos, temp_image_path)
  173. logger.info(f"镜头 {lens_id} 第 {attempt_count} 次生成完成,图片已上传: {image_url_for_check}")
  174. # 调用审查函数
  175. check_result_text = await check_image_script(
  176. client=ark_client,
  177. image_prompt=image_prompt,
  178. image_url=image_url_for_check
  179. )
  180. # 解析审查结果
  181. check_result = self.io_handler.string_to_json(check_result_text)
  182. review_result = check_result.get("review_result", False)
  183. result_reason = check_result.get("result_reason", "")
  184. if review_result:
  185. # 审查不通过(review_result为true表示有问题,需要重新生成)
  186. if attempt_count > max_retries:
  187. # 超过最大重试次数,使用最后一次生成的图片
  188. logger.error(
  189. f"镜头 {lens_id} 已达到最大重试次数 ({max_retries}),"
  190. f"审查结果: {result_reason},使用最后一次生成的图片"
  191. )
  192. # 将临时文件重命名为最终文件
  193. if os.path.exists(temp_image_path):
  194. os.rename(temp_image_path, image_save_path)
  195. else:
  196. # 如果临时文件不存在,重新下载
  197. await asyncio.to_thread(download_image, last_image_url, image_save_path)
  198. review_passed = True
  199. else:
  200. # 继续重试
  201. logger.warning(
  202. f"镜头 {lens_id} 第 {attempt_count} 次生成审查不通过: {result_reason},"
  203. f"将重新生成(剩余重试次数: {max_retries - attempt_count})"
  204. )
  205. # 保留临时文件,下次生成时会覆盖
  206. else:
  207. # 审查通过(review_result为false表示通过)
  208. review_passed = True
  209. logger.info(
  210. f"镜头 {lens_id} 第 {attempt_count} 次生成审查通过: {result_reason}"
  211. )
  212. # 将临时文件重命名为最终文件
  213. if os.path.exists(temp_image_path):
  214. os.rename(temp_image_path, image_save_path)
  215. else:
  216. # 如果临时文件不存在,重新下载
  217. await asyncio.to_thread(download_image, last_image_url, image_save_path)
  218. except Exception as e:
  219. logger.error(f"镜头 {lens_id} 第 {attempt_count} 次生成出错: {e}")
  220. if attempt_count > max_retries:
  221. logger.error(f"镜头 {lens_id} 已达到最大重试次数,停止重试")
  222. # 如果还有临时文件,尝试使用它
  223. if os.path.exists(temp_image_path):
  224. os.rename(temp_image_path, image_save_path)
  225. raise
  226. # 继续重试
  227. continue
  228. lens["storyboard_url"] = image_save_path
  229. logger.info(f"镜头 {lens_id} 分镜图片最终生成完成: {image_save_path}(共尝试 {attempt_count} 次)")
  230. async with AsyncArkImageClient() as image_client, AsyncArkClient() as ark_client:
  231. # 并行处理所有镜头
  232. tasks = [
  233. generate_single_storyboard(image_client, ark_client, lens)
  234. for lens in script["lens_details"]
  235. ]
  236. await asyncio.gather(*tasks, return_exceptions=True)
  237. output_file = str(self.output_dir / "step3_storyboard.json")
  238. await self.io_handler.write_json_async(script, output_file)
  239. return {
  240. "output_file": output_file,
  241. "data": script
  242. }
  243. async def step4_generate_video_clips(self) -> Dict:
  244. """步骤4:基于video_prompt和分镜图片生成视频片段"""
  245. previous_output = self.manager.load_step_output("step3")
  246. if previous_output is None:
  247. raise ValueError("步骤3未完成,无法生成视频片段")
  248. script = previous_output["data"]
  249. # 确保video_clips目录存在
  250. video_clip_dir = self.output_dir / "video_clips"
  251. video_clip_dir.mkdir(parents=True, exist_ok=True)
  252. async def process_single_lens(video_client: AsyncArkVideoClient, lens: Dict) -> Optional[asyncio.Task]:
  253. """处理单个镜头的视频生成,返回后台任务"""
  254. lens_id = lens.get("lens_id")
  255. video_prompt = lens.get("video_prompt")
  256. storyboard_url = lens.get("storyboard_url")
  257. lens_duration = lens.get("lens_duration", 4)
  258. if not video_prompt:
  259. raise ValueError(f"镜头 {lens_id} 缺少 video_prompt 字段")
  260. if not storyboard_url:
  261. raise ValueError(f"镜头 {lens_id} 缺少 storyboard_url 字段")
  262. video_save_path = str(video_clip_dir / f"lens_{lens_id}_{self.output_dir.name}.mp4")
  263. # 如果文件已存在,跳过生成
  264. if os.path.exists(video_save_path):
  265. logger.info(f"视频片段已存在,跳过生成: lens {lens_id}")
  266. lens["clip_url"] = video_save_path
  267. return None
  268. # 如果storyboard_url是本地路径,需要上传到TOS获取URL
  269. image_url = storyboard_url
  270. if not storyboard_url.startswith(("http://", "https://")):
  271. # 上传到TOS获取URL
  272. from examples.video_create.utils.tools import upload_file_to_tos
  273. image_url = await asyncio.to_thread(upload_file_to_tos, storyboard_url)
  274. logger.info(f"镜头 {lens_id} 分镜图片已上传到TOS: {image_url}")
  275. # 构建生成参数字符串
  276. gen_params = f" --dur {lens_duration}"
  277. # 创建视频生成任务
  278. task_id, background_task = await video_client.create_video_task_async(
  279. prompt=video_prompt,
  280. image_url=image_url,
  281. gen_params=gen_params,
  282. output_path=video_save_path
  283. )
  284. if background_task is not None:
  285. lens["clip_url"] = video_save_path
  286. logger.info(f"已提交视频生成任务,lens {lens_id}, task_id: {task_id}")
  287. return background_task
  288. else:
  289. logger.error(f"视频生成任务提交失败,lens {lens_id}")
  290. return None
  291. async with AsyncArkVideoClient() as video_client:
  292. # 并行提交所有视频生成任务,收集后台任务
  293. lens_tasks = [
  294. process_single_lens(video_client, lens)
  295. for lens in script["lens_details"]
  296. ]
  297. lens_results = await asyncio.gather(*lens_tasks, return_exceptions=True)
  298. # 展平所有后台任务
  299. all_background_tasks = []
  300. for task in lens_results:
  301. if isinstance(task, Exception):
  302. logger.error(f"处理镜头时出错: {task}")
  303. elif task is not None:
  304. all_background_tasks.append(task)
  305. # 等待所有视频生成和下载完成
  306. if all_background_tasks:
  307. logger.info(f"等待 {len(all_background_tasks)} 个视频生成任务完成...")
  308. await asyncio.gather(*all_background_tasks, return_exceptions=True)
  309. logger.info("所有视频生成任务已完成!")
  310. else:
  311. logger.warning("没有提交任何视频生成任务")
  312. output_file = str(self.output_dir / "step4_video_clips.json")
  313. await self.io_handler.write_json_async(script, output_file)
  314. return {
  315. "output_file": output_file,
  316. "data": script
  317. }
  318. async def step5_concat_clips(self) -> Dict:
  319. """步骤5:拼接所有视频片段"""
  320. previous_output = self.manager.load_step_output("step4")
  321. if previous_output is None:
  322. raise ValueError("步骤4未完成,无法进行视频拼接")
  323. script = previous_output["data"]
  324. # 确保video_save目录存在
  325. video_save_dir = self.output_dir / "video_save"
  326. video_save_dir.mkdir(parents=True, exist_ok=True)
  327. # 收集所有视频片段路径(按lens_id排序)
  328. clips_path = []
  329. for lens in sorted(script["lens_details"], key=lambda x: x.get("lens_id", 0)):
  330. clip_url = lens.get("clip_url")
  331. if clip_url and os.path.exists(clip_url):
  332. clips_path.append(clip_url)
  333. else:
  334. logger.warning(f"镜头 {lens.get('lens_id')} 的视频片段不存在,跳过")
  335. if not clips_path:
  336. raise ValueError("没有可用的视频片段进行拼接")
  337. output_file = str(video_save_dir / "final_video.mp4")
  338. # 拼接视频(使用线程池执行同步操作)
  339. await asyncio.to_thread(concat_videos, clips_path, output_file)
  340. logger.info(f"视频拼接完成: {output_file}")
  341. return {
  342. "output_file": output_file,
  343. "data": "final video"
  344. }