深度学习——线性神经网络(五、图像分类数据集——Fashion-MNIST数据集)

目录

  • 5.1 读取数据集
  • 5.2 读取小批量
  • 5.3 整合所有组件

  MNIST数据集是图像分类中广泛使用的数据集之一,但是作为基准数据集过于简单,在本小节将使用类似但更复杂的Fashion-MNIST数据集。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l# 这个函数的目的是设置图形显示格式为SVG(Scalable Vector Graphics),
# 这是一种基于矢量的图形格式,可以清晰地缩放而不失真。
d2l.use_svg_display()

5.1 读取数据集

  可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans,download=True)

在这里插入图片描述
  Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集中的6000张图像和测试数据集中的1000张图像组成。因此,训练集和测试集分别总共包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

print(len(mnist_train))
print(len(mnist_test))
60000
10000

  每个输入图像的高度和宽度均为28像素,数据集由灰度图像组成,其通道数为1.

  在图像处理和计算机视觉中,“通道”一词常用来描述图像中颜色信息的存储方式。每个通道代表图像中一种颜色的成分,不同的颜色模式会有不用的通道数。
  灰度图像的通道数为1,在灰度图像中,每个像素只有一个强度值,表示黑白之间的不同灰度级别,不包含颜色信息。

print(mnist_train[0][0].shape)
torch.Size([1, 28, 28])

  Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
  以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

  现在创建一个可视化函数来查看样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""创建一个函数来可视化这些样本,绘制图像列表,目的是在一张图中显示多个图像。imgs是要显示的图像列表,num_rows是创建的子图的行数,num_cols是创建的子图的列数,该子图没有设置标题,调整子图大小的缩放因子默认为1.5"""figsize = (num_cols * scale, num_rows * scale) # 计算整个子图的尺寸,基于子图的行数和列数以及缩放因子来决定_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # figsize 参数设置了整个图形的大小axes = axes.flatten() # 将子图网格展平为一维数组,方便后续遍历for i, (ax, img) in enumerate(zip(axes, imgs)):"""使用enumerate函数和zip函数来迭代两个列表:axes和imgs。这个循环将同时遍历这两个列表,并将它们对应的元素组合在一起,然后进行处理。其中enumerate函数用于跟踪循环的当前迭代次数(即索引i),并返回每个元素及其索引。"""if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)# 子图中隐藏坐标轴。具体来说,它们分别隐藏了x轴和y轴ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i]) # 用来给每个子图设置标题plt.show()plt.savefig('class')return axesX, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) # 用于拿到第一个小批量,批量大小为18
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

在这里插入图片描述

5.2 读取小批量

  为了使我们在读取训练集和测试集时更容易,使用内置的数据迭代器,而不是从开始创建。在每次迭代中,数据加载器都会读取一小批量数据,大小为batch_size,通过内置的数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。

batch_size = 256 # 设置批量大小def get_dataloader_workers():"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())# 看一下读取训练数据所需的时间
timer = d2l.Timer()
for X, y in train_iter:continue
print(f'{timer.stop():.2f} sec')
2.36 sec

  下面设置了不同的进程数所需的时间。设置的8个进程数读取小批量所需的时间比较少。
在这里插入图片描述

5.3 整合所有组件

  现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

  我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

   小结:
  数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程的可能性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/57797.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端零基础入门到上班:【Day3】从零开始构建网页骨架HTML

HTML 基础入门&#xff1a;从零开始构建网页骨架 目录 1. 什么是 HTML&#xff1f;HTML 的核心作用 2. HTML 基本结构2.1 DOCTYPE 声明2.2 <html> 标签2.3 <head> 标签2.4 <body> 标签 3. HTML 常用标签详解3.1 标题标签3.2 段落和文本标签3.3 链接标签3.4 图…

使用Python来下一场深夜雪

效果图&#xff1a;&#xff08;真实情况是动态的&#xff09; 完整代码&#xff1a; import turtle import random# 初始化画布 turtle.bgcolor("#001f3f") # 偏深蓝色的背景 turtle.title("下雪的画面") turtle.speed(0) turtle.hideturtle() turtle.t…

创建ODBC数据源SQLConfigDataSource函数的用法

网络上没有这个函数能实际落地的用法说明&#xff0c;我实践后整理一下&#xff1a; 1.头文件与额外依赖库&#xff1a; #include <odbcinst.h> #pragma comment(lib, "legacy_stdio_definitions.lib") 2.调用函数&#xff1a; if (!SQLConfigDataSourceW(…

集创赛比赛细则了解

一、赛道划分 数字与SOC设计 紫光展锐杯不推荐大家参加&#xff0c;设计比较复杂 Arm杯是芯片IP封装测试&#xff0c;在FPGA上做外部总线协议设计。 Robei杯是作为FPGA的应用背景&#xff0c;包括控制算法 平头哥杯是阿里旗下专注于VSC的平台。通过平头哥的平台实现专门的应用…

【C语言】控制台学生成绩管理系统

文章目录 C语言编程&#xff1a;学生成绩管理系统一、程序概述二、代码实现三、程序解释 C语言编程&#xff1a;学生成绩管理系统 在这篇文章中&#xff0c;我们将一起探讨如何使用C语言来创建一个简单的学生成绩管理系统。这个系统将允许用户输入学生数量、学号和成绩&#x…

Web刷题日记1---清风

[GDOUCTF 2023]EZ WEB 题目网站在NSSCTF 这个题目有一个新的知识点&#xff0c;对于我来说比较的少见吧&#xff0c;第一次遇见。em...是什么呢?后面再说 进入靶场&#xff0c;比较突兀&#xff0c;点了这个button后&#xff0c;提示flag在附近 查看源码&#xff0c;有提示…

react18中使用redux管理公共数据仓库实现数据immutable更新

Immutable.js出自Facebook&#xff0c;是最流行的不可变数据结构的实现之一。它实现了完全的持久化数据结构&#xff0c;使用结构共享。所有的更新操作都会返回新的值&#xff0c;但是在内部结构是共享的&#xff0c;来减少内存占用。Immutablejs官网 在上一篇介绍redux的文章&…

FFMPEG+Qt 实时显示本机USB摄像头1080p画面以及同步录制mp4视频

FFMPEGQt 实时显示本机USB摄像头1080p画面以及同步录制mp4视频 文章目录 FFMPEGQt 实时显示本机USB摄像头1080p画面以及同步录制mp4视频1、前言1.1 目标1.2 一些说明 2、效果3、代码3.1 思路3.2 工程目录3.3 核心代码 4、全部代码获取 1、前言 本文通过FFMPEG(7.0.2)与Qt(5.13.…

有色行业测温取样机器人 - SNK施努卡

SNK施努卡有色行业熔炼车间机器人测温取样 在有色行业&#xff0c;测温取样机器人专门设计用于自动化处理高温熔体的温度监测和样品采集任务。这类机器人在铜、铝、锌等金属冶炼过程中扮演着关键角色&#xff0c;以提高生产效率、确保产品质量并增强工作安全性。 主要工作项 …

基于 matlab 计算 TPI(地形位置指数)

1. TPI 简介 地形位置指数算法由 Weiss 提出&#xff0c;主要是根据局部地形高程对各类地貌单元提取。 其基本原理为&#xff1a;在邻域分析方法的基础上&#xff0c;计算每个栅格的高程值和该栅格领域内所有栅格的平均高程之间的差值&#xff0c;正值表示该栅格点高于领域内栅…

element ui中el-image组件查看图片的坑

比如说上传组件使用el-image-viewer组件去看&#xff0c;如果用错了&#xff0c;你会发现&#xff0c;你每次只能看一张图片 <template><div><el-upload action"#" list-type"picture-card" :auto-upload"false" :file-list"…

Spring Cloud --- Sentinel 熔断规则

熔断规则 慢调用比例 发送10个请求&#xff0c;每个请求理想响应时长为200毫秒。统计1秒钟&#xff0c;如果10个请求响应时间超过200毫秒的比例大于等于10%&#xff0c;则触发熔断&#xff0c;熔断5秒。 异常比例 1秒内&#xff0c;发送请求出现异常率为20%&#xff0c;则触…

arcgis中dem转模型导入3dmax

文末分享素材 效果 1、准备数据 (1)DEM (2)DOM 2、打开arcscene软件 3、加载DEM、DOM数据 4、设置DOM的高度为DEM

LabVIEW中句柄与引用

在LabVIEW中&#xff0c;句柄&#xff08;Handle&#xff09; 是一种用于引用特定资源或对象的标识符。它类似于指针&#xff0c;允许程序在内存中管理和操作复杂的资源&#xff0c;而不需要直接访问资源本身。句柄用于管理动态分配的资源&#xff0c;如队列、文件、网络连接、…

Vision-Language Models for Vision Tasks: A Survey阅读笔记

虽然LLM的文章还没都看完&#xff0c;但是终究是开始看起来了VLM&#xff0c;首当其冲&#xff0c;当然是做一片文献综述啦。这篇文章比较早了&#xff0c;2024年2月份出的last version。 文章链接&#xff1a;https://arxiv.org/abs/2304.00685 GitHub链接&#xff1a;GitHu…

Java Web开发教程:从入门到精通

Java Web开发教程&#xff1a;从入门到精通 前言 在当今互联网时代&#xff0c;Web开发已成为一个炙手可热的领域。Java作为一种成熟的编程语言&#xff0c;以其稳定性和跨平台性&#xff0c;成为了Web开发的热门选择。本文将带您从基础知识入手&#xff0c;逐步深入Java Web…

C#与C++交互开发系列(十):数组传递的几种形式

前言 在C#和C的交互开发中&#xff0c;数组传递是一个非常常见且实用的场景。数组可以作为方法的参数&#xff0c;也可以作为响应结果返回。在本篇博客中&#xff0c;我们将探讨几种常见的数组传递方式&#xff0c;展示如何在C#与C之间进行有效的数据交换。我们将主要介绍以下…

代谢组数据分析(二十):通过WGCNA识别核心代谢物

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍识别核心基因加载R包导入数据数据预处理检查数据完整性计算软阈值soft根据软阈值构建接矩阵和拓扑重叠矩阵聚类并构建网络拓扑重叠热图查看具体模块的代谢物表达热图识别表型相关模…

word表格跨页后自动生成的顶部横线【去除方法】

Hello World! Its been a long time. 这一年重心放在了科研、做事、追寻新的经历上&#xff0c;事有正事、琐事、幸事、哀事&#xff0c;内心与认知成长了一些&#xff0c;思想成熟了几分&#xff0c;技艺也有若干收获。不管怎样&#xff0c;来打个卡吧&#xff0c;纪念一下&…

边缘计算路由网关R40钡铼技术3LAN口1WAN口Modbus协议

在当今快速发展的工业互联网时代&#xff0c;随着物联网&#xff08;IoT&#xff09;与大数据分析的日益融合&#xff0c;边缘计算成为了提高数据处理效率、降低延迟的关键技术。 产品特点&#xff1a; 多接口支持&#xff1a;R40B拥有3个LAN口和1个WAN口的设计&#xff0c;能…