""" 火山引擎ARK视频生成API客户端 封装ARK视频生成API的调用,提供类型安全的接口 """ import os import time import threading from typing import Optional, Dict, Any, Callable from enum import Enum from examples.video_create.utils.tools import download_video from .base_client import APIClient, APIError from taskflow.logger import get_logger from taskflow.config import get_config logger = get_logger("api_modules.ark_video_client") class TaskStatus(str, Enum): """任务状态枚举""" PENDING = "pending" PROCESSING = "processing" SUCCEEDED = "succeeded" FAILED = "failed" def handle_video_result( task_id: str, output_path: str, result: Optional[Dict], error: Optional[str] ) -> None: if error: logger.info(f"\n任务 {task_id} 处理失败:{error}") else: video_url = result.get("content", {}).get("video_url") download_video(video_url, output_path) logger.info(f"生成视频已下载:{output_path}") class ArkVideoClient(APIClient): """ 火山引擎ARK视频生成API客户端 封装ARK视频生成API的调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks" DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 60, poll_interval: int = 5, max_poll_time: int = 500, **kwargs ): """ 初始化ARK视频生成API客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认60秒) poll_interval: 轮询间隔(秒,默认5秒) max_poll_time: 最大轮询总时间(秒,默认500秒) **kwargs: 传递给APIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.video_model", self.DEFAULT_MODEL) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, **kwargs ) # 保存视频生成相关配置 self.model = model self.poll_interval = poll_interval self.max_poll_time = max_poll_time logger.info(f"ARK视频生成API客户端初始化完成,模型: {self.model}") def _build_gen_params( self, duration: Optional[int] = None, ratio: Optional[str] = None, resolution: Optional[str] = None, watermark: Optional[str] = None, camerafixed: Optional[str] = None, **kwargs ) -> str: """ 构建生成参数字符串 如果参数为None,则使用默认值或从kwargs中获取 Args: duration: 视频时长(秒,默认4秒) ratio: 视频比例(默认"16:9") resolution: 视频分辨率(默认"1080p") watermark: 水印开关(默认"false") camerafixed: 相机固定开关(默认"false") **kwargs: 其他参数,会优先从kwargs中获取 Returns: 生成参数字符串,格式如:"--dur 4 --rt 16:9 --rs 1080p --wm false --cf false" """ # 默认值 defaults = { "duration": 4, "ratio": "16:9", "resolution": "1080p", "watermark": "false", "camerafixed": "false" } # 从kwargs中获取参数,如果没有则使用传入的参数,再没有则使用默认值 # 优先级:kwargs > 显式参数 > 默认值 def get_param(key: str, param_value: Any) -> Any: if key in kwargs: return kwargs[key] return param_value if param_value is not None else defaults[key] duration = get_param("duration", duration) ratio = get_param("ratio", ratio) resolution = get_param("resolution", resolution) watermark = get_param("watermark", watermark) camerafixed = get_param("camerafixed", camerafixed) # 构建参数字符串 params = [ f"--dur {duration}", f"--rt {ratio}", f"--rs {resolution}", f"--wm {watermark}", f"--cf {camerafixed}" ] return " ".join(params) + " " def create_video_task( self, prompt: str, image_url: str, gen_params: Optional[str] = None, duration: Optional[int] = None, ratio: Optional[str] = None, resolution: Optional[str] = None, watermark: Optional[str] = None, camerafixed: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 创建视频生成任务 Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL) gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数) duration: 视频时长(秒,默认4秒) ratio: 视频比例(默认"16:9") resolution: 视频分辨率(默认"1080p") watermark: 水印开关(默认"false") camerafixed: 相机固定开关(默认"false") **kwargs: 其他请求参数(会覆盖默认配置,包括生成参数) Returns: API响应数据,包含任务ID等信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 示例: >>> client.create_video_task( ... prompt="一个美丽的风景", ... image_url="https://example.com/image.jpg", ... duration=5, ... ratio="9:16", ... resolution="720p" ... ) """ if not prompt or not prompt.strip(): raise ValueError("prompt不能为空") if not image_url or not image_url.strip(): raise ValueError("image_url不能为空") # 验证image_url是否为URL格式 if not image_url.startswith(("http://", "https://")): raise ValueError( f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。" "如果是本地文件路径,请先上传到云存储获取URL。" ) # 构建生成参数:如果提供了gen_params字符串,直接使用;否则根据参数构建 if gen_params is None: # 合并kwargs和显式参数 gen_params_kwargs = { "duration": duration, "ratio": ratio, "resolution": resolution, "watermark": watermark, "camerafixed": camerafixed, **kwargs } gen_params = self._build_gen_params(**gen_params_kwargs) else: # 如果提供了gen_params字符串,确保以空格结尾(如果没有) gen_params = gen_params.strip() if gen_params and not gen_params.endswith(" "): gen_params += " " # 构建请求体 request_data = { "model": kwargs.get("model", self.model), "content": [ { "type": "text", "text": prompt + gen_params }, { "type": "image_url", "image_url": { "url": image_url } } ], **{k: v for k, v in kwargs.items() if k != "model"} } logger.info(f"创建视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...") logger.info(f"参考图片: {image_url}") try: response = self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}") return response except APIError as e: logger.error(f"创建视频生成任务失败: {e}") raise def query_video_task(self, task_id: str) -> Dict[str, Any]: """ 查询视频生成任务状态 Args: task_id: 任务ID(从create_video_task响应中获取) Returns: 任务状态详情,包含状态、视频URL等信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ if not task_id or not task_id.strip(): raise ValueError("task_id不能为空") query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}" logger.debug(f"查询视频生成任务状态,任务ID: {task_id}") try: response = self.get(endpoint=query_endpoint) status = response.get("status", "").lower() logger.debug(f"任务 {task_id} 状态: {status}") return response except APIError as e: logger.error(f"查询视频生成任务状态失败: {e}") raise def wait_for_task( self, task_id: str, callback: Optional[Callable] = None, callback_kwargs: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ 等待任务完成(同步轮询) Args: task_id: 任务ID callback: 可选的回调函数,可以是以下两种签名之一: 1. (task_id, result, error) -> None 2. (task_id, output_path, result, error) -> None callback_kwargs: 传递给回调函数的额外关键字参数(如 output_path) Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 TimeoutError: 如果任务超时 """ start_time = time.time() callback_kwargs = callback_kwargs or {} while True: elapsed = time.time() - start_time if elapsed > self.max_poll_time: error_msg = f"任务超时(超过 {self.max_poll_time} 秒)" logger.error(f"任务 {task_id} {error_msg}") if callback: # 检查回调函数签名,支持两种格式 import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: # 4参数版本:(task_id, output_path, result, error) callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg) else: # 3参数版本:(task_id, result, error) callback(task_id, {}, error_msg) raise TimeoutError(error_msg) # 查询任务状态 result = self.query_video_task(task_id) if not result: logger.warning(f"任务 {task_id} 查询结果为空,继续等待...") time.sleep(self.poll_interval) continue # 解析状态 status = result.get("status", "").lower() if status == TaskStatus.SUCCEEDED: logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒") if callback: # 检查回调函数签名,支持两种格式 import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: # 4参数版本:(task_id, output_path, result, error) callback(task_id, callback_kwargs.get("output_path", ""), result, None) else: # 3参数版本:(task_id, result, error) callback(task_id, result, None) return result elif status == TaskStatus.FAILED: error_msg = result.get("error", {}).get("message", "未知错误") logger.error(f"任务 {task_id} 失败: {error_msg}") if callback: # 检查回调函数签名,支持两种格式 import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: # 4参数版本:(task_id, output_path, result, error) callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg) else: # 3参数版本:(task_id, result, error) callback(task_id, {}, error_msg) raise APIError(f"任务失败: {error_msg}") elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]: logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}") time.sleep(self.poll_interval) else: logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...") time.sleep(self.poll_interval) def _background_poll( self, task_id: str, callback: Optional[Callable] = None, callback_kwargs: Optional[Dict[str, Any]] = None ): """ 后台轮询任务状态的线程函数 Args: task_id: 任务ID callback: 回调函数 callback_kwargs: 传递给回调函数的额外参数 """ callback_kwargs = callback_kwargs or {} start_time = time.time() while True: elapsed = time.time() - start_time if elapsed > self.max_poll_time: error_msg = f"任务超时(超过 {self.max_poll_time} 秒)" logger.error(f"任务 {task_id} {error_msg}") if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg) else: callback(task_id, {}, error_msg) return # 查询任务状态 try: result = self.query_video_task(task_id) except Exception as e: logger.error(f"查询任务 {task_id} 状态失败: {e}") time.sleep(self.poll_interval) continue if not result: logger.warning(f"任务 {task_id} 查询结果为空,继续等待...") time.sleep(self.poll_interval) continue # 解析状态 status = result.get("status", "").lower() if status == TaskStatus.SUCCEEDED: logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒") if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: callback(task_id, callback_kwargs.get("output_path", ""), result, None) else: callback(task_id, result, None) return elif status == TaskStatus.FAILED: error_msg = result.get("error", {}).get("message", "未知错误") logger.error(f"任务 {task_id} 失败: {error_msg}") if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg) else: callback(task_id, {}, error_msg) return elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]: logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}") time.sleep(self.poll_interval) else: logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...") time.sleep(self.poll_interval) def create_video_task_async( self, prompt: str, image_url: str, gen_params: Optional[str] = None, callback: Optional[Callable] = handle_video_result, output_path: Optional[str] = None, duration: Optional[int] = None, ratio: Optional[str] = None, resolution: Optional[str] = None, watermark: Optional[str] = None, camerafixed: Optional[str] = None, **kwargs ) -> Optional[str]: """ 创建视频生成任务并立即返回task_id(不阻塞主流程) 任务会在后台线程中轮询,完成后调用回调函数。 Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填) gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数) callback: 可选的回调函数,可以是以下两种签名之一: 1. (task_id, result, error) -> None 2. (task_id, output_path, result, error) -> None output_path: 视频输出路径(可选,会传递给回调函数) duration: 视频时长(秒,默认4秒) ratio: 视频比例(默认"16:9") resolution: 视频分辨率(默认"1080p") watermark: 水印开关(默认"false") camerafixed: 相机固定开关(默认"false") **kwargs: 其他请求参数(会覆盖默认配置) Returns: 任务ID(task_id),如果创建失败则返回None Raises: APIError: 如果创建任务失败 """ # 创建任务 task_response = self.create_video_task( prompt=prompt, image_url=image_url, gen_params=gen_params, duration=duration, ratio=ratio, resolution=resolution, watermark=watermark, camerafixed=camerafixed, **kwargs ) task_id = task_response.get("id") if not task_id: logger.error("任务提交失败,无法启动后台轮询") return None logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...") # 准备回调参数 callback_kwargs = {} if output_path: callback_kwargs["output_path"] = output_path # 启动后台线程轮询结果 poll_thread = threading.Thread( target=self._background_poll, args=(task_id, callback), kwargs={"callback_kwargs": callback_kwargs}, daemon=True # 守护线程:主程序退出时自动结束 ) poll_thread.start() return task_id def create_and_wait( self, prompt: str, image_url: str, gen_params: Optional[str] = None, callback: Optional[Callable] = None, output_path: Optional[str] = None, duration: Optional[int] = None, ratio: Optional[str] = None, resolution: Optional[str] = None, watermark: Optional[str] = None, camerafixed: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """ 创建视频生成任务并等待完成(同步方法,会阻塞) Args: prompt: 视频生成提示词(必填) image_url: 参考图片URL(必填) gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数) callback: 可选的回调函数,可以是以下两种签名之一: 1. (task_id, result, error) -> None 2. (task_id, output_path, result, error) -> None output_path: 视频输出路径(可选,会传递给回调函数) duration: 视频时长(秒,默认4秒) ratio: 视频比例(默认"16:9") resolution: 视频分辨率(默认"1080p") watermark: 水印开关(默认"false") camerafixed: 相机固定开关(默认"false") **kwargs: 其他请求参数(会覆盖默认配置) Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 TimeoutError: 如果任务超时 """ # 创建任务 task_response = self.create_video_task( prompt=prompt, image_url=image_url, gen_params=gen_params, duration=duration, ratio=ratio, resolution=resolution, watermark=watermark, camerafixed=camerafixed, **kwargs ) task_id = task_response.get("id") if not task_id: raise APIError("创建任务成功但未返回任务ID") logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...") # 等待任务完成 callback_kwargs = {} if output_path: callback_kwargs["output_path"] = output_path return self.wait_for_task(task_id, callback=callback, callback_kwargs=callback_kwargs) def get_video_url(self, result: Dict[str, Any]) -> Optional[str]: """ 从任务结果中提取视频URL Args: result: 任务结果(从query_video_task或wait_for_task返回) Returns: 视频URL,如果不存在则返回None """ try: content = result.get("content", {}) if isinstance(content, dict): return content.get("video_url") return None except (KeyError, TypeError, AttributeError) as e: logger.warning(f"提取视频URL失败: {e}") return None def get_task_status(self, result: Dict[str, Any]) -> Optional[str]: """ 从任务结果中提取任务状态 Args: result: 任务结果 Returns: 任务状态字符串,如果不存在则返回None """ try: return result.get("status", "").lower() except (KeyError, TypeError, AttributeError): return None if __name__ == "__main__": # 示例用法 client = ArkVideoClient() # 方式1:异步创建任务(立即返回task_id,不阻塞主流程)推荐 # 使用默认生成参数 task_id = client.create_video_task_async( prompt="图中的女生在街道上散步", image_url="https://ark-content-generation-v2-cn-beijing.tos-cn-beijing.volces.com/doubao-seedream-4-5/021766049633300c2c2346a8f7f450117997084af59f03e11d642_0.jpeg?X-Tos-Algorithm=TOS4-HMAC-SHA256&X-Tos-Credential=AKLTYWJkZTExNjA1ZDUyNDc3YzhjNTM5OGIyNjBhNDcyOTQ%2F20251218%2Fcn-beijing%2Ftos%2Frequest&X-Tos-Date=20251218T092054Z&X-Tos-Expires=86400&X-Tos-Signature=a83d797cceaa38226c6489f27892ab9e6651dc7fe84addb37c95fec18358706c&X-Tos-SignedHeaders=host", callback=handle_video_result, output_path="./output/video3.mp4" ) # 方式1b:使用自定义生成参数(推荐方式) # task_id = client.create_video_task_async( # prompt="图中的女生在街道上散步", # image_url="https://example.com/image.jpg", # duration=5, # ratio="9:16", # resolution="720p", # callback=handle_video_result, # output_path="./output/video2.mp4" # ) # 方式1c:使用自定义生成参数字符串 # task_id = client.create_video_task_async( # prompt="图中的女生在街道上散步", # image_url="https://example.com/image.jpg", # gen_params="--dur 5 --rt 9:16 --rs 720p --wm false --cf false", # callback=handle_video_result, # output_path="./output/video2.mp4" # ) print(f"任务已提交,task_id: {task_id},主流程继续执行...") # 等待视频下载完成(可选) while True: if os.path.exists("./output/video3.mp4"): print(f"视频下载完成,退出循环...") break time.sleep(10) print(f"等待10秒...")