qa_chain.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # 标准库导入
  2. import re
  3. import json
  4. import time
  5. from functools import lru_cache
  6. from typing import Dict, List, Any, Optional
  7. # 第三方库导入
  8. from langchain.chains.base import Chain
  9. from pydantic import Field
  10. # 本地导入
  11. from utils.logger_config import setup_logger
  12. from utils.prompt_config import prompt_router
  13. from utils.rag_config import BaseMethod, ConfigManager
  14. from module.intent_chain import IntentChain
  15. from utils.common import str_to_json
  16. # 配置
  17. logger = setup_logger(__name__)
  18. # 提取商品SKU
  19. def extract_sku(string):
  20. pattern = re.compile(r'(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{6,}')
  21. result = re.findall(pattern, string)
  22. return result
  23. class QAChain(Chain):
  24. """
  25. 意图链,用于处理用户问题并返回答案
  26. 继承自Chain,并组合BaseMethod的功能
  27. """
  28. # 声明Pydantic字段
  29. config_manager: ConfigManager = Field(default_factory=ConfigManager)
  30. base_method: BaseMethod = Field(default_factory=BaseMethod)
  31. intent_chain: IntentChain = Field(default_factory=IntentChain)
  32. # prompt_template: Any = None
  33. def __init__(self, config_path=None):
  34. """
  35. 初始化意图链
  36. Args:
  37. config_path: 配置文件路径,如果为None则使用默认路径
  38. """
  39. super().__init__()
  40. # 使用指定的配置文件初始化
  41. if config_path:
  42. self.config_manager = ConfigManager(config_path)
  43. self.base_method = BaseMethod(config_path)
  44. self.intent_chain = IntentChain("/data/data/luosy/project/AIqa/config/intent_config.json")
  45. def _init_prompt_template(self, prompt_name: str) -> None:
  46. """初始化提示词模板"""
  47. try:
  48. return prompt_router(prompt_name)
  49. except Exception as e:
  50. logger.error(f"初始化提示词模板失败: {e}")
  51. raise
  52. @property
  53. def input_keys(self) -> List[str]:
  54. """定义输入键"""
  55. return ["question"]
  56. @property
  57. def output_keys(self) -> List[str]:
  58. """定义输出键"""
  59. return ["answer"]
  60. def _format_documents(self, documents: List[str]) -> str:
  61. """
  62. 格式化检索到的文档
  63. Args:
  64. documents: 文档列表
  65. Returns:
  66. 格式化后的文档字符串
  67. """
  68. retriever_text = " ".join([doc for doc in documents])
  69. return retriever_text
  70. def _get_history_memory(self, question: str, answer:str):
  71. logger.info(f"qa_chain history message: {self.base_method.memory.buffer_as_str}")
  72. self.base_method.memory.save_context({"input": question}, {"output": answer})
  73. def _retrieve_documents(self, question: str) -> List[str]:
  74. """
  75. 检索相关文档
  76. Args:
  77. question: 用户问题
  78. Returns:
  79. 相关文档内容列表
  80. Raises:
  81. Exception: 检索失败时抛出异常
  82. """
  83. try:
  84. # csv_retriever for search csv file; retriever for search txt file
  85. retrieved_docs = self.base_method.retriever(question)
  86. if not retrieved_docs:
  87. logger.warning(f"未找到相关文档: {question}")
  88. return []
  89. return [doc.page_content for doc in retrieved_docs]
  90. except Exception as e:
  91. logger.error(f"文档检索失败: {e}")
  92. raise
  93. def _retrieve_product(self, question: str) -> List[str]:
  94. sku = extract_sku(question)
  95. if not sku:
  96. return ""
  97. else:
  98. try:
  99. # csv_retriever for search csv file; retriever for search txt file
  100. retrieved_docs = self.base_method.csv_retriever(question, "/data/data/luosy/project/AIqa/vectordb/vectordb_product")
  101. if not retrieved_docs:
  102. logger.warning(f"未找到相关文档: {question}")
  103. return []
  104. return [doc.page_content for doc in retrieved_docs]
  105. except Exception as e:
  106. logger.error(f"文档检索失败: {e}")
  107. raise
  108. def _generate_answer(self, intent: str, context: str, question: str, prompt_name: str) -> str:
  109. """
  110. 生成答案
  111. Args:
  112. context: 上下文信息
  113. question: 用户问题
  114. Returns:
  115. 生成的答案
  116. Raises:
  117. Exception: 生成答案失败时抛出异常
  118. """
  119. try:
  120. prompt = self._init_prompt_template(prompt_name).format(
  121. intent=intent,
  122. history=self.base_method.memory.buffer_as_str,
  123. context=context,
  124. question=question,
  125. )
  126. # print("---------------------------------------------------------------------------------------")
  127. # print(f"prompt: {prompt}")
  128. # print("---------------------------------------------------------------------------------------")
  129. return self.base_method.model_config.llm(prompt).content
  130. except Exception as e:
  131. logger.error(f"生成答案失败: {e}")
  132. raise
  133. def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
  134. """
  135. 处理用户输入并返回答案
  136. Args:
  137. inputs: 包含用户问题的字典
  138. Returns:
  139. 包含答案的字典
  140. Raises:
  141. KeyError: 输入缺少必要字段
  142. Exception: 处理过程中的其他异常
  143. """
  144. try:
  145. # 获取意图
  146. intent_content = self.intent_chain.invoke({"question": inputs["question"], "history": self.base_method.memory.buffer_as_str})["answer"]
  147. # 1. 首先检查 intent_content 是否为空
  148. if not intent_content:
  149. logger.error("intent_content 为空")
  150. return {"error": "意图内容为空"}
  151. # 2. 添加日志记录查看实际内容
  152. logger.debug(f"intent_content: {intent_content}")
  153. # 3. 添加错误处理
  154. try:
  155. json_data = str_to_json(intent_content)
  156. prompt_name = json_data.get("用户意图类型")
  157. if not prompt_name:
  158. return {"error": "未找到用户意图类型"}
  159. except json.JSONDecodeError as e:
  160. logger.error(f"JSON 解析错误: {str(e)}, content: {intent_content}")
  161. return {"error": "JSON 格式错误"}
  162. # 参数验证
  163. if "question" not in inputs:
  164. raise KeyError("输入缺少'question'字段")
  165. question = inputs["question"]
  166. if not isinstance(question, str) or not question.strip():
  167. raise ValueError("问题不能为空")
  168. # 检索问答
  169. search_start_time = time.time()
  170. documents = self._retrieve_documents(question + " " + intent_content)
  171. if not documents:
  172. documents = "NaN"
  173. # print(f"-------------------------------------------------------------------------------\n 自动回复检索结果: \n {documents} \n-------------------------------------------------------------------------------")
  174. logger.info(f"自动回复检索耗时: {time.time() - search_start_time} ,自动回复检索结果: {documents}")
  175. # 检索商品
  176. products = self._retrieve_product(question)
  177. ## if not products:
  178. ## products = "NaN"
  179. # print(f"-------------------------------------------------------------------------------\n 检索商品结果: \n {products} \n-------------------------------------------------------------------------------")
  180. logger.info(f"检索商品结果: {products}")
  181. # 格式化文档
  182. context = self._format_documents(documents)
  183. product_context = self._format_documents(products)
  184. # 生成答案
  185. generate_start_time = time.time()
  186. answer = self._generate_answer(intent_content, context, question, prompt_name)
  187. logger.info(f"自动回复耗时: {time.time() - generate_start_time} ,自动回复结果: {answer}")
  188. # 保存历史记录
  189. self._get_history_memory(question, answer)
  190. if product_context:
  191. self._get_history_memory(product_context, "以上是客户正在咨询的商品")
  192. # 返回答案
  193. return {"answer": answer}
  194. except KeyError as e:
  195. logger.error(f"输入参数错误: {e}")
  196. return {"answer": "系统处理出现错误,请确保输入正确的问题格式。"}
  197. except Exception as e:
  198. logger.error(f"处理问题失败: {e}")
  199. return {"answer": "抱歉,系统处理您的请求时遇到了问题,请稍后重试。"}
  200. if __name__ == "__main__":
  201. # 使用示例
  202. chain = QAChain()
  203. try:
  204. result = chain.invoke({"question": "你好,请问有什么可以帮助你?"})
  205. print(f"回答: {result['answer']}")
  206. except Exception as e:
  207. logger.error(f"处理失败: {e}")