| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655 |
- """
- 任务管理模块
- 提供TaskManager类和StepStatus枚举,用于管理多步骤任务的执行
- """
- import asyncio
- import json
- import os
- import pickle
- import threading
- from collections import deque
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from datetime import datetime
- from typing import Any, Callable, Dict, List, Optional, Set
- from pathlib import Path
- from .config import get_config
- from .logger import get_logger
- class StepStatus:
- """步骤状态枚举"""
- PENDING = "pending" # 待执行
- RUNNING = "running" # 执行中
- COMPLETED = "completed" # 已完成
- FAILED = "failed" # 失败
- SKIPPED = "skipped" # 跳过
- class TaskManager:
- """
- 任务管理器 - 支持断点续传和步骤重试,支持并行执行
-
- 特性:
- 1. 自动保存每个步骤的执行状态
- 2. 支持从上次中断的地方继续执行
- 3. 支持强制重新执行指定步骤
- 4. 支持步骤依赖管理
- 5. 支持步骤输出缓存和恢复
- 6. 支持并行执行(自动检测可并行步骤)
-
- 使用示例:
- >>> manager = TaskManager(state_file="state.json", cache_dir="cache")
- >>> manager.register_step("step1", my_function, depends_on=[])
- >>> manager.register_step("step2", my_function2, depends_on=["step1"])
- >>> manager.register_step("step3", my_function3, depends_on=["step2"])
- >>> manager.register_step("step4", my_function4, depends_on=["step2"])
- >>> # 顺序执行
- >>> manager.run_all()
- >>> # 并行执行(自动检测step3和step4可以并行)
- >>> manager.run_all_parallel(max_workers=2)
- """
- def __init__(self, state_file: str = "task_state.json", cache_dir: str = "task_cache"):
- """
- 初始化任务管理器
- Args:
- state_file: 状态文件路径
- cache_dir: 缓存目录路径,用于存储步骤输出缓存
- """
- self.state_file = state_file
- self.cache_dir = Path(cache_dir)
- self.cache_dir.mkdir(parents=True, exist_ok=True)
- self.logger = get_logger("taskflow.manager")
- # 加载已有状态
- self.state = self._load_state()
- self.steps: Dict[str, Dict] = {}
-
- # 线程锁,用于保护状态文件的并发访问
- self._state_lock = threading.Lock()
- def _load_state(self) -> Dict:
- """加载任务状态"""
- if os.path.exists(self.state_file):
- try:
- with open(self.state_file, 'r', encoding='utf-8') as f:
- return json.load(f)
- except Exception as e:
- self.logger.warning(f"加载任务状态失败,将创建新状态:{e}")
- return {
- "steps": {},
- "metadata": {
- "created_at": datetime.now().isoformat(),
- "last_updated": None
- }
- }
- def _save_state(self):
- """保存任务状态(线程安全)"""
- with self._state_lock:
- self.state["metadata"]["last_updated"] = datetime.now().isoformat()
- try:
- with open(self.state_file, 'w', encoding='utf-8') as f:
- json.dump(self.state, f, ensure_ascii=False, indent=2)
- except Exception as e:
- self.logger.error(f"保存状态文件失败: {e}")
-
- def _get_cache_path(self, step_name: str) -> Path:
- """获取步骤缓存文件路径"""
- return self.cache_dir / f"{step_name}.pkl"
- def register_step(
- self,
- step_name: str,
- func: Callable,
- depends_on: Optional[List[str]] = None,
- force_rerun: bool = False
- ):
- """
- 注册步骤
-
- Args:
- step_name: 步骤名称(唯一标识)
- func: 步骤函数
- depends_on: 依赖的步骤名称列表
- force_rerun: 是否强制重新执行(忽略已完成状态)
- """
- if depends_on is None:
- depends_on = []
-
- self.steps[step_name] = {
- "func": func,
- "depends_on": depends_on,
- "force_rerun": force_rerun
- }
- # 初始化步骤状态(如果不存在)
- if step_name not in self.state["steps"]:
- self.state["steps"][step_name] = {
- "status": StepStatus.PENDING,
- "started_at": None,
- "completed_at": None,
- "error": None,
- "output_file": None
- }
- def get_step_status(self, step_name: str) -> str:
- """获取步骤状态(线程安全)"""
- with self._state_lock:
- if step_name not in self.state["steps"]:
- return StepStatus.PENDING
- return self.state["steps"][step_name]["status"]
- def set_step_status(self, step_name: str, status: str, error: Optional[str] = None):
- """设置步骤状态(线程安全)"""
- with self._state_lock:
- if step_name not in self.state["steps"]:
- self.state["steps"][step_name] = {
- "status": status,
- "started_at": None,
- "completed_at": None,
- "error": None,
- "output_file": None
- }
-
- step_state = self.state["steps"][step_name]
- step_state["status"] = status
-
- if status == StepStatus.RUNNING:
- step_state["started_at"] = datetime.now().isoformat()
- elif status in [StepStatus.COMPLETED, StepStatus.FAILED]:
- step_state["completed_at"] = datetime.now().isoformat()
-
- if error:
- step_state["error"] = str(error)
-
- self._save_state()
- def save_step_output(self, step_name: str, output: Any):
- """保存步骤输出到缓存(线程安全)"""
- cache_path = self._get_cache_path(step_name)
- try:
- with open(cache_path, 'wb') as f:
- pickle.dump(output, f)
- with self._state_lock:
- self.state["steps"][step_name]["output_file"] = str(cache_path)
- self._save_state()
- except Exception as e:
- self.logger.warning(f"保存步骤 {step_name} 的输出失败: {e}")
- def load_step_output(self, step_name: str) -> Optional[Any]:
- """从缓存加载步骤输出"""
- step_state = self.state["steps"].get(step_name, {})
- output_file = step_state.get("output_file")
-
- if output_file and os.path.exists(output_file):
- try:
- with open(output_file, 'rb') as f:
- return pickle.load(f)
- except Exception as e:
- self.logger.warning(f"加载步骤 {step_name} 的输出失败: {e}")
- return None
- def check_dependencies(self, step_name: str) -> bool:
- """检查步骤的依赖是否都已完成(线程安全)"""
- step_info = self.steps.get(step_name, {})
- depends_on = step_info.get("depends_on", [])
-
- for dep_step in depends_on:
- dep_status = self.get_step_status(dep_step)
- if dep_status != StepStatus.COMPLETED:
- self.logger.warning(f"步骤 {step_name} 的依赖 {dep_step} 尚未完成(状态: {dep_status})")
- return False
- return True
-
- def _get_ready_steps(self, step_order: Optional[List[str]] = None) -> List[str]:
- """
- 获取当前可以执行的步骤(依赖已满足且未完成)
-
- Args:
- step_order: 步骤顺序(如果为None,使用所有注册的步骤)
-
- Returns:
- 可以执行的步骤列表
- """
- if step_order is None:
- step_order = list(self.steps.keys())
-
- ready_steps = []
- for step_name in step_order:
- if step_name not in self.steps:
- continue
-
- # 检查是否需要执行
- if not self.should_run_step(step_name):
- continue
-
- # 检查依赖是否满足
- if self.check_dependencies(step_name):
- ready_steps.append(step_name)
-
- return ready_steps
-
- def _topological_sort(self, step_order: Optional[List[str]] = None) -> List[List[str]]:
- """
- 使用拓扑排序将步骤分组,每组内的步骤可以并行执行
-
- Args:
- step_order: 步骤顺序(如果为None,使用所有注册的步骤)
-
- Returns:
- 步骤批次列表,每个批次内的步骤可以并行执行
- """
- if step_order is None:
- step_order = list(self.steps.keys())
-
- # 构建依赖图和入度计数
- in_degree: Dict[str, int] = {step: 0 for step in step_order if step in self.steps}
- graph: Dict[str, List[str]] = {step: [] for step in step_order if step in self.steps}
-
- for step_name in step_order:
- if step_name not in self.steps:
- continue
-
- step_info = self.steps[step_name]
- depends_on = step_info.get("depends_on", [])
-
- for dep in depends_on:
- if dep in graph:
- graph[dep].append(step_name)
- in_degree[step_name] = in_degree.get(step_name, 0) + 1
-
- # 拓扑排序
- batches: List[List[str]] = []
- queue = deque([step for step in step_order if step in in_degree and in_degree[step] == 0])
-
- while queue:
- # 当前批次:所有入度为0的步骤
- current_batch = []
- batch_size = len(queue)
-
- for _ in range(batch_size):
- step = queue.popleft()
- if step in self.steps:
- current_batch.append(step)
-
- # 减少依赖此步骤的步骤的入度
- for dependent in graph.get(step, []):
- in_degree[dependent] -= 1
- if in_degree[dependent] == 0:
- queue.append(dependent)
-
- if current_batch:
- batches.append(current_batch)
-
- return batches
- def should_run_step(self, step_name: str) -> bool:
- """判断步骤是否应该执行"""
- step_info = self.steps.get(step_name, {})
- force_rerun = step_info.get("force_rerun", False)
- current_status = self.get_step_status(step_name)
-
- # 强制重新执行
- if force_rerun:
- return True
-
- # 已完成或跳过,不需要重新执行
- if current_status == StepStatus.COMPLETED:
- return False
-
- # 待执行或失败,需要执行
- if current_status in [StepStatus.PENDING, StepStatus.FAILED]:
- return True
-
- # 执行中(可能是上次中断),需要重新执行
- if current_status == StepStatus.RUNNING:
- return True
-
- return False
- def run_step(self, step_name: str, *args, **kwargs) -> Any:
- """
- 执行单个步骤
-
- Args:
- step_name: 步骤名称
- *args, **kwargs: 传递给步骤函数的参数
-
- Returns:
- 步骤的输出结果
- """
- if step_name not in self.steps:
- raise ValueError(f"步骤 {step_name} 未注册")
-
- # 检查是否需要执行
- if not self.should_run_step(step_name):
- self.logger.info(f"步骤 {step_name} 已完成,跳过执行。使用 load_step_output() 获取结果。")
- return self.load_step_output(step_name)
-
- # 检查依赖
- if not self.check_dependencies(step_name):
- raise RuntimeError(f"步骤 {step_name} 的依赖未满足")
-
- step_info = self.steps[step_name]
- func = step_info["func"]
-
- # 标记为执行中
- self.set_step_status(step_name, StepStatus.RUNNING)
-
- try:
- self.logger.info(f"开始执行步骤: {step_name}")
- # 执行步骤函数(支持同步和异步函数)
- import asyncio
- import inspect
-
- # 检测函数是否为异步函数
- if inspect.iscoroutinefunction(func):
- # 异步函数:在当前事件循环中运行,如果没有则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_running():
- # 如果事件循环正在运行,需要在新线程中运行
- import concurrent.futures
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future = executor.submit(asyncio.run, func(*args, **kwargs))
- output = future.result()
- else:
- output = loop.run_until_complete(func(*args, **kwargs))
- except RuntimeError:
- # 没有事件循环,创建新的
- output = asyncio.run(func(*args, **kwargs))
- else:
- # 同步函数:直接调用
- output = func(*args, **kwargs)
-
- # 保存输出
- self.save_step_output(step_name, output)
-
- # 标记为已完成
- self.set_step_status(step_name, StepStatus.COMPLETED)
-
- self.logger.info(f"步骤 {step_name} 执行完成")
- return output
-
- except Exception as e:
- # 标记为失败
- self.set_step_status(step_name, StepStatus.FAILED, error=str(e))
- self.logger.error(f"步骤 {step_name} 执行失败: {e}", exc_info=True)
- raise
- def run_all(self, step_order: Optional[List[str]] = None):
- """
- 按顺序执行所有步骤
-
- Args:
- step_order: 步骤执行顺序(如果为None,则按注册顺序执行)
- """
- if step_order is None:
- step_order = list(self.steps.keys())
-
- for step_name in step_order:
- if step_name not in self.steps:
- self.logger.warning(f"步骤 {step_name} 未注册,跳过")
- continue
-
- try:
- self.run_step(step_name)
- except Exception as e:
- self.logger.error(f"执行步骤 {step_name} 时出错: {e}", exc_info=True)
- # 可以选择继续执行后续步骤,或者中断
- # 这里选择中断,可以根据需要修改
- raise
-
- def run_all_parallel(
- self,
- step_order: Optional[List[str]] = None,
- max_workers: Optional[int] = None,
- continue_on_error: bool = False
- ):
- """
- 并行执行所有步骤(自动检测可并行步骤)
-
- 该方法会自动分析步骤依赖关系,将可以并行执行的步骤分组执行。
- 例如:如果step3和step4都依赖step2,它们会在step2完成后并行执行。
-
- Args:
- step_order: 步骤执行顺序(如果为None,则按注册顺序执行)
- max_workers: 最大并行工作线程数(如果为None,使用批次大小)
- continue_on_error: 遇到错误时是否继续执行后续步骤(默认False,遇到错误立即停止)
-
- 使用示例:
- >>> manager.register_step("step1", func1, depends_on=[])
- >>> manager.register_step("step2", func2, depends_on=["step1"])
- >>> manager.register_step("step3", func3, depends_on=["step2"])
- >>> manager.register_step("step4", func4, depends_on=["step2"])
- >>> # step3和step4会在step2完成后并行执行
- >>> manager.run_all_parallel(max_workers=2)
- """
- if step_order is None:
- step_order = list(self.steps.keys())
-
- # 获取拓扑排序的批次
- batches = self._topological_sort(step_order)
-
- if not batches:
- self.logger.warning("没有可执行的步骤")
- return
-
- self.logger.info(f"共 {len(batches)} 个执行批次")
-
- # 按批次执行
- for batch_idx, batch in enumerate(batches, 1):
- # 过滤出需要执行的步骤
- ready_steps = [step for step in batch if self.should_run_step(step) and self.check_dependencies(step)]
-
- if not ready_steps:
- self.logger.info(f"批次 {batch_idx}/{len(batches)}: 所有步骤已完成,跳过")
- continue
-
- self.logger.info(f"批次 {batch_idx}/{len(batches)}: 执行步骤 {ready_steps} (共 {len(ready_steps)} 个)")
-
- # 如果只有一个步骤,直接执行(避免线程开销)
- if len(ready_steps) == 1:
- try:
- self.run_step(ready_steps[0])
- except Exception as e:
- self.logger.error(f"执行步骤 {ready_steps[0]} 时出错: {e}", exc_info=True)
- if not continue_on_error:
- raise
- else:
- # 并行执行多个步骤
- workers = max_workers if max_workers is not None else len(ready_steps)
- with ThreadPoolExecutor(max_workers=workers) as executor:
- # 提交所有步骤
- future_to_step = {
- executor.submit(self.run_step, step_name): step_name
- for step_name in ready_steps
- }
-
- # 等待所有步骤完成
- for future in as_completed(future_to_step):
- step_name = future_to_step[future]
- try:
- result = future.result()
- self.logger.info(f"步骤 {step_name} 执行完成")
- except Exception as e:
- self.logger.error(f"执行步骤 {step_name} 时出错: {e}", exc_info=True)
- if not continue_on_error:
- # 取消其他未完成的步骤
- for f in future_to_step:
- f.cancel()
- raise
- async def run_step_async(self, step_name: str, *args, **kwargs) -> Any:
- """
- 异步执行单个步骤
-
- Args:
- step_name: 步骤名称
- *args, **kwargs: 传递给步骤函数的参数
-
- Returns:
- 步骤的输出结果
- """
- if step_name not in self.steps:
- raise ValueError(f"步骤 {step_name} 未注册")
-
- # 检查是否需要执行
- if not self.should_run_step(step_name):
- self.logger.info(f"步骤 {step_name} 已完成,跳过执行。使用 load_step_output() 获取结果。")
- return self.load_step_output(step_name)
-
- # 检查依赖
- if not self.check_dependencies(step_name):
- raise RuntimeError(f"步骤 {step_name} 的依赖未满足")
-
- step_info = self.steps[step_name]
- func = step_info["func"]
-
- # 标记为执行中
- self.set_step_status(step_name, StepStatus.RUNNING)
-
- try:
- self.logger.info(f"开始执行步骤: {step_name}")
-
- # 执行步骤函数(支持同步和异步函数)
- import inspect
-
- # 检测函数是否为异步函数
- if inspect.iscoroutinefunction(func):
- # 异步函数:直接 await
- output = await func(*args, **kwargs)
- else:
- # 同步函数:在线程池中运行,避免阻塞事件循环
- loop = asyncio.get_event_loop()
- output = await loop.run_in_executor(None, func, *args, **kwargs)
-
- # 保存输出
- self.save_step_output(step_name, output)
-
- # 标记为已完成
- self.set_step_status(step_name, StepStatus.COMPLETED)
-
- self.logger.info(f"步骤 {step_name} 执行完成")
- return output
-
- except Exception as e:
- # 标记为失败
- self.set_step_status(step_name, StepStatus.FAILED, error=str(e))
- self.logger.error(f"步骤 {step_name} 执行失败: {e}", exc_info=True)
- raise
- async def run_all_async(
- self,
- step_order: Optional[List[str]] = None,
- continue_on_error: bool = False
- ):
- """
- 异步并行执行所有步骤(自动检测可并行步骤)
-
- 该方法会自动分析步骤依赖关系,将可以并行执行的步骤分组执行。
- 使用 asyncio.gather() 实现真正的异步并发,比 ThreadPoolExecutor 更高效。
- 例如:如果step3和step4都依赖step2,它们会在step2完成后并行执行。
-
- Args:
- step_order: 步骤执行顺序(如果为None,则按注册顺序执行)
- continue_on_error: 遇到错误时是否继续执行后续步骤(默认False,遇到错误立即停止)
-
- 使用示例:
- >>> async def main():
- ... manager = TaskManager(...)
- ... manager.register_step("step1", async_func1, depends_on=[])
- ... manager.register_step("step2", async_func2, depends_on=["step1"])
- ... manager.register_step("step3", async_func3, depends_on=["step2"])
- ... manager.register_step("step4", async_func4, depends_on=["step2"])
- ... # step3和step4会在step2完成后并行执行
- ... await manager.run_all_async()
- """
- if step_order is None:
- step_order = list(self.steps.keys())
-
- # 获取拓扑排序的批次
- batches = self._topological_sort(step_order)
-
- if not batches:
- self.logger.warning("没有可执行的步骤")
- return
-
- self.logger.info(f"共 {len(batches)} 个执行批次(异步并发执行)")
-
- # 按批次执行
- for batch_idx, batch in enumerate(batches, 1):
- # 过滤出需要执行的步骤
- ready_steps = [step for step in batch if self.should_run_step(step) and self.check_dependencies(step)]
-
- if not ready_steps:
- self.logger.info(f"批次 {batch_idx}/{len(batches)}: 所有步骤已完成,跳过")
- continue
-
- self.logger.info(f"批次 {batch_idx}/{len(batches)}: 异步并发执行步骤 {ready_steps} (共 {len(ready_steps)} 个)")
-
- # 如果只有一个步骤,直接执行(避免并发开销)
- if len(ready_steps) == 1:
- try:
- await self.run_step_async(ready_steps[0])
- except Exception as e:
- self.logger.error(f"执行步骤 {ready_steps[0]} 时出错: {e}", exc_info=True)
- if not continue_on_error:
- raise
- else:
- # 异步并发执行多个步骤
- tasks = [self.run_step_async(step_name) for step_name in ready_steps]
-
- # 使用 asyncio.gather 并发执行所有步骤
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # 处理结果和异常
- for step_name, result in zip(ready_steps, results):
- if isinstance(result, Exception):
- self.logger.error(f"执行步骤 {step_name} 时出错: {result}", exc_info=True)
- if not continue_on_error:
- raise result
- else:
- self.logger.info(f"步骤 {step_name} 执行完成")
- def reset_step(self, step_name: str):
- """重置步骤状态,使其可以重新执行"""
- if step_name in self.state["steps"]:
- self.state["steps"][step_name] = {
- "status": StepStatus.PENDING,
- "started_at": None,
- "completed_at": None,
- "error": None,
- "output_file": None
- }
- # 删除缓存文件
- cache_path = self._get_cache_path(step_name)
- if cache_path.exists():
- cache_path.unlink()
- self._save_state()
- self.logger.info(f"步骤 {step_name} 已重置")
-
- def reset_all(self):
- """重置所有步骤状态"""
- for step_name in list(self.state["steps"].keys()):
- self.reset_step(step_name)
- def get_summary(self) -> Dict:
- """获取任务执行摘要"""
- total = len(self.state["steps"])
- completed = sum(1 for s in self.state["steps"].values()
- if s["status"] == StepStatus.COMPLETED)
- failed = sum(1 for s in self.state["steps"].values()
- if s["status"] == StepStatus.FAILED)
- pending = sum(1 for s in self.state["steps"].values()
- if s["status"] == StepStatus.PENDING)
-
- return {
- "total": total,
- "completed": completed,
- "failed": failed,
- "pending": pending,
- "progress": f"{completed}/{total}" if total > 0 else "0/0"
- }
|