config_manager.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. import importlib.util
  3. from typing import Dict, Any, Optional
  4. import yaml
  5. from utils.logger_config import setup_logger
  6. logger = setup_logger(__name__)
  7. class ConfigManager:
  8. """配置管理器,用于加载和管理提示词配置"""
  9. def __init__(self, config_path: Optional[str] = None):
  10. """
  11. 初始化配置管理器
  12. Args:
  13. config_path: 配置文件路径,如果为None则使用默认路径
  14. """
  15. self.config_path = config_path or os.path.join(
  16. os.path.dirname(os.path.dirname(__file__)),
  17. "config",
  18. "prompts.py"
  19. )
  20. self.config: Dict[str, Any] = {}
  21. self.load_config()
  22. def load_config(self) -> None:
  23. """加载配置文件"""
  24. try:
  25. # 检查是否为Python配置文件
  26. if self.config_path.endswith('.py'):
  27. self._load_py_config()
  28. else:
  29. self._load_yaml_config()
  30. except Exception as e:
  31. logger.error(f"Failed to load config: {str(e)}")
  32. raise
  33. def _load_yaml_config(self) -> None:
  34. """加载YAML配置文件"""
  35. if not os.path.exists(self.config_path):
  36. logger.error(f"Config file not found: {self.config_path}")
  37. raise FileNotFoundError(f"Config file not found: {self.config_path}")
  38. with open(self.config_path, 'r', encoding='utf-8') as f:
  39. self.config = yaml.safe_load(f)
  40. logger.info(f"Successfully loaded YAML config from {self.config_path}")
  41. def _load_py_config(self) -> None:
  42. """从Python文件加载配置"""
  43. if not os.path.exists(self.config_path):
  44. logger.error(f"Config file not found: {self.config_path}")
  45. raise FileNotFoundError(f"Config file not found: {self.config_path}")
  46. # 使用importlib加载Python模块
  47. spec = importlib.util.spec_from_file_location("config_module", self.config_path)
  48. if spec is None:
  49. raise ImportError(f"Could not load spec from {self.config_path}")
  50. module = importlib.util.module_from_spec(spec)
  51. if spec.loader is not None:
  52. spec.loader.exec_module(module)
  53. # 解析Python配置文件中的分层结构
  54. py_config = self._parse_py_config(module)
  55. # 构建与YAML配置相同的结构
  56. self.config = py_config
  57. logger.info(f"Successfully loaded Python config from {self.config_path}")
  58. def _parse_py_config(self, module) -> Dict[str, Any]:
  59. """
  60. 解析Python配置文件中的分层结构
  61. Args:
  62. module: 加载的Python模块
  63. Returns:
  64. Dict: 解析后的配置字典
  65. """
  66. # 初始化配置结构
  67. config = {
  68. "defaults": {
  69. "video": {},
  70. "image": {},
  71. "text": {}
  72. },
  73. "scenarios": {
  74. "news": {
  75. "video": {},
  76. "image": {},
  77. "text": {}
  78. },
  79. "entertainment": {
  80. "video": {},
  81. "image": {},
  82. "text": {}
  83. },
  84. "academic": {
  85. "video": {},
  86. "image": {},
  87. "text": {}
  88. }
  89. }
  90. }
  91. # 查找模块中的所有大写变量并分类
  92. for attr_name in dir(module):
  93. if attr_name.isupper() and not attr_name.startswith('_'): # 只获取大写的变量名
  94. attr_value = getattr(module, attr_name)
  95. if isinstance(attr_value, str):
  96. # 根据变量名前缀分类到不同的媒体类型和场景
  97. self._categorize_config_item(config, attr_name, attr_value)
  98. return config
  99. def _categorize_config_item(self, config: Dict, attr_name: str, attr_value: str) -> None:
  100. """
  101. 根据变量名将配置项分类到适当的媒体类型和场景中
  102. Args:
  103. config: 配置字典
  104. attr_name: 变量名
  105. attr_value: 变量值
  106. """
  107. # 转换为小写以便比较
  108. name_lower = attr_name.lower()
  109. # 确定场景
  110. scenario = None
  111. if 'news' in name_lower:
  112. scenario = 'news'
  113. elif 'entertainment' in name_lower:
  114. scenario = 'entertainment'
  115. elif 'academic' in name_lower:
  116. scenario = 'academic'
  117. # 确定媒体类型
  118. media_type = 'text' # 默认为text
  119. if 'video' in name_lower:
  120. media_type = 'video'
  121. elif 'image' in name_lower:
  122. media_type = 'image'
  123. # 确定提示词类型(去除前缀后的部分)
  124. prompt_type = self._extract_prompt_type(attr_name, scenario, media_type)
  125. # 将配置项放入相应的位置
  126. if scenario:
  127. config['scenarios'][scenario][media_type][prompt_type] = attr_value
  128. else:
  129. config['defaults'][media_type][prompt_type] = attr_value
  130. def _extract_prompt_type(self, attr_name: str, scenario: Optional[str], media_type: str) -> str:
  131. """
  132. 从变量名中提取提示词类型
  133. Args:
  134. attr_name: 变量名
  135. scenario: 场景类型
  136. media_type: 媒体类型
  137. Returns:
  138. str: 提示词类型
  139. """
  140. # 移除场景前缀
  141. name = attr_name
  142. if scenario:
  143. scenario_prefix = scenario.upper()
  144. if name.startswith(scenario_prefix):
  145. name = name[len(scenario_prefix):]
  146. # 移除媒体类型前缀
  147. media_prefix = media_type.upper()
  148. if name.startswith(media_prefix):
  149. name = name[len(media_prefix):]
  150. # 移除下划线分隔符
  151. if name.startswith('_'):
  152. name = name[1:]
  153. # 转换为小写作为提示词类型
  154. return name.lower() if name else 'default'
  155. def get_prompt(self,
  156. media_type: str,
  157. prompt_type: str = "caption",
  158. scenario: Optional[str] = None) -> str:
  159. """
  160. 获取指定类型的提示词
  161. Args:
  162. media_type: 媒体类型 ('video' 或 'image' 或 'text')
  163. prompt_type: 提示词类型 (如 'caption', 'scene' 等)
  164. scenario: 场景类型 (如 'news', 'entertainment' 等)
  165. Returns:
  166. str: 提示词
  167. Raises:
  168. KeyError: 指定的配置不存在
  169. """
  170. try:
  171. # 如果指定了场景,优先使用场景特定配置
  172. if scenario and scenario in self.config.get("scenarios", {}):
  173. scenario_config = self.config["scenarios"][scenario]
  174. if media_type in scenario_config and prompt_type in scenario_config[media_type]:
  175. result = scenario_config[media_type][prompt_type]
  176. # 如果是从Python文件加载的,提取content字段
  177. if isinstance(result, dict) and "content" in result:
  178. return result["content"]
  179. return str(result)
  180. # 回退到默认配置
  181. if prompt_type in self.config["defaults"][media_type]:
  182. prompt_data = self.config["defaults"][media_type][prompt_type]
  183. # 如果是从Python文件加载的,提取content字段
  184. if isinstance(prompt_data, dict) and "content" in prompt_data:
  185. return prompt_data["content"]
  186. return str(prompt_data)
  187. raise KeyError(f"Prompt not found for {media_type}/{prompt_type}")
  188. except KeyError as e:
  189. logger.error(f"Failed to get prompt: {str(e)}")
  190. # 返回基础提示词作为后备
  191. return ("视频里有什么?" if media_type == "video" else "请描述图片内容")
  192. def get_all_prompts(self,
  193. media_type: str,
  194. scenario: Optional[str] = None) -> Dict[str, str]:
  195. """
  196. 获取指定媒体类型的所有提示词
  197. Args:
  198. media_type: 媒体类型 ('video' 或 'image' 或 'text')
  199. scenario: 场景类型 (如 'news', 'entertainment' 等)
  200. Returns:
  201. Dict[str, str]: 提示词类型到提示词的映射
  202. """
  203. prompts = {}
  204. # 获取默认配置
  205. if media_type in self.config["defaults"]:
  206. for key, value in self.config["defaults"][media_type].items():
  207. # 如果是从Python文件加载的,提取content字段
  208. if isinstance(value, dict) and "content" in value:
  209. prompts[key] = value["content"]
  210. else:
  211. prompts[key] = str(value)
  212. # 如果指定了场景,添加或覆盖场景特定配置
  213. if scenario and scenario in self.config.get("scenarios", {}):
  214. scenario_config = self.config["scenarios"][scenario]
  215. if media_type in scenario_config:
  216. for key, value in scenario_config[media_type].items():
  217. # 如果是从Python文件加载的,提取content字段
  218. if isinstance(value, dict) and "content" in value:
  219. prompts[key] = value["content"]
  220. else:
  221. prompts[key] = str(value)
  222. return prompts
  223. if __name__ == "__main__":
  224. # 创建ConfigManager实例
  225. config_manager = ConfigManager()
  226. # 获取所有提示词
  227. all_prompts = config_manager.get_all_prompts("video")
  228. print(all_prompts)
  229. # 获取指定类型的提示词
  230. caption_prompt = config_manager.get_prompt("video", "script")
  231. print(caption_prompt)