gemini_client_request.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import base64
  2. import io
  3. import json
  4. import os
  5. from typing import List, Optional
  6. from ..logger_setup import logger
  7. from ..upload_tos import upload_image
  8. import requests
  9. from PIL import Image
  10. DEFAULT_API_BASE_URL = "https://api.openaius.com"
  11. # GOOGLE_API_KEY
  12. class GeminiAPIError(RuntimeError):
  13. pass
  14. def _plugin_dir() -> str:
  15. return os.path.dirname(__file__)
  16. def _get_api_key() -> str:
  17. # 1) Environment variables override
  18. key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
  19. if key:
  20. return key
  21. raise GeminiAPIError(
  22. "Google Gemini API key not found. Provide it via environment variable 'GOOGLE_API_KEY' (or 'GEMINI_API_KEY'),"
  23. " config file 'gemini_config.json' (field 'api_key'), or a 'gemini_api_key.txt' file in the plugin directory."
  24. )
  25. def _get_base_url() -> str:
  26. # 2) Base URL via env
  27. base = os.environ.get("GOOGLE_API_BASE_URL") or os.environ.get("GEMINI_API_BASE_URL")
  28. if base:
  29. return base
  30. return DEFAULT_API_BASE_URL
  31. def _build_full_endpoint(model: str) -> str:
  32. base = _get_base_url().rstrip("/")
  33. return f"{base}/v1beta/models/{model}:generateContent"
  34. def _apply_auth(headers: dict, params: dict, api_key: str) -> None:
  35. """Apply authentication to headers or query according to env/config.
  36. Priority:
  37. - Env vars (header)
  38. - Config (auth_header_name/auth_header_value_template)
  39. - Query param (env GEMINI_QUERY_PARAM_NAME) or config (query_param_name)
  40. - Default query param 'key'
  41. """
  42. cfg = {}
  43. # Extra headers from config
  44. extra_headers = cfg.get("extra_headers") or {}
  45. if isinstance(extra_headers, dict):
  46. headers.update({str(k): str(v) for k, v in extra_headers.items()})
  47. # Auth via header (env)
  48. env_hdr_name = os.environ.get("GEMINI_AUTH_HEADER_NAME")
  49. env_hdr_value = os.environ.get("GEMINI_AUTH_HEADER_VALUE")
  50. if env_hdr_name and env_hdr_value:
  51. headers[str(env_hdr_name)] = env_hdr_value.format(api_key=api_key)
  52. return
  53. # Auth via header (config)
  54. hdr_name = cfg.get("auth_header_name")
  55. hdr_value_tmpl = cfg.get("auth_header_value_template")
  56. if hdr_name and hdr_value_tmpl:
  57. headers[str(hdr_name)] = str(hdr_value_tmpl).format(api_key=api_key)
  58. return
  59. # Auth via query param (env or config)
  60. query_param = os.environ.get("GEMINI_QUERY_PARAM_NAME") or cfg.get("query_param_name") or "key"
  61. params[str(query_param)] = api_key
  62. def _encode_image_to_base64(img: Image.Image) -> dict:
  63. from utils.image_io import pil_to_png_bytes
  64. png_bytes = pil_to_png_bytes(img)
  65. return {
  66. "mime_type": "image/png",
  67. "data": base64.b64encode(png_bytes).decode("utf-8"),
  68. }
  69. def _build_payload(prompt: str, images: List[Image.Image], seed: Optional[int] = None) -> dict:
  70. parts: List[dict] = [{"text": prompt}]
  71. for img in images:
  72. parts.append({"inline_data": _encode_image_to_base64(img)})
  73. generation_config = {
  74. "response_mime_type": "image/png",
  75. }
  76. if seed is not None:
  77. # Include commonly seen keys for better proxy compatibility
  78. try:
  79. generation_config["seed"] = int(seed)
  80. generation_config["random_seed"] = int(seed)
  81. except Exception:
  82. pass
  83. payload = {
  84. "contents": [
  85. {
  86. "role": "user",
  87. "parts": parts,
  88. }
  89. ],
  90. # Match requirement: request both IMAGE and TEXT modalities, but we expect image primary
  91. "generationConfig": generation_config,
  92. "responseModalities": ["IMAGE", "TEXT"],
  93. }
  94. return payload
  95. def _extract_image_bytes_from_response(resp_json: dict) -> bytes:
  96. # Typical success path: candidates[0].content.parts[*].inline_data.data (base64)
  97. candidates = resp_json.get("candidates") or []
  98. if not candidates:
  99. # Check for promptFeedback block
  100. feedback = resp_json.get("promptFeedback") or {}
  101. block_reason = feedback.get("blockReason") or feedback.get("block_reason")
  102. if block_reason:
  103. raise GeminiAPIError(f"Gemini blocked the request: {block_reason}")
  104. raise GeminiAPIError("Gemini API returned no candidates.")
  105. # inspect first candidate with inline image data
  106. for cand in candidates:
  107. content = cand.get("content") or {}
  108. parts = content.get("parts") or []
  109. for part in parts:
  110. inline = part.get("inline_data") or part.get("inlineData")
  111. if inline and isinstance(inline, dict):
  112. mime = inline.get("mime_type") or inline.get("mimeType") or ""
  113. data_b64 = inline.get("data")
  114. if data_b64 and mime.startswith("image/"):
  115. return base64.b64decode(data_b64)
  116. # Fallback: check ground-level parts
  117. contents = resp_json.get("contents") or []
  118. for content in contents:
  119. for part in content.get("parts", []):
  120. inline = part.get("inline_data") or part.get("inlineData")
  121. if inline and isinstance(inline, dict):
  122. mime = inline.get("mime_type") or inline.get("mimeType") or ""
  123. data_b64 = inline.get("data")
  124. if data_b64 and mime.startswith("image/"):
  125. return base64.b64decode(data_b64)
  126. # If we reach here, extract helpful diagnostics
  127. finish_reason = candidates[0].get("finishReason") if candidates else None
  128. raise GeminiAPIError(
  129. f"No image found in response. finishReason={finish_reason}, raw={json.dumps(resp_json)[:800]}"
  130. )
  131. def call_gemini_generate_image(
  132. prompt: str,
  133. images: List[Image.Image],
  134. model: str = "gemini-2.5-flash-image",
  135. api_key: Optional[str] = None,
  136. timeout: float = 60.0,
  137. seed: Optional[int] = None,
  138. ) -> bytes:
  139. """Call Gemini API to generate an image response from prompt + images.
  140. Returns PNG bytes on success, raises GeminiAPIError on failure.
  141. """
  142. key = api_key or _get_api_key()
  143. url = _build_full_endpoint(model)
  144. print(url)
  145. headers = {
  146. "Content-Type": "application/json",
  147. }
  148. params = {}
  149. _apply_auth(headers, params, key)
  150. payload = _build_payload(prompt, images, seed=seed)
  151. try:
  152. resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout)
  153. except requests.RequestException as ex:
  154. raise GeminiAPIError(f"Network error calling Gemini API: {ex}")
  155. if resp.status_code != 200:
  156. # Try to parse error payload
  157. try:
  158. err_json = resp.json()
  159. except Exception:
  160. err_json = None
  161. if err_json:
  162. # Common format: {"error": {"code":..., "message":..., "status":...}}
  163. err = err_json.get("error") or {}
  164. message = err.get("message") or json.dumps(err_json)
  165. raise GeminiAPIError(f"Gemini API error {resp.status_code}: {message}")
  166. raise GeminiAPIError(f"Gemini API error {resp.status_code}: {resp.text}")
  167. try:
  168. resp_json = resp.json()
  169. except ValueError as ex:
  170. raise GeminiAPIError(f"Failed to parse Gemini response JSON: {ex}")
  171. img_bytes=_extract_image_bytes_from_response(resp_json)
  172. img = Image.open(io.BytesIO(img_bytes))
  173. img_url=upload_image(img)
  174. return img_url
  175. def call_gemini_generate_image_from_images(
  176. prompt: str,
  177. image_urls: List[str],
  178. model: str = "gemini-2.5-flash-image",
  179. resolution: str = "1K",
  180. api_key: Optional[str] = None,
  181. timeout: float = 60.0,
  182. seed: Optional[int] = None,
  183. ) -> Optional[Image.Image]:
  184. """
  185. 从图片URL列表生成图片
  186. Args:
  187. prompt: 提示词
  188. image_urls: 图片URL列表
  189. model: 使用的模型
  190. resolution: 分辨率 ("1K" 或 "2K")
  191. api_key: API密钥(可选)
  192. timeout: 超时时间
  193. seed: 随机种子(可选)
  194. Returns:
  195. PIL Image对象,失败返回None
  196. """
  197. try:
  198. # 下载图片URL并转换为PIL Image
  199. images = []
  200. for image_url in image_urls:
  201. try:
  202. response = requests.get(image_url, timeout=30)
  203. response.raise_for_status()
  204. img = Image.open(io.BytesIO(response.content)).convert("RGB")
  205. images.append(img)
  206. except Exception as e:
  207. logger.error(f"下载图片失败 {image_url}: {e}")
  208. return None
  209. if not images:
  210. logger.error("没有成功下载任何图片")
  211. return None
  212. # 调用生成函数
  213. key = api_key or _get_api_key()
  214. url = _build_full_endpoint(model)
  215. headers = {
  216. "Content-Type": "application/json",
  217. }
  218. params = {}
  219. _apply_auth(headers, params, key)
  220. payload = _build_payload(prompt, images, seed=seed)
  221. try:
  222. resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout)
  223. except requests.RequestException as ex:
  224. logger.error(f"网络错误: {ex}")
  225. return None
  226. if resp.status_code != 200:
  227. try:
  228. err_json = resp.json()
  229. err = err_json.get("error") or {}
  230. message = err.get("message") or json.dumps(err_json)
  231. logger.error(f"Gemini API错误 {resp.status_code}: {message}")
  232. except Exception:
  233. logger.error(f"Gemini API错误 {resp.status_code}: {resp.text}")
  234. return None
  235. try:
  236. resp_json = resp.json()
  237. except ValueError as ex:
  238. logger.error(f"解析响应JSON失败: {ex}")
  239. return None
  240. img_bytes = _extract_image_bytes_from_response(resp_json)
  241. img = Image.open(io.BytesIO(img_bytes))
  242. return img
  243. except Exception as e:
  244. logger.error(f"生成图片时出错: {e}")
  245. return None
  246. if __name__=="__main__":
  247. image=r"D:\ppt\企业微信截图_1760499008629.png"
  248. image=[Image.open(image)]
  249. res=call_gemini_generate_image("把这条上衣变成红色",image,api_key="sk-qnsfJw0vsAitlnrXcOeBYrLbTv9LXfsN1m3jIUfMJagan5IR")
  250. with open("output.png", "wb") as f:
  251. f.write(res)