ark_image_client.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. """
  2. 火山引擎ARK图片生成API客户端
  3. 封装ARK图片生成API的调用,提供类型安全的接口
  4. """
  5. import os
  6. import base64
  7. from typing import Optional, Dict, Any, List
  8. from pathlib import Path
  9. from .base_client import APIClient, APIError
  10. from taskflow.logger import get_logger
  11. from taskflow.config import get_config
  12. logger = get_logger("api_modules.ark_image_client")
  13. def encode_image_to_base64(image_path: str) -> str:
  14. """
  15. 将本地图片文件编码为base64格式
  16. Args:
  17. image_path: 图片文件路径
  18. Returns:
  19. base64编码的图片字符串(包含data:image/...;base64,前缀)
  20. """
  21. try:
  22. with open(image_path, 'rb') as image_file:
  23. image_data = image_file.read()
  24. image_base64 = base64.b64encode(image_data).decode('utf-8')
  25. # 根据文件扩展名确定MIME类型
  26. ext = Path(image_path).suffix.lower()
  27. mime_types = {
  28. '.jpg': 'image/jpeg',
  29. '.jpeg': 'image/jpeg',
  30. '.png': 'image/png',
  31. '.gif': 'image/gif',
  32. '.webp': 'image/webp'
  33. }
  34. mime_type = mime_types.get(ext, 'image/jpeg')
  35. return f"data:{mime_type};base64,{image_base64}"
  36. except Exception as e:
  37. logger.error(f"编码图片失败: {e}")
  38. raise ValueError(f"无法读取或编码图片文件: {image_path}") from e
  39. class ArkImageClient(APIClient):
  40. """
  41. 火山引擎ARK图片生成API客户端
  42. 封装ARK图片生成API的调用,提供便捷的接口
  43. """
  44. DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
  45. DEFAULT_ENDPOINT = "/api/v3/images/generations"
  46. DEFAULT_MODEL = "doubao-seedream-4-0-250828"
  47. def __init__(
  48. self,
  49. api_key: Optional[str] = None,
  50. base_url: Optional[str] = None,
  51. model: Optional[str] = None,
  52. timeout: int = 120,
  53. sequential_generation: str = "disabled",
  54. response_format: str = "url",
  55. stream: bool = False,
  56. watermark: bool = False,
  57. **kwargs
  58. ):
  59. """
  60. 初始化ARK图片生成API客户端
  61. Args:
  62. api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
  63. base_url: API基础URL(默认使用官方URL)
  64. model: 模型名称(如果为None,会尝试从配置中获取)
  65. timeout: 请求超时时间(秒,默认120秒)
  66. sequential_generation: 序列生成开关(默认"disabled")
  67. response_format: 响应格式(默认"url")
  68. stream: 流式响应开关(默认False)
  69. watermark: 水印开关(默认False)
  70. **kwargs: 传递给APIClient的其他参数
  71. """
  72. # 获取API密钥(优先级:参数 > 环境变量 > 配置)
  73. if api_key is None:
  74. api_key = os.getenv("ARK_API_KEY")
  75. if api_key is None:
  76. config = get_config()
  77. api_key = config.get("api.ark.api_key")
  78. if not api_key:
  79. raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
  80. # 获取base_url(优先级:参数 > 配置 > 默认值)
  81. if base_url is None:
  82. config = get_config()
  83. base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
  84. # 获取model(优先级:参数 > 配置 > 默认值)
  85. if model is None:
  86. config = get_config()
  87. model = config.get("api.ark.image_model", self.DEFAULT_MODEL)
  88. super().__init__(
  89. base_url=base_url,
  90. api_key=api_key,
  91. timeout=timeout,
  92. **kwargs
  93. )
  94. # 保存图片生成相关配置
  95. self.model = model
  96. self.sequential_generation = sequential_generation
  97. self.response_format = response_format
  98. self.stream = stream
  99. self.watermark = watermark
  100. logger.info(f"ARK图片生成API客户端初始化完成,模型: {self.model}")
  101. def generate_image(
  102. self,
  103. prompt: str,
  104. size: str = "1440x2560",
  105. reference_image: Optional[List[str]] = None,
  106. **kwargs
  107. ) -> Dict[str, Any]:
  108. """
  109. 生成图片
  110. Args:
  111. prompt: 图片生成提示词(必填)
  112. size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
  113. reference_image: 参考图片列表,可以是:
  114. - 本地文件路径列表(会自动编码为base64)
  115. - HTTP/HTTPS URL列表
  116. - base64编码的字符串列表(包含data:image/...;base64,前缀)
  117. 如果为None,则生成无参考图片列表
  118. **kwargs: 其他请求参数(会覆盖默认配置)
  119. Returns:
  120. API响应数据,包含生成的图片信息
  121. Raises:
  122. APIError: 如果请求失败
  123. ValueError: 如果参数无效
  124. """
  125. if not prompt:
  126. raise ValueError("prompt不能为空")
  127. # 构建请求体
  128. request_data = {
  129. "model": kwargs.get("model", self.model),
  130. "prompt": prompt,
  131. "size": size,
  132. "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation),
  133. "response_format": kwargs.get("response_format", self.response_format),
  134. "stream": kwargs.get("stream", self.stream),
  135. "watermark": kwargs.get("watermark", self.watermark),
  136. }
  137. # 如果有参考图片,添加到请求中
  138. if reference_image:
  139. # 判断是本地文件路径还是URL
  140. if reference_image[0].startswith(("http://", "https://")):
  141. # URL格式,直接使用
  142. request_data["image"] = reference_image
  143. elif reference_image[0].startswith("data:image"):
  144. # 已经是base64格式,直接使用
  145. request_data["image"] = reference_image
  146. else:
  147. # 本地文件路径,编码为base64
  148. request_data["image"] = [encode_image_to_base64(image) for image in reference_image]
  149. logger.info(f"发送图片生成请求,模型: {request_data['model']}, 尺寸: {size}")
  150. if reference_image:
  151. logger.info(f"使用参考图片: {reference_image[:50]}...")
  152. try:
  153. response = self.post(
  154. endpoint=self.DEFAULT_ENDPOINT,
  155. json=request_data
  156. )
  157. logger.info("图片生成请求成功")
  158. return response
  159. except APIError as e:
  160. logger.error(f"图片生成请求失败: {e}")
  161. raise
  162. def get_image_url(self, response: Dict[str, Any]) -> Optional[str]:
  163. """
  164. 从响应中提取图片URL
  165. Args:
  166. response: API响应数据
  167. Returns:
  168. 图片URL,如果不存在则返回None
  169. """
  170. try:
  171. if "data" in response and isinstance(response["data"], list):
  172. if len(response["data"]) > 0:
  173. image_data = response["data"][0]
  174. if isinstance(image_data, dict):
  175. # 根据response_format返回相应字段
  176. if self.response_format == "url":
  177. return image_data.get("url")
  178. elif self.response_format == "b64_json":
  179. return image_data.get("b64_json")
  180. return None
  181. except (KeyError, TypeError, IndexError) as e:
  182. logger.warning(f"提取图片URL失败: {e}")
  183. return None
  184. def get_image_urls(self, response: Dict[str, Any]) -> List[str]:
  185. """
  186. 从响应中提取所有图片URL
  187. Args:
  188. response: API响应数据
  189. Returns:
  190. 图片URL列表
  191. """
  192. urls = []
  193. try:
  194. if "data" in response and isinstance(response["data"], list):
  195. for image_data in response["data"]:
  196. if isinstance(image_data, dict):
  197. if self.response_format == "url":
  198. url = image_data.get("url")
  199. elif self.response_format == "b64_json":
  200. url = image_data.get("b64_json")
  201. else:
  202. url = image_data.get("url") or image_data.get("b64_json")
  203. if url:
  204. urls.append(url)
  205. return urls
  206. except (KeyError, TypeError, IndexError) as e:
  207. logger.warning(f"提取图片URL列表失败: {e}")
  208. return []
  209. if __name__ == "__main__":
  210. client = ArkImageClient()
  211. response = client.generate_image(
  212. prompt = "图1中的女生穿着图2中的衣服在街道上散步",
  213. reference_image = ["./data/image/face.jpg", "./data/image/cloth.jpg"],
  214. size = "1440x2560"
  215. )
  216. image_url = client.get_image_url(response)
  217. print(image_url)