PyTorch框架学习八——PyTorch数据读取机制(简述)

PyTorch框架学习八——PyTorch数据读取机制(简述)

  • 一、数据
  • 二、DataLoader与Dataset
    • 1.torch.utils.data.DataLoader
    • 2.torch.utils.data.Dataset
  • 三、数据读取整体流程

琢磨了一段时间,终于对PyTorch的数据读取机制有了一点理解,并自己实现了简单数据集(猫狗分类数据集)的读入和训练,这里简单写一写自己的理解,以备日后回顾。

一、数据

简单来说,一个机器学习或深度学习问题可以拆解为五个主要的部分:数据、模型、损失函数、优化器和迭代过程,这五部分每个都可以详细展开,都有非常多的知识点,而一切的开始,都源于数据。

一般数据部分可以分为四个主要的内容去学习:

  1. 数据收集:即获取Img和相应的Label。
  2. 数据划分:划分为训练集、验证集和测试集。
  3. 数据读取:DataLoader。
  4. 数据预处理:transforms。

在PyTorch框架的学习中,前两个不是重点,它们是机器学习基础和Python基础的事。而PyTorch的数据预处理transforms方法在前几次笔记进行了很详细地介绍,这次笔记重点是写一点对数据读取机制的理解,这也是最折磨的一部分,经过了很多次的步进演示,终于对整个数据读取过程有了一个较为完整的印象。

总的来说,DataLoader里比较重要的是Sampler和Dataset,前者负责获取要读取的数据的索引,即读哪些数据,后者决定数据从哪里读取以及如何读取。

二、DataLoader与Dataset

1.torch.utils.data.DataLoader

功能:构建可迭代的数据装载器。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None)

参数比较多,如下所示:
在这里插入图片描述
介绍几个主要的:

  1. dataset:Dataset类,决定数据从哪读取以及如何读取。
  2. batch_size:批大小,默认为1。
  3. num_works:是否多进程读取数据。
  4. shuffle:每个epoch是否乱序。
  5. drop_last:当样本数不能被batch_size整除时,是否舍弃最后一批数据。

上面涉及到一个小知识点,顺带介绍一下,即Epoch、Iteration、Batchsize之间的关系:

  1. Epoch:所有训练样本都输入到模型中,称为一个epoch。
  2. Iteration:一个Batch的样本输入到模型中,称为一个Iteration。
  3. Batchsize:批大小,决定一个epoch有多少个iteration。

举个栗子:

若样本总数:80,Batchsize:8,则 1 Epoch = 10 Iterations。
若样本总数:87,Batchsize:8,且 drop_last = True,则1 Epoch = 10 Iterations;而drop_last = False时,1 Epoch = 11 Iterations。

2.torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()函数。
这里__getitem__()函数的功能是:接收一个索引,返回一个样本。

三、数据读取整体流程

经过上面简单的介绍,下面来看一下数据读取的整体流程:
在这里插入图片描述

  1. 从DataLoader这个命令开始。
  2. 然后进入到DataLoaderIter里,判断是单进程还是多进程。
  3. 然后进入到Sampler里进行采样,获得一批一批的索引,这些索引就指引了要读取哪些数据。
  4. 然后进入到DatasetFetcher中要依据Sampler获得的Index对数据进行获取。
  5. 在DatasetFetcher调用Dataset类,这里是我们自定义的数据集,数据集一般放在硬盘中,Dataset里面一般都有数据的路径,所以也就能知道了从哪读取数据。
  6. 自定义的Dataset类里再调用__getitem__函数,这里有我们编写的如何读取数据的代码,依据这里的代码读取数据。
  7. 读取出来后可能需要进行图像预处理或数据增强,所以紧接着是transforms方法。
  8. 经过上述的读取,已经得到了图像及其标签,但是还需要将它们组合成batch,就是下面的collate_fn,最后得到了一个batch一个batch的数据。

这个过程中的三个主要问题:

  1. 读哪些数据:Sampler输出要读取的数据的Index。
  2. 从哪读数据:Dataset类中的data_dir,即数据的存放路径。
  3. 怎么读数据:Dataset类中编写的__getitem__()函数。

精力有限,就不在这里写一个具体读取数据的代码了,这里有很多有价值的课程和资料可以学习:深度之眼PyTorch框架

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491699.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

报告 | 2019年全球数字化转型现状研究报告

来源:Prophet2019年,战略数字化转型的重要性已经不止于IT领域,而影响着全公司的竞争力。企业的相关预算直线攀升,利益相关方所关注的颠覆性技术数量急剧增加。数字化项目开始由首席高管主导,并由相互协作的跨职能团队管…

Android调用binder实现权限提升-android学习之旅(81)

当进程A权限较低,而B权限较高时,容易产生提权漏洞 fuzz测试的测试路径 First level Interface是服务 Second level Interface是服务中对应的接口 1.首先获取第一层和第二层接口,及服务以及对应服务提供的接口 2.根据以上信息结合参数类型信息…

PyTorch框架学习九——网络模型的构建

PyTorch框架学习九——网络模型的构建一、概述二、nn.Module三、模型容器Container1.nn.Sequential2.nn.ModuleList3.nn.ModuleDict()4.总结笔记二到八主要介绍与数据有关的内容,这次笔记将开始介绍网络模型有关的内容,首先我们不追求网络内部各层的具体…

中国17种稀土有啥军事用途?没它们,美军技术优势将归零

来源:陶慕剑观察 稀土就是化学元素周期表中镧系元素——镧(La)、铈(Ce)、镨(Pr)、钕(Nd)、钷(Pm)、钐(Sm)、铕(Eu)、钆(Gd)、铽(Tb)、镝(Dy)、钬(Ho)、铒(Er)、铥(Tm)、镱(Yb)、镥(Lu),再加上钪(Sc)和钇(Y)共17种元素。中国稀土占据着众多的世界第一&…

PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)

PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)一、卷积层二、转置卷积层三、池化层1.最大池化nn.MaxPool2d2.平均池化nn.AvgPool2d四、反池化层最大值反池化nn.MaxUnpool2d五、线性层六、激活函数层1.nn.Sigmoid2.nn.…

PyTorch框架学习十一——网络层权值初始化

PyTorch框架学习十一——网络层权值初始化一、均匀分布初始化二、正态分布初始化三、常数初始化四、Xavier 均匀分布初始化五、Xavier正态分布初始化六、kaiming均匀分布初始化前面的笔记介绍了网络模型的搭建,这次将介绍网络层权值的初始化,适当的初始化…

W3C 战败:无权再制定 HTML 和 DOM 标准!

来源:CSDN历史性时刻!——近日,W3C正式宣告战败:HTML和DOM标准制定权将全权移交给浏览器厂商联盟WHATWG。由苹果、Google、微软和Mozilla四大浏览器厂商组成的WHATWG已经与万维网联盟(World Wide Web Consortium&#…

PyTorch框架学习十二——损失函数

PyTorch框架学习十二——损失函数一、损失函数的作用二、18种常见损失函数简述1.L1Loss(MAE)2.MSELoss3.SmoothL1Loss4.交叉熵CrossEntropyLoss5.NLLLoss6.PoissonNLLLoss7.KLDivLoss8.BCELoss9.BCEWithLogitsLoss10.MarginRankingLoss11.HingeEmbedding…

化合物半导体的机遇

来源:国盛证券半导体材料可分为单质半导体及化合物半导体两类,前者如硅(Si)、锗(Ge)等所形成的半导体,后者为砷化镓(GaAs)、氮化镓(GaN)、碳化硅(…

PyTorch框架学习十三——优化器

PyTorch框架学习十三——优化器一、优化器二、Optimizer类1.基本属性2.基本方法三、学习率与动量1.学习率learning rate2.动量、冲量Momentum四、十种常见的优化器(简单罗列)上次笔记简单介绍了一下损失函数的概念以及18种常用的损失函数,这次…

最全芯片产业报告出炉,计算、存储、模拟IC一文扫尽

来源:智东西最近几年, 半导体产业风起云涌。 一方面, 中国半导体异军突起, 另一方面, 全球产业面临超级周期,加上人工智能等新兴应用的崛起,中美科技摩擦频发,全球半导体现状如何&am…

python向CSV文件写内容

f open(r"D:\test.csv", w) f.write(1,2,3\n) f.write(4,5,6\n) f.close() 注意:上面例子中的123456这6个数字会分别写入不同的单元格里,即以逗号作为分隔符将字符串内容分开放到不同单元格 上面例子的图: 如果要把变量的值放入…

PyTorch框架学习十四——学习率调整策略

PyTorch框架学习十四——学习率调整策略一、_LRScheduler类二、六种常见的学习率调整策略1.StepLR2.MultiStepLR3.ExponentialLR4.CosineAnnealingLR5.ReduceLRonPlateau6.LambdaLR在上次笔记优化器的内容中介绍了学习率的概念,但是在整个训练过程中学习率并不是一直…

JavaScript数组常用方法

转载于:https://www.cnblogs.com/kenan9527/p/4926145.html

蕨叶形生物刷新生命史,动物界至少起源于5.7亿年前

来源 :newsweek.com根据发表于《古生物学》期刊(Palaeontology)的一项研究,动物界可能比科学界所知更加古老。研究人员发现,一种名为“美妙春光虫”(Stromatoveris psygmoglena)的海洋生物在埃迪…

PyTorch框架学习十五——可视化工具TensorBoard

PyTorch框架学习十五——可视化工具TensorBoard一、TensorBoard简介二、TensorBoard安装及测试三、TensorBoard的使用1.add_scalar()2.add_scalars()3.add_histogram()4.add_image()5.add_graph()之前的笔记介绍了模型训练中的数据、模型、损失函数和优化器,下面将介…

CNN、RNN、DNN的内部网络结构有什么区别?

来源:AI量化百科神经网络技术起源于上世纪五、六十年代,当时叫感知机(perceptron),拥有输入层、输出层和一个隐含层。输入的特征向量通过隐含层变换达到输出层,在输出层得到分类结果。早期感知机的推动者是…

L2级自动驾驶量产趋势解读

来源:《国盛计算机组》L2 级自动驾驶离我们比想象的更近。18 年下半年部分 L2 车型已面世,凯迪拉克、吉利、长城、长安、上汽等均已推出了 L2 自动驾驶车辆。国内目前在售2872个车型,L2级功能渗透率平均超过25%,豪华车甚至超过了6…

PyTorch框架学习十六——正则化与Dropout

PyTorch框架学习十六——正则化与Dropout一、泛化误差二、L2正则化与权值衰减三、正则化之Dropout补充:这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout。 一、泛化误差 一般模型的泛化误差可以被分解为三部分:偏差、方差与噪声…

HDU 5510 Bazinga 暴力匹配加剪枝

Bazinga Time Limit: 20 Sec Memory Limit: 256 MB 题目连接 http://acm.hdu.edu.cn/showproblem.php?pid5510 Description Ladies and gentlemen, please sit up straight.Dont tilt your head. Im serious.For n given strings S1,S2,⋯,Sn, labelled from 1 to n, you shou…