|
@@ -0,0 +1,841 @@
|
|
|
|
|
+"""
|
|
|
|
|
+AI换脸换装API接口
|
|
|
|
|
+提供RESTful API接口,支持换脸换装、历史记录查询等功能
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Body
|
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
|
|
+from starlette.responses import FileResponse
|
|
|
|
|
+from fastapi.staticfiles import StaticFiles
|
|
|
|
|
+import os
|
|
|
|
|
+from pydantic import BaseModel, Field
|
|
|
|
|
+from typing import Optional, List, Dict, Any, Union
|
|
|
|
|
+import uvicorn
|
|
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+from sqlalchemy import func, and_
|
|
|
|
|
+
|
|
|
|
|
+from backend.services.ai_swap_service import AISwapService, process_swap_with_record
|
|
|
|
|
+from backend.modules.database.operations import DatabaseOperations
|
|
|
|
|
+from backend.utils.logger_config import setup_logger
|
|
|
|
|
+from backend.api.auth_api import router as auth_router
|
|
|
|
|
+from backend.api.user_material_api import router as user_material_router
|
|
|
|
|
+from backend.api.user_template_api import router as user_template_router
|
|
|
|
|
+from backend.modules.database.models import ProcessRecord
|
|
|
|
|
+from backend.api.auto_post_api import router as auto_post_router
|
|
|
|
|
+from backend.api.ai_swap_bg_api import router as ai_swap_bg_router
|
|
|
|
|
+from backend.api.ai_swap_cloth_api import router as ai_swap_cloth_router
|
|
|
|
|
+from backend.api.ai_swap_face_api import router as ai_swap_face_router
|
|
|
|
|
+from backend.api.ai_gen_video_api import router as ai_gen_video_router
|
|
|
|
|
+
|
|
|
|
|
+logger = setup_logger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+# 创建FastAPI应用
|
|
|
|
|
+app = FastAPI(
|
|
|
|
|
+ title="AI换脸换装服务",
|
|
|
|
|
+ description="包含用户认证与管理、AI换脸换装等API",
|
|
|
|
|
+ version="1.0.0"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+# 挂载静态目录:使用与当前文件相对的绝对路径,避免工作目录变化导致路径失效
|
|
|
|
|
+current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
+project_root = os.path.dirname(os.path.dirname(current_dir))
|
|
|
|
|
+
|
|
|
|
|
+materials_dir = os.path.abspath(os.path.join(project_root, 'backend', 'data', 'materials'))
|
|
|
|
|
+output_dir = os.path.abspath(os.path.join(project_root, 'backend', 'output'))
|
|
|
|
|
+
|
|
|
|
|
+# 确保目录存在(StaticFiles 需要已存在的目录)
|
|
|
|
|
+os.makedirs(materials_dir, exist_ok=True)
|
|
|
|
|
+os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+allowed_origins = [
|
|
|
|
|
+ "http://localhost:5173",
|
|
|
|
|
+ "http://127.0.0.1:5173",
|
|
|
|
|
+ "http://localhost:3000",
|
|
|
|
|
+ "http://127.0.0.1:3000",
|
|
|
|
|
+ "http://10.41.175.254:5173",
|
|
|
|
|
+ "http://10.41.175.254:3000",
|
|
|
|
|
+ # 局域网访问(当前机器IP)
|
|
|
|
|
+ "http://10.41.175.43:5173",
|
|
|
|
|
+ "http://10.41.175.43:3000",
|
|
|
|
|
+ # 局域网访问(域名访问)
|
|
|
|
|
+ "http://ai-swap.local:5173",
|
|
|
|
|
+]
|
|
|
|
|
+
|
|
|
|
|
+# 为静态资源单独包裹 CORS 中间件,确保跨域下载(fetch blob)也包含 CORS 头
|
|
|
|
|
+materials_app = StaticFiles(directory=materials_dir)
|
|
|
|
|
+materials_app = CORSMiddleware(
|
|
|
|
|
+ app=materials_app,
|
|
|
|
|
+ allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS
|
|
|
|
|
+ allow_credentials=False,
|
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
|
+)
|
|
|
|
|
+app.mount("/materials", materials_app, name="materials")
|
|
|
|
|
+
|
|
|
|
|
+output_app = StaticFiles(directory=output_dir)
|
|
|
|
|
+output_app = CORSMiddleware(
|
|
|
|
|
+ app=output_app,
|
|
|
|
|
+ allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS
|
|
|
|
|
+ allow_credentials=False,
|
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
|
+)
|
|
|
|
|
+app.mount("/output", output_app, name="output")
|
|
|
|
|
+
|
|
|
|
|
+# 添加CORS中间件
|
|
|
|
|
+app.add_middleware(
|
|
|
|
|
+ CORSMiddleware,
|
|
|
|
|
+ allow_origins=allowed_origins,
|
|
|
|
|
+ allow_credentials=True,
|
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+# 注册认证API
|
|
|
|
|
+app.include_router(auth_router, prefix="/auth", tags=["auth"])
|
|
|
|
|
+app.include_router(user_material_router)
|
|
|
|
|
+app.include_router(auto_post_router, prefix="/auto_post")
|
|
|
|
|
+app.include_router(user_template_router)
|
|
|
|
|
+app.include_router(ai_swap_bg_router)
|
|
|
|
|
+app.include_router(ai_swap_cloth_router)
|
|
|
|
|
+app.include_router(ai_swap_face_router)
|
|
|
|
|
+app.include_router(ai_gen_video_router)
|
|
|
|
|
+
|
|
|
|
|
+# 创建服务实例
|
|
|
|
|
+ai_swap_service = AISwapService()
|
|
|
|
|
+db_ops = DatabaseOperations()
|
|
|
|
|
+
|
|
|
|
|
+# 启动任务队列服务
|
|
|
|
|
+from backend.services.task_queue_service import get_task_queue_service
|
|
|
|
|
+task_queue_service = get_task_queue_service()
|
|
|
|
|
+task_queue_service.start()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 数据模型 ====================
|
|
|
|
|
+
|
|
|
|
|
+class SwapRequest(BaseModel):
|
|
|
|
|
+ """换脸换装请求模型"""
|
|
|
|
|
+ user_id: int = Field(..., description="用户ID")
|
|
|
|
|
+ face_image_id: int = Field(..., description="人脸图片ID")
|
|
|
|
|
+ cloth_image_id: int = Field(..., description="服装图片ID")
|
|
|
|
|
+ prompt: str = Field(..., description="提示词", max_length=500)
|
|
|
|
|
+ face_prompt: Optional[str] = Field(None, description="人脸识别提示词")
|
|
|
|
|
+ cloth_prompt: Optional[str] = Field(None, description="服装识别提示词")
|
|
|
|
|
+ prompt_prompt: Optional[str] = Field(None, description="提示词优化")
|
|
|
|
|
+ quantity: Optional[int] = Field(1, ge=1, le=10, description="每组生成数量")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SwapResponse(BaseModel):
|
|
|
|
|
+ """换脸换装响应模型"""
|
|
|
|
|
+ success: bool
|
|
|
|
|
+ task_id: Optional[str] = None
|
|
|
|
|
+ process_record_id: Optional[int] = None
|
|
|
|
|
+ result_image_id: Optional[int] = None
|
|
|
|
|
+ copywriter_text: Optional[str] = None
|
|
|
|
|
+ history_prompt: Optional[str] = None
|
|
|
|
|
+ error: Optional[str] = None
|
|
|
|
|
+ error_type: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TaskStatusResponse(BaseModel):
|
|
|
|
|
+ """任务状态响应模型"""
|
|
|
|
|
+ task_id: str
|
|
|
|
|
+ status: str
|
|
|
|
|
+ progress: int
|
|
|
|
|
+ created_at: str
|
|
|
|
|
+ started_at: Optional[str] = None
|
|
|
|
|
+ completed_at: Optional[str] = None
|
|
|
|
|
+ result: Optional[Dict[str, Any]] = None
|
|
|
|
|
+ error: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class UserTasksResponse(BaseModel):
|
|
|
|
|
+ """用户任务列表响应模型"""
|
|
|
|
|
+ tasks: List[Dict[str, Any]]
|
|
|
|
|
+ total: int
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ProcessHistoryResponse(BaseModel):
|
|
|
|
|
+ """处理历史响应模型"""
|
|
|
|
|
+ records: List[Dict[str, Any]]
|
|
|
|
|
+ total: int
|
|
|
|
|
+ page: int
|
|
|
|
|
+ page_size: int
|
|
|
|
|
+ total_pages: int
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ProcessDetailResponse(BaseModel):
|
|
|
|
|
+ """处理详情响应模型"""
|
|
|
|
|
+ process_record: Dict[str, Any]
|
|
|
|
|
+ face_image: Optional[Dict[str, Any]] = None
|
|
|
|
|
+ cloth_image: Optional[Dict[str, Any]] = None
|
|
|
|
|
+ result_image: Optional[Dict[str, Any]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class UpdateTextRequest(BaseModel):
|
|
|
|
|
+ """更新文本请求模型"""
|
|
|
|
|
+ user_id: int = Field(..., description="用户ID")
|
|
|
|
|
+ title: str = Field(..., description="标题")
|
|
|
|
|
+ content: str = Field(..., description="文案内容")
|
|
|
|
|
+ label: str = Field(..., description="话题标签")
|
|
|
|
|
+
|
|
|
|
|
+ class Config:
|
|
|
|
|
+ json_schema_extra = {
|
|
|
|
|
+ "example": {
|
|
|
|
|
+ "user_id": 1,
|
|
|
|
|
+ "title": "示例标题",
|
|
|
|
|
+ "content": "示例文案内容",
|
|
|
|
|
+ "label": "示例标签"
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ApproveRecordRequest(BaseModel):
|
|
|
|
|
+ """审核记录请求模型"""
|
|
|
|
|
+ user_id: int = Field(..., description="用户ID")
|
|
|
|
|
+
|
|
|
|
|
+ class Config:
|
|
|
|
|
+ json_schema_extra = {
|
|
|
|
|
+ "example": {
|
|
|
|
|
+ "user_id": 1
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ApproveRecordResponse(BaseModel):
|
|
|
|
|
+ """审核记录响应模型"""
|
|
|
|
|
+ success: bool
|
|
|
|
|
+ message: str
|
|
|
|
|
+ process_record: Optional[Dict[str, Any]] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DeleteResultImageRequest(BaseModel):
|
|
|
|
|
+ """删除结果图片请求模型"""
|
|
|
|
|
+ user_id: int = Field(..., description="用户ID")
|
|
|
|
|
+
|
|
|
|
|
+ class Config:
|
|
|
|
|
+ json_schema_extra = {
|
|
|
|
|
+ "example": {
|
|
|
|
|
+ "user_id": 1
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DeleteResultImageResponse(BaseModel):
|
|
|
|
|
+ """删除结果图片响应模型"""
|
|
|
|
|
+ success: bool
|
|
|
|
|
+ message: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DashboardStatsResponse(BaseModel):
|
|
|
|
|
+ """仪表盘统计数据响应模型"""
|
|
|
|
|
+ today_generated: int = Field(..., description="今日生成数量")
|
|
|
|
|
+ pending_review: int = Field(..., description="待审核数量")
|
|
|
|
|
+ published: int = Field(..., description="已发布数量")
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 中间件和依赖 ====================
|
|
|
|
|
+
|
|
|
|
|
+async def verify_user(user_id: int):
|
|
|
|
|
+ """验证用户是否存在且有效"""
|
|
|
|
|
+ user = db_ops.get_user_by_id(user_id)
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail=f"用户ID {user_id} 不存在")
|
|
|
|
|
+ if not user.get("is_active", False):
|
|
|
|
|
+ raise HTTPException(status_code=403, detail="用户账户已被禁用")
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== API接口 ====================
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/api/v1/swap", response_model=SwapResponse, tags=["AI换脸换装"])
|
|
|
|
|
+async def process_swap(request: SwapRequest):
|
|
|
|
|
+ """
|
|
|
|
|
+ 执行AI换脸换装(异步任务)
|
|
|
|
|
+
|
|
|
|
|
+ - **user_id**: 用户ID
|
|
|
|
|
+ - **face_image_id**: 人脸图片ID
|
|
|
|
|
+ - **cloth_image_id**: 服装图片ID
|
|
|
|
|
+ - **prompt**: 提示词
|
|
|
|
|
+ - **face_prompt**: 人脸识别提示词(可选)
|
|
|
|
|
+ - **cloth_prompt**: 服装识别提示词(可选)
|
|
|
|
|
+ - **prompt_prompt**: 提示词优化(可选)
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 验证用户
|
|
|
|
|
+ await verify_user(request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 提交任务到队列(非阻塞)
|
|
|
|
|
+ task_id = ai_swap_service.submit_swap_task(
|
|
|
|
|
+ user_id=request.user_id,
|
|
|
|
|
+ face_image_id=request.face_image_id,
|
|
|
|
|
+ cloth_image_id=request.cloth_image_id,
|
|
|
|
|
+ prompt=request.prompt,
|
|
|
|
|
+ face_prompt=request.face_prompt,
|
|
|
|
|
+ cloth_prompt=request.cloth_prompt,
|
|
|
|
|
+ prompt_prompt=request.prompt_prompt,
|
|
|
|
|
+ quantity=request.quantity or 1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"用户 {request.user_id} 换脸换装任务已提交,任务ID: {task_id}")
|
|
|
|
|
+ return SwapResponse(
|
|
|
|
|
+ success=True,
|
|
|
|
|
+ task_id=task_id,
|
|
|
|
|
+ process_record_id=None,
|
|
|
|
|
+ result_image_id=None,
|
|
|
|
|
+ copywriter_text=None,
|
|
|
|
|
+ history_prompt=None
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"API处理异常: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/tasks/{task_id}", response_model=TaskStatusResponse, tags=["任务管理"])
|
|
|
|
|
+async def get_task_status(task_id: str):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取任务状态
|
|
|
|
|
+
|
|
|
|
|
+ - **task_id**: 任务ID
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ task_info = ai_swap_service.get_task_status(task_id)
|
|
|
|
|
+ if not task_info:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
|
+
|
|
|
|
|
+ return TaskStatusResponse(
|
|
|
|
|
+ task_id=task_info["id"],
|
|
|
|
|
+ status=task_info["status"].value,
|
|
|
|
|
+ progress=task_info["progress"],
|
|
|
|
|
+ created_at=task_info["created_at"].isoformat(),
|
|
|
|
|
+ started_at=task_info["started_at"].isoformat() if task_info["started_at"] else None,
|
|
|
|
|
+ completed_at=task_info["completed_at"].isoformat() if task_info["completed_at"] else None,
|
|
|
|
|
+ result=task_info["result"],
|
|
|
|
|
+ error=task_info["error"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取任务状态异常: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/users/{user_id}/tasks", response_model=UserTasksResponse, tags=["任务管理"])
|
|
|
|
|
+async def get_user_tasks(user_id: int):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取用户的所有任务
|
|
|
|
|
+
|
|
|
|
|
+ - **user_id**: 用户ID
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 验证用户
|
|
|
|
|
+ await verify_user(user_id)
|
|
|
|
|
+
|
|
|
|
|
+ tasks = ai_swap_service.get_user_tasks(user_id)
|
|
|
|
|
+
|
|
|
|
|
+ return UserTasksResponse(
|
|
|
|
|
+ tasks=tasks,
|
|
|
|
|
+ total=len(tasks)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取用户任务异常: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.delete("/api/v1/tasks/{task_id}", tags=["任务管理"])
|
|
|
|
|
+async def cancel_task(task_id: str):
|
|
|
|
|
+ """
|
|
|
|
|
+ 取消任务
|
|
|
|
|
+
|
|
|
|
|
+ - **task_id**: 任务ID
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ success = ai_swap_service.cancel_task(task_id)
|
|
|
|
|
+ if not success:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="任务不存在或无法取消")
|
|
|
|
|
+
|
|
|
|
|
+ return {"success": True, "message": "任务已取消"}
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"取消任务异常: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/tasks/queue/stats", tags=["任务管理"])
|
|
|
|
|
+async def get_queue_stats():
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取任务队列统计信息
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ stats = ai_swap_service.task_queue.get_queue_stats()
|
|
|
|
|
+ return stats
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取队列统计异常: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/users/{user_id}/process-history", response_model=ProcessHistoryResponse, tags=["历史记录"])
|
|
|
|
|
+async def get_user_process_history(
|
|
|
|
|
+ user_id: int,
|
|
|
|
|
+ page: int = 1,
|
|
|
|
|
+ page_size: int = 20,
|
|
|
|
|
+ user: Dict[str, Any] = Depends(verify_user)
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取用户的处理历史记录
|
|
|
|
|
+
|
|
|
|
|
+ - **user_id**: 用户ID
|
|
|
|
|
+ - **page**: 页码(默认1)
|
|
|
|
|
+ - **page_size**: 每页大小(默认20,最大100)
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 限制每页大小
|
|
|
|
|
+ page_size = min(page_size, 100)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取历史记录
|
|
|
|
|
+ history = ai_swap_service.get_user_process_history(user_id, page, page_size)
|
|
|
|
|
+
|
|
|
|
|
+ return ProcessHistoryResponse(
|
|
|
|
|
+ records=history["records"],
|
|
|
|
|
+ total=history["total"],
|
|
|
|
|
+ page=history["page"],
|
|
|
|
|
+ page_size=history["page_size"],
|
|
|
|
|
+ total_pages=history["total_pages"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取用户历史记录失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取历史记录失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/process/{process_id}", response_model=ProcessDetailResponse, tags=["历史记录"])
|
|
|
|
|
+async def get_process_detail(
|
|
|
|
|
+ process_id: int,
|
|
|
|
|
+ user_id: Optional[int] = None
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取处理记录详情
|
|
|
|
|
+
|
|
|
|
|
+ - **process_id**: 处理记录ID
|
|
|
|
|
+ - **user_id**: 用户ID(可选,用于权限验证)
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 如果提供了用户ID,先验证用户
|
|
|
|
|
+ if user_id:
|
|
|
|
|
+ await verify_user(user_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取处理详情
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="处理记录不存在或无权限访问")
|
|
|
|
|
+
|
|
|
|
|
+ return ProcessDetailResponse(
|
|
|
|
|
+ process_record=detail["process_record"],
|
|
|
|
|
+ face_image=detail["face_image"],
|
|
|
|
|
+ cloth_image=detail["cloth_image"],
|
|
|
|
|
+ result_image=detail["result_image"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取处理详情失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取处理详情失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.put("/api/v1/process/{process_id}/text", response_model=ProcessDetailResponse, tags=["历史记录"])
|
|
|
|
|
+async def update_process_text(
|
|
|
|
|
+ process_id: int,
|
|
|
|
|
+ request: UpdateTextRequest = Body(...),
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 更新处理记录的文本内容
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ process_id: 处理记录ID
|
|
|
|
|
+ request: 更新文本请求数据
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ ProcessDetailResponse: 更新后的处理记录详情
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 验证用户
|
|
|
|
|
+ user = await verify_user(request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取处理详情
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="处理记录不存在")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证权限
|
|
|
|
|
+ if detail["process_record"]["user_id"] != request.user_id:
|
|
|
|
|
+ raise HTTPException(status_code=403, detail="无权限修改此记录")
|
|
|
|
|
+
|
|
|
|
|
+ # 更新文本内容
|
|
|
|
|
+ generated_text = f"- 标题:{request.title}\n- 文案:{request.content}\n- 话题标签:{request.label}"
|
|
|
|
|
+
|
|
|
|
|
+ # 更新数据库记录
|
|
|
|
|
+ update_data = {
|
|
|
|
|
+ "generated_text": generated_text
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ updated_record = db_ops.update_process_record(process_id, update_data)
|
|
|
|
|
+
|
|
|
|
|
+ if not updated_record:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="更新数据库失败")
|
|
|
|
|
+
|
|
|
|
|
+ # 重新获取更新后的记录
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="获取更新后的记录失败")
|
|
|
|
|
+
|
|
|
|
|
+ return ProcessDetailResponse(
|
|
|
|
|
+ process_record=detail["process_record"],
|
|
|
|
|
+ face_image=detail["face_image"],
|
|
|
|
|
+ cloth_image=detail["cloth_image"],
|
|
|
|
|
+ result_image=detail["result_image"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"更新处理记录文本失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"更新处理记录文本失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/api/v1/process/{process_id}/approve", response_model=ApproveRecordResponse, tags=["历史记录"])
|
|
|
|
|
+async def approve_process_record(
|
|
|
|
|
+ process_id: int,
|
|
|
|
|
+ request: ApproveRecordRequest = Body(...)
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 审核处理记录
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ process_id: 处理记录ID
|
|
|
|
|
+ request: 审核请求数据
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ ApproveRecordResponse: 审核结果
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 验证用户
|
|
|
|
|
+ await verify_user(request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取处理详情
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="处理记录不存在")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证权限
|
|
|
|
|
+ if detail["process_record"]["user_id"] != request.user_id:
|
|
|
|
|
+ raise HTTPException(status_code=403, detail="无权限审核此记录")
|
|
|
|
|
+
|
|
|
|
|
+ # 审核通过
|
|
|
|
|
+ success = ai_swap_service.approve_process_record(process_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not success:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="审核失败")
|
|
|
|
|
+
|
|
|
|
|
+ # 重新获取审核后的记录
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="获取审核后的记录失败")
|
|
|
|
|
+
|
|
|
|
|
+ return ApproveRecordResponse(
|
|
|
|
|
+ success=True,
|
|
|
|
|
+ message="处理记录已审核通过",
|
|
|
|
|
+ process_record=detail["process_record"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"审核处理记录失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"审核处理记录失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.delete("/api/v1/process/{process_id}/result-image", response_model=DeleteResultImageResponse, tags=["历史记录"])
|
|
|
|
|
+async def delete_result_image(
|
|
|
|
|
+ process_id: int,
|
|
|
|
|
+ request: DeleteResultImageRequest = Body(...)
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 删除处理记录的结果图片
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ process_id: 处理记录ID
|
|
|
|
|
+ request: 删除请求数据
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ DeleteResultImageResponse: 删除结果
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 验证用户
|
|
|
|
|
+ await verify_user(request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取处理详情
|
|
|
|
|
+ detail = ai_swap_service.get_process_detail(process_id, request.user_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not detail:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="处理记录不存在")
|
|
|
|
|
+
|
|
|
|
|
+ # 验证权限
|
|
|
|
|
+ if detail["process_record"]["user_id"] != request.user_id:
|
|
|
|
|
+ raise HTTPException(status_code=403, detail="无权限删除此图片")
|
|
|
|
|
+
|
|
|
|
|
+ # 删除图片
|
|
|
|
|
+ success = ai_swap_service.delete_result_image(process_id)
|
|
|
|
|
+
|
|
|
|
|
+ if not success:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="删除图片失败")
|
|
|
|
|
+
|
|
|
|
|
+ return DeleteResultImageResponse(
|
|
|
|
|
+ success=True,
|
|
|
|
|
+ message="结果图片已删除"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"删除结果图片失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/users/{user_id}/images", tags=["图片管理"])
|
|
|
|
|
+async def get_user_images(
|
|
|
|
|
+ user_id: int,
|
|
|
|
|
+ image_type: Optional[str] = None,
|
|
|
|
|
+ page: int = 1,
|
|
|
|
|
+ page_size: int = 20,
|
|
|
|
|
+ user: Dict[str, Any] = Depends(verify_user)
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取用户的图片列表
|
|
|
|
|
+
|
|
|
|
|
+ - **user_id**: 用户ID
|
|
|
|
|
+ - **image_type**: 图片类型(face/cloth/result,可选)
|
|
|
|
|
+ - **page**: 页码(默认1)
|
|
|
|
|
+ - **page_size**: 每页大小(默认20,最大100)
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 限制每页大小
|
|
|
|
|
+ page_size = min(page_size, 100)
|
|
|
|
|
+
|
|
|
|
|
+ # 验证图片类型
|
|
|
|
|
+ if image_type and image_type not in ["face", "cloth", "result", "original"]:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="图片类型必须是 face、cloth、result 或 original")
|
|
|
|
|
+
|
|
|
|
|
+ # 获取图片列表
|
|
|
|
|
+ images = db_ops.get_user_images(user_id, image_type, page, page_size)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "images": images["images"],
|
|
|
|
|
+ "total": images["total"],
|
|
|
|
|
+ "page": images["page"],
|
|
|
|
|
+ "page_size": images["page_size"],
|
|
|
|
|
+ "total_pages": images["total_pages"]
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except HTTPException:
|
|
|
|
|
+ raise
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取用户图片列表失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取图片列表失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/dashboard/stats", response_model=DashboardStatsResponse, tags=["仪表盘"])
|
|
|
|
|
+async def get_dashboard_stats(user_id: int):
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取仪表盘统计数据
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ user_id: 用户ID
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ DashboardStatsResponse: 包含今日生成、待审核、已发布的数量
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ with db_ops.db_connection.get_session() as session:
|
|
|
|
|
+ # 获取今天的开始时间(0点)
|
|
|
|
|
+ today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
|
+
|
|
|
|
|
+ # 取该用户所有记录的关键字段,按组合(face+cloth)聚合
|
|
|
|
|
+ rows = (
|
|
|
|
|
+ session.query(
|
|
|
|
|
+ ProcessRecord.face_image_id,
|
|
|
|
|
+ ProcessRecord.cloth_image_id,
|
|
|
|
|
+ ProcessRecord.status,
|
|
|
|
|
+ ProcessRecord.completed_at,
|
|
|
|
|
+ ProcessRecord.result_image_id,
|
|
|
|
|
+ )
|
|
|
|
|
+ .filter(ProcessRecord.user_id == user_id)
|
|
|
|
|
+ .all()
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 组合 -> 最新记录(以 completed_at 最大为准)
|
|
|
|
|
+ latest_by_group = {}
|
|
|
|
|
+ # 今日生成的组合集合(任意一条记录今天完成即可计一次)
|
|
|
|
|
+ groups_generated_today = set()
|
|
|
|
|
+
|
|
|
|
|
+ for face_id, cloth_id, status, completed_at, result_image_id in rows:
|
|
|
|
|
+ if completed_at and completed_at >= today_start and result_image_id is not None:
|
|
|
|
|
+ groups_generated_today.add((face_id, cloth_id))
|
|
|
|
|
+
|
|
|
|
|
+ key = (face_id, cloth_id)
|
|
|
|
|
+ prev = latest_by_group.get(key)
|
|
|
|
|
+ if prev is None:
|
|
|
|
|
+ latest_by_group[key] = (status, completed_at, result_image_id)
|
|
|
|
|
+ else:
|
|
|
|
|
+ _, prev_time, _ = prev
|
|
|
|
|
+ # None 视为更早
|
|
|
|
|
+ if (prev_time or datetime.min) < (completed_at or datetime.min):
|
|
|
|
|
+ latest_by_group[key] = (status, completed_at, result_image_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 以组合为单位统计
|
|
|
|
|
+ today_generated = len(groups_generated_today)
|
|
|
|
|
+
|
|
|
|
|
+ pending_review = 0
|
|
|
|
|
+ published = 0
|
|
|
|
|
+ for (_face_id, _cloth_id), (status, _t, _rid) in latest_by_group.items():
|
|
|
|
|
+ if status == "待审核":
|
|
|
|
|
+ pending_review += 1
|
|
|
|
|
+ elif status == "已发布":
|
|
|
|
|
+ published += 1
|
|
|
|
|
+
|
|
|
|
|
+ return DashboardStatsResponse(
|
|
|
|
|
+ today_generated=today_generated,
|
|
|
|
|
+ pending_review=pending_review,
|
|
|
|
|
+ published=published,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取仪表盘统计数据失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取仪表盘统计数据失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/statistics", tags=["系统统计"])
|
|
|
|
|
+async def get_system_statistics():
|
|
|
|
|
+ """
|
|
|
|
|
+ 获取系统统计信息
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ stats = db_ops.get_statistics()
|
|
|
|
|
+ return {
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "statistics": stats
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"获取系统统计失败: {str(e)}", exc_info=True)
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取系统统计失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 健康检查 ====================
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/health", tags=["系统"])
|
|
|
|
|
+async def health_check():
|
|
|
|
|
+ """健康检查接口"""
|
|
|
|
|
+ return {"status": "healthy", "service": "ai_swap_api"}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/")
|
|
|
|
|
+def root():
|
|
|
|
|
+ return {"msg": "AI换脸换装服务已启动"}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 静态文件直链下载(包含CORS响应头) ====================
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/output/{filename}")
|
|
|
|
|
+async def get_output_file(filename: str):
|
|
|
|
|
+ """为结果图片提供带CORS头的直链下载"""
|
|
|
|
|
+ file_path = os.path.join(output_dir, filename)
|
|
|
|
|
+ if not os.path.isfile(file_path):
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="文件不存在")
|
|
|
|
|
+ headers = {
|
|
|
|
|
+ "Access-Control-Allow-Origin": "*",
|
|
|
|
|
+ "Access-Control-Expose-Headers": "Content-Disposition",
|
|
|
|
|
+ }
|
|
|
|
|
+ return FileResponse(file_path, headers=headers, filename=filename)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/api/v1/materials/{filename}")
|
|
|
|
|
+async def get_material_file(filename: str):
|
|
|
|
|
+ """为素材图片提供带CORS头的直链下载"""
|
|
|
|
|
+ file_path = os.path.join(materials_dir, filename)
|
|
|
|
|
+ if not os.path.isfile(file_path):
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="文件不存在")
|
|
|
|
|
+ headers = {
|
|
|
|
|
+ "Access-Control-Allow-Origin": "*",
|
|
|
|
|
+ "Access-Control-Expose-Headers": "Content-Disposition",
|
|
|
|
|
+ }
|
|
|
|
|
+ return FileResponse(file_path, headers=headers, filename=filename)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 错误处理 ====================
|
|
|
|
|
+
|
|
|
|
|
+@app.exception_handler(HTTPException)
|
|
|
|
|
+async def http_exception_handler(request, exc):
|
|
|
|
|
+ """HTTP异常处理器"""
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=exc.status_code,
|
|
|
|
|
+ content={
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "error": exc.detail,
|
|
|
|
|
+ "error_type": "HTTPException"
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.exception_handler(Exception)
|
|
|
|
|
+async def general_exception_handler(request, exc):
|
|
|
|
|
+ """通用异常处理器"""
|
|
|
|
|
+ logger.error(f"未处理的异常: {str(exc)}", exc_info=True)
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ status_code=500,
|
|
|
|
|
+ content={
|
|
|
|
|
+ "success": False,
|
|
|
|
|
+ "error": "服务器内部错误",
|
|
|
|
|
+ "error_type": "InternalServerError"
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# ==================== 启动配置 ====================
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ """
|
|
|
|
|
+ 启动API服务器
|
|
|
|
|
+ """
|
|
|
|
|
+ host = "0.0.0.0"
|
|
|
|
|
+ port = 8002
|
|
|
|
|
+ reload = True
|
|
|
|
|
+
|
|
|
|
|
+ print(f"启动AI Swap API服务器...")
|
|
|
|
|
+ print(f"服务器地址: http://{host}:{port}")
|
|
|
|
|
+ print(f"API文档: http://{host}:{port}/docs")
|
|
|
|
|
+ print(f"自动重载: {reload}")
|
|
|
|
|
+
|
|
|
|
|
+ uvicorn.run(
|
|
|
|
|
+ "backend.api.ai_swap_api:app",
|
|
|
|
|
+ host=host,
|
|
|
|
|
+ port=port,
|
|
|
|
|
+ reload=reload,
|
|
|
|
|
+ log_level="info"
|
|
|
|
|
+ )
|