123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- # 标准库导入
- import os
- import time
- import json
- import warnings
- from uuid import uuid4
- from typing import List, Dict, Any, Optional
- # 设置Python路径
- import sys
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- # 第三方库导入
- import numpy as np
- import pandas as pd
- from dotenv import load_dotenv
- from sentence_transformers import SentenceTransformer
- # LangChain 核心组件
- from langchain.schema import Document
- from langchain.schema.runnable import RunnablePassthrough
- from langchain.schema.output_parser import StrOutputParser
- # LangChain 文本处理
- from langchain.text_splitter import (
- CharacterTextSplitter,
- RecursiveCharacterTextSplitter,
- MarkdownHeaderTextSplitter
- )
- # LangChain 向量存储
- from langchain_community.vectorstores import (
- FAISS,
- Chroma
- )
- # LangChain 文档加载器
- from langchain_community.document_loaders import (
- TextLoader,
- CSVLoader,
- JSONLoader,
- DirectoryLoader,
- UnstructuredMarkdownLoader
- )
- # LangChain 检索和压缩
- from langchain.retrievers import ContextualCompressionRetriever
- from langchain.retrievers.document_compressors import (
- CrossEncoderReranker,
- EmbeddingsFilter,
- DocumentCompressorPipeline,
- LLMChainFilter
- )
- # LangChain 模型和嵌入
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
- from langchain_huggingface import HuggingFaceEmbeddings
- from langchain_community.cross_encoders import HuggingFaceCrossEncoder
- # LangChain 内存管理
- from langchain.memory import (
- ConversationBufferWindowMemory,
- ConversationSummaryMemory
- )
- # 本地模块导入
- from utils.logger_config import setup_logger
- # 配置
- warnings.filterwarnings("ignore")
- logger = setup_logger(__name__)
- load_dotenv()
- class ConfigManager:
- _instances = {} # 使用字典存储不同配置文件的实例
- _configs = {} # 存储不同配置文件的配置内容
- def __new__(cls, config_path=None):
- """
- 创建ConfigManager实例
-
- Args:
- config_path: 配置文件路径,如果为None则使用默认路径
-
- Returns:
- ConfigManager实例
- """
- if config_path is None:
- current_dir = os.path.dirname(os.path.abspath(__file__))
- config_path = os.path.join(current_dir, "..", "config", "config.json")
-
- config_path = os.path.abspath(config_path)
-
- if config_path not in cls._instances:
- cls._instances[config_path] = super(ConfigManager, cls).__new__(cls)
- cls._instances[config_path]._config_path = config_path
- return cls._instances[config_path]
- def __init__(self, config_path=None):
- """
- 初始化ConfigManager
-
- Args:
- config_path: 配置文件路径,如果为None则使用默认路径
- """
- if config_path is None:
- current_dir = os.path.dirname(os.path.abspath(__file__))
- config_path = os.path.join(current_dir, "..", "config", "qa_config.json")
-
- config_path = os.path.abspath(config_path)
- self._config_path = config_path
-
- if config_path not in self._configs:
- self._load_config()
- def _load_config(self):
- """加载配置文件"""
- try:
- if not os.path.exists(self._config_path):
- raise FileNotFoundError(f"配置文件不存在: {self._config_path}")
-
- with open(self._config_path, 'r', encoding='utf-8') as f:
- self._configs[self._config_path] = json.load(f)
- self._validate_config()
- except Exception as e:
- logger.error(f"加载配置文件时出错: {e}")
- raise
- def _validate_config(self):
- """验证配置文件的必要字段"""
- required_fields = [
- "embed_model_name",
- "reranker_model",
- "reranker_k",
- "split_chunks_size",
- "split_overlap_size",
- "filter_threshold",
- "retrievel_k",
- "llm_model_name",
- "llm_api_key",
- "llm_base_url",
- "max_memory_size",
- "persist_path",
- "file_path",
- "query_db"
- ]
- config = self._configs[self._config_path]
- for field in required_fields:
- if field not in config:
- raise ValueError(f"配置文件缺少必要字段: {field}")
- @property
- def config(self):
- """获取当前配置"""
- return self._configs[self._config_path]
- class ModelConfig:
- """模型相关配置"""
- def __init__(self, config):
- self.config = config
- self.embedding_model = self._init_embedding_model()
- self.rerank_model = self._init_rerank_model()
- self.hg_embedding = self._init_hg_embedding()
- self.llm = self._init_llm()
- def _init_embedding_model(self):
- return DefineEmbedding(self.config["embed_model_name"])
- def _init_rerank_model(self):
- return HuggingFaceCrossEncoder(model_name=self.config["reranker_model"])
- def _init_hg_embedding(self):
- return HuggingFaceEmbeddings(
- model_name=self.config["reranker_model"],
- model_kwargs=self.config["embed_model_kwargs"],
- encode_kwargs=self.config["embed_encode_kwargs"]
- )
- def _init_llm(self):
- return ChatOpenAI(
- temperature=0,
- max_tokens=None,
- timeout=None,
- max_retries=2,
- model=self.config["llm_model_name"],
- api_key=self.config["llm_api_key"],
- base_url=self.config["llm_base_url"]
- )
- class BaseConfig:
- def __init__(self, config_path=None):
- """
- 初始化基础配置
-
- Args:
- config_path: 配置文件路径,如果为None则使用默认路径
- """
- self.config_manager = ConfigManager(config_path)
- self.config = self.config_manager.config
- self.model_config = ModelConfig(self.config)
-
- # 初始化模型
- self.embedding_model = self.model_config.embedding_model
- self.rerank_model = self.model_config.rerank_model
- self.hg_embedding = self.model_config.hg_embedding
- self.llm = self.model_config.llm
- # 文档切块
- self.text_splitter = CharacterTextSplitter(
- separator="##",
- chunk_size=self.config["split_chunks_size"],
- chunk_overlap=self.config["split_overlap_size"],
- is_separator_regex=False
- )
- # 初始化压缩器和过滤器
- self.compressor = CrossEncoderReranker(
- model=self.rerank_model,
- top_n=self.config["reranker_k"]
- )
- self.llm_filter = LLMChainFilter.from_llm(self.llm)
- self.embed_filter = EmbeddingsFilter(
- embeddings=self.hg_embedding,
- similarity_threshold=self.config["filter_threshold"]
- )
- # 文档过滤通道
- self.pipeline_compressor = DocumentCompressorPipeline(
- transformers=[self.compressor]
- )
- # 记忆体
- self.memory = ConversationBufferWindowMemory(
- memory_key="chat_history",
- return_messages=True,
- k=self.config["max_memory_size"]
- )
- class BotMethod(BaseConfig):
- def __init__(self, config_path=None):
- """初始化BotMethod,完成Bot配置的初始化"""
- super().__init__(config_path)
- class BaseMethod(BaseConfig):
- def __init__(self, config_path=None):
- """初始化BaseMethod,包括配置和向量数据库"""
- super().__init__(config_path)
- self._init_vector_db()
-
- def _init_vector_db(self) -> None:
- """初始化或更新向量数据库"""
- start_time = time.time() # 记录向量化开始时间
- try:
- file_paths = self.config["file_path"]
- persist_paths = self.config["persist_path"]
- if isinstance(file_paths, list) and isinstance(persist_paths, list):
- for file_path, persist_path in zip(file_paths, persist_paths):
- # 检查向量数据库目录是否存在
- if not os.path.exists(persist_path):
- logger.info(f"向量数据库目录不存在,创建新的向量数据库: {persist_path}")
- os.makedirs(persist_path, exist_ok=True)
- self.update_vecdb(file_path, persist_path)
- else:
- # 检查CSV文件的修改时间
- csv_mtime = os.path.getmtime(file_path)
- db_mtime = os.path.getmtime(persist_path)
-
- if csv_mtime > db_mtime:
- logger.info(f"CSV文件已更新,重新构建向量数据库")
- self.update_vecdb(file_path, persist_path)
- else:
- logger.info("向量数据库已是最新,无需更新")
- else:
- raise ValueError("file_path 和 persist_path 必须是列表类型")
- except Exception as e:
- logger.error(f"初始化向量数据库失败: {e}")
- raise
- finally:
- end_time = time.time() # 记录向量化结束时间
- logger.info(f"向量化耗时: {end_time - start_time} 秒")
- def split_txt(self, file_path):
- documents = []
- with open(file_path, "r") as f:
- text = f.read()
- documents.append(text)
- return self.text_splitter.create_documents(documents)
-
- def split_csv(self, file_path):
- # GBK-意图识别知识;utf-8-商品信息测试集2
- data = pd.read_csv(file_path, encoding="utf-8")
- headers = data.columns.tolist()
- # first column document
- first_column = data.iloc[:, 0].tolist()
- first_column_document = []
- for idx, text in enumerate(first_column):
- if isinstance(text, str) and text.strip():
- first_column_document.append(Document(metadata={'index': idx}, page_content=text))
- # every row document
- result_list = []
- for _, row in data.iterrows():
- row_content = []
- for header, value in zip(headers, row):
- if pd.notna(value):
- row_content.append(f"{header}{':'}{value}")
- result_list.append("。".join(row_content))
- first_row_document = []
- for idx, text in enumerate(result_list):
- if isinstance(text, str) and text.strip():
- first_row_document.append(Document(metadata={'index': idx}, page_content=text))
- logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_column_document): \n {first_column_document} \n -------------------------------------------------------------------------------\n")
- logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_row_document): \n {first_row_document} \n -------------------------------------------------------------------------------\n")
- return first_column_document, first_row_document
- def update_vecdb(self, file_path, persist_path):
- _, ext = os.path.splitext(file_path)
- try:
- if ext == ".txt":
- chunks = self.split_txt(file_path)
- elif ext == ".csv":
- chunks, _ = self.split_csv(file_path)
- except Exception as e:
- logger.warning(f"something wrong when update vertor db!")
- vecdb = Chroma.from_documents(documents=chunks, embedding=self.embedding_model, persist_directory=persist_path)
- vecdb.persist()
- def retriever(self, query, query_db=None):
- if query_db:
- vecdb = Chroma(persist_directory=query_db, embedding_function=self.embedding_model)
- else:
- vecdb = Chroma(persist_directory=self.config["query_db"], embedding_function=self.embedding_model)
- try:
- return ContextualCompressionRetriever(base_compressor=self.pipeline_compressor, base_retriever=vecdb.as_retriever(search_kwargs={"k": self.config["retrievel_k"]})) .invoke(query)
- except Exception as e:
- logger.error(f"创建检索器时出错: {e}")
- raise
-
- def csv_retriever(self, query, query_db=None):
- try:
- retriever_result = self.retriever(query, query_db)
- _, chunk_query = self.split_csv(self.config["file_path"][0])
- if not retriever_result:
- logger.warning(f"检索结果为空,返回空列表")
- return []
- else:
- retriever_index = retriever_result[0].metadata['index']
- retriever_context = [document for document in chunk_query if document.metadata['index'] == retriever_index]
- return retriever_context
- except Exception as e:
- logger.error(f"Error in csv_retriever: {e}")
- return []
- class DefineEmbedding:
- def __init__(self, model_path):
- self.model = SentenceTransformer(model_path)
- def embed_documents(self, texts: list) -> list:
- embedding = self.model.encode(texts, batch_size=32, show_progress_bar=True)
- return embedding.tolist()
-
- def embed_query(self, text: str) -> list:
- embedding = self.model.encode(text)
- return embedding.tolist()
-
- def __call__(self, texts: str) -> list:
- return self.embed_documents(texts)
- if __name__ == "__main__":
- pass
|