import os import importlib.util from typing import Dict, Any, Optional import yaml from utils.logger_config import setup_logger logger = setup_logger(__name__) class ConfigManager: """配置管理器,用于加载和管理提示词配置""" def __init__(self, config_path: Optional[str] = None): """ 初始化配置管理器 Args: config_path: 配置文件路径,如果为None则使用默认路径 """ self.config_path = config_path or os.path.join( os.path.dirname(os.path.dirname(__file__)), "config", "prompts.py" ) self.config: Dict[str, Any] = {} self.load_config() def load_config(self) -> None: """加载配置文件""" try: # 检查是否为Python配置文件 if self.config_path.endswith('.py'): self._load_py_config() else: self._load_yaml_config() except Exception as e: logger.error(f"Failed to load config: {str(e)}") raise def _load_yaml_config(self) -> None: """加载YAML配置文件""" if not os.path.exists(self.config_path): logger.error(f"Config file not found: {self.config_path}") raise FileNotFoundError(f"Config file not found: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) logger.info(f"Successfully loaded YAML config from {self.config_path}") def _load_py_config(self) -> None: """从Python文件加载配置""" if not os.path.exists(self.config_path): logger.error(f"Config file not found: {self.config_path}") raise FileNotFoundError(f"Config file not found: {self.config_path}") # 使用importlib加载Python模块 spec = importlib.util.spec_from_file_location("config_module", self.config_path) if spec is None: raise ImportError(f"Could not load spec from {self.config_path}") module = importlib.util.module_from_spec(spec) if spec.loader is not None: spec.loader.exec_module(module) # 解析Python配置文件中的分层结构 py_config = self._parse_py_config(module) # 构建与YAML配置相同的结构 self.config = py_config logger.info(f"Successfully loaded Python config from {self.config_path}") def _parse_py_config(self, module) -> Dict[str, Any]: """ 解析Python配置文件中的分层结构 Args: module: 加载的Python模块 Returns: Dict: 解析后的配置字典 """ # 初始化配置结构 config = { "defaults": { "video": {}, "image": {}, "text": {} }, "scenarios": { "news": { "video": {}, "image": {}, "text": {} }, "entertainment": { "video": {}, "image": {}, "text": {} }, "academic": { "video": {}, "image": {}, "text": {} } } } # 查找模块中的所有大写变量并分类 for attr_name in dir(module): if attr_name.isupper() and not attr_name.startswith('_'): # 只获取大写的变量名 attr_value = getattr(module, attr_name) if isinstance(attr_value, str): # 根据变量名前缀分类到不同的媒体类型和场景 self._categorize_config_item(config, attr_name, attr_value) return config def _categorize_config_item(self, config: Dict, attr_name: str, attr_value: str) -> None: """ 根据变量名将配置项分类到适当的媒体类型和场景中 Args: config: 配置字典 attr_name: 变量名 attr_value: 变量值 """ # 转换为小写以便比较 name_lower = attr_name.lower() # 确定场景 scenario = None if 'news' in name_lower: scenario = 'news' elif 'entertainment' in name_lower: scenario = 'entertainment' elif 'academic' in name_lower: scenario = 'academic' # 确定媒体类型 media_type = 'text' # 默认为text if 'video' in name_lower: media_type = 'video' elif 'image' in name_lower: media_type = 'image' # 确定提示词类型(去除前缀后的部分) prompt_type = self._extract_prompt_type(attr_name, scenario, media_type) # 将配置项放入相应的位置 if scenario: config['scenarios'][scenario][media_type][prompt_type] = attr_value else: config['defaults'][media_type][prompt_type] = attr_value def _extract_prompt_type(self, attr_name: str, scenario: Optional[str], media_type: str) -> str: """ 从变量名中提取提示词类型 Args: attr_name: 变量名 scenario: 场景类型 media_type: 媒体类型 Returns: str: 提示词类型 """ # 移除场景前缀 name = attr_name if scenario: scenario_prefix = scenario.upper() if name.startswith(scenario_prefix): name = name[len(scenario_prefix):] # 移除媒体类型前缀 media_prefix = media_type.upper() if name.startswith(media_prefix): name = name[len(media_prefix):] # 移除下划线分隔符 if name.startswith('_'): name = name[1:] # 转换为小写作为提示词类型 return name.lower() if name else 'default' def get_prompt(self, media_type: str, prompt_type: str = "caption", scenario: Optional[str] = None) -> str: """ 获取指定类型的提示词 Args: media_type: 媒体类型 ('video' 或 'image' 或 'text') prompt_type: 提示词类型 (如 'caption', 'scene' 等) scenario: 场景类型 (如 'news', 'entertainment' 等) Returns: str: 提示词 Raises: KeyError: 指定的配置不存在 """ try: # 如果指定了场景,优先使用场景特定配置 if scenario and scenario in self.config.get("scenarios", {}): scenario_config = self.config["scenarios"][scenario] if media_type in scenario_config and prompt_type in scenario_config[media_type]: result = scenario_config[media_type][prompt_type] # 如果是从Python文件加载的,提取content字段 if isinstance(result, dict) and "content" in result: return result["content"] return str(result) # 回退到默认配置 if prompt_type in self.config["defaults"][media_type]: prompt_data = self.config["defaults"][media_type][prompt_type] # 如果是从Python文件加载的,提取content字段 if isinstance(prompt_data, dict) and "content" in prompt_data: return prompt_data["content"] return str(prompt_data) raise KeyError(f"Prompt not found for {media_type}/{prompt_type}") except KeyError as e: logger.error(f"Failed to get prompt: {str(e)}") # 返回基础提示词作为后备 return ("视频里有什么?" if media_type == "video" else "请描述图片内容") def get_all_prompts(self, media_type: str, scenario: Optional[str] = None) -> Dict[str, str]: """ 获取指定媒体类型的所有提示词 Args: media_type: 媒体类型 ('video' 或 'image' 或 'text') scenario: 场景类型 (如 'news', 'entertainment' 等) Returns: Dict[str, str]: 提示词类型到提示词的映射 """ prompts = {} # 获取默认配置 if media_type in self.config["defaults"]: for key, value in self.config["defaults"][media_type].items(): # 如果是从Python文件加载的,提取content字段 if isinstance(value, dict) and "content" in value: prompts[key] = value["content"] else: prompts[key] = str(value) # 如果指定了场景,添加或覆盖场景特定配置 if scenario and scenario in self.config.get("scenarios", {}): scenario_config = self.config["scenarios"][scenario] if media_type in scenario_config: for key, value in scenario_config[media_type].items(): # 如果是从Python文件加载的,提取content字段 if isinstance(value, dict) and "content" in value: prompts[key] = value["content"] else: prompts[key] = str(value) return prompts if __name__ == "__main__": # 创建ConfigManager实例 config_manager = ConfigManager() # 获取所有提示词 all_prompts = config_manager.get_all_prompts("video") print(all_prompts) # 获取指定类型的提示词 caption_prompt = config_manager.get_prompt("video", "script") print(caption_prompt)