milvus_process.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from pymilvus import DataType, FieldSchema, CollectionSchema, AnnSearchRequest, WeightedRanker, Function, FunctionType
  2. from config import milvus_url, logger_search
  3. from embedding_search import model_embedding_bge, bge_rerank
  4. import time, asyncio
  5. import traceback
  6. from pymilvus import MilvusClient, DataType, RRFRanker
  7. from util import encode_to_base64
  8. from typing_extensions import Annotated
  9. async def create_collection(collection_name):
  10. client = MilvusClient(uri=milvus_url)
  11. if client.has_collection(collection_name=collection_name):
  12. # client.drop_collection(collection_name=collection_name)
  13. # print('ok')
  14. return False
  15. fields = [
  16. FieldSchema(name="client_id", dtype=DataType.VARCHAR, max_length=65535),
  17. FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=65535),
  18. FieldSchema(name="filename_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
  19. FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535, is_primary=True, enable_analyzer=True),
  20. FieldSchema(name="content_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
  21. FieldSchema(name="content_sparse", dtype=DataType.SPARSE_FLOAT_VECTOR)
  22. ]
  23. bm25_function = Function(
  24. name="text_bm25_emb",
  25. input_field_names=["content"],
  26. output_field_names=["content_sparse"],
  27. function_type=FunctionType.BM25,
  28. )
  29. schema = CollectionSchema(fields=fields, enable_dynamic_field=True, auto_id=False)
  30. schema.add_function(bm25_function)
  31. index_params = client.prepare_index_params()
  32. index_params.add_index(
  33. field_name="client_id",
  34. index_type="INVERTED"
  35. )
  36. index_params.add_index(
  37. field_name="file_name",
  38. index_type="INVERTED"
  39. )
  40. index_params.add_index(
  41. field_name="content_embedding",
  42. index_type="IVF_FLAT",
  43. metric_type="COSINE",
  44. params={ "nlist": 128 }
  45. )
  46. index_params.add_index(
  47. field_name="filename_embedding",
  48. index_type="IVF_FLAT",
  49. metric_type="COSINE",
  50. params={ "nlist": 128 }
  51. )
  52. index_params.add_index(
  53. field_name="content_sparse",
  54. index_type="AUTOINDEX",
  55. metric_type="BM25"
  56. )
  57. client.create_collection(
  58. collection_name=collection_name,
  59. schema=schema,
  60. index_params=index_params
  61. )
  62. return True
  63. async def update_mulvus_file(client_id, file_name, chunks, collection_name='gloria_files'):
  64. try:
  65. start_time = time.time()
  66. # 从图中拿数据
  67. client = MilvusClient(uri=milvus_url)
  68. await create_collection(collection_name=collection_name)
  69. client.load_collection(collection_name=collection_name,)
  70. try:
  71. res = client.delete(collection_name=collection_name, filter=f'file_name == "{file_name}"')
  72. except Exception as e:
  73. print(e)
  74. final_data = []
  75. filename_embedding = await asyncio.to_thread(model_embedding_bge, file_name)
  76. for chunk in chunks:
  77. chunk_embedding = await asyncio.to_thread(model_embedding_bge, chunk)
  78. final_data.append({
  79. 'client_id': client_id,
  80. 'file_name': file_name,
  81. "filename_embedding":filename_embedding,
  82. 'content': chunk,
  83. 'content_embedding': chunk_embedding,
  84. })
  85. insert_dict = await asyncio.to_thread(client.insert, collection_name=collection_name, data=final_data)
  86. status = {"result": "succeed"}
  87. end_time = time.time()
  88. logger_search.info(f"{file_name} 嵌入milvus单次耗时: {end_time - start_time} s")
  89. client.release_collection(collection_name=collection_name)
  90. client.close()
  91. return status
  92. except Exception as e:
  93. logger_search.error(f"处理 {client_id} 的文件 {file_name} 嵌入milvus报错如下:{e}")
  94. traceback.print_exc()
  95. status = {"result": "failed"}
  96. return status
  97. async def get_filedata_from_milvus(query, collection_name='gloria_files', file_name=None, thread=0.5):
  98. try:
  99. client = MilvusClient(uri=milvus_url)
  100. query_embedding = await asyncio.to_thread(model_embedding_bge, query)
  101. time0 = time.time()
  102. client.load_collection(collection_name=collection_name)
  103. time1 = time.time()
  104. logger_search.info(f"load 时间: {time1 - time0} s")
  105. search_filter = f'file_name == "{file_name}"' if file_name else ''
  106. # 创建针对不同字段的ANN搜索请求
  107. search_param_1 = {
  108. "data": [query_embedding], # 确保输入为二维向量
  109. "anns_field": "filename_embedding",# 文件名向量字段
  110. "param": {
  111. "metric_type": "COSINE",
  112. "params": {"nprobe": 10}
  113. },
  114. "limit": 300,
  115. }
  116. request_1 = AnnSearchRequest(**search_param_1)
  117. search_param_2 = {
  118. "data": [query_embedding],
  119. "anns_field": "content_embedding",# 段落内容向量字段
  120. "param": {
  121. "metric_type": "COSINE",
  122. "params": {"nprobe": 10}
  123. },
  124. "limit": 300,
  125. }
  126. request_2 = AnnSearchRequest(**search_param_2)
  127. req_temp = AnnSearchRequest(
  128. data=[query],
  129. anns_field='content_sparse',
  130. param={
  131. "metric_type": 'BM25',
  132. "params": {'drop_ratio_search': 0.2}
  133. },
  134. limit=50,
  135. )
  136. # 执行混合搜索(文件名权重0.4,内容权重0.6)
  137. rerank = WeightedRanker(0.2, 0.5, 0.5)
  138. res = await asyncio.to_thread(
  139. client.hybrid_search,
  140. collection_name=collection_name,
  141. reqs=[request_1, request_2, req_temp],
  142. ranker=rerank,
  143. output_fields=['content', 'file_name'],
  144. filter = search_filter
  145. )
  146. client.release_collection(collection_name=collection_name)
  147. client.close()
  148. final_answer = []
  149. for hits in res:
  150. for hit in hits:
  151. if hit.get('distance', 0) > thread:
  152. entity = hit.get('entity')
  153. final_answer.append({'file_name': entity.get("file_name"), 'content': entity.get("content")})
  154. except Exception as e:
  155. logger_search.error(e)
  156. final_answer = []
  157. logger_search.info(f'从milvus文件知识库搜索到{len(final_answer)}条信息')
  158. return final_answer
  159. async def rerank_file_data(query, data):
  160. lookup = {} # 使用单个字典存储映射关系
  161. chunks = []
  162. final_chunks = {}
  163. for item in data:
  164. content = item.get('content', '')
  165. # 直接存储原始数据项的引用,避免后续查找
  166. lookup[content] = item
  167. chunks.append(content)
  168. # 执行重排序
  169. logger_search.info(chunks)
  170. reranked = bge_rerank(query=query, indicates=chunks, socre=0.3)
  171. send_data = [lookup[chunk] for chunk in reranked if chunk in lookup]
  172. reranked_content = [i.get('content') for i in send_data]
  173. send_data = [dict(t) for t in {tuple(d.items()) for d in send_data}]
  174. logger_search.info(send_data)
  175. for sed in send_data:
  176. if not sed.get('file_name') in final_chunks:
  177. final_chunks[sed.get('file_name')] = sed.get('content')
  178. final_chunks[sed.get('file_name')] += sed.get('content')
  179. logger_search.info(final_chunks)
  180. search_results_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
  181. return final_chunks, search_results_str
  182. async def get_search_results(query: Annotated[str, "需要搜索的问题"], summary=True):
  183. data = await get_filedata_from_milvus(query=query)
  184. final_data, search_str = await rerank_file_data(query=query,data=data)
  185. if not final_data:
  186. final_chunks = {}
  187. for sed in data:
  188. if not sed.get('file_name') in final_chunks:
  189. final_chunks[sed.get('file_name')] = sed.get('content')
  190. final_chunks[sed.get('file_name')] += sed.get('content')
  191. search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
  192. final_data = final_chunks
  193. logger_search.info(search_str)
  194. return final_data, search_str
  195. if __name__ == '__main__':
  196. from agent import get_content_summary
  197. final_data, result = asyncio.run(get_search_results(query='每月盘点时间有要求吗?'))
  198. data, res = asyncio.run(get_content_summary(question='每月盘点时间有要求吗?',res_info=result, final_data=final_data))
  199. logger_search.info(data)
  200. logger_search.info(res)
  201. # x = asyncio.run(create_collection(collection_name='gloria_files'))
  202. # print(x)
  203. pass