[pytorch、学习] - 3.5 图像分类数据集

参考

3.5. 图像分类数据集

在介绍shftmax回归的实现前我们先引入一个多类图像分类数据集

本章开始使用pytorch实现啦~

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口
  2. torchvision.models: 包含常用的模型(含预训练模型),例如 AlexNet、VGG、ResNet等
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法

3.5.1. 获取数据集

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 为了导入上层的d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过torchvision的torchvision.dataset来下载这个数据集。第一次调用时会自动从网上下载获取数据。我们通过参数train来指定获取训练集或测试数据集(testing data)。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

数据集下载如果比较慢,建议通过网址下载4个,然后根据下载的目录导入即可。如下图,是下载了(4个文件,这里只说一个,其他类似)train-images-idx3-ubyte.gz到本地的目录C:\Users\1/Datasets/FashionMNIST\FashionMNIST\raw\下.可从网址直接下载4个(无需解压)到该目录下在执行以上代码.

在这里插入图片描述

# 上面的 mnist_train 和 mnist_test都是 torch.utils.data.Datasets的子类
# 所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本
# 训练集中和测试集中的每个类别的图像分别是6000和1000。因为有10个类别,所以训练集和测试集的样本数分别为60000和10000
print(type(mnist_train))   # <class 'totchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test))   # 60000 10000
# 通过下标访问任意样本
feature, label = mnist_train[0]
print(feature.shape, label)   # torch.Size([1, 28, 28]) tensor(5)

在这里插入图片描述
变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是(C * H * W)的,而不是(H * W * C)。第一维是通道数,因为数据通道数为1.后面两维分别是图像的高和宽

Fashion-MNIST中一共包括了10个类别,分别0、1、2、3、4、5、6、7、8、9

# 本函数已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
#     text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']text_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']return [text_labels[int(i)] for i in labels]# 定义一个可以在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):d2l.use_svg_display()_, figs = plt.subplots(1, len(images), figsize= (12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()X, y = [], []
for i in range(10):# 从数据集中取出10个X.append(mnist_train[i][0])y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

3.5.2. 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试集上评价模型的表现。前面说过, mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例

batch_size = 256
if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)start = time.time()
for X,y in train_iter:continue
print('%.2f sec' % (time.time() - start))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250194.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python自动化第三周---文件读写

1.python文件对象提供了三个“读”方法&#xff1a; read()、readline() 和 readlines()。每种方法可以接受一个变量以限制每次读取的数据量。 read() 每次读取整个文件&#xff0c;它通常用于将文件内容放到一个字符串变量中。如果文件大于可用内存&#xff0c;为了保险起见&a…

最详细的java泛型详解

来源&#xff1a;最详细的java泛型详解 对java的泛型特性的了解仅限于表面的浅浅一层&#xff0c;直到在学习设计模式时发现有不了解的用法&#xff0c;才想起详细的记录一下。 本文参考java 泛型详解、Java中的泛型方法、 java泛型详解 1. 概述 泛型在java中有很重要的地位&a…

[pytorch、学习] - 3.6 softmax回归的从零开始实现

参考 3.6 softmax回归的从零开始实现 import torch import torchvision import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.6.1. 获取和读取数据 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_si…

Django基础必备三件套: HttpResponse render redirect

1. HttpResponse : 它的作用是内部传入一个字符串参数, 然后发给浏览器 def index(request):return HttpResponse(ok) 2. render : 可以接收三个参数, 一是request参数, 二是待渲染的 html 模板文件, 三是保存具体数据的字典参数 def index(request):return render(request, …

React 简单实例 (React-router + webpack + Antd )

React Demo Github 地址 经过React Native 的洗礼之后&#xff0c;写了这个 demo &#xff1b;React 是为了使前端的V层更具组件化&#xff0c;能更好的复用&#xff0c;同时可以让你从操作dom中解脱出来&#xff0c;只需要操作数据就会改变相应的dom&#xff1b; 而React Nat…

[pytorch、学习] - 3.7 softmax回归的简洁实现

参考 3.7. softmax回归的简洁实现 使用pytorch实现softmax import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.7.1. 获取和读取数据 batch_size 256 train_iter…

【模板】NTT

NTT模板 #include<bits/stdc.h> using namespace std; #define LL long long const int MAXL22; const int MAXN1<<MAXL; const int Mod998244353; int rev[MAXN],A[MAXN],B[MAXN],C[MAXN]; int fast_pow(int a,int b){int ans1;while(b){if(b&1)ans1ll*ans*a%…

centos 7 php7 yum源

rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpmrpm -Uvh https://mirror.webtatic.com/yum/el7/webtatic-release.rpm 转载于:https://www.cnblogs.com/myJuly/p/10008252.html

[pytorch、学习] - 3.9 多重感知机的从零开始实现

参考 3.9 多重感知机的从零开始实现 import torch import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.9.1. 获取和读取数据 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)3.9.2. 定义模型参…

C语言逗号运算符和逗号表达式基础总结

逗号运算符的作用&#xff1a; 1&#xff0c;起分隔符的作用&#xff1a; 定义变量用于分隔变量&#xff1a;int a,b输入或输出时用于分隔输出表列 printf("%d%d",a,b) 2,用于逗号表达式的顺序运算符 语法&#xff1a;表达式1&#xff0c;表达式2&#xff0c;...,表达…

java基础-泛型举例详解

泛型 泛型是JDK5.0增加的新特性&#xff0c;泛型的本质是参数化类型&#xff0c;即所操作的数据类型被指定为一个参数。这种类型参数可以在类、接口、和方法的创建中&#xff0c;分别被称为泛型类、泛型接口、泛型方法。 一、认识泛型 在没有泛型之前,通过对类型Object的引用来…

MySQL数据库视图(view),视图定义、创建视图、修改视图

原文链接&#xff1a;https://blog.csdn.net/moxigandashu/article/details/63254901转载于:https://www.cnblogs.com/chrdai/p/9131881.html

[pytorch、学习] - 3.10 多重感知机的简洁实现

参考 3.10. 多重感知机的简洁实现 import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.10.1. 定义模型 num_inputs, num_outputs, num_hiddens 784, 10, 256 # 参…

【汇编语言】——第三章课后总结

第三章 的书本上主要有以下几个内容&#xff1a; 1.内存中字的存储 字单元&#xff1a;即存放一个字型数据&#xff08;16位&#xff09;的内存单元&#xff0c;由两个地址连续的内存单元组成。 小端法&#xff1a;高地址内存单元中存放字型数据的高位字节&#xff0c;低地址内…

如何从 Android 手机免费恢复已删除的通话记录/历史记录?

有一个有合作意向的人给我打电话&#xff0c;但我没有接听。更糟糕的是&#xff0c;我错误地将其删除&#xff0c;认为这是一个骚扰电话。那么有没有办法从 Android 手机恢复已删除的通话记录呢&#xff1f;” 塞缪尔问道。如何在 Android 上恢复已删除的通话记录&#xff1f;如…

springBoot 登录拦截器

1、首选创建一个继承HandlerInterceptor的拦截器 import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; /*** 拦…

[pytorch、学习] - 3.11 模型选择、欠拟合和过拟合

参考 3.11 模型选择、欠拟合和过拟合 3.11.1 训练误差和泛化误差 在解释上述现象之前&#xff0c;我们需要区分训练误差&#xff08;training error&#xff09;和泛化误差&#xff08;generalization error&#xff09;。通俗来讲&#xff0c;前者指模型在训练数据集上表现…

关于'java' 不是内部或外部命令,也不是可运行的程序 或批处理文件 和 错误: 找不到或无法加载主类 helloworld的问题...

一、前几天电脑重装了一次系统将java配置的环境变量都弄没了&#xff0c;自己添加了两个新的变量JAVA_HOME&#xff08;自己jdk的地址&#xff09;以及在path中添加%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin; 然后因为这几天都是用eclipse进行编程的&#xff0c;没有出现问题&#…

spring-boot注解详解(一)

spring-boot注解详解(一) SpringBootApplication SpringBootApplication (默认属性)Configuration EnableAutoConfiguration ComponentScan。 Configuration&#xff1a;提到Configuration就要提到他的搭档Bean。使用这两个注解就可以创建一个简单的spring配置类&#xf…

前端基础-jQuery的优点以及用法

一、jQuery介绍 jQuery是一个轻量级的、兼容多浏览器的JavaScript库。jQuery使用户能够更方便地处理HTML Document、Events、实现动画效果、方便地进行Ajax交互&#xff0c;能够极大地简化JavaScript编程。它的宗旨就是&#xff1a;“Write less, do more.“二、jQuery的优势 一…