|
@@ -1,5 +1,5 @@
|
|
|
from pymilvus import DataType, FieldSchema, CollectionSchema, AnnSearchRequest, WeightedRanker, Function, FunctionType
|
|
|
-from config import milvus_url, logger_search
|
|
|
+from config import milvus_url
|
|
|
from embedding_search import model_embedding_bge, bge_rerank
|
|
|
import time, asyncio
|
|
|
import traceback
|
|
@@ -93,13 +93,13 @@ async def update_mulvus_file(client_id, file_name, chunks, collection_name='glor
|
|
|
|
|
|
status = {"result": "succeed"}
|
|
|
end_time = time.time()
|
|
|
- logger_search.info(f"{file_name} 嵌入milvus单次耗时: {end_time - start_time} s")
|
|
|
+
|
|
|
client.release_collection(collection_name=collection_name)
|
|
|
client.close()
|
|
|
return status
|
|
|
|
|
|
except Exception as e:
|
|
|
- logger_search.error(f"处理 {client_id} 的文件 {file_name} 嵌入milvus报错如下:{e}")
|
|
|
+
|
|
|
traceback.print_exc()
|
|
|
status = {"result": "failed"}
|
|
|
return status
|
|
@@ -112,7 +112,7 @@ async def get_filedata_from_milvus(query, collection_name='gloria_files', file_n
|
|
|
time0 = time.time()
|
|
|
client.load_collection(collection_name=collection_name)
|
|
|
time1 = time.time()
|
|
|
- logger_search.info(f"load 时间: {time1 - time0} s")
|
|
|
+
|
|
|
|
|
|
search_filter = f'file_name == "{file_name}"' if file_name else ''
|
|
|
|
|
@@ -170,9 +170,9 @@ async def get_filedata_from_milvus(query, collection_name='gloria_files', file_n
|
|
|
entity = hit.get('entity')
|
|
|
final_answer.append({'file_name': entity.get("file_name"), 'content': entity.get("content")})
|
|
|
except Exception as e:
|
|
|
- logger_search.error(e)
|
|
|
+
|
|
|
final_answer = []
|
|
|
- logger_search.info(f'从milvus文件知识库搜索到{len(final_answer)}条信息')
|
|
|
+
|
|
|
return final_answer
|
|
|
|
|
|
async def rerank_file_data(query, data):
|
|
@@ -185,7 +185,7 @@ async def rerank_file_data(query, data):
|
|
|
lookup[content] = item
|
|
|
chunks.append(content)
|
|
|
# 执行重排序
|
|
|
- logger_search.info(chunks)
|
|
|
+
|
|
|
reranked = bge_rerank(query=query, indicates=chunks, socre=0.3)
|
|
|
|
|
|
|
|
@@ -195,12 +195,12 @@ async def rerank_file_data(query, data):
|
|
|
|
|
|
|
|
|
send_data = [dict(t) for t in {tuple(d.items()) for d in send_data}]
|
|
|
- logger_search.info(send_data)
|
|
|
+
|
|
|
for sed in send_data:
|
|
|
if not sed.get('file_name') in final_chunks:
|
|
|
final_chunks[sed.get('file_name')] = sed.get('content')
|
|
|
final_chunks[sed.get('file_name')] += sed.get('content')
|
|
|
- logger_search.info(final_chunks)
|
|
|
+
|
|
|
search_results_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
|
|
|
|
|
|
return final_chunks, search_results_str
|
|
@@ -217,7 +217,7 @@ async def get_search_results(query: Annotated[str, "需要搜索的问题"], sum
|
|
|
search_str = "\n".join([f"[{i+1}]: \n {k}: \n{s} \n" for i, (k,s) in enumerate(final_chunks.items())])
|
|
|
final_data = final_chunks
|
|
|
|
|
|
- logger_search.info(search_str)
|
|
|
+
|
|
|
return final_data, search_str
|
|
|
|
|
|
|
|
@@ -225,8 +225,8 @@ if __name__ == '__main__':
|
|
|
from agent import get_content_summary
|
|
|
final_data, result = asyncio.run(get_search_results(query='每月盘点时间有要求吗?'))
|
|
|
data, res = asyncio.run(get_content_summary(question='每月盘点时间有要求吗?',res_info=result, final_data=final_data))
|
|
|
- logger_search.info(data)
|
|
|
- logger_search.info(res)
|
|
|
+
|
|
|
+
|
|
|
# x = asyncio.run(create_collection(collection_name='gloria_files'))
|
|
|
# print(x)
|
|
|
pass
|