""" 火山引擎ARK图片生成API异步客户端 封装ARK图片生成API的异步调用,提供类型安全的接口 """ import os import base64 import asyncio import aiohttp from typing import Optional, Dict, Any, List, Callable from pathlib import Path from .base_client_async import AsyncAPIClient, APIError, RetryConfig from .ark_image_client import encode_image_to_base64 # 复用同步版本的编码函数 from taskflow.logger import get_logger from taskflow.config import get_config logger = get_logger("api_modules.ark_image_client_async") def handle_image_result( task_id: str, output_path: str, result: Optional[Dict], error: Optional[str] ) -> None: """处理图片生成结果的回调函数""" if error: logger.info(f"\n任务 {task_id} 处理失败:{error}") else: from examples.video_create.utils.tools import download_image image_url = result.get("data", [{}])[0].get("url") if result.get("data") else None if image_url: download_image(image_url, output_path) logger.info(f"生成图片已下载:{output_path}") else: logger.warning(f"任务 {task_id} 完成但未获取到图片URL") class AsyncArkImageClient(AsyncAPIClient): """ 火山引擎ARK图片生成API异步客户端 封装ARK图片生成API的异步调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/images/generations" DEFAULT_MODEL = "doubao-seedream-4-0-250828" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 120, sequential_generation: str = "disabled", response_format: str = "url", stream: bool = False, watermark: bool = False, **kwargs ): """ 初始化ARK图片生成API异步客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认120秒) sequential_generation: 序列生成开关(默认"disabled") response_format: 响应格式(默认"url") stream: 流式响应开关(默认False) watermark: 水印开关(默认False) **kwargs: 传递给AsyncAPIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.image_model", self.DEFAULT_MODEL) # 创建自定义重试配置 retry_config = RetryConfig( max_retries=3, backoff_factor=3.0, retry_on_status=(500, 502, 429, 503, 504), retry_on_exception=(aiohttp.ClientError, asyncio.TimeoutError) ) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, retry_config=retry_config, **kwargs ) # 保存图片生成相关配置 self.model = model self.sequential_generation = sequential_generation self.response_format = response_format self.stream = stream self.watermark = watermark logger.info(f"ARK图片生成API异步客户端初始化完成,模型: {self.model}") async def create_image_task( self, prompt: str, size: str = "1440x2560", reference_image: Optional[List[str]] = None, **kwargs ) -> Dict[str, Any]: """ 异步创建图片生成任务 Args: prompt: 图片生成提示词(必填) size: 图片尺寸,格式为"宽x高"(默认"1440x2560") reference_image: 参考图片列表,可以是: - 本地文件路径列表(会自动编码为base64) - HTTP/HTTPS URL列表 - base64编码的字符串列表(包含data:image/...;base64,前缀) 如果为None,则生成无参考图片列表 **kwargs: 其他请求参数(会覆盖默认配置) Returns: API响应数据,包含生成的图片信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ if not prompt or not prompt.strip(): raise ValueError("prompt不能为空") # 构建请求体 request_data = { "model": kwargs.get("model", self.model), "prompt": prompt, "size": size, "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation), "response_format": kwargs.get("response_format", self.response_format), "stream": kwargs.get("stream", self.stream), "watermark": kwargs.get("watermark", self.watermark), } # 如果有参考图片,添加到请求中 if reference_image and len(reference_image) > 0: # 判断是本地文件路径还是URL if reference_image[0].startswith(("http://", "https://")): # URL格式,直接使用 request_data["image"] = reference_image elif reference_image[0].startswith("data:image"): # 已经是base64格式,直接使用 request_data["image"] = reference_image else: # 本地文件路径,编码为base64(使用线程池避免阻塞事件循环) loop = asyncio.get_event_loop() request_data["image"] = [await loop.run_in_executor(None, encode_image_to_base64, image) for image in reference_image] logger.info(f"创建异步图片生成任务,模型: {request_data['model']}, 尺寸: {size}") if reference_image and len(reference_image) > 0: logger.info(f"使用参考图片: {reference_image[:50]}...") else: logger.info("未使用参考图片") # 记录请求数据(用于调试) logger.debug(f"请求数据: {request_data}") try: response = await self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info("图片生成任务创建成功") return response except APIError as e: logger.error(f"创建图片生成任务失败: {e}") logger.error(f"请求数据: {request_data}") if e.response: logger.error(f"API错误响应: {e.response}") raise async def query_image_task(self, task_id: str) -> Dict[str, Any]: """ 查询图片生成任务状态 注意:图片生成API通常是同步的,此方法主要用于接口一致性。 如果任务已完成,直接返回结果;否则返回待处理状态。 Args: task_id: 任务ID(对于图片生成,这通常是响应中的某个标识符) Returns: 任务状态详情,包含图片URL等信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ # 图片生成API通常是同步的,不需要查询 # 此方法主要用于接口一致性 logger.warning("图片生成API是同步的,query_image_task方法可能不适用") raise NotImplementedError("图片生成API是同步的,不需要查询任务状态") async def wait_for_task( self, task_id: str, callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None ) -> Dict[str, Any]: """ 等待任务完成 注意:图片生成API通常是同步的,此方法主要用于接口一致性。 对于图片生成,任务通常在create_image_task时就已经完成。 Args: task_id: 任务ID callback: 可选的回调函数,参数为 (task_id, result, error) Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 """ # 图片生成API通常是同步的,不需要等待 # 此方法主要用于接口一致性 logger.warning("图片生成API是同步的,wait_for_task方法可能不适用") raise NotImplementedError("图片生成API是同步的,不需要等待任务完成") async def create_image_task_async( self, prompt: str, size: str = "1440x2560", reference_image: Optional[List[str]] = None, callback: Optional[Callable] = handle_image_result, output_path: Optional[str] = None, **kwargs ) -> Optional[str]: """ 创建图片生成任务并在后台任务中处理(不阻塞主流程) 任务会在后台异步任务中执行,完成后调用回调函数。 Args: prompt: 图片生成提示词(必填) size: 图片尺寸,格式为"宽x高"(默认"1440x2560") reference_image: 参考图片列表(可选) callback: 可选的回调函数,可以是以下两种签名之一: 1. (task_id, result, error) -> None 2. (task_id, output_path, result, error) -> None output_path: 图片输出路径(可选,会传递给回调函数) **kwargs: 其他请求参数(会覆盖默认配置) Returns: 任务ID(task_id),如果创建失败则返回None Raises: APIError: 如果创建任务失败 """ # 生成一个简单的任务ID(基于时间戳) import time task_id = f"img_{int(time.time() * 1000)}" async def _background_task(): """后台任务:执行图片生成并调用回调""" try: # 创建图片生成任务 result = await self.create_image_task( prompt=prompt, size=size, reference_image=reference_image, **kwargs ) # 调用回调函数 if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: # 4参数版本:(task_id, output_path, result, error) callback(task_id, output_path or "", result, None) else: # 3参数版本:(task_id, result, error) callback(task_id, result, None) except Exception as e: error_msg = str(e) logger.error(f"后台图片生成任务失败: {error_msg}") if callback: import inspect sig = inspect.signature(callback) param_count = len(sig.parameters) if param_count == 4: callback(task_id, output_path or "", {}, error_msg) else: callback(task_id, {}, error_msg) # 启动后台任务 asyncio.create_task(_background_task()) logger.info(f"图片生成任务已提交,task_id: {task_id},后台处理中...") return task_id async def create_and_wait( self, prompt: str, size: str = "1440x2560", reference_image: Optional[List[str]] = None, callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None, **kwargs ) -> Dict[str, Any]: """ 创建图片生成任务并等待完成(便捷方法) Args: prompt: 图片生成提示词(必填) size: 图片尺寸,格式为"宽x高"(默认"1440x2560") reference_image: 参考图片列表(可选) callback: 可选的回调函数,参数为 (task_id, result, error) **kwargs: 其他请求参数 Returns: 任务完成后的结果 Raises: APIError: 如果请求失败 """ # 创建任务(图片生成是同步的,所以直接返回结果) result = await self.create_image_task( prompt=prompt, size=size, reference_image=reference_image, **kwargs ) # 生成一个简单的任务ID import time task_id = f"img_{int(time.time() * 1000)}" logger.info(f"图片生成任务完成,任务ID: {task_id}") # 调用回调函数(如果提供) if callback: if asyncio.iscoroutinefunction(callback): await callback(task_id, result, None) else: callback(task_id, result, None) return result def get_image_url(self, response: Dict[str, Any]) -> Optional[str]: """ 从响应中提取图片URL Args: response: API响应数据(从create_image_task或create_and_wait返回) Returns: 图片URL,如果不存在则返回None """ try: if "data" in response and isinstance(response["data"], list): if len(response["data"]) > 0: image_data = response["data"][0] if isinstance(image_data, dict): # 根据response_format返回相应字段 if self.response_format == "url": return image_data.get("url") elif self.response_format == "b64_json": return image_data.get("b64_json") return None except (KeyError, TypeError, IndexError) as e: logger.warning(f"提取图片URL失败: {e}") return None def get_image_urls(self, response: Dict[str, Any]) -> List[str]: """ 从响应中提取所有图片URL Args: response: API响应数据 Returns: 图片URL列表 """ urls = [] try: if "data" in response and isinstance(response["data"], list): for image_data in response["data"]: if isinstance(image_data, dict): if self.response_format == "url": url = image_data.get("url") elif self.response_format == "b64_json": url = image_data.get("b64_json") else: url = image_data.get("url") or image_data.get("b64_json") if url: urls.append(url) return urls except (KeyError, TypeError, IndexError) as e: logger.warning(f"提取图片URL列表失败: {e}") return [] def get_task_status(self, result: Dict[str, Any]) -> Optional[str]: """ 从任务结果中提取任务状态 Args: result: 任务结果(从create_image_task或create_and_wait返回) Returns: 任务状态字符串,如果不存在则返回None """ try: # 图片生成API通常是同步的,如果返回了数据,则认为成功 if result.get("data"): return "succeeded" return None except (KeyError, TypeError, AttributeError): return None # 保持向后兼容:generate_image作为便捷方法 async def generate_image( self, prompt: str, size: str = "1440x2560", reference_image: Optional[List[str]] = None, **kwargs ) -> Dict[str, Any]: """ 异步生成图片(便捷方法,等同于create_and_wait) Args: prompt: 图片生成提示词(必填) size: 图片尺寸,格式为"宽x高"(默认"1440x2560") reference_image: 参考图片列表(可选) **kwargs: 其他请求参数 Returns: API响应数据,包含生成的图片信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ return await self.create_and_wait( prompt=prompt, size=size, reference_image=reference_image, **kwargs ) async def main(): # 示例用法 async with AsyncArkImageClient() as client: # 方式1:创建任务并等待完成(推荐) try: result = await client.create_and_wait( prompt="图1中的女生穿着图2中的衣服在街道上散步,目视前方,手牵着一只小狗", reference_image=["./data/image/face.jpg", "./data/image/cloth.jpg"], size="1440x2560" ) image_url = client.get_image_url(result) print(f"图片生成成功,URL: {image_url}") except Exception as e: print(f"图片生成失败: {e}") # 方式2:使用便捷方法generate_image # response = await client.generate_image( # prompt="一个美丽的风景", # size="1440x2560" # ) # image_url = client.get_image_url(response) # print(f"图片URL: {image_url}") # 方式3:异步创建任务(后台处理) # task_id = await client.create_image_task_async( # prompt="一个美丽的风景", # callback=handle_image_result, # output_path="./output/image.jpg" # ) # print(f"任务已提交,task_id: {task_id}") if __name__ == "__main__": asyncio.run(main())