| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- import base64
- import io
- import json
- import os
- from typing import List, Optional
- from ..logger_setup import logger
- from ..upload_tos import upload_image
- import requests
- from PIL import Image
- DEFAULT_API_BASE_URL = "https://api.openaius.com"
- # GOOGLE_API_KEY
- class GeminiAPIError(RuntimeError):
- pass
- def _plugin_dir() -> str:
- return os.path.dirname(__file__)
- def _get_api_key() -> str:
- # 1) Environment variables override
- key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY")
- if key:
- return key
- raise GeminiAPIError(
- "Google Gemini API key not found. Provide it via environment variable 'GOOGLE_API_KEY' (or 'GEMINI_API_KEY'),"
- " config file 'gemini_config.json' (field 'api_key'), or a 'gemini_api_key.txt' file in the plugin directory."
- )
- def _get_base_url() -> str:
- # 2) Base URL via env
- base = os.environ.get("GOOGLE_API_BASE_URL") or os.environ.get("GEMINI_API_BASE_URL")
- if base:
- return base
- return DEFAULT_API_BASE_URL
- def _build_full_endpoint(model: str) -> str:
- base = _get_base_url().rstrip("/")
- return f"{base}/v1beta/models/{model}:generateContent"
- def _apply_auth(headers: dict, params: dict, api_key: str) -> None:
- """Apply authentication to headers or query according to env/config.
- Priority:
- - Env vars (header)
- - Config (auth_header_name/auth_header_value_template)
- - Query param (env GEMINI_QUERY_PARAM_NAME) or config (query_param_name)
- - Default query param 'key'
- """
- cfg = {}
- # Extra headers from config
- extra_headers = cfg.get("extra_headers") or {}
- if isinstance(extra_headers, dict):
- headers.update({str(k): str(v) for k, v in extra_headers.items()})
- # Auth via header (env)
- env_hdr_name = os.environ.get("GEMINI_AUTH_HEADER_NAME")
- env_hdr_value = os.environ.get("GEMINI_AUTH_HEADER_VALUE")
- if env_hdr_name and env_hdr_value:
- headers[str(env_hdr_name)] = env_hdr_value.format(api_key=api_key)
- return
- # Auth via header (config)
- hdr_name = cfg.get("auth_header_name")
- hdr_value_tmpl = cfg.get("auth_header_value_template")
- if hdr_name and hdr_value_tmpl:
- headers[str(hdr_name)] = str(hdr_value_tmpl).format(api_key=api_key)
- return
- # Auth via query param (env or config)
- query_param = os.environ.get("GEMINI_QUERY_PARAM_NAME") or cfg.get("query_param_name") or "key"
- params[str(query_param)] = api_key
- def _encode_image_to_base64(img: Image.Image) -> dict:
- from utils.image_io import pil_to_png_bytes
- png_bytes = pil_to_png_bytes(img)
- return {
- "mime_type": "image/png",
- "data": base64.b64encode(png_bytes).decode("utf-8"),
- }
- def _build_payload(prompt: str, images: List[Image.Image], seed: Optional[int] = None) -> dict:
- parts: List[dict] = [{"text": prompt}]
- for img in images:
- parts.append({"inline_data": _encode_image_to_base64(img)})
- generation_config = {
- "response_mime_type": "image/png",
- }
- if seed is not None:
- # Include commonly seen keys for better proxy compatibility
- try:
- generation_config["seed"] = int(seed)
- generation_config["random_seed"] = int(seed)
- except Exception:
- pass
- payload = {
- "contents": [
- {
- "role": "user",
- "parts": parts,
- }
- ],
- # Match requirement: request both IMAGE and TEXT modalities, but we expect image primary
- "generationConfig": generation_config,
- "responseModalities": ["IMAGE", "TEXT"],
- }
- return payload
- def _extract_image_bytes_from_response(resp_json: dict) -> bytes:
- # Typical success path: candidates[0].content.parts[*].inline_data.data (base64)
- candidates = resp_json.get("candidates") or []
- if not candidates:
- # Check for promptFeedback block
- feedback = resp_json.get("promptFeedback") or {}
- block_reason = feedback.get("blockReason") or feedback.get("block_reason")
- if block_reason:
- raise GeminiAPIError(f"Gemini blocked the request: {block_reason}")
- raise GeminiAPIError("Gemini API returned no candidates.")
- # inspect first candidate with inline image data
- for cand in candidates:
- content = cand.get("content") or {}
- parts = content.get("parts") or []
- for part in parts:
- inline = part.get("inline_data") or part.get("inlineData")
- if inline and isinstance(inline, dict):
- mime = inline.get("mime_type") or inline.get("mimeType") or ""
- data_b64 = inline.get("data")
- if data_b64 and mime.startswith("image/"):
- return base64.b64decode(data_b64)
- # Fallback: check ground-level parts
- contents = resp_json.get("contents") or []
- for content in contents:
- for part in content.get("parts", []):
- inline = part.get("inline_data") or part.get("inlineData")
- if inline and isinstance(inline, dict):
- mime = inline.get("mime_type") or inline.get("mimeType") or ""
- data_b64 = inline.get("data")
- if data_b64 and mime.startswith("image/"):
- return base64.b64decode(data_b64)
- # If we reach here, extract helpful diagnostics
- finish_reason = candidates[0].get("finishReason") if candidates else None
- raise GeminiAPIError(
- f"No image found in response. finishReason={finish_reason}, raw={json.dumps(resp_json)[:800]}"
- )
- def call_gemini_generate_image(
- prompt: str,
- images: List[Image.Image],
- model: str = "gemini-2.5-flash-image",
- api_key: Optional[str] = None,
- timeout: float = 60.0,
- seed: Optional[int] = None,
- ) -> bytes:
- """Call Gemini API to generate an image response from prompt + images.
- Returns PNG bytes on success, raises GeminiAPIError on failure.
- """
- key = api_key or _get_api_key()
- url = _build_full_endpoint(model)
- print(url)
- headers = {
- "Content-Type": "application/json",
- }
- params = {}
- _apply_auth(headers, params, key)
- payload = _build_payload(prompt, images, seed=seed)
- try:
- resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout)
- except requests.RequestException as ex:
- raise GeminiAPIError(f"Network error calling Gemini API: {ex}")
- if resp.status_code != 200:
- # Try to parse error payload
- try:
- err_json = resp.json()
- except Exception:
- err_json = None
- if err_json:
- # Common format: {"error": {"code":..., "message":..., "status":...}}
- err = err_json.get("error") or {}
- message = err.get("message") or json.dumps(err_json)
- raise GeminiAPIError(f"Gemini API error {resp.status_code}: {message}")
- raise GeminiAPIError(f"Gemini API error {resp.status_code}: {resp.text}")
- try:
- resp_json = resp.json()
- except ValueError as ex:
- raise GeminiAPIError(f"Failed to parse Gemini response JSON: {ex}")
- img_bytes=_extract_image_bytes_from_response(resp_json)
- img = Image.open(io.BytesIO(img_bytes))
- img_url=upload_image(img)
- return img_url
- def call_gemini_generate_image_from_images(
- prompt: str,
- image_urls: List[str],
- model: str = "gemini-2.5-flash-image",
- resolution: str = "1K",
- api_key: Optional[str] = None,
- timeout: float = 60.0,
- seed: Optional[int] = None,
- ) -> Optional[Image.Image]:
- """
- 从图片URL列表生成图片
-
- Args:
- prompt: 提示词
- image_urls: 图片URL列表
- model: 使用的模型
- resolution: 分辨率 ("1K" 或 "2K")
- api_key: API密钥(可选)
- timeout: 超时时间
- seed: 随机种子(可选)
-
- Returns:
- PIL Image对象,失败返回None
- """
- try:
- # 下载图片URL并转换为PIL Image
- images = []
- for image_url in image_urls:
- try:
- response = requests.get(image_url, timeout=30)
- response.raise_for_status()
- img = Image.open(io.BytesIO(response.content)).convert("RGB")
- images.append(img)
- except Exception as e:
- logger.error(f"下载图片失败 {image_url}: {e}")
- return None
-
- if not images:
- logger.error("没有成功下载任何图片")
- return None
-
- # 调用生成函数
- key = api_key or _get_api_key()
- url = _build_full_endpoint(model)
- headers = {
- "Content-Type": "application/json",
- }
- params = {}
- _apply_auth(headers, params, key)
- payload = _build_payload(prompt, images, seed=seed)
- try:
- resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout)
- except requests.RequestException as ex:
- logger.error(f"网络错误: {ex}")
- return None
- if resp.status_code != 200:
- try:
- err_json = resp.json()
- err = err_json.get("error") or {}
- message = err.get("message") or json.dumps(err_json)
- logger.error(f"Gemini API错误 {resp.status_code}: {message}")
- except Exception:
- logger.error(f"Gemini API错误 {resp.status_code}: {resp.text}")
- return None
- try:
- resp_json = resp.json()
- except ValueError as ex:
- logger.error(f"解析响应JSON失败: {ex}")
- return None
- img_bytes = _extract_image_bytes_from_response(resp_json)
- img = Image.open(io.BytesIO(img_bytes))
- return img
-
- except Exception as e:
- logger.error(f"生成图片时出错: {e}")
- return None
- if __name__=="__main__":
- image=r"D:\ppt\企业微信截图_1760499008629.png"
- image=[Image.open(image)]
- res=call_gemini_generate_image("把这条上衣变成红色",image,api_key="sk-qnsfJw0vsAitlnrXcOeBYrLbTv9LXfsN1m3jIUfMJagan5IR")
- with open("output.png", "wb") as f:
- f.write(res)
|