利用Pytorch预训练模型进行图像分类

Use Pre-trained models for Image Classification.

# This post is rectified on the base of https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/# And we have re-orginaized the code script.

预训练模型(Pre-trained models)是在ImageNet等大型基准数据集上训练的神经网络模型。深度学习社区从这些开源模型中受益匪浅。此外,预训练模型也是计算机视觉研究取得快速进展的一个重要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是从头开始重新训练。

# Here are some examples of classic pre-trained models.

在这里插入图片描述

在详细介绍如何使用预训练模型进行图像分类之前,我们先来看看有哪些预训练模型。我们将以 AlexNet 和 ResNet101 为例进行讨论。这两个网络都在 ImageNet 数据集上训练过。

ImageNet 数据集拥有超过 1400 万张由斯坦福大学维护的图像。它被广泛用于各种与图像相关的深度学习项目。这些图像属于不同的类别或标签。预训练模型(如 AlexNet 和 ResNet101)的目的是将图像作为输入并预测其类别。

这里的 "预训练 "是指,深度学习架构 AlexNet 和 ResNet101 已经在某个(庞大的)数据集上进行过训练,因此带有由此产生的权重和偏差。架构与权重和偏置之间的区别应该非常明显,因为我们将在下一节看到,TorchVision 同时拥有架构和预训练模型。

1.1 Model Inference Process

由于我们将重点讨论如何使用预先训练好的模型来预测输入的类别(标签),因此我们也来讨论一下其中涉及的过程。这个过程被称为模型推理。整个过程包括以下主要步骤:

(1) 读取输入图像;
(2) 对图像进行转换;例如resize、center crop、normalization等;
(3) 前向传递:使用预训练的模型权重来获得输出向量,而输出向量中的每个元素都描述了模型对于输入图像属于特定类别的置信度预测结果;
(4) 预测结果:基于获得的置信度分数,显示预测结果。

1.2 Loading Pre-Trained Network using TorchVision

# [Optinal Step]
# %pip install torchvision
# Load necessary packages.
from PIL import Image
import torch
import torchvision
from torchvision import models
from torchvision import transformsprint(torch.__version__)
print(torchvision.__version__)
2.0.0
0.15.0
# Check the different models and architectures available to us.
dir(models)
['AlexNet','AlexNet_Weights','ConvNeXt','ConvNeXt_Base_Weights','ConvNeXt_Large_Weights','ConvNeXt_Small_Weights','ConvNeXt_Tiny_Weights','DenseNet','DenseNet121_Weights','DenseNet161_Weights','DenseNet169_Weights','DenseNet201_Weights','EfficientNet','EfficientNet_B0_Weights','EfficientNet_B1_Weights','EfficientNet_B2_Weights','EfficientNet_B3_Weights','EfficientNet_B4_Weights','EfficientNet_B5_Weights','EfficientNet_B6_Weights','EfficientNet_B7_Weights','EfficientNet_V2_L_Weights','EfficientNet_V2_M_Weights','EfficientNet_V2_S_Weights','GoogLeNet','GoogLeNetOutputs','GoogLeNet_Weights','Inception3','InceptionOutputs','Inception_V3_Weights','MNASNet','MNASNet0_5_Weights','MNASNet0_75_Weights','MNASNet1_0_Weights','MNASNet1_3_Weights','MaxVit','MaxVit_T_Weights','MobileNetV2','MobileNetV3','MobileNet_V2_Weights','MobileNet_V3_Large_Weights','MobileNet_V3_Small_Weights','RegNet','RegNet_X_16GF_Weights','RegNet_X_1_6GF_Weights','RegNet_X_32GF_Weights','RegNet_X_3_2GF_Weights','RegNet_X_400MF_Weights','RegNet_X_800MF_Weights','RegNet_X_8GF_Weights','RegNet_Y_128GF_Weights','RegNet_Y_16GF_Weights','RegNet_Y_1_6GF_Weights','RegNet_Y_32GF_Weights','RegNet_Y_3_2GF_Weights','RegNet_Y_400MF_Weights','RegNet_Y_800MF_Weights','RegNet_Y_8GF_Weights','ResNeXt101_32X8D_Weights','ResNeXt101_64X4D_Weights','ResNeXt50_32X4D_Weights','ResNet','ResNet101_Weights','ResNet152_Weights','ResNet18_Weights','ResNet34_Weights','ResNet50_Weights','ShuffleNetV2','ShuffleNet_V2_X0_5_Weights','ShuffleNet_V2_X1_0_Weights','ShuffleNet_V2_X1_5_Weights','ShuffleNet_V2_X2_0_Weights','SqueezeNet','SqueezeNet1_0_Weights','SqueezeNet1_1_Weights','SwinTransformer','Swin_B_Weights','Swin_S_Weights','Swin_T_Weights','Swin_V2_B_Weights','Swin_V2_S_Weights','Swin_V2_T_Weights','VGG','VGG11_BN_Weights','VGG11_Weights','VGG13_BN_Weights','VGG13_Weights','VGG16_BN_Weights','VGG16_Weights','VGG19_BN_Weights','VGG19_Weights','ViT_B_16_Weights','ViT_B_32_Weights','ViT_H_14_Weights','ViT_L_16_Weights','ViT_L_32_Weights','VisionTransformer','Weights','WeightsEnum','Wide_ResNet101_2_Weights','Wide_ResNet50_2_Weights','_GoogLeNetOutputs','_InceptionOutputs','__builtins__','__cached__','__doc__','__file__','__loader__','__name__','__package__','__path__','__spec__','_api','_meta','_utils','alexnet','convnext','convnext_base','convnext_large','convnext_small','convnext_tiny','densenet','densenet121','densenet161','densenet169','densenet201','detection','efficientnet','efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4','efficientnet_b5','efficientnet_b6','efficientnet_b7','efficientnet_v2_l','efficientnet_v2_m','efficientnet_v2_s','get_model','get_model_builder','get_model_weights','get_weight','googlenet','inception','inception_v3','list_models','maxvit','maxvit_t','mnasnet','mnasnet0_5','mnasnet0_75','mnasnet1_0','mnasnet1_3','mobilenet','mobilenet_v2','mobilenet_v3_large','mobilenet_v3_small','mobilenetv2','mobilenetv3','optical_flow','quantization','regnet','regnet_x_16gf','regnet_x_1_6gf','regnet_x_32gf','regnet_x_3_2gf','regnet_x_400mf','regnet_x_800mf','regnet_x_8gf','regnet_y_128gf','regnet_y_16gf','regnet_y_1_6gf','regnet_y_32gf','regnet_y_3_2gf','regnet_y_400mf','regnet_y_800mf','regnet_y_8gf','resnet','resnet101','resnet152','resnet18','resnet34','resnet50','resnext101_32x8d','resnext101_64x4d','resnext50_32x4d','segmentation','shufflenet_v2_x0_5','shufflenet_v2_x1_0','shufflenet_v2_x1_5','shufflenet_v2_x2_0','shufflenetv2','squeezenet','squeezenet1_0','squeezenet1_1','swin_b','swin_s','swin_t','swin_transformer','swin_v2_b','swin_v2_s','swin_v2_t','vgg','vgg11','vgg11_bn','vgg13','vgg13_bn','vgg16','vgg16_bn','vgg19','vgg19_bn','video','vision_transformer','vit_b_16','vit_b_32','vit_h_14','vit_l_16','vit_l_32','wide_resnet101_2','wide_resnet50_2']

以AlexNet为例,我们可以看到还有一个名称为alexnet的条目。其中,大写的名称是Python类(AlexNet),而alexnet是一个便于操作的函数(convenience function),用于从AlexNet类返回实例化的模型。

这些方便函数也可以有不同的参数集,例如:densenet121、densenet161、densenet169以及densenet201,都是DenseNet的实例,但层数分别为121,161,169和201.

1.3. Using AlexNet for Image Classification

AlexnetNet是图像识别领域早期的一个突破性网络结构,相关文章可以参考Understanding Alexnet。该网络架构如下:

在这里插入图片描述

Step 1: Load the pre-trained model

# Create an instance of the network.
alexnet = models.alexnet(pretrained=True)
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
# Note: Pytorch模型的扩展名通常为.pt或.pth
# Check the model details.
print(alexnet)
AlexNet((features): Sequential((0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))(1): ReLU(inplace=True)(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): ReLU(inplace=True)(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU(inplace=True)(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU(inplace=True)(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))(classifier): Sequential((0): Dropout(p=0.5, inplace=False)(1): Linear(in_features=9216, out_features=4096, bias=True)(2): ReLU(inplace=True)(3): Dropout(p=0.5, inplace=False)(4): Linear(in_features=4096, out_features=4096, bias=True)(5): ReLU(inplace=True)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

Step 2: Specify image transformations

# Use transforms to compose all the data transformations.
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])     # Three numbers for RGB Channels.
# transforms.Resize: Resize the input images to 256x256 pixels.
# transforms.CenterCrop: Crop the image to 224×224 pixels about the center.
# transforms.Normalize: Normalize the image by setting its mean and standard deviation to the specified values.
# transforms.ToTensor: Convert the image to Pytorch tensor datatype.

Step 3: Load the input image and pre-process it.

# Download image
# !wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
img = Image.open("./dog.jpg")
img

在这里插入图片描述

# Pre-process the image.
trans_img = transform(img)img_batch = torch.unsqueeze(trans_img, 0)

Step 4: Model Inference

# Set the model to eval model.
alexnet.eval()out = alexnet(img_batch)
print(out.shape)
torch.Size([1, 1000])
# Download classes text file
!wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
--2023-12-14 21:30:09--  https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.
# Load labels.
with open('imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]
# Find out the maximum score.
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())
Labrador retriever 41.58513259887695
# The model predicts the image to be of a Labrador Retriever with a 41.58% confidence.
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[('Labrador retriever', 41.58513259887695),('golden retriever', 16.59164810180664),('Saluki, gazelle hound', 16.286897659301758),('whippet', 2.8539111614227295),('Ibizan hound, Ibizan Podenco', 2.39247727394104)]

1.4. Using ResNet for Image Classification

# Load the resnet101 model.
resnet = models.resnet101(pretrained=True)# Set the model to eval mode.
resnet.eval()# carry out model inference.
out = resnet(img_batch)# Print the top 5 classes predicted by the model.
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/wsl_ubuntu/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:51<00:00, 3.47MB/s] [('Labrador retriever', 48.255577087402344),('dingo, warrigal, warragal, Canis dingo', 7.900773048400879),('golden retriever', 6.91691780090332),('Eskimo dog, husky', 3.6434390544891357),('bull mastiff', 3.046128273010254)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222034.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大型科技公司与初创公司:选择哪一个?

你有没有想过&#xff0c;特别是在你职业生涯的开始&#xff0c;选择什么类型的公司&#xff1f;它应该是一家像谷歌、亚马逊、Meta 这样的大型科技公司&#xff0c;还是为一家小型初创公司工作。在本文中&#xff0c;我们将讨论实际差异是什么&#xff0c;并帮助你选择最适合你…

『OPEN3D』1.8.3 多份点云配准

多份点云配准是将多份点云数据在全局空间中对齐的过程。通常,输入是一组数据(例如点云或RGBD图像){Pi}。输出是一组刚性变换{Ti},使得经过变换的点云在全局空间中对齐。 NNNNNathan 本专栏地址: https://blog.csdn.net/qq_41366026/category_12186023.html 此处是…

DHCP—动态主机配置协议

动态主机配置协议DHCP&#xff08;Dynamic Host Configuration Protocol&#xff0c;动态主机配置协议&#xff09;是RFC 1541&#xff08;已被RFC 2131取代&#xff09;定义的标准协议&#xff0c;该协议允许服务器向客户端动态分配IP地址和配置信息。 DHCP协议支持C/S&#x…

RocketMQ 总体概括

目录 概述RocketMQ 领域模型MQ 解决的问题电商平台案例初步设计引入中间件设计 MQ 选型结束 概述 官网地址 RocketMQ 领域模型 官方领域模型概述 下面图&#xff0c;是在自己理解的基础上&#xff0c;对官方的模型图添加了一些。 Topic&#xff1a;主题&#xff0c;可以理解…

Java网络编程——基于UDP的数据报和套接字

java.net.ServerSocket与java.net.Socket建立在TCP的基础上。TCP是网络传输层的一种可靠的数据传输协议。如果数据在传输途中被丢失或损坏&#xff0c;那么TCP会保证再次发送数据&#xff1b;如果数据到达接收方的顺序被打乱&#xff0c;那么TCP会在接收方重新恢复数据的正确顺…

扬声器(喇叭)

扬声器(喇叭) 电子元器件百科 文章目录 扬声器(喇叭)前言一、扬声器(喇叭)是什么二、扬声器(喇叭)的类别三、扬声器(喇叭)的应用场景四、扬声器(喇叭)的作用原理总结前言 扬声器广泛应用于音响系统、公共广播系统、汽车音响、电视、电脑和移动设备等各种电子设备…

Linux基本开发工具

编译器和自动化构建工具 一、编译器——gcc、g1. 安装 gcc/g2. 使用3. 链接库4. 拓展命令&#xff1a;od/file/ldd/readelf 二、自动化构建项目——make、makefile1. 介绍2. 使用例子touch——change file timestampsstat——display file or file system status修改时间 .PHON…

Qt 文字描边(基础篇)

项目中有时需要文字描边的功能 1.基础的绘制文字 使用drawtext处理 void MainWindow::paintEvent(QPaintEvent *event) {QPainter painter(this);painter.setRenderHint(QPainter::Antialiasing, true);painter.setRenderHint(QPainter::SmoothPixmapTransform, true);painte…

ceph的osd盘删除操作和iscsi扩展

ceph的osd盘删除操作 拓展:osd磁盘的删除(这里以删除node1上的osd.0磁盘为例) 1, 查看osd磁盘状态 [rootnode1 ceph]# ceph osd tree ID CLASS WEIGHT TYPE NAME STATUS REWEIGHT PRI-AFF -1 0.00298 root default -3 0.00099 host node10 hdd 0.00…

【Vins轨迹】pose_graph位姿图加载EVO精度评定

1. Vins的位姿图加载功能 如果想要对slam运行后的位姿轨迹进行评定&#xff0c;需要将数据保存到output文件夹中。 其中pose_graph.txt含有的信息&#xff1a;关键帧id、时间戳、vio的xyz、优化后的xyz、vio的四元数、优化后的四元数、回环到的关键帧id、回环信息&#xff08…

【十】python复合模式

10.1 复合模式简介 在前面的栏目中我们了解了各种设计模式。正如我们所看到的&#xff0c;设计模式可分为三大类:结构型、创建型和行为型设计模式。同时&#xff0c;我们还给出了每种类型的相应示例。然而&#xff0c;在软件实现中&#xff0c;模式并是不孤立地工作的。对于所…

HPM5300系列--第一篇 命令行开发调试环境搭建

一、目的 在之前的博客中《HPM6750系列--第二篇 搭建Ubuntu开发环境》、 《HPM6750系列--第三篇 搭建MACOS编译和调试环境》我们介绍了HPM6750evkmini开发环境的搭建过程&#xff0c;由于HPM5300系列共用一套hpm-sdk&#xff0c;故HPM5300的开发调试环境的搭建过程基本和之前的…

智能故障诊断期刊推荐【中文期刊】

控制与决策 http://kzyjc.alljournals.cn/kzyjc/home 兵工学报 http://www.co-journal.com/CN/1000-1093/home.shtml 计算机集成制造系统 http://jsjjc.soripan.net/ 机械工程学报 http://www.cjmenet.com.cn/CN/0577-6686/home.shtml 太阳能学报 https://www.tynxb.org.c…

Visual Studio Code中的任务配置文件tasks.json中的可选任务组tasks详解

☞ ░ 前往老猿Python博客 ░ https://blog.csdn.net/LaoYuanPython 一、引言 vscode是支持通过配置可以实现类似Visual C等IDE开发工具使用菜单和快捷键直接进行程序编译构建的&#xff0c;这样构建的任务可以结合后续的调试配置进行IDE环境的程序调试&#xff0c;不过在之前…

12. IO

1.File类 • File 类代表与平台无关的文件和目录。 • File 能新建、删除、重命名文件和目录&#xff0c;但 File 不能访问文件内容本身。如果需要访问文件内容本身&#xff0c;则需要使用输入/输出流。 1).File的常用方法 在这里插入图片描述 2).遍历给定目录所有文件 …

MySQL增删改查

查询数据 MySQL 数据库使用 SQL SELECT 语句来查询数据。以下为在 MySQL 数据库中查询数据通用的 SELECT 语法&#xff1a; SELECT column_name,column_name FROM table_name [WHERE Clause] [LIMIT N][ OFFSET M] 查询语句中你可以使用一个或者多个表&#xff0c;表之间使用…

联想笔记本如何安装Vmware ESXi

环境&#xff1a; Vmware ESXi 8.0 Vmware ESXi 6.7 联想E14笔记本 问题描述&#xff1a; 联想笔记本如何安装Vmware ESXi 解决方案&#xff1a; 1.官网下载镜像文件 https://customerconnect.vmware.com/en/downloads/search?queryesxi%208 下载 2.没有账户注册一个 …

什么时候使用匿名类,匿名类解决了什么问题?为什么需要匿名类 ?

匿名类通常在以下场景下使用&#xff1a; 一次性使用&#xff1a; 当你需要创建一个类的实例&#xff0c;但该类只在一个地方使用&#xff0c;而不打算在其他地方重复使用时&#xff0c;可以考虑使用匿名类。 简化代码&#xff1a; 当创建一个小型的、一次性的类会让代码更简洁…

浅析特征增强个性化在CTR预估中的经典方法和效果对比

在CTR预估中&#xff0c;主流都采用特征embeddingMLP的方式&#xff0c;其中特征非常关键。然而对于相同的特征&#xff0c;在不同的样本中&#xff0c;表征是相同的&#xff0c;这种方式输入到下游模型&#xff0c;会限制模型的表达能力。为了解决这个问题&#xff0c;CTR预估…

【每日一题】用邮票贴满网格图

文章目录 Tag题目来源题目解读解题思路方法一&#xff1a;二维前缀和二维差分 写在最后 Tag 【二维前缀和】【二维差分】【矩阵】【2023-12-14】 题目来源 2132. 用邮票贴满网格图 题目解读 在 01 矩阵中&#xff0c;判断是否可以用给定尺寸的邮票将所有 0 位置都覆盖住&…