【Pytorch笔记】5.DataLoader、Dataset、自定义Dataset

参考
深度之眼官方账号 - 02-01 Dataloader与Dataset.mp4

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器。

data.DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

dataset:Dataset类,决定数据从哪里读取;
batch_size:批大小;
shuffle:每个epoch是否乱序;
sampler:定义从数据集中抽取数据的策略,指明sampler时不能指明shuffle。
batch_sampler:和sampler相似,但是一次返回一个批大小的下标。和batch_sizeshufflesamplerdrop_last互斥;
num_workers:使用该DataLoader加载数据的子进程的数量。如果该值为0,意味着只有主进程使用该DataLoader加载数据。
collate_fn:合并样本列表,形成一个小型tensor批次。在使用map-style dataset批量加载时使用。
pin_memory:True时,DataLoader在返回tensor数据时会先将这些数据复制到device/CUDA的pinned memory中。
drop_last:当样本数不能被batch_size整除时,是否舍弃最后一批数据。
timeout:如果是正数,则表示获取批次的超时值。要求设置为非负数。
worker_init_fn:如果不是None,那么这个函数会调用所有带有worker_id的worker的子进程。

Epoch:所有训练样本都已输入到模型中,称为一个Epoch。
Iteration:一批样本输入到模型中,称为一个Iteration。
BatchSize:批大小,决定一个Epoch有多少个Iteration。

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承他,并且复写__getitem__()

__getitem__():接收一个索引,返回一个样本。

torch.utils.data.TensorDataset

由tensor封装的数据集,传入一些在第1维长度相同的tensor,形成dataset。

DataLoader如何使用?

import torch
from torch.utils import datadef data_gen(num_examples):X = torch.normal(0, 1, (num_examples, 2))y = torch.normal(0, 1, (num_examples, 1))return X, ydef load_array(data_arrays, batch_size, need_shuffle=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)data_A, data_B = data_gen(30)
# 定义一个数据迭代器data_iter
data_iter = load_array((data_A, data_B), 10, False)# 使用for循环,iter每次会指向下一组数据
# 每组数据返回我们封装进去的那些tensor,封装几个就返回几个
for X, y in iter(data_iter):print("X\n", X)print("y\n", y)

输出:

Xtensor([[-0.2338, -0.4072],[ 1.6080, -0.0880],[-1.0243, -0.6245],[ 0.6012,  1.9620],[ 1.1876,  0.9539],[-0.5972,  0.9251],[-0.4918, -0.1340],[ 2.3297,  0.1833],[-0.8487,  0.3370],[-0.0600,  2.3769]])
ytensor([[ 1.0454],[ 0.5992],[ 2.0075],[ 0.2727],[-1.6845],[ 0.0845],[ 1.0992],[ 0.5103],[-0.6727],[-0.1900]])
Xtensor([[-0.1419, -1.5535],[-1.6436,  0.8680],[-0.4432, -0.7703],[ 0.3822,  0.4675],[ 0.4000,  1.3471],[ 0.9776,  2.0103],[ 0.1298,  2.7382],[ 0.2664, -0.6223],[-1.0774,  0.0734],[-0.1904, -1.3299]])
ytensor([[-0.5979],[-0.5432],[ 0.2951],[ 0.2811],[-0.5997],[ 0.8073],[ 1.4356],[ 1.1555],[-0.3368],[-0.0626]])
Xtensor([[-1.4326, -0.3407],[-1.1878, -1.5619],[ 0.3498,  1.5307],[-0.8174,  0.6017],[ 0.8076,  0.8295],[ 2.6239,  1.1669],[-1.2598,  1.4309],[ 0.3365,  0.1765],[-0.4472, -0.6882],[ 0.6732, -0.0742]])
ytensor([[ 0.5114],[ 1.0669],[-1.5565],[ 0.4512],[ 3.2071],[ 0.4752],[-1.5981],[ 0.0035],[-0.2723],[-1.3634]])

自定义Dataset

import torch
from torch.utils import data# 自定义的Dataset
# 每一条数据包含X的一行数据(1x2的tensor)和y的一行数据(1x1的tensor)
class MyTensorDataset(data.Dataset):def __init__(self, tensor_list):# 初始化超类 (data.Dataset)super(MyTensorDataset, self).__init__()self.data = tensor_listdef __getitem__(self, index):# 这个函数必须重写# 返回dataset中下标编号为index的数据sample = {'X': self.data[0][index], 'y': self.data[1][index]}return sampledef __len__(self):# 返回这个dataset中的数据个数return len(self.data[0])# 随机生成数据
def data_gen(num_examples):X = torch.normal(0, 1, (num_examples, 2))y = torch.normal(0, 1, (num_examples, 1))return X, y# 使用自己的Dataset
def load_array(data_arrays, batch_size, need_shuffle=True):X, y = data_arrays  dataset = MyTensorDataset((X, y)) return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)data_A, data_B = data_gen(30)
data_iter = load_array((data_A, data_B), 10, False)for sample in iter(data_iter):print("sample:\n", sample)

输出:

sample:{'X': tensor([[ 0.2149,  0.6216],[ 0.4691,  0.1862],[-1.7705, -1.5983],[-0.0196, -0.5903],[-0.1313,  0.3206],[ 0.1898, -0.2575],[ 1.5934, -0.4720],[-0.8343, -0.2181],[ 1.6159, -0.5473],[-1.2662, -0.0218]]), 'y': tensor([[ 0.8886],[ 0.1653],[ 0.8054],[-0.0725],[-0.4806],[-0.4661],[-0.4040],[ 0.6192],[-0.2522],[ 1.4091]])}
sample:{'X': tensor([[ 0.6441, -0.5759],[-0.7285, -1.0021],[ 0.1250, -0.2333],[ 0.3196,  0.7762],[ 0.1429,  0.4667],[ 1.0751, -0.4867],[ 0.1664,  0.3489],[ 0.1616, -0.1998],[ 0.6707, -0.4678],[-1.7778, -2.4658]]), 'y': tensor([[-0.2342],[ 0.1402],[ 0.5768],[ 1.7898],[-0.6802],[-2.3584],[ 0.7048],[ 0.1848],[ 0.1225],[ 0.5535]])}
sample:{'X': tensor([[ 0.1694,  0.2863],[ 0.3062, -0.7494],[ 0.6844,  1.9278],[ 0.9141,  0.3842],[-1.2314, -1.4933],[ 0.1568, -1.4182],[ 1.7723, -0.4890],[ 0.5734,  1.0614],[ 0.9536, -0.2866],[ 0.2510,  0.3375]]), 'y': tensor([[ 0.9802],[-0.5557],[ 0.7763],[ 1.1688],[-1.0067],[-0.4044],[-0.2745],[-0.3661],[-0.6058],[ 0.6905]])}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个案例熟悉使用pytorch

文章目录 1. 完整模型的训练套路1.2 导入必要的包1.3 准备数据集1.3.1 使用公开数据集:1.3.2 获取训练集、测试集长度:1.3.3 利用 DataLoader来加载数据集 1.4 搭建神经网络1.4.1 测试搭建的模型1.4.2 创建用于训练的模型 1.5 定义损失函数和优化器1.6 使…

JDK8 Stream测试

如何创建一个流Stream,三种方法:测试 1、通过 java.util.Collection.stream() 2、通过数组来创建流 3、静态方法:使用Stream的静态方法:of()、iterate()、generate() public class StreamJ {public static void main(String[] arg…

redis持久化与调优

一 、Redis 高可用: 在web服务器中,高可用是指服务器可以正常访问的时间,衡量的标准是在多长时间内可以提供正常服务(99.9%、99.99%、99.999%等等)。但是在Redis语境中,高可用的含义似乎要宽泛一些&#x…

POJ 2886 Who Gets the Most Candies? 树状数组+二分

一、题目大意 我们有N个孩子,每个人带着一张卡片,一起顺时针围成一个圈来玩游戏,第一回合时,第k个孩子被淘汰,然后他说出他卡片上的数字A,如果A是一个正数,那么下一个回合他左边的第A个孩子被淘…

通过usb串口发送接收数据

USB通信使用系统api,USB转串口通信使用第三方库usb-serial-for-android, 串口通信使用Google官方库android-serialport-api。x 引入包后在本地下载的位置:C:\Users\Administrator\.gradle\caches\modules-2\files-2.1 在 Android 中&#x…

【python海洋专题十一】colormap调色

【python海洋专题十一】colormap调色 上期内容 本期内容 图像的函数包调用! Part01. 自带颜色条Colormap 调用方法: cmap3plt.get_cmap(ocean)查询方法! Part02. seaborn函数包 01:sns.cubehelix_palette cmap5 sns.cu…

string类的模拟实现(万字讲解超详细)

目录 前言 1.命名空间的使用 2.string的成员变量 3.构造函数 4.析构函数 5.拷贝构造 5.1 swap交换函数的实现 6.赋值运算符重载 7.迭代器部分 8.数据容量控制 8.1 size和capacity 8.2 empty 9.数据修改部分 9.1 push_back 9.2 append添加字符串 9.3 运算符重载…

OpenCV利用Camshift实现目标追踪

目录 原理 做法 代码实现 结果展示 原理 做法 代码实现 import numpy as np import cv2 as cv# 读取视频 cap cv.VideoCapture(video.mp4)# 检查视频是否成功打开 if not cap.isOpened():print("Error: Cannot open video file.")exit()# 获取第一帧图像&#x…

SpringCloud Alibaba - Sentinel 微服务保护解决雪崩问题、Hystrix 区别、安装及使用

目录 一、Sentinel 1.1、背景:雪崩问题 1.2、雪崩问题的解决办法 1.2.1、超时处理 缺陷:为什么这里只是 “缓解” 雪崩问题,而不是百分之百解决了雪问题呢? 1.2.2、舱壁模式 缺陷:资源浪费 1.2.3、熔断降级 1.…

OK3568 forlinx系统编译过程及问题汇总

1. 共享文件夹无法加载;通过网上把文件夹加载后,拷贝文件很慢,任务管理器查看发现硬盘读写速率很低。解决办法:重新安装vmware tools。 2. 拷贝Linux源码到虚拟机,解压。 3. 虚拟机基本库安装 forlinxubuntu:~$ sudo…

『力扣每日一题12』:只出现一次的数字

一、题目 给你一个 非空 整数数组 nums ,除了某个元素只出现一次以外,其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法来解决此问题,且该算法只使用常量额外空间。 示例 1 : 输入&…

WVP-28181协议视频平台搭建教程

28181协议视频平台搭建教程 安装mysql安装redis安装ZLMediaKit安装28181协议视频平台安装依赖下载源码编译静态页面打包项目, 生成可执行jar修改配置文件启动WVP 项目地址: https://github.com/648540858/wvp-GB28181-pro 说明: wvp-GB28181-pro 依赖redis和mysql中…

案例题--信息系统架构设计

案例题--信息系统架构设计 概念 以扩展了解为主,主要关注图 概念 架构的组成:构件,连接件,约束 构件:组成元素 连接件:构件之间的连接方式 约束:构件和连接件之间的约束 上应,下技&a…

Linux CentOS7 vim多窗口编辑

我们在用vim编辑文件时,有各种需求。如有时需要在多个文件之间来回操作,一会关闭一个文件,一会再打开另外一个文件,这样来回操作显得太笨拙。有时,vim编辑多行的大文件,来回查看、编辑前面一部分及最后一部…

【Axure】元件库和母版、常见的原型规范、静态原型页面制作

添加现有元件库 点击元件库——载入 当然也可以创建元件库,自己画自己保存 建立京东秒杀母版 静态原型页面的制作 框架 选择以iphone8的界面大小为例,顶部状态栏高度为20 左侧类似于标尺,因为图标、文字离最左侧的间距是不一样的 信…

基于Kylin的数据统计分析平台架构设计与实现

目录 1 前言 2 关键模块 2.1 数据仓库的搭建 2.2 ETL 2.3 Kylin数据分析系统 2.4 数据可视化系统 2.5 报表模块 3 最终成果 4 遇到问题 1 前言 这是在TP-LINK公司云平台部门做的一个项目,总体包括云上数据统计平台的架构设计和组件开发,在此只做…

深入了解 Linux 中的 AWK 命令:文本处理的瑞士军刀

简介 在Linux和Unix操作系统中,文本处理是一个常见的任务。AWK命令是一个强大的文本处理工具,专门进行文本截取和分析,它允许你在文本文件中查找、过滤、处理和格式化数据。本文将深入介绍Linux中的AWK命令,让你了解其基本用法和…

Linux-Centos中配置docker

1.安装yum工具 yum install -y yum-utils 2.配置yam源头 yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 3.安装docker yum install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin 4. 查看d…

ElasticSearch第四讲:ES详解:ElasticSearch和Kibana安装

ElasticSearch第四讲:ES详解:ElasticSearch和Kibana安装 本文是ElasticSearch第四讲:ElasticSearch和Kibana安装,主要介绍ElasticSearch和Kibana的安装。了解完ElasticSearch基础和Elastic Stack生态后,我们便可以开始…

数据库的备份与恢复

数据备份的重要性 备份的主要目的是灾难恢复。 在生产环境中,数据的安全性至关重要。 任何数据的丢失都可能产生严重的后果。 造成数据丢失的原因: 程序错误人为操作错误运算错误磁盘故障灾难(如火灾、地震)和盗窃 数据库备份…