利用Pytorch预训练模型进行图像分类

Use Pre-trained models for Image Classification.

# This post is rectified on the base of https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/# And we have re-orginaized the code script.

预训练模型(Pre-trained models)是在ImageNet等大型基准数据集上训练的神经网络模型。深度学习社区从这些开源模型中受益匪浅。此外,预训练模型也是计算机视觉研究取得快速进展的一个重要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是从头开始重新训练。

# Here are some examples of classic pre-trained models.

在这里插入图片描述

在详细介绍如何使用预训练模型进行图像分类之前,我们先来看看有哪些预训练模型。我们将以 AlexNet 和 ResNet101 为例进行讨论。这两个网络都在 ImageNet 数据集上训练过。

ImageNet 数据集拥有超过 1400 万张由斯坦福大学维护的图像。它被广泛用于各种与图像相关的深度学习项目。这些图像属于不同的类别或标签。预训练模型(如 AlexNet 和 ResNet101)的目的是将图像作为输入并预测其类别。

这里的 "预训练 "是指,深度学习架构 AlexNet 和 ResNet101 已经在某个(庞大的)数据集上进行过训练,因此带有由此产生的权重和偏差。架构与权重和偏置之间的区别应该非常明显,因为我们将在下一节看到,TorchVision 同时拥有架构和预训练模型。

1.1 Model Inference Process

由于我们将重点讨论如何使用预先训练好的模型来预测输入的类别(标签),因此我们也来讨论一下其中涉及的过程。这个过程被称为模型推理。整个过程包括以下主要步骤:

(1) 读取输入图像;
(2) 对图像进行转换;例如resize、center crop、normalization等;
(3) 前向传递:使用预训练的模型权重来获得输出向量,而输出向量中的每个元素都描述了模型对于输入图像属于特定类别的置信度预测结果;
(4) 预测结果:基于获得的置信度分数,显示预测结果。

1.2 Loading Pre-Trained Network using TorchVision

# [Optinal Step]
# %pip install torchvision
# Load necessary packages.
from PIL import Image
import torch
import torchvision
from torchvision import models
from torchvision import transformsprint(torch.__version__)
print(torchvision.__version__)
2.0.0
0.15.0
# Check the different models and architectures available to us.
dir(models)
['AlexNet','AlexNet_Weights','ConvNeXt','ConvNeXt_Base_Weights','ConvNeXt_Large_Weights','ConvNeXt_Small_Weights','ConvNeXt_Tiny_Weights','DenseNet','DenseNet121_Weights','DenseNet161_Weights','DenseNet169_Weights','DenseNet201_Weights','EfficientNet','EfficientNet_B0_Weights','EfficientNet_B1_Weights','EfficientNet_B2_Weights','EfficientNet_B3_Weights','EfficientNet_B4_Weights','EfficientNet_B5_Weights','EfficientNet_B6_Weights','EfficientNet_B7_Weights','EfficientNet_V2_L_Weights','EfficientNet_V2_M_Weights','EfficientNet_V2_S_Weights','GoogLeNet','GoogLeNetOutputs','GoogLeNet_Weights','Inception3','InceptionOutputs','Inception_V3_Weights','MNASNet','MNASNet0_5_Weights','MNASNet0_75_Weights','MNASNet1_0_Weights','MNASNet1_3_Weights','MaxVit','MaxVit_T_Weights','MobileNetV2','MobileNetV3','MobileNet_V2_Weights','MobileNet_V3_Large_Weights','MobileNet_V3_Small_Weights','RegNet','RegNet_X_16GF_Weights','RegNet_X_1_6GF_Weights','RegNet_X_32GF_Weights','RegNet_X_3_2GF_Weights','RegNet_X_400MF_Weights','RegNet_X_800MF_Weights','RegNet_X_8GF_Weights','RegNet_Y_128GF_Weights','RegNet_Y_16GF_Weights','RegNet_Y_1_6GF_Weights','RegNet_Y_32GF_Weights','RegNet_Y_3_2GF_Weights','RegNet_Y_400MF_Weights','RegNet_Y_800MF_Weights','RegNet_Y_8GF_Weights','ResNeXt101_32X8D_Weights','ResNeXt101_64X4D_Weights','ResNeXt50_32X4D_Weights','ResNet','ResNet101_Weights','ResNet152_Weights','ResNet18_Weights','ResNet34_Weights','ResNet50_Weights','ShuffleNetV2','ShuffleNet_V2_X0_5_Weights','ShuffleNet_V2_X1_0_Weights','ShuffleNet_V2_X1_5_Weights','ShuffleNet_V2_X2_0_Weights','SqueezeNet','SqueezeNet1_0_Weights','SqueezeNet1_1_Weights','SwinTransformer','Swin_B_Weights','Swin_S_Weights','Swin_T_Weights','Swin_V2_B_Weights','Swin_V2_S_Weights','Swin_V2_T_Weights','VGG','VGG11_BN_Weights','VGG11_Weights','VGG13_BN_Weights','VGG13_Weights','VGG16_BN_Weights','VGG16_Weights','VGG19_BN_Weights','VGG19_Weights','ViT_B_16_Weights','ViT_B_32_Weights','ViT_H_14_Weights','ViT_L_16_Weights','ViT_L_32_Weights','VisionTransformer','Weights','WeightsEnum','Wide_ResNet101_2_Weights','Wide_ResNet50_2_Weights','_GoogLeNetOutputs','_InceptionOutputs','__builtins__','__cached__','__doc__','__file__','__loader__','__name__','__package__','__path__','__spec__','_api','_meta','_utils','alexnet','convnext','convnext_base','convnext_large','convnext_small','convnext_tiny','densenet','densenet121','densenet161','densenet169','densenet201','detection','efficientnet','efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_b3','efficientnet_b4','efficientnet_b5','efficientnet_b6','efficientnet_b7','efficientnet_v2_l','efficientnet_v2_m','efficientnet_v2_s','get_model','get_model_builder','get_model_weights','get_weight','googlenet','inception','inception_v3','list_models','maxvit','maxvit_t','mnasnet','mnasnet0_5','mnasnet0_75','mnasnet1_0','mnasnet1_3','mobilenet','mobilenet_v2','mobilenet_v3_large','mobilenet_v3_small','mobilenetv2','mobilenetv3','optical_flow','quantization','regnet','regnet_x_16gf','regnet_x_1_6gf','regnet_x_32gf','regnet_x_3_2gf','regnet_x_400mf','regnet_x_800mf','regnet_x_8gf','regnet_y_128gf','regnet_y_16gf','regnet_y_1_6gf','regnet_y_32gf','regnet_y_3_2gf','regnet_y_400mf','regnet_y_800mf','regnet_y_8gf','resnet','resnet101','resnet152','resnet18','resnet34','resnet50','resnext101_32x8d','resnext101_64x4d','resnext50_32x4d','segmentation','shufflenet_v2_x0_5','shufflenet_v2_x1_0','shufflenet_v2_x1_5','shufflenet_v2_x2_0','shufflenetv2','squeezenet','squeezenet1_0','squeezenet1_1','swin_b','swin_s','swin_t','swin_transformer','swin_v2_b','swin_v2_s','swin_v2_t','vgg','vgg11','vgg11_bn','vgg13','vgg13_bn','vgg16','vgg16_bn','vgg19','vgg19_bn','video','vision_transformer','vit_b_16','vit_b_32','vit_h_14','vit_l_16','vit_l_32','wide_resnet101_2','wide_resnet50_2']

以AlexNet为例,我们可以看到还有一个名称为alexnet的条目。其中,大写的名称是Python类(AlexNet),而alexnet是一个便于操作的函数(convenience function),用于从AlexNet类返回实例化的模型。

这些方便函数也可以有不同的参数集,例如:densenet121、densenet161、densenet169以及densenet201,都是DenseNet的实例,但层数分别为121,161,169和201.

1.3. Using AlexNet for Image Classification

AlexnetNet是图像识别领域早期的一个突破性网络结构,相关文章可以参考Understanding Alexnet。该网络架构如下:

在这里插入图片描述

Step 1: Load the pre-trained model

# Create an instance of the network.
alexnet = models.alexnet(pretrained=True)
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
# Note: Pytorch模型的扩展名通常为.pt或.pth
# Check the model details.
print(alexnet)
AlexNet((features): Sequential((0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))(1): ReLU(inplace=True)(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): ReLU(inplace=True)(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU(inplace=True)(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU(inplace=True)(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))(classifier): Sequential((0): Dropout(p=0.5, inplace=False)(1): Linear(in_features=9216, out_features=4096, bias=True)(2): ReLU(inplace=True)(3): Dropout(p=0.5, inplace=False)(4): Linear(in_features=4096, out_features=4096, bias=True)(5): ReLU(inplace=True)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

Step 2: Specify image transformations

# Use transforms to compose all the data transformations.
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])     # Three numbers for RGB Channels.
# transforms.Resize: Resize the input images to 256x256 pixels.
# transforms.CenterCrop: Crop the image to 224×224 pixels about the center.
# transforms.Normalize: Normalize the image by setting its mean and standard deviation to the specified values.
# transforms.ToTensor: Convert the image to Pytorch tensor datatype.

Step 3: Load the input image and pre-process it.

# Download image
# !wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
img = Image.open("./dog.jpg")
img

在这里插入图片描述

# Pre-process the image.
trans_img = transform(img)img_batch = torch.unsqueeze(trans_img, 0)

Step 4: Model Inference

# Set the model to eval model.
alexnet.eval()out = alexnet(img_batch)
print(out.shape)
torch.Size([1, 1000])
# Download classes text file
!wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
--2023-12-14 21:30:09--  https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.
# Load labels.
with open('imagenet_classes.txt') as f:classes = [line.strip() for line in f.readlines()]
# Find out the maximum score.
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())
Labrador retriever 41.58513259887695
# The model predicts the image to be of a Labrador Retriever with a 41.58% confidence.
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[('Labrador retriever', 41.58513259887695),('golden retriever', 16.59164810180664),('Saluki, gazelle hound', 16.286897659301758),('whippet', 2.8539111614227295),('Ibizan hound, Ibizan Podenco', 2.39247727394104)]

1.4. Using ResNet for Image Classification

# Load the resnet101 model.
resnet = models.resnet101(pretrained=True)# Set the model to eval mode.
resnet.eval()# carry out model inference.
out = resnet(img_batch)# Print the top 5 classes predicted by the model.
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/wsl_ubuntu/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:51<00:00, 3.47MB/s] [('Labrador retriever', 48.255577087402344),('dingo, warrigal, warragal, Canis dingo', 7.900773048400879),('golden retriever', 6.91691780090332),('Eskimo dog, husky', 3.6434390544891357),('bull mastiff', 3.046128273010254)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222034.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++标识线程

c标识线程 线程ID类型为std::thread::id&#xff0c;它有两种方式获取。 直接通过std::thread对象的成员函数get_id()来获取。如果thread对象没有与任何执行线程相关联&#xff0c;get_id()将返回std::thread::id对象&#xff0c;它按照默认的构造方式生成&#xff0c;表示线程…

大型科技公司与初创公司:选择哪一个?

你有没有想过&#xff0c;特别是在你职业生涯的开始&#xff0c;选择什么类型的公司&#xff1f;它应该是一家像谷歌、亚马逊、Meta 这样的大型科技公司&#xff0c;还是为一家小型初创公司工作。在本文中&#xff0c;我们将讨论实际差异是什么&#xff0c;并帮助你选择最适合你…

『OPEN3D』1.8.3 多份点云配准

多份点云配准是将多份点云数据在全局空间中对齐的过程。通常,输入是一组数据(例如点云或RGBD图像){Pi}。输出是一组刚性变换{Ti},使得经过变换的点云在全局空间中对齐。 NNNNNathan 本专栏地址: https://blog.csdn.net/qq_41366026/category_12186023.html 此处是…

orcad(Cadence)常用库olb介绍

基础库&#xff1a; CAPSYM.OLB capsym 共35个零件&#xff0c;存放电源&#xff0c;地&#xff0c;输入输出口&#xff0c;标题栏等。DISCRETE.OLB 共872个零件&#xff0c;存放分立式元件&#xff0c;如电阻&#xff0c;电容&#xff0c;电感&#xff0c;开关&#xff0c;变…

DHCP—动态主机配置协议

动态主机配置协议DHCP&#xff08;Dynamic Host Configuration Protocol&#xff0c;动态主机配置协议&#xff09;是RFC 1541&#xff08;已被RFC 2131取代&#xff09;定义的标准协议&#xff0c;该协议允许服务器向客户端动态分配IP地址和配置信息。 DHCP协议支持C/S&#x…

pom配置文件重要标签探究

文章目录 dependencies标签dependencyManagement标签两者辨析repositories标签properties标签 dependencies标签 <dependencies>标签用于指定项目的依赖项列表。这些依赖项可以是应用程序代码所需的库&#xff0c;也可以是Spring Boot和其他第三方库。<dependencies&…

RocketMQ 总体概括

目录 概述RocketMQ 领域模型MQ 解决的问题电商平台案例初步设计引入中间件设计 MQ 选型结束 概述 官网地址 RocketMQ 领域模型 官方领域模型概述 下面图&#xff0c;是在自己理解的基础上&#xff0c;对官方的模型图添加了一些。 Topic&#xff1a;主题&#xff0c;可以理解…

打家劫舍Ⅱ java

题目描述 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋&#xff0c;每间房内都藏有一定的现金。这个地方所有的房屋都 围成一圈 &#xff0c;这意味着第一个房屋和最后一个房屋是紧挨着的。同时&#xff0c;相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的…

深入解析MySQL中内连接、外连接的区别及实践应用

​嗨&#xff0c;大家好&#xff0c;欢迎来到程序猿漠然公众号&#xff0c;我是漠然。在数据库查询中&#xff0c;连接是一种常用的操作&#xff0c;用于从两个或多个表中获取数据。本文将详细介绍MySQL中的内连接、外连接的概念、区别以及实践应用&#xff0c;帮助大家更好地理…

Java网络编程——基于UDP的数据报和套接字

java.net.ServerSocket与java.net.Socket建立在TCP的基础上。TCP是网络传输层的一种可靠的数据传输协议。如果数据在传输途中被丢失或损坏&#xff0c;那么TCP会保证再次发送数据&#xff1b;如果数据到达接收方的顺序被打乱&#xff0c;那么TCP会在接收方重新恢复数据的正确顺…

vscode报错Pylance client: couldn‘t create connection to server.

问题描述&#xff1a; 一打开vscode&#xff0c;右下角就弹报错&#xff0c;Pylance client: couldn’t create connection to server.&#xff0c;让我打开output&#xff0c;打开后似乎是在说连不上server 因为连不上server&#xff0c;所以我的python代码没法解析&#xff0…

2.4 C语言之运算符

2.4 C语言之运算符 一、算术运算符二、关系运算符三、逻辑运算符四、自增自减运算符五、按位运算符六、赋值运算符七、条件表达式八、运算符优先级与求值次序 一、算术运算符 二元算术运算符包括&#xff1a;(加)、-(减)、*(乘)、/(除)、%(取模) 整数除法会截断结果中的小数部…

扬声器(喇叭)

扬声器(喇叭) 电子元器件百科 文章目录 扬声器(喇叭)前言一、扬声器(喇叭)是什么二、扬声器(喇叭)的类别三、扬声器(喇叭)的应用场景四、扬声器(喇叭)的作用原理总结前言 扬声器广泛应用于音响系统、公共广播系统、汽车音响、电视、电脑和移动设备等各种电子设备…

Linux基本开发工具

编译器和自动化构建工具 一、编译器——gcc、g1. 安装 gcc/g2. 使用3. 链接库4. 拓展命令&#xff1a;od/file/ldd/readelf 二、自动化构建项目——make、makefile1. 介绍2. 使用例子touch——change file timestampsstat——display file or file system status修改时间 .PHON…

Qt 文字描边(基础篇)

项目中有时需要文字描边的功能 1.基础的绘制文字 使用drawtext处理 void MainWindow::paintEvent(QPaintEvent *event) {QPainter painter(this);painter.setRenderHint(QPainter::Antialiasing, true);painter.setRenderHint(QPainter::SmoothPixmapTransform, true);painte…

ceph的osd盘删除操作和iscsi扩展

ceph的osd盘删除操作 拓展:osd磁盘的删除(这里以删除node1上的osd.0磁盘为例) 1, 查看osd磁盘状态 [rootnode1 ceph]# ceph osd tree ID CLASS WEIGHT TYPE NAME STATUS REWEIGHT PRI-AFF -1 0.00298 root default -3 0.00099 host node10 hdd 0.00…

【Vins轨迹】pose_graph位姿图加载EVO精度评定

1. Vins的位姿图加载功能 如果想要对slam运行后的位姿轨迹进行评定&#xff0c;需要将数据保存到output文件夹中。 其中pose_graph.txt含有的信息&#xff1a;关键帧id、时间戳、vio的xyz、优化后的xyz、vio的四元数、优化后的四元数、回环到的关键帧id、回环信息&#xff08…

【Py/Java/C++三种语言OD2023C卷真题】20天拿下华为OD笔试【DP】2023C-分班【欧弟算法】全网注释最详细分类最全的华为OD真题题解

文章目录 题目描述与示例题目描述输入描述输出描述示例一输入输出 示例二输入输出 解题思路代码PythonJavaC时空复杂度 华为OD算法/大厂面试高频题算法练习冲刺训练 题目描述与示例 题目描述 幼儿园两个班的小朋友在排队时混在了一起&#xff0c;每位小朋友都知道自己是否与前…

【十】python复合模式

10.1 复合模式简介 在前面的栏目中我们了解了各种设计模式。正如我们所看到的&#xff0c;设计模式可分为三大类:结构型、创建型和行为型设计模式。同时&#xff0c;我们还给出了每种类型的相应示例。然而&#xff0c;在软件实现中&#xff0c;模式并是不孤立地工作的。对于所…

HPM5300系列--第一篇 命令行开发调试环境搭建

一、目的 在之前的博客中《HPM6750系列--第二篇 搭建Ubuntu开发环境》、 《HPM6750系列--第三篇 搭建MACOS编译和调试环境》我们介绍了HPM6750evkmini开发环境的搭建过程&#xff0c;由于HPM5300系列共用一套hpm-sdk&#xff0c;故HPM5300的开发调试环境的搭建过程基本和之前的…