""" 火山引擎ARK视频生成API异步客户端 封装ARK视频生成API的异步调用,提供类型安全的接口 """ import os import time import asyncio from typing import Optional, Dict, Any, Callable, Tuple from .base_client_async import AsyncAPIClient, APIError from .ark_video_client import TaskStatus # 复用同步版本的枚举 from taskflow.logger import get_logger from taskflow.config import get_config from examples.video_create.utils.tools import upload_file_to_tos, download_video logger = get_logger("api_modules.ark_video_client_async") async def handle_video_result( task_id: str, output_path: str, result: Optional[Dict], error: Optional[str] ) -> None: """ 处理视频生成结果的异步回调函数 Args: task_id: 任务ID output_path: 视频输出路径 result: 任务结果(如果成功) error: 错误信息(如果失败) """ if error: logger.info(f"\n任务 {task_id} 处理失败:{error}") else: video_url = result.get("content", {}).get("video_url") if video_url: # 使用 asyncio.to_thread 在后台线程中执行同步的下载函数 await asyncio.to_thread(download_video, video_url, output_path) logger.info(f"生成视频已下载:{output_path}") else: logger.warning(f"任务 {task_id} 完成但未获取到视频URL") class AsyncArkVideoClient(AsyncAPIClient): """ 火山引擎ARK视频生成API异步客户端 封装ARK视频生成API的异步调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks" DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 60, poll_interval: int = 5, max_poll_time: int = 500, **kwargs ): """ 初始化ARK视频生成API异步客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认60秒) poll_interval: 轮询间隔(秒,默认5秒) max_poll_time: 最大轮询总时间(秒,默认500秒) **kwargs: 传递给AsyncAPIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.video_model", self.DEFAULT_MODEL) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, **kwargs ) # 保存视频生成相关配置 self.model = model self.poll_interval = poll_interval self.max_poll_time = max_poll_time logger.info(f"ARK视频生成API异步客户端初始化完成,模型: {self.model}") async def create_video_task( self, prompt: str, image_url: str, gen_params: str = "", **kwargs ) -> Dict[str, Any]: """ 异步创建视频生成任务 Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL) gen_params: 额外的生成参数(可选,会追加到prompt后面) **kwargs: 其他请求参数(会覆盖默认配置) Returns: API响应数据,包含任务ID等信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ if not prompt or not prompt.strip(): raise ValueError("prompt不能为空") if not image_url or not image_url.strip(): raise ValueError("image_url不能为空") # # 验证image_url是否为URL格式 # if not image_url.startswith(("http://", "https://")): # raise ValueError( # f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。" # "如果是本地文件路径,请先上传到云存储获取URL。" # ) image_url = upload_file_to_tos(image_url) if "http" not in image_url else image_url # 构建请求体 request_data = { "model": kwargs.get("model", self.model), "content": [ { "type": "text", "text": prompt + gen_params }, { "type": "image_url", "image_url": { "url": image_url } } ], **{k: v for k, v in kwargs.items() if k != "model"} } logger.info(f"创建异步视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...") logger.info(f"参考图片: {image_url}") try: response = await self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}") return response except APIError as e: logger.error(f"创建视频生成任务失败: {e}") raise async def query_video_task(self, task_id: str) -> Dict[str, Any]: """ 异步查询视频生成任务状态 Args: task_id: 任务ID(从create_video_task响应中获取) Returns: 任务状态详情,包含状态、视频URL等信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ if not task_id or not task_id.strip(): raise ValueError("task_id不能为空") query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}" logger.debug(f"查询异步视频生成任务状态,任务ID: {task_id}") try: response = await self.get(endpoint=query_endpoint) status = response.get("status", "").lower() logger.debug(f"任务 {task_id} 状态: {status}") return response except APIError as e: logger.error(f"查询视频生成任务状态失败: {e}") raise async def wait_for_task( self, task_id: str, callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None ) -> Dict[str, Any]: """ 异步等待任务完成(异步轮询) Args: task_id: 任务ID callback: 可选的回调函数,参数为 (task_id, result, error) 注意:回调函数如果是异步的,需要使用asyncio.create_task调用 Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 TimeoutError: 如果任务超时 """ start_time = time.time() while True: elapsed = time.time() - start_time if elapsed > self.max_poll_time: error_msg = f"任务超时(超过 {self.max_poll_time} 秒)" logger.error(f"任务 {task_id} {error_msg}") if callback: # 如果回调是协程函数,需要特殊处理 if asyncio.iscoroutinefunction(callback): await callback(task_id, {}, error_msg) else: callback(task_id, {}, error_msg) raise TimeoutError(error_msg) # 查询任务状态 result = await self.query_video_task(task_id) if not result: logger.warning(f"任务 {task_id} 查询结果为空,继续等待...") await asyncio.sleep(self.poll_interval) continue # 解析状态 status = result.get("status", "").lower() if status == TaskStatus.SUCCEEDED: logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒") if callback: if asyncio.iscoroutinefunction(callback): await callback(task_id, result, None) else: callback(task_id, result, None) return result elif status == TaskStatus.FAILED: error_msg = result.get("error", {}).get("message", "未知错误") logger.error(f"任务 {task_id} 失败: {error_msg}") if callback: if asyncio.iscoroutinefunction(callback): await callback(task_id, {}, error_msg) else: callback(task_id, {}, error_msg) raise APIError(f"任务失败: {error_msg}") elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]: logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}") await asyncio.sleep(self.poll_interval) else: logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...") await asyncio.sleep(self.poll_interval) async def create_and_wait( self, prompt: str, image_url: str, gen_params: str = "", callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None, **kwargs ) -> Dict[str, Any]: """ 异步创建视频生成任务并等待完成(便捷方法) Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填) gen_params: 额外的生成参数(可选) callback: 可选的回调函数,参数为 (task_id, result, error) **kwargs: 其他请求参数 Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 TimeoutError: 如果任务超时 """ # 创建任务 task_response = await self.create_video_task( prompt=prompt, image_url=image_url, gen_params=gen_params, **kwargs ) task_id = task_response.get("id") if not task_id: raise APIError("创建任务成功但未返回任务ID") logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...") # 等待任务完成 return await self.wait_for_task(task_id, callback=callback) async def create_video_task_async( self, prompt: str, image_url: str, gen_params: str = "", callback: Optional[Callable] = handle_video_result, output_path: Optional[str] = None, **kwargs ) -> Tuple[Optional[str], Optional[asyncio.Task]]: """ 创建视频生成任务并立即返回task_id和后台任务对象(不阻塞主流程) 任务会在后台异步任务中轮询,完成后调用回调函数。 调用者可以通过返回的任务对象等待任务完成。 Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填) gen_params: 额外的生成参数(可选,会追加到prompt后面) callback: 可选的回调函数,可以是以下两种签名之一: 1. (task_id, result, error) -> None 2. (task_id, output_path, result, error) -> None 注意:如果是异步函数,需要使用 async def 定义 output_path: 视频输出路径(可选,会传递给回调函数) **kwargs: 其他请求参数(会覆盖默认配置) Returns: 元组 (task_id, background_task): - task_id: 任务ID,如果创建失败则返回None - background_task: 后台异步任务对象,可以用于等待任务完成 如果创建失败则返回None Raises: APIError: 如果创建任务失败 """ # 创建任务 task_response = await self.create_video_task( prompt=prompt, image_url=image_url, gen_params=gen_params, **kwargs ) task_id = task_response.get("id") if not task_id: logger.error("任务提交失败,无法启动后台轮询") return None, None logger.info(f"任务提交成功,task_id: {task_id},启动后台异步轮询...") # 定义后台异步任务包装函数 async def _background_wait(): """后台异步任务:等待任务完成并调用回调""" try: # 等待任务完成 result = await self.wait_for_task(task_id) # 调用回调函数 if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if asyncio.iscoroutinefunction(callback): # 异步回调函数 if param_count == 4: await callback(task_id, output_path or "", result, None) else: await callback(task_id, result, None) else: # 同步回调函数 if param_count == 4: # 4参数版本:(task_id, output_path, result, error) callback(task_id, output_path or "", result, None) else: # 3参数版本:(task_id, result, error) callback(task_id, result, None) except Exception as e: error_msg = str(e) logger.error(f"后台异步任务处理失败: {error_msg}") if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if asyncio.iscoroutinefunction(callback): if param_count == 4: await callback(task_id, output_path or "", {}, error_msg) else: await callback(task_id, {}, error_msg) else: if param_count == 4: callback(task_id, output_path or "", {}, error_msg) else: callback(task_id, {}, error_msg) # 启动后台异步任务并返回任务对象,以便调用者可以等待 background_task = asyncio.create_task(_background_wait()) # 返回任务ID和任务对象(使用元组) return task_id, background_task def get_video_url(self, result: Dict[str, Any]) -> Optional[str]: """ 从任务结果中提取视频URL Args: result: 任务结果(从query_video_task或wait_for_task返回) Returns: 视频URL,如果不存在则返回None """ try: content = result.get("content", {}) if isinstance(content, dict): return content.get("video_url") return None except (KeyError, TypeError, AttributeError) as e: logger.warning(f"提取视频URL失败: {e}") return None def get_task_status(self, result: Dict[str, Any]) -> Optional[str]: """ 从任务结果中提取任务状态 Args: result: 任务结果 Returns: 任务状态字符串,如果不存在则返回None """ try: return result.get("status", "").lower() except (KeyError, TypeError, AttributeError): return None async def main(): # 示例用法 async with AsyncArkVideoClient() as client: # 方式1:创建任务并等待完成(阻塞) try: result = await client.create_and_wait( prompt="图中的女生在街道上散步", image_url="https://example.com/image.jpg", # 必须是可访问的URL gen_params=" --dur 4" ) video_url = client.get_video_url(result) print(f"视频生成成功,URL: {video_url}") except Exception as e: print(f"视频生成失败: {e}") # 方式2:异步创建任务(不阻塞,后台处理) # task_id = await client.create_video_task_async( # prompt="图中的女生在街道上散步", # image_url="https://example.com/image.jpg", # gen_params=" --dur 4", # callback=handle_video_result, # output_path="./output/video.mp4" # ) # print(f"任务已提交,task_id: {task_id},主流程继续执行...") # # 主流程可以继续执行其他操作 # await asyncio.sleep(10) # 等待一段时间 # 方式3:创建任务后手动查询 # task_response = await client.create_video_task( # prompt="图中的女生在街道上散步", # image_url="https://example.com/image.jpg", # gen_params=" --dur 4" # ) # task_id = task_response.get("id") # result = await client.wait_for_task(task_id) # video_url = client.get_video_url(result) # print(f"视频URL: {video_url}") if __name__ == "__main__": asyncio.run(main())