embedding_search.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from transformers import AutoTokenizer, AutoModel
  2. import torch
  3. import numpy as np
  4. from config import bge_model, bge_rerank_model, bge_rerank_tokenizer, logger_search
  5. def model_embedding_bge(chunk):
  6. try:
  7. if isinstance(chunk, str):
  8. sentences = [chunk]
  9. else:
  10. sentences = chunk
  11. # 尝试以更安全的方式调用 bge_model.encode
  12. try:
  13. # 检查模型是否有 "use_fp16" 属性并暂时禁用
  14. temp_use_fp16 = None
  15. if hasattr(bge_model, 'use_fp16'):
  16. temp_use_fp16 = bge_model.use_fp16
  17. bge_model.use_fp16 = False
  18. embeddings = bge_model.encode(
  19. sentences,
  20. batch_size=12,
  21. max_length=8192,
  22. )['dense_vecs']
  23. result = []
  24. for emd in embeddings:
  25. result.append(emd.tolist())
  26. # 恢复原始设置
  27. if temp_use_fp16 is not None:
  28. bge_model.use_fp16 = temp_use_fp16
  29. except RuntimeError as e:
  30. if "expected scalar type Float but found Half" in str(e):
  31. # 尝试将模型转为 float32 并重试
  32. if hasattr(bge_model, 'model') and hasattr(bge_model.model, 'float'):
  33. bge_model.model.float()
  34. # 禁用 fp16
  35. if hasattr(bge_model, 'use_fp16'):
  36. bge_model.use_fp16 = False
  37. embeddings = bge_model.encode(
  38. sentences,
  39. batch_size=12,
  40. max_length=8192,
  41. )['dense_vecs']
  42. result = []
  43. for emd in embeddings:
  44. result.append(emd.tolist())
  45. else:
  46. raise e
  47. return result if len(result) > 1 else result[0]
  48. except Exception as e:
  49. raise e
  50. def fill_embed_nan(vector):
  51. return np.nan_to_num(vector, nan=0.0, posinf=1.0, neginf=-1.0)
  52. def bge_rerank(query, indicates, n=5, socre=0.6):
  53. with torch.no_grad():
  54. results = {}
  55. for chos in indicates:
  56. inputs_1 = bge_rerank_tokenizer([[query, chos]], padding=True, truncation=True, return_tensors='pt', max_length=512).to(device='cuda:0')
  57. scores_1 = bge_rerank_model(**inputs_1, return_dict=True).logits.view(-1, ).float()[0]
  58. probs = torch.sigmoid(scores_1)
  59. if probs > socre:
  60. results[chos] = probs
  61. sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
  62. final_result = [x[0] for x in sorted_results]
  63. return final_result
  64. if __name__ == '__main__':
  65. pass