file_process.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import PyPDF2
  2. import docx
  3. import nltk, subprocess, os
  4. from config import nltk_path, converter
  5. from typing import List, Union
  6. from pathlib import Path
  7. import re
  8. import pandas as pd
  9. from nltk.tokenize import sent_tokenize
  10. import spacy
  11. from langdetect import detect
  12. from milvus_process import update_mulvus_file
  13. import fitz
  14. from marker.output import text_from_rendered
  15. try:
  16. nltk.data.find(nltk_path)
  17. except LookupError:
  18. nltk.download('punkt')
  19. try:
  20. nlp = spacy.load("zh_core_web_sm")
  21. except OSError:
  22. pass
  23. class DocumentProcessor:
  24. def __init__(self,):
  25. pass
  26. def get_file_len(self, file_path: Union[str, Path]) -> int:
  27. text = self.read_file(file_path)
  28. length = len(text)
  29. del text
  30. return length
  31. @staticmethod
  32. def convert_docx_to_doc(input_path):
  33. output_dir = os.path.dirname(input_path)
  34. command = [
  35. "soffice", "--headless", "--convert-to", "docx", input_path, "--outdir", output_dir
  36. ]
  37. subprocess.run(command, check=True)
  38. def _read_doc(self, file_path) -> str:
  39. """读取Word文档"""
  40. self.convert_docx_to_doc(file_path)
  41. old_file = Path(file_path) # 原始 .doc 文件
  42. new_file = old_file.with_suffix(".docx") # 转换后的 .docx 文件
  43. if old_file.exists(): # 确保旧文件存在
  44. old_file.unlink() # 删除旧文件
  45. doc = docx.Document(new_file) # 读取 .docx
  46. text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
  47. return self.create_chunks(text)
  48. def read_file(self, file_path: Union[str, Path]) -> str:
  49. """
  50. 读取不同格式的文档
  51. Args:
  52. file_path: 文件路径
  53. Returns:
  54. str: 提取的文本内容
  55. """
  56. file_path = Path(file_path)
  57. extension = file_path.suffix.lower()
  58. if extension == '.pdf':
  59. return self._read_pdf(file_path)
  60. elif extension == '.docx':
  61. return self._read_docx(file_path)
  62. elif extension == '.doc':
  63. return self._read_doc(file_path)
  64. elif extension == '.txt':
  65. return self._read_txt(file_path)
  66. elif extension == '.csv':
  67. return self._read_csv(file_path)
  68. elif extension == '.xlsx':
  69. return self._read_excel(file_path)
  70. else:
  71. raise ValueError(f"Unsupported file format: {extension}")
  72. def _read_pdf(self, file_path) -> str:
  73. """读取PDF文件"""
  74. rendered = converter(str(file_path))
  75. text, x, images = text_from_rendered(rendered)
  76. return self.create_chunks(text=text)
  77. def _read_docx(self, file_path: Path) -> str:
  78. """读取Word文档"""
  79. doc = docx.Document(file_path)
  80. text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
  81. return self.create_chunks(text)
  82. def _read_txt(self, file_path: Path) -> str:
  83. """读取文本文件"""
  84. with open(file_path, 'r', encoding='utf-8') as file:
  85. return self.create_chunks(file.read())
  86. def _read_excel(self, file_path: Path) -> str:
  87. """读取Excel文件"""
  88. df = pd.read_excel(file_path, sheet_name=None)
  89. text = ""
  90. for sheet_name, sheet_df in df.items():
  91. text += f"\nSheet: {sheet_name}\n"
  92. text += sheet_df.to_csv(index=False, sep=' ', header=True)
  93. return self.create_chunks(text)
  94. def _read_csv(self, file_path: Path) -> str:
  95. """读取CSV文件"""
  96. df = pd.read_csv(file_path)
  97. return self.create_chunks(df.to_csv(index=False, sep=' ', header=True))
  98. def _clean_text(self, text: str) -> str:
  99. """
  100. 清理文本
  101. - 移除多余的空白字符
  102. - 标准化换行符
  103. """
  104. # 替换多个空格为单个空格
  105. text = re.sub(r'\s+', ' ', text)
  106. # 标准化换行符
  107. text = text.replace('\r\n', '\n').replace('\r', '\n')
  108. # 移除空行
  109. text = '\n'.join(line.strip() for line in text.split('\n') if line.strip())
  110. return text.strip()
  111. def split_into_sentences(self, text: str) -> List[str]:
  112. """
  113. 将文本分割成句子
  114. Args:
  115. text: 输入文本
  116. Returns:
  117. List[str]: 句子列表
  118. """
  119. # 使用NLTK进行句子分割
  120. sentences = sent_tokenize(text)
  121. return sentences
  122. def force_split_sentence(self, sentence: str, max_length: int) -> List[str]:
  123. """
  124. 强制将超长句子按字符数切分
  125. Args:
  126. sentence (str): 输入的句子
  127. max_length (int): 最大长度
  128. Returns:
  129. List[str]: 切分后的句子片段列表
  130. """
  131. # 使用标点符号作为次要切分点
  132. punctuation = '。,;!?,.;!?'
  133. parts = []
  134. current_part = ''
  135. # 优先在标点符号处切分
  136. chars = list(sentence)
  137. for i, char in enumerate(chars):
  138. current_part += char
  139. # 如果当前部分达到最大长度或遇到标点符号
  140. if (len(current_part) >= max_length and char in punctuation) or \
  141. (len(current_part) >= max_length * 1.2): # 允许略微超过max_length以寻找标点
  142. parts.append(current_part)
  143. current_part = ''
  144. # 处理剩余部分
  145. if current_part:
  146. # 如果剩余部分仍然过长,强制按长度切分
  147. while len(current_part) > max_length:
  148. parts.append(current_part[:max_length] + '...')
  149. current_part = '...' + current_part[max_length:]
  150. parts.append(current_part)
  151. return parts
  152. def split_text_nltk(self, text: str, chunk_size: int = 1500, overlap_size: int = 100) -> List[str]:
  153. """
  154. 使用NLTK进行中文文本分割,支持文本块重叠和超长句子处理
  155. Args:
  156. text (str): 输入的中文文本
  157. chunk_size (int): 每个chunk的近似字符数
  158. overlap_size (int): 相邻chunk之间的重叠字符数
  159. Returns:
  160. List[str]: 分割后的文本块列表
  161. """
  162. text = self._clean_text(text)
  163. sentences = nltk.sent_tokenize(text)
  164. chunks = self.process_sentences(sentences=sentences, chunk_size=chunk_size, overlap_size=overlap_size)
  165. return chunks
  166. def split_text_spacy(self, text: str, chunk_size: int = 500, overlap_size: int = 100) -> List[str]:
  167. """
  168. 使用SpaCy进行中文文本分割,支持文本块重叠和超长句子处理
  169. Args:
  170. text (str): 输入的中文文本
  171. chunk_size (int): 每个chunk的近似字符数
  172. overlap_size (int): 相邻chunk之间的重叠字符数
  173. Returns:
  174. List[str]: 分割后的文本块列表
  175. """
  176. text = self._clean_text(text)
  177. doc = nlp(text)
  178. chunks = []
  179. sentences = [sent.text for sent in doc.sents]
  180. chunks = self.process_sentences(sentences=sentences, chunk_size=chunk_size, overlap_size=overlap_size)
  181. return chunks
  182. def process_sentences(self, sentences, chunk_size: int = 500, overlap_size: int = 100):
  183. chunks = []
  184. current_chunk = []
  185. current_length = 0
  186. for sentence in sentences:
  187. # 处理超长句子
  188. if len(sentence) > chunk_size:
  189. # 先处理当前chunk中已有的内容
  190. if current_chunk:
  191. chunks.append("".join(current_chunk))
  192. current_chunk = []
  193. current_length = 0
  194. # 强制切分超长句子
  195. sentence_parts = self.force_split_sentence(sentence, chunk_size)
  196. for part in sentence_parts:
  197. chunks.append(part)
  198. continue
  199. # 正常处理普通长度的句子
  200. if current_length + len(sentence) <= chunk_size:
  201. current_chunk.append(sentence)
  202. current_length += len(sentence)
  203. else:
  204. if current_chunk:
  205. chunks.append("".join(current_chunk))
  206. # 处理重叠
  207. overlap_chars = 0
  208. overlap_sentences = []
  209. for prev_sentence in reversed(current_chunk):
  210. if overlap_chars + len(prev_sentence) <= overlap_size:
  211. overlap_sentences.insert(0, prev_sentence)
  212. overlap_chars += len(prev_sentence)
  213. else:
  214. break
  215. current_chunk = overlap_sentences + [sentence]
  216. current_length = sum(len(s) for s in current_chunk)
  217. if current_chunk:
  218. chunks.append("".join(current_chunk))
  219. return chunks
  220. def create_chunks(self, text: str, chunk_size=300, overlap_size=100) -> List[str]:
  221. is_chinese = self.is_chinese_text(text)
  222. if is_chinese:
  223. # print('检测为中文文章, 采用spacy')
  224. chunks = self.split_text_spacy(text,chunk_size=chunk_size,overlap_size=overlap_size)
  225. else:
  226. # print('检测为外文文章, 采用nltk')
  227. chunks = self.split_text_spacy(text,chunk_size=chunk_size,overlap_size=overlap_size)
  228. return chunks
  229. def is_chinese_text(self, text: str) -> bool:
  230. """
  231. 判断文本是否主要为中文
  232. Args:
  233. text (str): 输入文本
  234. Returns:
  235. bool: 如果是中文文本返回True,否则返回False
  236. """
  237. try:
  238. chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
  239. total_chars = len(re.findall(r'\w', text)) + chinese_chars
  240. char_ratio = chinese_chars / max(total_chars, 1)
  241. if char_ratio > 0.1:
  242. return True
  243. # 使用langdetect进行语言检测
  244. lang = detect(text)
  245. # 如果检测失败,使用备用方法
  246. if not lang:
  247. raise Exception("Language detection failed")
  248. return lang == 'zh-cn' or lang == 'zh-tw' or lang == 'zh'
  249. except Exception:
  250. return char_ratio > 0.1
  251. def process_document(self, file_path: Union[str, Path], chunk_size=1000, overlap_size=250) -> List[str]:
  252. """
  253. 处理文档的主方法
  254. Args:
  255. file_path: 文档路径
  256. Returns:
  257. List[str]: 处理后的文本块列表
  258. """
  259. # 读取文档
  260. text = self.read_file(file_path)
  261. chunks = self.create_chunks(text=text, chunk_size=chunk_size, overlap_size=overlap_size)
  262. # return chunks
  263. return chunks
  264. if __name__ == '__main__':
  265. import asyncio
  266. processor = DocumentProcessor()
  267. # 处理文档
  268. chunks = processor.read_file("./tests/test.pdf")
  269. # 打印结果
  270. # for i, chunk in enumerate(chunks):
  271. # print(f"Chunk {i+1}:")
  272. # print(chunk)
  273. # print(len(chunk))
  274. # print("-" * 50)
  275. status = asyncio.run(update_mulvus_file(client_id='test', file_name='test.pdf',chunks=chunks))