| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701 |
- """
- 火山引擎ARK视频生成API客户端
- 封装ARK视频生成API的调用,提供类型安全的接口
- """
- import os
- import time
- import threading
- from typing import Optional, Dict, Any, Callable
- from enum import Enum
- from examples.video_create.utils.tools import download_video
- from .base_client import APIClient, APIError
- from taskflow.logger import get_logger
- from taskflow.config import get_config
- logger = get_logger("api_modules.ark_video_client")
- class TaskStatus(str, Enum):
- """任务状态枚举"""
- PENDING = "pending"
- PROCESSING = "processing"
- SUCCEEDED = "succeeded"
- FAILED = "failed"
- def handle_video_result(
- task_id: str,
- output_path: str,
- result: Optional[Dict],
- error: Optional[str]
- ) -> None:
- if error:
- logger.info(f"\n任务 {task_id} 处理失败:{error}")
- else:
- video_url = result.get("content", {}).get("video_url")
- download_video(video_url, output_path)
- logger.info(f"生成视频已下载:{output_path}")
- class ArkVideoClient(APIClient):
- """
- 火山引擎ARK视频生成API客户端
-
- 封装ARK视频生成API的调用,提供便捷的接口
- """
-
- DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
- DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks"
- DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528"
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- model: Optional[str] = None,
- timeout: int = 60,
- poll_interval: int = 5,
- max_poll_time: int = 500,
- **kwargs
- ):
- """
- 初始化ARK视频生成API客户端
-
- Args:
- api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
- base_url: API基础URL(默认使用官方URL)
- model: 模型名称(如果为None,会尝试从配置中获取)
- timeout: 请求超时时间(秒,默认60秒)
- poll_interval: 轮询间隔(秒,默认5秒)
- max_poll_time: 最大轮询总时间(秒,默认500秒)
- **kwargs: 传递给APIClient的其他参数
- """
- # 获取API密钥(优先级:参数 > 环境变量 > 配置)
- if api_key is None:
- api_key = os.getenv("ARK_API_KEY")
- if api_key is None:
- config = get_config()
- api_key = config.get("api.ark.api_key")
-
- if not api_key:
- raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
-
- # 获取base_url(优先级:参数 > 配置 > 默认值)
- if base_url is None:
- config = get_config()
- base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
-
- # 获取model(优先级:参数 > 配置 > 默认值)
- if model is None:
- config = get_config()
- model = config.get("api.ark.video_model", self.DEFAULT_MODEL)
-
- super().__init__(
- base_url=base_url,
- api_key=api_key,
- timeout=timeout,
- **kwargs
- )
-
- # 保存视频生成相关配置
- self.model = model
- self.poll_interval = poll_interval
- self.max_poll_time = max_poll_time
-
- logger.info(f"ARK视频生成API客户端初始化完成,模型: {self.model}")
-
- def _build_gen_params(
- self,
- duration: Optional[int] = None,
- ratio: Optional[str] = None,
- resolution: Optional[str] = None,
- watermark: Optional[str] = None,
- camerafixed: Optional[str] = None,
- **kwargs
- ) -> str:
- """
- 构建生成参数字符串
-
- 如果参数为None,则使用默认值或从kwargs中获取
- Args:
- duration: 视频时长(秒,默认4秒)
- ratio: 视频比例(默认"16:9")
- resolution: 视频分辨率(默认"1080p")
- watermark: 水印开关(默认"false")
- camerafixed: 相机固定开关(默认"false")
- **kwargs: 其他参数,会优先从kwargs中获取
- Returns:
- 生成参数字符串,格式如:"--dur 4 --rt 16:9 --rs 1080p --wm false --cf false"
- """
- # 默认值
- defaults = {
- "duration": 4,
- "ratio": "16:9",
- "resolution": "1080p",
- "watermark": "false",
- "camerafixed": "false"
- }
-
- # 从kwargs中获取参数,如果没有则使用传入的参数,再没有则使用默认值
- # 优先级:kwargs > 显式参数 > 默认值
- def get_param(key: str, param_value: Any) -> Any:
- if key in kwargs:
- return kwargs[key]
- return param_value if param_value is not None else defaults[key]
-
- duration = get_param("duration", duration)
- ratio = get_param("ratio", ratio)
- resolution = get_param("resolution", resolution)
- watermark = get_param("watermark", watermark)
- camerafixed = get_param("camerafixed", camerafixed)
-
- # 构建参数字符串
- params = [
- f"--dur {duration}",
- f"--rt {ratio}",
- f"--rs {resolution}",
- f"--wm {watermark}",
- f"--cf {camerafixed}"
- ]
-
- return " ".join(params) + " "
-
- def create_video_task(
- self,
- prompt: str,
- image_url: str,
- gen_params: Optional[str] = None,
- duration: Optional[int] = None,
- ratio: Optional[str] = None,
- resolution: Optional[str] = None,
- watermark: Optional[str] = None,
- camerafixed: Optional[str] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 创建视频生成任务
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL)
- gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
- duration: 视频时长(秒,默认4秒)
- ratio: 视频比例(默认"16:9")
- resolution: 视频分辨率(默认"1080p")
- watermark: 水印开关(默认"false")
- camerafixed: 相机固定开关(默认"false")
- **kwargs: 其他请求参数(会覆盖默认配置,包括生成参数)
-
- Returns:
- API响应数据,包含任务ID等信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
-
- 示例:
- >>> client.create_video_task(
- ... prompt="一个美丽的风景",
- ... image_url="https://example.com/image.jpg",
- ... duration=5,
- ... ratio="9:16",
- ... resolution="720p"
- ... )
- """
- if not prompt or not prompt.strip():
- raise ValueError("prompt不能为空")
-
- if not image_url or not image_url.strip():
- raise ValueError("image_url不能为空")
-
- # 验证image_url是否为URL格式
- if not image_url.startswith(("http://", "https://")):
- raise ValueError(
- f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。"
- "如果是本地文件路径,请先上传到云存储获取URL。"
- )
- # 构建生成参数:如果提供了gen_params字符串,直接使用;否则根据参数构建
- if gen_params is None:
- # 合并kwargs和显式参数
- gen_params_kwargs = {
- "duration": duration,
- "ratio": ratio,
- "resolution": resolution,
- "watermark": watermark,
- "camerafixed": camerafixed,
- **kwargs
- }
- gen_params = self._build_gen_params(**gen_params_kwargs)
- else:
- # 如果提供了gen_params字符串,确保以空格结尾(如果没有)
- gen_params = gen_params.strip()
- if gen_params and not gen_params.endswith(" "):
- gen_params += " "
- # 构建请求体
- request_data = {
- "model": kwargs.get("model", self.model),
- "content": [
- {
- "type": "text",
- "text": prompt + gen_params
- },
- {
- "type": "image_url",
- "image_url": {
- "url": image_url
- }
- }
- ],
- **{k: v for k, v in kwargs.items() if k != "model"}
- }
-
- logger.info(f"创建视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...")
- logger.info(f"参考图片: {image_url}")
-
- try:
- response = self.post(
- endpoint=self.DEFAULT_ENDPOINT,
- json=request_data
- )
-
- logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}")
- return response
-
- except APIError as e:
- logger.error(f"创建视频生成任务失败: {e}")
- raise
-
- def query_video_task(self, task_id: str) -> Dict[str, Any]:
- """
- 查询视频生成任务状态
-
- Args:
- task_id: 任务ID(从create_video_task响应中获取)
-
- Returns:
- 任务状态详情,包含状态、视频URL等信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- if not task_id or not task_id.strip():
- raise ValueError("task_id不能为空")
-
- query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}"
-
- logger.debug(f"查询视频生成任务状态,任务ID: {task_id}")
-
- try:
- response = self.get(endpoint=query_endpoint)
-
- status = response.get("status", "").lower()
- logger.debug(f"任务 {task_id} 状态: {status}")
-
- return response
-
- except APIError as e:
- logger.error(f"查询视频生成任务状态失败: {e}")
- raise
-
- def wait_for_task(
- self,
- task_id: str,
- callback: Optional[Callable] = None,
- callback_kwargs: Optional[Dict[str, Any]] = None
- ) -> Dict[str, Any]:
- """
- 等待任务完成(同步轮询)
-
- Args:
- task_id: 任务ID
- callback: 可选的回调函数,可以是以下两种签名之一:
- 1. (task_id, result, error) -> None
- 2. (task_id, output_path, result, error) -> None
- callback_kwargs: 传递给回调函数的额外关键字参数(如 output_path)
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- TimeoutError: 如果任务超时
- """
- start_time = time.time()
- callback_kwargs = callback_kwargs or {}
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > self.max_poll_time:
- error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
- logger.error(f"任务 {task_id} {error_msg}")
- if callback:
- # 检查回调函数签名,支持两种格式
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- # 4参数版本:(task_id, output_path, result, error)
- callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
- else:
- # 3参数版本:(task_id, result, error)
- callback(task_id, {}, error_msg)
- raise TimeoutError(error_msg)
-
- # 查询任务状态
- result = self.query_video_task(task_id)
-
- if not result:
- logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
- time.sleep(self.poll_interval)
- continue
-
- # 解析状态
- status = result.get("status", "").lower()
-
- if status == TaskStatus.SUCCEEDED:
- logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
- if callback:
- # 检查回调函数签名,支持两种格式
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- # 4参数版本:(task_id, output_path, result, error)
- callback(task_id, callback_kwargs.get("output_path", ""), result, None)
- else:
- # 3参数版本:(task_id, result, error)
- callback(task_id, result, None)
- return result
-
- elif status == TaskStatus.FAILED:
- error_msg = result.get("error", {}).get("message", "未知错误")
- logger.error(f"任务 {task_id} 失败: {error_msg}")
- if callback:
- # 检查回调函数签名,支持两种格式
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- # 4参数版本:(task_id, output_path, result, error)
- callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
- else:
- # 3参数版本:(task_id, result, error)
- callback(task_id, {}, error_msg)
- raise APIError(f"任务失败: {error_msg}")
-
- elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
- logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
- time.sleep(self.poll_interval)
-
- else:
- logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
- time.sleep(self.poll_interval)
-
- def _background_poll(
- self,
- task_id: str,
- callback: Optional[Callable] = None,
- callback_kwargs: Optional[Dict[str, Any]] = None
- ):
- """
- 后台轮询任务状态的线程函数
-
- Args:
- task_id: 任务ID
- callback: 回调函数
- callback_kwargs: 传递给回调函数的额外参数
- """
- callback_kwargs = callback_kwargs or {}
- start_time = time.time()
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > self.max_poll_time:
- error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
- logger.error(f"任务 {task_id} {error_msg}")
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
- return
-
- # 查询任务状态
- try:
- result = self.query_video_task(task_id)
- except Exception as e:
- logger.error(f"查询任务 {task_id} 状态失败: {e}")
- time.sleep(self.poll_interval)
- continue
-
- if not result:
- logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
- time.sleep(self.poll_interval)
- continue
-
- # 解析状态
- status = result.get("status", "").lower()
-
- if status == TaskStatus.SUCCEEDED:
- logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- callback(task_id, callback_kwargs.get("output_path", ""), result, None)
- else:
- callback(task_id, result, None)
- return
-
- elif status == TaskStatus.FAILED:
- error_msg = result.get("error", {}).get("message", "未知错误")
- logger.error(f"任务 {task_id} 失败: {error_msg}")
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- callback(task_id, callback_kwargs.get("output_path", ""), {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
- return
-
- elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
- logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
- time.sleep(self.poll_interval)
-
- else:
- logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
- time.sleep(self.poll_interval)
-
- def create_video_task_async(
- self,
- prompt: str,
- image_url: str,
- gen_params: Optional[str] = None,
- callback: Optional[Callable] = handle_video_result,
- output_path: Optional[str] = None,
- duration: Optional[int] = None,
- ratio: Optional[str] = None,
- resolution: Optional[str] = None,
- watermark: Optional[str] = None,
- camerafixed: Optional[str] = None,
- **kwargs
- ) -> Optional[str]:
- """
- 创建视频生成任务并立即返回task_id(不阻塞主流程)
-
- 任务会在后台线程中轮询,完成后调用回调函数。
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填)
- gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
- callback: 可选的回调函数,可以是以下两种签名之一:
- 1. (task_id, result, error) -> None
- 2. (task_id, output_path, result, error) -> None
- output_path: 视频输出路径(可选,会传递给回调函数)
- duration: 视频时长(秒,默认4秒)
- ratio: 视频比例(默认"16:9")
- resolution: 视频分辨率(默认"1080p")
- watermark: 水印开关(默认"false")
- camerafixed: 相机固定开关(默认"false")
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- 任务ID(task_id),如果创建失败则返回None
-
- Raises:
- APIError: 如果创建任务失败
- """
- # 创建任务
- task_response = self.create_video_task(
- prompt=prompt,
- image_url=image_url,
- gen_params=gen_params,
- duration=duration,
- ratio=ratio,
- resolution=resolution,
- watermark=watermark,
- camerafixed=camerafixed,
- **kwargs
- )
-
- task_id = task_response.get("id")
- if not task_id:
- logger.error("任务提交失败,无法启动后台轮询")
- return None
-
- logger.info(f"任务提交成功,task_id: {task_id},启动后台轮询...")
-
- # 准备回调参数
- callback_kwargs = {}
- if output_path:
- callback_kwargs["output_path"] = output_path
-
- # 启动后台线程轮询结果
- poll_thread = threading.Thread(
- target=self._background_poll,
- args=(task_id, callback),
- kwargs={"callback_kwargs": callback_kwargs},
- daemon=True # 守护线程:主程序退出时自动结束
- )
- poll_thread.start()
-
- return task_id
-
- def create_and_wait(
- self,
- prompt: str,
- image_url: str,
- gen_params: Optional[str] = None,
- callback: Optional[Callable] = None,
- output_path: Optional[str] = None,
- duration: Optional[int] = None,
- ratio: Optional[str] = None,
- resolution: Optional[str] = None,
- watermark: Optional[str] = None,
- camerafixed: Optional[str] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 创建视频生成任务并等待完成(同步方法,会阻塞)
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填)
- gen_params: 自定义生成参数字符串(可选,如果提供则忽略其他生成参数)
- callback: 可选的回调函数,可以是以下两种签名之一:
- 1. (task_id, result, error) -> None
- 2. (task_id, output_path, result, error) -> None
- output_path: 视频输出路径(可选,会传递给回调函数)
- duration: 视频时长(秒,默认4秒)
- ratio: 视频比例(默认"16:9")
- resolution: 视频分辨率(默认"1080p")
- watermark: 水印开关(默认"false")
- camerafixed: 相机固定开关(默认"false")
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- TimeoutError: 如果任务超时
- """
- # 创建任务
- task_response = self.create_video_task(
- prompt=prompt,
- image_url=image_url,
- gen_params=gen_params,
- duration=duration,
- ratio=ratio,
- resolution=resolution,
- watermark=watermark,
- camerafixed=camerafixed,
- **kwargs
- )
-
- task_id = task_response.get("id")
- if not task_id:
- raise APIError("创建任务成功但未返回任务ID")
-
- logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...")
-
- # 等待任务完成
- callback_kwargs = {}
- if output_path:
- callback_kwargs["output_path"] = output_path
-
- return self.wait_for_task(task_id, callback=callback, callback_kwargs=callback_kwargs)
-
- def get_video_url(self, result: Dict[str, Any]) -> Optional[str]:
- """
- 从任务结果中提取视频URL
-
- Args:
- result: 任务结果(从query_video_task或wait_for_task返回)
-
- Returns:
- 视频URL,如果不存在则返回None
- """
- try:
- content = result.get("content", {})
- if isinstance(content, dict):
- return content.get("video_url")
- return None
- except (KeyError, TypeError, AttributeError) as e:
- logger.warning(f"提取视频URL失败: {e}")
- return None
-
- def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
- """
- 从任务结果中提取任务状态
-
- Args:
- result: 任务结果
-
- Returns:
- 任务状态字符串,如果不存在则返回None
- """
- try:
- return result.get("status", "").lower()
- except (KeyError, TypeError, AttributeError):
- return None
- if __name__ == "__main__":
- # 示例用法
- client = ArkVideoClient()
-
- # 方式1:异步创建任务(立即返回task_id,不阻塞主流程)推荐
- # 使用默认生成参数
- task_id = client.create_video_task_async(
- prompt="图中的女生在街道上散步",
- image_url="https://ark-content-generation-v2-cn-beijing.tos-cn-beijing.volces.com/doubao-seedream-4-5/021766049633300c2c2346a8f7f450117997084af59f03e11d642_0.jpeg?X-Tos-Algorithm=TOS4-HMAC-SHA256&X-Tos-Credential=AKLTYWJkZTExNjA1ZDUyNDc3YzhjNTM5OGIyNjBhNDcyOTQ%2F20251218%2Fcn-beijing%2Ftos%2Frequest&X-Tos-Date=20251218T092054Z&X-Tos-Expires=86400&X-Tos-Signature=a83d797cceaa38226c6489f27892ab9e6651dc7fe84addb37c95fec18358706c&X-Tos-SignedHeaders=host",
- callback=handle_video_result,
- output_path="./output/video3.mp4"
- )
-
- # 方式1b:使用自定义生成参数(推荐方式)
- # task_id = client.create_video_task_async(
- # prompt="图中的女生在街道上散步",
- # image_url="https://example.com/image.jpg",
- # duration=5,
- # ratio="9:16",
- # resolution="720p",
- # callback=handle_video_result,
- # output_path="./output/video2.mp4"
- # )
-
- # 方式1c:使用自定义生成参数字符串
- # task_id = client.create_video_task_async(
- # prompt="图中的女生在街道上散步",
- # image_url="https://example.com/image.jpg",
- # gen_params="--dur 5 --rt 9:16 --rs 720p --wm false --cf false",
- # callback=handle_video_result,
- # output_path="./output/video2.mp4"
- # )
-
- print(f"任务已提交,task_id: {task_id},主流程继续执行...")
- # 等待视频下载完成(可选)
- while True:
- if os.path.exists("./output/video3.mp4"):
- print(f"视频下载完成,退出循环...")
- break
- time.sleep(10)
- print(f"等待10秒...")
|