| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- """
- 火山引擎ARK图片生成API客户端
- 封装ARK图片生成API的调用,提供类型安全的接口
- """
- import os
- import base64
- from typing import Optional, Dict, Any, List
- from pathlib import Path
- from .base_client import APIClient, APIError
- from taskflow.logger import get_logger
- from taskflow.config import get_config
- logger = get_logger("api_modules.ark_image_client")
- def encode_image_to_base64(image_path: str) -> str:
- """
- 将本地图片文件编码为base64格式
-
- Args:
- image_path: 图片文件路径
-
- Returns:
- base64编码的图片字符串(包含data:image/...;base64,前缀)
- """
- try:
- with open(image_path, 'rb') as image_file:
- image_data = image_file.read()
- image_base64 = base64.b64encode(image_data).decode('utf-8')
-
- # 根据文件扩展名确定MIME类型
- ext = Path(image_path).suffix.lower()
- mime_types = {
- '.jpg': 'image/jpeg',
- '.jpeg': 'image/jpeg',
- '.png': 'image/png',
- '.gif': 'image/gif',
- '.webp': 'image/webp'
- }
- mime_type = mime_types.get(ext, 'image/jpeg')
-
- return f"data:{mime_type};base64,{image_base64}"
- except Exception as e:
- logger.error(f"编码图片失败: {e}")
- raise ValueError(f"无法读取或编码图片文件: {image_path}") from e
- class ArkImageClient(APIClient):
- """
- 火山引擎ARK图片生成API客户端
-
- 封装ARK图片生成API的调用,提供便捷的接口
- """
-
- DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
- DEFAULT_ENDPOINT = "/api/v3/images/generations"
- DEFAULT_MODEL = "doubao-seedream-4-0-250828"
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- model: Optional[str] = None,
- timeout: int = 120,
- sequential_generation: str = "disabled",
- response_format: str = "url",
- stream: bool = False,
- watermark: bool = False,
- **kwargs
- ):
- """
- 初始化ARK图片生成API客户端
-
- Args:
- api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
- base_url: API基础URL(默认使用官方URL)
- model: 模型名称(如果为None,会尝试从配置中获取)
- timeout: 请求超时时间(秒,默认120秒)
- sequential_generation: 序列生成开关(默认"disabled")
- response_format: 响应格式(默认"url")
- stream: 流式响应开关(默认False)
- watermark: 水印开关(默认False)
- **kwargs: 传递给APIClient的其他参数
- """
- # 获取API密钥(优先级:参数 > 环境变量 > 配置)
- if api_key is None:
- api_key = os.getenv("ARK_API_KEY")
- if api_key is None:
- config = get_config()
- api_key = config.get("api.ark.api_key")
-
- if not api_key:
- raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
-
- # 获取base_url(优先级:参数 > 配置 > 默认值)
- if base_url is None:
- config = get_config()
- base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
-
- # 获取model(优先级:参数 > 配置 > 默认值)
- if model is None:
- config = get_config()
- model = config.get("api.ark.image_model", self.DEFAULT_MODEL)
-
- super().__init__(
- base_url=base_url,
- api_key=api_key,
- timeout=timeout,
- **kwargs
- )
-
- # 保存图片生成相关配置
- self.model = model
- self.sequential_generation = sequential_generation
- self.response_format = response_format
- self.stream = stream
- self.watermark = watermark
-
- logger.info(f"ARK图片生成API客户端初始化完成,模型: {self.model}")
-
- def generate_image(
- self,
- prompt: str,
- size: str = "1440x2560",
- reference_image: Optional[List[str]] = None,
- **kwargs
- ) -> Dict[str, Any]:
- """
- 生成图片
-
- Args:
- prompt: 图片生成提示词(必填)
- size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
- reference_image: 参考图片列表,可以是:
- - 本地文件路径列表(会自动编码为base64)
- - HTTP/HTTPS URL列表
- - base64编码的字符串列表(包含data:image/...;base64,前缀)
- 如果为None,则生成无参考图片列表
- **kwargs: 其他请求参数(会覆盖默认配置)
-
- Returns:
- API响应数据,包含生成的图片信息
-
- Raises:
- APIError: 如果请求失败
- ValueError: 如果参数无效
- """
- if not prompt:
- raise ValueError("prompt不能为空")
-
- # 构建请求体
- request_data = {
- "model": kwargs.get("model", self.model),
- "prompt": prompt,
- "size": size,
- "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation),
- "response_format": kwargs.get("response_format", self.response_format),
- "stream": kwargs.get("stream", self.stream),
- "watermark": kwargs.get("watermark", self.watermark),
- }
-
- # 如果有参考图片,添加到请求中
- if reference_image:
- # 判断是本地文件路径还是URL
- if reference_image[0].startswith(("http://", "https://")):
- # URL格式,直接使用
- request_data["image"] = reference_image
- elif reference_image[0].startswith("data:image"):
- # 已经是base64格式,直接使用
- request_data["image"] = reference_image
- else:
- # 本地文件路径,编码为base64
- request_data["image"] = [encode_image_to_base64(image) for image in reference_image]
-
- logger.info(f"发送图片生成请求,模型: {request_data['model']}, 尺寸: {size}")
- if reference_image:
- logger.info(f"使用参考图片: {reference_image[:50]}...")
-
- try:
- response = self.post(
- endpoint=self.DEFAULT_ENDPOINT,
- json=request_data
- )
-
- logger.info("图片生成请求成功")
- return response
-
- except APIError as e:
- logger.error(f"图片生成请求失败: {e}")
- raise
-
- def get_image_url(self, response: Dict[str, Any]) -> Optional[str]:
- """
- 从响应中提取图片URL
-
- Args:
- response: API响应数据
-
- Returns:
- 图片URL,如果不存在则返回None
- """
- try:
- if "data" in response and isinstance(response["data"], list):
- if len(response["data"]) > 0:
- image_data = response["data"][0]
- if isinstance(image_data, dict):
- # 根据response_format返回相应字段
- if self.response_format == "url":
- return image_data.get("url")
- elif self.response_format == "b64_json":
- return image_data.get("b64_json")
-
- return None
-
- except (KeyError, TypeError, IndexError) as e:
- logger.warning(f"提取图片URL失败: {e}")
- return None
-
- def get_image_urls(self, response: Dict[str, Any]) -> List[str]:
- """
- 从响应中提取所有图片URL
-
- Args:
- response: API响应数据
-
- Returns:
- 图片URL列表
- """
- urls = []
- try:
- if "data" in response and isinstance(response["data"], list):
- for image_data in response["data"]:
- if isinstance(image_data, dict):
- if self.response_format == "url":
- url = image_data.get("url")
- elif self.response_format == "b64_json":
- url = image_data.get("b64_json")
- else:
- url = image_data.get("url") or image_data.get("b64_json")
-
- if url:
- urls.append(url)
-
- return urls
-
- except (KeyError, TypeError, IndexError) as e:
- logger.warning(f"提取图片URL列表失败: {e}")
- return []
- if __name__ == "__main__":
- client = ArkImageClient()
- response = client.generate_image(
- prompt = "图1中的女生穿着图2中的衣服在街道上散步",
- reference_image = ["./data/image/face.jpg", "./data/image/cloth.jpg"],
- size = "1440x2560"
- )
- image_url = client.get_image_url(response)
- print(image_url)
|