Dataset和DataLoader构建数据通道

重点在第二部分的构建数据通道和第三部分的加载数据集

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。

Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。

而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

在绝大部分情况下,用户只需实现Dataset的__len__方法和__getitem__方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

一,Dataset和DataLoader概述

1,获取一个batch数据的步骤

让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。

(假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m)

1,首先我们要确定数据集的长度n。

结果类似:n = 1000。

2,然后我们从0到n-1的范围中抽样出m个数(batch大小)。

假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]

3,接着我们从数据集中去取这m个数对应下标的元素。

拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]

4,最后我们将结果整理成两个张量作为输出。

拿到的结果是两个张量,类似batch = (features,labels),

其中 features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

2,Dataset和DataLoader的功能分工

上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。

第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader的 sampler和 batch_sampler参数指定的。

sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。

第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 __getitem__方法实现的。

第4个步骤的逻辑由DataLoader的参数collate_fn指定。一般情况下也无需用户设置。

3,Dataset和DataLoader的主要接口

伪代码,实际应用意义不大

import torch 
class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):self.dataset = datasetself.sampler =torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler = torch.utils.data.BatchSamplerself.sample_iter = self.batch_sampler(self.sampler(range(len(dataset))),batch_size = batch_size,drop_last = drop_last)def __next__(self):indices = next(self.sample_iter)batch = self.collate_fn([self.dataset[i] for i in indices])return batch

二,使用Dataset创建数据集

Dataset创建数据集常用的方法有:

使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。

使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。

继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过

torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。

调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

1,根据Tensor创建数据集

  1. 头文件:
import numpy as np 
import torch 
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split 
  1. 根据Tensor创建数据集
from sklearn import datasets 
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))
  1. 分割成训练集和预测集
n_train = int(len(ds_iris)*0.8)
n_valid = len(ds_iris) - n_train
ds_train,ds_valid = random_split(ds_iris,[n_train,n_valid])
  1. 使用DataLoader加载数据集
dl_train,dl_valid = DataLoader(ds_train,batch_size = 8),DataLoader(ds_valid,batch_size = 8)#查看数据集
for features,labels in dl_train:print(features,labels)break
  1. 演示加法运算符(+)的合并作用
ds_data = ds_train + ds_validprint('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_valid))
print('len(ds_train+ds_valid) = ',len(ds_data))print(type(ds_data))

2,根据图片目录创建图片数据集

  1. 头文件:
import numpy as np 
import torch 
from torch.utils.data import DataLoader
from torchvision import transforms,datasets 
  1. 图片加载:
from PIL import Image
img = Image.open('./data/cat.jpeg')
  1. 随机数值翻转
transforms.RandomVerticalFlip()(img)
  1. 随机旋转
transforms.RandomRotation(45)(img)
  1. 定义图片增强操作
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), #随机水平翻转transforms.RandomVerticalFlip(), #随机垂直翻转transforms.RandomRotation(45),  #随机在45度角度内旋转transforms.ToTensor() #转换成张量]
) transform_valid = transforms.Compose([transforms.ToTensor()]
)
  1. 根据图片目录创建数据集
ds_train = datasets.ImageFolder("./data/cifar2/train/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("./data/cifar2/test/",transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())print(ds_train.class_to_idx)
  1. 使用DataLoader加载数据集
#注意:windows用户要把num_workers去掉,容易报错
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)for features,labels in dl_train:print(features.shape)print(labels.shape)break

三,使用DataLoader加载数据集

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

DataLoader的函数签名

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)

一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers, drop_last这五个参数,其他参数使用默认值即可。
dataset : 数据集
batch_size: 批次大小
shuffle: 是否乱序
sampler: 样本采样函数,一般无需设置。
batch_sampler: 批次采样函数,一般无需设置。
num_workers: 使用多进程读取数据,设置的进程数。
collate_fn: 整理一个批次数据的函数。
pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
timeout: 加载一个数据批次的最长等待时间,一般无需设置。
worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。

#构建输入数据管道
ds = TensorDataset(torch.arange(1,50))
dl = DataLoader(ds,batch_size = 10,shuffle= True,num_workers=2,drop_last = True)
#迭代数据
for batch, in dl:print(batch)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389359.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

铁拳nat映射_铁拳如何重塑我的数据可视化设计流程

铁拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 实战03-文件上传

作者:Hubery 时间:2018.10.31 接上文:接上文:Django2 Web 实战02-用户注册登录退出 视频是一种可视化媒介,因此视频数据库至少应该存储图像。让用户上传文件是个很大的隐患,因此接下来会讨论这俩话题&#…

BZOJ.2738.矩阵乘法(整体二分 二维树状数组)

题目链接 BZOJ洛谷 整体二分。把求序列第K小的树状数组改成二维树状数组就行了。 初始答案区间有点大,离散化一下。 因为这题是一开始给点,之后询问,so可以先处理该区间值在l~mid的修改,再处理询问。即二分标准可以直接用点的标号…

从数据库里读值往TEXT文本里写

/// <summary> /// 把预定内容导入到Text文档 /// </summary> private void ChangeDbToText() { this.RecNum.Visibletrue; //建立文件&#xff0c;并打开 string oneLine ""; string filename "Storage Card/YD" DateTime.Now.…

DengAI —如何应对数据科学竞赛? (EDA)

了解机器学习 (Understanding ML) This article is based on my entry into DengAI competition on the DrivenData platform. I’ve managed to score within 0.2% (14/9069 as on 02 Jun 2020). Some of the ideas presented here are strictly designed for competitions li…

Pytorch模型层简单介绍

模型层layers 深度学习模型一般由各种模型层组合而成。 torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类&#xff0c;具备参数管理功能。 例如&#xff1a; nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.Co…

有效沟通的技能有哪些_如何有效地展示您的数据科学或软件工程技能

有效沟通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

场景&#xff1a;接口测试 编辑器&#xff1a;eclipse 版本&#xff1a;Version: 2018-09 (4.9.0) testng版本&#xff1a;TestNG version 6.14.0 执行testng.xml时报错信息&#xff1a; 出现此报错原因之一&#xff1a;网上有人说是testng版本与eclipse版本不一致造成的&#…

[博客..配置?]博客园美化

博客园搞定时间 -> 18年6月27日 [让我歇会儿 搞这个费脑子 代码一个都看不懂] 转载于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means对美因河畔法兰克福的社区进行聚类

介绍 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch损失函数losses简介

一般来说&#xff0c;监督学习的目标函数由损失函数和正则化项组成。(Objective Loss Regularization) Pytorch中的损失函数一般在训练模型时候指定。 注意Pytorch中内置的损失函数的参数和tensorflow不同&#xff0c;是y_pred在前&#xff0c;y_true在后&#xff0c;而Ten…

读取Mc1000的 唯一 ID 机器号

先引用Symbol.ResourceCoordination 然后引用命名空间 using System;using System.Security.Cryptography;using System.IO; 以下为类程序 /// <summary> /// 获取设备id /// </summary> /// <returns></returns> public static string GetDevi…

样本均值的抽样分布_抽样分布样本均值

样本均值的抽样分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩转ceph性能测试---对象存储(一)

笔者最近在工作中需要测试ceph的rgw&#xff0c;于是边测试边学习。首先工具采用的intel的一个开源工具cosbench&#xff0c;这也是业界主流的对象存储测试工具。 1、cosbench的安装&#xff0c;启动下载最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]绝世好题

Description 题库链接 给定一个长度为 \(n\) 的数列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最长长度&#xff0c;满足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位与&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 为二进制第 \(i…

因果关系和相关关系 大数据_数据科学中的相关性与因果关系

因果关系和相关关系 大数据Let’s jump into it right away.让我们马上进入。 相关性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch构建模型的3种方法

这个地方一直是我思考的地方&#xff01;因为学的代码太多了&#xff0c;构建的模型各有不同&#xff0c;这里记录一下&#xff01; 可以使用以下3种方式构建模型&#xff1a; 1&#xff0c;继承nn.Module基类构建自定义模型。 2&#xff0c;使用nn.Sequential按层顺序构建模…

vue取数据第一个数据_我作为数据科学家的第一个月

vue取数据第一个数据A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了数据科学家的第一份工作&#xff0c;并且像任何新工作一样&#xff0c;一…

Flask-SocketIO 简单使用指南

Flask-SocketIO 使 Flask 应用程序能够访问客户端和服务器之间的低延迟双向通信。客户端应用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客户端库或任何兼容的客户端来建立与服务器的永久连接。 安装 直接使用 pip 来安装&#xf…

STL-开篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;标准模板库 定义&#xff1a; c引入的一个标准类库 特点&#xff1a;1&#xff09;数据结构和算法的 c实现&#xff08; 采用模板类和模板函数&#xff09;2&#xff09;数据的存储和算法的分离3&#xff09;高…