auto_post_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import os
  2. import json
  3. from typing import List, Optional, Dict, Any
  4. from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
  5. from pydantic import BaseModel, Field, validator
  6. from datetime import datetime
  7. import time
  8. import threading
  9. import sys
  10. current_dir = os.path.dirname(os.path.abspath(__file__))
  11. backend_dir = os.path.dirname(current_dir)
  12. if backend_dir not in sys.path:
  13. sys.path.insert(0, backend_dir)
  14. from services.auto_post_service import auto_post_service
  15. from modules.database.operations import DatabaseOperations
  16. from utils.logger_config import setup_logger
  17. logger = setup_logger(__name__)
  18. router = APIRouter()
  19. # 创建线程锁防止并发冲突
  20. post_lock = threading.Lock()
  21. db_ops = DatabaseOperations()
  22. # ============== 任务状态托管 ==============
  23. # 简单的内存任务注册表(进程内有效,重启后清空)
  24. _task_registry_lock = threading.Lock()
  25. _task_registry: Dict[str, Dict[str, Any]] = {}
  26. class AutoPostTaskItemStatus(BaseModel):
  27. index: int
  28. schedule_time: Optional[str] = None
  29. status: str = "pending" # pending | running | success | failure
  30. error: Optional[str] = None
  31. started_at: Optional[str] = None
  32. completed_at: Optional[str] = None
  33. class AutoPostTaskStatus(BaseModel):
  34. task_id: str
  35. name: str
  36. platform: str
  37. frequency: Optional[str] = None
  38. scheduled_times: List[str] = []
  39. total: int
  40. success: int
  41. failure: int
  42. next_run: Optional[str] = None
  43. last_run: Optional[str] = None
  44. is_completed: bool = False
  45. created_at: str
  46. started_at: Optional[str] = None
  47. completed_at: Optional[str] = None
  48. items: List[AutoPostTaskItemStatus] = []
  49. class BatchPostItem(BaseModel):
  50. image_paths: List[str]
  51. title: str
  52. description: str
  53. topics: List[str]
  54. schedule_time: Optional[str] = None
  55. if_clean: bool = False
  56. # 新增:该发布项关联的处理记录ID列表(同一组合内的拆分记录)
  57. record_ids: List[int] = []
  58. class BatchPostRequest(BaseModel):
  59. name: str = Field("定时发布")
  60. platform: str = Field(..., description="发布平台,如 xiaohongshu")
  61. frequency: Optional[str] = None
  62. scheduled_times: List[str] = []
  63. tasks: List[BatchPostItem]
  64. # 新增:创建该任务的用户ID,用于任务隔离
  65. user_id: Optional[int] = None
  66. def _save_task_status(task_id: str, status: Dict[str, Any]) -> None:
  67. with _task_registry_lock:
  68. _task_registry[task_id] = status
  69. def _get_task_status(task_id: str) -> Optional[Dict[str, Any]]:
  70. with _task_registry_lock:
  71. return _task_registry.get(task_id)
  72. def _list_tasks(user_id: Optional[int] = None) -> List[Dict[str, Any]]:
  73. with _task_registry_lock:
  74. tasks = list(_task_registry.values())
  75. if user_id is not None:
  76. tasks = [t for t in tasks if t.get("user_id") == user_id]
  77. # 返回按创建时间倒序的列表
  78. return sorted(tasks, key=lambda x: x.get("created_at", ""), reverse=True)
  79. class XiaohongshuPostRequest(BaseModel):
  80. """小红书发布请求数据结构"""
  81. image_paths: List[str]
  82. title: str
  83. description: str
  84. topics: List[str]
  85. schedule_time: Optional[str] = None
  86. if_clean: bool = False
  87. class PostResponse(BaseModel):
  88. """API响应数据结构"""
  89. success: bool
  90. message: str
  91. task_id: Optional[str] = None
  92. timestamp: str
  93. class CreateBatchTaskResponse(BaseModel):
  94. success: bool
  95. message: str
  96. task_id: str
  97. total_tasks: int
  98. timestamp: str
  99. def _run_post_task(req: XiaohongshuPostRequest) -> None:
  100. """在后台线程中执行发布任务,避免阻塞事件循环。"""
  101. try:
  102. with post_lock:
  103. auto_post_service.post_to_xiaohongshu(
  104. image_paths=req.image_paths,
  105. title=req.title,
  106. description=req.description,
  107. topics=req.topics,
  108. schedule_time=req.schedule_time,
  109. if_clean=req.if_clean,
  110. )
  111. logger.info("Background post task finished successfully")
  112. except Exception as e:
  113. logger.exception(f"Background post task failed: {str(e)}")
  114. @router.post("/xiaohongshu/post", response_model=PostResponse, tags=["发布管理"])
  115. async def post_to_xiaohongshu(request: XiaohongshuPostRequest, background_tasks: BackgroundTasks):
  116. """
  117. 发布内容到小红书
  118. - **image_paths**: 图片路径列表
  119. - **title**: 内容标题
  120. - **description**: 详细描述
  121. - **topics**: 话题标签列表
  122. - **schedule_time**: 定时发布时间 (格式: YYYY-MM-DD HH:MM)
  123. - **if_clean**: 是否清理浏览器缓存 (默认False)
  124. """
  125. task_id = f"task_{int(time.time())}"
  126. logger.info(f"Received new post task {task_id}")
  127. # 验证图片路径
  128. for path in request.image_paths:
  129. if not os.path.exists(path):
  130. logger.error(f"Image not found: {path}")
  131. raise HTTPException(
  132. status_code=400,
  133. )
  134. # 验证时间格式
  135. if request.schedule_time:
  136. try:
  137. datetime.strptime(request.schedule_time, "%Y-%m-%d %H:%M")
  138. except ValueError:
  139. logger.error(f"Invalid time format: {request.schedule_time}")
  140. raise HTTPException(
  141. status_code=400,
  142. detail="时间格式错误,请使用 YYYY-MM-DD HH:MM 格式"
  143. )
  144. # 启动独立线程执行耗时任务,避免阻塞事件循环
  145. worker = threading.Thread(target=_run_post_task, args=(request,), daemon=True)
  146. worker.start()
  147. return {
  148. "success": True,
  149. "message": "发布任务已开始执行",
  150. "task_id": task_id,
  151. "timestamp": datetime.now().isoformat(),
  152. }
  153. # ============== 批量任务接口:后端托管任务状态 ==============
  154. def _validate_item_paths(item: BatchPostItem) -> None:
  155. for path in item.image_paths:
  156. if not os.path.exists(path):
  157. logger.error(f"Image not found: {path}")
  158. raise HTTPException(status_code=400)
  159. def _run_batch_post_task(task_id: str) -> None:
  160. status = _get_task_status(task_id)
  161. if not status:
  162. return
  163. try:
  164. status["started_at"] = datetime.now().isoformat()
  165. _save_task_status(task_id, status)
  166. items: List[Dict[str, Any]] = status["items"]
  167. total = len(items)
  168. for idx, item in enumerate(items):
  169. # 更新当前项状态
  170. item["status"] = "running"
  171. item["started_at"] = datetime.now().isoformat()
  172. _save_task_status(task_id, status)
  173. try:
  174. # 使用锁保护底层浏览器自动化(独占)
  175. with post_lock:
  176. auto_post_service.post_to_xiaohongshu(
  177. image_paths=item["image_paths"],
  178. title=item["title"],
  179. description=item["description"],
  180. topics=item["topics"],
  181. schedule_time=item.get("schedule_time"),
  182. if_clean=item.get("if_clean", False),
  183. )
  184. # 成功
  185. item["status"] = "success"
  186. item["completed_at"] = datetime.now().isoformat()
  187. status["success"] += 1
  188. status["last_run"] = item.get("schedule_time")
  189. status["next_run"] = items[idx + 1].get("schedule_time") if idx < total - 1 else None
  190. # 成功后批量更新该组合内的所有处理记录为“已发布”
  191. try:
  192. record_ids: List[int] = item.get("record_ids", []) or []
  193. for pid in record_ids:
  194. try:
  195. db_ops.update_process_record(int(pid), {"status": "已发布"})
  196. except Exception as ue:
  197. logger.warning(f"更新处理记录状态失败 process_id={pid}: {ue}")
  198. except Exception as ge:
  199. logger.warning(f"批量更新组合关联记录状态失败: {ge}")
  200. except Exception as e:
  201. logger.exception(f"Batch post item failed: {e}")
  202. item["status"] = "failure"
  203. item["error"] = str(e)
  204. item["completed_at"] = datetime.now().isoformat()
  205. status["failure"] += 1
  206. status["last_run"] = item.get("schedule_time")
  207. status["next_run"] = items[idx + 1].get("schedule_time") if idx < total - 1 else None
  208. # 每个子任务结束都保存一次
  209. _save_task_status(task_id, status)
  210. status["is_completed"] = True
  211. status["completed_at"] = datetime.now().isoformat()
  212. _save_task_status(task_id, status)
  213. except Exception as e:
  214. logger.exception(f"Batch task failed: {e}")
  215. status["is_completed"] = True
  216. status["completed_at"] = datetime.now().isoformat()
  217. status["error"] = str(e)
  218. _save_task_status(task_id, status)
  219. @router.post("/tasks/batch", response_model=CreateBatchTaskResponse, tags=["发布管理"])
  220. async def create_batch_post_task(request: BatchPostRequest):
  221. # 验证各项图片与时间格式
  222. for item in request.tasks:
  223. _validate_item_paths(item)
  224. if item.schedule_time:
  225. try:
  226. datetime.strptime(item.schedule_time, "%Y-%m-%d %H:%M")
  227. except ValueError:
  228. raise HTTPException(status_code=400, detail="时间格式错误,请使用 YYYY-MM-DD HH:MM 格式")
  229. task_id = f"task_{int(time.time() * 1000)}"
  230. created_at = datetime.now().isoformat()
  231. # 初始化任务状态
  232. items_status: List[Dict[str, Any]] = []
  233. for idx, item in enumerate(request.tasks):
  234. items_status.append({
  235. "index": idx,
  236. "schedule_time": item.schedule_time,
  237. "status": "pending",
  238. "error": None,
  239. "started_at": None,
  240. "completed_at": None,
  241. # 原始数据,供执行用
  242. "image_paths": item.image_paths,
  243. "title": item.title,
  244. "description": item.description,
  245. "topics": item.topics,
  246. "if_clean": item.if_clean,
  247. "record_ids": item.record_ids or [],
  248. })
  249. status: Dict[str, Any] = {
  250. "task_id": task_id,
  251. "name": request.name,
  252. "platform": request.platform,
  253. "frequency": request.frequency,
  254. "scheduled_times": request.scheduled_times,
  255. "total": len(items_status),
  256. "success": 0,
  257. "failure": 0,
  258. "next_run": items_status[0].get("schedule_time") if items_status else None,
  259. "last_run": None,
  260. "is_completed": False,
  261. "created_at": created_at,
  262. "started_at": None,
  263. "completed_at": None,
  264. "items": items_status,
  265. "user_id": request.user_id,
  266. }
  267. _save_task_status(task_id, status)
  268. # 启动后台线程执行任务
  269. worker = threading.Thread(target=_run_batch_post_task, args=(task_id,), daemon=True)
  270. worker.start()
  271. return CreateBatchTaskResponse(
  272. success=True,
  273. message="批量发布任务已创建",
  274. task_id=task_id,
  275. total_tasks=len(items_status),
  276. timestamp=created_at,
  277. )
  278. @router.get("/tasks/{task_id}/status", response_model=AutoPostTaskStatus, tags=["发布管理"])
  279. async def get_auto_post_task_status(task_id: str):
  280. status = _get_task_status(task_id)
  281. if not status:
  282. raise HTTPException(status_code=404, detail="任务不存在")
  283. # 过滤掉执行用原始字段,仅返回状态相关
  284. sanitized_items = []
  285. for item in status["items"]:
  286. sanitized_items.append({
  287. "index": item["index"],
  288. "schedule_time": item.get("schedule_time"),
  289. "status": item.get("status"),
  290. "error": item.get("error"),
  291. "started_at": item.get("started_at"),
  292. "completed_at": item.get("completed_at"),
  293. })
  294. return AutoPostTaskStatus(
  295. task_id=status["task_id"],
  296. name=status.get("name", "定时发布"),
  297. platform=status.get("platform", "xiaohongshu"),
  298. frequency=status.get("frequency"),
  299. scheduled_times=status.get("scheduled_times", []),
  300. total=status.get("total", 0),
  301. success=status.get("success", 0),
  302. failure=status.get("failure", 0),
  303. next_run=status.get("next_run"),
  304. last_run=status.get("last_run"),
  305. is_completed=status.get("is_completed", False),
  306. created_at=status.get("created_at"),
  307. started_at=status.get("started_at"),
  308. completed_at=status.get("completed_at"),
  309. items=sanitized_items,
  310. )
  311. @router.get("/tasks", tags=["发布管理"])
  312. async def list_auto_post_tasks(user_id: Optional[int] = None):
  313. tasks = _list_tasks(user_id)
  314. result = []
  315. for status in tasks:
  316. result.append({
  317. "task_id": status["task_id"],
  318. "name": status.get("name"),
  319. "platform": status.get("platform"),
  320. "frequency": status.get("frequency"),
  321. "scheduled_times": status.get("scheduled_times", []),
  322. "total": status.get("total", 0),
  323. "success": status.get("success", 0),
  324. "failure": status.get("failure", 0),
  325. "next_run": status.get("next_run"),
  326. "last_run": status.get("last_run"),
  327. "is_completed": status.get("is_completed", False),
  328. "created_at": status.get("created_at"),
  329. "started_at": status.get("started_at"),
  330. "completed_at": status.get("completed_at"),
  331. })
  332. return {"tasks": result, "total": len(result)}
  333. if __name__ == "__main__":
  334. abs_path = os.path.abspath(__file__).replace("\\", "/").split("backend")[0]
  335. print(abs_path)