VGG16模型实现MNIST图像分类

MNIST图像数据集

MNIST(Modified National Institute of Standards and Technology)是一个经典的机器学习数据集,常用于训练和测试图像处理和机器学习算法,特别是在数字识别领域。该数据集包含了大约 7 万张手写数字图片,其中 6 万张是用于训练,1 万张用于测试。每张图片都是 28x28 像素的灰度图像,展示了从 0 到 9 的手写数字。这些图像已经被处理过,以使得数字在图像中居中且尺寸一致。

MNIST 数据集是一个广泛被用于测试新的机器学习算法的基准,因为它相对较小,易于理解,且可以用于快速验证算法的有效性。许多人使用 MNIST 作为开始学习深度学习的入门数据集,因为它提供了一个简单但具有挑战性的任务,即将手写数字图像分类为相应的数字。

尽管 MNIST 已经存在了很长时间,但它仍然是一个重要的基准数据集,特别是对于新的机器学习研究和算法的初步测试。MINIST数据集中部分图片如下所示:

下载MNIST数据集

由于MINIST作为经典数据集,已经被内嵌在torchvision库中的dataset中了,所以直接使用代码datasets.MNIST进行下载即可。

下载后的文件格式如下图所示。

搭建VGG16图像分类模型

class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数

定义VGG网络结构如上所示,在上面代码中我定义了一个基于 VGG16 架构分类器的模型。VGG16 是一种经典的卷积神经网络模型,由 16 层深度的卷积层和全连接层组成,所构建的 VGGClassifier 类的网络结构包含两个主要部分:

特征提取器(features):这部分使用了预训练的 VGG16 模型的特征提取器。通过调用 models.vgg16(pretrained=True).features 来加载 VGG16 的特征提取器部分。然后,将第一层卷积层的输入通道数从 3 修改为 1,以适应 MNIST 数据集的灰度图像格式。

分类器(classifier):这部分是自定义的分类器,用于对提取的特征进行分类。首先,通过几个全连接层将特征图展平成一维张量,然后通过一系列的线性层和激活函数对特征进行处理。具体来说,包括:一个包含 256 个神经元的全连接层,输入维度为 512x7x7(经过 VGG16 的特征提取器后的输出尺寸),使用 ReLU 激活函数。一个 Dropout 层,用于防止过拟合,随机关闭一些神经元。一个包含 256 个神经元的全连接层,使用 ReLU 激活函数。再次添加一个 Dropout 层。最后是一个包含 num_classes 个神经元的全连接层,用于输出最终的类别预测结果。

通过上述方式,整个网络结构将 VGG16 的特征提取器和自定义的分类器相结合,以适应 MNIST 数据集的图像分类任务。

构建的VGG网络结构如下图所示:

VGG网络结构图

模型训练

# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。

模型参数设置如下表所示(代码见上)

模型超参数

数值

batchsize

16

num_epochs

5

learning_rate

0.001

num_classes

10

由于MINIST数据集样本数量较大,所以对于上述代码训练速度也会较慢,我考虑使用我的笔记本电脑独显进行运算,却发现电脑显存不够,于是我调小batchsize与epoch,并降低学习率learning rate才让GPU勉强能够运行上面代码,并获得到了模型model.pth,最终获得模型在测试集上面的识别精度为96.7%,精度还是比较高的。(由于笔记本电脑性能有限,在处理较大规模数据的小型项目时速度较慢,故上述代码运行了一下午左右的时间才跑完)。

模型测试

使用上面模型进行手写数字识别的检验。绘制一张图片上面含有9张子图,随机选取识别结果的9张进行展示 。识别效果以及运行结果如下图所示。

 

附录:

 VGG训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transformsimport warnings
warnings.filterwarnings("ignore")# 定义数据预处理操作
transform = transforms.Compose([transforms.Resize(224), # 将图像大小调整为(224, 224)transforms.ToTensor(),  # 将图像转换为PyTorch张量transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 256),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。nn.Linear(256, 256),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(256, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数def forward(self, x):x = self.features(x)  # 通过特征提取器提取特征x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量x = self.classifier(x)  # 通过分类器进行分类预测return xdef _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 定义超参数和训练参数
batch_size = 16  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.001  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)# 训练模型
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空梯度缓存loss.backward()  # 计算梯度optimizer.step()  # 更新权重参数if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),loss.item()))# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间for images, labels in test_loader:images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备outputs = model(images)  # 模型前向传播,得到预测结果_, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别total += labels.size(0)  # 更新总样本数量correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/56128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

线性代数在大一计算机课程中的重要性

线性代数在大一计算机课程中的重要性 线性代数是一门研究向量空间、矩阵运算和线性变换的数学学科,在计算机科学中有着广泛的应用。大一的计算机课程中,线性代数的学习为学生们掌握许多计算机领域的关键概念打下了坚实的基础。本文将介绍线性代数的基本…

高考技术——pandas使用

百家讲坛,谈论古今,今天我们不聊别的,我们来聊一聊中国的国宝——大熊猫(bushi) 好好,言归正传,我们今天来讲pandas import pandas as pd 申明无需多言,高考主要考察Series和Data…

Milvus向量数据库管理工具[Attu]实践

Attu是一款专为Milvus向量数据库打造的开源数据库管理工具,提供了便捷的图形化界面,极大地简化了对Milvus数据库的操作与管理流程。阿里云Milvus集成了Attu,以便更加高效地管理数据库、集合(Collection)、索引&#xf…

第四次论文问题知识点及问题

1、NP-hard问题 NP-hard,指所有NP问题都能在多项式时间复杂度内归约到的问题。 2、启发式算法 ‌‌启发式算法(heuristic algorithm)是相对于最优化算法提出的。它是一种基于直观或经验构造的算法,旨在以可接受的花费给出待解决…

树莓派3b安装ubuntu18.04服务器系统server配置网线连接

下载ubuntu镜像网址 img镜像,即树莓派官方烧录器使用的镜像网址 ubuntu18.04-server:ARM/RaspberryPi - Ubuntu Wiki 其他版本:Index of /ubuntu/releases 下载后解压即可。 发现使用官方烧录器烧录配置时配置wifi无论如何都不能使用&am…

应对网站IP劫持的有效策略与技术手段

摘要: IP劫持是一种常见的网络攻击方式,攻击者通过非法手段获取目标网站服务器的控制权,进而改变其网络流量的路由路径,导致用户访问错误的站点。本文将介绍如何识别IP劫持,并提供一系列预防和应对措施,以确…

android + tflite 分类APP开发-2

APP开发 build.gradle导入库 //implementation org.tensorflow:tensorflow-android: implementation org.tensorflow:tensorflow-lite:2.4.0 implementation org.tensorflow:tensorflow-lite-support:0.3.1 implementation org.tensorflow:tensorflow-lite-metada…

GO网络编程(三):海量用户通信系统1:登录功能初步

一、准备工作 需求分析 1)用户注册 2)用户登录 3)显示在线用户列表 4)群聊(广播) 5)点对点聊天 6)离线留言 主界面 首先,在项目根目录下初始化mod,然后按照如下结构设计目录: 海量用户通信系统/ ├── go.mod ├── client/ │ ├──…

【阅读笔记】水果轻微损伤的无损检测技术应用

一、水果轻微损伤检测技术以及应用 无损检测技术顾名思义就是指在不破坏水果样品完整性的情况下对样品进行品质鉴定。目前比较常用的农产品水果类无损检测法有:基于红外热成像、机器视觉技术的图像处理方法、光谱检测技术、介电特性技术检测法等。 1.1 基于红外热…

【C++】基于红黑树封装set和map

🚀个人主页:小羊 🚀所属专栏:C 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 前言一、更高维度的泛型二、模版参数三、比较逻辑的重写四、迭代器4.1 const迭代器4.2 重载4.3 - -重载 五、完整代…

在深度学习中,Epoch、迭代次数、批次大小(Batch Size)和学习速率(Learning Rate)是影响模型训练效果的重要超参数。

1. Epoch 定义:Epoch是指整个训练数据集被完整地用来训练一次。影响:增加Epoch的数量可以使模型更充分地学习数据。然而,过高的Epoch可能导致过拟合,即模型在训练集上表现良好,但在测试集上表现不佳。设置&#xff1a…

【C++设计模式】行为型模式:中介者模式

行为型模式:中介者模式 中介者模式通过引入一个中介者对象来集中控制对象之间的交互。这样可以解耦多个对象之间的复杂交互关系,使系统更易于维护和扩展。 假设我们有一个简单的聊天室应用,其中有每个用户可以发送群聊消息给其他用户&#…

阿里P8面试官推荐学习的11大专题:java面试精讲框架文档

本篇文章给大家分享一波,阿里P8面试官推荐学习的11大专题:java面试精讲框架文档,主要包含11大块的内容:spring、springcloud、netty、zookeeper、kafka、Hadoop、HBASE、Cassandra、elasticsearch、spark、flink;希望大…

【C++入门篇 - 3】:从C到C++第二篇

文章目录 从C到C第二篇new和delete命名空间命名空间的访问 cin和coutstring的基本使用 从C到C第二篇 new和delete 在C中用来向系统申请堆区的内存空间 New的作用相当于C语言中的malloc Delete的作用相当于C语言中的free 注意:在C语言中,如果内存不够…

stm32定时器中断和外部中断

一,中断系统的介绍 中断:在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行 中…

Github 优质项目推荐(第七期):涵盖免费服务、API、低代码、安卓root、深度学习

文章目录 Github优质项目推荐 - 第七期一、【LangGPT】,5.7k stars - 让每个人都成为提示专家二、【awesome-selfhosted】,198k stars - 免费软件网络服务和 Web 应用程序列表三、【public-apis】,315k stars - 免费 API四、【JeecgBoot】&am…

mysql游标的使用

说明: 虽然我们也可以通过筛选条件 WHERE 和 HAVING,或者是限定返回记录的关键字 LIMIT 返回一条记录,但是,却无法在结果集中像指针一样,向前定位一条记录、向后定位一条记录,或者是 随意定位到某一条记录 …

No.3 笔记 | Web安全基础:Web1.0 - 3.0 发展史

大家好!作为一个喜欢探索本质的INTP,我整理了一份简明易懂的Web安全笔记。希望能帮助你轻松掌握这个领域的核心知识。 这份笔记涵盖了Web发展的历程,从静态的Web 1.0到智能化的Web 3.0。我们将探讨URL和HTTP协议,揭示它们在网络中…

新书速览|你好,C++

《你好,C》 本书内容 《你好,C》主要介绍C开发环境的搭建、基础语法知识、面向对象编程思想以及标准模板库的应用,特别针对初学者在学习C过程中可能遇到的难点提供了解决方案。全书共分13章,以一个工资程序的不断优化和完善为线索…

pds 开发流程(pango design suite)使用方法

author: hjjdebug date: 2024年 10月 12日 星期六 13:24:55 CST pds 开发流程(pango design suite)使用方法 基于 Pango Design Suite(PDS) 的FPGA开发流程 盘古设计开发包, 是一个集成开发环境, 就是说把很多功能就集中在了一起的意思. 我…