ai_swap_cloth_api.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. AI换衣服API接口
  3. 提供RESTful API接口,支持AI换衣服、历史记录查询等功能
  4. """
  5. import os
  6. import uvicorn
  7. from pydantic import BaseModel, Field
  8. from sqlalchemy import func, and_
  9. from datetime import datetime, timedelta
  10. from typing import Optional, List, Dict, Any, Union
  11. from starlette.responses import FileResponse
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from fastapi.responses import JSONResponse
  14. from fastapi.staticfiles import StaticFiles
  15. from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Body, APIRouter
  16. from backend.services.ai_swap_cloth_service import AISwapClothService, process_swap_cloth_with_record
  17. from backend.modules.database.operations import DatabaseOperations
  18. from backend.utils.logger_config import setup_logger
  19. from backend.api.auth_api import router as auth_router
  20. from backend.api.user_material_api import router as user_material_router
  21. from backend.api.user_template_api import router as user_template_router
  22. from backend.modules.database.models import ProcessRecord
  23. from backend.api.auto_post_api import router as auto_post_router
  24. logger = setup_logger(__name__)
  25. router = APIRouter()
  26. # 创建服务实例
  27. ai_swap_cloth_service = AISwapClothService()
  28. db_ops = DatabaseOperations()
  29. # 启动任务队列服务
  30. from backend.services.task_queue_service import get_task_queue_service
  31. task_queue_service = get_task_queue_service()
  32. task_queue_service.start()
  33. # ==================== 数据模型 ====================
  34. class SwapClothRequest(BaseModel):
  35. user_id: int = Field(..., description="用户ID")
  36. raw_image_id: int = Field(..., description="原始图片ID")
  37. cloth_image_id: int = Field(..., description="衣服图片ID")
  38. quantity: Optional[int] = Field(1, ge=1, le=10, description="每组生成数量")
  39. class SwapClothResponse(BaseModel):
  40. success: bool
  41. task_id: Optional[str] = None
  42. process_record_id: Optional[int] = None
  43. result_image_id: Optional[int] = None
  44. copywriter_text: Optional[str] = None
  45. history_prompt: Optional[str] = None
  46. error: Optional[str] = None
  47. error_type: Optional[str] = None
  48. # ==================== 中间件和依赖 ====================
  49. async def verify_user(user_id: int):
  50. """验证用户是否存在且有效"""
  51. user = db_ops.get_user_by_id(user_id)
  52. if not user:
  53. raise HTTPException(status_code=404, detail=f"用户ID {user_id} 不存在")
  54. if not user.get("is_active", False):
  55. raise HTTPException(status_code=403, detail="用户账户已被禁用")
  56. return user
  57. # ==================== API路由 ====================
  58. @router.post("/api/v1/swap-cloth", response_model=SwapClothResponse, tags=["AI换衣服"])
  59. async def process_swap_cloth(request: SwapClothRequest):
  60. """
  61. 执行AI换衣服(异步任务)
  62. - **user_id**: 用户ID
  63. - **raw_image_id**: 原始图片ID
  64. - **cloth_image_id**: 衣服图片ID
  65. - **quantity**: 每组生成数量(可选)
  66. """
  67. try:
  68. # 验证用户
  69. await verify_user(request.user_id)
  70. # 提交任务到队列(非阻塞)
  71. task_id = ai_swap_cloth_service.submit_swap_cloth_task(
  72. user_id=request.user_id,
  73. raw_image_id=request.raw_image_id,
  74. cloth_image_id=request.cloth_image_id,
  75. quantity=request.quantity or 1
  76. )
  77. logger.info(f"用户 {request.user_id} 换衣服任务已提交,任务ID: {task_id}")
  78. return SwapClothResponse(
  79. success=True,
  80. task_id=task_id,
  81. process_record_id=None,
  82. result_image_id=None,
  83. copywriter_text=None,
  84. history_prompt=None
  85. )
  86. except HTTPException:
  87. raise
  88. except Exception as e:
  89. logger.error(f"AI换衣服API处理异常: {str(e)}", exc_info=True)
  90. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")