123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- from pymilvus import DataType, FieldSchema, CollectionSchema, AnnSearchRequest, WeightedRanker, Function, FunctionType
- from config import milvus_url, logger_search
- from embedding_search import model_embedding_bge, bge_rerank
- import time, asyncio
- import traceback
- from pymilvus import MilvusClient, DataType, RRFRanker
- from util import encode_to_base64
- from typing_extensions import Annotated
- async def create_collection(collection_name):
- client = MilvusClient(uri=milvus_url)
- if client.has_collection(collection_name=collection_name):
- # client.drop_collection(collection_name=collection_name)
- # print('ok')
- return False
- fields = [
- FieldSchema(name="client_id", dtype=DataType.VARCHAR, max_length=65535),
- FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=65535),
- FieldSchema(name="filename_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
- FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535, is_primary=True, enable_analyzer=True),
- FieldSchema(name="content_embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
- FieldSchema(name="content_sparse", dtype=DataType.SPARSE_FLOAT_VECTOR)
- ]
- bm25_function = Function(
- name="text_bm25_emb",
- input_field_names=["content"],
- output_field_names=["content_sparse"],
- function_type=FunctionType.BM25,
- )
- schema = CollectionSchema(fields=fields, enable_dynamic_field=True, auto_id=False)
- schema.add_function(bm25_function)
-
- index_params = client.prepare_index_params()
- index_params.add_index(
- field_name="client_id",
- index_type="INVERTED"
- )
- index_params.add_index(
- field_name="file_name",
- index_type="INVERTED"
- )
- index_params.add_index(
- field_name="content_embedding",
- index_type="IVF_FLAT",
- metric_type="COSINE",
- params={ "nlist": 128 }
- )
- index_params.add_index(
- field_name="filename_embedding",
- index_type="IVF_FLAT",
- metric_type="COSINE",
- params={ "nlist": 128 }
- )
- index_params.add_index(
- field_name="content_sparse",
- index_type="AUTOINDEX",
- metric_type="BM25"
- )
- client.create_collection(
- collection_name=collection_name,
- schema=schema,
- index_params=index_params
- )
- return True
- async def update_mulvus_file(client_id, file_name, chunks, collection_name='gloria_files'):
- try:
- start_time = time.time()
- # 从图中拿数据
- client = MilvusClient(uri=milvus_url)
- await create_collection(collection_name=collection_name)
- client.load_collection(collection_name=collection_name,)
- try:
-
- res = client.delete(collection_name=collection_name, filter=f'file_name == "{file_name}"')
- except Exception as e:
-
- print(e)
- final_data = []
- filename_embedding = await asyncio.to_thread(model_embedding_bge, file_name)
- for chunk in chunks:
- chunk_embedding = await asyncio.to_thread(model_embedding_bge, chunk)
- final_data.append({
- 'client_id': client_id,
- 'file_name': file_name,
- "filename_embedding":filename_embedding,
- 'content': chunk,
- 'content_embedding': chunk_embedding,
- })
- insert_dict = await asyncio.to_thread(client.insert, collection_name=collection_name, data=final_data)
-
- status = {"result": "succeed"}
- end_time = time.time()
- logger_search.info(f"{file_name} 嵌入milvus单次耗时: {end_time - start_time} s")
- client.release_collection(collection_name=collection_name)
- client.close()
- return status
- except Exception as e:
- logger_search.error(f"处理 {client_id} 的文件 {file_name} 嵌入milvus报错如下:{e}")
- traceback.print_exc()
- status = {"result": "failed"}
- return status
- async def get_filedata_from_milvus(query, collection_name='gloria_files', file_name=None, thread=0.5):
- try:
- client = MilvusClient(uri=milvus_url)
- query_embedding = await asyncio.to_thread(model_embedding_bge, query)
- time0 = time.time()
- client.load_collection(collection_name=collection_name)
- time1 = time.time()
- logger_search.info(f"load 时间: {time1 - time0} s")
-
- search_filter = f'file_name == "{file_name}"' if file_name else ''
- # 创建针对不同字段的ANN搜索请求
- search_param_1 = {
- "data": [query_embedding], # 确保输入为二维向量
- "anns_field": "filename_embedding",# 文件名向量字段
- "param": {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- },
- "limit": 300,
- }
- request_1 = AnnSearchRequest(**search_param_1)
- search_param_2 = {
- "data": [query_embedding],
- "anns_field": "content_embedding",# 段落内容向量字段
- "param": {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- },
- "limit": 300,
-
- }
- request_2 = AnnSearchRequest(**search_param_2)
-
- req_temp = AnnSearchRequest(
- data=[query],
- anns_field='content_sparse',
- param={
- "metric_type": 'BM25',
- "params": {'drop_ratio_search': 0.2}
- },
- limit=50,
- )
- # 执行混合搜索(文件名权重0.4,内容权重0.6)
- rerank = WeightedRanker(0.2, 0.5, 0.5)
- res = await asyncio.to_thread(
- client.hybrid_search,
- collection_name=collection_name,
- reqs=[request_1, request_2, req_temp],
- ranker=rerank,
- output_fields=['content', 'file_name'],
- filter = search_filter
- )
- client.release_collection(collection_name=collection_name)
- client.close()
- final_answer = []
- for hits in res:
- for hit in hits:
- if hit.get('distance', 0) > thread:
- entity = hit.get('entity')
- final_answer.append({'file_name': entity.get("file_name"), 'content': entity.get("content")})
- except Exception as e:
- logger_search.error(e)
- final_answer = []
- logger_search.info(f'从milvus文件知识库搜索到{len(final_answer)}条信息')
- return final_answer
- async def rerank_file_data(query, data):
- lookup = {} # 使用单个字典存储映射关系
- chunks = []
- final_chunks = {}
- for item in data:
- content = item.get('content', '')
- # 直接存储原始数据项的引用,避免后续查找
- lookup[content] = item
- chunks.append(content)
- # 执行重排序
- logger_search.info(chunks)
- reranked = bge_rerank(query=query, indicates=chunks, socre=0.3)
-
- send_data = [lookup[chunk] for chunk in reranked if chunk in lookup]
-
- reranked_content = [i.get('content') for i in send_data]
-
- send_data = [dict(t) for t in {tuple(d.items()) for d in send_data}]
- logger_search.info(send_data)
- for sed in send_data:
- if not sed.get('file_name') in final_chunks:
- final_chunks[sed.get('file_name')] = sed.get('content')
- final_chunks[sed.get('file_name')] += sed.get('content')
- logger_search.info(final_chunks)
- search_results_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
- return final_chunks, search_results_str
- async def get_search_results(query: Annotated[str, "需要搜索的问题"], summary=True):
- data = await get_filedata_from_milvus(query=query)
- final_data, search_str = await rerank_file_data(query=query,data=data)
- if not final_data:
- final_chunks = {}
- for sed in data:
- if not sed.get('file_name') in final_chunks:
- final_chunks[sed.get('file_name')] = sed.get('content')
- final_chunks[sed.get('file_name')] += sed.get('content')
- search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
- final_data = final_chunks
- logger_search.info(search_str)
- return final_data, search_str
- if __name__ == '__main__':
- from agent import get_content_summary
- final_data, result = asyncio.run(get_search_results(query='每月盘点时间有要求吗?'))
- data, res = asyncio.run(get_content_summary(question='每月盘点时间有要求吗?',res_info=result, final_data=final_data))
- logger_search.info(data)
- logger_search.info(res)
- # x = asyncio.run(create_collection(collection_name='gloria_files'))
- # print(x)
- pass
|