【Pytorch】16.使用ImageFolder加载自定义MNIST数据集训练手写数字识别网络(包含数据集下载)

数据集下载

MINST_PNG_Training在github的项目目录中的datasets中有MNIST的png格式数据集的压缩包

用于训练的神经网络模型

在这里插入图片描述

自定义数据集训练

在前文【Pytorch】13.搭建完整的CIFAR10模型我们已经知道了基本搭建神经网络的框架了,但是其中的数据集使用的torchvision中的CIFAR10官方数据集进行训练的

train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,transform=torchvision.transforms.ToTensor())

在这里插入图片描述

本文将用图片格式的数据集进行训练
在这里插入图片描述
我们通过

# Dataset CIFAR10
#     Number of datapoints: 60000
#     Root location: ../datasets
#     Split: Train
#     StandardTransform
# Transform: ToTensor()
print(train_dataset)

可以看到我们下载的数据集是这种格式的,所以我们的主要问题就是如何将自定义的数据集获取,并且转化为这种形式,剩下的步骤就和上文相同了

数据类型进行转化

我们的首要目的是,根据数据集的地址,分别将数据转化为train_datasettest_dataset
我们需要调用ImageFolder方法来进行操作

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)

首先我们需要将数据进行处理,通过transforms.Compose获取对象data_transform
其中进行了三步操作

  • 将图片大小变为28*28像素便于输入网络模型
  • 将图片转化为灰度格式,因为手写数字识别不需要三通道的图片,只需要灰度图像就可以识别,而png格式的图片是四通道
  • 将图片转化为tensor数据类型

然后通过ImageFolder给出图片的地址与转化类型,就可以实现与我们在官方下载数据集相同的格式

# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
print(train_dataset)

其他与前文【Pytorch】13.搭建完整的CIFAR10模型基本相同

完整代码

网络模型

import torch
from torch import nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(3136, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == "__main__":model = Net()input = torch.ones((1, 1, 28, 28))output = model(input)print(output.shape)

训练过程

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)# Dataset ImageFolder
#     Number of datapoints: 60000
#     Root location: ../datasets/mnist_png/training
#     StandardTransform
# Transform: Compose(
#                Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
#                ToTensor()
#            )
# print(train_dataset)# print(train_dataset[0])train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")model = Net().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)epoch = 10writer = SummaryWriter('../logs')
total_step = 0for i in range(epoch):model.train()pre_step = 0pre_loss = 0for data in train_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()pre_loss = pre_loss + loss.item()pre_step += 1total_step += 1if pre_step % 100 == 0:print(f"Epoch: {i+1} ,pre_loss = {pre_loss/pre_step}")writer.add_scalar('train_loss', pre_loss / pre_step, total_step)model.eval()pre_accuracy = 0with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)pre_accuracy += outputs.argmax(1).eq(labels).sum().item()print(f"Test_accuracy: {pre_accuracy/len(test_dataset)}")writer.add_scalar('test_accuracy', pre_accuracy / len(test_dataset), i)torch.save(model, f'../models/model{i}.pth')writer.close()

参考文章

【CNN】搭建AlexNet网络——并处理自定义的数据集(猫狗分类)
How to download MNIST images as PNGs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/14251.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flutter 中的 WidgetInspector 小部件:全面指南

Flutter 中的 WidgetInspector 小部件:全面指南 Flutter 的 WidgetInspector 是一个强大的工具,它允许开发者在运行时检查和操作他们的 widget 树。这在调试复杂的布局和 widget 结构时尤其有用。本文将为您提供一个全面的指南,帮助您了解如…

Excel 按顺序去重再编号

Excel的A有重复数据: A1Cow2Chicken3Horse4Butterfly5Cow 现在要去除重复,用自然数按顺序进行编号,结果写在相邻列: AB1Cow12Chicken23Horse34Butterfly45Cow1 使用 SPL XLL,输入公式并向下拖: spl(&q…

RISC-V压缩指令扩展测试

概述 RISC-V定义了压缩指令扩展(compressed instruction-set extension ),命名为“C”扩展。压缩指令使用16位宽指令替换32位宽指令,从而减少代码量。这个C扩展可运用在RV32、RV64和RV128指令集上,通常使用“RVC”来表…

Double 4 VR情景实训教学系统在商务洽谈课堂上的应用

随着科技的不断发展,VR(虚拟现实)技术已经逐渐渗透到各个领域。在商务洽谈课堂上,Double 4 VR情景实训教学系统不仅可以为学生提供身临其境的模拟环境,还可以通过互动和交互式学习方式,增强学生的学习体验和…

贝锐向日葵打造农机设备远程运维支持方案

当物联网“万物互联”的概念向第一产业赋能,农机设备的智能化程度也越来越高。 所谓农业物联网,即在应用层将大量的传感器节点构成监控网络,通过各种传感器采集信息,以帮助农民及时发现问题,并准确地判定发生问题的位…

QT 使用QZipReader 进行文件解压缩

目录 1、QZipReader 概述 2、解压示例 3、说明 1、QZipReader 概述 QZipReader 是一个方便的工具,用于在 Qt 应用程序中解压 ZIP 压缩包。它提供了读取 ZIP 文件的接口,并能提取其中的内容。以下是如何使用 QZipReader 解压 ZIP 文件的示例代码&#…

List、IList、ArrayList 和 Dictionary

List 类型: 泛型类命名空间: System.Collections.Generic作用: List<T> 表示一个强类型的对象列表&#xff0c;可以通过索引访问。提供了搜索、排序和操作列表的方法。特点: 类型安全&#xff0c;性能较好&#xff0c;适用于需要强类型和高效操作的场景。例子: List<…

每日一练 - BGP Keepalive 报文详解

01 真题题目 关于 BGP 的 Keepalive 报文消息的描述,错误的是&#xff1a; A.Keepalive 周期性的在两个 BGP 邻居之间发送 B.缺省情况下,Keepalive 的时间间隔是 180s C.Keepalive 报文主要用于对等路由器间的运行状态和链路的可用性确认 D.Keepalive 报文的组成只包含一个…

Web安全:SQL注入之时间盲注原理+步骤+实战操作

「作者简介」&#xff1a;2022年北京冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础对安全知识体系进行总结与归纳&#xff0c;著作适用于快速入门的 《网络安全自学教程》&#xff0c;内容涵盖系统安全、信息收集等…

ICML2024高分论文!大模型计算效率暴涨至200%,来自中国AI公司

前段时间&#xff0c;KAN突然爆火&#xff0c;成为可以替代MLP的一种全新神经网络架构&#xff0c;200个参数顶30万参数&#xff1b;而且&#xff0c;GPT-4o的生成速度也是惊艳了一众大模型爱好者。 大家开始意识到—— 大模型的计算效率很重要&#xff0c;提升大模型的token…

前端加载excel文件数据 XLSX插件的使用

npm i xlsx import axios from axios; axios //这里用自己封装的http是不行的&#xff0c;踩过坑.get(url,{ responseType: "arraybuffer" }).then((re) > {console.log(re)let res re.datavar XLSX require("xlsx");let wb XLSX.read(r…

黑龙江大学文学院古代文学教研室安家琪副教授

女&#xff0c;生于1990年。兰州大学文学学士、硕士&#xff0c;上海交通大学文学博士&#xff0c;曾赴台湾东华大学交流&#xff0c;研究方向为明清诗文与唐代文学。 在《文艺理论研究》、《苏州大学学报》、《唐史论丛》、《中国社会科学报》等期刊发表论文20余篇&#xff0…

2024年 电工杯 (A题)大学生数学建模挑战赛 | 园区微电网风光储协调优化配置 | 数学建模完整代码解析

DeepVisionary 每日深度学习前沿科技推送&顶会论文&数学建模与科技信息前沿资讯分享&#xff0c;与你一起了解前沿科技知识&#xff01; 本次DeepVisionary带来的是电工杯的详细解读&#xff1a; 完整内容可以在文章末尾全文免费领取&阅读&#xff01; 问题重述…

干就对了!

成年人的世界哪有那么容易&#xff0c;不过都在负重前行&#xff0c;谁不是一边抱怨着&#xff0c;一边咬牙坚持&#xff0c;一边崩溃&#xff0c;一边还要自我安慰。 想改变&#xff0c;想更好&#xff0c;我们都有很多想法。 想再多不如动手做一次。一旦开始做了&#xff0…

前端手写文件上传;使用input实现文件拖动上传

使用input实现文件拖动上传 vue2代码&#xff1a; <template><div><div class"drop-area" dragenter"highlight" dragover"highlight" dragleave"unhighlight" drop"handleDrop"click"handleClick&quo…

听说京东618裁员没?上午还在赶需求,下午就开会通知被裁了~

文末还有最新面经共享群&#xff0c;没准能让你刷到意向公司的面试真题呢。 京东也要向市场输送人才了? 在群里看到不少群友转发京东裁员相关的内容&#xff1a; 我特地去网上搜索了相关资料&#xff0c;看看网友的分享&#xff1a; 想不到马上就618了&#xff0c;东哥竟然抢…

Python 机器学习 基础 之 模型评估与改进 【模型评估与改进 / 交叉验证】的简单说明

Python 机器学习 基础 之 模型评估与改进 【模型评估与改进 / 交叉验证】的简单说明 目录 Python 机器学习 基础 之 模型评估与改进 【模型评估与改进 / 交叉验证】的简单说明 一、简单介绍 二、模型评估与改进 三、交叉验证 1、scikit-learn 中的交叉验证 2、交叉验证的…

stm32工程综合实验_延时及中断优先级

待下载综合实验 ![在这里插入图片描述](https://img-blog.csdnimg.cn/161fa4e200bb4022bf384e80a3af8797.jpg 很好的编程思想模式及资料(富莱xx电子)

【repo系列】repo常用命令的使用

前言 repo是一种代码版本管理工具&#xff0c;它是由一系列的Python脚本组成&#xff0c;封装了一系列的Git命令&#xff0c;用来统一管理多个Git仓库。 本文章描述repo常用命令的使用。 常用命令 初始化 repo init 初始化代码仓 repo init [options]常用options: -u URL…

JDBC——API详解

一、DriverManager 1、用于注册驱动程序&#xff1a;registerDriver(Driver driver)。 更常用的是Class.forName("com.mysql.jdbc.Driver")是由于Driver中包含了registerDriver(Driver driver)&#xff0c;值得注意的是&#xff0c;是mysql5之后的版本中&#xff0…