import os from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from tqdm import tqdm class ImageSegmenter: def __init__(self, model_path='models/RMBG-2.0'): """ 初始化图像分割器,加载模型 Args: model_path (str): 模型的路径 """ self.model = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) self.model.to('cuda') self.model.eval() # 数据设置 self.image_size = (1024, 1024) self.transform_image = transforms.Compose([ transforms.Resize(self.image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def process_image(self, input_image_path): """ 处理输入图像,生成掩膜并保存掩膜图像 Args: input_image_path (str): 输入图像的路径 Returns: PIL.Image.Image: 生成的掩膜图像 """ # 读取输入图像 image = Image.open(input_image_path) input_images = self.transform_image(image).unsqueeze(0).to('cuda') # 预测 with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) # 保存掩膜图像 mask_output_path = f"data/mask_image/mask-{os.path.splitext(os.path.basename(input_image_path))[0]}.png" mask.save(mask_output_path) # 保存掩膜图像 return mask def mask_image_list(image_paths): """ 处理图像列表,返回掩膜列表 Args: image_paths (list): 输入图像路径列表 Returns: list: 生成的掩膜图像列表 """ segmenter = ImageSegmenter() # 创建图像分割器实例 masks = [] for image_path in tqdm(image_paths): mask = segmenter.process_image(image_path) # 处理每个图像 masks.append(mask) # 将掩膜添加到列表中 return masks # 使用示例 if __name__ == "__main__": model_path = 'models/RMBG-2.0' # 模型路径 input_image_path = "data/key_frame/frame_00000000.jpg" # 输入图像路径 mask_output_path = "rmbgmask_image.png" # 掩膜图像保存路径 segmenter = ImageSegmenter(model_path) # 创建图像分割器实例 mask_image = segmenter.process_image(input_image_path) # 处理图像 print(type(mask_image))