main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. from typing import List, Optional,Any
  2. from datetime import datetime
  3. import requests
  4. import os
  5. from PIL import Image
  6. import io
  7. from fastapi import FastAPI
  8. from fastapi.middleware.cors import CORSMiddleware
  9. from pydantic import BaseModel, Field
  10. # import fal_client
  11. from scr.check import process_image_pair_with_gemini
  12. from scr.conf import size_dict,check_prompt
  13. from scr.upload_tos import process_cropped_upload
  14. from scr.gemini_client_request import call_gemini_generate_image, call_gemini_generate_image_from_images
  15. from scr.utils.image_io import pil_to_png_bytes
  16. from scr.llm import llm_request
  17. from scr.logger_setup import logger
  18. from scr.sketch import generate_sketch
  19. from dotenv import load_dotenv
  20. load_dotenv()
  21. # FAL API 配置
  22. FAL_KEY = os.getenv("FAL_KEY")
  23. if FAL_KEY:
  24. os.environ["FAL_KEY"] = FAL_KEY
  25. else:
  26. logger.warning("FAL_KEY 未在环境变量中设置,相关功能可能不可用")
  27. # LLM 配置 - 用于指令优化
  28. LLM_API_KEY = os.getenv("LLM_API_KEY")
  29. LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
  30. LLM_MODEL = os.getenv("LLM_MODEL", "qwen-vl-max-latest")
  31. if not LLM_API_KEY:
  32. logger.warning("LLM_API_KEY 未在环境变量中设置,指令优化功能不可用")
  33. app = FastAPI(title="Nano Banana Image Service", version="1.0.0")
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. class EditImageRequest(BaseModel):
  42. prompt: str = Field(..., description="用于编辑图像的提示词")
  43. image_urls: List[str] = Field(..., description="用于编辑的原始图像URL列表")
  44. ratio: Optional[tuple[int, int]] = Field(None, description="编辑比例(例如 '16:9',可选)")
  45. image_size: Optional[tuple[int, int]] = Field(None, description="图像大小(可选)")
  46. # class EditImageRequest(BaseModel):
  47. # prompt: str = Field(..., description="用于编辑图像的提示词")
  48. # model_image_urls: List[str] = Field(..., description="需要换装的图像URL列表")
  49. # garment_image_urls: List[str] = Field(..., description="用于编辑的原始图像URL列表")
  50. # ratio: Optional[tuple[int, int]] = Field(None, description="编辑比例(例如 '16:9',可选)")
  51. # image_size: Optional[tuple[int, int]] = Field(None, description="图像大小(可选)")
  52. class GenerateImageRequest(BaseModel):
  53. prompt: str = Field(..., description="用于文生图的提示词")
  54. class GenerateSketchRequest(BaseModel):
  55. image: str = Field(..., description="款式图片URL")
  56. prompt: Optional[str] = Field("", description="提示词(可选)")
  57. auto_mode: Optional[bool] = Field(False, description="是否自动模式(可选)")
  58. multi_garment: Optional[bool] = Field(False, description="是否多服装模式(可选)")
  59. class OptimizeInstructionRequest(BaseModel):
  60. instruction: str = Field(..., description="用户输入的原始编辑指令(中文)")
  61. # 初始化 LLM 实例(用于指令优化)
  62. llm_optimizer = None
  63. def get_llm_optimizer():
  64. """获取或创建 LLM 优化器实例"""
  65. global llm_optimizer
  66. if llm_optimizer is None:
  67. if not LLM_API_KEY:
  68. raise RuntimeError("指令优化器不可用:缺少 LLM_API_KEY 环境变量")
  69. llm_optimizer = llm_request(LLM_API_KEY, LLM_BASE_URL, LLM_MODEL)
  70. logger.info("LLM 指令优化器已初始化")
  71. return llm_optimizer
  72. # def _on_queue_update(update):
  73. # if isinstance(update, fal_client.InProgress):
  74. # logs = getattr(update, "logs", None)
  75. # if logs:
  76. # for log in logs:
  77. # msg = log.get('message')
  78. # logger.info(f"处理中: {msg}")
  79. # from typing import Any
  80. # def _extract_first_image_url(result: dict[str, Any]) -> Optional[str]:
  81. # if not isinstance(result, dict):
  82. # return None
  83. # images = result.get("images") or []
  84. # if images and isinstance(images, list):
  85. # first = images[0]
  86. # if isinstance(first, dict):
  87. # return first.get("url") or first.get("result_url")
  88. # return result.get("url") or result.get("result_url")
  89. @app.get("/health")
  90. def health():
  91. logger.info("health check")
  92. return {"status": "ok"}
  93. @app.post("/optimize_instruction")
  94. def optimize_instruction(req: OptimizeInstructionRequest):
  95. """
  96. 优化用户的图片编辑指令
  97. 将中文描述转换为 Nano Banana 可理解的英文指令,并自动添加细节保留要求
  98. """
  99. logger.info(f"/optimize_instruction 请求: instruction={req.instruction}")
  100. try:
  101. # 获取 LLM 优化器
  102. optimizer = get_llm_optimizer()
  103. # 优化指令
  104. optimized = optimizer.optimize_edit_instruction(req.instruction)
  105. logger.info(f"/optimize_instruction 优化结果: {optimized}")
  106. return {
  107. "success": True,
  108. "original_instruction": req.instruction,
  109. "optimized_instruction": optimized,
  110. }
  111. except Exception as e:
  112. logger.exception("指令优化接口异常")
  113. return {
  114. "success": False,
  115. "error": str(e),
  116. "original_instruction": req.instruction
  117. }
  118. def edit_images_api(
  119. prompt: str,
  120. image_urls: List[str],
  121. aspect_ratio: Optional[List[int]] = None,
  122. resolution: str = "1K"
  123. ) -> Optional[Image.Image]:
  124. """
  125. 使用 Gemini API 编辑图片
  126. Args:
  127. prompt: 编辑提示词
  128. image_urls: 图片URL列表
  129. aspect_ratio: 宽高比(可选,目前未使用,保留用于兼容)
  130. resolution: 分辨率 ("1K" 或 "2K"),目前未使用,保留用于兼容
  131. Returns:
  132. PIL Image对象,失败返回None
  133. """
  134. try:
  135. logger.info(f"调用 edit_images_api: prompt={prompt[:100]}..., image_count={len(image_urls)}")
  136. # 使用 Gemini API 生成/编辑图片
  137. pil_image = call_gemini_generate_image_from_images(
  138. prompt=prompt,
  139. image_urls=image_urls,
  140. model="gemini-2.5-flash-image",
  141. resolution=resolution
  142. )
  143. if pil_image is None:
  144. logger.error("Gemini API 返回 None")
  145. return None
  146. # 如果指定了宽高比,可以在这里调整图片尺寸
  147. # 注意:Gemini API 本身不支持 aspect_ratio 参数,所以这里只是兼容性处理
  148. if aspect_ratio and len(aspect_ratio) == 2:
  149. # 可以在这里添加图片尺寸调整逻辑
  150. pass
  151. return pil_image
  152. except Exception as e:
  153. logger.error(f"edit_images_api 调用失败: {e}")
  154. return None
  155. @app.post("/edit_image")
  156. def edit_image(req: EditImageRequest):
  157. logger.info(f"/edit_image 请求: prompt={req.prompt[:200]}...,图片链接:{req.image_urls},图片尺寸:{req.image_size},图片比例:{req.ratio} image_count={len(req.image_urls)}")
  158. try:
  159. pil_image = None
  160. if not req.image_urls:
  161. return {"success": False, "error": "图片链接不能为空"}
  162. if req.image_size:
  163. w, h = req.image_size
  164. ratio = (w, h)
  165. ar = f"{w}x{h}"
  166. aspect_ratio = size_dict.get(ar, [1, 1]) # 使用 get 避免 KeyError
  167. # 先上传获取URL,用于后续编辑
  168. pil_image = edit_images_api(
  169. req.prompt,
  170. req.image_urls,
  171. aspect_ratio=aspect_ratio,
  172. resolution="2K"
  173. )
  174. if pil_image is None:
  175. return {"success": False, "error": "编辑图片失败"}
  176. pil_image = pil_image.resize(req.image_size)
  177. logger.info(f"图片已resize到: {req.image_size}")
  178. elif req.ratio:
  179. pil_image = edit_images_api(
  180. req.prompt,
  181. req.image_urls,
  182. aspect_ratio=req.ratio,
  183. resolution="1K"
  184. )
  185. else:
  186. pil_image = edit_images_api(
  187. req.prompt,
  188. req.image_urls,
  189. resolution="1K"
  190. )
  191. if pil_image is None:
  192. return {"success": False, "error": "编辑图片失败"}
  193. # 上传最终图片并获取URL
  194. image_url = process_cropped_upload(pil_image)
  195. return {
  196. "success": image_url is not None,
  197. "description": "生成的图片如下",
  198. "image_url": image_url,
  199. "raw": {"url": image_url},
  200. }
  201. except Exception as e:
  202. logger.exception("编辑接口异常")
  203. return {"success": False, "error": str(e)}
  204. # @app.post("/edit_image_new")
  205. # def edit_image_new(req: EditImageRequest):
  206. # logger.info(f"/edit_image 请求: prompt={req.prompt[:200]}...,模特图片链接:{req.model_image_urls},衣服图片链接:{req.garment_image_urls},图片尺寸:{req.image_size},图片比例:{req.ratio} image_count={len(req.model_image_urls)}")
  207. # try:
  208. # if req.image_size:
  209. # w,h=req.image_size
  210. # ratio = (w,h)
  211. # ar = f"{w}x{h}"
  212. # aspect_ratio = size_dict[ar]
  213. # elif req.ratio:
  214. # aspect_ratio = req.ratio
  215. # else:
  216. # aspect_ratio = [1,1]
  217. # pil_image=None
  218. # pil_image_list=[]
  219. # if not req.model_image_urls:
  220. # return {"success": False, "error": "model_image_urls图片链接不能为空"}
  221. # if not req.garment_image_urls:
  222. # return {"success": False, "error": "garment_image_urls图片链接不能为空"}
  223. # for image_url in req.model_image_urls:
  224. # for garment_image_url in req.garment_image_urls:
  225. # pil_image=edit_images_api(req.prompt,[image_url,garment_image_url],aspect_ratio=aspect_ratio,resolution="2K")
  226. # if pil_image is None:
  227. # continue
  228. # if req.image_size:
  229. # pil_image=pil_image.resize(req.image_size)
  230. # temp_image_url=process_cropped_upload(pil_image)
  231. # if temp_image_url is None:
  232. # continue
  233. # logger.info(f"图片已上传: {temp_image_url}")
  234. # pil_image_list.append(temp_image_url)
  235. # if len(pil_image_list) == 0:
  236. # return {"success": False, "error": "编辑图片失败"}
  237. # return {
  238. # "success": len(pil_image_list) > 0,
  239. # "description": "生成的图片如下",
  240. # "image_url": pil_image_list,
  241. # "raw": {"url": pil_image_list},
  242. # }
  243. # except Exception as e:
  244. # logger.exception("编辑接口异常")
  245. # return {"success": False, "error": str(e)}
  246. @app.post("/text_to_image")
  247. def text_to_image(req: GenerateImageRequest):
  248. logger.info(f"/text_to_image 请求: prompt={req.prompt[:200]}...")
  249. try:
  250. pil_image=call_gemini_generate_image(req.prompt)
  251. if pil_image is None:
  252. return {"success": False, "error": "生成图片失败"}
  253. # 上传图片并获取URL
  254. image_url = process_cropped_upload(pil_image)
  255. result={"url":image_url}
  256. logger.info(f"/text_to_image 最终结果: {result}")
  257. return {
  258. "success": True,
  259. "image_url": image_url,
  260. "raw": result,
  261. }
  262. except Exception as e:
  263. logger.exception("文生图接口异常")
  264. return {"success": False, "error": str(e)}
  265. @app.post("/generate_sketch")
  266. def generate_sketch_api(req: GenerateSketchRequest):
  267. """
  268. 生成线稿图接口
  269. 使用scr.sketch模块生成线稿图,支持自动质量检查
  270. """
  271. logger.info(f"/generate_sketch 线稿图请求: image={req.image[:100] if req.image else 'None'}..., prompt={req.prompt[:100] if req.prompt else 'None'}")
  272. try:
  273. # 检查是否提供了款式图片
  274. if not req.image:
  275. return {"success": False, "error": "请上传款式图片"}
  276. # 调用线稿图生成模块
  277. sketch_url = generate_sketch(
  278. image_url=req.image,
  279. prompt=req.prompt if req.prompt else None,
  280. max_retries=2,
  281. auto_check=True
  282. )
  283. if sketch_url is None:
  284. return {"success": False, "error": "生成线稿图失败"}
  285. logger.info(f"/generate_sketch 最终结果: {sketch_url}")
  286. return {
  287. "success": True,
  288. "description": "线稿图生成成功",
  289. "image_url": sketch_url,
  290. "raw": {"url": sketch_url},
  291. }
  292. except Exception as e:
  293. logger.exception("线稿生成接口异常")
  294. return {"success": False, "error": str(e)}
  295. if __name__ == "__main__":
  296. import uvicorn
  297. uvicorn.run(app, host="0.0.0.0", port=9002)