# 标准库导入 import json import warnings from pathlib import Path from typing import Dict, Any, Optional # 第三方库导入 from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import SystemMessage, HumanMessage # 本地导入 from .logger_config import setup_logger # 配置 warnings.filterwarnings("ignore") logger = setup_logger(__name__) class PromptManager: """提示词管理器,负责加载和路由不同类型的提示词模板""" DEFAULT_PROMPT = "answer_prompt.json" REQUIRED_KEYS = {"system_template", "human_template"} def __init__(self, config_dir: str = "./config"): self.config_dir = Path(config_dir) if not self.config_dir.exists(): raise FileNotFoundError(f"配置目录不存在: {self.config_dir}") def load_config(self, intent_type: str) -> Dict[str, Any]: """ 加载指定意图类型的配置文件 Args: intent_type: 意图类型名称 Returns: 配置字典 Raises: FileNotFoundError: 配置文件不存在 json.JSONDecodeError: 配置文件格式错误 """ config_path = self.config_dir / f"{intent_type}.json" try: if not config_path.exists(): logger.warning(f"配置文件 {config_path} 不存在,使用默认配置") config_path = self.config_dir / self.DEFAULT_PROMPT with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) self._validate_config(config) return config except json.JSONDecodeError as e: logger.error(f"配置文件格式错误: {config_path}") raise def _validate_config(self, config: Dict[str, Any]) -> None: """验证配置文件格式是否正确""" missing_keys = self.REQUIRED_KEYS - set(config.keys()) if missing_keys: raise KeyError(f"配置缺少必要字段: {missing_keys}") def get_prompt_template(self, intent_type: str) -> ChatPromptTemplate: """ 获取指定意图类型的提示词模板 Args: intent_type: 意图类型名称 Returns: ChatPromptTemplate: 提示词模板 """ config = self.load_config(intent_type) return ChatPromptTemplate.from_messages([ ("system", config["system_template"]), ("human", config["human_template"]) ]) def prompt_router(intent_type: str) -> ChatPromptTemplate: """ 提示词路由函数(向后兼容的接口) Args: intent_type: 意图类型名称 Returns: ChatPromptTemplate: 提示词模板 """ prompt_manager = PromptManager() return prompt_manager.get_prompt_template(intent_type) if __name__ == "__main__": # 使用示例 try: template = prompt_router("qa") print(template) except Exception as e: logger.error(f"获取提示词模板失败: {e}")