PyTorch: 基于【VGG16】处理MNIST数据集的图像分类任务【准确率98.9%+】

目录

  • 引言
  • 在Conda虚拟环境下安装pytorch
  • 步骤一:利用代码自动下载mnist数据集
  • 步骤二:搭建基于VGG16的图像分类模型
  • 步骤三:训练模型
  • 步骤四:测试模型
  • 运行结果
  • 后续模型的优化和改进建议
  • 完整代码
  • 结束语

引言

在本博客中,小编将向大家介绍如何使用VGG16处理MNIST数据集的图像分类任务。MNIST数据集是一个常用的手写数字分类数据集,包含60,000个训练样本和10,000个测试样本。我们将使用Python编程语言和PyTorch深度学习框架来实现这个任务。

在Conda虚拟环境下安装pytorch

# CUDA 11.6
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# CUDA 11.3
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# CUDA 10.2
pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102
# CPU only
pip install torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cpu

步骤一:利用代码自动下载mnist数据集

import torchvision.datasets as datasets  
import torchvision.transforms as transforms  # 定义数据预处理操作  
transform = transforms.Compose([transforms.Resize(224), # 将图像大小调整为(224, 224)transforms.ToTensor(),  # 将图像转换为PyTorch张量transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])# 下载并加载MNIST数据集  
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

步骤二:搭建基于VGG16的图像分类模型

class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构VGG16网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,这样可以有效地防止过拟合。nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数def forward(self, x):x = self.features(x)  # 通过特征提取器提取特征x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量x = self.classifier(x)  # 通过分类器进行分类预测return xdef _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

步骤三:训练模型

import torch.optim as optim  
from torch.utils.data import DataLoader  # 定义超参数和训练参数  
batch_size = 64  # 批处理大小  
num_epochs = 5  # 训练轮数
learning_rate = 0.01  # 学习率
num_classes = 10  # 类别数(MNIST数据集有10个类别)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用GPU进行训练,否则使用CPU。# 定义训练集和测试集的数据加载器  
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)  
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)  # 初始化模型和优化器  
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)  
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数  
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)  # 训练模型  
for epoch in range(num_epochs):  for i, (images, labels) in enumerate(train_loader):  images = images.to(device)  # 将图像数据移动到指定设备  labels = labels.to(device)  # 将标签数据移动到指定设备  # 前向传播  outputs = model(images)  loss = criterion(outputs, labels)  # 反向传播和优化  optimizer.zero_grad()  # 清空梯度缓存  loss.backward()  # 计算梯度  optimizer.step()  # 更新权重参数  if (i+1) % 100 == 0:  # 每100个batch打印一次训练信息  print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))  # 保存模型参数  
torch.save(model.state_dict(), './model.pth')

步骤四:测试模型

# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间for images, labels in test_loader:images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备outputs = model(images)  # 模型前向传播,得到预测结果_, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别total += labels.size(0)  # 更新总样本数量correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

运行结果

在这里插入图片描述

后续模型的优化和改进建议

  1. 数据增强:通过旋转、缩放、平移等方式来增加训练数据,从而让模型拥有更好的泛化能力。
  2. 调整模型参数:可以尝试调整模型的参数,比如学习率、批次大小、迭代次数等,来提高模型的性能。
  3. 更换网络结构:可以尝试使用更深的网络结构,如ResNet、DenseNet等,来提高模型的性能。
  4. 调整优化器:本次代码采用SGD优化器,但仍可以尝试使用不同的优化器,如Adam、RMSprop等,来找到最适合我们模型的优化器。
  5. 添加正则化操作:为了防止过拟合,可以添加一些正则化项,如L1正则化、L2正则化等。
  6. 代码目前只有等训练完全结束后才能进入测试阶段,后续可以在每个epoch结束,甚至是指定的迭代次数完成后便进入测试阶段。因为训练完全结束的模型很可能已经过拟合,在测试集上不能表现较强的泛化能力。

完整代码

import torch
import torch.nn as nnimport torch.optim as optim
from torch.utils.data import DataLoaderimport torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transformsimport warnings
warnings.filterwarnings("ignore")# 定义数据预处理操作
transform = transforms.Compose([transforms.Resize(224), # 将图像大小调整为(224, 224)transforms.ToTensor(),  # 将图像转换为PyTorch张量transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化
])# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)class VGGClassifier(nn.Module):def __init__(self, num_classes):super(VGGClassifier, self).__init__()self.features = models.vgg16(pretrained=True).features  # 使用预训练的VGG16模型作为特征提取器# 重构网络的第一层卷积层,适配mnist数据的灰度图像格式self.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 添加一个全连接层,输入特征维度为512x7x7,输出维度为4096nn.ReLU(True),nn.Dropout(), # 随机将一些神经元“关闭”,有效地防止过拟合。nn.Linear(4096, 4096),  # 添加一个全连接层,输入和输出维度都为4096nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),  # 添加一个全连接层,输入维度为4096,输出维度为类别数(10))self._initialize_weights()  # 初始化权重参数def forward(self, x):x = self.features(x)  # 通过特征提取器提取特征x = x.view(x.size(0), -1)  # 将特征张量展平为一维向量x = self.classifier(x)  # 通过分类器进行分类预测return xdef _initialize_weights(self):  # 定义初始化权重的方法,使用Xavier初始化方法for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 定义超参数和训练参数
batch_size = 64  # 批处理大小
num_epochs = 5  # 训练轮数(epoch)
learning_rate = 0.01  # 学习率(learning rate)
num_classes = 10  # 类别数(MNIST数据集有10个类别)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU进行训练,如果有GPU则使用第一个GPU(cuda:0)进行训练,否则使用CPU进行训练。# 定义数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 初始化模型和优化器
model = VGGClassifier(num_classes=num_classes).to(device)  # 将模型移动到指定设备(GPU或CPU)
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # 使用随机梯度下降优化器(SGD)# 训练模型
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空梯度缓存loss.backward()  # 计算梯度optimizer.step()  # 更新权重参数if (i + 1) % 100 == 0:  # 每100个batch打印一次训练信息print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),loss.item()))# 训练结束,保存模型参数
torch.save(model.state_dict(), './model.pth')# 加载训练好的模型参数
model.load_state_dict(torch.load('./model.pth'))
model.eval()  # 将模型设置为评估模式,关闭dropout等操作# 定义评估指标变量
correct = 0  # 记录预测正确的样本数量
total = 0  # 记录总样本数量# 测试模型性能
with torch.no_grad():  # 关闭梯度计算,节省内存空间for images, labels in test_loader:images = images.to(device)  # 将图像数据移动到指定设备labels = labels.to(device)  # 将标签数据移动到指定设备outputs = model(images)  # 模型前向传播,得到预测结果_, predicted = torch.max(outputs.data, 1)  # 取预测结果的最大值对应的类别作为预测类别total += labels.size(0)  # 更新总样本数量correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量# 计算模型准确率并打印出来
accuracy = 100 * correct / total  # 计算准确率,将正确预测的样本数量除以总样本数量并乘以100得到百分比形式的准确率。
print('Accuracy of the model on the test images: {} %'.format(accuracy))  # 打印出模型的准确率。

结束语

如果本博文对你有所帮助/启发,可以点个赞/收藏支持一下,如果能够持续关注,小编感激不尽~
如果有相关需求/问题需要小编帮助,欢迎私信~
小编会坚持创作,持续优化博文质量,给读者带来更好de阅读体验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/224087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

商城后台管理系统--->新闻简报(富文本编辑器,文章,图片上传)

在商城的项目里面需要添加新闻,使用富文本编辑器,我用的是 wangEditor这个编辑器挺好用的,而且也方便简单,官网也是中文的wangEditor 这是做的添加新闻的页面 我用的是SCUI框架,引入的是npm,具体可看官网 npm install wangedit…

【Docker实战】基于Dockerfile搭建LNMP+wordpress

一、项目背景和要求 公司在实际的生产环境中,需要使用Docker 技术在一台主机上创建LNMP服务并运行Wordpress网站平台。 然后对此服务进行相关的性能调优和管理工作 二、架构: nginx172.111.0.10docker-nginxmysql172.111.0.20docker-mysqlPHP172.111…

map 和 multimap 存储区别 、取消自动排序 unordered_map

测试代码 std::map<int, CString > Map1;Map1.insert({ 6, L"HN400*200*11*8" });Map1.insert({ 5, L"HN200*200*11*8" });Map1.insert({ 7, L"HN100*200*11*8" });Map1.insert({ 4, L"HN200*200*11*8" });Map1.insert({ 4, L…

【开发工具】最新VMWare无法识别USB设备,驱动错误,未知错误【2023.12.15】

解决方案1&#xff1a;在这里改下连接方式 多试试 解决方案2 控制面板卸载程序&#xff0c;进行VMWare的修复 解决方案3 对于Windows7系统&#xff0c;切换解决方案1的usb类型为3.1&#xff0c;并下载这个intel的驱动包到虚拟机里 https://www.intel.com/content/www/us/en/do…

科目三 换挡为什么要踩离合器

换挡时需要踩离合器为了切断动力传输&#xff0c;让变速器空转&#xff0c;齿轮才会同步&#xff0c;从而轻松挂挡。 在起步时&#xff0c;当车速达到15km/h时&#xff0c;从一挡换到二挡。 当车速达到25km/h时&#xff0c;可以换成三挡&#xff0c; 达到35km/h左右时&#xf…

高效电商策略:小红书集成CRM与广告推广无代码化

无代码开发的优势 随着科技的不断进步&#xff0c;无代码开发&#xff08;No-Code Development&#xff09;已经成为快速构建系统和应用的新趋势。无代码开发指的是不需要传统编程知识&#xff0c;通过图形化的用户界面和模型驱动逻辑来创建应用程序。这种方式让非技术背景的用…

金蝶云星空协同开发环境应用内执行SQL脚本

文章目录 金蝶云星空协同开发环境应用内执行SQL脚本 金蝶云星空协同开发环境应用内执行SQL脚本

中文字符串逆序输出

今天碰到这个题&#xff0c;让我逆序输出中文字符串&#xff0c;可给我烦死了&#xff0c;之前没有遇到过&#xff0c;也是查了资料才知道&#xff0c;让我太汗颜了。 英文字符串逆序输出很容易&#xff0c;开辟一块空间用来存放逆序后的字符串&#xff0c;从后往前遍历原字符串…

操作系统笔记——储存系统、文件系统(王道408)

文章目录 前言储存系统地址转换内存扩展覆盖交换 储存器分配——连续分配固定大小分区动态分区分配动态分区分配算法 储存器分配——非连续分配页式管理基本思想地址变换硬件快表&#xff08;TLB&#xff09;多级页表 段式管理段页式管理 虚拟储存器——基于交换的内存扩充技术…

题目:区间或 (蓝桥OJ 3691)

题目描述: 解题思路: 本题采用位运算.先求出全部数组每一位各自的前缀和,然后再判断区间内每一位区间和是否为0,不为0则乘上相应的2^n并将各个为的2^n相加,得ans. 实现原理图 题解: #include<bits/stdc.h> using namespace std;const int N 1e5 9;int a[N], prefix[35…

20231215给AIO-3399J适配Rockchip的原始Andoroid10的挖掘机开发板02

20231215给AIO-3399J适配Rockchip的原始Andoroid10的挖掘机开发板02 2023/12/15 15:37 【请严重注意&#xff1a;】如果刷不适配的SDK&#xff0c;可能会引起您的开发板【硬件发生物理】损坏&#xff01; 如果您按照本步骤刷机引起的一切后果&#xff0c;请自行承担责任&#x…

Day09 Liunx高级系统设计11-数据库1

MySQL 简介 数据库DB 数据库&#xff08; DataBase &#xff0c; DB &#xff09;从本质上讲就是一个文件系统&#xff0c;它能够将数据有组织地集合在一起&#xff0c;按照一定的规则长期存储到计算机的磁盘中&#xff0c;并且能够供多个用户共享和使用&#xff0c;同时&…

Linux篇:信号

一、信号的概念&#xff1a; ①进程必须识别能够处理信号&#xff0c;信号没有产生&#xff0c;也要具备处理信号的能力---信号的处理能力属于进程内置功能的一部分 ②进程即便是没有收到信号&#xff0c;也能知道哪些信号该怎么处理。 ③当进程真的受到了一个具体的信号的时候…

猫粮哪个牌子质量好性价比高?分享十款主食冻干猫粮品牌排行榜!

一款好的、健康的主粮对猫整体有很大的提升&#xff0c;主食作为猫的日常饮食&#xff0c;直接关乎着小猫是否能摄入充分的营养&#xff0c;达到最佳的理想状态&#xff0c;因此对于每一位铲屎官来说&#xff0c;主食选得好不好至关重要。面对种类众多的主食&#xff0c;很多人…

c/c++ 结构体、联合体、枚举

结构体 结构体内存对齐规则&#xff1a; 1、结构体的第一个成员对齐到结构体变量起始位置偏移量为0的地址处 2、其他成员变量要对齐到某个数字&#xff08;对齐数&#xff09;的整数倍的地址处。 对齐数&#xff1a;编译器默认的一个对齐数与该成员变量大小的较小值。 vs 中…

编程实际应用实例:洗车店会员管理系统操作教程

一、前言 洗车店在会员管理有时候需要一卡多用&#xff0c;基本也不需要做卡&#xff0c;直接报手机号或车牌号即可完成电子会员卡录入。 下面以 佳易王洗车店会员管理系统软件为例说明&#xff0c; 软件试用版下载或技术支持可以点击下方的官网卡片 如图&#xff1a;这个卡…

【教程】Autojs脚本实现暂停和超时重启功能的思路和示例代码

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 背景介绍 autojs本身不支持暂停脚本&#xff0c;现有网上大部分最直接的做法就是在每条语句后面添加检查是否暂停。当脚本功能和代码量非常打的时候&#xff0c;每一条语句后面都加检测&#xff0c;未免不太现实。…

RFID技术在物流仓储解决方案中的应用

行业现状 当前市场竞争日益激烈&#xff0c;提高生产效率、降低运营成本对企业至关重要。仓库和物流管理在各行业中广泛应用。建立完善的仓库管理流程&#xff0c;提高仓库周转率&#xff0c;减少资金占用&#xff0c;实现资产变现&#xff0c;降低仓库淘汰成本&#xff0c;是…

c++面经总结

C基础语法 C和c的区别 c中new和delete是对内存分配的运算符&#xff0c;取代了c中的malloc和free 标准c中的字符串类取代了标准c函数库头文件中的字符数组处理函数(c中没有字符串类型). 在c中&#xff0c;允许有相同的函数名&#xff0c;不过他们的参数类型不能完全相同&…

【状态机FSM 序列检测 饮料机_2023.12.1】

同步状态机 概念 同步状态机&#xff08;同一脉冲边沿触发&#xff09;&#xff1a;有限个离散状态及某状之间的转移 异步状态机无法综合 分类 Moore状态机 只和状态有关&#xff0c;与输入无关 Mealy状态机 和状态和输入都有关 Mealy型比Moore型少一个状态 结构 由状态寄…