1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from transformers import AutoTokenizer, AutoModel
- import torch
- import numpy as np
- from config import bge_model, bge_rerank_model, bge_rerank_tokenizer, logger_search
- def model_embedding_bge(chunk):
- try:
- if isinstance(chunk, str):
- sentences = [chunk]
- else:
- sentences = chunk
- # 尝试以更安全的方式调用 bge_model.encode
- try:
- # 检查模型是否有 "use_fp16" 属性并暂时禁用
- temp_use_fp16 = None
- if hasattr(bge_model, 'use_fp16'):
- temp_use_fp16 = bge_model.use_fp16
- bge_model.use_fp16 = False
-
- embeddings = bge_model.encode(
- sentences,
- batch_size=12,
- max_length=8192,
- )['dense_vecs']
- result = []
- for emd in embeddings:
- result.append(emd.tolist())
- # 恢复原始设置
- if temp_use_fp16 is not None:
- bge_model.use_fp16 = temp_use_fp16
-
- except RuntimeError as e:
- if "expected scalar type Float but found Half" in str(e):
- # 尝试将模型转为 float32 并重试
- if hasattr(bge_model, 'model') and hasattr(bge_model.model, 'float'):
- bge_model.model.float()
-
- # 禁用 fp16
- if hasattr(bge_model, 'use_fp16'):
- bge_model.use_fp16 = False
-
- embeddings = bge_model.encode(
- sentences,
- batch_size=12,
- max_length=8192,
- )['dense_vecs']
- result = []
- for emd in embeddings:
- result.append(emd.tolist())
- else:
- raise e
- return result if len(result) > 1 else result[0]
- except Exception as e:
- raise e
- def fill_embed_nan(vector):
- return np.nan_to_num(vector, nan=0.0, posinf=1.0, neginf=-1.0)
- def bge_rerank(query, indicates, n=5, socre=0.6):
-
- with torch.no_grad():
- results = {}
- for chos in indicates:
- inputs_1 = bge_rerank_tokenizer([[query, chos]], padding=True, truncation=True, return_tensors='pt', max_length=512).to(device='cuda:0')
- scores_1 = bge_rerank_model(**inputs_1, return_dict=True).logits.view(-1, ).float()[0]
- probs = torch.sigmoid(scores_1)
- if probs > socre:
- results[chos] = probs
- sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
- final_result = [x[0] for x in sorted_results]
- return final_result
- if __name__ == '__main__':
- pass
|