# 标准库导入 import re import json import time from functools import lru_cache from typing import Dict, List, Any, Optional # 第三方库导入 from langchain.chains.base import Chain from pydantic import Field # 本地导入 from utils.logger_config import setup_logger from utils.prompt_config import prompt_router from utils.rag_config import BaseMethod, ConfigManager from module.intent_chain import IntentChain from utils.common import str_to_json # 配置 logger = setup_logger(__name__) # 提取商品SKU def extract_sku(string): pattern = re.compile(r'(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{6,}') result = re.findall(pattern, string) return result class QAChain(Chain): """ 意图链,用于处理用户问题并返回答案 继承自Chain,并组合BaseMethod的功能 """ # 声明Pydantic字段 config_manager: ConfigManager = Field(default_factory=ConfigManager) base_method: BaseMethod = Field(default_factory=BaseMethod) intent_chain: IntentChain = Field(default_factory=IntentChain) # prompt_template: Any = None def __init__(self, config_path=None): """ 初始化意图链 Args: config_path: 配置文件路径,如果为None则使用默认路径 """ super().__init__() # 使用指定的配置文件初始化 if config_path: self.config_manager = ConfigManager(config_path) self.base_method = BaseMethod(config_path) self.intent_chain = IntentChain("/data/data/luosy/project/AIqa/config/intent_config.json") def _init_prompt_template(self, prompt_name: str) -> None: """初始化提示词模板""" try: return prompt_router(prompt_name) except Exception as e: logger.error(f"初始化提示词模板失败: {e}") raise @property def input_keys(self) -> List[str]: """定义输入键""" return ["question"] @property def output_keys(self) -> List[str]: """定义输出键""" return ["answer"] def _format_documents(self, documents: List[str]) -> str: """ 格式化检索到的文档 Args: documents: 文档列表 Returns: 格式化后的文档字符串 """ retriever_text = " ".join([doc for doc in documents]) return retriever_text def _get_history_memory(self, question: str, answer:str): logger.info(f"qa_chain history message: {self.base_method.memory.buffer_as_str}") self.base_method.memory.save_context({"input": question}, {"output": answer}) def _retrieve_documents(self, question: str) -> List[str]: """ 检索相关文档 Args: question: 用户问题 Returns: 相关文档内容列表 Raises: Exception: 检索失败时抛出异常 """ try: # csv_retriever for search csv file; retriever for search txt file retrieved_docs = self.base_method.retriever(question) if not retrieved_docs: logger.warning(f"未找到相关文档: {question}") return [] return [doc.page_content for doc in retrieved_docs] except Exception as e: logger.error(f"文档检索失败: {e}") raise def _retrieve_product(self, question: str) -> List[str]: sku = extract_sku(question) if not sku: return "" else: try: # csv_retriever for search csv file; retriever for search txt file retrieved_docs = self.base_method.csv_retriever(question, "/data/data/luosy/project/AIqa/vectordb/vectordb_product") if not retrieved_docs: logger.warning(f"未找到相关文档: {question}") return [] return [doc.page_content for doc in retrieved_docs] except Exception as e: logger.error(f"文档检索失败: {e}") raise def _generate_answer(self, intent: str, context: str, question: str, prompt_name: str) -> str: """ 生成答案 Args: context: 上下文信息 question: 用户问题 Returns: 生成的答案 Raises: Exception: 生成答案失败时抛出异常 """ try: prompt = self._init_prompt_template(prompt_name).format( intent=intent, history=self.base_method.memory.buffer_as_str, context=context, question=question, ) # print("---------------------------------------------------------------------------------------") # print(f"prompt: {prompt}") # print("---------------------------------------------------------------------------------------") return self.base_method.model_config.llm(prompt).content except Exception as e: logger.error(f"生成答案失败: {e}") raise def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: """ 处理用户输入并返回答案 Args: inputs: 包含用户问题的字典 Returns: 包含答案的字典 Raises: KeyError: 输入缺少必要字段 Exception: 处理过程中的其他异常 """ try: # 获取意图 intent_content = self.intent_chain.invoke({"question": inputs["question"], "history": self.base_method.memory.buffer_as_str})["answer"] # 1. 首先检查 intent_content 是否为空 if not intent_content: logger.error("intent_content 为空") return {"error": "意图内容为空"} # 2. 添加日志记录查看实际内容 logger.debug(f"intent_content: {intent_content}") # 3. 添加错误处理 try: json_data = str_to_json(intent_content) prompt_name = json_data.get("用户意图类型") if not prompt_name: return {"error": "未找到用户意图类型"} except json.JSONDecodeError as e: logger.error(f"JSON 解析错误: {str(e)}, content: {intent_content}") return {"error": "JSON 格式错误"} # 参数验证 if "question" not in inputs: raise KeyError("输入缺少'question'字段") question = inputs["question"] if not isinstance(question, str) or not question.strip(): raise ValueError("问题不能为空") # 检索问答 search_start_time = time.time() documents = self._retrieve_documents(question + " " + intent_content) if not documents: documents = "NaN" # print(f"-------------------------------------------------------------------------------\n 自动回复检索结果: \n {documents} \n-------------------------------------------------------------------------------") logger.info(f"自动回复检索耗时: {time.time() - search_start_time} ,自动回复检索结果: {documents}") # 检索商品 products = self._retrieve_product(question) ## if not products: ## products = "NaN" # print(f"-------------------------------------------------------------------------------\n 检索商品结果: \n {products} \n-------------------------------------------------------------------------------") logger.info(f"检索商品结果: {products}") # 格式化文档 context = self._format_documents(documents) product_context = self._format_documents(products) # 生成答案 generate_start_time = time.time() answer = self._generate_answer(intent_content, context, question, prompt_name) logger.info(f"自动回复耗时: {time.time() - generate_start_time} ,自动回复结果: {answer}") # 保存历史记录 self._get_history_memory(question, answer) if product_context: self._get_history_memory(product_context, "以上是客户正在咨询的商品") # 返回答案 return {"answer": answer} except KeyError as e: logger.error(f"输入参数错误: {e}") return {"answer": "系统处理出现错误,请确保输入正确的问题格式。"} except Exception as e: logger.error(f"处理问题失败: {e}") return {"answer": "抱歉,系统处理您的请求时遇到了问题,请稍后重试。"} if __name__ == "__main__": # 使用示例 chain = QAChain() try: result = chain.invoke({"question": "你好,请问有什么可以帮助你?"}) print(f"回答: {result['answer']}") except Exception as e: logger.error(f"处理失败: {e}")