""" 配置管理模块 提供统一的配置管理功能,支持从配置文件、环境变量加载配置 """ import json import os from pathlib import Path from typing import Any, Dict, Optional, Union from .logger import get_logger logger = get_logger("taskflow.config") class Config: """ 配置管理类 支持从多种来源加载配置: 1. 默认配置 2. 配置文件(JSON) 3. 环境变量 配置优先级:环境变量 > 配置文件 > 默认配置 使用示例: >>> config = Config() >>> config.load_from_file("config.json") >>> value = config.get("task.state_file") """ def __init__(self, default_config: Optional[Dict] = None): """ 初始化配置管理器 Args: default_config: 默认配置字典 """ self._config: Dict[str, Any] = {} self._default_config = default_config or self._get_default_config() self._load_defaults() def _get_default_config(self) -> Dict[str, Any]: """ 获取默认配置 Returns: 默认配置字典 """ return { "task": { "state_file": "task_state.json", "cache_dir": "task_cache", "auto_save": True, "save_interval": 1, # 每N个步骤保存一次状态 }, "logging": { "level": "INFO", "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "date_format": "%Y-%m-%d %H:%M:%S", "console_output": True, "file_output": False, "log_file": None, }, "run": { "base_output_dir": "output", "run_id_format": "run_{timestamp}", "timestamp_format": "%Y%m%d_%H%M%S", }, "io": { "encoding": "utf-8", "json_indent": 2, "create_dirs": True, }, "api": { "ark": { "api_key": None, "base_url": "https://ark.cn-beijing.volces.com", "model": "doubao-seed-1-6-251015", "image_model": "doubao-seedream-4-0-250828", "timeout": 60, } } } def _load_defaults(self): """加载默认配置""" self._config = self._deep_copy(self._default_config) def _deep_copy(self, source: Dict) -> Dict: """深拷贝字典""" return json.loads(json.dumps(source)) def load_from_file(self, config_file: Union[str, Path]) -> bool: """ 从配置文件加载配置 Args: config_file: 配置文件路径(JSON格式) Returns: 是否成功加载 """ config_path = Path(config_file) if not config_path.exists(): logger.warning(f"配置文件不存在: {config_file}") return False try: with open(config_path, 'r', encoding='utf-8') as f: file_config = json.load(f) self._merge_config(self._config, file_config) logger.info(f"成功加载配置文件: {config_file}") return True except json.JSONDecodeError as e: logger.error(f"配置文件格式错误: {e}") return False except Exception as e: logger.error(f"加载配置文件失败: {e}") return False def load_from_env(self, prefix: str = "TASKFLOW_"): """ 从环境变量加载配置 环境变量命名规则: - TASKFLOW_TASK_STATE_FILE -> task.state_file - TASKFLOW_LOGGING_LEVEL -> logging.level Args: prefix: 环境变量前缀 """ env_config = {} for key, value in os.environ.items(): if key.startswith(prefix): # 移除前缀并转换为配置路径 config_key = key[len(prefix):].lower() # 将下划线分隔的键转换为嵌套字典路径 keys = config_key.split('_') # 构建嵌套字典 current = env_config for k in keys[:-1]: if k not in current: current[k] = {} current = current[k] # 设置值(尝试转换为合适的类型) final_key = keys[-1] current[final_key] = self._parse_env_value(value) if env_config: self._merge_config(self._config, env_config) logger.info(f"从环境变量加载了 {len(env_config)} 个配置项") def _parse_env_value(self, value: str) -> Any: """ 解析环境变量值,尝试转换为合适的类型 Args: value: 环境变量值 Returns: 转换后的值 """ # 尝试转换为布尔值 if value.lower() in ('true', '1', 'yes', 'on'): return True if value.lower() in ('false', '0', 'no', 'off'): return False # 尝试转换为数字 try: if '.' in value: return float(value) return int(value) except ValueError: pass # 返回字符串 return value def _merge_config(self, base: Dict, override: Dict): """ 合并配置(深度合并) Args: base: 基础配置(会被修改) override: 覆盖配置 """ for key, value in override.items(): if key in base and isinstance(base[key], dict) and isinstance(value, dict): self._merge_config(base[key], value) else: base[key] = value def get(self, key: str, default: Any = None) -> Any: """ 获取配置值 支持点号分隔的嵌套键,如 "task.state_file" Args: key: 配置键(支持点号分隔的嵌套路径) default: 默认值 Returns: 配置值 示例: >>> config.get("task.state_file") >>> config.get("logging.level", "DEBUG") """ keys = key.split('.') current = self._config try: for k in keys: current = current[k] return current except (KeyError, TypeError): return default def set(self, key: str, value: Any): """ 设置配置值 支持点号分隔的嵌套键,如 "task.state_file" Args: key: 配置键(支持点号分隔的嵌套路径) value: 配置值 示例: >>> config.set("task.state_file", "custom_state.json") """ keys = key.split('.') current = self._config # 创建嵌套字典结构 for k in keys[:-1]: if k not in current: current[k] = {} elif not isinstance(current[k], dict): current[k] = {} current = current[k] # 设置值 current[keys[-1]] = value def has(self, key: str) -> bool: """ 检查配置键是否存在 Args: key: 配置键(支持点号分隔的嵌套路径) Returns: 是否存在 """ keys = key.split('.') current = self._config try: for k in keys: current = current[k] return True except (KeyError, TypeError): return False def get_section(self, section: str) -> Dict[str, Any]: """ 获取配置节 Args: section: 配置节名称(如 "task", "logging") Returns: 配置节字典 """ return self._config.get(section, {}) def to_dict(self) -> Dict[str, Any]: """ 获取完整配置字典 Returns: 配置字典的深拷贝 """ return self._deep_copy(self._config) def save_to_file(self, config_file: Union[str, Path], indent: int = 2): """ 保存配置到文件 Args: config_file: 配置文件路径 indent: JSON缩进空格数 """ config_path = Path(config_file) config_path.parent.mkdir(parents=True, exist_ok=True) try: with open(config_path, 'w', encoding='utf-8') as f: json.dump(self._config, f, ensure_ascii=False, indent=indent) logger.info(f"配置已保存到: {config_file}") except Exception as e: logger.error(f"保存配置失败: {e}") def reset(self): """重置为默认配置""" self._load_defaults() logger.info("配置已重置为默认值") def validate(self): """ 验证配置的有效性 Returns: (是否有效, 错误列表) """ errors = [] # 验证task配置 if not isinstance(self.get("task.state_file"), str): errors.append("task.state_file 必须是字符串") if not isinstance(self.get("task.cache_dir"), str): errors.append("task.cache_dir 必须是字符串") # 验证logging配置 valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] log_level = self.get("logging.level") if log_level not in valid_log_levels: errors.append(f"logging.level 必须是以下之一: {valid_log_levels}") # 验证run配置 if not isinstance(self.get("run.base_output_dir"), str): errors.append("run.base_output_dir 必须是字符串") return len(errors) == 0, errors # 全局配置实例 _global_config: Optional[Config] = None def get_config(config_file: Optional[Union[str, Path]] = None) -> Config: """ 获取全局配置实例 如果全局配置不存在,会创建一个新的配置实例。 如果提供了配置文件路径,会尝试加载。 Args: config_file: 可选的配置文件路径 Returns: 配置实例 """ global _global_config if _global_config is None: _global_config = Config() # 尝试从环境变量加载 _global_config.load_from_env() # 如果提供了配置文件路径,尝试加载 if config_file: _global_config.load_from_file(config_file) else: # 尝试加载默认配置文件 default_config_files = [ "./config/taskflow_config.example.json", "./config/taskflow_config.json", "./config/.taskflow_config.json" ] for config_file_path in default_config_files: if Path(config_file_path).exists(): _global_config.load_from_file(config_file_path) break return _global_config def reset_config(): """重置全局配置""" global _global_config _global_config = None