import base64 import io import json import os from typing import List, Optional from ..logger_setup import logger from ..upload_tos import upload_image import requests from PIL import Image DEFAULT_API_BASE_URL = "https://api.openaius.com" # GOOGLE_API_KEY class GeminiAPIError(RuntimeError): pass def _plugin_dir() -> str: return os.path.dirname(__file__) def _get_api_key() -> str: # 1) Environment variables override key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") if key: return key raise GeminiAPIError( "Google Gemini API key not found. Provide it via environment variable 'GOOGLE_API_KEY' (or 'GEMINI_API_KEY')," " config file 'gemini_config.json' (field 'api_key'), or a 'gemini_api_key.txt' file in the plugin directory." ) def _get_base_url() -> str: # 2) Base URL via env base = os.environ.get("GOOGLE_API_BASE_URL") or os.environ.get("GEMINI_API_BASE_URL") if base: return base return DEFAULT_API_BASE_URL def _build_full_endpoint(model: str) -> str: base = _get_base_url().rstrip("/") return f"{base}/v1beta/models/{model}:generateContent" def _apply_auth(headers: dict, params: dict, api_key: str) -> None: """Apply authentication to headers or query according to env/config. Priority: - Env vars (header) - Config (auth_header_name/auth_header_value_template) - Query param (env GEMINI_QUERY_PARAM_NAME) or config (query_param_name) - Default query param 'key' """ cfg = {} # Extra headers from config extra_headers = cfg.get("extra_headers") or {} if isinstance(extra_headers, dict): headers.update({str(k): str(v) for k, v in extra_headers.items()}) # Auth via header (env) env_hdr_name = os.environ.get("GEMINI_AUTH_HEADER_NAME") env_hdr_value = os.environ.get("GEMINI_AUTH_HEADER_VALUE") if env_hdr_name and env_hdr_value: headers[str(env_hdr_name)] = env_hdr_value.format(api_key=api_key) return # Auth via header (config) hdr_name = cfg.get("auth_header_name") hdr_value_tmpl = cfg.get("auth_header_value_template") if hdr_name and hdr_value_tmpl: headers[str(hdr_name)] = str(hdr_value_tmpl).format(api_key=api_key) return # Auth via query param (env or config) query_param = os.environ.get("GEMINI_QUERY_PARAM_NAME") or cfg.get("query_param_name") or "key" params[str(query_param)] = api_key def _encode_image_to_base64(img: Image.Image) -> dict: from utils.image_io import pil_to_png_bytes png_bytes = pil_to_png_bytes(img) return { "mime_type": "image/png", "data": base64.b64encode(png_bytes).decode("utf-8"), } def _build_payload(prompt: str, images: List[Image.Image], seed: Optional[int] = None) -> dict: parts: List[dict] = [{"text": prompt}] for img in images: parts.append({"inline_data": _encode_image_to_base64(img)}) generation_config = { "response_mime_type": "image/png", } if seed is not None: # Include commonly seen keys for better proxy compatibility try: generation_config["seed"] = int(seed) generation_config["random_seed"] = int(seed) except Exception: pass payload = { "contents": [ { "role": "user", "parts": parts, } ], # Match requirement: request both IMAGE and TEXT modalities, but we expect image primary "generationConfig": generation_config, "responseModalities": ["IMAGE", "TEXT"], } return payload def _extract_image_bytes_from_response(resp_json: dict) -> bytes: # Typical success path: candidates[0].content.parts[*].inline_data.data (base64) candidates = resp_json.get("candidates") or [] if not candidates: # Check for promptFeedback block feedback = resp_json.get("promptFeedback") or {} block_reason = feedback.get("blockReason") or feedback.get("block_reason") if block_reason: raise GeminiAPIError(f"Gemini blocked the request: {block_reason}") raise GeminiAPIError("Gemini API returned no candidates.") # inspect first candidate with inline image data for cand in candidates: content = cand.get("content") or {} parts = content.get("parts") or [] for part in parts: inline = part.get("inline_data") or part.get("inlineData") if inline and isinstance(inline, dict): mime = inline.get("mime_type") or inline.get("mimeType") or "" data_b64 = inline.get("data") if data_b64 and mime.startswith("image/"): return base64.b64decode(data_b64) # Fallback: check ground-level parts contents = resp_json.get("contents") or [] for content in contents: for part in content.get("parts", []): inline = part.get("inline_data") or part.get("inlineData") if inline and isinstance(inline, dict): mime = inline.get("mime_type") or inline.get("mimeType") or "" data_b64 = inline.get("data") if data_b64 and mime.startswith("image/"): return base64.b64decode(data_b64) # If we reach here, extract helpful diagnostics finish_reason = candidates[0].get("finishReason") if candidates else None raise GeminiAPIError( f"No image found in response. finishReason={finish_reason}, raw={json.dumps(resp_json)[:800]}" ) def call_gemini_generate_image( prompt: str, images: List[Image.Image], model: str = "gemini-2.5-flash-image", api_key: Optional[str] = None, timeout: float = 60.0, seed: Optional[int] = None, ) -> bytes: """Call Gemini API to generate an image response from prompt + images. Returns PNG bytes on success, raises GeminiAPIError on failure. """ key = api_key or _get_api_key() url = _build_full_endpoint(model) print(url) headers = { "Content-Type": "application/json", } params = {} _apply_auth(headers, params, key) payload = _build_payload(prompt, images, seed=seed) try: resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout) except requests.RequestException as ex: raise GeminiAPIError(f"Network error calling Gemini API: {ex}") if resp.status_code != 200: # Try to parse error payload try: err_json = resp.json() except Exception: err_json = None if err_json: # Common format: {"error": {"code":..., "message":..., "status":...}} err = err_json.get("error") or {} message = err.get("message") or json.dumps(err_json) raise GeminiAPIError(f"Gemini API error {resp.status_code}: {message}") raise GeminiAPIError(f"Gemini API error {resp.status_code}: {resp.text}") try: resp_json = resp.json() except ValueError as ex: raise GeminiAPIError(f"Failed to parse Gemini response JSON: {ex}") img_bytes=_extract_image_bytes_from_response(resp_json) img = Image.open(io.BytesIO(img_bytes)) img_url=upload_image(img) return img_url def call_gemini_generate_image_from_images( prompt: str, image_urls: List[str], model: str = "gemini-2.5-flash-image", resolution: str = "1K", api_key: Optional[str] = None, timeout: float = 60.0, seed: Optional[int] = None, ) -> Optional[Image.Image]: """ 从图片URL列表生成图片 Args: prompt: 提示词 image_urls: 图片URL列表 model: 使用的模型 resolution: 分辨率 ("1K" 或 "2K") api_key: API密钥(可选) timeout: 超时时间 seed: 随机种子(可选) Returns: PIL Image对象,失败返回None """ try: # 下载图片URL并转换为PIL Image images = [] for image_url in image_urls: try: response = requests.get(image_url, timeout=30) response.raise_for_status() img = Image.open(io.BytesIO(response.content)).convert("RGB") images.append(img) except Exception as e: logger.error(f"下载图片失败 {image_url}: {e}") return None if not images: logger.error("没有成功下载任何图片") return None # 调用生成函数 key = api_key or _get_api_key() url = _build_full_endpoint(model) headers = { "Content-Type": "application/json", } params = {} _apply_auth(headers, params, key) payload = _build_payload(prompt, images, seed=seed) try: resp = requests.post(url, headers=headers, params=params, json=payload, timeout=timeout) except requests.RequestException as ex: logger.error(f"网络错误: {ex}") return None if resp.status_code != 200: try: err_json = resp.json() err = err_json.get("error") or {} message = err.get("message") or json.dumps(err_json) logger.error(f"Gemini API错误 {resp.status_code}: {message}") except Exception: logger.error(f"Gemini API错误 {resp.status_code}: {resp.text}") return None try: resp_json = resp.json() except ValueError as ex: logger.error(f"解析响应JSON失败: {ex}") return None img_bytes = _extract_image_bytes_from_response(resp_json) img = Image.open(io.BytesIO(img_bytes)) return img except Exception as e: logger.error(f"生成图片时出错: {e}") return None if __name__=="__main__": image=r"D:\ppt\企业微信截图_1760499008629.png" image=[Image.open(image)] res=call_gemini_generate_image("把这条上衣变成红色",image,api_key="sk-qnsfJw0vsAitlnrXcOeBYrLbTv9LXfsN1m3jIUfMJagan5IR") with open("output.png", "wb") as f: f.write(res)