""" 火山引擎ARK图片生成API客户端 封装ARK图片生成API的调用,提供类型安全的接口 """ import os import base64 from typing import Optional, Dict, Any, List from pathlib import Path from .base_client import APIClient, APIError from taskflow.logger import get_logger from taskflow.config import get_config logger = get_logger("api_modules.ark_image_client") def encode_image_to_base64(image_path: str) -> str: """ 将本地图片文件编码为base64格式 Args: image_path: 图片文件路径 Returns: base64编码的图片字符串(包含data:image/...;base64,前缀) """ try: with open(image_path, 'rb') as image_file: image_data = image_file.read() image_base64 = base64.b64encode(image_data).decode('utf-8') # 根据文件扩展名确定MIME类型 ext = Path(image_path).suffix.lower() mime_types = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp' } mime_type = mime_types.get(ext, 'image/jpeg') return f"data:{mime_type};base64,{image_base64}" except Exception as e: logger.error(f"编码图片失败: {e}") raise ValueError(f"无法读取或编码图片文件: {image_path}") from e class ArkImageClient(APIClient): """ 火山引擎ARK图片生成API客户端 封装ARK图片生成API的调用,提供便捷的接口 """ DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com" DEFAULT_ENDPOINT = "/api/v3/images/generations" DEFAULT_MODEL = "doubao-seedream-4-0-250828" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, timeout: int = 120, sequential_generation: str = "disabled", response_format: str = "url", stream: bool = False, watermark: bool = False, **kwargs ): """ 初始化ARK图片生成API客户端 Args: api_key: API密钥(如果为None,会尝试从环境变量或配置中获取) base_url: API基础URL(默认使用官方URL) model: 模型名称(如果为None,会尝试从配置中获取) timeout: 请求超时时间(秒,默认120秒) sequential_generation: 序列生成开关(默认"disabled") response_format: 响应格式(默认"url") stream: 流式响应开关(默认False) watermark: 水印开关(默认False) **kwargs: 传递给APIClient的其他参数 """ # 获取API密钥(优先级:参数 > 环境变量 > 配置) if api_key is None: api_key = os.getenv("ARK_API_KEY") if api_key is None: config = get_config() api_key = config.get("api.ark.api_key") if not api_key: raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供") # 获取base_url(优先级:参数 > 配置 > 默认值) if base_url is None: config = get_config() base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL) # 获取model(优先级:参数 > 配置 > 默认值) if model is None: config = get_config() model = config.get("api.ark.image_model", self.DEFAULT_MODEL) super().__init__( base_url=base_url, api_key=api_key, timeout=timeout, **kwargs ) # 保存图片生成相关配置 self.model = model self.sequential_generation = sequential_generation self.response_format = response_format self.stream = stream self.watermark = watermark logger.info(f"ARK图片生成API客户端初始化完成,模型: {self.model}") def generate_image( self, prompt: str, size: str = "1440x2560", reference_image: Optional[List[str]] = None, **kwargs ) -> Dict[str, Any]: """ 生成图片 Args: prompt: 图片生成提示词(必填) size: 图片尺寸,格式为"宽x高"(默认"1440x2560") reference_image: 参考图片列表,可以是: - 本地文件路径列表(会自动编码为base64) - HTTP/HTTPS URL列表 - base64编码的字符串列表(包含data:image/...;base64,前缀) 如果为None,则生成无参考图片列表 **kwargs: 其他请求参数(会覆盖默认配置) Returns: API响应数据,包含生成的图片信息 Raises: APIError: 如果请求失败 ValueError: 如果参数无效 """ if not prompt: raise ValueError("prompt不能为空") # 构建请求体 request_data = { "model": kwargs.get("model", self.model), "prompt": prompt, "size": size, "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation), "response_format": kwargs.get("response_format", self.response_format), "stream": kwargs.get("stream", self.stream), "watermark": kwargs.get("watermark", self.watermark), } # 如果有参考图片,添加到请求中 if reference_image: # 判断是本地文件路径还是URL if reference_image[0].startswith(("http://", "https://")): # URL格式,直接使用 request_data["image"] = reference_image elif reference_image[0].startswith("data:image"): # 已经是base64格式,直接使用 request_data["image"] = reference_image else: # 本地文件路径,编码为base64 request_data["image"] = [encode_image_to_base64(image) for image in reference_image] logger.info(f"发送图片生成请求,模型: {request_data['model']}, 尺寸: {size}") if reference_image: logger.info(f"使用参考图片: {reference_image[:50]}...") try: response = self.post( endpoint=self.DEFAULT_ENDPOINT, json=request_data ) logger.info("图片生成请求成功") return response except APIError as e: logger.error(f"图片生成请求失败: {e}") raise def get_image_url(self, response: Dict[str, Any]) -> Optional[str]: """ 从响应中提取图片URL Args: response: API响应数据 Returns: 图片URL,如果不存在则返回None """ try: if "data" in response and isinstance(response["data"], list): if len(response["data"]) > 0: image_data = response["data"][0] if isinstance(image_data, dict): # 根据response_format返回相应字段 if self.response_format == "url": return image_data.get("url") elif self.response_format == "b64_json": return image_data.get("b64_json") return None except (KeyError, TypeError, IndexError) as e: logger.warning(f"提取图片URL失败: {e}") return None def get_image_urls(self, response: Dict[str, Any]) -> List[str]: """ 从响应中提取所有图片URL Args: response: API响应数据 Returns: 图片URL列表 """ urls = [] try: if "data" in response and isinstance(response["data"], list): for image_data in response["data"]: if isinstance(image_data, dict): if self.response_format == "url": url = image_data.get("url") elif self.response_format == "b64_json": url = image_data.get("b64_json") else: url = image_data.get("url") or image_data.get("b64_json") if url: urls.append(url) return urls except (KeyError, TypeError, IndexError) as e: logger.warning(f"提取图片URL列表失败: {e}") return [] if __name__ == "__main__": client = ArkImageClient() response = client.generate_image( prompt = "图1中的女生穿着图2中的衣服在街道上散步", reference_image = ["./data/image/face.jpg", "./data/image/cloth.jpg"], size = "1440x2560" ) image_url = client.get_image_url(response) print(image_url)