rmbg.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. from PIL import Image
  3. import torch
  4. from torchvision import transforms
  5. from transformers import AutoModelForImageSegmentation
  6. from tqdm import tqdm
  7. class ImageSegmenter:
  8. def __init__(self, model_path='models/RMBG-2.0'):
  9. """
  10. 初始化图像分割器,加载模型
  11. Args:
  12. model_path (str): 模型的路径
  13. """
  14. self.model = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True)
  15. torch.set_float32_matmul_precision(['high', 'highest'][0])
  16. self.model.to('cuda')
  17. self.model.eval()
  18. # 数据设置
  19. self.image_size = (1024, 1024)
  20. self.transform_image = transforms.Compose([
  21. transforms.Resize(self.image_size),
  22. transforms.ToTensor(),
  23. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  24. ])
  25. def process_image(self, input_image_path):
  26. """
  27. 处理输入图像,生成掩膜并保存掩膜图像
  28. Args:
  29. input_image_path (str): 输入图像的路径
  30. Returns:
  31. PIL.Image.Image: 生成的掩膜图像
  32. """
  33. # 读取输入图像
  34. image = Image.open(input_image_path)
  35. input_images = self.transform_image(image).unsqueeze(0).to('cuda')
  36. # 预测
  37. with torch.no_grad():
  38. preds = self.model(input_images)[-1].sigmoid().cpu()
  39. pred = preds[0].squeeze()
  40. pred_pil = transforms.ToPILImage()(pred)
  41. mask = pred_pil.resize(image.size)
  42. # 保存掩膜图像
  43. mask_output_path = f"data/mask_image/mask-{os.path.splitext(os.path.basename(input_image_path))[0]}.png"
  44. mask.save(mask_output_path) # 保存掩膜图像
  45. return mask
  46. def mask_image_list(image_paths):
  47. """
  48. 处理图像列表,返回掩膜列表
  49. Args:
  50. image_paths (list): 输入图像路径列表
  51. Returns:
  52. list: 生成的掩膜图像列表
  53. """
  54. segmenter = ImageSegmenter() # 创建图像分割器实例
  55. masks = []
  56. for image_path in tqdm(image_paths):
  57. mask = segmenter.process_image(image_path) # 处理每个图像
  58. masks.append(mask) # 将掩膜添加到列表中
  59. return masks
  60. # 使用示例
  61. if __name__ == "__main__":
  62. model_path = 'models/RMBG-2.0' # 模型路径
  63. input_image_path = "data/key_frame/frame_00000000.jpg" # 输入图像路径
  64. mask_output_path = "rmbgmask_image.png" # 掩膜图像保存路径
  65. segmenter = ImageSegmenter(model_path) # 创建图像分割器实例
  66. mask_image = segmenter.process_image(input_image_path) # 处理图像
  67. print(type(mask_image))