| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- """
- 配置管理模块
- 提供统一的配置管理功能,支持从配置文件、环境变量加载配置
- """
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, Optional, Union
- from .logger import get_logger
- logger = get_logger("taskflow.config")
- class Config:
- """
- 配置管理类
-
- 支持从多种来源加载配置:
- 1. 默认配置
- 2. 配置文件(JSON)
- 3. 环境变量
-
- 配置优先级:环境变量 > 配置文件 > 默认配置
-
- 使用示例:
- >>> config = Config()
- >>> config.load_from_file("config.json")
- >>> value = config.get("task.state_file")
- """
-
- def __init__(self, default_config: Optional[Dict] = None):
- """
- 初始化配置管理器
-
- Args:
- default_config: 默认配置字典
- """
- self._config: Dict[str, Any] = {}
- self._default_config = default_config or self._get_default_config()
- self._load_defaults()
-
- def _get_default_config(self) -> Dict[str, Any]:
- """
- 获取默认配置
-
- Returns:
- 默认配置字典
- """
- return {
- "task": {
- "state_file": "task_state.json",
- "cache_dir": "task_cache",
- "auto_save": True,
- "save_interval": 1, # 每N个步骤保存一次状态
- },
- "logging": {
- "level": "INFO",
- "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
- "date_format": "%Y-%m-%d %H:%M:%S",
- "console_output": True,
- "file_output": False,
- "log_file": None,
- },
- "run": {
- "base_output_dir": "output",
- "run_id_format": "run_{timestamp}",
- "timestamp_format": "%Y%m%d_%H%M%S",
- },
- "io": {
- "encoding": "utf-8",
- "json_indent": 2,
- "create_dirs": True,
- },
- "api": {
- "ark": {
- "api_key": None,
- "base_url": "https://ark.cn-beijing.volces.com",
- "model": "doubao-seed-1-6-251015",
- "image_model": "doubao-seedream-4-0-250828",
- "timeout": 60,
- }
- }
- }
-
- def _load_defaults(self):
- """加载默认配置"""
- self._config = self._deep_copy(self._default_config)
-
- def _deep_copy(self, source: Dict) -> Dict:
- """深拷贝字典"""
- return json.loads(json.dumps(source))
-
- def load_from_file(self, config_file: Union[str, Path]) -> bool:
- """
- 从配置文件加载配置
-
- Args:
- config_file: 配置文件路径(JSON格式)
-
- Returns:
- 是否成功加载
- """
- config_path = Path(config_file)
-
- if not config_path.exists():
- logger.warning(f"配置文件不存在: {config_file}")
- return False
-
- try:
- with open(config_path, 'r', encoding='utf-8') as f:
- file_config = json.load(f)
-
- self._merge_config(self._config, file_config)
- logger.info(f"成功加载配置文件: {config_file}")
- return True
-
- except json.JSONDecodeError as e:
- logger.error(f"配置文件格式错误: {e}")
- return False
- except Exception as e:
- logger.error(f"加载配置文件失败: {e}")
- return False
-
- def load_from_env(self, prefix: str = "TASKFLOW_"):
- """
- 从环境变量加载配置
-
- 环境变量命名规则:
- - TASKFLOW_TASK_STATE_FILE -> task.state_file
- - TASKFLOW_LOGGING_LEVEL -> logging.level
-
- Args:
- prefix: 环境变量前缀
- """
- env_config = {}
-
- for key, value in os.environ.items():
- if key.startswith(prefix):
- # 移除前缀并转换为配置路径
- config_key = key[len(prefix):].lower()
- # 将下划线分隔的键转换为嵌套字典路径
- keys = config_key.split('_')
-
- # 构建嵌套字典
- current = env_config
- for k in keys[:-1]:
- if k not in current:
- current[k] = {}
- current = current[k]
-
- # 设置值(尝试转换为合适的类型)
- final_key = keys[-1]
- current[final_key] = self._parse_env_value(value)
-
- if env_config:
- self._merge_config(self._config, env_config)
- logger.info(f"从环境变量加载了 {len(env_config)} 个配置项")
-
- def _parse_env_value(self, value: str) -> Any:
- """
- 解析环境变量值,尝试转换为合适的类型
-
- Args:
- value: 环境变量值
-
- Returns:
- 转换后的值
- """
- # 尝试转换为布尔值
- if value.lower() in ('true', '1', 'yes', 'on'):
- return True
- if value.lower() in ('false', '0', 'no', 'off'):
- return False
-
- # 尝试转换为数字
- try:
- if '.' in value:
- return float(value)
- return int(value)
- except ValueError:
- pass
-
- # 返回字符串
- return value
-
- def _merge_config(self, base: Dict, override: Dict):
- """
- 合并配置(深度合并)
-
- Args:
- base: 基础配置(会被修改)
- override: 覆盖配置
- """
- for key, value in override.items():
- if key in base and isinstance(base[key], dict) and isinstance(value, dict):
- self._merge_config(base[key], value)
- else:
- base[key] = value
-
- def get(self, key: str, default: Any = None) -> Any:
- """
- 获取配置值
-
- 支持点号分隔的嵌套键,如 "task.state_file"
-
- Args:
- key: 配置键(支持点号分隔的嵌套路径)
- default: 默认值
-
- Returns:
- 配置值
-
- 示例:
- >>> config.get("task.state_file")
- >>> config.get("logging.level", "DEBUG")
- """
- keys = key.split('.')
- current = self._config
-
- try:
- for k in keys:
- current = current[k]
- return current
- except (KeyError, TypeError):
- return default
-
- def set(self, key: str, value: Any):
- """
- 设置配置值
-
- 支持点号分隔的嵌套键,如 "task.state_file"
-
- Args:
- key: 配置键(支持点号分隔的嵌套路径)
- value: 配置值
-
- 示例:
- >>> config.set("task.state_file", "custom_state.json")
- """
- keys = key.split('.')
- current = self._config
-
- # 创建嵌套字典结构
- for k in keys[:-1]:
- if k not in current:
- current[k] = {}
- elif not isinstance(current[k], dict):
- current[k] = {}
- current = current[k]
-
- # 设置值
- current[keys[-1]] = value
-
- def has(self, key: str) -> bool:
- """
- 检查配置键是否存在
-
- Args:
- key: 配置键(支持点号分隔的嵌套路径)
-
- Returns:
- 是否存在
- """
- keys = key.split('.')
- current = self._config
-
- try:
- for k in keys:
- current = current[k]
- return True
- except (KeyError, TypeError):
- return False
-
- def get_section(self, section: str) -> Dict[str, Any]:
- """
- 获取配置节
-
- Args:
- section: 配置节名称(如 "task", "logging")
-
- Returns:
- 配置节字典
- """
- return self._config.get(section, {})
-
- def to_dict(self) -> Dict[str, Any]:
- """
- 获取完整配置字典
-
- Returns:
- 配置字典的深拷贝
- """
- return self._deep_copy(self._config)
-
- def save_to_file(self, config_file: Union[str, Path], indent: int = 2):
- """
- 保存配置到文件
-
- Args:
- config_file: 配置文件路径
- indent: JSON缩进空格数
- """
- config_path = Path(config_file)
- config_path.parent.mkdir(parents=True, exist_ok=True)
-
- try:
- with open(config_path, 'w', encoding='utf-8') as f:
- json.dump(self._config, f, ensure_ascii=False, indent=indent)
- logger.info(f"配置已保存到: {config_file}")
- except Exception as e:
- logger.error(f"保存配置失败: {e}")
-
- def reset(self):
- """重置为默认配置"""
- self._load_defaults()
- logger.info("配置已重置为默认值")
-
- def validate(self):
- """
- 验证配置的有效性
-
- Returns:
- (是否有效, 错误列表)
- """
- errors = []
-
- # 验证task配置
- if not isinstance(self.get("task.state_file"), str):
- errors.append("task.state_file 必须是字符串")
-
- if not isinstance(self.get("task.cache_dir"), str):
- errors.append("task.cache_dir 必须是字符串")
-
- # 验证logging配置
- valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
- log_level = self.get("logging.level")
- if log_level not in valid_log_levels:
- errors.append(f"logging.level 必须是以下之一: {valid_log_levels}")
-
- # 验证run配置
- if not isinstance(self.get("run.base_output_dir"), str):
- errors.append("run.base_output_dir 必须是字符串")
-
- return len(errors) == 0, errors
- # 全局配置实例
- _global_config: Optional[Config] = None
- def get_config(config_file: Optional[Union[str, Path]] = None) -> Config:
- """
- 获取全局配置实例
-
- 如果全局配置不存在,会创建一个新的配置实例。
- 如果提供了配置文件路径,会尝试加载。
-
- Args:
- config_file: 可选的配置文件路径
-
- Returns:
- 配置实例
- """
- global _global_config
-
- if _global_config is None:
- _global_config = Config()
-
- # 尝试从环境变量加载
- _global_config.load_from_env()
-
- # 如果提供了配置文件路径,尝试加载
- if config_file:
- _global_config.load_from_file(config_file)
- else:
- # 尝试加载默认配置文件
- default_config_files = [
- "./config/taskflow_config.example.json",
- "./config/taskflow_config.json",
- "./config/.taskflow_config.json"
- ]
-
- for config_file_path in default_config_files:
- if Path(config_file_path).exists():
- _global_config.load_from_file(config_file_path)
- break
-
- return _global_config
- def reset_config():
- """重置全局配置"""
- global _global_config
- _global_config = None
|