|
|
@@ -0,0 +1,701 @@
|
|
|
+"""
|
|
|
+火山引擎ARK视频生成API客户端
|
|
|
+封装ARK视频生成API的调用,提供类型安全的接口
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import time
|
|
|
+import threading
|
|
|
+from typing import Optional, Dict, Any, Callable
|
|
|
+from enum import Enum
|
|
|
+
|
|
|
+from examples.video_create.utils.tools import download_video
|
|
|
+from .base_client import APIClient, APIError
|
|
|
+from taskflow.logger import get_logger
|
|
|
+from taskflow.config import get_config
|
|
|
+
|
|
|
+logger = get_logger("api_modules.ark_video_client")
|
|
|
+
|
|
|
+
|
|
|
+class TaskStatus(str, Enum):
|
|
|
+ """任务状态枚举"""
|
|
|
+ PENDING = "pending"
|
|
|
+ PROCESSING = "processing"
|
|
|
+ SUCCEEDED = "succeeded"
|
|
|
+ FAILED = "failed"
|
|
|
+
|
|
|
+def handle_video_result(
|
|
|
+ task_id: str,
|
|
|
+ output_path: str,
|
|
|
+ result: Optional[Dict],
|
|
|
+ error: Optional[str]
|
|
|
+) -> None:
|
|
|
+ if error:
|
|
|
+ logger.info(f"\n任务 {task_id} 处理失败:{error}")
|
|
|
+ else:
|
|
|
+ video_url = result.get("content", {}).get("video_url")
|
|
|
+ download_video(video_url, output_path)
|
|
|
+ logger.info(f"生成视频已下载:{output_path}")
|
|
|
+
|
|
|
+class ArkVideoClient(APIClient):
|
|
|
+ """
|
|
|
+ 火山引擎ARK视频生成API客户端
|
|
|
+
|
|
|
+ 封装ARK视频生成API的调用,提供便捷的接口
|
|
|
+ """
|
|
|
+
|
|
|
+ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
|
|
|
+ DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks"
|
|
|
+ DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528"
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ api_key: Optional[str] = None,
|
|
|
+ base_url: Optional[str] = None,
|
|
|
+ model: Optional[str] = None,
|
|
|
+ timeout: int = 60,
|
|
|
+ poll_interval: int = 5,
|
|
|
+ max_poll_time: int = 500,
|
|
|
+ **kwargs
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 初始化ARK视频生成API客户端
|
|
|
+
|
|
|
+ Args:
|
|
|
+ api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
|
|
|
+ base_url: API基础URL(默认使用官方URL)
|
|
|
+ model: 模型名称(如果为None,会尝试从配置中获取)
|
|
|
+ timeout: 请求超时时间(秒,默认60秒)
|
|
|
+ poll_interval: 轮询间隔(秒,默认5秒)
|
|
|
+ max_poll_time: 最大轮询总时间(秒,默认500秒)
|
|
|
+ **kwargs: 传递给APIClient的其他参数
|
|
|
+ """
|
|
|
+ # 获取API密钥(优先级:参数 > 环境变量 > 配置)
|
|
|
+ if api_key is None:
|
|
|
+ api_key = os.getenv("ARK_API_KEY")
|
|
|
+ if api_key is None:
|
|
|
+ config = get_config()
|
|
|
+ api_key = config.get("api.ark.api_key")
|
|
|
+
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
|
|
|
+
|
|
|
+ # 获取base_url(优先级:参数 > 配置 > 默认值)
|
|
|
+ if base_url is None:
|
|
|
+ config = get_config()
|
|
|
+ base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
|
|
|
+
|
|
|
+ # 获取model(优先级:参数 > 配置 > 默认值)
|
|
|
+ if model is None:
|
|
|
+ config = get_config()
|
|
|
+ model = config.get("api.ark.video_model", self.DEFAULT_MODEL)
|
|
|
+
|
|
|
+ super().__init__(
|
|
|
+ base_url=base_url,
|
|
|
+ api_key=api_key,
|
|
|
+ timeout=timeout,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存视频生成相关配置
|
|
|
+ self.model = model
|
|
|
+ self.poll_interval = poll_interval
|
|
|
+ self.max_poll_time = max_poll_time
|
|
|
+
|
|
|
+ logger.info(f"ARK视频生成API客户端初始化完成,模型: {self.model}")
|
|
|
+
|
|
|
+
|
|
|
+ def _build_gen_params(
|
|
|
+ self,
|
|
|
+ duration: Optional[int] = None,
|
|
|
+ ratio: Optional[str] = None,
|
|
|
+ resolution: Optional[str] = None,
|
|
|
+ watermark: Optional[str] = None,
|
|
|
+ camerafixed: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
+ ) -> str:
|
|
|
+ """
|
|
|
+ 构建生成参数字符串
|
|
|
+
|
|
|
+ 如果参数为None,则使用默认值或从kwargs中获取
|
|
|
+
|
|
|
+ Args:
|
|
|
+ duration: 视频时长(秒,默认4秒)
|
|
|
+ ratio: 视频比例(默认"16:9")
|
|
|
+ resolution: 视频分辨率(默认"1080p")
|
|
|
+ watermark: 水印开关(默认"false")
|
|
|
+ camerafixed: 相机固定开关(默认"false")
|
|
|
+ **kwargs: 其他参数,会优先从kwargs中获取
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 生成参数字符串,格式如:"--dur 4 --rt 16:9 --rs 1080p --wm false --cf false"
|
|
|
+ """
|
|
|
+ # 默认值
|
|
|
+ defaults = {
|
|
|
+ "duration": 4,
|
|
|
+ "ratio": "16:9",
|
|
|
+ "resolution": "1080p",
|
|
|
+ "watermark": "false",
|
|
|
+ "camerafixed": "false"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 从kwargs中获取参数,如果没有则使用传入的参数,再没有则使用默认值
|
|
|
+ # 优先级:kwargs > 显式参数 > 默认值
|
|
|
+ def get_param(key: str, param_value: Any) -> Any:
|
|
|
+ if key in kwargs:
|
|
|
+ return kwargs[key]
|
|
|
+ return param_value if param_value is not None else defaults[key]
|
|
|
+
|
|
|
+ duration = get_param("duration", duration)
|
|
|
+ ratio = get_param("ratio", ratio)
|
|
|
+ resolution = get_param("resolution", resolution)
|
|
|
+ watermark = get_param("watermark", watermark)
|
|
|
+ camerafixed = get_param("camerafixed", camerafixed)
|
|
|
+
|
|
|
+ # 构建参数字符串
|
|
|
+ params = [
|
|
|
+ f"--dur {duration}",
|
|
|
+ f"--rt {ratio}",
|
|
|
+ f"--rs {resolution}",
|
|
|
+ f"--wm {watermark}",
|
|
|
+ f"--cf {camerafixed}"
|
|
|
+ ]
|
|
|
+
|
|
|
+ return " ".join(params) + " "
|
|
|
+
|
|
|
+ def create_video_task(
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ image_url: str,
|
|
|
+ gen_params: Optional[str] = None,
|
|
|
+ duration: Optional[int] = None,
|
|
|
+ ratio: Optional[str] = None,
|
|
|
+ resolution: Optional[str] = None,
|
|
|
+ watermark: Optional[str] = None,
|
|
|
+ camerafixed: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 创建视频生成任务
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt: 视频生成提示词(必填)
|
|
|
+ image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL)
|
|
|
+ gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
|
|
|
+ duration: 视频时长(秒,默认4秒)
|
|
|
+ ratio: 视频比例(默认"16:9")
|
|
|
+ resolution: 视频分辨率(默认"1080p")
|
|
|
+ watermark: 水印开关(默认"false")
|
|
|
+ camerafixed: 相机固定开关(默认"false")
|
|
|
+ **kwargs: 其他请求参数(会覆盖默认配置,包括生成参数)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ API响应数据,包含任务ID等信息
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ APIError: 如果请求失败
|
|
|
+ ValueError: 如果参数无效
|
|
|
+
|
|
|
+ 示例:
|
|
|
+ >>> client.create_video_task(
|
|
|
+ ... prompt="一个美丽的风景",
|
|
|
+ ... image_url="https://example.com/image.jpg",
|
|
|
+ ... duration=5,
|
|
|
+ ... ratio="9:16",
|
|
|
+ ... resolution="720p"
|
|
|
+ ... )
|
|
|
+ """
|
|
|
+ if not prompt or not prompt.strip():
|
|
|
+ raise ValueError("prompt不能为空")
|
|
|
+
|
|
|
+ if not image_url or not image_url.strip():
|
|
|
+ raise ValueError("image_url不能为空")
|
|
|
+
|
|
|
+ # 验证image_url是否为URL格式
|
|
|
+ if not image_url.startswith(("http://", "https://")):
|
|
|
+ raise ValueError(
|
|
|
+ f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。"
|
|
|
+ "如果是本地文件路径,请先上传到云存储获取URL。"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建生成参数:如果提供了gen_params字符串,直接使用;否则根据参数构建
|
|
|
+ if gen_params is None:
|
|
|
+ # 合并kwargs和显式参数
|
|
|
+ gen_params_kwargs = {
|
|
|
+ "duration": duration,
|
|
|
+ "ratio": ratio,
|
|
|
+ "resolution": resolution,
|
|
|
+ "watermark": watermark,
|
|
|
+ "camerafixed": camerafixed,
|
|
|
+ **kwargs
|
|
|
+ }
|
|
|
+ gen_params = self._build_gen_params(**gen_params_kwargs)
|
|
|
+ else:
|
|
|
+ # 如果提供了gen_params字符串,确保以空格结尾(如果没有)
|
|
|
+ gen_params = gen_params.strip()
|
|
|
+ if gen_params and not gen_params.endswith(" "):
|
|
|
+ gen_params += " "
|
|
|
+
|
|
|
+ # 构建请求体
|
|
|
+ request_data = {
|
|
|
+ "model": kwargs.get("model", self.model),
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": prompt + gen_params
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": image_url
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ **{k: v for k, v in kwargs.items() if k != "model"}
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"创建视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...")
|
|
|
+ logger.info(f"参考图片: {image_url}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = self.post(
|
|
|
+ endpoint=self.DEFAULT_ENDPOINT,
|
|
|
+ json=request_data
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}")
|
|
|
+ return response
|
|
|
+
|
|
|
+ except APIError as e:
|
|
|
+ logger.error(f"创建视频生成任务失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def query_video_task(self, task_id: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 查询视频生成任务状态
|
|
|
+
|
|
|
+ Args:
|
|
|
+ task_id: 任务ID(从create_video_task响应中获取)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 任务状态详情,包含状态、视频URL等信息
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ APIError: 如果请求失败
|
|
|
+ ValueError: 如果参数无效
|
|
|
+ """
|
|
|
+ if not task_id or not task_id.strip():
|
|
|
+ raise ValueError("task_id不能为空")
|
|
|
+
|
|
|
+ query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}"
|
|
|
+
|
|
|
+ logger.debug(f"查询视频生成任务状态,任务ID: {task_id}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = self.get(endpoint=query_endpoint)
|
|
|
+
|
|
|
+ status = response.get("status", "").lower()
|
|
|
+ logger.debug(f"任务 {task_id} 状态: {status}")
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+ except APIError as e:
|
|
|
+ logger.error(f"查询视频生成任务状态失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def wait_for_task(
|
|
|
+ self,
|
|
|
+ task_id: str,
|
|
|
+ callback: Optional[Callable] = None,
|
|
|
+ callback_kwargs: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 等待任务完成(同步轮询)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ task_id: 任务ID
|
|
|
+ callback: 可选的回调函数,可以是以下两种签名之一:
|
|
|
+ 1. (task_id, result, error) -> None
|
|
|
+ 2. (task_id, output_path, result, error) -> None
|
|
|
+ callback_kwargs: 传递给回调函数的额外关键字参数(如 output_path)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 任务完成后的结果
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ APIError: 如果请求失败
|
|
|
+ TimeoutError: 如果任务超时
|
|
|
+ """
|
|
|
+ start_time = time.time()
|
|
|
+ callback_kwargs = callback_kwargs or {}
|
|
|
+
|
|
|
+ while True:
|
|
|
+ elapsed = time.time() - start_time
|
|
|
+
|
|
|
+ if elapsed > self.max_poll_time:
|
|
|
+ error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
|
|
|
+ logger.error(f"任务 {task_id} {error_msg}")
|
|
|
+ if callback:
|
|
|
+ # 检查回调函数签名,支持两种格式
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ # 4参数版本:(task_id, output_path, result, error)
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
|
|
|
+ else:
|
|
|
+ # 3参数版本:(task_id, result, error)
|
|
|
+ callback(task_id, {}, error_msg)
|
|
|
+ raise TimeoutError(error_msg)
|
|
|
+
|
|
|
+ # 查询任务状态
|
|
|
+ result = self.query_video_task(task_id)
|
|
|
+
|
|
|
+ if not result:
|
|
|
+ logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 解析状态
|
|
|
+ status = result.get("status", "").lower()
|
|
|
+
|
|
|
+ if status == TaskStatus.SUCCEEDED:
|
|
|
+ logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
|
|
|
+ if callback:
|
|
|
+ # 检查回调函数签名,支持两种格式
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ # 4参数版本:(task_id, output_path, result, error)
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), result, None)
|
|
|
+ else:
|
|
|
+ # 3参数版本:(task_id, result, error)
|
|
|
+ callback(task_id, result, None)
|
|
|
+ return result
|
|
|
+
|
|
|
+ elif status == TaskStatus.FAILED:
|
|
|
+ error_msg = result.get("error", {}).get("message", "未知错误")
|
|
|
+ logger.error(f"任务 {task_id} 失败: {error_msg}")
|
|
|
+ if callback:
|
|
|
+ # 检查回调函数签名,支持两种格式
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ # 4参数版本:(task_id, output_path, result, error)
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
|
|
|
+ else:
|
|
|
+ # 3参数版本:(task_id, result, error)
|
|
|
+ callback(task_id, {}, error_msg)
|
|
|
+ raise APIError(f"任务失败: {error_msg}")
|
|
|
+
|
|
|
+ elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
|
|
|
+ logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+
|
|
|
+ else:
|
|
|
+ logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+
|
|
|
+ def _background_poll(
|
|
|
+ self,
|
|
|
+ task_id: str,
|
|
|
+ callback: Optional[Callable] = None,
|
|
|
+ callback_kwargs: Optional[Dict[str, Any]] = None
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 后台轮询任务状态的线程函数
|
|
|
+
|
|
|
+ Args:
|
|
|
+ task_id: 任务ID
|
|
|
+ callback: 回调函数
|
|
|
+ callback_kwargs: 传递给回调函数的额外参数
|
|
|
+ """
|
|
|
+ callback_kwargs = callback_kwargs or {}
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ elapsed = time.time() - start_time
|
|
|
+
|
|
|
+ if elapsed > self.max_poll_time:
|
|
|
+ error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
|
|
|
+ logger.error(f"任务 {task_id} {error_msg}")
|
|
|
+ if callback:
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
|
|
|
+ else:
|
|
|
+ callback(task_id, {}, error_msg)
|
|
|
+ return
|
|
|
+
|
|
|
+ # 查询任务状态
|
|
|
+ try:
|
|
|
+ result = self.query_video_task(task_id)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"查询任务 {task_id} 状态失败: {e}")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if not result:
|
|
|
+ logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 解析状态
|
|
|
+ status = result.get("status", "").lower()
|
|
|
+
|
|
|
+ if status == TaskStatus.SUCCEEDED:
|
|
|
+ logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
|
|
|
+ if callback:
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), result, None)
|
|
|
+ else:
|
|
|
+ callback(task_id, result, None)
|
|
|
+ return
|
|
|
+
|
|
|
+ elif status == TaskStatus.FAILED:
|
|
|
+ error_msg = result.get("error", {}).get("message", "未知错误")
|
|
|
+ logger.error(f"任务 {task_id} 失败: {error_msg}")
|
|
|
+ if callback:
|
|
|
+ import inspect
|
|
|
+ sig = inspect.signature(callback)
|
|
|
+ param_count = len(sig.parameters)
|
|
|
+ if param_count == 4:
|
|
|
+ callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
|
|
|
+ else:
|
|
|
+ callback(task_id, {}, error_msg)
|
|
|
+ return
|
|
|
+
|
|
|
+ elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
|
|
|
+ logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+
|
|
|
+ else:
|
|
|
+ logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
|
|
|
+ time.sleep(self.poll_interval)
|
|
|
+
|
|
|
+ def create_video_task_async(
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ image_url: str,
|
|
|
+ gen_params: Optional[str] = None,
|
|
|
+ callback: Optional[Callable] = handle_video_result,
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ duration: Optional[int] = None,
|
|
|
+ ratio: Optional[str] = None,
|
|
|
+ resolution: Optional[str] = None,
|
|
|
+ watermark: Optional[str] = None,
|
|
|
+ camerafixed: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
+ ) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 创建视频生成任务并立即返回task_id(不阻塞主流程)
|
|
|
+
|
|
|
+ 任务会在后台线程中轮询,完成后调用回调函数。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt: 视频生成提示词(必填)
|
|
|
+ image_url: 参考图片URL(必填)
|
|
|
+ gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
|
|
|
+ callback: 可选的回调函数,可以是以下两种签名之一:
|
|
|
+ 1. (task_id, result, error) -> None
|
|
|
+ 2. (task_id, output_path, result, error) -> None
|
|
|
+ output_path: 视频输出路径(可选,会传递给回调函数)
|
|
|
+ duration: 视频时长(秒,默认4秒)
|
|
|
+ ratio: 视频比例(默认"16:9")
|
|
|
+ resolution: 视频分辨率(默认"1080p")
|
|
|
+ watermark: 水印开关(默认"false")
|
|
|
+ camerafixed: 相机固定开关(默认"false")
|
|
|
+ **kwargs: 其他请求参数(会覆盖默认配置)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 任务ID(task_id),如果创建失败则返回None
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ APIError: 如果创建任务失败
|
|
|
+ """
|
|
|
+ # 创建任务
|
|
|
+ task_response = self.create_video_task(
|
|
|
+ prompt=prompt,
|
|
|
+ image_url=image_url,
|
|
|
+ gen_params=gen_params,
|
|
|
+ duration=duration,
|
|
|
+ ratio=ratio,
|
|
|
+ resolution=resolution,
|
|
|
+ watermark=watermark,
|
|
|
+ camerafixed=camerafixed,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ task_id = task_response.get("id")
|
|
|
+ if not task_id:
|
|
|
+ logger.error("任务提交失败,无法启动后台轮询")
|
|
|
+ return None
|
|
|
+
|
|
|
+ logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...")
|
|
|
+
|
|
|
+ # 准备回调参数
|
|
|
+ callback_kwargs = {}
|
|
|
+ if output_path:
|
|
|
+ callback_kwargs["output_path"] = output_path
|
|
|
+
|
|
|
+ # 启动后台线程轮询结果
|
|
|
+ poll_thread = threading.Thread(
|
|
|
+ target=self._background_poll,
|
|
|
+ args=(task_id, callback),
|
|
|
+ kwargs={"callback_kwargs": callback_kwargs},
|
|
|
+ daemon=True # 守护线程:主程序退出时自动结束
|
|
|
+ )
|
|
|
+ poll_thread.start()
|
|
|
+
|
|
|
+ return task_id
|
|
|
+
|
|
|
+ def create_and_wait(
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ image_url: str,
|
|
|
+ gen_params: Optional[str] = None,
|
|
|
+ callback: Optional[Callable] = None,
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ duration: Optional[int] = None,
|
|
|
+ ratio: Optional[str] = None,
|
|
|
+ resolution: Optional[str] = None,
|
|
|
+ watermark: Optional[str] = None,
|
|
|
+ camerafixed: Optional[str] = None,
|
|
|
+ **kwargs
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 创建视频生成任务并等待完成(同步方法,会阻塞)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt: 视频生成提示词(必填)
|
|
|
+ image_url: 参考图片URL(必填)
|
|
|
+ gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
|
|
|
+ callback: 可选的回调函数,可以是以下两种签名之一:
|
|
|
+ 1. (task_id, result, error) -> None
|
|
|
+ 2. (task_id, output_path, result, error) -> None
|
|
|
+ output_path: 视频输出路径(可选,会传递给回调函数)
|
|
|
+ duration: 视频时长(秒,默认4秒)
|
|
|
+ ratio: 视频比例(默认"16:9")
|
|
|
+ resolution: 视频分辨率(默认"1080p")
|
|
|
+ watermark: 水印开关(默认"false")
|
|
|
+ camerafixed: 相机固定开关(默认"false")
|
|
|
+ **kwargs: 其他请求参数(会覆盖默认配置)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 任务完成后的结果
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ APIError: 如果请求失败
|
|
|
+ TimeoutError: 如果任务超时
|
|
|
+ """
|
|
|
+ # 创建任务
|
|
|
+ task_response = self.create_video_task(
|
|
|
+ prompt=prompt,
|
|
|
+ image_url=image_url,
|
|
|
+ gen_params=gen_params,
|
|
|
+ duration=duration,
|
|
|
+ ratio=ratio,
|
|
|
+ resolution=resolution,
|
|
|
+ watermark=watermark,
|
|
|
+ camerafixed=camerafixed,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
+ task_id = task_response.get("id")
|
|
|
+ if not task_id:
|
|
|
+ raise APIError("创建任务成功但未返回任务ID")
|
|
|
+
|
|
|
+ logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...")
|
|
|
+
|
|
|
+ # 等待任务完成
|
|
|
+ callback_kwargs = {}
|
|
|
+ if output_path:
|
|
|
+ callback_kwargs["output_path"] = output_path
|
|
|
+
|
|
|
+ return self.wait_for_task(task_id, callback=callback, callback_kwargs=callback_kwargs)
|
|
|
+
|
|
|
+ def get_video_url(self, result: Dict[str, Any]) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 从任务结果中提取视频URL
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: 任务结果(从query_video_task或wait_for_task返回)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 视频URL,如果不存在则返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ content = result.get("content", {})
|
|
|
+ if isinstance(content, dict):
|
|
|
+ return content.get("video_url")
|
|
|
+ return None
|
|
|
+ except (KeyError, TypeError, AttributeError) as e:
|
|
|
+ logger.warning(f"提取视频URL失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 从任务结果中提取任务状态
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: 任务结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 任务状态字符串,如果不存在则返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ return result.get("status", "").lower()
|
|
|
+ except (KeyError, TypeError, AttributeError):
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 示例用法
|
|
|
+ client = ArkVideoClient()
|
|
|
+
|
|
|
+ # 方式1:异步创建任务(立即返回task_id,不阻塞主流程)推荐
|
|
|
+ # 使用默认生成参数
|
|
|
+ task_id = client.create_video_task_async(
|
|
|
+ prompt="图中的女生在街道上散步",
|
|
|
+ image_url="https://ark-content-generation-v2-cn-beijing.tos-cn-beijing.volces.com/doubao-seedream-4-5/021766049633300c2c2346a8f7f450117997084af59f03e11d642_0.jpeg?X-Tos-Algorithm=TOS4-HMAC-SHA256&X-Tos-Credential=AKLTYWJkZTExNjA1ZDUyNDc3YzhjNTM5OGIyNjBhNDcyOTQ%2F20251218%2Fcn-beijing%2Ftos%2Frequest&X-Tos-Date=20251218T092054Z&X-Tos-Expires=86400&X-Tos-Signature=a83d797cceaa38226c6489f27892ab9e6651dc7fe84addb37c95fec18358706c&X-Tos-SignedHeaders=host",
|
|
|
+ callback=handle_video_result,
|
|
|
+ output_path="./output/video3.mp4"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 方式1b:使用自定义生成参数(推荐方式)
|
|
|
+ # task_id = client.create_video_task_async(
|
|
|
+ # prompt="图中的女生在街道上散步",
|
|
|
+ # image_url="https://example.com/image.jpg",
|
|
|
+ # duration=5,
|
|
|
+ # ratio="9:16",
|
|
|
+ # resolution="720p",
|
|
|
+ # callback=handle_video_result,
|
|
|
+ # output_path="./output/video2.mp4"
|
|
|
+ # )
|
|
|
+
|
|
|
+ # 方式1c:使用自定义生成参数字符串
|
|
|
+ # task_id = client.create_video_task_async(
|
|
|
+ # prompt="图中的女生在街道上散步",
|
|
|
+ # image_url="https://example.com/image.jpg",
|
|
|
+ # gen_params="--dur 5 --rt 9:16 --rs 720p --wm false --cf false",
|
|
|
+ # callback=handle_video_result,
|
|
|
+ # output_path="./output/video2.mp4"
|
|
|
+ # )
|
|
|
+
|
|
|
+ print(f"任务已提交,task_id: {task_id},主流程继续执行...")
|
|
|
+
|
|
|
+ # 等待视频下载完成(可选)
|
|
|
+ while True:
|
|
|
+ if os.path.exists("./output/video3.mp4"):
|
|
|
+ print(f"视频下载完成,退出循环...")
|
|
|
+ break
|
|
|
+ time.sleep(10)
|
|
|
+ print(f"等待10秒...")
|
|
|
+
|