深度学习修炼(二)——数据集的加载

文章目录

    • 致谢
  • 2 数据集的加载
    • 2.1 框架数据集的加载
    • 2.2 自定义数据集
    • 2.3 准备数据以进行数据加载器训练

致谢

Pytorch自带数据集介绍_godblesstao的博客-CSDN博客_pytorch自带数据集

2 数据集的加载

与sklearn中的datasets自带数据集类似,pytorch框架也为我们提供了数据集以便一系列的模型测试。其数据集作为一个类继承自父类torch.utils.data.Dataset。

2.1 框架数据集的加载

让我们看看torch为我们提供了什么数据集。数据集种类如下所示:

  • 手写字符识别:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot

  • 实物分类:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet

  • 人脸识别:CelebA

  • 场景分类:LSUN、Places365

  • 用于object detection:SVHN、VOCDetection、COCODetection

  • 用于semantic/instance segmentation:

  • 语义分割:Cityscapes、VOCSegmentation

  • 语义边界:SBD

  • 用于image captioning:Flickr、COCOCaption

  • 用于video classification:HMDB51、Kinetics

  • 用于3D reconstruction:PhotoTour

  • 用于shadow detectors:SBU

以FashionMNIST数据集为例,我们看一下如何加载数据集。

torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())

  • root是存储训练/测试数据的路径
  • train指定训练或测试数据集,当布尔值为True则为训练集,当布尔值为False则为测试集
  • download=True从互联网下载数据(如果无法在本地获得)
  • transform指定特征转换方式,target_transform指定标签转换方式
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensordef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datatrain_data, test_data = load_data()
print(train_data)

数据集加载完实际上是以类的形式存在的,其不同于sklearn中返回的Bunch。

如果我们想要看看数据集中有啥要怎么做呢?首先,这个数据集是图像分类数据集,说明里面含有的都是图像,为此,我们可以使用subplots存放这些图片。对于这些数据集,我们可以像列表一样手动索引。如train_data[index]

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as pltdef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""数据集可视化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 从训练集中随机抽出九张图(九个样本)for i in range(1, cols * rows + 1):# 设置索引,索引取值为0到训练集的长度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出对应样本的图片和标签img, label = train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis("off")# 展示图片plt.imshow(img.squeeze(), cmap="gray")# 释放画布plt.show()train_data, test_data = load_data()
show_data(train_data)

out:

image-20220315095159288

上面用到了一个API,即torch.randint()

torch.randint(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor

  • 用于取随机整数,返回值为张量
  • low:int类型,表明要从分布中提取的最低整数
  • high:int类型,表明要从分布中提取的最高整数1
  • size:元组类型,表明输出张量的形状
  • dtype:返回值张量的数据类型
  • device:返回张量所需的设备
  • requires_grad:布尔类型,表明是否要对返回的张量自动求导。

如:

torch.randint(3, 5, (3,))
tensor([4, 3, 4])

意味生成一个一维的3元素向量,其中向量中的元素取值从3-5取。

2.2 自定义数据集

如果你不想使用框架自带的数据集,那么你可以自己定义一个数据集类。自定义Dataset类必须实现三个函数:__ init __ 、 __ len __ 、__ getitem __。其中图像部分存储于一个文件夹中,标签单独存储在CSV文件中。

在接下来的代码中,让我们看看如何创建一个自定义数据集。

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

对于__ init __ 函数来说,包含加载图像、注释文件和两个转换的目录,在这里我们不做过多讲解,后面会详细介绍。

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform

对于__ len __ 函数,其功能是返回数据集中的样本数。

def __len__(self):return len(self.img_labels)

对于 __ getitem __,其功能是给定索引便能返回对应样本。

def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

在自定义这一部分不用过多的去了解,用着用着就会了,就算不会代码也是通用,需要用的时候看一下复制一下,别搞得自己这么焦虑。

2.3 准备数据以进行数据加载器训练

在pytorch中,数据加载的核心实际上是torch.utils.data.DataLoader类,它支持对torch数据集的python可迭代,换而言之,DataLoader相当于你拿一个水盆,而dataset相当于泉水。DataLoader可以对小批量数据集进行处理,处理内容包括:

  • 地图样式和可迭代样式的数据集
  • 自定义数据集加载顺序
  • 多进程加载数据
  • 自动内存固定

其中地图样式数据集是指自定义数据集,而可迭代样式数据集指的是自带数据集。其他详情对于初学者来说很不友好,这里不做过多解释,你可以理解为这就是个科普知识。

我们来看一下这个API吧。

torch.utils.data.DataLoader(数据集, batch_size=1, shuffle=False)

  • 用于加载样本并且进行批处理
  • 数据集:要加载的数据集
  • batch_size:整数类型,表明每批要加载的样本数,默认为1
  • shuffle:布尔类型,表明是否要洗牌

我们利用上面的API来加载我们上面的Fashion_MNIST吧。

def load_batch_data():"""数据集批处理加载器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloader

既然已经将样本导入加载器,那么我们如何从加载器中读取数据呢?我们可以根据需要循环访问数据集。

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoaderdef load_data():"""加载数据集"""# 1 训练数据集的加载train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 测试数据集的加载test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""数据集可视化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 从训练集中随机抽出九张图(九个样本)for i in range(1, cols * rows + 1):# 设置索引,索引取值为0到训练集的长度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出对应样本的图片和标签img, label = train_data[sample_idx]# 依次画于事先指定的九宫格图上figure.add_subplot(rows, cols, i)# 设置对应图片的标题plt.title(label_map[label])# 关掉坐标轴plt.axis("off")# 展示图片plt.imshow(img.squeeze(), cmap="gray")# 释放画布plt.show()def load_batch_data():"""数据集批处理加载器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloaderdef show_batch_data():"""循环访问数据加载器"""train_dataloader, test_dataloader = load_batch_data()train_feature, train_labels = next(iter(train_dataloader))print(f"特征大小:{train_feature.size()}")print(f"标签大小:{train_labels.size()}")img = train_feature[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"label:{label}")train_data, test_data = load_data()
# show_data(train_data)
show_batch_data()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/398636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Css3之基础-5 Css 背景、渐变属性

一、CSS 背景概述背景属性的作用- 背景样式可以控制 HTML 元素的背景颜色、背景图像等 - 背景色- 设置单一的颜色作为背景- 背景图像- 以图片作为背景- 可以设置图像的位置、平铺、尺寸等二、CSS 背景属性背景色 background-color - 属性用于为元素设置背景色- 接受任何合法的颜…

机器学习的练功方式(十)——岭回归

文章目录十 岭回归10.1 岭回归的接口10.2 岭回归处理房价预测十 岭回归 岭回归是线性回归的改进,有时候迫不得已我们的参数确实不能少,这时候过拟合的现象就可能发生。为了避免过拟合现象的发生,既然不能从减少参数上面下手,那我…

OpenCV修养(一)——引入

文章目录1 引入1.1 OpenCV是啥1.2 OpenCV——Python1 引入 1.1 OpenCV是啥 OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉计算机视觉和机器学习软件库,可以运行在Linux、Windows、Android和Mac OS操作系统上。 它轻量级而且高…

被解放的姜戈08 远走高飞

作者:Vamei 出处:http://www.cnblogs.com/vamei 转载请先与我联系。 之前在单机上实现了一个Django服务器(被解放的姜戈07 马不停蹄),现在我们可以把这个服务器推上一个云平台。这里我选择使用阿里云。 看着复仇的火焰…

OpenCV修养(二)——OpenCV基本操作

文章目录2 OpenCV基本操作2.1 IO操作2.2 图像基本操作2.2.1 图像绘制2.2.1.1 绘制直线2.2.1.2 绘制圆形2.2.1.3 绘制矩形2.2.1.4 添加文字2.2.1.5 试手2.2.2 获取/修改图像的像素点2.2.3 获取图像属性2.2.4 图像通道拆分/合并2.2.5 色彩空间改变2.2.6 边界填充2.3图像算数操作2…

ylbtech-LanguageSamples-Porperties(属性)

ylbtech-Microsoft-CSharpSamples:ylbtech-LanguageSamples-Porperties(属性)1.A,示例(Sample) 返回顶部“属性”示例 本示例演示属性为何是 C# 编程语言必不可少的一个组成部分。它演示了如何声明和使用属性。有关更多信息,请参见属性(C# 编…

Altium Designer敷铜的规则设定

InPolygon 这个词是铺铜对其他网络的设置,铺铜要离其他网络远点,因为腐蚀不干净会对 电路板有影响... 问题一:: 如下图所示,现在想让敷铜与板子边界也就是keepoutlayer的间距小一点,比如0.2MM。而与走线的间距比较大,比如0.8mm。要怎么设置规…

OpenCV修养(三)——图像处理(上)

文章目录致谢3 图像处理(上)3.1 几何变换3.1.1 图像缩放3.1.2 图像平移3.1.3 图像旋转3.1.4 仿射变换3.2 图像阈值3.3 图像平滑3.3.1 图像噪声3.3.1.1 椒盐噪声3.3.1.2 高斯噪声3.3.2 均值滤波3.3.3 方框滤波3.3.4 高斯滤波3.3.5 中值滤波3.3.6 小结3.4 …

机器学习的练功方式(十一)——逻辑回归

文章目录致谢11 逻辑回归11.1 引入11.2 激活函数11.3 损失函数11.4 梯度下降11.5 案例:癌症分类预测致谢 逻辑回归为什么用Sigmoid? - 知乎 (zhihu.com) 逻辑回归中的损失函数的解释_yidiLi的博客-CSDN博客_逻辑回归损失函数 11 逻辑回归 逻辑回归也被称…

【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈

【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈 原文:【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈摘要: 你玩魔兽不?你知道如何做一张魔兽地图不…

linux系统分两种更普遍的包,rpm和tar,这两种安装包如何解压与安装

2019独角兽企业重金招聘Python工程师标准>>> RPM软件包管理器&#xff1a;一种用于互联网下载包的打包及安装工具&#xff0c;它包含在某些Linux分发版中。它生成具有.RPM扩展名的文件。rpm -ivh xxxx.rpm <-安装rpm包 -i install的意思 -v view 查看更详细的…

C++类的数组元素查找最大值问题

找出一个整型数组中的元素的最大值。 1 /*找出一个整型数组中的元素的最大值。*/2 3 #include <iostream>4 using namespace std;5 6 class ArrayMax //创建一个类7 {8 public :9 void set_value(); 10 void max_value(); 11 void sh…

C++从0到1的入门级教学(二)——数据类型

文章目录2 数据类型2.1 简单变量2.2 基本数据类型2.2.1 整型2.2.2 实型&#xff08;浮点型&#xff09;2.2.3 字符型2.2.4 布尔类型2.3 sizeof关键字2.4 类型转换2.5 转义字符2.6 重新谈及变量2.6.1 字面值常量2.6.2 变量2.6.3 列表初始化2.7 数据的输入2 数据类型 2.1 简单变…

深度学习修炼(三)——自动求导机制

文章目录致谢3 自动求导机制3.1 传播机制与计算图3.1.1 前向传播3.1.2 反向传播3.2 自动求导3.3 再来做一次3.4 线性回归3.4.1 回归3.4.2 线性回归的基本元素3.4.3 线性模型3.4.4 线性回归的实现3.4.4.1 获取数据集3.4.4.2 模型搭建3.4.4.3 损失函数3.4.4.4 训练模型3.5 后记致…

深度学习修炼(四)——补充知识

文章目录致谢4 补充知识4.1 微积分4.1.1 导数和微分4.1.2 偏导数4.1.3 梯度4.1.4 链式求导4.2 Hub模块致谢 导数与微分到底有什么区别&#xff1f; - 知乎 (zhihu.com) 4 补充知识 在这一小节的学习中&#xff0c;我们会对上一小节的知识点做一个补充&#xff0c;并且拓展一个…

java使用POI jar包读写xls文件

主要使用poi jar来操作excel文件。代码中用到的数据库表信息见ORACLE之表。使用public ArrayList<Person> getPersonAllRecords()获得所有的记录。 1 public class PersonXLS {2 3 public static void main(String[] args) throws IOException {4 5 …

深度学习修炼(五)——基于pytorch神经网络模型进行气温预测

文章目录5 基于pytorch神经网络模型进行气温预测5.1 实现前的知识补充5.1.1 神经网络的表示5.1.2 隐藏层5.1.3 线性模型出错5.1.4 在网络中加入隐藏层5.1.5 激活函数5.1.6 小批量随机梯度下降5.2 实现的过程5.2.1 预处理5.2.2 搭建网络模型5.3 简化实现5.4 评估模型5 基于pytor…

Android 应用程序集成FaceBook 登录及二次封装

1、首先在Facebook 开发者平台注册一个账号 https://developers.facebook.com/ 开发者后台 https://developers.facebook.com/apps 2、创建账号并且获得 APP ID 图一 图二 图三 图四 图五 3、获取app签名的Key Hashes 值&#xff08;两种方式&#xff09; 3.1方法1&#xff1…

IKAnalyzer进行中文分词和去停用词

最近学习主题模型pLSA、LDA&#xff0c;就想拿来试试中文。首先就是找文本进行切词、去停用词等预处理&#xff0c;这里我找了开源工具IKAnalyzer2012&#xff0c;下载地址&#xff1a;(&#xff1a;(注意&#xff1a;这里尽量下载最新版本&#xff0c;我这里用的IKAnalyzer201…

C++从0到1的入门级教学(六)——函数

文章目录6 函数6.1 概述6.2 函数的定义6.3 函数的调用6.4 值传递6.5 函数的常见形式6.6 函数的声明6.7 函数的分文件编写6 函数 6.1 概述 作用&#xff1a;将一段经常使用的代码封装起来&#xff0c;减少重复代码。 一个较大的程序&#xff0c;一般分为若干个程序块&#xf…