123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- # 标准库导入
- import json
- import warnings
- from pathlib import Path
- from typing import Dict, Any, Optional
- # 第三方库导入
- from langchain_core.prompts import ChatPromptTemplate
- from langchain_core.messages import SystemMessage, HumanMessage
- # 本地导入
- from .logger_config import setup_logger
- # 配置
- warnings.filterwarnings("ignore")
- logger = setup_logger(__name__)
- class PromptManager:
- """提示词管理器,负责加载和路由不同类型的提示词模板"""
-
- DEFAULT_PROMPT = "answer_prompt.json"
- REQUIRED_KEYS = {"system_template", "human_template"}
-
- def __init__(self, config_dir: str = "./config"):
- self.config_dir = Path(config_dir)
- if not self.config_dir.exists():
- raise FileNotFoundError(f"配置目录不存在: {self.config_dir}")
-
- def load_config(self, intent_type: str) -> Dict[str, Any]:
- """
- 加载指定意图类型的配置文件
-
- Args:
- intent_type: 意图类型名称
-
- Returns:
- 配置字典
-
- Raises:
- FileNotFoundError: 配置文件不存在
- json.JSONDecodeError: 配置文件格式错误
- """
- config_path = self.config_dir / f"{intent_type}.json"
-
- try:
- if not config_path.exists():
- logger.warning(f"配置文件 {config_path} 不存在,使用默认配置")
- config_path = self.config_dir / self.DEFAULT_PROMPT
-
- with open(config_path, "r", encoding="utf-8") as f:
- config = json.load(f)
-
- self._validate_config(config)
- return config
-
- except json.JSONDecodeError as e:
- logger.error(f"配置文件格式错误: {config_path}")
- raise
-
- def _validate_config(self, config: Dict[str, Any]) -> None:
- """验证配置文件格式是否正确"""
- missing_keys = self.REQUIRED_KEYS - set(config.keys())
- if missing_keys:
- raise KeyError(f"配置缺少必要字段: {missing_keys}")
-
- def get_prompt_template(self, intent_type: str) -> ChatPromptTemplate:
- """
- 获取指定意图类型的提示词模板
-
- Args:
- intent_type: 意图类型名称
-
- Returns:
- ChatPromptTemplate: 提示词模板
- """
- config = self.load_config(intent_type)
-
- return ChatPromptTemplate.from_messages([
- ("system", config["system_template"]),
- ("human", config["human_template"])
- ])
- def prompt_router(intent_type: str) -> ChatPromptTemplate:
- """
- 提示词路由函数(向后兼容的接口)
-
- Args:
- intent_type: 意图类型名称
-
- Returns:
- ChatPromptTemplate: 提示词模板
- """
- prompt_manager = PromptManager()
- return prompt_manager.get_prompt_template(intent_type)
- if __name__ == "__main__":
- # 使用示例
- try:
- template = prompt_router("qa")
- print(template)
- except Exception as e:
- logger.error(f"获取提示词模板失败: {e}")
|