rag_config.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # 标准库导入
  2. import os
  3. import time
  4. import json
  5. import warnings
  6. from uuid import uuid4
  7. from typing import List, Dict, Any, Optional
  8. # 设置Python路径
  9. import sys
  10. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  11. # 第三方库导入
  12. import numpy as np
  13. import pandas as pd
  14. from dotenv import load_dotenv
  15. from sentence_transformers import SentenceTransformer
  16. # LangChain 核心组件
  17. from langchain.schema import Document
  18. from langchain.schema.runnable import RunnablePassthrough
  19. from langchain.schema.output_parser import StrOutputParser
  20. # LangChain 文本处理
  21. from langchain.text_splitter import (
  22. CharacterTextSplitter,
  23. RecursiveCharacterTextSplitter,
  24. MarkdownHeaderTextSplitter
  25. )
  26. # LangChain 向量存储
  27. from langchain_community.vectorstores import (
  28. FAISS,
  29. Chroma
  30. )
  31. # LangChain 文档加载器
  32. from langchain_community.document_loaders import (
  33. TextLoader,
  34. CSVLoader,
  35. JSONLoader,
  36. DirectoryLoader,
  37. UnstructuredMarkdownLoader
  38. )
  39. # LangChain 检索和压缩
  40. from langchain.retrievers import ContextualCompressionRetriever
  41. from langchain.retrievers.document_compressors import (
  42. CrossEncoderReranker,
  43. EmbeddingsFilter,
  44. DocumentCompressorPipeline,
  45. LLMChainFilter
  46. )
  47. # LangChain 模型和嵌入
  48. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  49. from langchain_huggingface import HuggingFaceEmbeddings
  50. from langchain_community.cross_encoders import HuggingFaceCrossEncoder
  51. # LangChain 内存管理
  52. from langchain.memory import (
  53. ConversationBufferWindowMemory,
  54. ConversationSummaryMemory
  55. )
  56. # 本地模块导入
  57. from utils.logger_config import setup_logger
  58. # 配置
  59. warnings.filterwarnings("ignore")
  60. logger = setup_logger(__name__)
  61. load_dotenv()
  62. class ConfigManager:
  63. _instances = {} # 使用字典存储不同配置文件的实例
  64. _configs = {} # 存储不同配置文件的配置内容
  65. def __new__(cls, config_path=None):
  66. """
  67. 创建ConfigManager实例
  68. Args:
  69. config_path: 配置文件路径,如果为None则使用默认路径
  70. Returns:
  71. ConfigManager实例
  72. """
  73. if config_path is None:
  74. current_dir = os.path.dirname(os.path.abspath(__file__))
  75. config_path = os.path.join(current_dir, "..", "config", "config.json")
  76. config_path = os.path.abspath(config_path)
  77. if config_path not in cls._instances:
  78. cls._instances[config_path] = super(ConfigManager, cls).__new__(cls)
  79. cls._instances[config_path]._config_path = config_path
  80. return cls._instances[config_path]
  81. def __init__(self, config_path=None):
  82. """
  83. 初始化ConfigManager
  84. Args:
  85. config_path: 配置文件路径,如果为None则使用默认路径
  86. """
  87. if config_path is None:
  88. current_dir = os.path.dirname(os.path.abspath(__file__))
  89. config_path = os.path.join(current_dir, "..", "config", "qa_config.json")
  90. config_path = os.path.abspath(config_path)
  91. self._config_path = config_path
  92. if config_path not in self._configs:
  93. self._load_config()
  94. def _load_config(self):
  95. """加载配置文件"""
  96. try:
  97. if not os.path.exists(self._config_path):
  98. raise FileNotFoundError(f"配置文件不存在: {self._config_path}")
  99. with open(self._config_path, 'r', encoding='utf-8') as f:
  100. self._configs[self._config_path] = json.load(f)
  101. self._validate_config()
  102. except Exception as e:
  103. logger.error(f"加载配置文件时出错: {e}")
  104. raise
  105. def _validate_config(self):
  106. """验证配置文件的必要字段"""
  107. required_fields = [
  108. "embed_model_name",
  109. "reranker_model",
  110. "reranker_k",
  111. "split_chunks_size",
  112. "split_overlap_size",
  113. "filter_threshold",
  114. "retrievel_k",
  115. "llm_model_name",
  116. "llm_api_key",
  117. "llm_base_url",
  118. "max_memory_size",
  119. "persist_path",
  120. "file_path",
  121. "query_db"
  122. ]
  123. config = self._configs[self._config_path]
  124. for field in required_fields:
  125. if field not in config:
  126. raise ValueError(f"配置文件缺少必要字段: {field}")
  127. @property
  128. def config(self):
  129. """获取当前配置"""
  130. return self._configs[self._config_path]
  131. class ModelConfig:
  132. """模型相关配置"""
  133. def __init__(self, config):
  134. self.config = config
  135. self.embedding_model = self._init_embedding_model()
  136. self.rerank_model = self._init_rerank_model()
  137. self.hg_embedding = self._init_hg_embedding()
  138. self.llm = self._init_llm()
  139. def _init_embedding_model(self):
  140. return DefineEmbedding(self.config["embed_model_name"])
  141. def _init_rerank_model(self):
  142. return HuggingFaceCrossEncoder(model_name=self.config["reranker_model"])
  143. def _init_hg_embedding(self):
  144. return HuggingFaceEmbeddings(
  145. model_name=self.config["reranker_model"],
  146. model_kwargs=self.config["embed_model_kwargs"],
  147. encode_kwargs=self.config["embed_encode_kwargs"]
  148. )
  149. def _init_llm(self):
  150. return ChatOpenAI(
  151. temperature=0,
  152. max_tokens=None,
  153. timeout=None,
  154. max_retries=2,
  155. model=self.config["llm_model_name"],
  156. api_key=self.config["llm_api_key"],
  157. base_url=self.config["llm_base_url"]
  158. )
  159. class BaseConfig:
  160. def __init__(self, config_path=None):
  161. """
  162. 初始化基础配置
  163. Args:
  164. config_path: 配置文件路径,如果为None则使用默认路径
  165. """
  166. self.config_manager = ConfigManager(config_path)
  167. self.config = self.config_manager.config
  168. self.model_config = ModelConfig(self.config)
  169. # 初始化模型
  170. self.embedding_model = self.model_config.embedding_model
  171. self.rerank_model = self.model_config.rerank_model
  172. self.hg_embedding = self.model_config.hg_embedding
  173. self.llm = self.model_config.llm
  174. # 文档切块
  175. self.text_splitter = CharacterTextSplitter(
  176. separator="##",
  177. chunk_size=self.config["split_chunks_size"],
  178. chunk_overlap=self.config["split_overlap_size"],
  179. is_separator_regex=False
  180. )
  181. # 初始化压缩器和过滤器
  182. self.compressor = CrossEncoderReranker(
  183. model=self.rerank_model,
  184. top_n=self.config["reranker_k"]
  185. )
  186. self.llm_filter = LLMChainFilter.from_llm(self.llm)
  187. self.embed_filter = EmbeddingsFilter(
  188. embeddings=self.hg_embedding,
  189. similarity_threshold=self.config["filter_threshold"]
  190. )
  191. # 文档过滤通道
  192. self.pipeline_compressor = DocumentCompressorPipeline(
  193. transformers=[self.compressor]
  194. )
  195. # 记忆体
  196. self.memory = ConversationBufferWindowMemory(
  197. memory_key="chat_history",
  198. return_messages=True,
  199. k=self.config["max_memory_size"]
  200. )
  201. class BotMethod(BaseConfig):
  202. def __init__(self, config_path=None):
  203. """初始化BotMethod,完成Bot配置的初始化"""
  204. super().__init__(config_path)
  205. class BaseMethod(BaseConfig):
  206. def __init__(self, config_path=None):
  207. """初始化BaseMethod,包括配置和向量数据库"""
  208. super().__init__(config_path)
  209. self._init_vector_db()
  210. def _init_vector_db(self) -> None:
  211. """初始化或更新向量数据库"""
  212. start_time = time.time() # 记录向量化开始时间
  213. try:
  214. file_paths = self.config["file_path"]
  215. persist_paths = self.config["persist_path"]
  216. if isinstance(file_paths, list) and isinstance(persist_paths, list):
  217. for file_path, persist_path in zip(file_paths, persist_paths):
  218. # 检查向量数据库目录是否存在
  219. if not os.path.exists(persist_path):
  220. logger.info(f"向量数据库目录不存在,创建新的向量数据库: {persist_path}")
  221. os.makedirs(persist_path, exist_ok=True)
  222. self.update_vecdb(file_path, persist_path)
  223. else:
  224. # 检查CSV文件的修改时间
  225. csv_mtime = os.path.getmtime(file_path)
  226. db_mtime = os.path.getmtime(persist_path)
  227. if csv_mtime > db_mtime:
  228. logger.info(f"CSV文件已更新,重新构建向量数据库")
  229. self.update_vecdb(file_path, persist_path)
  230. else:
  231. logger.info("向量数据库已是最新,无需更新")
  232. else:
  233. raise ValueError("file_path 和 persist_path 必须是列表类型")
  234. except Exception as e:
  235. logger.error(f"初始化向量数据库失败: {e}")
  236. raise
  237. finally:
  238. end_time = time.time() # 记录向量化结束时间
  239. logger.info(f"向量化耗时: {end_time - start_time} 秒")
  240. def split_txt(self, file_path):
  241. documents = []
  242. with open(file_path, "r") as f:
  243. text = f.read()
  244. documents.append(text)
  245. return self.text_splitter.create_documents(documents)
  246. def split_csv(self, file_path):
  247. # GBK-意图识别知识;utf-8-商品信息测试集2
  248. data = pd.read_csv(file_path, encoding="utf-8")
  249. headers = data.columns.tolist()
  250. # first column document
  251. first_column = data.iloc[:, 0].tolist()
  252. first_column_document = []
  253. for idx, text in enumerate(first_column):
  254. if isinstance(text, str) and text.strip():
  255. first_column_document.append(Document(metadata={'index': idx}, page_content=text))
  256. # every row document
  257. result_list = []
  258. for _, row in data.iterrows():
  259. row_content = []
  260. for header, value in zip(headers, row):
  261. if pd.notna(value):
  262. row_content.append(f"{header}{':'}{value}")
  263. result_list.append("。".join(row_content))
  264. first_row_document = []
  265. for idx, text in enumerate(result_list):
  266. if isinstance(text, str) and text.strip():
  267. first_row_document.append(Document(metadata={'index': idx}, page_content=text))
  268. logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_column_document): \n {first_column_document} \n -------------------------------------------------------------------------------\n")
  269. logger.info(f"\n-------------------------------------------------------------------------------\n 意图识别知识库(first_row_document): \n {first_row_document} \n -------------------------------------------------------------------------------\n")
  270. return first_column_document, first_row_document
  271. def update_vecdb(self, file_path, persist_path):
  272. _, ext = os.path.splitext(file_path)
  273. try:
  274. if ext == ".txt":
  275. chunks = self.split_txt(file_path)
  276. elif ext == ".csv":
  277. chunks, _ = self.split_csv(file_path)
  278. except Exception as e:
  279. logger.warning(f"something wrong when update vertor db!")
  280. vecdb = Chroma.from_documents(documents=chunks, embedding=self.embedding_model, persist_directory=persist_path)
  281. vecdb.persist()
  282. def retriever(self, query, query_db=None):
  283. if query_db:
  284. vecdb = Chroma(persist_directory=query_db, embedding_function=self.embedding_model)
  285. else:
  286. vecdb = Chroma(persist_directory=self.config["query_db"], embedding_function=self.embedding_model)
  287. try:
  288. return ContextualCompressionRetriever(base_compressor=self.pipeline_compressor, base_retriever=vecdb.as_retriever(search_kwargs={"k": self.config["retrievel_k"]})) .invoke(query)
  289. except Exception as e:
  290. logger.error(f"创建检索器时出错: {e}")
  291. raise
  292. def csv_retriever(self, query, query_db=None):
  293. try:
  294. retriever_result = self.retriever(query, query_db)
  295. _, chunk_query = self.split_csv(self.config["file_path"][0])
  296. if not retriever_result:
  297. logger.warning(f"检索结果为空,返回空列表")
  298. return []
  299. else:
  300. retriever_index = retriever_result[0].metadata['index']
  301. retriever_context = [document for document in chunk_query if document.metadata['index'] == retriever_index]
  302. return retriever_context
  303. except Exception as e:
  304. logger.error(f"Error in csv_retriever: {e}")
  305. return []
  306. class DefineEmbedding:
  307. def __init__(self, model_path):
  308. self.model = SentenceTransformer(model_path)
  309. def embed_documents(self, texts: list) -> list:
  310. embedding = self.model.encode(texts, batch_size=32, show_progress_bar=True)
  311. return embedding.tolist()
  312. def embed_query(self, text: str) -> list:
  313. embedding = self.model.encode(text)
  314. return embedding.tolist()
  315. def __call__(self, texts: str) -> list:
  316. return self.embed_documents(texts)
  317. if __name__ == "__main__":
  318. pass