from pymilvus import DataType, FieldSchema, CollectionSchema, AnnSearchRequest, WeightedRanker, Function, FunctionType from config import milvus_url, logger_search from embedding_search import model_embedding_bge, bge_rerank import time, asyncio import traceback from pymilvus import MilvusClient, DataType, RRFRanker from util import encode_to_base64 from typing_extensions import Annotated async def create_collection(collection_name): client = MilvusClient(uri=milvus_url) if client.has_collection(collection_name=collection_name): # client.drop_collection(collection_name=collection_name) # print('ok') return False fields = [ FieldSchema(name="client_id", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="filename_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024), FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535, is_primary=True, enable_analyzer=True), FieldSchema(name="content_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024), FieldSchema(name="content_sparse", dtype=DataType.SPARSE_FLOAT_VECTOR) ] bm25_function = Function( name="text_bm25_emb", input_field_names=["content"], output_field_names=["content_sparse"], function_type=FunctionType.BM25, ) schema = CollectionSchema(fields=fields, enable_dynamic_field=True, auto_id=False) schema.add_function(bm25_function) index_params = client.prepare_index_params() index_params.add_index( field_name="client_id", index_type="INVERTED" ) index_params.add_index( field_name="file_name", index_type="INVERTED" ) index_params.add_index( field_name="content_embedding", index_type="IVF_FLAT", metric_type="COSINE", params={ "nlist": 128 } ) index_params.add_index( field_name="filename_embedding", index_type="IVF_FLAT", metric_type="COSINE", params={ "nlist": 128 } ) index_params.add_index( field_name="content_sparse", index_type="AUTOINDEX", metric_type="BM25" ) client.create_collection( collection_name=collection_name, schema=schema, index_params=index_params ) return True async def update_mulvus_file(client_id, file_name, chunks, collection_name='gloria_files'): try: start_time = time.time() # 从图中拿数据 client = MilvusClient(uri=milvus_url) await create_collection(collection_name=collection_name) client.load_collection(collection_name=collection_name,) try: res = client.delete(collection_name=collection_name, filter=f'file_name == "{file_name}"') except Exception as e: print(e) final_data = [] filename_embedding = await asyncio.to_thread(model_embedding_bge, file_name) for chunk in chunks: chunk_embedding = await asyncio.to_thread(model_embedding_bge, chunk) final_data.append({ 'client_id': client_id, 'file_name': file_name, "filename_embedding":filename_embedding, 'content': chunk, 'content_embedding': chunk_embedding, }) insert_dict = await asyncio.to_thread(client.insert, collection_name=collection_name, data=final_data) status = {"result": "succeed"} end_time = time.time() logger_search.info(f"{file_name} 嵌入milvus单次耗时: {end_time - start_time} s") client.release_collection(collection_name=collection_name) client.close() return status except Exception as e: logger_search.error(f"处理 {client_id} 的文件 {file_name} 嵌入milvus报错如下:{e}") traceback.print_exc() status = {"result": "failed"} return status async def get_filedata_from_milvus(query, collection_name='gloria_files', file_name=None, thread=0.5): try: client = MilvusClient(uri=milvus_url) query_embedding = await asyncio.to_thread(model_embedding_bge, query) time0 = time.time() client.load_collection(collection_name=collection_name) time1 = time.time() logger_search.info(f"load 时间: {time1 - time0} s") search_filter = f'file_name == "{file_name}"' if file_name else '' # 创建针对不同字段的ANN搜索请求 search_param_1 = { "data": [query_embedding], # 确保输入为二维向量 "anns_field": "filename_embedding",# 文件名向量字段 "param": { "metric_type": "COSINE", "params": {"nprobe": 10} }, "limit": 300, } request_1 = AnnSearchRequest(**search_param_1) search_param_2 = { "data": [query_embedding], "anns_field": "content_embedding",# 段落内容向量字段 "param": { "metric_type": "COSINE", "params": {"nprobe": 10} }, "limit": 300, } request_2 = AnnSearchRequest(**search_param_2) req_temp = AnnSearchRequest( data=[query], anns_field='content_sparse', param={ "metric_type": 'BM25', "params": {'drop_ratio_search': 0.2} }, limit=50, ) # 执行混合搜索(文件名权重0.4,内容权重0.6) rerank = WeightedRanker(0.2, 0.5, 0.5) res = await asyncio.to_thread( client.hybrid_search, collection_name=collection_name, reqs=[request_1, request_2, req_temp], ranker=rerank, output_fields=['content', 'file_name'], filter = search_filter ) client.release_collection(collection_name=collection_name) client.close() final_answer = [] for hits in res: for hit in hits: if hit.get('distance', 0) > thread: entity = hit.get('entity') final_answer.append({'file_name': entity.get("file_name"), 'content': entity.get("content")}) except Exception as e: logger_search.error(e) final_answer = [] logger_search.info(f'从milvus文件知识库搜索到{len(final_answer)}条信息') return final_answer async def rerank_file_data(query, data): lookup = {} # 使用单个字典存储映射关系 chunks = [] final_chunks = {} for item in data: content = item.get('content', '') # 直接存储原始数据项的引用,避免后续查找 lookup[content] = item chunks.append(content) # 执行重排序 logger_search.info(chunks) reranked = bge_rerank(query=query, indicates=chunks, socre=0.3) send_data = [lookup[chunk] for chunk in reranked if chunk in lookup] reranked_content = [i.get('content') for i in send_data] send_data = [dict(t) for t in {tuple(d.items()) for d in send_data}] logger_search.info(send_data) for sed in send_data: if not sed.get('file_name') in final_chunks: final_chunks[sed.get('file_name')] = sed.get('content') final_chunks[sed.get('file_name')] += sed.get('content') logger_search.info(final_chunks) search_results_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())]) return final_chunks, search_results_str async def get_search_results(query: Annotated[str, "需要搜索的问题"], summary=True): data = await get_filedata_from_milvus(query=query) final_data, search_str = await rerank_file_data(query=query,data=data) if not final_data: final_chunks = {} for sed in data: if not sed.get('file_name') in final_chunks: final_chunks[sed.get('file_name')] = sed.get('content') final_chunks[sed.get('file_name')] += sed.get('content') search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())]) final_data = final_chunks logger_search.info(search_str) return final_data, search_str if __name__ == '__main__': from agent import get_content_summary final_data, result = asyncio.run(get_search_results(query='每月盘点时间有要求吗?')) data, res = asyncio.run(get_content_summary(question='每月盘点时间有要求吗?',res_info=result, final_data=final_data)) logger_search.info(data) logger_search.info(res) # x = asyncio.run(create_collection(collection_name='gloria_files')) # print(x) pass