prompt_config.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # 标准库导入
  2. import json
  3. import warnings
  4. from pathlib import Path
  5. from typing import Dict, Any, Optional
  6. # 第三方库导入
  7. from langchain_core.prompts import ChatPromptTemplate
  8. from langchain_core.messages import SystemMessage, HumanMessage
  9. # 本地导入
  10. from .logger_config import setup_logger
  11. # 配置
  12. warnings.filterwarnings("ignore")
  13. logger = setup_logger(__name__)
  14. class PromptManager:
  15. """提示词管理器,负责加载和路由不同类型的提示词模板"""
  16. DEFAULT_PROMPT = "answer_prompt.json"
  17. REQUIRED_KEYS = {"system_template", "human_template"}
  18. def __init__(self, config_dir: str = "./config"):
  19. self.config_dir = Path(config_dir)
  20. if not self.config_dir.exists():
  21. raise FileNotFoundError(f"配置目录不存在: {self.config_dir}")
  22. def load_config(self, intent_type: str) -> Dict[str, Any]:
  23. """
  24. 加载指定意图类型的配置文件
  25. Args:
  26. intent_type: 意图类型名称
  27. Returns:
  28. 配置字典
  29. Raises:
  30. FileNotFoundError: 配置文件不存在
  31. json.JSONDecodeError: 配置文件格式错误
  32. """
  33. config_path = self.config_dir / f"{intent_type}.json"
  34. try:
  35. if not config_path.exists():
  36. logger.warning(f"配置文件 {config_path} 不存在,使用默认配置")
  37. config_path = self.config_dir / self.DEFAULT_PROMPT
  38. with open(config_path, "r", encoding="utf-8") as f:
  39. config = json.load(f)
  40. self._validate_config(config)
  41. return config
  42. except json.JSONDecodeError as e:
  43. logger.error(f"配置文件格式错误: {config_path}")
  44. raise
  45. def _validate_config(self, config: Dict[str, Any]) -> None:
  46. """验证配置文件格式是否正确"""
  47. missing_keys = self.REQUIRED_KEYS - set(config.keys())
  48. if missing_keys:
  49. raise KeyError(f"配置缺少必要字段: {missing_keys}")
  50. def get_prompt_template(self, intent_type: str) -> ChatPromptTemplate:
  51. """
  52. 获取指定意图类型的提示词模板
  53. Args:
  54. intent_type: 意图类型名称
  55. Returns:
  56. ChatPromptTemplate: 提示词模板
  57. """
  58. config = self.load_config(intent_type)
  59. return ChatPromptTemplate.from_messages([
  60. ("system", config["system_template"]),
  61. ("human", config["human_template"])
  62. ])
  63. def prompt_router(intent_type: str) -> ChatPromptTemplate:
  64. """
  65. 提示词路由函数(向后兼容的接口)
  66. Args:
  67. intent_type: 意图类型名称
  68. Returns:
  69. ChatPromptTemplate: 提示词模板
  70. """
  71. prompt_manager = PromptManager()
  72. return prompt_manager.get_prompt_template(intent_type)
  73. if __name__ == "__main__":
  74. # 使用示例
  75. try:
  76. template = prompt_router("qa")
  77. print(template)
  78. except Exception as e:
  79. logger.error(f"获取提示词模板失败: {e}")