12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import os
- from PIL import Image
- import torch
- from torchvision import transforms
- from transformers import AutoModelForImageSegmentation
- from tqdm import tqdm
- class ImageSegmenter:
- def __init__(self, model_path='models/RMBG-2.0'):
- """
- 初始化图像分割器,加载模型
- Args:
- model_path (str): 模型的路径
- """
- self.model = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True)
- torch.set_float32_matmul_precision(['high', 'highest'][0])
- self.model.to('cuda')
- self.model.eval()
- # 数据设置
- self.image_size = (1024, 1024)
- self.transform_image = transforms.Compose([
- transforms.Resize(self.image_size),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- def process_image(self, input_image_path):
- """
- 处理输入图像,生成掩膜并保存掩膜图像
- Args:
- input_image_path (str): 输入图像的路径
- Returns:
- PIL.Image.Image: 生成的掩膜图像
- """
- # 读取输入图像
- image = Image.open(input_image_path)
- input_images = self.transform_image(image).unsqueeze(0).to('cuda')
- # 预测
- with torch.no_grad():
- preds = self.model(input_images)[-1].sigmoid().cpu()
- pred = preds[0].squeeze()
- pred_pil = transforms.ToPILImage()(pred)
- mask = pred_pil.resize(image.size)
- # 保存掩膜图像
- mask_output_path = f"data/mask_image/mask-{os.path.splitext(os.path.basename(input_image_path))[0]}.png"
- mask.save(mask_output_path) # 保存掩膜图像
- return mask
-
- def mask_image_list(image_paths):
- """
- 处理图像列表,返回掩膜列表
- Args:
- image_paths (list): 输入图像路径列表
- Returns:
- list: 生成的掩膜图像列表
- """
- segmenter = ImageSegmenter() # 创建图像分割器实例
- masks = []
- for image_path in tqdm(image_paths):
- mask = segmenter.process_image(image_path) # 处理每个图像
- masks.append(mask) # 将掩膜添加到列表中
- return masks
- # 使用示例
- if __name__ == "__main__":
- model_path = 'models/RMBG-2.0' # 模型路径
- input_image_path = "data/key_frame/frame_00000000.jpg" # 输入图像路径
- mask_output_path = "rmbgmask_image.png" # 掩膜图像保存路径
- segmenter = ImageSegmenter(model_path) # 创建图像分割器实例
- mask_image = segmenter.process_image(input_image_path) # 处理图像
- print(type(mask_image))
|