| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- import os
- import importlib.util
- from typing import Dict, Any, Optional
- import yaml
- from utils.logger_config import setup_logger
- logger = setup_logger(__name__)
- class ConfigManager:
- """配置管理器,用于加载和管理提示词配置"""
-
- def __init__(self, config_path: Optional[str] = None):
- """
- 初始化配置管理器
-
- Args:
- config_path: 配置文件路径,如果为None则使用默认路径
- """
- self.config_path = config_path or os.path.join(
- os.path.dirname(os.path.dirname(__file__)),
- "config",
- "prompts.py"
- )
- self.config: Dict[str, Any] = {}
- self.load_config()
-
- def load_config(self) -> None:
- """加载配置文件"""
- try:
- # 检查是否为Python配置文件
- if self.config_path.endswith('.py'):
- self._load_py_config()
- else:
- self._load_yaml_config()
-
- except Exception as e:
- logger.error(f"Failed to load config: {str(e)}")
- raise
-
- def _load_yaml_config(self) -> None:
- """加载YAML配置文件"""
- if not os.path.exists(self.config_path):
- logger.error(f"Config file not found: {self.config_path}")
- raise FileNotFoundError(f"Config file not found: {self.config_path}")
-
- with open(self.config_path, 'r', encoding='utf-8') as f:
- self.config = yaml.safe_load(f)
- logger.info(f"Successfully loaded YAML config from {self.config_path}")
-
- def _load_py_config(self) -> None:
- """从Python文件加载配置"""
- if not os.path.exists(self.config_path):
- logger.error(f"Config file not found: {self.config_path}")
- raise FileNotFoundError(f"Config file not found: {self.config_path}")
-
- # 使用importlib加载Python模块
- spec = importlib.util.spec_from_file_location("config_module", self.config_path)
- if spec is None:
- raise ImportError(f"Could not load spec from {self.config_path}")
-
- module = importlib.util.module_from_spec(spec)
- if spec.loader is not None:
- spec.loader.exec_module(module)
-
- # 解析Python配置文件中的分层结构
- py_config = self._parse_py_config(module)
-
- # 构建与YAML配置相同的结构
- self.config = py_config
- logger.info(f"Successfully loaded Python config from {self.config_path}")
-
- def _parse_py_config(self, module) -> Dict[str, Any]:
- """
- 解析Python配置文件中的分层结构
-
- Args:
- module: 加载的Python模块
-
- Returns:
- Dict: 解析后的配置字典
- """
- # 初始化配置结构
- config = {
- "defaults": {
- "video": {},
- "image": {},
- "text": {}
- },
- "scenarios": {
- "news": {
- "video": {},
- "image": {},
- "text": {}
- },
- "entertainment": {
- "video": {},
- "image": {},
- "text": {}
- },
- "academic": {
- "video": {},
- "image": {},
- "text": {}
- }
- }
- }
-
- # 查找模块中的所有大写变量并分类
- for attr_name in dir(module):
- if attr_name.isupper() and not attr_name.startswith('_'): # 只获取大写的变量名
- attr_value = getattr(module, attr_name)
- if isinstance(attr_value, str):
- # 根据变量名前缀分类到不同的媒体类型和场景
- self._categorize_config_item(config, attr_name, attr_value)
-
- return config
-
- def _categorize_config_item(self, config: Dict, attr_name: str, attr_value: str) -> None:
- """
- 根据变量名将配置项分类到适当的媒体类型和场景中
-
- Args:
- config: 配置字典
- attr_name: 变量名
- attr_value: 变量值
- """
- # 转换为小写以便比较
- name_lower = attr_name.lower()
-
- # 确定场景
- scenario = None
- if 'news' in name_lower:
- scenario = 'news'
- elif 'entertainment' in name_lower:
- scenario = 'entertainment'
- elif 'academic' in name_lower:
- scenario = 'academic'
-
- # 确定媒体类型
- media_type = 'text' # 默认为text
- if 'video' in name_lower:
- media_type = 'video'
- elif 'image' in name_lower:
- media_type = 'image'
-
- # 确定提示词类型(去除前缀后的部分)
- prompt_type = self._extract_prompt_type(attr_name, scenario, media_type)
-
- # 将配置项放入相应的位置
- if scenario:
- config['scenarios'][scenario][media_type][prompt_type] = attr_value
- else:
- config['defaults'][media_type][prompt_type] = attr_value
-
- def _extract_prompt_type(self, attr_name: str, scenario: Optional[str], media_type: str) -> str:
- """
- 从变量名中提取提示词类型
-
- Args:
- attr_name: 变量名
- scenario: 场景类型
- media_type: 媒体类型
-
- Returns:
- str: 提示词类型
- """
- # 移除场景前缀
- name = attr_name
- if scenario:
- scenario_prefix = scenario.upper()
- if name.startswith(scenario_prefix):
- name = name[len(scenario_prefix):]
-
- # 移除媒体类型前缀
- media_prefix = media_type.upper()
- if name.startswith(media_prefix):
- name = name[len(media_prefix):]
-
- # 移除下划线分隔符
- if name.startswith('_'):
- name = name[1:]
-
- # 转换为小写作为提示词类型
- return name.lower() if name else 'default'
-
- def get_prompt(self,
- media_type: str,
- prompt_type: str = "caption",
- scenario: Optional[str] = None) -> str:
- """
- 获取指定类型的提示词
-
- Args:
- media_type: 媒体类型 ('video' 或 'image' 或 'text')
- prompt_type: 提示词类型 (如 'caption', 'scene' 等)
- scenario: 场景类型 (如 'news', 'entertainment' 等)
-
- Returns:
- str: 提示词
-
- Raises:
- KeyError: 指定的配置不存在
- """
- try:
- # 如果指定了场景,优先使用场景特定配置
- if scenario and scenario in self.config.get("scenarios", {}):
- scenario_config = self.config["scenarios"][scenario]
- if media_type in scenario_config and prompt_type in scenario_config[media_type]:
- result = scenario_config[media_type][prompt_type]
- # 如果是从Python文件加载的,提取content字段
- if isinstance(result, dict) and "content" in result:
- return result["content"]
- return str(result)
-
- # 回退到默认配置
- if prompt_type in self.config["defaults"][media_type]:
- prompt_data = self.config["defaults"][media_type][prompt_type]
- # 如果是从Python文件加载的,提取content字段
- if isinstance(prompt_data, dict) and "content" in prompt_data:
- return prompt_data["content"]
- return str(prompt_data)
-
- raise KeyError(f"Prompt not found for {media_type}/{prompt_type}")
-
- except KeyError as e:
- logger.error(f"Failed to get prompt: {str(e)}")
- # 返回基础提示词作为后备
- return ("视频里有什么?" if media_type == "video" else "请描述图片内容")
-
- def get_all_prompts(self,
- media_type: str,
- scenario: Optional[str] = None) -> Dict[str, str]:
- """
- 获取指定媒体类型的所有提示词
-
- Args:
- media_type: 媒体类型 ('video' 或 'image' 或 'text')
- scenario: 场景类型 (如 'news', 'entertainment' 等)
-
- Returns:
- Dict[str, str]: 提示词类型到提示词的映射
- """
- prompts = {}
-
- # 获取默认配置
- if media_type in self.config["defaults"]:
- for key, value in self.config["defaults"][media_type].items():
- # 如果是从Python文件加载的,提取content字段
- if isinstance(value, dict) and "content" in value:
- prompts[key] = value["content"]
- else:
- prompts[key] = str(value)
-
- # 如果指定了场景,添加或覆盖场景特定配置
- if scenario and scenario in self.config.get("scenarios", {}):
- scenario_config = self.config["scenarios"][scenario]
- if media_type in scenario_config:
- for key, value in scenario_config[media_type].items():
- # 如果是从Python文件加载的,提取content字段
- if isinstance(value, dict) and "content" in value:
- prompts[key] = value["content"]
- else:
- prompts[key] = str(value)
-
- return prompts
- if __name__ == "__main__":
- # 创建ConfigManager实例
- config_manager = ConfigManager()
-
- # 获取所有提示词
- all_prompts = config_manager.get_all_prompts("video")
- print(all_prompts)
-
- # 获取指定类型的提示词
- caption_prompt = config_manager.get_prompt("video", "script")
- print(caption_prompt)
|