""" AI换脸换装API接口 提供RESTful API接口,支持换脸换装、历史记录查询等功能 """ from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.responses import FileResponse from fastapi.staticfiles import StaticFiles import os from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, Union import uvicorn from datetime import datetime, timedelta from pydantic import BaseModel from sqlalchemy import func, and_ from backend.services.ai_swap_service import AISwapService, process_swap_with_record from backend.modules.database.operations import DatabaseOperations from backend.utils.logger_config import setup_logger from backend.api.auth_api import router as auth_router from backend.api.user_material_api import router as user_material_router from backend.api.user_template_api import router as user_template_router from backend.modules.database.models import ProcessRecord from backend.api.auto_post_api import router as auto_post_router from backend.api.ai_swap_bg_api import router as ai_swap_bg_router from backend.api.ai_swap_cloth_api import router as ai_swap_cloth_router from backend.api.ai_swap_face_api import router as ai_swap_face_router from backend.api.ai_gen_video_api import router as ai_gen_video_router logger = setup_logger(__name__) # 创建FastAPI应用 app = FastAPI( title="AI换脸换装服务", description="包含用户认证与管理、AI换脸换装等API", version="1.0.0" ) # 挂载静态目录:使用与当前文件相对的绝对路径,避免工作目录变化导致路径失效 current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(os.path.dirname(current_dir)) materials_dir = os.path.abspath(os.path.join(project_root, 'backend', 'data', 'materials')) output_dir = os.path.abspath(os.path.join(project_root, 'backend', 'output')) # 确保目录存在(StaticFiles 需要已存在的目录) os.makedirs(materials_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) allowed_origins = [ "http://localhost:5173", "http://127.0.0.1:5173", "http://localhost:3000", "http://127.0.0.1:3000", "http://10.41.175.254:5173", "http://10.41.175.254:3000", # 局域网访问(当前机器IP) "http://10.41.175.43:5173", "http://10.41.175.43:3000", # 局域网访问(域名访问) "http://ai-swap.local:5173", ] # 为静态资源单独包裹 CORS 中间件,确保跨域下载(fetch blob)也包含 CORS 头 materials_app = StaticFiles(directory=materials_dir) materials_app = CORSMiddleware( app=materials_app, allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) app.mount("/materials", materials_app, name="materials") output_app = StaticFiles(directory=output_dir) output_app = CORSMiddleware( app=output_app, allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) app.mount("/output", output_app, name="output") # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 注册认证API app.include_router(auth_router, prefix="/auth", tags=["auth"]) app.include_router(user_material_router) app.include_router(auto_post_router, prefix="/auto_post") app.include_router(user_template_router) app.include_router(ai_swap_bg_router) app.include_router(ai_swap_cloth_router) app.include_router(ai_swap_face_router) app.include_router(ai_gen_video_router) # 创建服务实例 ai_swap_service = AISwapService() db_ops = DatabaseOperations() # 启动任务队列服务 from backend.services.task_queue_service import get_task_queue_service task_queue_service = get_task_queue_service() task_queue_service.start() # ==================== 数据模型 ==================== class SwapRequest(BaseModel): """换脸换装请求模型""" user_id: int = Field(..., description="用户ID") face_image_id: int = Field(..., description="人脸图片ID") cloth_image_id: int = Field(..., description="服装图片ID") prompt: str = Field(..., description="提示词", max_length=500) face_prompt: Optional[str] = Field(None, description="人脸识别提示词") cloth_prompt: Optional[str] = Field(None, description="服装识别提示词") prompt_prompt: Optional[str] = Field(None, description="提示词优化") quantity: Optional[int] = Field(1, ge=1, le=10, description="每组生成数量") class SwapResponse(BaseModel): """换脸换装响应模型""" success: bool task_id: Optional[str] = None process_record_id: Optional[int] = None result_image_id: Optional[int] = None copywriter_text: Optional[str] = None history_prompt: Optional[str] = None error: Optional[str] = None error_type: Optional[str] = None class TaskStatusResponse(BaseModel): """任务状态响应模型""" task_id: str status: str progress: int created_at: str started_at: Optional[str] = None completed_at: Optional[str] = None result: Optional[Dict[str, Any]] = None error: Optional[str] = None class UserTasksResponse(BaseModel): """用户任务列表响应模型""" tasks: List[Dict[str, Any]] total: int class ProcessHistoryResponse(BaseModel): """处理历史响应模型""" records: List[Dict[str, Any]] total: int page: int page_size: int total_pages: int class ProcessDetailResponse(BaseModel): """处理详情响应模型""" process_record: Dict[str, Any] face_image: Optional[Dict[str, Any]] = None cloth_image: Optional[Dict[str, Any]] = None result_image: Optional[Dict[str, Any]] = None class UpdateTextRequest(BaseModel): """更新文本请求模型""" user_id: int = Field(..., description="用户ID") title: str = Field(..., description="标题") content: str = Field(..., description="文案内容") label: str = Field(..., description="话题标签") class Config: json_schema_extra = { "example": { "user_id": 1, "title": "示例标题", "content": "示例文案内容", "label": "示例标签" } } class ApproveRecordRequest(BaseModel): """审核记录请求模型""" user_id: int = Field(..., description="用户ID") class Config: json_schema_extra = { "example": { "user_id": 1 } } class ApproveRecordResponse(BaseModel): """审核记录响应模型""" success: bool message: str process_record: Optional[Dict[str, Any]] = None class DeleteResultImageRequest(BaseModel): """删除结果图片请求模型""" user_id: int = Field(..., description="用户ID") class Config: json_schema_extra = { "example": { "user_id": 1 } } class DeleteResultImageResponse(BaseModel): """删除结果图片响应模型""" success: bool message: str class DashboardStatsResponse(BaseModel): """仪表盘统计数据响应模型""" today_generated: int = Field(..., description="今日生成数量") pending_review: int = Field(..., description="待审核数量") published: int = Field(..., description="已发布数量") # ==================== 中间件和依赖 ==================== async def verify_user(user_id: int): """验证用户是否存在且有效""" user = db_ops.get_user_by_id(user_id) if not user: raise HTTPException(status_code=404, detail=f"用户ID {user_id} 不存在") if not user.get("is_active", False): raise HTTPException(status_code=403, detail="用户账户已被禁用") return user # ==================== API接口 ==================== @app.post("/api/v1/swap", response_model=SwapResponse, tags=["AI换脸换装"]) async def process_swap(request: SwapRequest): """ 执行AI换脸换装(异步任务) - **user_id**: 用户ID - **face_image_id**: 人脸图片ID - **cloth_image_id**: 服装图片ID - **prompt**: 提示词 - **face_prompt**: 人脸识别提示词(可选) - **cloth_prompt**: 服装识别提示词(可选) - **prompt_prompt**: 提示词优化(可选) """ try: # 验证用户 await verify_user(request.user_id) # 提交任务到队列(非阻塞) task_id = ai_swap_service.submit_swap_task( user_id=request.user_id, face_image_id=request.face_image_id, cloth_image_id=request.cloth_image_id, prompt=request.prompt, face_prompt=request.face_prompt, cloth_prompt=request.cloth_prompt, prompt_prompt=request.prompt_prompt, quantity=request.quantity or 1 ) logger.info(f"用户 {request.user_id} 换脸换装任务已提交,任务ID: {task_id}") return SwapResponse( success=True, task_id=task_id, process_record_id=None, result_image_id=None, copywriter_text=None, history_prompt=None ) except HTTPException: raise except Exception as e: logger.error(f"API处理异常: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") @app.get("/api/v1/tasks/{task_id}", response_model=TaskStatusResponse, tags=["任务管理"]) async def get_task_status(task_id: str): """ 获取任务状态 - **task_id**: 任务ID """ try: task_info = ai_swap_service.get_task_status(task_id) if not task_info: raise HTTPException(status_code=404, detail="任务不存在") return TaskStatusResponse( task_id=task_info["id"], status=task_info["status"].value, progress=task_info["progress"], created_at=task_info["created_at"].isoformat(), started_at=task_info["started_at"].isoformat() if task_info["started_at"] else None, completed_at=task_info["completed_at"].isoformat() if task_info["completed_at"] else None, result=task_info["result"], error=task_info["error"] ) except HTTPException: raise except Exception as e: logger.error(f"获取任务状态异常: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") @app.get("/api/v1/users/{user_id}/tasks", response_model=UserTasksResponse, tags=["任务管理"]) async def get_user_tasks(user_id: int): """ 获取用户的所有任务 - **user_id**: 用户ID """ try: # 验证用户 await verify_user(user_id) tasks = ai_swap_service.get_user_tasks(user_id) return UserTasksResponse( tasks=tasks, total=len(tasks) ) except HTTPException: raise except Exception as e: logger.error(f"获取用户任务异常: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") @app.delete("/api/v1/tasks/{task_id}", tags=["任务管理"]) async def cancel_task(task_id: str): """ 取消任务 - **task_id**: 任务ID """ try: success = ai_swap_service.cancel_task(task_id) if not success: raise HTTPException(status_code=404, detail="任务不存在或无法取消") return {"success": True, "message": "任务已取消"} except HTTPException: raise except Exception as e: logger.error(f"取消任务异常: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") @app.get("/api/v1/tasks/queue/stats", tags=["任务管理"]) async def get_queue_stats(): """ 获取任务队列统计信息 """ try: stats = ai_swap_service.task_queue.get_queue_stats() return stats except Exception as e: logger.error(f"获取队列统计异常: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") @app.get("/api/v1/users/{user_id}/process-history", response_model=ProcessHistoryResponse, tags=["历史记录"]) async def get_user_process_history( user_id: int, page: int = 1, page_size: int = 20, user: Dict[str, Any] = Depends(verify_user) ): """ 获取用户的处理历史记录 - **user_id**: 用户ID - **page**: 页码(默认1) - **page_size**: 每页大小(默认20,最大100) """ try: # 限制每页大小 page_size = min(page_size, 100) # 获取历史记录 history = ai_swap_service.get_user_process_history(user_id, page, page_size) return ProcessHistoryResponse( records=history["records"], total=history["total"], page=history["page"], page_size=history["page_size"], total_pages=history["total_pages"] ) except Exception as e: logger.error(f"获取用户历史记录失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"获取历史记录失败: {str(e)}") @app.get("/api/v1/process/{process_id}", response_model=ProcessDetailResponse, tags=["历史记录"]) async def get_process_detail( process_id: int, user_id: Optional[int] = None ): """ 获取处理记录详情 - **process_id**: 处理记录ID - **user_id**: 用户ID(可选,用于权限验证) """ try: # 如果提供了用户ID,先验证用户 if user_id: await verify_user(user_id) # 获取处理详情 detail = ai_swap_service.get_process_detail(process_id, user_id) if not detail: raise HTTPException(status_code=404, detail="处理记录不存在或无权限访问") return ProcessDetailResponse( process_record=detail["process_record"], face_image=detail["face_image"], cloth_image=detail["cloth_image"], result_image=detail["result_image"] ) except HTTPException: raise except Exception as e: logger.error(f"获取处理详情失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"获取处理详情失败: {str(e)}") @app.put("/api/v1/process/{process_id}/text", response_model=ProcessDetailResponse, tags=["历史记录"]) async def update_process_text( process_id: int, request: UpdateTextRequest = Body(...), ): """ 更新处理记录的文本内容 Args: process_id: 处理记录ID request: 更新文本请求数据 Returns: ProcessDetailResponse: 更新后的处理记录详情 """ try: # 验证用户 user = await verify_user(request.user_id) # 获取处理详情 detail = ai_swap_service.get_process_detail(process_id, request.user_id) if not detail: raise HTTPException(status_code=404, detail="处理记录不存在") # 验证权限 if detail["process_record"]["user_id"] != request.user_id: raise HTTPException(status_code=403, detail="无权限修改此记录") # 更新文本内容 generated_text = f"- 标题:{request.title}\n- 文案:{request.content}\n- 话题标签:{request.label}" # 更新数据库记录 update_data = { "generated_text": generated_text } updated_record = db_ops.update_process_record(process_id, update_data) if not updated_record: raise HTTPException(status_code=500, detail="更新数据库失败") # 重新获取更新后的记录 detail = ai_swap_service.get_process_detail(process_id, request.user_id) if not detail: raise HTTPException(status_code=500, detail="获取更新后的记录失败") return ProcessDetailResponse( process_record=detail["process_record"], face_image=detail["face_image"], cloth_image=detail["cloth_image"], result_image=detail["result_image"] ) except HTTPException: raise except Exception as e: logger.error(f"更新处理记录文本失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"更新处理记录文本失败: {str(e)}") @app.post("/api/v1/process/{process_id}/approve", response_model=ApproveRecordResponse, tags=["历史记录"]) async def approve_process_record( process_id: int, request: ApproveRecordRequest = Body(...) ): """ 审核处理记录 Args: process_id: 处理记录ID request: 审核请求数据 Returns: ApproveRecordResponse: 审核结果 """ try: # 验证用户 await verify_user(request.user_id) # 获取处理详情 detail = ai_swap_service.get_process_detail(process_id, request.user_id) if not detail: raise HTTPException(status_code=404, detail="处理记录不存在") # 验证权限 if detail["process_record"]["user_id"] != request.user_id: raise HTTPException(status_code=403, detail="无权限审核此记录") # 审核通过 success = ai_swap_service.approve_process_record(process_id) if not success: raise HTTPException(status_code=500, detail="审核失败") # 重新获取审核后的记录 detail = ai_swap_service.get_process_detail(process_id, request.user_id) if not detail: raise HTTPException(status_code=500, detail="获取审核后的记录失败") return ApproveRecordResponse( success=True, message="处理记录已审核通过", process_record=detail["process_record"] ) except HTTPException: raise except Exception as e: logger.error(f"审核处理记录失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"审核处理记录失败: {str(e)}") @app.delete("/api/v1/process/{process_id}/result-image", response_model=DeleteResultImageResponse, tags=["历史记录"]) async def delete_result_image( process_id: int, request: DeleteResultImageRequest = Body(...) ): """ 删除处理记录的结果图片 Args: process_id: 处理记录ID request: 删除请求数据 Returns: DeleteResultImageResponse: 删除结果 """ try: # 验证用户 await verify_user(request.user_id) # 获取处理详情 detail = ai_swap_service.get_process_detail(process_id, request.user_id) if not detail: raise HTTPException(status_code=404, detail="处理记录不存在") # 验证权限 if detail["process_record"]["user_id"] != request.user_id: raise HTTPException(status_code=403, detail="无权限删除此图片") # 删除图片 success = ai_swap_service.delete_result_image(process_id) if not success: raise HTTPException(status_code=500, detail="删除图片失败") return DeleteResultImageResponse( success=True, message="结果图片已删除" ) except HTTPException: raise except Exception as e: logger.error(f"删除结果图片失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"删除结果图片失败: {str(e)}") @app.get("/api/v1/users/{user_id}/images", tags=["图片管理"]) async def get_user_images( user_id: int, image_type: Optional[str] = None, page: int = 1, page_size: int = 20, user: Dict[str, Any] = Depends(verify_user) ): """ 获取用户的图片列表 - **user_id**: 用户ID - **image_type**: 图片类型(face/cloth/result,可选) - **page**: 页码(默认1) - **page_size**: 每页大小(默认20,最大100) """ try: # 限制每页大小 page_size = min(page_size, 100) # 验证图片类型 if image_type and image_type not in ["face", "cloth", "result", "original"]: raise HTTPException(status_code=400, detail="图片类型必须是 face、cloth、result 或 original") # 获取图片列表 images = db_ops.get_user_images(user_id, image_type, page, page_size) return { "images": images["images"], "total": images["total"], "page": images["page"], "page_size": images["page_size"], "total_pages": images["total_pages"] } except HTTPException: raise except Exception as e: logger.error(f"获取用户图片列表失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"获取图片列表失败: {str(e)}") @app.get("/api/v1/dashboard/stats", response_model=DashboardStatsResponse, tags=["仪表盘"]) async def get_dashboard_stats(user_id: int): """ 获取仪表盘统计数据 Args: user_id: 用户ID Returns: DashboardStatsResponse: 包含今日生成、待审核、已发布的数量 """ try: with db_ops.db_connection.get_session() as session: # 获取今天的开始时间(0点) today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) # 取该用户所有记录的关键字段,按组合(face+cloth)聚合 rows = ( session.query( ProcessRecord.face_image_id, ProcessRecord.cloth_image_id, ProcessRecord.status, ProcessRecord.completed_at, ProcessRecord.result_image_id, ) .filter(ProcessRecord.user_id == user_id) .all() ) # 组合 -> 最新记录(以 completed_at 最大为准) latest_by_group = {} # 今日生成的组合集合(任意一条记录今天完成即可计一次) groups_generated_today = set() for face_id, cloth_id, status, completed_at, result_image_id in rows: if completed_at and completed_at >= today_start and result_image_id is not None: groups_generated_today.add((face_id, cloth_id)) key = (face_id, cloth_id) prev = latest_by_group.get(key) if prev is None: latest_by_group[key] = (status, completed_at, result_image_id) else: _, prev_time, _ = prev # None 视为更早 if (prev_time or datetime.min) < (completed_at or datetime.min): latest_by_group[key] = (status, completed_at, result_image_id) # 以组合为单位统计 today_generated = len(groups_generated_today) pending_review = 0 published = 0 for (_face_id, _cloth_id), (status, _t, _rid) in latest_by_group.items(): if status == "待审核": pending_review += 1 elif status == "已发布": published += 1 return DashboardStatsResponse( today_generated=today_generated, pending_review=pending_review, published=published, ) except Exception as e: logger.error(f"获取仪表盘统计数据失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"获取仪表盘统计数据失败: {str(e)}") @app.get("/api/v1/statistics", tags=["系统统计"]) async def get_system_statistics(): """ 获取系统统计信息 """ try: stats = db_ops.get_statistics() return { "success": True, "statistics": stats } except Exception as e: logger.error(f"获取系统统计失败: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"获取系统统计失败: {str(e)}") # ==================== 健康检查 ==================== @app.get("/health", tags=["系统"]) async def health_check(): """健康检查接口""" return {"status": "healthy", "service": "ai_swap_api"} @app.get("/") def root(): return {"msg": "AI换脸换装服务已启动"} # ==================== 静态文件直链下载(包含CORS响应头) ==================== @app.get("/api/v1/output/{filename}") async def get_output_file(filename: str): """为结果图片提供带CORS头的直链下载""" file_path = os.path.join(output_dir, filename) if not os.path.isfile(file_path): raise HTTPException(status_code=404, detail="文件不存在") headers = { "Access-Control-Allow-Origin": "*", "Access-Control-Expose-Headers": "Content-Disposition", } return FileResponse(file_path, headers=headers, filename=filename) @app.get("/api/v1/materials/{filename}") async def get_material_file(filename: str): """为素材图片提供带CORS头的直链下载""" file_path = os.path.join(materials_dir, filename) if not os.path.isfile(file_path): raise HTTPException(status_code=404, detail="文件不存在") headers = { "Access-Control-Allow-Origin": "*", "Access-Control-Expose-Headers": "Content-Disposition", } return FileResponse(file_path, headers=headers, filename=filename) # ==================== 错误处理 ==================== @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): """HTTP异常处理器""" return JSONResponse( status_code=exc.status_code, content={ "success": False, "error": exc.detail, "error_type": "HTTPException" } ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): """通用异常处理器""" logger.error(f"未处理的异常: {str(exc)}", exc_info=True) return JSONResponse( status_code=500, content={ "success": False, "error": "服务器内部错误", "error_type": "InternalServerError" } ) # ==================== 启动配置 ==================== if __name__ == "__main__": """ 启动API服务器 """ host = "0.0.0.0" port = 8002 reload = True print(f"启动AI Swap API服务器...") print(f"服务器地址: http://{host}:{port}") print(f"API文档: http://{host}:{port}/docs") print(f"自动重载: {reload}") uvicorn.run( "backend.api.ai_swap_api:app", host=host, port=port, reload=reload, log_level="info" )