Colab/PyTorch - 006 Mask RCNN Instance Segmentation

Colab/PyTorch - 006 Mask RCNN Instance Segmentation

  • 1. 源由
  • 2. 用 PyTorch 实现 Mask R-CNN
    • 2.1 输入输出
    • 2.2 预训练模型
    • 2.3 模型预测
    • 2.4 目标检测流程
    • 2.5 推理
      • 示例一
      • 示例二
      • 示例三
  • 3. 推断时间比较(CPU v.s. GPU)
  • 4. 总结
  • 5. 参考资料

1. 源由

在《Colab/PyTorch - 004 Torchvision Semantic Segmentation》的源由里面,我们分析了关于多因素(图像)分析难度进阶的一个列表。

随着我们对技术的深入,以及问题复杂度的增加,一个非常自然的想法是,当识别出物体的边界框,希望找出边界框内哪些像素属于该物体。Mask R-CNN 就是其中一种算法。

Mask R-CNN 的架构是 Faster R-CNN 的扩展,Faster R-CNN 架构包含以下组件:

  1. 卷积层:输入图像通过多个卷积层以创建特征图。如果你是初学者,可以将卷积层看作一个黑盒,它接收一个3通道的输入图像,并输出一个空间维度较小(7×7),但通道数量很多(512)的“图像”。
  2. 区域建议网络(RPN):卷积层的输出用于训练一个网络,该网络提出包含物体的区域。
  3. 分类器:同样的特征图也用于训练一个分类器,为边界框内的物体分配一个标签。

还记得 Faster R-CNN 比 Fast R-CNN 更快,因为特征图计算一次后可被 RPN 和分类器重复使用。

Mask R-CNN 更进一步。在将特征图输入 RPN 和分类器的同时,它还用这些特征图预测边界框内物体的二值掩码。Mask R-CNN 掩码预测部分的方法是,它是一个用于语义分割的全卷积网络(FCN)。唯一的区别在于,这个 FCN 应用于边界框,并且与 RPN 和分类器共享卷积层。

下图展示了一个非常高层次的架构。

在这里插入图片描述

2. 用 PyTorch 实现 Mask R-CNN

Colab上运行,需要将制作好的数据集上传Google云存储。
在这里插入图片描述照片可以直接下载,也可以复制到目录位置/content/drive/MyDrive/mask_rcnn/

# import necessary libraries
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as T
import torchvision
import torch
import numpy as np
import cv2
import random
import time
import os# Test on Google Drivefrom google.colab import drive
drive.mount('/content/drive')

2.1 输入输出

模型期望输入是一个形状为 (n, c, h, w) 的张量图像列表,值的范围在 0-1 之间。图像的尺寸不需要固定。

  • n 是图像的数量
  • c 是通道数,对于 RGB 图像来说是 3
  • h 是图像的高度
  • w 是图像的宽度

模型返回:

  • 边界框的坐标,
  • 模型预测存在于输入图像中的类别标签及其分数,
  • 每个类别标签对应的掩码。

2.2 预训练模型

# get the pretrained model from torchvision.models
# Note: pretrained=True will get the pretrained weights for the model.
# model.eval() to use the model for inference
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1`. You can also use `weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:01<00:00, 92.7MB/s]MaskRCNN((transform): GeneralizedRCNNTransform(Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])Resize(min_size=(800,), max_size=1333, mode='bilinear'))(backbone): BackboneWithFPN((body): IntermediateLayerGetter((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): FrozenBatchNorm2d(64, eps=0.0)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(64, eps=0.0)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(64, eps=0.0)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(256, eps=0.0)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): FrozenBatchNorm2d(256, eps=0.0)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(64, eps=0.0)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(64, eps=0.0)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(256, eps=0.0)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(64, eps=0.0)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(64, eps=0.0)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(256, eps=0.0)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(128, eps=0.0)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(128, eps=0.0)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(512, eps=0.0)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d(512, eps=0.0)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(128, eps=0.0)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(128, eps=0.0)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(512, eps=0.0)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(128, eps=0.0)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(128, eps=0.0)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(512, eps=0.0)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(128, eps=0.0)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(128, eps=0.0)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(512, eps=0.0)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d(1024, eps=0.0)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(256, eps=0.0)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(256, eps=0.0)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(1024, eps=0.0)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(512, eps=0.0)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(512, eps=0.0)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(2048, eps=0.0)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d(2048, eps=0.0)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(512, eps=0.0)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(512, eps=0.0)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(2048, eps=0.0)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d(512, eps=0.0)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d(512, eps=0.0)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d(2048, eps=0.0)(relu): ReLU(inplace=True))))(fpn): FeaturePyramidNetwork((inner_blocks): ModuleList((0): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)))(1): Conv2dNormActivation((0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)))(2): Conv2dNormActivation((0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)))(3): Conv2dNormActivation((0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))))(layer_blocks): ModuleList((0-3): 4 x Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(extra_blocks): LastLevelMaxPool()))(rpn): RegionProposalNetwork((anchor_generator): AnchorGenerator()(head): RPNHead((conv): Sequential((0): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)))(cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))(bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))))(roi_heads): RoIHeads((box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)(box_head): TwoMLPHead((fc6): Linear(in_features=12544, out_features=1024, bias=True)(fc7): Linear(in_features=1024, out_features=1024, bias=True))(box_predictor): FastRCNNPredictor((cls_score): Linear(in_features=1024, out_features=91, bias=True)(bbox_pred): Linear(in_features=1024, out_features=364, bias=True))(mask_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(14, 14), sampling_ratio=2)(mask_head): MaskRCNNHeads((0): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True))(1): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True))(2): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True))(3): Conv2dNormActivation((0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)))(mask_predictor): MaskRCNNPredictor((conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))(relu): ReLU(inplace=True)(mask_fcn_logits): Conv2d(256, 91, kernel_size=(1, 1), stride=(1, 1))))
)

2.3 模型预测

 # These are the classes that are available in the COCO-Dataset
COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign','parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow','elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A','handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket','bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl','banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table','N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone','microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]def get_prediction(img_path, threshold):"""get_predictionparameters:- img_path - path of the input imagemethod:- Image is obtained from the image path- the image is converted to image tensor using PyTorch's Transforms- image is passed through the model to get the predictions- masks, classes and bounding boxes are obtained from the model and soft masks are made binary(0 or 1) on masksie: eg. segment of cat is made 1 and rest of the image is made 0"""img = Image.open(img_path)transform = T.Compose([T.ToTensor()])img = transform(img)pred = model([img])pred_score = list(pred[0]['scores'].detach().numpy())pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]masks = masks[:pred_t+1]pred_boxes = pred_boxes[:pred_t+1]pred_class = pred_class[:pred_t+1]return masks, pred_boxes, pred_class
  • 图像是从图像路径中获取的。
  • 图像通过 PyTorch 的 transforms 转换为图像张量。
  • 图像通过模型进行预测。
  • 从模型中获取掩码、预测类别和边界框坐标,并将软掩码二值化(0 或 1)。例如,猫的部分被设为 1,图像的其余部分被设为 0。

每个预测对象的掩码被赋予一组预定义的 11 种颜色中的一种随机颜色,以便在输入图像上可视化掩码。

def random_colour_masks(image):"""random_colour_masksparameters:- image - predicted masksmethod:- the masks of each predicted object is given random colour for visualization"""colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]r = np.zeros_like(image).astype(np.uint8)g = np.zeros_like(image).astype(np.uint8)b = np.zeros_like(image).astype(np.uint8)r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0,10)]coloured_mask = np.stack([r, g, b], axis=2)return coloured_mask

2.4 目标检测流程

def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):"""instance_segmentation_apiparameters:- img_path - path to input imagemethod:- prediction is obtained by get_prediction- each mask is given random color- each mask is added to the image in the ration 1:0.8 with opencv- final output is displayed"""masks, boxes, pred_cls = get_prediction(img_path, threshold)img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)for i in range(len(masks)):rgb_mask = random_colour_masks(masks[i])img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)cv2.rectangle(img, (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1])),color=(0, 255, 0), thickness=rect_th)cv2.putText(img,pred_cls[i], (int(boxes[i][0][0]), int(boxes[i][0][1])), cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)plt.figure(figsize=(20,30))plt.imshow(img)plt.xticks([])plt.yticks([])plt.show()
  • 掩码、预测类别和边界框是通过 get_prediction 获取的。
  • 每个掩码从 11 种颜色的集合中随机赋予一种颜色。
  • 每个掩码以 1:0.5 的比例通过 OpenCV 添加到图像中。
  • 使用 cv2.rectangle 绘制边界框,并将类别名称标注为文本。
  • 显示最终输出。

2.5 推理

示例一

#!wget https://www.wsha.org/wp-content/uploads/banner-diverse-group-of-people-2.jpg -O mrcnn_standing_people.jpg
image_file = "mrcnn_standing_people.jpg"
full_image_path = os.path.join(directory_path, image_file)
download_image("https://www.wsha.org/wp-content/uploads/banner-diverse-group-of-people-2.jpg", full_image_path)instance_segmentation_api(full_image_path, 0.75)

在这里插入图片描述

示例二

#!wget https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/10best-cars-group-cropped-1542126037.jpg -O mrcnn_cars.jpg
image_file = "mrcnn_cars.jpg"
full_image_path = os.path.join(directory_path, image_file)
download_image("https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/10best-cars-group-cropped-1542126037.jpg", full_image_path)instance_segmentation_api(full_image_path, 0.9, rect_th=5, text_size=5, text_th=5)

在这里插入图片描述

示例三

#!wget https://cdn.pixabay.com/photo/2013/07/05/01/08/traffic-143391_960_720.jpg -O mrcnn_traffic.jpg
image_file = "mrcnn_traffic.jpg"
full_image_path = os.path.join(directory_path, image_file)
download_image("https://cdn.pixabay.com/photo/2013/07/05/01/08/traffic-143391_960_720.jpg", full_image_path)instance_segmentation_api(full_image_path, 0.6, rect_th=2, text_size=2, text_th=2)

在这里插入图片描述

3. 推断时间比较(CPU v.s. GPU)

def check_inference_time(image_path, gpu=False):model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)model.eval()img = Image.open(image_path)transform = T.Compose([T.ToTensor()])img = transform(img)if gpu:model.cuda()img = img.cuda()else:model.cpu()img = img.cpu()start_time = time.time()pred = model([img])end_time = time.time()return end_time-start_time# Let's run inference on all the downloaded images and average their inference time 
#img_paths = [path for path in os.listdir("./") if path.split(".")[-1].lower() in ["jpeg", "jpg", "png"] ]# Get a list of image paths in the specified directory
img_paths = [os.path.join(directory_path, path) for path in os.listdir(directory_path) if path.split(".")[-1].lower() in ["jpeg", "jpg", "png"]]gpu_time = sum([check_inference_time(img_path, gpu=True) for img_path in img_paths])/len(img_paths)
cpu_time = sum([check_inference_time(img_path, gpu=False) for img_path in img_paths])/len(img_paths)print('\n\nAverage Time take by the model with GPU = {}s\nAverage Time take by the model with CPU = {}s'.format(gpu_time, cpu_time))

GPU耗时显著优于CPU。

Average Time take by the model with GPU = 0.32508648525584827s
Average Time take by the model with CPU = 8.285651618784124s

4. 总结

总的来说,简单应用通用模型来解决一些应用类问题,并不复杂。

难点在于有效数据的收集,标记,以及特殊应用模型的建模以及学习。

好在,后面我们将会面对的实际问题,都有比较好的算法,比如:Yolo算法等。

测试代码:006 PyTorch Mask RCNN

5. 参考资料

【1】Colab/PyTorch - Getting Started with PyTorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue2组件的封装+antd

1.vue2表格的封装使用 表格使用 <standard-tables:columns"columns":dataSource"dataSource":loading"loading"bordered:pagination"{ ...pagination, onChange: onPageChange }"><div slot"warnType" slot-scope…

必应bing国内推广开户,全方位必应广告开户流程介绍!

在所有获客渠道中&#xff0c;搜索引擎广告成为企业扩大品牌影响力、精准触达目标客户的关键途径之一。作为全球领先的搜索引擎之一&#xff0c;必应&#xff08;Bing&#xff09;拥有庞大的用户群体和独特的市场优势&#xff0c;是企业不可忽视的营销阵地。云衔科技&#xff0…

vscode添加代办相关插件,提高开发效率

这里写目录标题 前言插件添加添加TODO Highlight安装TODO Highlight在项目中自定义需要高亮显示的关键字 TODO Tree安装TODO Tree插件 单行注释快捷键 前言 在前端开发中&#xff0c;我们经常会遇到一些未完成、有问题或需要修复的部分&#xff0c;但又暂时未完成或未确定如何处…

合约的值类型

基本数据类型&#xff1a;整数、枚举、布尔&#xff08;类似java的数据类型&#xff09;Address、Contract&#xff08;这两种是solidity特有的数据类型&#xff09;Fixed byte array&#xff08;定长字节数组&#xff09; Integer(int/uint) int/uint 以8位字节递增&#xf…

推荐ChatGPT4.0——数学建模

1.建模助手 2. 可直接上传文档分析 3.获取途径 现在商家有活动&#xff0c;仅仅需要19.9&#xff01;&#xff01;&#xff01;&#xff01; 现在有优惠&#xff1a; 推荐人写&#xff1a;love 周卡&#xff0c;半月卡&#xff0c;月卡优惠码是love&#xff0c; 会优惠10元…

表现层框架设计之表现层设计模式_1.MVC模式

1.MVC模式三个核心模块 MVC是一种目前广泛流行的软件设计模式。近年来&#xff0c;随着Java EE的成熟&#xff0c;MVC成为了Java EE平台上推荐的一种设计模式。MVC强制性地把一个应用的输入、处理、输出流程按照视图、控制、模型的方式进行分离&#xff0c;形成了控制器…

Github上传时报错The file path is empty的解决办法

问题截图 文件夹明明不是空的&#xff0c;却怎么都上传不上去。 解决方案&#xff1a; 打开隐藏文件的开关&#xff0c;删除原作者的.git文件 如图所示&#xff1a; 上传成功&#xff01;

全面掌握深度学习:从基础到前沿

引言&#xff1a;深入探索深度学习的世界 在人工智能&#xff08;AI&#xff09;的广阔领域中&#xff0c;深度学习已经成为最令人瞩目的技术之一。它不仅推动了科技的许多突破性进展&#xff0c;也正在改变我们的工作和生活方式。本博客旨在全面总结深度学习的关键知识点&…

Rust面试宝典第14题:旋转数组

题目 给定一个数组&#xff0c;将数组中的元素向右移动k个位置&#xff0c;其中k是非负数。要求如下&#xff1a; &#xff08;1&#xff09;尽可能想出更多的解决方案&#xff0c;至少有三种不同的方法可以解决这个问题。 &#xff08;2&#xff09;使用时间复杂度为O(n)和空间…

RAW转换和图像编辑工具:Capture One 23 Pro (win/mac)中文专业版

Capture One 23是一款功能强大的桌面版照片编辑软件&#xff0c;由丹麦PHASE ONE飞思数码公司开发。 以下是该软件的一些主要特点&#xff1a; 强大的RAW处理功能&#xff1a;Capture One 23支持多种品牌的相机和镜头&#xff0c;提供了丰富的RAW处理工具&#xff0c;包括曝光、…

安装ollama并部署大模型并测试

Ollama介绍 项目地址&#xff1a;ollama 官网地址&#xff1a; https://ollama.com 模型仓库&#xff1a;https://ollama.com/library API接口&#xff1a;api接口 Ollama 是一个基于 Go 语言开发的简单易用的本地大语言模型运行框架。可以将其类比为 docker&#xff08;同基…

【vue-3】动态属性绑定v-bind

1、文本动态绑定&#xff1a; <input type"text" v-bind:value"web.url"> 简写&#xff1a; <input type"text" :value"web.url"> 2、文字样式动态绑定 <b :class"{textColor:web.fontStatus}">vue学…

word页眉线如何置于文字上方

然后 敲黑板&#xff0c;点这里

【Java超详细的学习笔记】Java超详细的学习笔记,适合小白入门,Java语法学习,Java开发学习笔记,Java入门学习笔记

一&#xff0c;Java初学者学习笔记 Java超详细的学习笔记&#xff0c;点击我获取 1&#xff0c;Java基础语法 一个 Java 程序可以认为是一系列对象的集合&#xff0c;而这些对象通过调用彼此的方法来协同工作。下面简要介绍下类、对象、方法和实例变量的概念。 对象&#…

碳课堂|ISO 14067 产品碳足迹国际标准

为规范评估产品碳排放&#xff0c;国际标准化组织发布了《ISO14067&#xff1a;2018温室气体-产品碳足迹-量化要求及指南》&#xff0c;标准量化产品生命周期阶段&#xff08;包括从资源开采、原材料采购到产品的生产、使用和报废阶段&#xff09;的温室气体排放的通用标准。该…

功耗相关总结

文章目录 功耗相关的使用场景MCU中低功耗的应用RTOS中低功耗应用 功耗相关的使用场景 目前越来越多的嵌入式设备采用电池进行供电&#xff0c;而不是跟台式电脑一样&#xff0c;可以一直连接着电源。在电池供电的场景下&#xff0c;对功耗的要求很高&#xff0c;工程师们尽量希…

炫酷gdb

在VS里面调试很方便对吧&#xff1f;&#xff08;F5直接调试&#xff0c;F10逐过程调试--不进函数&#xff0c;F11逐语句调试--进函数&#xff0c;F9创建断点&#xff09;&#xff0c;那在Linux中怎么调试呢&#xff1f; 我们需要用到一个工具&#xff1a;gdb 我们知道VS中程…

从业务角度来看,DevOps 是什么?

如果您在我们的应用程序名称中看到“DevOps”&#xff0c;这意味着我们必须正确解释该术语&#xff0c;我们会这样做&#xff0c;但角度会有所不同。让我们从业务角度看看 DevOps 是什么。 通用名称 首先你应该知道&#xff0c;DevOps 没有明确的定义。是的。 大多数情况下&a…

安卓实现5个底部导航栏切换fragment

步骤&#xff0c;写 5 个 fragment 自定义的类5个布局文件&#xff1a; package com.xmkjsoft.xhgh.fragment;import android.os.Bundle; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup;import androidx.annotation.NonNul…

win11安装docker运行Open-Webui 界面化展示 ollama大模型

1.OpenWeb UI运行需要docker 环境下载docker Get Started | Docker 2.需要命令提示符docker -v 查询是否安装成功&#xff1b; 查询docker详情docker version 3.github拉取open-webUi镜像Package open-webui GitHub 复制命令运行在命令提示符&#xff1b; 等待下载完成 4.到…