| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- from typing import List, Optional,Any
- from datetime import datetime
- import requests
- import os
- from PIL import Image
- import io
- from fastapi import FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel, Field
- # import fal_client
- from scr.check import process_image_pair_with_gemini
- from scr.conf import size_dict,check_prompt
- from scr.upload_tos import process_cropped_upload
- from scr.gemini_client_request import call_gemini_generate_image, call_gemini_generate_image_from_images
- from scr.utils.image_io import pil_to_png_bytes
- from scr.llm import llm_request
- from scr.logger_setup import logger
- from scr.sketch import generate_sketch
- from dotenv import load_dotenv
- load_dotenv()
- # FAL API 配置
- FAL_KEY = os.getenv("FAL_KEY")
- if FAL_KEY:
- os.environ["FAL_KEY"] = FAL_KEY
- else:
- logger.warning("FAL_KEY 未在环境变量中设置,相关功能可能不可用")
- # LLM 配置 - 用于指令优化
- LLM_API_KEY = os.getenv("LLM_API_KEY")
- LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
- LLM_MODEL = os.getenv("LLM_MODEL", "qwen-vl-max-latest")
- if not LLM_API_KEY:
- logger.warning("LLM_API_KEY 未在环境变量中设置,指令优化功能不可用")
- app = FastAPI(title="Nano Banana Image Service", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class EditImageRequest(BaseModel):
- prompt: str = Field(..., description="用于编辑图像的提示词")
- image_urls: List[str] = Field(..., description="用于编辑的原始图像URL列表")
- ratio: Optional[tuple[int, int]] = Field(None, description="编辑比例(例如 '16:9',可选)")
- image_size: Optional[tuple[int, int]] = Field(None, description="图像大小(可选)")
- # class EditImageRequest(BaseModel):
- # prompt: str = Field(..., description="用于编辑图像的提示词")
- # model_image_urls: List[str] = Field(..., description="需要换装的图像URL列表")
- # garment_image_urls: List[str] = Field(..., description="用于编辑的原始图像URL列表")
- # ratio: Optional[tuple[int, int]] = Field(None, description="编辑比例(例如 '16:9',可选)")
- # image_size: Optional[tuple[int, int]] = Field(None, description="图像大小(可选)")
- class GenerateImageRequest(BaseModel):
- prompt: str = Field(..., description="用于文生图的提示词")
- class GenerateSketchRequest(BaseModel):
- image: str = Field(..., description="款式图片URL")
- prompt: Optional[str] = Field("", description="提示词(可选)")
- auto_mode: Optional[bool] = Field(False, description="是否自动模式(可选)")
- multi_garment: Optional[bool] = Field(False, description="是否多服装模式(可选)")
- class OptimizeInstructionRequest(BaseModel):
- instruction: str = Field(..., description="用户输入的原始编辑指令(中文)")
-
- # 初始化 LLM 实例(用于指令优化)
- llm_optimizer = None
- def get_llm_optimizer():
- """获取或创建 LLM 优化器实例"""
- global llm_optimizer
- if llm_optimizer is None:
- if not LLM_API_KEY:
- raise RuntimeError("指令优化器不可用:缺少 LLM_API_KEY 环境变量")
- llm_optimizer = llm_request(LLM_API_KEY, LLM_BASE_URL, LLM_MODEL)
- logger.info("LLM 指令优化器已初始化")
- return llm_optimizer
- # def _on_queue_update(update):
- # if isinstance(update, fal_client.InProgress):
- # logs = getattr(update, "logs", None)
- # if logs:
- # for log in logs:
- # msg = log.get('message')
- # logger.info(f"处理中: {msg}")
- # from typing import Any
- # def _extract_first_image_url(result: dict[str, Any]) -> Optional[str]:
- # if not isinstance(result, dict):
- # return None
- # images = result.get("images") or []
- # if images and isinstance(images, list):
- # first = images[0]
- # if isinstance(first, dict):
- # return first.get("url") or first.get("result_url")
- # return result.get("url") or result.get("result_url")
- @app.get("/health")
- def health():
- logger.info("health check")
- return {"status": "ok"}
- @app.post("/optimize_instruction")
- def optimize_instruction(req: OptimizeInstructionRequest):
- """
- 优化用户的图片编辑指令
- 将中文描述转换为 Nano Banana 可理解的英文指令,并自动添加细节保留要求
- """
- logger.info(f"/optimize_instruction 请求: instruction={req.instruction}")
- try:
- # 获取 LLM 优化器
- optimizer = get_llm_optimizer()
-
- # 优化指令
- optimized = optimizer.optimize_edit_instruction(req.instruction)
-
- logger.info(f"/optimize_instruction 优化结果: {optimized}")
-
- return {
- "success": True,
- "original_instruction": req.instruction,
- "optimized_instruction": optimized,
- }
- except Exception as e:
- logger.exception("指令优化接口异常")
- return {
- "success": False,
- "error": str(e),
- "original_instruction": req.instruction
- }
- def edit_images_api(
- prompt: str,
- image_urls: List[str],
- aspect_ratio: Optional[List[int]] = None,
- resolution: str = "1K"
- ) -> Optional[Image.Image]:
- """
- 使用 Gemini API 编辑图片
-
- Args:
- prompt: 编辑提示词
- image_urls: 图片URL列表
- aspect_ratio: 宽高比(可选,目前未使用,保留用于兼容)
- resolution: 分辨率 ("1K" 或 "2K"),目前未使用,保留用于兼容
-
- Returns:
- PIL Image对象,失败返回None
- """
- try:
- logger.info(f"调用 edit_images_api: prompt={prompt[:100]}..., image_count={len(image_urls)}")
-
- # 使用 Gemini API 生成/编辑图片
- pil_image = call_gemini_generate_image_from_images(
- prompt=prompt,
- image_urls=image_urls,
- model="gemini-2.5-flash-image",
- resolution=resolution
- )
-
- if pil_image is None:
- logger.error("Gemini API 返回 None")
- return None
-
- # 如果指定了宽高比,可以在这里调整图片尺寸
- # 注意:Gemini API 本身不支持 aspect_ratio 参数,所以这里只是兼容性处理
- if aspect_ratio and len(aspect_ratio) == 2:
- # 可以在这里添加图片尺寸调整逻辑
- pass
-
- return pil_image
-
- except Exception as e:
- logger.error(f"edit_images_api 调用失败: {e}")
- return None
- @app.post("/edit_image")
- def edit_image(req: EditImageRequest):
- logger.info(f"/edit_image 请求: prompt={req.prompt[:200]}...,图片链接:{req.image_urls},图片尺寸:{req.image_size},图片比例:{req.ratio} image_count={len(req.image_urls)}")
- try:
- pil_image = None
- if not req.image_urls:
- return {"success": False, "error": "图片链接不能为空"}
- if req.image_size:
- w, h = req.image_size
- ratio = (w, h)
- ar = f"{w}x{h}"
- aspect_ratio = size_dict.get(ar, [1, 1]) # 使用 get 避免 KeyError
- # 先上传获取URL,用于后续编辑
-
- pil_image = edit_images_api(
- req.prompt,
- req.image_urls,
- aspect_ratio=aspect_ratio,
- resolution="2K"
- )
- if pil_image is None:
- return {"success": False, "error": "编辑图片失败"}
- pil_image = pil_image.resize(req.image_size)
- logger.info(f"图片已resize到: {req.image_size}")
- elif req.ratio:
- pil_image = edit_images_api(
- req.prompt,
- req.image_urls,
- aspect_ratio=req.ratio,
- resolution="1K"
- )
- else:
- pil_image = edit_images_api(
- req.prompt,
- req.image_urls,
- resolution="1K"
- )
-
- if pil_image is None:
- return {"success": False, "error": "编辑图片失败"}
- # 上传最终图片并获取URL
- image_url = process_cropped_upload(pil_image)
- return {
- "success": image_url is not None,
- "description": "生成的图片如下",
- "image_url": image_url,
- "raw": {"url": image_url},
- }
- except Exception as e:
- logger.exception("编辑接口异常")
- return {"success": False, "error": str(e)}
- # @app.post("/edit_image_new")
- # def edit_image_new(req: EditImageRequest):
- # logger.info(f"/edit_image 请求: prompt={req.prompt[:200]}...,模特图片链接:{req.model_image_urls},衣服图片链接:{req.garment_image_urls},图片尺寸:{req.image_size},图片比例:{req.ratio} image_count={len(req.model_image_urls)}")
- # try:
- # if req.image_size:
- # w,h=req.image_size
- # ratio = (w,h)
- # ar = f"{w}x{h}"
- # aspect_ratio = size_dict[ar]
- # elif req.ratio:
-
- # aspect_ratio = req.ratio
- # else:
- # aspect_ratio = [1,1]
- # pil_image=None
- # pil_image_list=[]
- # if not req.model_image_urls:
- # return {"success": False, "error": "model_image_urls图片链接不能为空"}
- # if not req.garment_image_urls:
- # return {"success": False, "error": "garment_image_urls图片链接不能为空"}
- # for image_url in req.model_image_urls:
- # for garment_image_url in req.garment_image_urls:
- # pil_image=edit_images_api(req.prompt,[image_url,garment_image_url],aspect_ratio=aspect_ratio,resolution="2K")
- # if pil_image is None:
- # continue
- # if req.image_size:
- # pil_image=pil_image.resize(req.image_size)
-
- # temp_image_url=process_cropped_upload(pil_image)
- # if temp_image_url is None:
- # continue
- # logger.info(f"图片已上传: {temp_image_url}")
- # pil_image_list.append(temp_image_url)
-
- # if len(pil_image_list) == 0:
- # return {"success": False, "error": "编辑图片失败"}
-
- # return {
- # "success": len(pil_image_list) > 0,
- # "description": "生成的图片如下",
- # "image_url": pil_image_list,
- # "raw": {"url": pil_image_list},
- # }
- # except Exception as e:
- # logger.exception("编辑接口异常")
- # return {"success": False, "error": str(e)}
- @app.post("/text_to_image")
- def text_to_image(req: GenerateImageRequest):
- logger.info(f"/text_to_image 请求: prompt={req.prompt[:200]}...")
- try:
- pil_image=call_gemini_generate_image(req.prompt)
- if pil_image is None:
- return {"success": False, "error": "生成图片失败"}
-
- # 上传图片并获取URL
- image_url = process_cropped_upload(pil_image)
- result={"url":image_url}
- logger.info(f"/text_to_image 最终结果: {result}")
- return {
- "success": True,
- "image_url": image_url,
- "raw": result,
- }
- except Exception as e:
- logger.exception("文生图接口异常")
- return {"success": False, "error": str(e)}
- @app.post("/generate_sketch")
- def generate_sketch_api(req: GenerateSketchRequest):
- """
- 生成线稿图接口
-
- 使用scr.sketch模块生成线稿图,支持自动质量检查
- """
- logger.info(f"/generate_sketch 线稿图请求: image={req.image[:100] if req.image else 'None'}..., prompt={req.prompt[:100] if req.prompt else 'None'}")
-
- try:
- # 检查是否提供了款式图片
- if not req.image:
- return {"success": False, "error": "请上传款式图片"}
-
- # 调用线稿图生成模块
- sketch_url = generate_sketch(
- image_url=req.image,
- prompt=req.prompt if req.prompt else None,
- max_retries=2,
- auto_check=True
- )
-
- if sketch_url is None:
- return {"success": False, "error": "生成线稿图失败"}
-
- logger.info(f"/generate_sketch 最终结果: {sketch_url}")
-
- return {
- "success": True,
- "description": "线稿图生成成功",
- "image_url": sketch_url,
- "raw": {"url": sketch_url},
- }
-
- except Exception as e:
- logger.exception("线稿生成接口异常")
- return {"success": False, "error": str(e)}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=9002)
|