ark_image_client_async.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. """
  2. 火山引擎ARK图片生成API异步客户端
  3. 封装ARK图片生成API的异步调用,提供类型安全的接口
  4. """
  5. import os
  6. import base64
  7. import asyncio
  8. import aiohttp
  9. from typing import Optional, Dict, Any, List, Callable
  10. from pathlib import Path
  11. from .base_client_async import AsyncAPIClient, APIError, RetryConfig
  12. from .ark_image_client import encode_image_to_base64 # 复用同步版本的编码函数
  13. from taskflow.logger import get_logger
  14. from taskflow.config import get_config
  15. logger = get_logger("api_modules.ark_image_client_async")
  16. def handle_image_result(
  17. task_id: str,
  18. output_path: str,
  19. result: Optional[Dict],
  20. error: Optional[str]
  21. ) -> None:
  22. """处理图片生成结果的回调函数"""
  23. if error:
  24. logger.info(f"\n任务 {task_id} 处理失败:{error}")
  25. else:
  26. from examples.video_create.utils.tools import download_image
  27. image_url = result.get("data", [{}])[0].get("url") if result.get("data") else None
  28. if image_url:
  29. download_image(image_url, output_path)
  30. logger.info(f"生成图片已下载:{output_path}")
  31. else:
  32. logger.warning(f"任务 {task_id} 完成但未获取到图片URL")
  33. class AsyncArkImageClient(AsyncAPIClient):
  34. """
  35. 火山引擎ARK图片生成API异步客户端
  36. 封装ARK图片生成API的异步调用,提供便捷的接口
  37. """
  38. DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com"
  39. DEFAULT_ENDPOINT = "/api/v3/images/generations"
  40. DEFAULT_MODEL = "doubao-seedream-4-0-250828"
  41. def __init__(
  42. self,
  43. api_key: Optional[str] = None,
  44. base_url: Optional[str] = None,
  45. model: Optional[str] = None,
  46. timeout: int = 120,
  47. sequential_generation: str = "disabled",
  48. response_format: str = "url",
  49. stream: bool = False,
  50. watermark: bool = False,
  51. **kwargs
  52. ):
  53. """
  54. 初始化ARK图片生成API异步客户端
  55. Args:
  56. api_key: API密钥(如果为None,会尝试从环境变量或配置中获取)
  57. base_url: API基础URL(默认使用官方URL)
  58. model: 模型名称(如果为None,会尝试从配置中获取)
  59. timeout: 请求超时时间(秒,默认120秒)
  60. sequential_generation: 序列生成开关(默认"disabled")
  61. response_format: 响应格式(默认"url")
  62. stream: 流式响应开关(默认False)
  63. watermark: 水印开关(默认False)
  64. **kwargs: 传递给AsyncAPIClient的其他参数
  65. """
  66. # 获取API密钥(优先级:参数 > 环境变量 > 配置)
  67. if api_key is None:
  68. api_key = os.getenv("ARK_API_KEY")
  69. if api_key is None:
  70. config = get_config()
  71. api_key = config.get("api.ark.api_key")
  72. if not api_key:
  73. raise ValueError("ARK API密钥未提供,请通过参数、环境变量ARK_API_KEY或配置文件提供")
  74. # 获取base_url(优先级:参数 > 配置 > 默认值)
  75. if base_url is None:
  76. config = get_config()
  77. base_url = config.get("api.ark.base_url", self.DEFAULT_BASE_URL)
  78. # 获取model(优先级:参数 > 配置 > 默认值)
  79. if model is None:
  80. config = get_config()
  81. model = config.get("api.ark.image_model", self.DEFAULT_MODEL)
  82. # 创建自定义重试配置
  83. retry_config = RetryConfig(
  84. max_retries=3,
  85. backoff_factor=3.0,
  86. retry_on_status=(500, 502, 429, 503, 504),
  87. retry_on_exception=(aiohttp.ClientError, asyncio.TimeoutError)
  88. )
  89. super().__init__(
  90. base_url=base_url,
  91. api_key=api_key,
  92. timeout=timeout,
  93. retry_config=retry_config,
  94. **kwargs
  95. )
  96. # 保存图片生成相关配置
  97. self.model = model
  98. self.sequential_generation = sequential_generation
  99. self.response_format = response_format
  100. self.stream = stream
  101. self.watermark = watermark
  102. logger.info(f"ARK图片生成API异步客户端初始化完成,模型: {self.model}")
  103. async def create_image_task(
  104. self,
  105. prompt: str,
  106. size: str = "1440x2560",
  107. reference_image: Optional[List[str]] = None,
  108. **kwargs
  109. ) -> Dict[str, Any]:
  110. """
  111. 异步创建图片生成任务
  112. Args:
  113. prompt: 图片生成提示词(必填)
  114. size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
  115. reference_image: 参考图片列表,可以是:
  116. - 本地文件路径列表(会自动编码为base64)
  117. - HTTP/HTTPS URL列表
  118. - base64编码的字符串列表(包含data:image/...;base64,前缀)
  119. 如果为None,则生成无参考图片列表
  120. **kwargs: 其他请求参数(会覆盖默认配置)
  121. Returns:
  122. API响应数据,包含生成的图片信息
  123. Raises:
  124. APIError: 如果请求失败
  125. ValueError: 如果参数无效
  126. """
  127. if not prompt or not prompt.strip():
  128. raise ValueError("prompt不能为空")
  129. # 构建请求体
  130. request_data = {
  131. "model": kwargs.get("model", self.model),
  132. "prompt": prompt,
  133. "size": size,
  134. "sequential_image_generation": kwargs.get("sequential_generation", self.sequential_generation),
  135. "response_format": kwargs.get("response_format", self.response_format),
  136. "stream": kwargs.get("stream", self.stream),
  137. "watermark": kwargs.get("watermark", self.watermark),
  138. }
  139. # 如果有参考图片,添加到请求中
  140. if reference_image and len(reference_image) > 0:
  141. # 判断是本地文件路径还是URL
  142. if reference_image[0].startswith(("http://", "https://")):
  143. # URL格式,直接使用
  144. request_data["image"] = reference_image
  145. elif reference_image[0].startswith("data:image"):
  146. # 已经是base64格式,直接使用
  147. request_data["image"] = reference_image
  148. else:
  149. # 本地文件路径,编码为base64(使用线程池避免阻塞事件循环)
  150. loop = asyncio.get_event_loop()
  151. request_data["image"] = [await loop.run_in_executor(None, encode_image_to_base64, image) for image in reference_image]
  152. logger.info(f"创建异步图片生成任务,模型: {request_data['model']}, 尺寸: {size}")
  153. if reference_image and len(reference_image) > 0:
  154. logger.info(f"使用参考图片: {reference_image[:50]}...")
  155. else:
  156. logger.info("未使用参考图片")
  157. # 记录请求数据(用于调试)
  158. logger.debug(f"请求数据: {request_data}")
  159. try:
  160. response = await self.post(
  161. endpoint=self.DEFAULT_ENDPOINT,
  162. json=request_data
  163. )
  164. logger.info("图片生成任务创建成功")
  165. return response
  166. except APIError as e:
  167. logger.error(f"创建图片生成任务失败: {e}")
  168. logger.error(f"请求数据: {request_data}")
  169. if e.response:
  170. logger.error(f"API错误响应: {e.response}")
  171. raise
  172. async def query_image_task(self, task_id: str) -> Dict[str, Any]:
  173. """
  174. 查询图片生成任务状态
  175. 注意:图片生成API通常是同步的,此方法主要用于接口一致性。
  176. 如果任务已完成,直接返回结果;否则返回待处理状态。
  177. Args:
  178. task_id: 任务ID(对于图片生成,这通常是响应中的某个标识符)
  179. Returns:
  180. 任务状态详情,包含图片URL等信息
  181. Raises:
  182. APIError: 如果请求失败
  183. ValueError: 如果参数无效
  184. """
  185. # 图片生成API通常是同步的,不需要查询
  186. # 此方法主要用于接口一致性
  187. logger.warning("图片生成API是同步的,query_image_task方法可能不适用")
  188. raise NotImplementedError("图片生成API是同步的,不需要查询任务状态")
  189. async def wait_for_task(
  190. self,
  191. task_id: str,
  192. callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None
  193. ) -> Dict[str, Any]:
  194. """
  195. 等待任务完成
  196. 注意:图片生成API通常是同步的,此方法主要用于接口一致性。
  197. 对于图片生成,任务通常在create_image_task时就已经完成。
  198. Args:
  199. task_id: 任务ID
  200. callback: 可选的回调函数,参数为 (task_id, result, error)
  201. Returns:
  202. 任务完成后的结果
  203. Raises:
  204. APIError: 如果请求失败
  205. """
  206. # 图片生成API通常是同步的,不需要等待
  207. # 此方法主要用于接口一致性
  208. logger.warning("图片生成API是同步的,wait_for_task方法可能不适用")
  209. raise NotImplementedError("图片生成API是同步的,不需要等待任务完成")
  210. async def create_image_task_async(
  211. self,
  212. prompt: str,
  213. size: str = "1440x2560",
  214. reference_image: Optional[List[str]] = None,
  215. callback: Optional[Callable] = handle_image_result,
  216. output_path: Optional[str] = None,
  217. **kwargs
  218. ) -> Optional[str]:
  219. """
  220. 创建图片生成任务并在后台任务中处理(不阻塞主流程)
  221. 任务会在后台异步任务中执行,完成后调用回调函数。
  222. Args:
  223. prompt: 图片生成提示词(必填)
  224. size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
  225. reference_image: 参考图片列表(可选)
  226. callback: 可选的回调函数,可以是以下两种签名之一:
  227. 1. (task_id, result, error) -> None
  228. 2. (task_id, output_path, result, error) -> None
  229. output_path: 图片输出路径(可选,会传递给回调函数)
  230. **kwargs: 其他请求参数(会覆盖默认配置)
  231. Returns:
  232. 任务ID(task_id),如果创建失败则返回None
  233. Raises:
  234. APIError: 如果创建任务失败
  235. """
  236. # 生成一个简单的任务ID(基于时间戳)
  237. import time
  238. task_id = f"img_{int(time.time() * 1000)}"
  239. async def _background_task():
  240. """后台任务:执行图片生成并调用回调"""
  241. try:
  242. # 创建图片生成任务
  243. result = await self.create_image_task(
  244. prompt=prompt,
  245. size=size,
  246. reference_image=reference_image,
  247. **kwargs
  248. )
  249. # 调用回调函数
  250. if callback:
  251. import inspect
  252. sig = inspect.signature(callback)
  253. param_count = len(sig.parameters)
  254. if param_count == 4:
  255. # 4参数版本:(task_id, output_path, result, error)
  256. callback(task_id, output_path or "", result, None)
  257. else:
  258. # 3参数版本:(task_id, result, error)
  259. callback(task_id, result, None)
  260. except Exception as e:
  261. error_msg = str(e)
  262. logger.error(f"后台图片生成任务失败: {error_msg}")
  263. if callback:
  264. import inspect
  265. sig = inspect.signature(callback)
  266. param_count = len(sig.parameters)
  267. if param_count == 4:
  268. callback(task_id, output_path or "", {}, error_msg)
  269. else:
  270. callback(task_id, {}, error_msg)
  271. # 启动后台任务
  272. asyncio.create_task(_background_task())
  273. logger.info(f"图片生成任务已提交,task_id: {task_id},后台处理中...")
  274. return task_id
  275. async def create_and_wait(
  276. self,
  277. prompt: str,
  278. size: str = "1440x2560",
  279. reference_image: Optional[List[str]] = None,
  280. callback: Optional[Callable[[str, Dict[str, Any], Optional[str]], None]] = None,
  281. **kwargs
  282. ) -> Dict[str, Any]:
  283. """
  284. 创建图片生成任务并等待完成(便捷方法)
  285. Args:
  286. prompt: 图片生成提示词(必填)
  287. size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
  288. reference_image: 参考图片列表(可选)
  289. callback: 可选的回调函数,参数为 (task_id, result, error)
  290. **kwargs: 其他请求参数
  291. Returns:
  292. 任务完成后的结果
  293. Raises:
  294. APIError: 如果请求失败
  295. """
  296. # 创建任务(图片生成是同步的,所以直接返回结果)
  297. result = await self.create_image_task(
  298. prompt=prompt,
  299. size=size,
  300. reference_image=reference_image,
  301. **kwargs
  302. )
  303. # 生成一个简单的任务ID
  304. import time
  305. task_id = f"img_{int(time.time() * 1000)}"
  306. logger.info(f"图片生成任务完成,任务ID: {task_id}")
  307. # 调用回调函数(如果提供)
  308. if callback:
  309. if asyncio.iscoroutinefunction(callback):
  310. await callback(task_id, result, None)
  311. else:
  312. callback(task_id, result, None)
  313. return result
  314. def get_image_url(self, response: Dict[str, Any]) -> Optional[str]:
  315. """
  316. 从响应中提取图片URL
  317. Args:
  318. response: API响应数据(从create_image_task或create_and_wait返回)
  319. Returns:
  320. 图片URL,如果不存在则返回None
  321. """
  322. try:
  323. if "data" in response and isinstance(response["data"], list):
  324. if len(response["data"]) > 0:
  325. image_data = response["data"][0]
  326. if isinstance(image_data, dict):
  327. # 根据response_format返回相应字段
  328. if self.response_format == "url":
  329. return image_data.get("url")
  330. elif self.response_format == "b64_json":
  331. return image_data.get("b64_json")
  332. return None
  333. except (KeyError, TypeError, IndexError) as e:
  334. logger.warning(f"提取图片URL失败: {e}")
  335. return None
  336. def get_image_urls(self, response: Dict[str, Any]) -> List[str]:
  337. """
  338. 从响应中提取所有图片URL
  339. Args:
  340. response: API响应数据
  341. Returns:
  342. 图片URL列表
  343. """
  344. urls = []
  345. try:
  346. if "data" in response and isinstance(response["data"], list):
  347. for image_data in response["data"]:
  348. if isinstance(image_data, dict):
  349. if self.response_format == "url":
  350. url = image_data.get("url")
  351. elif self.response_format == "b64_json":
  352. url = image_data.get("b64_json")
  353. else:
  354. url = image_data.get("url") or image_data.get("b64_json")
  355. if url:
  356. urls.append(url)
  357. return urls
  358. except (KeyError, TypeError, IndexError) as e:
  359. logger.warning(f"提取图片URL列表失败: {e}")
  360. return []
  361. def get_task_status(self, result: Dict[str, Any]) -> Optional[str]:
  362. """
  363. 从任务结果中提取任务状态
  364. Args:
  365. result: 任务结果(从create_image_task或create_and_wait返回)
  366. Returns:
  367. 任务状态字符串,如果不存在则返回None
  368. """
  369. try:
  370. # 图片生成API通常是同步的,如果返回了数据,则认为成功
  371. if result.get("data"):
  372. return "succeeded"
  373. return None
  374. except (KeyError, TypeError, AttributeError):
  375. return None
  376. # 保持向后兼容:generate_image作为便捷方法
  377. async def generate_image(
  378. self,
  379. prompt: str,
  380. size: str = "1440x2560",
  381. reference_image: Optional[List[str]] = None,
  382. **kwargs
  383. ) -> Dict[str, Any]:
  384. """
  385. 异步生成图片(便捷方法,等同于create_and_wait)
  386. Args:
  387. prompt: 图片生成提示词(必填)
  388. size: 图片尺寸,格式为"宽x高"(默认"1440x2560")
  389. reference_image: 参考图片列表(可选)
  390. **kwargs: 其他请求参数
  391. Returns:
  392. API响应数据,包含生成的图片信息
  393. Raises:
  394. APIError: 如果请求失败
  395. ValueError: 如果参数无效
  396. """
  397. return await self.create_and_wait(
  398. prompt=prompt,
  399. size=size,
  400. reference_image=reference_image,
  401. **kwargs
  402. )
  403. async def main():
  404. # 示例用法
  405. async with AsyncArkImageClient() as client:
  406. # 方式1:创建任务并等待完成(推荐)
  407. try:
  408. result = await client.create_and_wait(
  409. prompt="图1中的女生穿着图2中的衣服在街道上散步,目视前方,手牵着一只小狗",
  410. reference_image=["./data/image/face.jpg", "./data/image/cloth.jpg"],
  411. size="1440x2560"
  412. )
  413. image_url = client.get_image_url(result)
  414. print(f"图片生成成功,URL: {image_url}")
  415. except Exception as e:
  416. print(f"图片生成失败: {e}")
  417. # 方式2:使用便捷方法generate_image
  418. # response = await client.generate_image(
  419. # prompt="一个美丽的风景",
  420. # size="1440x2560"
  421. # )
  422. # image_url = client.get_image_url(response)
  423. # print(f"图片URL: {image_url}")
  424. # 方式3:异步创建任务(后台处理)
  425. # task_id = await client.create_image_task_async(
  426. # prompt="一个美丽的风景",
  427. # callback=handle_image_result,
  428. # output_path="./output/image.jpg"
  429. # )
  430. # print(f"任务已提交,task_id: {task_id}")
  431. if __name__ == "__main__":
  432. asyncio.run(main())