Pytorch建立MyDataLoader过程详解

简介

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

详细:DataLoader

自己基于DataLoader实现各个模块

代码实现

MyDataset基于torch中的Data实现对个人数据集的载入,例如图像和标签载入
SingleSampler基于torch中的Sampler实现对于数据的batch个数图像的载入,例如,Batch_Size=4,实现对所有数据中选取4个索引作为一组,然后在MyDataset中基于__getitem__根据图像索引去进行图像操作
MyBathcSampler基于torch的BatchSampler实现自己对于batch_size数据的处理。需要基于SingleSampler实现Sampler的处理,更为灵活。MyBatchSampler的存在会自动覆盖DataLoader中的batch_size参数
注:Sampler的实现,将会与shuffer冲突,shuffer是在没有实现sampler前提下去自动判断选择的sampler类型
collate_fn是实现将batch_size的图像数据进行打包,遍历过程中就可以实现batch_size的images和labels对应
在这里插入图片描述

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

参考

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/53038.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Python爬虫定制化开发自己需要的数据集

在数据驱动的时代,获取准确、丰富的数据对于许多项目和业务至关重要。本文将介绍如何使用Python爬虫进行定制化开发,以满足个性化的数据需求,帮助你构建自己需要的数据集,为数据分析和应用提供有力支持。 1.确定数据需求和采集目…

flutter ios webview不能打开http地址

参考 1、iOS添加信任 webview_flutter 在使用过程中会iOS出现无法加载HTTP请求的情况&#xff0c; 但是Flutter 却可以加载HTTP请求。这就与两个的框架有关了&#xff0c;Flutter是独立于UIKit框架的。 解决方案就是在iOS 的info.plist中添加对HTTP的信任。 <key>NSApp…

拼多多淘宝大量缓存商品数据用什么格式提供比较好?

众所周知&#xff0c;淘宝拼多多是我国主流的电商平台&#xff0c;其上有大量的商品数据。很多商家会通过API来访问他们的商品数据&#xff0c;根据API的调用次数收费。第三方数据公司提供电商数据接口API&#xff0c;采集实时数据。但是&#xff0c;在他们的服务器上有大量的缓…

【2023钉钉杯复赛】A题 智能手机用户监测数据分析 Python代码分析

【2023钉钉杯复赛】A题 智能手机用户监测数据分析 Python代码分析 1 题目 一、问题背景 近年来&#xff0c;随着智能手机的产生&#xff0c;发展到爆炸式的普及增长&#xff0c;不仅推动了中 国智能手机市场的发展和扩大&#xff0c;还快速的促进手机软件的开发。近年中国智能…

【教程】Java 集成Mongodb

【教程】Java 集成Mongodb 依赖 <dependency><groupId>org.mongodb</groupId><artifactId>mongo-java-driver</artifactId><version>3.12.14</version></dependency> <dependency><groupId>cn.hutool</groupId…

网络安全应急响应预案培训

应急响应预案的培训是为了更好地应对网络突发状况&#xff0c;实施演 练计划所做的每一项工作&#xff0c;其培训过程主要针对应急预案涉及的相 关内容进行培训学习。做好应急预案的培训工作能使各级人员明确 自身职责&#xff0c;是做好应急响应工作的基础与前提。应急响应…

CleanMyMac2024永久版Mac清理工具

Mac电脑作为相对封闭的一个系统&#xff0c;它会中毒吗&#xff1f;如果有一天Mac电脑产生了疑似中毒或者遭到恶意不知名攻击的现象&#xff0c;那又应该如何从容应对呢&#xff1f;这些问题都是小编使用Mac系统一段时间后产生的疑惑&#xff0c;通过一番搜索研究&#xff0c;小…

人机识别:走近智能时代的大门

在当今数字化快速发展的时代&#xff0c;人机识别技术正成为引领人工智能革命的重要一环。人机识别&#xff0c;即通过计算机视觉和模式识别技术&#xff0c;使机器能够自动识别、分析、理解和处理人类的信息&#xff0c;逐渐渗透到我们的生活和工作中。从简单的人脸识别到更复…

Redis 7 教程 数据类型 基础篇

🌹 引导 Commands | Redishttps://redis.io/commands/Redis命令中心(Redis commands) -- Redis中国用户组(CRUG)Redis命令大全,显示全部已知的redis命令,redis集群相关命令,近期也会翻译过来,Redis命令参考,也可以直接输入命令进行命令检索。

图为科技_边缘计算在智能安防领域的作用

边缘计算在智能安防领域发挥着重要的作用。智能安防系统通常需要处理大量的图像、视频和传感器数据&#xff0c;并对其进行实时分析和处理。边缘计算可以将计算和数据处理功能移动到离数据源更接近的地方&#xff0c;例如摄像头、传感器设备或安防终端。 以下是边缘计算在智能…

网络爬虫到底是个啥?

网络爬虫到底是个啥&#xff1f; 当涉及到网络爬虫技术时&#xff0c;需要考虑多个方面&#xff0c;从网页获取到最终的数据处理和分析&#xff0c;每个阶段都有不同的算法和策略。以下是这些方面的详细解释&#xff1a; 网页获取&#xff08;Web Crawling&#xff09;&#x…

10 - 网络通信优化之通信协议:如何优化RPC网络通信?

微服务框架中 SpringCloud 和 Dubbo 的使用最为广泛&#xff0c;行业内也一直存在着对两者的比较&#xff0c;很多技术人会为这两个框架哪个更好而争辩。 我记得我们部门在搭建微服务框架时&#xff0c;也在技术选型上纠结良久&#xff0c;还曾一度有过激烈的讨论。当前 Sprin…

URI、URL、URIBuilder、UriBuilder、UriComponentsBuilder说明及基本使用

之前想过直接获取url通过拼接字符串的方式实现,但是这种只是暂时的,后续地址如果有变化或参数很多,去岂不是要拼接很长,由于这些等等原因,所以找了一些方法实现 java.net.URI URI全称是Uniform Resource Identifier,也就是统一资源标识符,它是一种采用特定的语法标识一…

强化学习时序差分学习方法--SARSA算法

强化学习时序差分学习方法--SARSA算法 介绍示例代码 介绍 SARSA&#xff08;State-Action-Reward-State-Action&#xff09;是一种强化学习算法&#xff0c;用于解决马尔可夫决策过程&#xff08;MDP&#xff09;中的问题。SARSA算法属于基于值的强化学习算法&#xff0c;用于…

Redis添加LocalDateTime时间序列化/反序列化Java 8报‘jackson-datatype-jsr310’问题

错误信息&#xff1a; com.fasterxml.jackson.databind.exc.InvalidDefinitionException: Java 8 date/time type java.time.LocalDateTime not supported by default: add Module "com.fasterxml.jackson.datatype:jackson-datatype-jsr310" to enable handling (t…

Navicat 连接 mysql 问题

需要将mysql配置文件设置为远程任意ip可登陆&#xff0c;注释掉一下两行配置 # bind-address>->--- 127.0.0.1 # mysqlx-bind-address>-- 127.0.0.1Cant connect to MySQL server on "192.168.137.139 (10013 "Unknown error") 检查Navicat是否联网H…

OSCS开源安全周报第 56 期:Apache Airflow Spark Provider 任意文件读取漏洞

本周安全态势综述 OSCS 社区共收录安全漏洞 3 个&#xff0c;公开漏洞值得关注的是 Apache NiFi 连接 URL 验证绕过漏洞(CVE-2023-40037)、PowerJob 未授权访问漏洞(CVE-2023-36106)、Apache Airflow Spark Provider 任意文件读取漏洞(CVE-2023-40272)。 针对 NPM 、PyPI 仓库…

stm32之点亮LED

今天&#xff0c;记录一下stm32如何点亮一个LED,程序本身十分简单&#xff0c;但主要是学习编程的格式。 led.h #ifndef _led_H #define _led_H#include "system.h"/* LED时钟端口、引脚定义 */ #define LED1_PORT GPIOB #define LED1_PIN GPIO_Pin_5 #d…

开发一款AR导览导航小程序多少钱?ar地图微信小程序 ar导航 源码

随着科技的不断发展&#xff0c;增强现实&#xff08;AR&#xff09;技术在不同领域展现出了巨大的潜力。AR导览小程序作为其中的一种应用形式&#xff0c;为用户提供了全新的观赏和学习体验。然而&#xff0c;开发一款高质量的AR导览小程序需要投入大量的时间、人力和技术资源…

❤ Ant Design Vue 2.28的使用

❤ Ant Design Vue 2.28 弹窗 //按钮 <a-button type"primary" click"showModal">Open Modal</a-button>//窗口 <a-modal v-model:visible"visible" title"Basic Modal" ok"handleOk"><p>Some con…