| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- """
- 火山引擎ARK视频生成API异步客户端
- 封装ARK视频生成API的异步调用,提供类型安全的接口
- """
- import os
- import time
- import asyncio
- from typing import Optional, Dict, Any, Callable, Tuple
- from .base_client_async import AsyncAPIClient, APIError
- from .ark_video_client import TaskStatus # 复用同步版本的枚举
- from taskflow.logger import get_logger
- from taskflow.config import get_config
- from examples.video_create.utils.tools import upload_file_to_tos, download_video
- logger = get_logger("api_modules.ark_video_client_async")
- async def handle_video_result(
- task_id: str,
- output_path: str,
- result: Optional[Dict],
- error: Optional[str]
- ) -> None:
- """
- 处理视频生成结果的异步回调函数
-
- Args:
- task_id: 任务ID
- output_path: 视频输出路径
- result: 任务结果(如果成功)
- error: 错误信息(如果失败)
- """
- if error:
- logger.info(f"\n任务 {task_id} 处理失败:{error}")
- else:
- video_url = result.get("content", {}).get("video_url")
- if video_url:
- # 使用 asyncio.to_thread 在后台线程中执行同步的下载函数
- await asyncio.to_thread(download_video, video_url, output_path)
- logger.info(f"生成视频已下载:{output_path}")
- else:
- logger.warning(f"任务 {task_id} 完成但未获取到视频URL")
- class AsyncArkVideoClient(AsyncAPIClient):
- """
- 火山引擎ARK视频生成API异步客户端
-
- 封装ARK视频生成API的异步调用,提供便捷的接口
- """
-
- DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
- DEFAULT_ENDPOINT = "/api/v3/contents/generations/tasks"
- DEFAULT_MODEL = "doubao-seedance-1-0-pro-250528"
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- model: Optional[str] = None,
- timeout: int = 60,
- poll_interval: int = 5,
- max_poll_time: int = 500,
- **kwargs
- ):
- """
- 初始化ARK视频生成API异步客户端
-
- Args:
- api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
- base_url: API基础URL(默认使用官方URL)
- model: 模型名称(如果为None,会尝试从配置中获取)
- timeout: 请求超时时间(秒,默认60秒)
- poll_interval: 轮询间隔(秒,默认5秒)
- max_poll_time: 最大轮询总时间(秒,默认500秒)
- **kwargs: 传递给AsyncAPIClient的其他参数
- """
- # 获取API密钥(优先级:参数 > 环境变量 > 配置)
- if api_key is None:
- api_key = os.getenv("ARK_API_KEY")
- if api_key is None:
- config = get_config()
- api_key = config.get("api.ark.api_key")
-
- if not api_key:
- raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
-
- # 获取base_url(优先级:参数 > 配置 > 默认值)
- if base_url is None:
- config = get_config()
- base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
-
- # 获取model(优先级:参数 > 配置 > 默认值)
- if model is None:
- config = get_config()
- model = config.get("api.ark.video_model", self.DEFAULT_MODEL)
-
- super().__init__(
- base_url=base_url,
- api_key=api_key,
- timeout=timeout,
- **kwargs
- )
-
- # 保存视频生成相关配置
- self.model = model
- self.poll_interval = poll_interval
- self.max_poll_time = max_poll_time
-
- logger.info(f"ARK视频生成API异步客户端初始化完成,模型: {self.model}")
-
- async def create_video_task(
- self,
- prompt: str,
- image_url: str,
- gen_params: str = "",
- **kwargs
- ) -> Dict[str, Any]:
- """
- 异步创建视频生成任务
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填,必须是可访问的HTTP/HTTPS URL)
- gen_params: 额外的生成参数(可选,会追加到prompt后面)
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- API响应数据,包含任务ID等信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- if not prompt or not prompt.strip():
- raise ValueError("prompt不能为空")
-
- if not image_url or not image_url.strip():
- raise ValueError("image_url不能为空")
-
- # # 验证image_url是否为URL格式
- # if not image_url.startswith(("http://", "https://")):
- # raise ValueError(
- # f"image_url必须是HTTP/HTTPS URL格式,当前值: {image_url}。"
- # "如果是本地文件路径,请先上传到云存储获取URL。"
- # )
- image_url = upload_file_to_tos(image_url) if "http" not in image_url else image_url
-
- # 构建请求体
- request_data = {
- "model": kwargs.get("model", self.model),
- "content": [
- {
- "type": "text",
- "text": prompt + gen_params
- },
- {
- "type": "image_url",
- "image_url": {
- "url": image_url
- }
- }
- ],
- **{k: v for k, v in kwargs.items() if k != "model"}
- }
-
- logger.info(f"创建异步视频生成任务,模型: {request_data['model']}, 提示词: {prompt[:50]}...")
- logger.info(f"参考图片: {image_url}")
-
- try:
- response = await self.post(
- endpoint=self.DEFAULT_ENDPOINT,
- json=request_data
- )
-
- logger.info(f"视频生成任务创建成功,任务ID: {response.get('id', 'unknown')}")
- return response
-
- except APIError as e:
- logger.error(f"创建视频生成任务失败: {e}")
- raise
-
- async def query_video_task(self, task_id: str) -> Dict[str, Any]:
- """
- 异步查询视频生成任务状态
-
- Args:
- task_id: 任务ID(从create_video_task响应中获取)
-
- Returns:
- 任务状态详情,包含状态、视频URL等信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- if not task_id or not task_id.strip():
- raise ValueError("task_id不能为空")
-
- query_endpoint = f"{self.DEFAULT_ENDPOINT}/{task_id}"
-
- logger.debug(f"查询异步视频生成任务状态,任务ID: {task_id}")
-
- try:
- response = await self.get(endpoint=query_endpoint)
-
- status = response.get("status", "").lower()
- logger.debug(f"任务 {task_id} 状态: {status}")
-
- return response
-
- except APIError as e:
- logger.error(f"查询视频生成任务状态失败: {e}")
- raise
-
- async def wait_for_task(
- self,
- task_id: str,
- callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None
- ) -> Dict[str, Any]:
- """
- 异步等待任务完成(异步轮询)
-
- Args:
- task_id: 任务ID
- callback: 可选的回调函数,参数为 (task_id, result, error)
- 注意:回调函数如果是异步的,需要使用asyncio.create_task调用
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- TimeoutError: 如果任务超时
- """
- start_time = time.time()
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > self.max_poll_time:
- error_msg = f"任务超时(超过 {self.max_poll_time} 秒)"
- logger.error(f"任务 {task_id} {error_msg}")
- if callback:
- # 如果回调是协程函数,需要特殊处理
- if asyncio.iscoroutinefunction(callback):
- await callback(task_id, {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
- raise TimeoutError(error_msg)
-
- # 查询任务状态
- result = await self.query_video_task(task_id)
-
- if not result:
- logger.warning(f"任务 {task_id} 查询结果为空,继续等待...")
- await asyncio.sleep(self.poll_interval)
- continue
-
- # 解析状态
- status = result.get("status", "").lower()
-
- if status == TaskStatus.SUCCEEDED:
- logger.info(f"任务 {task_id} 完成,耗时: {int(elapsed)}秒")
- if callback:
- if asyncio.iscoroutinefunction(callback):
- await callback(task_id, result, None)
- else:
- callback(task_id, result, None)
- return result
-
- elif status == TaskStatus.FAILED:
- error_msg = result.get("error", {}).get("message", "未知错误")
- logger.error(f"任务 {task_id} 失败: {error_msg}")
- if callback:
- if asyncio.iscoroutinefunction(callback):
- await callback(task_id, {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
- raise APIError(f"任务失败: {error_msg}")
-
- elif status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
- logger.info(f"任务 {task_id} 处理中({int(elapsed)}秒),状态: {status}")
- await asyncio.sleep(self.poll_interval)
-
- else:
- logger.warning(f"任务 {task_id} 未知状态: {status},继续等待...")
- await asyncio.sleep(self.poll_interval)
-
- async def create_and_wait(
- self,
- prompt: str,
- image_url: str,
- gen_params: str = "",
- callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 异步创建视频生成任务并等待完成(便捷方法)
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填)
- gen_params: 额外的生成参数(可选)
- callback: 可选的回调函数,参数为 (task_id, result, error)
- **kwargs: 其他请求参数
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- TimeoutError: 如果任务超时
- """
- # 创建任务
- task_response = await self.create_video_task(
- prompt=prompt,
- image_url=image_url,
- gen_params=gen_params,
- **kwargs
- )
-
- task_id = task_response.get("id")
- if not task_id:
- raise APIError("创建任务成功但未返回任务ID")
-
- logger.info(f"任务已创建,任务ID: {task_id},开始等待完成...")
-
- # 等待任务完成
- return await self.wait_for_task(task_id, callback=callback)
-
- async def create_video_task_async(
- self,
- prompt: str,
- image_url: str,
- gen_params: str = "",
- callback: Optional[Callable] = handle_video_result,
- output_path: Optional[str] = None,
- **kwargs
- ) -> Tuple[Optional[str], Optional[asyncio.Task]]:
- """
- 创建视频生成任务并立即返回task_id和后台任务对象(不阻塞主流程)
-
- 任务会在后台异步任务中轮询,完成后调用回调函数。
- 调用者可以通过返回的任务对象等待任务完成。
-
- Args:
- prompt: 视频生成提示词(必填)
- image_url: 参考图片URL(必填)
- gen_params: 额外的生成参数(可选,会追加到prompt后面)
- callback: 可选的回调函数,可以是以下两种签名之一:
- 1. (task_id, result, error) -> None
- 2. (task_id, output_path, result, error) -> None
- 注意:如果是异步函数,需要使用 async def 定义
- output_path: 视频输出路径(可选,会传递给回调函数)
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- 元组 (task_id, background_task):
- - task_id: 任务ID,如果创建失败则返回None
- - background_task: 后台异步任务对象,可以用于等待任务完成
- 如果创建失败则返回None
-
- Raises:
- APIError: 如果创建任务失败
- """
- # 创建任务
- task_response = await self.create_video_task(
- prompt=prompt,
- image_url=image_url,
- gen_params=gen_params,
- **kwargs
- )
-
- task_id = task_response.get("id")
- if not task_id:
- logger.error("任务提交失败,无法启动后台轮询")
- return None, None
-
- logger.info(f"任务提交成功,task_id: {task_id},启动后台异步轮询...")
-
- # 定义后台异步任务包装函数
- async def _background_wait():
- """后台异步任务:等待任务完成并调用回调"""
- try:
- # 等待任务完成
- result = await self.wait_for_task(task_id)
-
- # 调用回调函数
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
-
- if asyncio.iscoroutinefunction(callback):
- # 异步回调函数
- if param_count == 4:
- await callback(task_id, output_path or "", result, None)
- else:
- await callback(task_id, result, None)
- else:
- # 同步回调函数
- if param_count == 4:
- # 4参数版本:(task_id, output_path, result, error)
- callback(task_id, output_path or "", result, None)
- else:
- # 3参数版本:(task_id, result, error)
- callback(task_id, result, None)
- except Exception as e:
- error_msg = str(e)
- logger.error(f"后台异步任务处理失败: {error_msg}")
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
-
- if asyncio.iscoroutinefunction(callback):
- if param_count == 4:
- await callback(task_id, output_path or "", {}, error_msg)
- else:
- await callback(task_id, {}, error_msg)
- else:
- if param_count == 4:
- callback(task_id, output_path or "", {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
-
- # 启动后台异步任务并返回任务对象,以便调用者可以等待
- background_task = asyncio.create_task(_background_wait())
-
- # 返回任务ID和任务对象(使用元组)
- return task_id, background_task
-
- def get_video_url(self, result: Dict[str, Any]) -> Optional[str]:
- """
- 从任务结果中提取视频URL
-
- Args:
- result: 任务结果(从query_video_task或wait_for_task返回)
-
- Returns:
- 视频URL,如果不存在则返回None
- """
- try:
- content = result.get("content", {})
- if isinstance(content, dict):
- return content.get("video_url")
- return None
- except (KeyError, TypeError, AttributeError) as e:
- logger.warning(f"提取视频URL失败: {e}")
- return None
-
- def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
- """
- 从任务结果中提取任务状态
-
- Args:
- result: 任务结果
-
- Returns:
- 任务状态字符串,如果不存在则返回None
- """
- try:
- return result.get("status", "").lower()
- except (KeyError, TypeError, AttributeError):
- return None
- async def main():
- # 示例用法
- async with AsyncArkVideoClient() as client:
- # 方式1:创建任务并等待完成(阻塞)
- try:
- result = await client.create_and_wait(
- prompt="图中的女生在街道上散步",
- image_url="https://example.com/image.jpg", # 必须是可访问的URL
- gen_params=" --dur 4"
- )
- video_url = client.get_video_url(result)
- print(f"视频生成成功,URL: {video_url}")
- except Exception as e:
- print(f"视频生成失败: {e}")
-
- # 方式2:异步创建任务(不阻塞,后台处理)
- # task_id = await client.create_video_task_async(
- # prompt="图中的女生在街道上散步",
- # image_url="https://example.com/image.jpg",
- # gen_params=" --dur 4",
- # callback=handle_video_result,
- # output_path="./output/video.mp4"
- # )
- # print(f"任务已提交,task_id: {task_id},主流程继续执行...")
- # # 主流程可以继续执行其他操作
- # await asyncio.sleep(10) # 等待一段时间
-
- # 方式3:创建任务后手动查询
- # task_response = await client.create_video_task(
- # prompt="图中的女生在街道上散步",
- # image_url="https://example.com/image.jpg",
- # gen_params=" --dur 4"
- # )
- # task_id = task_response.get("id")
- # result = await client.wait_for_task(task_id)
- # video_url = client.get_video_url(result)
- # print(f"视频URL: {video_url}")
- if __name__ == "__main__":
- asyncio.run(main())
|