| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- import os
- import json
- from typing import List, Optional, Dict, Any
- from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
- from pydantic import BaseModel, Field, validator
- from datetime import datetime
- import time
- import threading
- import sys
- current_dir = os.path.dirname(os.path.abspath(__file__))
- backend_dir = os.path.dirname(current_dir)
- if backend_dir not in sys.path:
- sys.path.insert(0, backend_dir)
- from services.auto_post_service import auto_post_service
- from modules.database.operations import DatabaseOperations
- from utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- router = APIRouter()
- # 创建线程锁防止并发冲突
- post_lock = threading.Lock()
- db_ops = DatabaseOperations()
- # ============== 任务状态托管 ==============
- # 简单的内存任务注册表(进程内有效,重启后清空)
- _task_registry_lock = threading.Lock()
- _task_registry: Dict[str, Dict[str, Any]] = {}
- class AutoPostTaskItemStatus(BaseModel):
- index: int
- schedule_time: Optional[str] = None
- status: str = "pending" # pending | running | success | failure
- error: Optional[str] = None
- started_at: Optional[str] = None
- completed_at: Optional[str] = None
- class AutoPostTaskStatus(BaseModel):
- task_id: str
- name: str
- platform: str
- frequency: Optional[str] = None
- scheduled_times: List[str] = []
- total: int
- success: int
- failure: int
- next_run: Optional[str] = None
- last_run: Optional[str] = None
- is_completed: bool = False
- created_at: str
- started_at: Optional[str] = None
- completed_at: Optional[str] = None
- items: List[AutoPostTaskItemStatus] = []
- class BatchPostItem(BaseModel):
- image_paths: List[str]
- title: str
- description: str
- topics: List[str]
- schedule_time: Optional[str] = None
- if_clean: bool = False
- # 新增:该发布项关联的处理记录ID列表(同一组合内的拆分记录)
- record_ids: List[int] = []
- class BatchPostRequest(BaseModel):
- name: str = Field("定时发布")
- platform: str = Field(..., description="发布平台,如 xiaohongshu")
- frequency: Optional[str] = None
- scheduled_times: List[str] = []
- tasks: List[BatchPostItem]
- # 新增:创建该任务的用户ID,用于任务隔离
- user_id: Optional[int] = None
- def _save_task_status(task_id: str, status: Dict[str, Any]) -> None:
- with _task_registry_lock:
- _task_registry[task_id] = status
- def _get_task_status(task_id: str) -> Optional[Dict[str, Any]]:
- with _task_registry_lock:
- return _task_registry.get(task_id)
- def _list_tasks(user_id: Optional[int] = None) -> List[Dict[str, Any]]:
- with _task_registry_lock:
- tasks = list(_task_registry.values())
- if user_id is not None:
- tasks = [t for t in tasks if t.get("user_id") == user_id]
- # 返回按创建时间倒序的列表
- return sorted(tasks, key=lambda x: x.get("created_at", ""), reverse=True)
- class XiaohongshuPostRequest(BaseModel):
- """小红书发布请求数据结构"""
- image_paths: List[str]
- title: str
- description: str
- topics: List[str]
- schedule_time: Optional[str] = None
- if_clean: bool = False
- class PostResponse(BaseModel):
- """API响应数据结构"""
- success: bool
- message: str
- task_id: Optional[str] = None
- timestamp: str
- class CreateBatchTaskResponse(BaseModel):
- success: bool
- message: str
- task_id: str
- total_tasks: int
- timestamp: str
- def _run_post_task(req: XiaohongshuPostRequest) -> None:
- """在后台线程中执行发布任务,避免阻塞事件循环。"""
- try:
- with post_lock:
- auto_post_service.post_to_xiaohongshu(
- image_paths=req.image_paths,
- title=req.title,
- description=req.description,
- topics=req.topics,
- schedule_time=req.schedule_time,
- if_clean=req.if_clean,
- )
- logger.info("Background post task finished successfully")
- except Exception as e:
- logger.exception(f"Background post task failed: {str(e)}")
- @router.post("/xiaohongshu/post", response_model=PostResponse, tags=["发布管理"])
- async def post_to_xiaohongshu(request: XiaohongshuPostRequest, background_tasks: BackgroundTasks):
- """
- 发布内容到小红书
- - **image_paths**: 图片路径列表
- - **title**: 内容标题
- - **description**: 详细描述
- - **topics**: 话题标签列表
- - **schedule_time**: 定时发布时间 (格式: YYYY-MM-DD HH:MM)
- - **if_clean**: 是否清理浏览器缓存 (默认False)
- """
- task_id = f"task_{int(time.time())}"
- logger.info(f"Received new post task {task_id}")
- # 验证图片路径
- for path in request.image_paths:
- if not os.path.exists(path):
- logger.error(f"Image not found: {path}")
- raise HTTPException(
- status_code=400,
- )
- # 验证时间格式
- if request.schedule_time:
- try:
- datetime.strptime(request.schedule_time, "%Y-%m-%d %H:%M")
- except ValueError:
- logger.error(f"Invalid time format: {request.schedule_time}")
- raise HTTPException(
- status_code=400,
- detail="时间格式错误,请使用 YYYY-MM-DD HH:MM 格式"
- )
-
- # 启动独立线程执行耗时任务,避免阻塞事件循环
- worker = threading.Thread(target=_run_post_task, args=(request,), daemon=True)
- worker.start()
- return {
- "success": True,
- "message": "发布任务已开始执行",
- "task_id": task_id,
- "timestamp": datetime.now().isoformat(),
- }
- # ============== 批量任务接口:后端托管任务状态 ==============
- def _validate_item_paths(item: BatchPostItem) -> None:
- for path in item.image_paths:
- if not os.path.exists(path):
- logger.error(f"Image not found: {path}")
- raise HTTPException(status_code=400)
- def _run_batch_post_task(task_id: str) -> None:
- status = _get_task_status(task_id)
- if not status:
- return
- try:
- status["started_at"] = datetime.now().isoformat()
- _save_task_status(task_id, status)
- items: List[Dict[str, Any]] = status["items"]
- total = len(items)
- for idx, item in enumerate(items):
- # 更新当前项状态
- item["status"] = "running"
- item["started_at"] = datetime.now().isoformat()
- _save_task_status(task_id, status)
- try:
- # 使用锁保护底层浏览器自动化(独占)
- with post_lock:
- auto_post_service.post_to_xiaohongshu(
- image_paths=item["image_paths"],
- title=item["title"],
- description=item["description"],
- topics=item["topics"],
- schedule_time=item.get("schedule_time"),
- if_clean=item.get("if_clean", False),
- )
- # 成功
- item["status"] = "success"
- item["completed_at"] = datetime.now().isoformat()
- status["success"] += 1
- status["last_run"] = item.get("schedule_time")
- status["next_run"] = items[idx + 1].get("schedule_time") if idx < total - 1 else None
- # 成功后批量更新该组合内的所有处理记录为“已发布”
- try:
- record_ids: List[int] = item.get("record_ids", []) or []
- for pid in record_ids:
- try:
- db_ops.update_process_record(int(pid), {"status": "已发布"})
- except Exception as ue:
- logger.warning(f"更新处理记录状态失败 process_id={pid}: {ue}")
- except Exception as ge:
- logger.warning(f"批量更新组合关联记录状态失败: {ge}")
- except Exception as e:
- logger.exception(f"Batch post item failed: {e}")
- item["status"] = "failure"
- item["error"] = str(e)
- item["completed_at"] = datetime.now().isoformat()
- status["failure"] += 1
- status["last_run"] = item.get("schedule_time")
- status["next_run"] = items[idx + 1].get("schedule_time") if idx < total - 1 else None
- # 每个子任务结束都保存一次
- _save_task_status(task_id, status)
- status["is_completed"] = True
- status["completed_at"] = datetime.now().isoformat()
- _save_task_status(task_id, status)
- except Exception as e:
- logger.exception(f"Batch task failed: {e}")
- status["is_completed"] = True
- status["completed_at"] = datetime.now().isoformat()
- status["error"] = str(e)
- _save_task_status(task_id, status)
- @router.post("/tasks/batch", response_model=CreateBatchTaskResponse, tags=["发布管理"])
- async def create_batch_post_task(request: BatchPostRequest):
- # 验证各项图片与时间格式
- for item in request.tasks:
- _validate_item_paths(item)
- if item.schedule_time:
- try:
- datetime.strptime(item.schedule_time, "%Y-%m-%d %H:%M")
- except ValueError:
- raise HTTPException(status_code=400, detail="时间格式错误,请使用 YYYY-MM-DD HH:MM 格式")
- task_id = f"task_{int(time.time() * 1000)}"
- created_at = datetime.now().isoformat()
- # 初始化任务状态
- items_status: List[Dict[str, Any]] = []
- for idx, item in enumerate(request.tasks):
- items_status.append({
- "index": idx,
- "schedule_time": item.schedule_time,
- "status": "pending",
- "error": None,
- "started_at": None,
- "completed_at": None,
- # 原始数据,供执行用
- "image_paths": item.image_paths,
- "title": item.title,
- "description": item.description,
- "topics": item.topics,
- "if_clean": item.if_clean,
- "record_ids": item.record_ids or [],
- })
- status: Dict[str, Any] = {
- "task_id": task_id,
- "name": request.name,
- "platform": request.platform,
- "frequency": request.frequency,
- "scheduled_times": request.scheduled_times,
- "total": len(items_status),
- "success": 0,
- "failure": 0,
- "next_run": items_status[0].get("schedule_time") if items_status else None,
- "last_run": None,
- "is_completed": False,
- "created_at": created_at,
- "started_at": None,
- "completed_at": None,
- "items": items_status,
- "user_id": request.user_id,
- }
- _save_task_status(task_id, status)
- # 启动后台线程执行任务
- worker = threading.Thread(target=_run_batch_post_task, args=(task_id,), daemon=True)
- worker.start()
- return CreateBatchTaskResponse(
- success=True,
- message="批量发布任务已创建",
- task_id=task_id,
- total_tasks=len(items_status),
- timestamp=created_at,
- )
- @router.get("/tasks/{task_id}/status", response_model=AutoPostTaskStatus, tags=["发布管理"])
- async def get_auto_post_task_status(task_id: str):
- status = _get_task_status(task_id)
- if not status:
- raise HTTPException(status_code=404, detail="任务不存在")
- # 过滤掉执行用原始字段,仅返回状态相关
- sanitized_items = []
- for item in status["items"]:
- sanitized_items.append({
- "index": item["index"],
- "schedule_time": item.get("schedule_time"),
- "status": item.get("status"),
- "error": item.get("error"),
- "started_at": item.get("started_at"),
- "completed_at": item.get("completed_at"),
- })
- return AutoPostTaskStatus(
- task_id=status["task_id"],
- name=status.get("name", "定时发布"),
- platform=status.get("platform", "xiaohongshu"),
- frequency=status.get("frequency"),
- scheduled_times=status.get("scheduled_times", []),
- total=status.get("total", 0),
- success=status.get("success", 0),
- failure=status.get("failure", 0),
- next_run=status.get("next_run"),
- last_run=status.get("last_run"),
- is_completed=status.get("is_completed", False),
- created_at=status.get("created_at"),
- started_at=status.get("started_at"),
- completed_at=status.get("completed_at"),
- items=sanitized_items,
- )
- @router.get("/tasks", tags=["发布管理"])
- async def list_auto_post_tasks(user_id: Optional[int] = None):
- tasks = _list_tasks(user_id)
- result = []
- for status in tasks:
- result.append({
- "task_id": status["task_id"],
- "name": status.get("name"),
- "platform": status.get("platform"),
- "frequency": status.get("frequency"),
- "scheduled_times": status.get("scheduled_times", []),
- "total": status.get("total", 0),
- "success": status.get("success", 0),
- "failure": status.get("failure", 0),
- "next_run": status.get("next_run"),
- "last_run": status.get("last_run"),
- "is_completed": status.get("is_completed", False),
- "created_at": status.get("created_at"),
- "started_at": status.get("started_at"),
- "completed_at": status.get("completed_at"),
- })
- return {"tasks": result, "total": len(result)}
- if __name__ == "__main__":
- abs_path = os.path.abspath(__file__).replace("\\", "/").split("backend")[0]
- print(abs_path)
|