""" 任务管理模块 提供TaskManager类和StepStatus枚举,用于管理多步骤任务的执行 """ import asyncio import json import os import pickle import threading from collections import deque from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set from pathlib import Path from .config import get_config from .logger import get_logger class StepStatus: """步骤状态枚举""" PENDING = "pending" # 待执行 RUNNING = "running" # 执行中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 SKIPPED = "skipped" # 跳过 class TaskManager: """ 任务管理器 - 支持断点续传和步骤重试,支持并行执行 特性: 1. 自动保存每个步骤的执行状态 2. 支持从上次中断的地方继续执行 3. 支持强制重新执行指定步骤 4. 支持步骤依赖管理 5. 支持步骤输出缓存和恢复 6. 支持并行执行(自动检测可并行步骤) 使用示例: >>> manager = TaskManager(state_file="state.json", cache_dir="cache") >>> manager.register_step("step1", my_function, depends_on=[]) >>> manager.register_step("step2", my_function2, depends_on=["step1"]) >>> manager.register_step("step3", my_function3, depends_on=["step2"]) >>> manager.register_step("step4", my_function4, depends_on=["step2"]) >>> # 顺序执行 >>> manager.run_all() >>> # 并行执行(自动检测step3和step4可以并行) >>> manager.run_all_parallel(max_workers=2) """ def __init__(self, state_file: str = "task_state.json", cache_dir: str = "task_cache"): """ 初始化任务管理器 Args: state_file: 状态文件路径 cache_dir: 缓存目录路径,用于存储步骤输出缓存 """ self.state_file = state_file self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.logger = get_logger("taskflow.manager") # 加载已有状态 self.state = self._load_state() self.steps: Dict[str, Dict] = {} # 线程锁,用于保护状态文件的并发访问 self._state_lock = threading.Lock() def _load_state(self) -> Dict: """加载任务状态""" if os.path.exists(self.state_file): try: with open(self.state_file, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: self.logger.warning(f"加载任务状态失败,将创建新状态:{e}") return { "steps": {}, "metadata": { "created_at": datetime.now().isoformat(), "last_updated": None } } def _save_state(self): """保存任务状态(线程安全)""" with self._state_lock: self.state["metadata"]["last_updated"] = datetime.now().isoformat() try: with open(self.state_file, 'w', encoding='utf-8') as f: json.dump(self.state, f, ensure_ascii=False, indent=2) except Exception as e: self.logger.error(f"保存状态文件失败: {e}") def _get_cache_path(self, step_name: str) -> Path: """获取步骤缓存文件路径""" return self.cache_dir / f"{step_name}.pkl" def register_step( self, step_name: str, func: Callable, depends_on: Optional[List[str]] = None, force_rerun: bool = False ): """ 注册步骤 Args: step_name: 步骤名称(唯一标识) func: 步骤函数 depends_on: 依赖的步骤名称列表 force_rerun: 是否强制重新执行(忽略已完成状态) """ if depends_on is None: depends_on = [] self.steps[step_name] = { "func": func, "depends_on": depends_on, "force_rerun": force_rerun } # 初始化步骤状态(如果不存在) if step_name not in self.state["steps"]: self.state["steps"][step_name] = { "status": StepStatus.PENDING, "started_at": None, "completed_at": None, "error": None, "output_file": None } def get_step_status(self, step_name: str) -> str: """获取步骤状态(线程安全)""" with self._state_lock: if step_name not in self.state["steps"]: return StepStatus.PENDING return self.state["steps"][step_name]["status"] def set_step_status(self, step_name: str, status: str, error: Optional[str] = None): """设置步骤状态(线程安全)""" with self._state_lock: if step_name not in self.state["steps"]: self.state["steps"][step_name] = { "status": status, "started_at": None, "completed_at": None, "error": None, "output_file": None } step_state = self.state["steps"][step_name] step_state["status"] = status if status == StepStatus.RUNNING: step_state["started_at"] = datetime.now().isoformat() elif status in [StepStatus.COMPLETED, StepStatus.FAILED]: step_state["completed_at"] = datetime.now().isoformat() if error: step_state["error"] = str(error) self._save_state() def save_step_output(self, step_name: str, output: Any): """保存步骤输出到缓存(线程安全)""" cache_path = self._get_cache_path(step_name) try: with open(cache_path, 'wb') as f: pickle.dump(output, f) with self._state_lock: self.state["steps"][step_name]["output_file"] = str(cache_path) self._save_state() except Exception as e: self.logger.warning(f"保存步骤 {step_name} 的输出失败: {e}") def load_step_output(self, step_name: str) -> Optional[Any]: """从缓存加载步骤输出""" step_state = self.state["steps"].get(step_name, {}) output_file = step_state.get("output_file") if output_file and os.path.exists(output_file): try: with open(output_file, 'rb') as f: return pickle.load(f) except Exception as e: self.logger.warning(f"加载步骤 {step_name} 的输出失败: {e}") return None def check_dependencies(self, step_name: str) -> bool: """检查步骤的依赖是否都已完成(线程安全)""" step_info = self.steps.get(step_name, {}) depends_on = step_info.get("depends_on", []) for dep_step in depends_on: dep_status = self.get_step_status(dep_step) if dep_status != StepStatus.COMPLETED: self.logger.warning(f"步骤 {step_name} 的依赖 {dep_step} 尚未完成(状态: {dep_status})") return False return True def _get_ready_steps(self, step_order: Optional[List[str]] = None) -> List[str]: """ 获取当前可以执行的步骤(依赖已满足且未完成) Args: step_order: 步骤顺序(如果为None,使用所有注册的步骤) Returns: 可以执行的步骤列表 """ if step_order is None: step_order = list(self.steps.keys()) ready_steps = [] for step_name in step_order: if step_name not in self.steps: continue # 检查是否需要执行 if not self.should_run_step(step_name): continue # 检查依赖是否满足 if self.check_dependencies(step_name): ready_steps.append(step_name) return ready_steps def _topological_sort(self, step_order: Optional[List[str]] = None) -> List[List[str]]: """ 使用拓扑排序将步骤分组,每组内的步骤可以并行执行 Args: step_order: 步骤顺序(如果为None,使用所有注册的步骤) Returns: 步骤批次列表,每个批次内的步骤可以并行执行 """ if step_order is None: step_order = list(self.steps.keys()) # 构建依赖图和入度计数 in_degree: Dict[str, int] = {step: 0 for step in step_order if step in self.steps} graph: Dict[str, List[str]] = {step: [] for step in step_order if step in self.steps} for step_name in step_order: if step_name not in self.steps: continue step_info = self.steps[step_name] depends_on = step_info.get("depends_on", []) for dep in depends_on: if dep in graph: graph[dep].append(step_name) in_degree[step_name] = in_degree.get(step_name, 0) + 1 # 拓扑排序 batches: List[List[str]] = [] queue = deque([step for step in step_order if step in in_degree and in_degree[step] == 0]) while queue: # 当前批次:所有入度为0的步骤 current_batch = [] batch_size = len(queue) for _ in range(batch_size): step = queue.popleft() if step in self.steps: current_batch.append(step) # 减少依赖此步骤的步骤的入度 for dependent in graph.get(step, []): in_degree[dependent] -= 1 if in_degree[dependent] == 0: queue.append(dependent) if current_batch: batches.append(current_batch) return batches def should_run_step(self, step_name: str) -> bool: """判断步骤是否应该执行""" step_info = self.steps.get(step_name, {}) force_rerun = step_info.get("force_rerun", False) current_status = self.get_step_status(step_name) # 强制重新执行 if force_rerun: return True # 已完成或跳过,不需要重新执行 if current_status == StepStatus.COMPLETED: return False # 待执行或失败,需要执行 if current_status in [StepStatus.PENDING, StepStatus.FAILED]: return True # 执行中(可能是上次中断),需要重新执行 if current_status == StepStatus.RUNNING: return True return False def run_step(self, step_name: str, *args, **kwargs) -> Any: """ 执行单个步骤 Args: step_name: 步骤名称 *args, **kwargs: 传递给步骤函数的参数 Returns: 步骤的输出结果 """ if step_name not in self.steps: raise ValueError(f"步骤 {step_name} 未注册") # 检查是否需要执行 if not self.should_run_step(step_name): self.logger.info(f"步骤 {step_name} 已完成,跳过执行。使用 load_step_output() 获取结果。") return self.load_step_output(step_name) # 检查依赖 if not self.check_dependencies(step_name): raise RuntimeError(f"步骤 {step_name} 的依赖未满足") step_info = self.steps[step_name] func = step_info["func"] # 标记为执行中 self.set_step_status(step_name, StepStatus.RUNNING) try: self.logger.info(f"开始执行步骤: {step_name}") # 执行步骤函数(支持同步和异步函数) import asyncio import inspect # 检测函数是否为异步函数 if inspect.iscoroutinefunction(func): # 异步函数:在当前事件循环中运行,如果没有则创建新的 try: loop = asyncio.get_event_loop() if loop.is_running(): # 如果事件循环正在运行,需要在新线程中运行 import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, func(*args, **kwargs)) output = future.result() else: output = loop.run_until_complete(func(*args, **kwargs)) except RuntimeError: # 没有事件循环,创建新的 output = asyncio.run(func(*args, **kwargs)) else: # 同步函数:直接调用 output = func(*args, **kwargs) # 保存输出 self.save_step_output(step_name, output) # 标记为已完成 self.set_step_status(step_name, StepStatus.COMPLETED) self.logger.info(f"步骤 {step_name} 执行完成") return output except Exception as e: # 标记为失败 self.set_step_status(step_name, StepStatus.FAILED, error=str(e)) self.logger.error(f"步骤 {step_name} 执行失败: {e}", exc_info=True) raise def run_all(self, step_order: Optional[List[str]] = None): """ 按顺序执行所有步骤 Args: step_order: 步骤执行顺序(如果为None,则按注册顺序执行) """ if step_order is None: step_order = list(self.steps.keys()) for step_name in step_order: if step_name not in self.steps: self.logger.warning(f"步骤 {step_name} 未注册,跳过") continue try: self.run_step(step_name) except Exception as e: self.logger.error(f"执行步骤 {step_name} 时出错: {e}", exc_info=True) # 可以选择继续执行后续步骤,或者中断 # 这里选择中断,可以根据需要修改 raise def run_all_parallel( self, step_order: Optional[List[str]] = None, max_workers: Optional[int] = None, continue_on_error: bool = False ): """ 并行执行所有步骤(自动检测可并行步骤) 该方法会自动分析步骤依赖关系,将可以并行执行的步骤分组执行。 例如:如果step3和step4都依赖step2,它们会在step2完成后并行执行。 Args: step_order: 步骤执行顺序(如果为None,则按注册顺序执行) max_workers: 最大并行工作线程数(如果为None,使用批次大小) continue_on_error: 遇到错误时是否继续执行后续步骤(默认False,遇到错误立即停止) 使用示例: >>> manager.register_step("step1", func1, depends_on=[]) >>> manager.register_step("step2", func2, depends_on=["step1"]) >>> manager.register_step("step3", func3, depends_on=["step2"]) >>> manager.register_step("step4", func4, depends_on=["step2"]) >>> # step3和step4会在step2完成后并行执行 >>> manager.run_all_parallel(max_workers=2) """ if step_order is None: step_order = list(self.steps.keys()) # 获取拓扑排序的批次 batches = self._topological_sort(step_order) if not batches: self.logger.warning("没有可执行的步骤") return self.logger.info(f"共 {len(batches)} 个执行批次") # 按批次执行 for batch_idx, batch in enumerate(batches, 1): # 过滤出需要执行的步骤 ready_steps = [step for step in batch if self.should_run_step(step) and self.check_dependencies(step)] if not ready_steps: self.logger.info(f"批次 {batch_idx}/{len(batches)}: 所有步骤已完成,跳过") continue self.logger.info(f"批次 {batch_idx}/{len(batches)}: 执行步骤 {ready_steps} (共 {len(ready_steps)} 个)") # 如果只有一个步骤,直接执行(避免线程开销) if len(ready_steps) == 1: try: self.run_step(ready_steps[0]) except Exception as e: self.logger.error(f"执行步骤 {ready_steps[0]} 时出错: {e}", exc_info=True) if not continue_on_error: raise else: # 并行执行多个步骤 workers = max_workers if max_workers is not None else len(ready_steps) with ThreadPoolExecutor(max_workers=workers) as executor: # 提交所有步骤 future_to_step = { executor.submit(self.run_step, step_name): step_name for step_name in ready_steps } # 等待所有步骤完成 for future in as_completed(future_to_step): step_name = future_to_step[future] try: result = future.result() self.logger.info(f"步骤 {step_name} 执行完成") except Exception as e: self.logger.error(f"执行步骤 {step_name} 时出错: {e}", exc_info=True) if not continue_on_error: # 取消其他未完成的步骤 for f in future_to_step: f.cancel() raise async def run_step_async(self, step_name: str, *args, **kwargs) -> Any: """ 异步执行单个步骤 Args: step_name: 步骤名称 *args, **kwargs: 传递给步骤函数的参数 Returns: 步骤的输出结果 """ if step_name not in self.steps: raise ValueError(f"步骤 {step_name} 未注册") # 检查是否需要执行 if not self.should_run_step(step_name): self.logger.info(f"步骤 {step_name} 已完成,跳过执行。使用 load_step_output() 获取结果。") return self.load_step_output(step_name) # 检查依赖 if not self.check_dependencies(step_name): raise RuntimeError(f"步骤 {step_name} 的依赖未满足") step_info = self.steps[step_name] func = step_info["func"] # 标记为执行中 self.set_step_status(step_name, StepStatus.RUNNING) try: self.logger.info(f"开始执行步骤: {step_name}") # 执行步骤函数(支持同步和异步函数) import inspect # 检测函数是否为异步函数 if inspect.iscoroutinefunction(func): # 异步函数:直接 await output = await func(*args, **kwargs) else: # 同步函数:在线程池中运行,避免阻塞事件循环 loop = asyncio.get_event_loop() output = await loop.run_in_executor(None, func, *args, **kwargs) # 保存输出 self.save_step_output(step_name, output) # 标记为已完成 self.set_step_status(step_name, StepStatus.COMPLETED) self.logger.info(f"步骤 {step_name} 执行完成") return output except Exception as e: # 标记为失败 self.set_step_status(step_name, StepStatus.FAILED, error=str(e)) self.logger.error(f"步骤 {step_name} 执行失败: {e}", exc_info=True) raise async def run_all_async( self, step_order: Optional[List[str]] = None, continue_on_error: bool = False ): """ 异步并行执行所有步骤(自动检测可并行步骤) 该方法会自动分析步骤依赖关系,将可以并行执行的步骤分组执行。 使用 asyncio.gather() 实现真正的异步并发,比 ThreadPoolExecutor 更高效。 例如:如果step3和step4都依赖step2,它们会在step2完成后并行执行。 Args: step_order: 步骤执行顺序(如果为None,则按注册顺序执行) continue_on_error: 遇到错误时是否继续执行后续步骤(默认False,遇到错误立即停止) 使用示例: >>> async def main(): ... manager = TaskManager(...) ... manager.register_step("step1", async_func1, depends_on=[]) ... manager.register_step("step2", async_func2, depends_on=["step1"]) ... manager.register_step("step3", async_func3, depends_on=["step2"]) ... manager.register_step("step4", async_func4, depends_on=["step2"]) ... # step3和step4会在step2完成后并行执行 ... await manager.run_all_async() """ if step_order is None: step_order = list(self.steps.keys()) # 获取拓扑排序的批次 batches = self._topological_sort(step_order) if not batches: self.logger.warning("没有可执行的步骤") return self.logger.info(f"共 {len(batches)} 个执行批次(异步并发执行)") # 按批次执行 for batch_idx, batch in enumerate(batches, 1): # 过滤出需要执行的步骤 ready_steps = [step for step in batch if self.should_run_step(step) and self.check_dependencies(step)] if not ready_steps: self.logger.info(f"批次 {batch_idx}/{len(batches)}: 所有步骤已完成,跳过") continue self.logger.info(f"批次 {batch_idx}/{len(batches)}: 异步并发执行步骤 {ready_steps} (共 {len(ready_steps)} 个)") # 如果只有一个步骤,直接执行(避免并发开销) if len(ready_steps) == 1: try: await self.run_step_async(ready_steps[0]) except Exception as e: self.logger.error(f"执行步骤 {ready_steps[0]} 时出错: {e}", exc_info=True) if not continue_on_error: raise else: # 异步并发执行多个步骤 tasks = [self.run_step_async(step_name) for step_name in ready_steps] # 使用 asyncio.gather 并发执行所有步骤 results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果和异常 for step_name, result in zip(ready_steps, results): if isinstance(result, Exception): self.logger.error(f"执行步骤 {step_name} 时出错: {result}", exc_info=True) if not continue_on_error: raise result else: self.logger.info(f"步骤 {step_name} 执行完成") def reset_step(self, step_name: str): """重置步骤状态,使其可以重新执行""" if step_name in self.state["steps"]: self.state["steps"][step_name] = { "status": StepStatus.PENDING, "started_at": None, "completed_at": None, "error": None, "output_file": None } # 删除缓存文件 cache_path = self._get_cache_path(step_name) if cache_path.exists(): cache_path.unlink() self._save_state() self.logger.info(f"步骤 {step_name} 已重置") def reset_all(self): """重置所有步骤状态""" for step_name in list(self.state["steps"].keys()): self.reset_step(step_name) def get_summary(self) -> Dict: """获取任务执行摘要""" total = len(self.state["steps"]) completed = sum(1 for s in self.state["steps"].values() if s["status"] == StepStatus.COMPLETED) failed = sum(1 for s in self.state["steps"].values() if s["status"] == StepStatus.FAILED) pending = sum(1 for s in self.state["steps"].values() if s["status"] == StepStatus.PENDING) return { "total": total, "completed": completed, "failed": failed, "pending": pending, "progress": f"{completed}/{total}" if total > 0 else "0/0" }