intent_chain.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # 标准库导入
  2. import time
  3. from functools import lru_cache
  4. from typing import Dict, List, Any, Optional
  5. # 第三方库导入
  6. from langchain.chains.base import Chain
  7. from pydantic import Field
  8. # 本地导入
  9. from utils.logger_config import setup_logger
  10. from utils.prompt_config import prompt_router
  11. from utils.rag_config import BaseMethod, ConfigManager
  12. # 配置
  13. logger = setup_logger(__name__)
  14. class IntentChain(Chain):
  15. """
  16. 意图链,用于处理用户问题并返回答案
  17. 继承自Chain,并组合BaseMethod的功能
  18. """
  19. # 声明Pydantic字段
  20. config_manager: ConfigManager = Field(default_factory=ConfigManager)
  21. base_method: BaseMethod = Field(default_factory=BaseMethod)
  22. prompt_template: Any = None
  23. def __init__(self, config_path=None):
  24. """
  25. 初始化意图链
  26. Args:
  27. config_path: 配置文件路径,如果为None则使用默认路径
  28. """
  29. super().__init__()
  30. # 使用指定的配置文件初始化
  31. if config_path:
  32. self.config_manager = ConfigManager(config_path)
  33. self.base_method = BaseMethod(config_path)
  34. self._init_prompt_template()
  35. def _init_prompt_template(self) -> None:
  36. """初始化提示词模板"""
  37. try:
  38. self.prompt_template = prompt_router("intent_prompt")
  39. except Exception as e:
  40. logger.error(f"初始化提示词模板失败: {e}")
  41. raise
  42. @property
  43. def input_keys(self) -> List[str]:
  44. """定义输入键"""
  45. return ["question"]
  46. @property
  47. def output_keys(self) -> List[str]:
  48. """定义输出键"""
  49. return ["answer"]
  50. def _format_documents(self, documents: List[str]) -> str:
  51. """
  52. 格式化检索到的文档
  53. Args:
  54. documents: 文档列表
  55. Returns:
  56. 格式化后的文档字符串
  57. """
  58. retriever_text = " ".join([doc for doc in documents])
  59. return retriever_text
  60. def _retrieve_documents(self, question: str) -> List[str]:
  61. """
  62. 检索相关文档
  63. Args:
  64. question: 用户问题
  65. Returns:
  66. 相关文档内容列表
  67. Raises:
  68. Exception: 检索失败时抛出异常
  69. """
  70. try:
  71. retrieved_docs = self.base_method.csv_retriever(question)
  72. if not retrieved_docs:
  73. logger.warning(f"未找到相关文档: {question}")
  74. return []
  75. return [doc.page_content for doc in retrieved_docs]
  76. except Exception as e:
  77. logger.error(f"文档检索失败: {e}")
  78. raise
  79. def _generate_answer(self, context: str, question: str, history: str) -> str:
  80. """
  81. 生成答案
  82. Args:
  83. context: 上下文信息
  84. question: 用户问题
  85. history: 历史对话记录
  86. Returns:
  87. 生成的答案
  88. Raises:
  89. Exception: 生成答案失败时抛出异常
  90. """
  91. try:
  92. prompt = self.prompt_template.format(
  93. history=history,
  94. context=context,
  95. question=question
  96. )
  97. return self.base_method.model_config.llm(prompt).content
  98. except Exception as e:
  99. logger.error(f"生成答案失败: {e}")
  100. raise
  101. def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
  102. """
  103. 处理用户输入并返回答案
  104. Args:
  105. inputs: 包含用户问题的字典
  106. history: 历史对话记录
  107. Returns:
  108. 包含答案的字典
  109. Raises:
  110. KeyError: 输入缺少必要字段
  111. Exception: 处理过程中的其他异常
  112. """
  113. try:
  114. # 参数验证
  115. if "question" not in inputs:
  116. raise KeyError("输入缺少'question'字段")
  117. question = inputs["question"]
  118. if not isinstance(question, str) or not question.strip():
  119. raise ValueError("问题不能为空")
  120. if "history" not in inputs:
  121. raise KeyError("输入缺少'history'字段")
  122. history = inputs["history"]
  123. logger.info(f"intent_chain history message: {history}")
  124. # 检索文档
  125. search_start_time = time.time()
  126. documents = self._retrieve_documents(question)
  127. if not documents:
  128. documents = "NaN"
  129. logger.info(f"意图识别向量库-检索耗时: {time.time() - search_start_time} 秒, 意图识别检索到文档: {documents}")
  130. # 格式化文档
  131. context = self._format_documents(documents)
  132. logger.info(f"intent_chain retriever context: {context}")
  133. # 生成答案
  134. generate_start_time = time.time()
  135. answer = self._generate_answer(context, question, history)
  136. logger.info(f"意图识别耗时: {time.time() - generate_start_time} 秒, 意图识别结果: {answer}")
  137. # 返回结果
  138. return {"answer": answer}
  139. except (KeyError, ValueError) as e:
  140. logger.error(f"输入参数错误: {e}")
  141. raise
  142. except Exception as e:
  143. logger.error(f"意图识别失败: {str(e)}")
  144. return {"answer": "意图识别过程出现错误"}
  145. if __name__ == "__main__":
  146. # 使用示例
  147. chain = IntentChain()
  148. history = "你好,请问有什么可以帮助你?"
  149. try:
  150. result = chain.invoke({"question": "你好,请问有什么可以帮助你?", "history": history})
  151. print(f"回答: {result['answer']}")
  152. except Exception as e:
  153. logger.error(f"处理失败: {e}")