# 标准库导入 import os import time import json import warnings from uuid import uuid4 from typing import List, Dict, Any, Optional # 设置Python路径 import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 第三方库导入 import numpy as np import pandas as pd from dotenv import load_dotenv from sentence_transformers import SentenceTransformer # LangChain 核心组件 from langchain.schema import Document from langchain.schema.runnable import RunnablePassthrough from langchain.schema.output_parser import StrOutputParser # LangChain 文本处理 from langchain.text_splitter import ( CharacterTextSplitter, RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter ) # LangChain 向量存储 from langchain_community.vectorstores import ( FAISS, Chroma ) # LangChain 文档加载器 from langchain_community.document_loaders import ( TextLoader, CSVLoader, JSONLoader, DirectoryLoader, UnstructuredMarkdownLoader ) # LangChain 检索和压缩 from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import ( CrossEncoderReranker, EmbeddingsFilter, DocumentCompressorPipeline, LLMChainFilter ) # LangChain 模型和嵌入 from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.cross_encoders import HuggingFaceCrossEncoder # LangChain 内存管理 from langchain.memory import ( ConversationBufferWindowMemory, ConversationSummaryMemory ) # 本地模块导入 from utils.logger_config import setup_logger # 配置 warnings.filterwarnings("ignore") logger = setup_logger(__name__) load_dotenv() class ConfigManager: _instances = {} # 使用字典存储不同配置文件的实例 _configs = {} # 存储不同配置文件的配置内容 def __new__(cls, config_path=None): """ 创建ConfigManager实例 Args: config_path: 配置文件路径,如果为None则使用默认路径 Returns: ConfigManager实例 """ if config_path is None: current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "..", "config", "config.json") config_path = os.path.abspath(config_path) if config_path not in cls._instances: cls._instances[config_path] = super(ConfigManager, cls).__new__(cls) cls._instances[config_path]._config_path = config_path return cls._instances[config_path] def __init__(self, config_path=None): """ 初始化ConfigManager Args: config_path: 配置文件路径,如果为None则使用默认路径 """ if config_path is None: current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "..", "config", "qa_config.json") config_path = os.path.abspath(config_path) self._config_path = config_path if config_path not in self._configs: self._load_config() def _load_config(self): """加载配置文件""" try: if not os.path.exists(self._config_path): raise FileNotFoundError(f"配置文件不存在: {self._config_path}") with open(self._config_path, 'r', encoding='utf-8') as f: self._configs[self._config_path] = json.load(f) self._validate_config() except Exception as e: logger.error(f"加载配置文件时出错: {e}") raise def _validate_config(self): """验证配置文件的必要字段""" required_fields = [ "embed_model_name", "reranker_model", "reranker_k", "split_chunks_size", "split_overlap_size", "filter_threshold", "retrievel_k", "llm_model_name", "llm_api_key", "llm_base_url", "max_memory_size", "persist_path", "file_path", "query_db" ] config = self._configs[self._config_path] for field in required_fields: if field not in config: raise ValueError(f"配置文件缺少必要字段: {field}") @property def config(self): """获取当前配置""" return self._configs[self._config_path] class ModelConfig: """模型相关配置""" def __init__(self, config): self.config = config self.embedding_model = self._init_embedding_model() self.rerank_model = self._init_rerank_model() self.hg_embedding = self._init_hg_embedding() self.llm = self._init_llm() def _init_embedding_model(self): return DefineEmbedding(self.config["embed_model_name"]) def _init_rerank_model(self): return HuggingFaceCrossEncoder(model_name=self.config["reranker_model"]) def _init_hg_embedding(self): return HuggingFaceEmbeddings( model_name=self.config["reranker_model"], model_kwargs=self.config["embed_model_kwargs"], encode_kwargs=self.config["embed_encode_kwargs"] ) def _init_llm(self): return ChatOpenAI( temperature=0, max_tokens=None, timeout=None, max_retries=2, model=self.config["llm_model_name"], api_key=self.config["llm_api_key"], base_url=self.config["llm_base_url"] ) class BaseConfig: def __init__(self, config_path=None): """ 初始化基础配置 Args: config_path: 配置文件路径,如果为None则使用默认路径 """ self.config_manager = ConfigManager(config_path) self.config = self.config_manager.config self.model_config = ModelConfig(self.config) # 初始化模型 self.embedding_model = self.model_config.embedding_model self.rerank_model = self.model_config.rerank_model self.hg_embedding = self.model_config.hg_embedding self.llm = self.model_config.llm # 文档切块 self.text_splitter = CharacterTextSplitter( separator="##", chunk_size=self.config["split_chunks_size"], chunk_overlap=self.config["split_overlap_size"], is_separator_regex=False ) # 初始化压缩器和过滤器 self.compressor = CrossEncoderReranker( model=self.rerank_model, top_n=self.config["reranker_k"] ) self.llm_filter = LLMChainFilter.from_llm(self.llm) self.embed_filter = EmbeddingsFilter( embeddings=self.hg_embedding, similarity_threshold=self.config["filter_threshold"] ) # 文档过滤通道 self.pipeline_compressor = DocumentCompressorPipeline( transformers=[self.compressor] ) # 记忆体 self.memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=self.config["max_memory_size"] ) class BotMethod(BaseConfig): def __init__(self, config_path=None): """初始化BotMethod,完成Bot配置的初始化""" super().__init__(config_path) class BaseMethod(BaseConfig): def __init__(self, config_path=None): """初始化BaseMethod,包括配置和向量数据库""" super().__init__(config_path) self._init_vector_db() def _init_vector_db(self) -> None: """初始化或更新向量数据库""" start_time = time.time() # 记录向量化开始时间 try: file_paths = self.config["file_path"] persist_paths = self.config["persist_path"] if isinstance(file_paths, list) and isinstance(persist_paths, list): for file_path, persist_path in zip(file_paths, persist_paths): # 检查向量数据库目录是否存在 if not os.path.exists(persist_path): logger.info(f"向量数据库目录不存在,创建新的向量数据库: {persist_path}") os.makedirs(persist_path, exist_ok=True) self.update_vecdb(file_path, persist_path) else: # 检查CSV文件的修改时间 csv_mtime = os.path.getmtime(file_path) db_mtime = os.path.getmtime(persist_path) if csv_mtime > db_mtime: logger.info(f"CSV文件已更新,重新构建向量数据库") self.update_vecdb(file_path, persist_path) else: logger.info("向量数据库已是最新,无需更新") else: raise ValueError("file_path 和 persist_path 必须是列表类型") except Exception as e: logger.error(f"初始化向量数据库失败: {e}") raise finally: end_time = time.time() # 记录向量化结束时间 logger.info(f"向量化耗时: {end_time - start_time} 秒") def split_txt(self, file_path): documents = [] with open(file_path, "r") as f: text = f.read() documents.append(text) return self.text_splitter.create_documents(documents) def split_csv(self, file_path): # GBK-意图识别知识;utf-8-商品信息测试集2 data = pd.read_csv(file_path, encoding="utf-8") headers = data.columns.tolist() # first column document first_column = data.iloc[:, 0].tolist() first_column_document = [] for idx, text in enumerate(first_column): if isinstance(text, str) and text.strip(): first_column_document.append(Document(metadata={'index': idx}, page_content=text)) # every row document result_list = [] for _, row in data.iterrows(): row_content = [] for header, value in zip(headers, row): if pd.notna(value): row_content.append(f"{header}{':'}{value}") result_list.append("。".join(row_content)) first_row_document = [] for idx, text in enumerate(result_list): if isinstance(text, str) and text.strip(): first_row_document.append(Document(metadata={'index': idx}, page_content=text)) logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_column_document): \n {first_column_document} \n -------------------------------------------------------------------------------\n") logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_row_document): \n {first_row_document} \n -------------------------------------------------------------------------------\n") return first_column_document, first_row_document def update_vecdb(self, file_path, persist_path): _, ext = os.path.splitext(file_path) try: if ext == ".txt": chunks = self.split_txt(file_path) elif ext == ".csv": chunks, _ = self.split_csv(file_path) except Exception as e: logger.warning(f"something wrong when update vertor db!") vecdb = Chroma.from_documents(documents=chunks, embedding=self.embedding_model, persist_directory=persist_path) vecdb.persist() def retriever(self, query, query_db=None): if query_db: vecdb = Chroma(persist_directory=query_db, embedding_function=self.embedding_model) else: vecdb = Chroma(persist_directory=self.config["query_db"], embedding_function=self.embedding_model) try: return ContextualCompressionRetriever(base_compressor=self.pipeline_compressor, base_retriever=vecdb.as_retriever(search_kwargs={"k": self.config["retrievel_k"]})) .invoke(query) except Exception as e: logger.error(f"创建检索器时出错: {e}") raise def csv_retriever(self, query, query_db=None): try: retriever_result = self.retriever(query, query_db) _, chunk_query = self.split_csv(self.config["file_path"][0]) if not retriever_result: logger.warning(f"检索结果为空,返回空列表") return [] else: retriever_index = retriever_result[0].metadata['index'] retriever_context = [document for document in chunk_query if document.metadata['index'] == retriever_index] return retriever_context except Exception as e: logger.error(f"Error in csv_retriever: {e}") return [] class DefineEmbedding: def __init__(self, model_path): self.model = SentenceTransformer(model_path) def embed_documents(self, texts: list) -> list: embedding = self.model.encode(texts, batch_size=32, show_progress_bar=True) return embedding.tolist() def embed_query(self, text: str) -> list: embedding = self.model.encode(text) return embedding.tolist() def __call__(self, texts: str) -> list: return self.embed_documents(texts) if __name__ == "__main__": pass