config.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """
  2. 配置管理模块
  3. 提供统一的配置管理功能,支持从配置文件、环境变量加载配置
  4. """
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Dict, Optional, Union
  9. from .logger import get_logger
  10. logger = get_logger("taskflow.config")
  11. class Config:
  12. """
  13. 配置管理类
  14. 支持从多种来源加载配置:
  15. 1. 默认配置
  16. 2. 配置文件(JSON)
  17. 3. 环境变量
  18. 配置优先级:环境变量 > 配置文件 > 默认配置
  19. 使用示例:
  20. >>> config = Config()
  21. >>> config.load_from_file("config.json")
  22. >>> value = config.get("task.state_file")
  23. """
  24. def __init__(self, default_config: Optional[Dict] = None):
  25. """
  26. 初始化配置管理器
  27. Args:
  28. default_config: 默认配置字典
  29. """
  30. self._config: Dict[str, Any] = {}
  31. self._default_config = default_config or self._get_default_config()
  32. self._load_defaults()
  33. def _get_default_config(self) -> Dict[str, Any]:
  34. """
  35. 获取默认配置
  36. Returns:
  37. 默认配置字典
  38. """
  39. return {
  40. "task": {
  41. "state_file": "task_state.json",
  42. "cache_dir": "task_cache",
  43. "auto_save": True,
  44. "save_interval": 1, # 每N个步骤保存一次状态
  45. },
  46. "logging": {
  47. "level": "INFO",
  48. "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  49. "date_format": "%Y-%m-%d %H:%M:%S",
  50. "console_output": True,
  51. "file_output": False,
  52. "log_file": None,
  53. },
  54. "run": {
  55. "base_output_dir": "output",
  56. "run_id_format": "run_{timestamp}",
  57. "timestamp_format": "%Y%m%d_%H%M%S",
  58. },
  59. "io": {
  60. "encoding": "utf-8",
  61. "json_indent": 2,
  62. "create_dirs": True,
  63. },
  64. "api": {
  65. "ark": {
  66. "api_key": None,
  67. "base_url": "https://ark.cn-beijing.volces.com",
  68. "model": "doubao-seed-1-6-251015",
  69. "image_model": "doubao-seedream-4-0-250828",
  70. "timeout": 60,
  71. }
  72. }
  73. }
  74. def _load_defaults(self):
  75. """加载默认配置"""
  76. self._config = self._deep_copy(self._default_config)
  77. def _deep_copy(self, source: Dict) -> Dict:
  78. """深拷贝字典"""
  79. return json.loads(json.dumps(source))
  80. def load_from_file(self, config_file: Union[str, Path]) -> bool:
  81. """
  82. 从配置文件加载配置
  83. Args:
  84. config_file: 配置文件路径(JSON格式)
  85. Returns:
  86. 是否成功加载
  87. """
  88. config_path = Path(config_file)
  89. if not config_path.exists():
  90. logger.warning(f"配置文件不存在: {config_file}")
  91. return False
  92. try:
  93. with open(config_path, 'r', encoding='utf-8') as f:
  94. file_config = json.load(f)
  95. self._merge_config(self._config, file_config)
  96. logger.info(f"成功加载配置文件: {config_file}")
  97. return True
  98. except json.JSONDecodeError as e:
  99. logger.error(f"配置文件格式错误: {e}")
  100. return False
  101. except Exception as e:
  102. logger.error(f"加载配置文件失败: {e}")
  103. return False
  104. def load_from_env(self, prefix: str = "TASKFLOW_"):
  105. """
  106. 从环境变量加载配置
  107. 环境变量命名规则:
  108. - TASKFLOW_TASK_STATE_FILE -> task.state_file
  109. - TASKFLOW_LOGGING_LEVEL -> logging.level
  110. Args:
  111. prefix: 环境变量前缀
  112. """
  113. env_config = {}
  114. for key, value in os.environ.items():
  115. if key.startswith(prefix):
  116. # 移除前缀并转换为配置路径
  117. config_key = key[len(prefix):].lower()
  118. # 将下划线分隔的键转换为嵌套字典路径
  119. keys = config_key.split('_')
  120. # 构建嵌套字典
  121. current = env_config
  122. for k in keys[:-1]:
  123. if k not in current:
  124. current[k] = {}
  125. current = current[k]
  126. # 设置值(尝试转换为合适的类型)
  127. final_key = keys[-1]
  128. current[final_key] = self._parse_env_value(value)
  129. if env_config:
  130. self._merge_config(self._config, env_config)
  131. logger.info(f"从环境变量加载了 {len(env_config)} 个配置项")
  132. def _parse_env_value(self, value: str) -> Any:
  133. """
  134. 解析环境变量值,尝试转换为合适的类型
  135. Args:
  136. value: 环境变量值
  137. Returns:
  138. 转换后的值
  139. """
  140. # 尝试转换为布尔值
  141. if value.lower() in ('true', '1', 'yes', 'on'):
  142. return True
  143. if value.lower() in ('false', '0', 'no', 'off'):
  144. return False
  145. # 尝试转换为数字
  146. try:
  147. if '.' in value:
  148. return float(value)
  149. return int(value)
  150. except ValueError:
  151. pass
  152. # 返回字符串
  153. return value
  154. def _merge_config(self, base: Dict, override: Dict):
  155. """
  156. 合并配置(深度合并)
  157. Args:
  158. base: 基础配置(会被修改)
  159. override: 覆盖配置
  160. """
  161. for key, value in override.items():
  162. if key in base and isinstance(base[key], dict) and isinstance(value, dict):
  163. self._merge_config(base[key], value)
  164. else:
  165. base[key] = value
  166. def get(self, key: str, default: Any = None) -> Any:
  167. """
  168. 获取配置值
  169. 支持点号分隔的嵌套键,如 "task.state_file"
  170. Args:
  171. key: 配置键(支持点号分隔的嵌套路径)
  172. default: 默认值
  173. Returns:
  174. 配置值
  175. 示例:
  176. >>> config.get("task.state_file")
  177. >>> config.get("logging.level", "DEBUG")
  178. """
  179. keys = key.split('.')
  180. current = self._config
  181. try:
  182. for k in keys:
  183. current = current[k]
  184. return current
  185. except (KeyError, TypeError):
  186. return default
  187. def set(self, key: str, value: Any):
  188. """
  189. 设置配置值
  190. 支持点号分隔的嵌套键,如 "task.state_file"
  191. Args:
  192. key: 配置键(支持点号分隔的嵌套路径)
  193. value: 配置值
  194. 示例:
  195. >>> config.set("task.state_file", "custom_state.json")
  196. """
  197. keys = key.split('.')
  198. current = self._config
  199. # 创建嵌套字典结构
  200. for k in keys[:-1]:
  201. if k not in current:
  202. current[k] = {}
  203. elif not isinstance(current[k], dict):
  204. current[k] = {}
  205. current = current[k]
  206. # 设置值
  207. current[keys[-1]] = value
  208. def has(self, key: str) -> bool:
  209. """
  210. 检查配置键是否存在
  211. Args:
  212. key: 配置键(支持点号分隔的嵌套路径)
  213. Returns:
  214. 是否存在
  215. """
  216. keys = key.split('.')
  217. current = self._config
  218. try:
  219. for k in keys:
  220. current = current[k]
  221. return True
  222. except (KeyError, TypeError):
  223. return False
  224. def get_section(self, section: str) -> Dict[str, Any]:
  225. """
  226. 获取配置节
  227. Args:
  228. section: 配置节名称(如 "task", "logging")
  229. Returns:
  230. 配置节字典
  231. """
  232. return self._config.get(section, {})
  233. def to_dict(self) -> Dict[str, Any]:
  234. """
  235. 获取完整配置字典
  236. Returns:
  237. 配置字典的深拷贝
  238. """
  239. return self._deep_copy(self._config)
  240. def save_to_file(self, config_file: Union[str, Path], indent: int = 2):
  241. """
  242. 保存配置到文件
  243. Args:
  244. config_file: 配置文件路径
  245. indent: JSON缩进空格数
  246. """
  247. config_path = Path(config_file)
  248. config_path.parent.mkdir(parents=True, exist_ok=True)
  249. try:
  250. with open(config_path, 'w', encoding='utf-8') as f:
  251. json.dump(self._config, f, ensure_ascii=False, indent=indent)
  252. logger.info(f"配置已保存到: {config_file}")
  253. except Exception as e:
  254. logger.error(f"保存配置失败: {e}")
  255. def reset(self):
  256. """重置为默认配置"""
  257. self._load_defaults()
  258. logger.info("配置已重置为默认值")
  259. def validate(self):
  260. """
  261. 验证配置的有效性
  262. Returns:
  263. (是否有效, 错误列表)
  264. """
  265. errors = []
  266. # 验证task配置
  267. if not isinstance(self.get("task.state_file"), str):
  268. errors.append("task.state_file 必须是字符串")
  269. if not isinstance(self.get("task.cache_dir"), str):
  270. errors.append("task.cache_dir 必须是字符串")
  271. # 验证logging配置
  272. valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
  273. log_level = self.get("logging.level")
  274. if log_level not in valid_log_levels:
  275. errors.append(f"logging.level 必须是以下之一: {valid_log_levels}")
  276. # 验证run配置
  277. if not isinstance(self.get("run.base_output_dir"), str):
  278. errors.append("run.base_output_dir 必须是字符串")
  279. return len(errors) == 0, errors
  280. # 全局配置实例
  281. _global_config: Optional[Config] = None
  282. def get_config(config_file: Optional[Union[str, Path]] = None) -> Config:
  283. """
  284. 获取全局配置实例
  285. 如果全局配置不存在,会创建一个新的配置实例。
  286. 如果提供了配置文件路径,会尝试加载。
  287. Args:
  288. config_file: 可选的配置文件路径
  289. Returns:
  290. 配置实例
  291. """
  292. global _global_config
  293. if _global_config is None:
  294. _global_config = Config()
  295. # 尝试从环境变量加载
  296. _global_config.load_from_env()
  297. # 如果提供了配置文件路径,尝试加载
  298. if config_file:
  299. _global_config.load_from_file(config_file)
  300. else:
  301. # 尝试加载默认配置文件
  302. default_config_files = [
  303. "./config/taskflow_config.example.json",
  304. "./config/taskflow_config.json",
  305. "./config/.taskflow_config.json"
  306. ]
  307. for config_file_path in default_config_files:
  308. if Path(config_file_path).exists():
  309. _global_config.load_from_file(config_file_path)
  310. break
  311. return _global_config
  312. def reset_config():
  313. """重置全局配置"""
  314. global _global_config
  315. _global_config = None