ark_video_client_async.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. """
  2. 火山引擎ARK视频生成API异步客户端
  3. 封装ARK视频生成API的异步调用,提供类型安全的接口
  4. """
  5. import os
  6. import time
  7. import asyncio
  8. from typing import Optional, Dict, Any, Callable, Tuple
  9. from .base_client_async import AsyncAPIClient, APIError
  10. from .ark_video_client import TaskStatus # 复用同步版本的枚举
  11. from taskflow.logger import get_logger
  12. from taskflow.config import get_config
  13. from examples.video_create.utils.tools import upload_file_to_tos, download_video
  14. logger = get_logger("api_modules.ark_video_client_async")
  15. async def handle_video_result(
  16. task_id: str,
  17. output_path: str,
  18. result: Optional[Dict],
  19. error: Optional[str]
  20. ) -> None:
  21. """
  22. 处理视频生成结果的异步回调函数
  23. Args:
  24. task_id: 任务ID
  25. output_path: 视频输出路径
  26. result: 任务结果(如果成功)
  27. error: 错误信息(如果失败)
  28. """
  29. if error:
  30. logger.info(f"\n任务 {task_id} 处理失败:{error}")
  31. else:
  32. video_url = result.get("content", {}).get("video_url")
  33. if video_url:
  34. # 使用 asyncio.to_thread 在后台线程中执行同步的下载函数
  35. await asyncio.to_thread(download_video, video_url, output_path)
  36. logger.info(f"生成视频已下载:{output_path}")
  37. else:
  38. logger.warning(f"任务 {task_id} 完成但未获取到视频URL")
  39. class AsyncArkVideoClient(AsyncAPIClient):
  40. """
  41. 火山引擎ARK视频生成API异步客户端
  42. 封装ARK视频生成API的异步调用,提供便捷的接口
  43. """
  44. DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
  45. DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks"
  46. DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528"
  47. def __init__(
  48. self,
  49. api_key: Optional[str] = None,
  50. base_url: Optional[str] = None,
  51. model: Optional[str] = None,
  52. timeout: int = 60,
  53. poll_interval: int = 5,
  54. max_poll_time: int = 500,
  55. **kwargs
  56. ):
  57. """
  58. 初始化ARK视频生成API异步客户端
  59. Args:
  60. api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
  61. base_url: API基础URL(默认使用官方URL)
  62. model: 模型名称(如果为None,会尝试从配置中获取)
  63. timeout: 请求超时时间(秒,默认60秒)
  64. poll_interval: 轮询间隔(秒,默认5秒)
  65. max_poll_time: 最大轮询总时间(秒,默认500秒)
  66. **kwargs: 传递给AsyncAPIClient的其他参数
  67. """
  68. # 获取API密钥(优先级:参数 > 环境变量 > 配置)
  69. if api_key is None:
  70. api_key = os.getenv("ARK_API_KEY")
  71. if api_key is None:
  72. config = get_config()
  73. api_key = config.get("api.ark.api_key")
  74. if not api_key:
  75. raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
  76. # 获取base_url(优先级:参数 > 配置 > 默认值)
  77. if base_url is None:
  78. config = get_config()
  79. base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
  80. # 获取model(优先级:参数 > 配置 > 默认值)
  81. if model is None:
  82. config = get_config()
  83. model = config.get("api.ark.video_model", self.DEFAULT_MODEL)
  84. super().__init__(
  85. base_url=base_url,
  86. api_key=api_key,
  87. timeout=timeout,
  88. **kwargs
  89. )
  90. # 保存视频生成相关配置
  91. self.model = model
  92. self.poll_interval = poll_interval
  93. self.max_poll_time = max_poll_time
  94. logger.info(f"ARK视频生成API异步客户端初始化完成,模型: {self.model}")
  95. async def create_video_task(
  96. self,
  97. prompt: str,
  98. image_url: str,
  99. gen_params: str = "",
  100. **kwargs
  101. ) -> Dict[str, Any]:
  102. """
  103. 异步创建视频生成任务
  104. Args:
  105. prompt: 视频生成提示词(必填)
  106. image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL)
  107. gen_params: 额外的生成参数(可选,会追加到prompt后面)
  108. **kwargs: 其他请求参数(会覆盖默认配置)
  109. Returns:
  110. API响应数据,包含任务ID等信息
  111. Raises:
  112. APIError: 如果请求失败
  113. ValueError: 如果参数无效
  114. """
  115. if not prompt or not prompt.strip():
  116. raise ValueError("prompt不能为空")
  117. if not image_url or not image_url.strip():
  118. raise ValueError("image_url不能为空")
  119. # # 验证image_url是否为URL格式
  120. # if not image_url.startswith(("http://", "https://")):
  121. # raise ValueError(
  122. # f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。"
  123. # "如果是本地文件路径,请先上传到云存储获取URL。"
  124. # )
  125. image_url = upload_file_to_tos(image_url) if "http" not in image_url else image_url
  126. # 构建请求体
  127. request_data = {
  128. "model": kwargs.get("model", self.model),
  129. "content": [
  130. {
  131. "type": "text",
  132. "text": prompt + gen_params
  133. },
  134. {
  135. "type": "image_url",
  136. "image_url": {
  137. "url": image_url
  138. }
  139. }
  140. ],
  141. **{k: v for k, v in kwargs.items() if k != "model"}
  142. }
  143. logger.info(f"创建异步视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...")
  144. logger.info(f"参考图片: {image_url}")
  145. try:
  146. response = await self.post(
  147. endpoint=self.DEFAULT_ENDPOINT,
  148. json=request_data
  149. )
  150. logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}")
  151. return response
  152. except APIError as e:
  153. logger.error(f"创建视频生成任务失败: {e}")
  154. raise
  155. async def query_video_task(self, task_id: str) -> Dict[str, Any]:
  156. """
  157. 异步查询视频生成任务状态
  158. Args:
  159. task_id: 任务ID(从create_video_task响应中获取)
  160. Returns:
  161. 任务状态详情,包含状态、视频URL等信息
  162. Raises:
  163. APIError: 如果请求失败
  164. ValueError: 如果参数无效
  165. """
  166. if not task_id or not task_id.strip():
  167. raise ValueError("task_id不能为空")
  168. query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}"
  169. logger.debug(f"查询异步视频生成任务状态,任务ID: {task_id}")
  170. try:
  171. response = await self.get(endpoint=query_endpoint)
  172. status = response.get("status", "").lower()
  173. logger.debug(f"任务 {task_id} 状态: {status}")
  174. return response
  175. except APIError as e:
  176. logger.error(f"查询视频生成任务状态失败: {e}")
  177. raise
  178. async def wait_for_task(
  179. self,
  180. task_id: str,
  181. callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None
  182. ) -> Dict[str, Any]:
  183. """
  184. 异步等待任务完成(异步轮询)
  185. Args:
  186. task_id: 任务ID
  187. callback: 可选的回调函数,参数为 (task_id, result, error)
  188. 注意:回调函数如果是异步的,需要使用asyncio.create_task调用
  189. Returns:
  190. 任务完成后的结果
  191. Raises:
  192. APIError: 如果请求失败
  193. TimeoutError: 如果任务超时
  194. """
  195. start_time = time.time()
  196. while True:
  197. elapsed = time.time() - start_time
  198. if elapsed > self.max_poll_time:
  199. error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
  200. logger.error(f"任务 {task_id} {error_msg}")
  201. if callback:
  202. # 如果回调是协程函数,需要特殊处理
  203. if asyncio.iscoroutinefunction(callback):
  204. await callback(task_id, {}, error_msg)
  205. else:
  206. callback(task_id, {}, error_msg)
  207. raise TimeoutError(error_msg)
  208. # 查询任务状态
  209. result = await self.query_video_task(task_id)
  210. if not result:
  211. logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
  212. await asyncio.sleep(self.poll_interval)
  213. continue
  214. # 解析状态
  215. status = result.get("status", "").lower()
  216. if status == TaskStatus.SUCCEEDED:
  217. logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
  218. if callback:
  219. if asyncio.iscoroutinefunction(callback):
  220. await callback(task_id, result, None)
  221. else:
  222. callback(task_id, result, None)
  223. return result
  224. elif status == TaskStatus.FAILED:
  225. error_msg = result.get("error", {}).get("message", "未知错误")
  226. logger.error(f"任务 {task_id} 失败: {error_msg}")
  227. if callback:
  228. if asyncio.iscoroutinefunction(callback):
  229. await callback(task_id, {}, error_msg)
  230. else:
  231. callback(task_id, {}, error_msg)
  232. raise APIError(f"任务失败: {error_msg}")
  233. elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
  234. logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
  235. await asyncio.sleep(self.poll_interval)
  236. else:
  237. logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
  238. await asyncio.sleep(self.poll_interval)
  239. async def create_and_wait(
  240. self,
  241. prompt: str,
  242. image_url: str,
  243. gen_params: str = "",
  244. callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None,
  245. **kwargs
  246. ) -> Dict[str, Any]:
  247. """
  248. 异步创建视频生成任务并等待完成(便捷方法)
  249. Args:
  250. prompt: 视频生成提示词(必填)
  251. image_url: 参考图片URL(必填)
  252. gen_params: 额外的生成参数(可选)
  253. callback: 可选的回调函数,参数为 (task_id, result, error)
  254. **kwargs: 其他请求参数
  255. Returns:
  256. 任务完成后的结果
  257. Raises:
  258. APIError: 如果请求失败
  259. TimeoutError: 如果任务超时
  260. """
  261. # 创建任务
  262. task_response = await self.create_video_task(
  263. prompt=prompt,
  264. image_url=image_url,
  265. gen_params=gen_params,
  266. **kwargs
  267. )
  268. task_id = task_response.get("id")
  269. if not task_id:
  270. raise APIError("创建任务成功但未返回任务ID")
  271. logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...")
  272. # 等待任务完成
  273. return await self.wait_for_task(task_id, callback=callback)
  274. async def create_video_task_async(
  275. self,
  276. prompt: str,
  277. image_url: str,
  278. gen_params: str = "",
  279. callback: Optional[Callable] = handle_video_result,
  280. output_path: Optional[str] = None,
  281. **kwargs
  282. ) -> Tuple[Optional[str], Optional[asyncio.Task]]:
  283. """
  284. 创建视频生成任务并立即返回task_id和后台任务对象(不阻塞主流程)
  285. 任务会在后台异步任务中轮询,完成后调用回调函数。
  286. 调用者可以通过返回的任务对象等待任务完成。
  287. Args:
  288. prompt: 视频生成提示词(必填)
  289. image_url: 参考图片URL(必填)
  290. gen_params: 额外的生成参数(可选,会追加到prompt后面)
  291. callback: 可选的回调函数,可以是以下两种签名之一:
  292. 1. (task_id, result, error) -> None
  293. 2. (task_id, output_path, result, error) -> None
  294. 注意:如果是异步函数,需要使用 async def 定义
  295. output_path: 视频输出路径(可选,会传递给回调函数)
  296. **kwargs: 其他请求参数(会覆盖默认配置)
  297. Returns:
  298. 元组 (task_id, background_task):
  299. - task_id: 任务ID,如果创建失败则返回None
  300. - background_task: 后台异步任务对象,可以用于等待任务完成
  301. 如果创建失败则返回None
  302. Raises:
  303. APIError: 如果创建任务失败
  304. """
  305. # 创建任务
  306. task_response = await self.create_video_task(
  307. prompt=prompt,
  308. image_url=image_url,
  309. gen_params=gen_params,
  310. **kwargs
  311. )
  312. task_id = task_response.get("id")
  313. if not task_id:
  314. logger.error("任务提交失败,无法启动后台轮询")
  315. return None, None
  316. logger.info(f"任务提交成功,task_id: {task_id},启动后台异步轮询...")
  317. # 定义后台异步任务包装函数
  318. async def _background_wait():
  319. """后台异步任务:等待任务完成并调用回调"""
  320. try:
  321. # 等待任务完成
  322. result = await self.wait_for_task(task_id)
  323. # 调用回调函数
  324. if callback:
  325. import inspect
  326. sig = inspect.signature(callback)
  327. param_count = len(sig.parameters)
  328. if asyncio.iscoroutinefunction(callback):
  329. # 异步回调函数
  330. if param_count == 4:
  331. await callback(task_id, output_path or "", result, None)
  332. else:
  333. await callback(task_id, result, None)
  334. else:
  335. # 同步回调函数
  336. if param_count == 4:
  337. # 4参数版本:(task_id, output_path, result, error)
  338. callback(task_id, output_path or "", result, None)
  339. else:
  340. # 3参数版本:(task_id, result, error)
  341. callback(task_id, result, None)
  342. except Exception as e:
  343. error_msg = str(e)
  344. logger.error(f"后台异步任务处理失败: {error_msg}")
  345. if callback:
  346. import inspect
  347. sig = inspect.signature(callback)
  348. param_count = len(sig.parameters)
  349. if asyncio.iscoroutinefunction(callback):
  350. if param_count == 4:
  351. await callback(task_id, output_path or "", {}, error_msg)
  352. else:
  353. await callback(task_id, {}, error_msg)
  354. else:
  355. if param_count == 4:
  356. callback(task_id, output_path or "", {}, error_msg)
  357. else:
  358. callback(task_id, {}, error_msg)
  359. # 启动后台异步任务并返回任务对象,以便调用者可以等待
  360. background_task = asyncio.create_task(_background_wait())
  361. # 返回任务ID和任务对象(使用元组)
  362. return task_id, background_task
  363. def get_video_url(self, result: Dict[str, Any]) -> Optional[str]:
  364. """
  365. 从任务结果中提取视频URL
  366. Args:
  367. result: 任务结果(从query_video_task或wait_for_task返回)
  368. Returns:
  369. 视频URL,如果不存在则返回None
  370. """
  371. try:
  372. content = result.get("content", {})
  373. if isinstance(content, dict):
  374. return content.get("video_url")
  375. return None
  376. except (KeyError, TypeError, AttributeError) as e:
  377. logger.warning(f"提取视频URL失败: {e}")
  378. return None
  379. def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
  380. """
  381. 从任务结果中提取任务状态
  382. Args:
  383. result: 任务结果
  384. Returns:
  385. 任务状态字符串,如果不存在则返回None
  386. """
  387. try:
  388. return result.get("status", "").lower()
  389. except (KeyError, TypeError, AttributeError):
  390. return None
  391. async def main():
  392. # 示例用法
  393. async with AsyncArkVideoClient() as client:
  394. # 方式1:创建任务并等待完成(阻塞)
  395. try:
  396. result = await client.create_and_wait(
  397. prompt="图中的女生在街道上散步",
  398. image_url="https://example.com/image.jpg", # 必须是可访问的URL
  399. gen_params=" --dur 4"
  400. )
  401. video_url = client.get_video_url(result)
  402. print(f"视频生成成功,URL: {video_url}")
  403. except Exception as e:
  404. print(f"视频生成失败: {e}")
  405. # 方式2:异步创建任务(不阻塞,后台处理)
  406. # task_id = await client.create_video_task_async(
  407. # prompt="图中的女生在街道上散步",
  408. # image_url="https://example.com/image.jpg",
  409. # gen_params=" --dur 4",
  410. # callback=handle_video_result,
  411. # output_path="./output/video.mp4"
  412. # )
  413. # print(f"任务已提交,task_id: {task_id},主流程继续执行...")
  414. # # 主流程可以继续执行其他操作
  415. # await asyncio.sleep(10) # 等待一段时间
  416. # 方式3:创建任务后手动查询
  417. # task_response = await client.create_video_task(
  418. # prompt="图中的女生在街道上散步",
  419. # image_url="https://example.com/image.jpg",
  420. # gen_params=" --dur 4"
  421. # )
  422. # task_id = task_response.get("id")
  423. # result = await client.wait_for_task(task_id)
  424. # video_url = client.get_video_url(result)
  425. # print(f"视频URL: {video_url}")
  426. if __name__ == "__main__":
  427. asyncio.run(main())