123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- # 标准库导入
- import re
- import json
- import time
- from functools import lru_cache
- from typing import Dict, List, Any, Optional
- # 第三方库导入
- from langchain.chains.base import Chain
- from pydantic import Field
- # 本地导入
- from utils.logger_config import setup_logger
- from utils.prompt_config import prompt_router
- from utils.rag_config import BaseMethod, ConfigManager
- from module.intent_chain import IntentChain
- from utils.common import str_to_json
- # 配置
- logger = setup_logger(__name__)
- # 提取商品SKU
- def extract_sku(string):
- pattern = re.compile(r'(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{6,}')
- result = re.findall(pattern, string)
- return result
- class QAChain(Chain):
- """
- 意图链,用于处理用户问题并返回答案
- 继承自Chain,并组合BaseMethod的功能
- """
-
- # 声明Pydantic字段
- config_manager: ConfigManager = Field(default_factory=ConfigManager)
- base_method: BaseMethod = Field(default_factory=BaseMethod)
- intent_chain: IntentChain = Field(default_factory=IntentChain)
- # prompt_template: Any = None
- def __init__(self, config_path=None):
- """
- 初始化意图链
-
- Args:
- config_path: 配置文件路径,如果为None则使用默认路径
- """
- super().__init__()
- # 使用指定的配置文件初始化
- if config_path:
- self.config_manager = ConfigManager(config_path)
- self.base_method = BaseMethod(config_path)
- self.intent_chain = IntentChain("/data/data/luosy/project/AIqa/config/intent_config.json")
-
- def _init_prompt_template(self, prompt_name: str) -> None:
- """初始化提示词模板"""
- try:
- return prompt_router(prompt_name)
- except Exception as e:
- logger.error(f"初始化提示词模板失败: {e}")
- raise
- @property
- def input_keys(self) -> List[str]:
- """定义输入键"""
- return ["question"]
- @property
- def output_keys(self) -> List[str]:
- """定义输出键"""
- return ["answer"]
- def _format_documents(self, documents: List[str]) -> str:
- """
- 格式化检索到的文档
-
- Args:
- documents: 文档列表
-
- Returns:
- 格式化后的文档字符串
- """
- retriever_text = " ".join([doc for doc in documents])
- return retriever_text
-
- def _get_history_memory(self, question: str, answer:str):
- logger.info(f"qa_chain history message: {self.base_method.memory.buffer_as_str}")
- self.base_method.memory.save_context({"input": question}, {"output": answer})
- def _retrieve_documents(self, question: str) -> List[str]:
- """
- 检索相关文档
-
- Args:
- question: 用户问题
-
- Returns:
- 相关文档内容列表
-
- Raises:
- Exception: 检索失败时抛出异常
- """
- try:
- # csv_retriever for search csv file; retriever for search txt file
- retrieved_docs = self.base_method.retriever(question)
- if not retrieved_docs:
- logger.warning(f"未找到相关文档: {question}")
- return []
- return [doc.page_content for doc in retrieved_docs]
- except Exception as e:
- logger.error(f"文档检索失败: {e}")
- raise
- def _retrieve_product(self, question: str) -> List[str]:
- sku = extract_sku(question)
- if not sku:
- return ""
- else:
- try:
- # csv_retriever for search csv file; retriever for search txt file
- retrieved_docs = self.base_method.csv_retriever(question, "/data/data/luosy/project/AIqa/vectordb/vectordb_product")
- if not retrieved_docs:
- logger.warning(f"未找到相关文档: {question}")
- return []
- return [doc.page_content for doc in retrieved_docs]
- except Exception as e:
- logger.error(f"文档检索失败: {e}")
- raise
- def _generate_answer(self, intent: str, context: str, question: str, prompt_name: str) -> str:
- """
- 生成答案
-
- Args:
- context: 上下文信息
- question: 用户问题
-
- Returns:
- 生成的答案
-
- Raises:
- Exception: 生成答案失败时抛出异常
- """
- try:
- prompt = self._init_prompt_template(prompt_name).format(
- intent=intent,
- history=self.base_method.memory.buffer_as_str,
- context=context,
- question=question,
- )
- # print("---------------------------------------------------------------------------------------")
- # print(f"prompt: {prompt}")
- # print("---------------------------------------------------------------------------------------")
- return self.base_method.model_config.llm(prompt).content
- except Exception as e:
- logger.error(f"生成答案失败: {e}")
- raise
- def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
- """
- 处理用户输入并返回答案
-
- Args:
- inputs: 包含用户问题的字典
-
- Returns:
- 包含答案的字典
-
- Raises:
- KeyError: 输入缺少必要字段
- Exception: 处理过程中的其他异常
- """
- try:
- # 获取意图
- intent_content = self.intent_chain.invoke({"question": inputs["question"], "history": self.base_method.memory.buffer_as_str})["answer"]
-
- # 1. 首先检查 intent_content 是否为空
- if not intent_content:
- logger.error("intent_content 为空")
- return {"error": "意图内容为空"}
-
- # 2. 添加日志记录查看实际内容
- logger.debug(f"intent_content: {intent_content}")
-
- # 3. 添加错误处理
- try:
- json_data = str_to_json(intent_content)
- prompt_name = json_data.get("用户意图类型")
- if not prompt_name:
- return {"error": "未找到用户意图类型"}
- except json.JSONDecodeError as e:
- logger.error(f"JSON 解析错误: {str(e)}, content: {intent_content}")
- return {"error": "JSON 格式错误"}
- # 参数验证
- if "question" not in inputs:
- raise KeyError("输入缺少'question'字段")
-
- question = inputs["question"]
- if not isinstance(question, str) or not question.strip():
- raise ValueError("问题不能为空")
- # 检索问答
- search_start_time = time.time()
- documents = self._retrieve_documents(question + " " + intent_content)
- if not documents:
- documents = "NaN"
- # print(f"-------------------------------------------------------------------------------\n 自动回复检索结果: \n {documents} \n-------------------------------------------------------------------------------")
- logger.info(f"自动回复检索耗时: {time.time() - search_start_time} ,自动回复检索结果: {documents}")
- # 检索商品
- products = self._retrieve_product(question)
- ## if not products:
- ## products = "NaN"
- # print(f"-------------------------------------------------------------------------------\n 检索商品结果: \n {products} \n-------------------------------------------------------------------------------")
- logger.info(f"检索商品结果: {products}")
- # 格式化文档
- context = self._format_documents(documents)
- product_context = self._format_documents(products)
- # 生成答案
- generate_start_time = time.time()
- answer = self._generate_answer(intent_content, context, question, prompt_name)
- logger.info(f"自动回复耗时: {time.time() - generate_start_time} ,自动回复结果: {answer}")
- # 保存历史记录
- self._get_history_memory(question, answer)
- if product_context:
- self._get_history_memory(product_context, "以上是客户正在咨询的商品")
- # 返回答案
- return {"answer": answer}
-
- except KeyError as e:
- logger.error(f"输入参数错误: {e}")
- return {"answer": "系统处理出现错误,请确保输入正确的问题格式。"}
- except Exception as e:
- logger.error(f"处理问题失败: {e}")
- return {"answer": "抱歉,系统处理您的请求时遇到了问题,请稍后重试。"}
- if __name__ == "__main__":
- # 使用示例
- chain = QAChain()
- try:
- result = chain.invoke({"question": "你好,请问有什么可以帮助你?"})
- print(f"回答: {result['answer']}")
- except Exception as e:
- logger.error(f"处理失败: {e}")
|