前提条件:labelimg打标签得到bbox
1.代码
import torchfrom segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import os
import glob
import xml.etree.ElementTree as ETcheckpoint = "./weight/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)image_dir = r"D:\Desktop\mult_test\images"
# 获取图片目录下的所有图片文件路径image_files = glob.glob(os.path.join(image_dir, '*.[jJpPeEgG]*')) # 获取任意格式的图片
save_dir = r"D:\Desktop\mult_test\mask"
# 注释文件目录路径
xml_dir = r'D:\Desktop\mult_test\label'# 遍历图片文件
for image_file in image_files:image = cv2.imread(image_file)predictor.set_image(image)# 获取图片文件名(不包含扩展名)image_filename = os.path.splitext(os.path.basename(image_file))[0]# 构建注释文件路径xml_file = os.path.join(xml_dir,image_filename + '.xml')tree = ET.parse(xml_file)root = tree.getroot()data_list = []# 遍历 XML 标注文件中的目标对象for object_elem in root.findall('object'):# 获取目标对象的边界框坐标bbox_elem = object_elem.find('bndbox')xmin = int(bbox_elem.find('xmin').text)ymin = int(bbox_elem.find('ymin').text)xmax = int(bbox_elem.find('xmax').text)ymax = int(bbox_elem.find('ymax').text)data = [xmin,ymin,xmax,ymax]data_list.append(data)input_boxes = torch.tensor(data_list, device=predictor.device)transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])masks, _, _ = predictor.predict_torch(point_coords=None,point_labels=None,boxes=transformed_boxes,multimask_output=False,)first_mask = np.where(masks[0].cpu().numpy()[0, :, :] == 1, 0, 1) * 255for i in range(1, len(masks)):first_mask &= np.where(masks[i].cpu().numpy()[0, :, :] == 1, 0, 1) * 255image_filename = os.path.basename(image_file)cv2.imwrite(os.path.join(save_dir, image_filename), first_mask)
2.效果展示