深度学习修炼(二)——数据集的加载

文章目录

    • 致谢
  • 2 数据集的加载
    • 2.1 框架数据集的加载
    • 2.2 自定义数据集
    • 2.3 准备数据以进行数据加载器训练

致谢

Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_pytorch自带数据集

2 数据集的加载

与sklearn中的datasets自带数据集类似,pytorch框架也为我们提供了数据集以便一系列的模型测试。其数据集作为一个类继承自父类torch.utils.data.Dataset。

2.1 框架数据集的加载

让我们看看torch为我们提供了什么数据集。数据集种类如下所示:

  • 手写字符识别:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot

  • 实物分类:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet

  • 人脸识别:CelebA

  • 场景分类:LSUN、Places365

  • 用于object detection:SVHN、VOCDetection、COCODetection

  • 用于semantic/instance segmentation:

  • 语义分割:Cityscapes、VOCSegmentation

  • 语义边界:SBD

  • 用于image captioning:Flickr、COCOCaption

  • 用于video classification:HMDB51、Kinetics

  • 用于3D reconstruction:PhotoTour

  • 用于shadow detectors:SBU

以FashionMNIST数据集为例,我们看一下如何加载数据集。

torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())

  • root是存储训练/测试数据的路径
  • train指定训练或测试数据集,当布尔值为True则为训练集,当布尔值为False则为测试集
  • download=True从互联网下载数据(如果无法在本地获得)
  • transform指定特征转换方式,target_transform指定标签转换方式
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensordef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datatrain_data, test_data = load_data()
print(train_data)

数据集加载完实际上是以类的形式存在的,其不同于sklearn中返回的Bunch。

如果我们想要看看数据集中有啥要怎么做呢?首先,这个数据集是图像分类数据集,说明里面含有的都是图像,为此,我们可以使用subplots存放这些图片。对于这些数据集,我们可以像列表一样手动索引。如train_data[index]

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as pltdef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""数据集可视化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 从训练集中随机抽出九张图(九个样本)for i in range(1, cols * rows + 1):# 设置索引,索引取值为0到训练集的长度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出对应样本的图片和标签img, label = train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis("off")# 展示图片plt.imshow(img.squeeze(), cmap="gray")# 释放画布plt.show()train_data, test_data = load_data()
show_data(train_data)

out:

image-20220315095159288

上面用到了一个API,即torch.randint()

torch.randint(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor

  • 用于取随机整数,返回值为张量
  • low:int类型,表明要从分布中提取的最低整数
  • high:int类型,表明要从分布中提取的最高整数1
  • size:元组类型,表明输出张量的形状
  • dtype:返回值张量的数据类型
  • device:返回张量所需的设备
  • requires_grad:布尔类型,表明是否要对返回的张量自动求导。

如:

torch.randint(3, 5, (3,))
tensor([4, 3, 4])

意味生成一个一维的3元素向量,其中向量中的元素取值从3-5取。

2.2 自定义数据集

如果你不想使用框架自带的数据集,那么你可以自己定义一个数据集类。自定义Dataset类必须实现三个函数:__ init __ 、 __ len __ 、__ getitem __。其中图像部分存储于一个文件夹中,标签单独存储在CSV文件中。

在接下来的代码中,让我们看看如何创建一个自定义数据集。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

对于__ init __ 函数来说,包含加载图像、注释文件和两个转换的目录,在这里我们不做过多讲解,后面会详细介绍。

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

对于__ len __ 函数,其功能是返回数据集中的样本数。

def __len__(self):return len(self.img_labels)

对于 __ getitem __,其功能是给定索引便能返回对应样本。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

在自定义这一部分不用过多的去了解,用着用着就会了,就算不会代码也是通用,需要用的时候看一下复制一下,别搞得自己这么焦虑。

2.3 准备数据以进行数据加载器训练

在pytorch中,数据加载的核心实际上是torch.utils.data.DataLoader类,它支持对torch数据集的python可迭代,换而言之,DataLoader相当于你拿一个水盆,而dataset相当于泉水。DataLoader可以对小批量数据集进行处理,处理内容包括:

  • 地图样式和可迭代样式的数据集
  • 自定义数据集加载顺序
  • 多进程加载数据
  • 自动内存固定

其中地图样式数据集是指自定义数据集,而可迭代样式数据集指的是自带数据集。其他详情对于初学者来说很不友好,这里不做过多解释,你可以理解为这就是个科普知识。

我们来看一下这个API吧。

torch.utils.data.DataLoader(数据集, batch_size=1, shuffle=False)

  • 用于加载样本并且进行批处理
  • 数据集:要加载的数据集
  • batch_size:整数类型,表明每批要加载的样本数,默认为1
  • shuffle:布尔类型,表明是否要洗牌

我们利用上面的API来加载我们上面的Fashion_MNIST吧。

def load_batch_data():"""数据集批处理加载器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloader

既然已经将样本导入加载器,那么我们如何从加载器中读取数据呢?我们可以根据需要循环访问数据集。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoaderdef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""数据集可视化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 从训练集中随机抽出九张图(九个样本)for i in range(1, cols * rows + 1):# 设置索引,索引取值为0到训练集的长度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出对应样本的图片和标签img, label = train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis("off")# 展示图片plt.imshow(img.squeeze(), cmap="gray")# 释放画布plt.show()def load_batch_data():"""数据集批处理加载器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloaderdef show_batch_data():"""循环访问数据加载器"""train_dataloader, test_dataloader = load_batch_data()train_feature, train_labels = next(iter(train_dataloader))print(f"特征大小:{train_feature.size()}")print(f"标签大小:{train_labels.size()}")img = train_feature[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"label:{label}")train_data, test_data = load_data()
# show_data(train_data)
show_batch_data()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/398636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Node.js 函数

在JavaScript中,一个函数可以作为另一个函数接收一个参数。我们可以先定义一个函数,然后传递,也可以在传递参数的地方直接定义函数。 Node.js中函数的使用与Javascript类似,举例来说,你可以这样做: functio…

在Entity Framework 4.0中使用 Repository 和 Unit of Work 模式

【原文地址】Using Repository and Unit of Work patterns with Entity Framework 4.0 【原文发表日期】 16 June 09 04:08 PM 如果你一直在关注这个博客的话,你知道我最近在讨论我们加到Entity Framework 4.0中的POCO功能的方方面面,新加的POCO支持促成…

Css3之基础-5 Css 背景、渐变属性

一、CSS 背景概述背景属性的作用- 背景样式可以控制 HTML 元素的背景颜色、背景图像等 - 背景色- 设置单一的颜色作为背景- 背景图像- 以图片作为背景- 可以设置图像的位置、平铺、尺寸等二、CSS 背景属性背景色 background-color - 属性用于为元素设置背景色- 接受任何合法的颜…

[Angularjs]锚点操作服务$anchorScroll

写在前面 有个单页应用的项目中,需要通过锚点进行页面的定位。但angularjs的路由会出现跟锚点冲突,angularjs会将锚点当成路由进行解析,造成跳转到这个页面,而我们需要的只是跳转到当前的锚点位置。angularjs的路由格式#/home/en。…

机器学习的练功方式(十)——岭回归

文章目录十 岭回归10.1 岭回归的接口10.2 岭回归处理房价预测十 岭回归 岭回归是线性回归的改进,有时候迫不得已我们的参数确实不能少,这时候过拟合的现象就可能发生。为了避免过拟合现象的发生,既然不能从减少参数上面下手,那我…

js产生随机数

<script>document.write(parseInt(10*Math.random()));  //输出0&#xff5e;10之间的随机整数document.write(Math.floor(Math.random()*101));  //输出1&#xff5e;10之间的随机整数function RndNum(n){var rnd"";for(var i0;i<n;i)rndMath.floor(Math…

JS实现在输入框内输入@时,邮箱账号自动补全

<!DOCTYPE HTML><html lang"en"><head><meta charset"utf-8"/><title>邮箱自动补全</title><style type"text/css">.wrap{width:200px;margin:0 auto;}h1{font-size:36px;text-align:center;line-hei…

OpenCV修养(一)——引入

文章目录1 引入1.1 OpenCV是啥1.2 OpenCV——Python1 引入 1.1 OpenCV是啥 OpenCV是一个基于Apache2.0许可&#xff08;开源&#xff09;发行的跨平台计算机视觉计算机视觉和机器学习软件库&#xff0c;可以运行在Linux、Windows、Android和Mac OS操作系统上。 它轻量级而且高…

动态创建 Plist 文件

简介 Property List&#xff0c;属性列表文件&#xff0c;它是一种用来存储串行化后的对象的文件。属性列表文件的扩展名为.plist &#xff0c;因此通常被称为 plist文件&#xff0c;文件是xml格式的。 写入plist文件 在开发过程中&#xff0c;有时候需要把程序的一些配置保存下…

[Everyday Mathematics]20150101

(1). 设 $f(x),g(x)$ 在 $[a,b]$ 上同时单调递增或单调递减, 试证: \[ (b-a)\int_a^b f(x)g(x)\mathrm{\,d}x \geq \int_a^b f(x)\mathrm{\,d}x\cdot \int_a^b g(x)\mathrm{\,d}x. \] (2). 试证: \[ c\in (0,1)\Rightarrow \int_c^1 \dfrac{e^t}{t}\mathrm{\,d}t \geq e\cdot …

被解放的姜戈08 远走高飞

作者&#xff1a;Vamei 出处&#xff1a;http://www.cnblogs.com/vamei 转载请先与我联系。 之前在单机上实现了一个Django服务器&#xff08;被解放的姜戈07 马不停蹄&#xff09;&#xff0c;现在我们可以把这个服务器推上一个云平台。这里我选择使用阿里云。 看着复仇的火焰…

OpenCV修养(二)——OpenCV基本操作

文章目录2 OpenCV基本操作2.1 IO操作2.2 图像基本操作2.2.1 图像绘制2.2.1.1 绘制直线2.2.1.2 绘制圆形2.2.1.3 绘制矩形2.2.1.4 添加文字2.2.1.5 试手2.2.2 获取/修改图像的像素点2.2.3 获取图像属性2.2.4 图像通道拆分/合并2.2.5 色彩空间改变2.2.6 边界填充2.3图像算数操作2…

ylbtech-LanguageSamples-Porperties(属性)

ylbtech-Microsoft-CSharpSamples:ylbtech-LanguageSamples-Porperties(属性)1.A&#xff0c;示例(Sample) 返回顶部“属性”示例 本示例演示属性为何是 C# 编程语言必不可少的一个组成部分。它演示了如何声明和使用属性。有关更多信息&#xff0c;请参见属性&#xff08;C# 编…

Altium Designer敷铜的规则设定

InPolygon 这个词是铺铜对其他网络的设置,铺铜要离其他网络远点,因为腐蚀不干净会对 电路板有影响... 问题一:: 如下图所示&#xff0c;现在想让敷铜与板子边界也就是keepoutlayer的间距小一点&#xff0c;比如0.2MM。而与走线的间距比较大&#xff0c;比如0.8mm。要怎么设置规…

OpenCV修养(三)——图像处理(上)

文章目录致谢3 图像处理&#xff08;上&#xff09;3.1 几何变换3.1.1 图像缩放3.1.2 图像平移3.1.3 图像旋转3.1.4 仿射变换3.2 图像阈值3.3 图像平滑3.3.1 图像噪声3.3.1.1 椒盐噪声3.3.1.2 高斯噪声3.3.2 均值滤波3.3.3 方框滤波3.3.4 高斯滤波3.3.5 中值滤波3.3.6 小结3.4 …

程序员福利各大平台免费接口,非常适用

电商接口 京东获取单个商品价格接口: http://p.3.cn/prices/mgets?skuIdsJ_商品ID&type1 ps:商品ID这么获取:http://item.jd.com/954086.html 物流接口 快递接口: http://www.kuaidi100.com/query?type快递公司代号&postid快递单号 ps:快递公司编码:申通”shentong”…

CF940D Alena And The Heater

思路&#xff1a; 模拟。 实现&#xff1a; 1 #include <bits/stdc.h>2 using namespace std;3 const int INF 1e9;4 int a[100005], n;5 string b;6 int main()7 {8 while (cin >> n)9 { 10 for (int i 0; i < n; i) cin >> a[i]; 11 …

Android工程开发笔记一

Android工程开发笔记<一> ---------------------------------------不同 APP相互调用 activity 1.ComponentName() Intent _Intent new Intent(Intent.ACTION_MAIN); _Intent.setComponent(new ComponentName("com.semp.skipdemo002","com.semp.skipdemo…

机器学习的练功方式(十一)——逻辑回归

文章目录致谢11 逻辑回归11.1 引入11.2 激活函数11.3 损失函数11.4 梯度下降11.5 案例&#xff1a;癌症分类预测致谢 逻辑回归为什么用Sigmoid&#xff1f; - 知乎 (zhihu.com) 逻辑回归中的损失函数的解释_yidiLi的博客-CSDN博客_逻辑回归损失函数 11 逻辑回归 逻辑回归也被称…

ODB——基于c++的ORM映射框架尝试(安装)

这篇博客应该是和之前的重拾cgi一起的。当时为了模仿java的web框架&#xff0c;从页面的模板&#xff0c;到数据库的ORM&#xff0c;都找个对应的库来进行尝试。数据库用的就是ODB&#xff0c;官方网站是http://www.codesynthesis.com/products/odb/。 1、安装 odb是直接提供源…