ai_swap_api.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. """
  2. AI换脸换装API接口
  3. 提供RESTful API接口,支持换脸换装、历史记录查询等功能
  4. """
  5. from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Body
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from fastapi.responses import JSONResponse
  8. from starlette.responses import FileResponse
  9. from fastapi.staticfiles import StaticFiles
  10. import os
  11. from pydantic import BaseModel, Field
  12. from typing import Optional, List, Dict, Any, Union
  13. import uvicorn
  14. from datetime import datetime, timedelta
  15. from pydantic import BaseModel
  16. from sqlalchemy import func, and_
  17. from backend.services.ai_swap_service import AISwapService, process_swap_with_record
  18. from backend.modules.database.operations import DatabaseOperations
  19. from backend.utils.logger_config import setup_logger
  20. from backend.api.auth_api import router as auth_router
  21. from backend.api.user_material_api import router as user_material_router
  22. from backend.api.user_template_api import router as user_template_router
  23. from backend.modules.database.models import ProcessRecord
  24. from backend.api.auto_post_api import router as auto_post_router
  25. from backend.api.ai_swap_bg_api import router as ai_swap_bg_router
  26. from backend.api.ai_swap_cloth_api import router as ai_swap_cloth_router
  27. from backend.api.ai_swap_face_api import router as ai_swap_face_router
  28. from backend.api.ai_gen_video_api import router as ai_gen_video_router
  29. logger = setup_logger(__name__)
  30. # 创建FastAPI应用
  31. app = FastAPI(
  32. title="AI换脸换装服务",
  33. description="包含用户认证与管理、AI换脸换装等API",
  34. version="1.0.0"
  35. )
  36. # 挂载静态目录:使用与当前文件相对的绝对路径,避免工作目录变化导致路径失效
  37. current_dir = os.path.dirname(os.path.abspath(__file__))
  38. project_root = os.path.dirname(os.path.dirname(current_dir))
  39. materials_dir = os.path.abspath(os.path.join(project_root, 'backend', 'data', 'materials'))
  40. output_dir = os.path.abspath(os.path.join(project_root, 'backend', 'output'))
  41. # 确保目录存在(StaticFiles 需要已存在的目录)
  42. os.makedirs(materials_dir, exist_ok=True)
  43. os.makedirs(output_dir, exist_ok=True)
  44. allowed_origins = [
  45. "http://localhost:5173",
  46. "http://127.0.0.1:5173",
  47. "http://localhost:3000",
  48. "http://127.0.0.1:3000",
  49. "http://10.41.175.254:5173",
  50. "http://10.41.175.254:3000",
  51. # 局域网访问(当前机器IP)
  52. "http://10.41.175.43:5173",
  53. "http://10.41.175.43:3000",
  54. # 局域网访问(域名访问)
  55. "http://ai-swap.local:5173",
  56. ]
  57. # 为静态资源单独包裹 CORS 中间件,确保跨域下载(fetch blob)也包含 CORS 头
  58. materials_app = StaticFiles(directory=materials_dir)
  59. materials_app = CORSMiddleware(
  60. app=materials_app,
  61. allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS
  62. allow_credentials=False,
  63. allow_methods=["*"],
  64. allow_headers=["*"],
  65. )
  66. app.mount("/materials", materials_app, name="materials")
  67. output_app = StaticFiles(directory=output_dir)
  68. output_app = CORSMiddleware(
  69. app=output_app,
  70. allow_origins=["*"], # 静态资源放宽到任意来源,避免浏览器下载报CORS
  71. allow_credentials=False,
  72. allow_methods=["*"],
  73. allow_headers=["*"],
  74. )
  75. app.mount("/output", output_app, name="output")
  76. # 添加CORS中间件
  77. app.add_middleware(
  78. CORSMiddleware,
  79. allow_origins=allowed_origins,
  80. allow_credentials=True,
  81. allow_methods=["*"],
  82. allow_headers=["*"],
  83. )
  84. # 注册认证API
  85. app.include_router(auth_router, prefix="/auth", tags=["auth"])
  86. app.include_router(user_material_router)
  87. app.include_router(auto_post_router, prefix="/auto_post")
  88. app.include_router(user_template_router)
  89. app.include_router(ai_swap_bg_router)
  90. app.include_router(ai_swap_cloth_router)
  91. app.include_router(ai_swap_face_router)
  92. app.include_router(ai_gen_video_router)
  93. # 创建服务实例
  94. ai_swap_service = AISwapService()
  95. db_ops = DatabaseOperations()
  96. # 启动任务队列服务
  97. from backend.services.task_queue_service import get_task_queue_service
  98. task_queue_service = get_task_queue_service()
  99. task_queue_service.start()
  100. # ==================== 数据模型 ====================
  101. class SwapRequest(BaseModel):
  102. """换脸换装请求模型"""
  103. user_id: int = Field(..., description="用户ID")
  104. face_image_id: int = Field(..., description="人脸图片ID")
  105. cloth_image_id: int = Field(..., description="服装图片ID")
  106. prompt: str = Field(..., description="提示词", max_length=500)
  107. face_prompt: Optional[str] = Field(None, description="人脸识别提示词")
  108. cloth_prompt: Optional[str] = Field(None, description="服装识别提示词")
  109. prompt_prompt: Optional[str] = Field(None, description="提示词优化")
  110. quantity: Optional[int] = Field(1, ge=1, le=10, description="每组生成数量")
  111. class SwapResponse(BaseModel):
  112. """换脸换装响应模型"""
  113. success: bool
  114. task_id: Optional[str] = None
  115. process_record_id: Optional[int] = None
  116. result_image_id: Optional[int] = None
  117. copywriter_text: Optional[str] = None
  118. history_prompt: Optional[str] = None
  119. error: Optional[str] = None
  120. error_type: Optional[str] = None
  121. class TaskStatusResponse(BaseModel):
  122. """任务状态响应模型"""
  123. task_id: str
  124. status: str
  125. progress: int
  126. created_at: str
  127. started_at: Optional[str] = None
  128. completed_at: Optional[str] = None
  129. result: Optional[Dict[str, Any]] = None
  130. error: Optional[str] = None
  131. class UserTasksResponse(BaseModel):
  132. """用户任务列表响应模型"""
  133. tasks: List[Dict[str, Any]]
  134. total: int
  135. class ProcessHistoryResponse(BaseModel):
  136. """处理历史响应模型"""
  137. records: List[Dict[str, Any]]
  138. total: int
  139. page: int
  140. page_size: int
  141. total_pages: int
  142. class ProcessDetailResponse(BaseModel):
  143. """处理详情响应模型"""
  144. process_record: Dict[str, Any]
  145. face_image: Optional[Dict[str, Any]] = None
  146. cloth_image: Optional[Dict[str, Any]] = None
  147. result_image: Optional[Dict[str, Any]] = None
  148. class UpdateTextRequest(BaseModel):
  149. """更新文本请求模型"""
  150. user_id: int = Field(..., description="用户ID")
  151. title: str = Field(..., description="标题")
  152. content: str = Field(..., description="文案内容")
  153. label: str = Field(..., description="话题标签")
  154. class Config:
  155. json_schema_extra = {
  156. "example": {
  157. "user_id": 1,
  158. "title": "示例标题",
  159. "content": "示例文案内容",
  160. "label": "示例标签"
  161. }
  162. }
  163. class ApproveRecordRequest(BaseModel):
  164. """审核记录请求模型"""
  165. user_id: int = Field(..., description="用户ID")
  166. class Config:
  167. json_schema_extra = {
  168. "example": {
  169. "user_id": 1
  170. }
  171. }
  172. class ApproveRecordResponse(BaseModel):
  173. """审核记录响应模型"""
  174. success: bool
  175. message: str
  176. process_record: Optional[Dict[str, Any]] = None
  177. class DeleteResultImageRequest(BaseModel):
  178. """删除结果图片请求模型"""
  179. user_id: int = Field(..., description="用户ID")
  180. class Config:
  181. json_schema_extra = {
  182. "example": {
  183. "user_id": 1
  184. }
  185. }
  186. class DeleteResultImageResponse(BaseModel):
  187. """删除结果图片响应模型"""
  188. success: bool
  189. message: str
  190. class DashboardStatsResponse(BaseModel):
  191. """仪表盘统计数据响应模型"""
  192. today_generated: int = Field(..., description="今日生成数量")
  193. pending_review: int = Field(..., description="待审核数量")
  194. published: int = Field(..., description="已发布数量")
  195. # ==================== 中间件和依赖 ====================
  196. async def verify_user(user_id: int):
  197. """验证用户是否存在且有效"""
  198. user = db_ops.get_user_by_id(user_id)
  199. if not user:
  200. raise HTTPException(status_code=404, detail=f"用户ID {user_id} 不存在")
  201. if not user.get("is_active", False):
  202. raise HTTPException(status_code=403, detail="用户账户已被禁用")
  203. return user
  204. # ==================== API接口 ====================
  205. @app.post("/api/v1/swap", response_model=SwapResponse, tags=["AI换脸换装"])
  206. async def process_swap(request: SwapRequest):
  207. """
  208. 执行AI换脸换装(异步任务)
  209. - **user_id**: 用户ID
  210. - **face_image_id**: 人脸图片ID
  211. - **cloth_image_id**: 服装图片ID
  212. - **prompt**: 提示词
  213. - **face_prompt**: 人脸识别提示词(可选)
  214. - **cloth_prompt**: 服装识别提示词(可选)
  215. - **prompt_prompt**: 提示词优化(可选)
  216. """
  217. try:
  218. # 验证用户
  219. await verify_user(request.user_id)
  220. # 提交任务到队列(非阻塞)
  221. task_id = ai_swap_service.submit_swap_task(
  222. user_id=request.user_id,
  223. face_image_id=request.face_image_id,
  224. cloth_image_id=request.cloth_image_id,
  225. prompt=request.prompt,
  226. face_prompt=request.face_prompt,
  227. cloth_prompt=request.cloth_prompt,
  228. prompt_prompt=request.prompt_prompt,
  229. quantity=request.quantity or 1
  230. )
  231. logger.info(f"用户 {request.user_id} 换脸换装任务已提交,任务ID: {task_id}")
  232. return SwapResponse(
  233. success=True,
  234. task_id=task_id,
  235. process_record_id=None,
  236. result_image_id=None,
  237. copywriter_text=None,
  238. history_prompt=None
  239. )
  240. except HTTPException:
  241. raise
  242. except Exception as e:
  243. logger.error(f"API处理异常: {str(e)}", exc_info=True)
  244. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
  245. @app.get("/api/v1/tasks/{task_id}", response_model=TaskStatusResponse, tags=["任务管理"])
  246. async def get_task_status(task_id: str):
  247. """
  248. 获取任务状态
  249. - **task_id**: 任务ID
  250. """
  251. try:
  252. task_info = ai_swap_service.get_task_status(task_id)
  253. if not task_info:
  254. raise HTTPException(status_code=404, detail="任务不存在")
  255. return TaskStatusResponse(
  256. task_id=task_info["id"],
  257. status=task_info["status"].value,
  258. progress=task_info["progress"],
  259. created_at=task_info["created_at"].isoformat(),
  260. started_at=task_info["started_at"].isoformat() if task_info["started_at"] else None,
  261. completed_at=task_info["completed_at"].isoformat() if task_info["completed_at"] else None,
  262. result=task_info["result"],
  263. error=task_info["error"]
  264. )
  265. except HTTPException:
  266. raise
  267. except Exception as e:
  268. logger.error(f"获取任务状态异常: {str(e)}", exc_info=True)
  269. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
  270. @app.get("/api/v1/users/{user_id}/tasks", response_model=UserTasksResponse, tags=["任务管理"])
  271. async def get_user_tasks(user_id: int):
  272. """
  273. 获取用户的所有任务
  274. - **user_id**: 用户ID
  275. """
  276. try:
  277. # 验证用户
  278. await verify_user(user_id)
  279. tasks = ai_swap_service.get_user_tasks(user_id)
  280. return UserTasksResponse(
  281. tasks=tasks,
  282. total=len(tasks)
  283. )
  284. except HTTPException:
  285. raise
  286. except Exception as e:
  287. logger.error(f"获取用户任务异常: {str(e)}", exc_info=True)
  288. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
  289. @app.delete("/api/v1/tasks/{task_id}", tags=["任务管理"])
  290. async def cancel_task(task_id: str):
  291. """
  292. 取消任务
  293. - **task_id**: 任务ID
  294. """
  295. try:
  296. success = ai_swap_service.cancel_task(task_id)
  297. if not success:
  298. raise HTTPException(status_code=404, detail="任务不存在或无法取消")
  299. return {"success": True, "message": "任务已取消"}
  300. except HTTPException:
  301. raise
  302. except Exception as e:
  303. logger.error(f"取消任务异常: {str(e)}", exc_info=True)
  304. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
  305. @app.get("/api/v1/tasks/queue/stats", tags=["任务管理"])
  306. async def get_queue_stats():
  307. """
  308. 获取任务队列统计信息
  309. """
  310. try:
  311. stats = ai_swap_service.task_queue.get_queue_stats()
  312. return stats
  313. except Exception as e:
  314. logger.error(f"获取队列统计异常: {str(e)}", exc_info=True)
  315. raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}")
  316. @app.get("/api/v1/users/{user_id}/process-history", response_model=ProcessHistoryResponse, tags=["历史记录"])
  317. async def get_user_process_history(
  318. user_id: int,
  319. page: int = 1,
  320. page_size: int = 20,
  321. user: Dict[str, Any] = Depends(verify_user)
  322. ):
  323. """
  324. 获取用户的处理历史记录
  325. - **user_id**: 用户ID
  326. - **page**: 页码(默认1)
  327. - **page_size**: 每页大小(默认20,最大100)
  328. """
  329. try:
  330. # 限制每页大小
  331. page_size = min(page_size, 100)
  332. # 获取历史记录
  333. history = ai_swap_service.get_user_process_history(user_id, page, page_size)
  334. return ProcessHistoryResponse(
  335. records=history["records"],
  336. total=history["total"],
  337. page=history["page"],
  338. page_size=history["page_size"],
  339. total_pages=history["total_pages"]
  340. )
  341. except Exception as e:
  342. logger.error(f"获取用户历史记录失败: {str(e)}", exc_info=True)
  343. raise HTTPException(status_code=500, detail=f"获取历史记录失败: {str(e)}")
  344. @app.get("/api/v1/process/{process_id}", response_model=ProcessDetailResponse, tags=["历史记录"])
  345. async def get_process_detail(
  346. process_id: int,
  347. user_id: Optional[int] = None
  348. ):
  349. """
  350. 获取处理记录详情
  351. - **process_id**: 处理记录ID
  352. - **user_id**: 用户ID(可选,用于权限验证)
  353. """
  354. try:
  355. # 如果提供了用户ID,先验证用户
  356. if user_id:
  357. await verify_user(user_id)
  358. # 获取处理详情
  359. detail = ai_swap_service.get_process_detail(process_id, user_id)
  360. if not detail:
  361. raise HTTPException(status_code=404, detail="处理记录不存在或无权限访问")
  362. return ProcessDetailResponse(
  363. process_record=detail["process_record"],
  364. face_image=detail["face_image"],
  365. cloth_image=detail["cloth_image"],
  366. result_image=detail["result_image"]
  367. )
  368. except HTTPException:
  369. raise
  370. except Exception as e:
  371. logger.error(f"获取处理详情失败: {str(e)}", exc_info=True)
  372. raise HTTPException(status_code=500, detail=f"获取处理详情失败: {str(e)}")
  373. @app.put("/api/v1/process/{process_id}/text", response_model=ProcessDetailResponse, tags=["历史记录"])
  374. async def update_process_text(
  375. process_id: int,
  376. request: UpdateTextRequest = Body(...),
  377. ):
  378. """
  379. 更新处理记录的文本内容
  380. Args:
  381. process_id: 处理记录ID
  382. request: 更新文本请求数据
  383. Returns:
  384. ProcessDetailResponse: 更新后的处理记录详情
  385. """
  386. try:
  387. # 验证用户
  388. user = await verify_user(request.user_id)
  389. # 获取处理详情
  390. detail = ai_swap_service.get_process_detail(process_id, request.user_id)
  391. if not detail:
  392. raise HTTPException(status_code=404, detail="处理记录不存在")
  393. # 验证权限
  394. if detail["process_record"]["user_id"] != request.user_id:
  395. raise HTTPException(status_code=403, detail="无权限修改此记录")
  396. # 更新文本内容
  397. generated_text = f"- 标题:{request.title}\n- 文案:{request.content}\n- 话题标签:{request.label}"
  398. # 更新数据库记录
  399. update_data = {
  400. "generated_text": generated_text
  401. }
  402. updated_record = db_ops.update_process_record(process_id, update_data)
  403. if not updated_record:
  404. raise HTTPException(status_code=500, detail="更新数据库失败")
  405. # 重新获取更新后的记录
  406. detail = ai_swap_service.get_process_detail(process_id, request.user_id)
  407. if not detail:
  408. raise HTTPException(status_code=500, detail="获取更新后的记录失败")
  409. return ProcessDetailResponse(
  410. process_record=detail["process_record"],
  411. face_image=detail["face_image"],
  412. cloth_image=detail["cloth_image"],
  413. result_image=detail["result_image"]
  414. )
  415. except HTTPException:
  416. raise
  417. except Exception as e:
  418. logger.error(f"更新处理记录文本失败: {str(e)}", exc_info=True)
  419. raise HTTPException(status_code=500, detail=f"更新处理记录文本失败: {str(e)}")
  420. @app.post("/api/v1/process/{process_id}/approve", response_model=ApproveRecordResponse, tags=["历史记录"])
  421. async def approve_process_record(
  422. process_id: int,
  423. request: ApproveRecordRequest = Body(...)
  424. ):
  425. """
  426. 审核处理记录
  427. Args:
  428. process_id: 处理记录ID
  429. request: 审核请求数据
  430. Returns:
  431. ApproveRecordResponse: 审核结果
  432. """
  433. try:
  434. # 验证用户
  435. await verify_user(request.user_id)
  436. # 获取处理详情
  437. detail = ai_swap_service.get_process_detail(process_id, request.user_id)
  438. if not detail:
  439. raise HTTPException(status_code=404, detail="处理记录不存在")
  440. # 验证权限
  441. if detail["process_record"]["user_id"] != request.user_id:
  442. raise HTTPException(status_code=403, detail="无权限审核此记录")
  443. # 审核通过
  444. success = ai_swap_service.approve_process_record(process_id)
  445. if not success:
  446. raise HTTPException(status_code=500, detail="审核失败")
  447. # 重新获取审核后的记录
  448. detail = ai_swap_service.get_process_detail(process_id, request.user_id)
  449. if not detail:
  450. raise HTTPException(status_code=500, detail="获取审核后的记录失败")
  451. return ApproveRecordResponse(
  452. success=True,
  453. message="处理记录已审核通过",
  454. process_record=detail["process_record"]
  455. )
  456. except HTTPException:
  457. raise
  458. except Exception as e:
  459. logger.error(f"审核处理记录失败: {str(e)}", exc_info=True)
  460. raise HTTPException(status_code=500, detail=f"审核处理记录失败: {str(e)}")
  461. @app.delete("/api/v1/process/{process_id}/result-image", response_model=DeleteResultImageResponse, tags=["历史记录"])
  462. async def delete_result_image(
  463. process_id: int,
  464. request: DeleteResultImageRequest = Body(...)
  465. ):
  466. """
  467. 删除处理记录的结果图片
  468. Args:
  469. process_id: 处理记录ID
  470. request: 删除请求数据
  471. Returns:
  472. DeleteResultImageResponse: 删除结果
  473. """
  474. try:
  475. # 验证用户
  476. await verify_user(request.user_id)
  477. # 获取处理详情
  478. detail = ai_swap_service.get_process_detail(process_id, request.user_id)
  479. if not detail:
  480. raise HTTPException(status_code=404, detail="处理记录不存在")
  481. # 验证权限
  482. if detail["process_record"]["user_id"] != request.user_id:
  483. raise HTTPException(status_code=403, detail="无权限删除此图片")
  484. # 删除图片
  485. success = ai_swap_service.delete_result_image(process_id)
  486. if not success:
  487. raise HTTPException(status_code=500, detail="删除图片失败")
  488. return DeleteResultImageResponse(
  489. success=True,
  490. message="结果图片已删除"
  491. )
  492. except HTTPException:
  493. raise
  494. except Exception as e:
  495. logger.error(f"删除结果图片失败: {str(e)}", exc_info=True)
  496. raise HTTPException(status_code=500, detail=f"删除结果图片失败: {str(e)}")
  497. @app.get("/api/v1/users/{user_id}/images", tags=["图片管理"])
  498. async def get_user_images(
  499. user_id: int,
  500. image_type: Optional[str] = None,
  501. page: int = 1,
  502. page_size: int = 20,
  503. user: Dict[str, Any] = Depends(verify_user)
  504. ):
  505. """
  506. 获取用户的图片列表
  507. - **user_id**: 用户ID
  508. - **image_type**: 图片类型(face/cloth/result,可选)
  509. - **page**: 页码(默认1)
  510. - **page_size**: 每页大小(默认20,最大100)
  511. """
  512. try:
  513. # 限制每页大小
  514. page_size = min(page_size, 100)
  515. # 验证图片类型
  516. if image_type and image_type not in ["face", "cloth", "result", "original"]:
  517. raise HTTPException(status_code=400, detail="图片类型必须是 face、cloth、result 或 original")
  518. # 获取图片列表
  519. images = db_ops.get_user_images(user_id, image_type, page, page_size)
  520. return {
  521. "images": images["images"],
  522. "total": images["total"],
  523. "page": images["page"],
  524. "page_size": images["page_size"],
  525. "total_pages": images["total_pages"]
  526. }
  527. except HTTPException:
  528. raise
  529. except Exception as e:
  530. logger.error(f"获取用户图片列表失败: {str(e)}", exc_info=True)
  531. raise HTTPException(status_code=500, detail=f"获取图片列表失败: {str(e)}")
  532. @app.get("/api/v1/dashboard/stats", response_model=DashboardStatsResponse, tags=["仪表盘"])
  533. async def get_dashboard_stats(user_id: int):
  534. """
  535. 获取仪表盘统计数据
  536. Args:
  537. user_id: 用户ID
  538. Returns:
  539. DashboardStatsResponse: 包含今日生成、待审核、已发布的数量
  540. """
  541. try:
  542. with db_ops.db_connection.get_session() as session:
  543. # 获取今天的开始时间(0点)
  544. today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
  545. # 取该用户所有记录的关键字段,按组合(face+cloth)聚合
  546. rows = (
  547. session.query(
  548. ProcessRecord.face_image_id,
  549. ProcessRecord.cloth_image_id,
  550. ProcessRecord.status,
  551. ProcessRecord.completed_at,
  552. ProcessRecord.result_image_id,
  553. )
  554. .filter(ProcessRecord.user_id == user_id)
  555. .all()
  556. )
  557. # 组合 -> 最新记录(以 completed_at 最大为准)
  558. latest_by_group = {}
  559. # 今日生成的组合集合(任意一条记录今天完成即可计一次)
  560. groups_generated_today = set()
  561. for face_id, cloth_id, status, completed_at, result_image_id in rows:
  562. if completed_at and completed_at >= today_start and result_image_id is not None:
  563. groups_generated_today.add((face_id, cloth_id))
  564. key = (face_id, cloth_id)
  565. prev = latest_by_group.get(key)
  566. if prev is None:
  567. latest_by_group[key] = (status, completed_at, result_image_id)
  568. else:
  569. _, prev_time, _ = prev
  570. # None 视为更早
  571. if (prev_time or datetime.min) < (completed_at or datetime.min):
  572. latest_by_group[key] = (status, completed_at, result_image_id)
  573. # 以组合为单位统计
  574. today_generated = len(groups_generated_today)
  575. pending_review = 0
  576. published = 0
  577. for (_face_id, _cloth_id), (status, _t, _rid) in latest_by_group.items():
  578. if status == "待审核":
  579. pending_review += 1
  580. elif status == "已发布":
  581. published += 1
  582. return DashboardStatsResponse(
  583. today_generated=today_generated,
  584. pending_review=pending_review,
  585. published=published,
  586. )
  587. except Exception as e:
  588. logger.error(f"获取仪表盘统计数据失败: {str(e)}", exc_info=True)
  589. raise HTTPException(status_code=500, detail=f"获取仪表盘统计数据失败: {str(e)}")
  590. @app.get("/api/v1/statistics", tags=["系统统计"])
  591. async def get_system_statistics():
  592. """
  593. 获取系统统计信息
  594. """
  595. try:
  596. stats = db_ops.get_statistics()
  597. return {
  598. "success": True,
  599. "statistics": stats
  600. }
  601. except Exception as e:
  602. logger.error(f"获取系统统计失败: {str(e)}", exc_info=True)
  603. raise HTTPException(status_code=500, detail=f"获取系统统计失败: {str(e)}")
  604. # ==================== 健康检查 ====================
  605. @app.get("/health", tags=["系统"])
  606. async def health_check():
  607. """健康检查接口"""
  608. return {"status": "healthy", "service": "ai_swap_api"}
  609. @app.get("/")
  610. def root():
  611. return {"msg": "AI换脸换装服务已启动"}
  612. # ==================== 静态文件直链下载(包含CORS响应头) ====================
  613. @app.get("/api/v1/output/{filename}")
  614. async def get_output_file(filename: str):
  615. """为结果图片提供带CORS头的直链下载"""
  616. file_path = os.path.join(output_dir, filename)
  617. if not os.path.isfile(file_path):
  618. raise HTTPException(status_code=404, detail="文件不存在")
  619. headers = {
  620. "Access-Control-Allow-Origin": "*",
  621. "Access-Control-Expose-Headers": "Content-Disposition",
  622. }
  623. return FileResponse(file_path, headers=headers, filename=filename)
  624. @app.get("/api/v1/materials/{filename}")
  625. async def get_material_file(filename: str):
  626. """为素材图片提供带CORS头的直链下载"""
  627. file_path = os.path.join(materials_dir, filename)
  628. if not os.path.isfile(file_path):
  629. raise HTTPException(status_code=404, detail="文件不存在")
  630. headers = {
  631. "Access-Control-Allow-Origin": "*",
  632. "Access-Control-Expose-Headers": "Content-Disposition",
  633. }
  634. return FileResponse(file_path, headers=headers, filename=filename)
  635. # ==================== 错误处理 ====================
  636. @app.exception_handler(HTTPException)
  637. async def http_exception_handler(request, exc):
  638. """HTTP异常处理器"""
  639. return JSONResponse(
  640. status_code=exc.status_code,
  641. content={
  642. "success": False,
  643. "error": exc.detail,
  644. "error_type": "HTTPException"
  645. }
  646. )
  647. @app.exception_handler(Exception)
  648. async def general_exception_handler(request, exc):
  649. """通用异常处理器"""
  650. logger.error(f"未处理的异常: {str(exc)}", exc_info=True)
  651. return JSONResponse(
  652. status_code=500,
  653. content={
  654. "success": False,
  655. "error": "服务器内部错误",
  656. "error_type": "InternalServerError"
  657. }
  658. )
  659. # ==================== 启动配置 ====================
  660. if __name__ == "__main__":
  661. """
  662. 启动API服务器
  663. """
  664. host = "0.0.0.0"
  665. port = 8002
  666. reload = True
  667. print(f"启动AI Swap API服务器...")
  668. print(f"服务器地址: http://{host}:{port}")
  669. print(f"API文档: http://{host}:{port}/docs")
  670. print(f"自动重载: {reload}")
  671. uvicorn.run(
  672. "backend.api.ai_swap_api:app",
  673. host=host,
  674. port=port,
  675. reload=reload,
  676. log_level="info"
  677. )