| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- """
- 火山引擎ARK图片生成API异步客户端
- 封装ARK图片生成API的异步调用,提供类型安全的接口
- """
- import os
- import base64
- import asyncio
- import aiohttp
- from typing import Optional, Dict, Any, List, Callable
- from pathlib import Path
- from .base_client_async import AsyncAPIClient, APIError, RetryConfig
- from .ark_image_client import encode_image_to_base64 # 复用同步版本的编码函数
- from taskflow.logger import get_logger
- from taskflow.config import get_config
- logger = get_logger("api_modules.ark_image_client_async")
- def handle_image_result(
- task_id: str,
- output_path: str,
- result: Optional[Dict],
- error: Optional[str]
- ) -> None:
- """处理图片生成结果的回调函数"""
- if error:
- logger.info(f"\n任务 {task_id} 处理失败:{error}")
- else:
- from examples.video_create.utils.tools import download_image
- image_url = result.get("data", [{}])[0].get("url") if result.get("data") else None
- if image_url:
- download_image(image_url, output_path)
- logger.info(f"生成图片已下载:{output_path}")
- else:
- logger.warning(f"任务 {task_id} 完成但未获取到图片URL")
- class AsyncArkImageClient(AsyncAPIClient):
- """
- 火山引擎ARK图片生成API异步客户端
-
- 封装ARK图片生成API的异步调用,提供便捷的接口
- """
-
- DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
- DEFAULT_ENDPOINT = "/api/v3/images/generations"
- DEFAULT_MODEL = "doubao-seedream-4-0-250828"
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- model: Optional[str] = None,
- timeout: int = 120,
- sequential_generation: str = "disabled",
- response_format: str = "url",
- stream: bool = False,
- watermark: bool = False,
- **kwargs
- ):
- """
- 初始化ARK图片生成API异步客户端
-
- Args:
- api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
- base_url: API基础URL(默认使用官方URL)
- model: 模型名称(如果为None,会尝试从配置中获取)
- timeout: 请求超时时间(秒,默认120秒)
- sequential_generation: 序列生成开关(默认"disabled")
- response_format: 响应格式(默认"url")
- stream: 流式响应开关(默认False)
- watermark: 水印开关(默认False)
- **kwargs: 传递给AsyncAPIClient的其他参数
- """
- # 获取API密钥(优先级:参数 > 环境变量 > 配置)
- if api_key is None:
- api_key = os.getenv("ARK_API_KEY")
- if api_key is None:
- config = get_config()
- api_key = config.get("api.ark.api_key")
-
- if not api_key:
- raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
-
- # 获取base_url(优先级:参数 > 配置 > 默认值)
- if base_url is None:
- config = get_config()
- base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
-
- # 获取model(优先级:参数 > 配置 > 默认值)
- if model is None:
- config = get_config()
- model = config.get("api.ark.image_model", self.DEFAULT_MODEL)
- # 创建自定义重试配置
- retry_config = RetryConfig(
- max_retries=3,
- backoff_factor=3.0,
- retry_on_status=(500, 502, 429, 503, 504),
- retry_on_exception=(aiohttp.ClientError, asyncio.TimeoutError)
- )
- super().__init__(
- base_url=base_url,
- api_key=api_key,
- timeout=timeout,
- retry_config=retry_config,
- **kwargs
- )
-
- # 保存图片生成相关配置
- self.model = model
- self.sequential_generation = sequential_generation
- self.response_format = response_format
- self.stream = stream
- self.watermark = watermark
-
- logger.info(f"ARK图片生成API异步客户端初始化完成,模型: {self.model}")
-
- async def create_image_task(
- self,
- prompt: str,
- size: str = "1440x2560",
- reference_image: Optional[List[str]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 异步创建图片生成任务
-
- Args:
- prompt: 图片生成提示词(必填)
- size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
- reference_image: 参考图片列表,可以是:
- - 本地文件路径列表(会自动编码为base64)
- - HTTP/HTTPS URL列表
- - base64编码的字符串列表(包含data:image/...;base64,前缀)
- 如果为None,则生成无参考图片列表
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- API响应数据,包含生成的图片信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- if not prompt or not prompt.strip():
- raise ValueError("prompt不能为空")
-
- # 构建请求体
- request_data = {
- "model": kwargs.get("model", self.model),
- "prompt": prompt,
- "size": size,
- "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation),
- "response_format": kwargs.get("response_format", self.response_format),
- "stream": kwargs.get("stream", self.stream),
- "watermark": kwargs.get("watermark", self.watermark),
- }
-
- # 如果有参考图片,添加到请求中
- if reference_image and len(reference_image) > 0:
- # 判断是本地文件路径还是URL
- if reference_image[0].startswith(("http://", "https://")):
- # URL格式,直接使用
- request_data["image"] = reference_image
- elif reference_image[0].startswith("data:image"):
- # 已经是base64格式,直接使用
- request_data["image"] = reference_image
- else:
- # 本地文件路径,编码为base64(使用线程池避免阻塞事件循环)
- loop = asyncio.get_event_loop()
- request_data["image"] = [await loop.run_in_executor(None, encode_image_to_base64, image) for image in reference_image]
-
- logger.info(f"创建异步图片生成任务,模型: {request_data['model']}, 尺寸: {size}")
- if reference_image and len(reference_image) > 0:
- logger.info(f"使用参考图片: {reference_image[:50]}...")
- else:
- logger.info("未使用参考图片")
-
- # 记录请求数据(用于调试)
- logger.debug(f"请求数据: {request_data}")
-
- try:
- response = await self.post(
- endpoint=self.DEFAULT_ENDPOINT,
- json=request_data
- )
-
- logger.info("图片生成任务创建成功")
- return response
-
- except APIError as e:
- logger.error(f"创建图片生成任务失败: {e}")
- logger.error(f"请求数据: {request_data}")
- if e.response:
- logger.error(f"API错误响应: {e.response}")
- raise
-
- async def query_image_task(self, task_id: str) -> Dict[str, Any]:
- """
- 查询图片生成任务状态
-
- 注意:图片生成API通常是同步的,此方法主要用于接口一致性。
- 如果任务已完成,直接返回结果;否则返回待处理状态。
-
- Args:
- task_id: 任务ID(对于图片生成,这通常是响应中的某个标识符)
-
- Returns:
- 任务状态详情,包含图片URL等信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- # 图片生成API通常是同步的,不需要查询
- # 此方法主要用于接口一致性
- logger.warning("图片生成API是同步的,query_image_task方法可能不适用")
- raise NotImplementedError("图片生成API是同步的,不需要查询任务状态")
-
- async def wait_for_task(
- self,
- task_id: str,
- callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None
- ) -> Dict[str, Any]:
- """
- 等待任务完成
-
- 注意:图片生成API通常是同步的,此方法主要用于接口一致性。
- 对于图片生成,任务通常在create_image_task时就已经完成。
-
- Args:
- task_id: 任务ID
- callback: 可选的回调函数,参数为 (task_id, result, error)
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- """
- # 图片生成API通常是同步的,不需要等待
- # 此方法主要用于接口一致性
- logger.warning("图片生成API是同步的,wait_for_task方法可能不适用")
- raise NotImplementedError("图片生成API是同步的,不需要等待任务完成")
-
- async def create_image_task_async(
- self,
- prompt: str,
- size: str = "1440x2560",
- reference_image: Optional[List[str]] = None,
- callback: Optional[Callable] = handle_image_result,
- output_path: Optional[str] = None,
- **kwargs
- ) -> Optional[str]:
- """
- 创建图片生成任务并在后台任务中处理(不阻塞主流程)
-
- 任务会在后台异步任务中执行,完成后调用回调函数。
-
- Args:
- prompt: 图片生成提示词(必填)
- size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
- reference_image: 参考图片列表(可选)
- callback: 可选的回调函数,可以是以下两种签名之一:
- 1. (task_id, result, error) -> None
- 2. (task_id, output_path, result, error) -> None
- output_path: 图片输出路径(可选,会传递给回调函数)
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- 任务ID(task_id),如果创建失败则返回None
-
- Raises:
- APIError: 如果创建任务失败
- """
- # 生成一个简单的任务ID(基于时间戳)
- import time
- task_id = f"img_{int(time.time() * 1000)}"
-
- async def _background_task():
- """后台任务:执行图片生成并调用回调"""
- try:
- # 创建图片生成任务
- result = await self.create_image_task(
- prompt=prompt,
- size=size,
- reference_image=reference_image,
- **kwargs
- )
-
- # 调用回调函数
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- # 4参数版本:(task_id, output_path, result, error)
- callback(task_id, output_path or "", result, None)
- else:
- # 3参数版本:(task_id, result, error)
- callback(task_id, result, None)
-
- except Exception as e:
- error_msg = str(e)
- logger.error(f"后台图片生成任务失败: {error_msg}")
- if callback:
- import inspect
- sig = inspect.signature(callback)
- param_count = len(sig.parameters)
- if param_count == 4:
- callback(task_id, output_path or "", {}, error_msg)
- else:
- callback(task_id, {}, error_msg)
-
- # 启动后台任务
- asyncio.create_task(_background_task())
- logger.info(f"图片生成任务已提交,task_id: {task_id},后台处理中...")
-
- return task_id
-
- async def create_and_wait(
- self,
- prompt: str,
- size: str = "1440x2560",
- reference_image: Optional[List[str]] = None,
- callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 创建图片生成任务并等待完成(便捷方法)
-
- Args:
- prompt: 图片生成提示词(必填)
- size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
- reference_image: 参考图片列表(可选)
- callback: 可选的回调函数,参数为 (task_id, result, error)
- **kwargs: 其他请求参数
-
- Returns:
- 任务完成后的结果
-
- Raises:
- APIError: 如果请求失败
- """
- # 创建任务(图片生成是同步的,所以直接返回结果)
- result = await self.create_image_task(
- prompt=prompt,
- size=size,
- reference_image=reference_image,
- **kwargs
- )
-
- # 生成一个简单的任务ID
- import time
- task_id = f"img_{int(time.time() * 1000)}"
-
- logger.info(f"图片生成任务完成,任务ID: {task_id}")
-
- # 调用回调函数(如果提供)
- if callback:
- if asyncio.iscoroutinefunction(callback):
- await callback(task_id, result, None)
- else:
- callback(task_id, result, None)
-
- return result
-
- def get_image_url(self, response: Dict[str, Any]) -> Optional[str]:
- """
- 从响应中提取图片URL
-
- Args:
- response: API响应数据(从create_image_task或create_and_wait返回)
-
- Returns:
- 图片URL,如果不存在则返回None
- """
- try:
- if "data" in response and isinstance(response["data"], list):
- if len(response["data"]) > 0:
- image_data = response["data"][0]
- if isinstance(image_data, dict):
- # 根据response_format返回相应字段
- if self.response_format == "url":
- return image_data.get("url")
- elif self.response_format == "b64_json":
- return image_data.get("b64_json")
-
- return None
-
- except (KeyError, TypeError, IndexError) as e:
- logger.warning(f"提取图片URL失败: {e}")
- return None
-
- def get_image_urls(self, response: Dict[str, Any]) -> List[str]:
- """
- 从响应中提取所有图片URL
-
- Args:
- response: API响应数据
-
- Returns:
- 图片URL列表
- """
- urls = []
- try:
- if "data" in response and isinstance(response["data"], list):
- for image_data in response["data"]:
- if isinstance(image_data, dict):
- if self.response_format == "url":
- url = image_data.get("url")
- elif self.response_format == "b64_json":
- url = image_data.get("b64_json")
- else:
- url = image_data.get("url") or image_data.get("b64_json")
-
- if url:
- urls.append(url)
-
- return urls
-
- except (KeyError, TypeError, IndexError) as e:
- logger.warning(f"提取图片URL列表失败: {e}")
- return []
-
- def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
- """
- 从任务结果中提取任务状态
-
- Args:
- result: 任务结果(从create_image_task或create_and_wait返回)
-
- Returns:
- 任务状态字符串,如果不存在则返回None
- """
- try:
- # 图片生成API通常是同步的,如果返回了数据,则认为成功
- if result.get("data"):
- return "succeeded"
- return None
- except (KeyError, TypeError, AttributeError):
- return None
-
- # 保持向后兼容:generate_image作为便捷方法
- async def generate_image(
- self,
- prompt: str,
- size: str = "1440x2560",
- reference_image: Optional[List[str]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 异步生成图片(便捷方法,等同于create_and_wait)
-
- Args:
- prompt: 图片生成提示词(必填)
- size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
- reference_image: 参考图片列表(可选)
- **kwargs: 其他请求参数
-
- Returns:
- API响应数据,包含生成的图片信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- return await self.create_and_wait(
- prompt=prompt,
- size=size,
- reference_image=reference_image,
- **kwargs
- )
- async def main():
- # 示例用法
- async with AsyncArkImageClient() as client:
- # 方式1:创建任务并等待完成(推荐)
- try:
- result = await client.create_and_wait(
- prompt="图1中的女生穿着图2中的衣服在街道上散步,目视前方,手牵着一只小狗",
- reference_image=["./data/image/face.jpg", "./data/image/cloth.jpg"],
- size="1440x2560"
- )
- image_url = client.get_image_url(result)
- print(f"图片生成成功,URL: {image_url}")
- except Exception as e:
- print(f"图片生成失败: {e}")
-
- # 方式2:使用便捷方法generate_image
- # response = await client.generate_image(
- # prompt="一个美丽的风景",
- # size="1440x2560"
- # )
- # image_url = client.get_image_url(response)
- # print(f"图片URL: {image_url}")
-
- # 方式3:异步创建任务(后台处理)
- # task_id = await client.create_image_task_async(
- # prompt="一个美丽的风景",
- # callback=handle_image_result,
- # output_path="./output/image.jpg"
- # )
- # print(f"任务已提交,task_id: {task_id}")
- if __name__ == "__main__":
- asyncio.run(main())
|