详细介绍torch中的from torch.utils.data.sampler相关知识

PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍:

  1. Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。
  2. RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。
  3. SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。
  4. SubsetRandomSampler: 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。
  5. WeightedRandomSampler: 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。
  6. BatchSampler: 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。

这些Sampler类可以通过在DataLoader的构造函数中指定来使用。例如,可以使用RandomSampler来实现随机采样,使用SubsetRandomSampler来实现将数据集分成训练集和验证集。此外,还可以使用函数如WeightedRandomSampler来实现加权随机采样。

下面是使用上述Sampler类和函数的示例代码:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler# 创建一个数据集
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))# 创建一个使用RandomSampler的DataLoader
random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset))# 创建一个使用SequentialSampler的DataLoader
seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset))# 创建一个使用SubsetRandomSampler的DataLoader
train_indices = [0, 1, 2, 3, 4]
val_indices = [5, 6, 7, 8, 9]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)# 创建一个使用WeightedRandomSampler的DataLoader
weights = [0.1, 0.9]
weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)
weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler)# 使用BatchSampler将样本索引分成多个批次
batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False)
batch_loader = DataLoader(dataset, batch_sampler=batch_sampler)# 遍历DataLoader,输出每个批次的数据
for data, label in random_loader:print(data, label)for data, label in seq_loader:print(data, label)for data, label in train_loader:print(data, label)for data, label in val_loader:print(data, label)for data, label in weighted_loader:print(data, label)for batch_indices in batch_sampler:batch_data = [dataset[idx] for idx in batch_indices]print(batch_data)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们创建了5个不同的DataLoader,每个DataLoader使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些DataLoader,输出每个批次的数据。

可以通过继承Sampler基类来自定义采样函数。自定义采样函数需要实现__iter__方法和__len__方法。

__iter__方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。

__len__方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。

下面是一个自定义采样器的示例代码:

import torch
from torch.utils.data.sampler import Samplerclass CustomSampler(Sampler):def __init__(self, data_source):self.data_source = data_source# 在初始化方法中,可以根据需要对数据集进行处理def __iter__(self):# 在这个方法中,可以自定义样本索引的选取方式# 这里的示例是随机选取样本indices = torch.randperm(len(self.data_source)).tolist()return iter(indices)def __len__(self):# 在这个方法中,需要返回采样器的样本数量# 这里的示例是采样器的样本数量等于数据集的长度return len(self.data_source)

在这个示例中,我们定义了一个名为CustomSampler的采样器类,它继承自Sampler基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在__iter__方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在__len__方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。

使用自定义采样器时,只需要将它传入DataLoader的构造函数即可:

dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
custom_sampler = CustomSampler(dataset)
loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们使用CustomSampler创建了一个采样器,并将它传入DataLoader的构造函数。最后,我们遍历这个DataLoader,输出每个批次的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/190276.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【SpringCloud】设计原则之前后端分离与版本控制

一、设计原则之前后端分离 在传统的 Web 应用开发中,大多数的程序员会将浏览器作为前后端的分界线 将浏览器中用户进行页面展示的部分称之为前端,而将运行在服务器,为前端提供业务逻辑和数据准备的所有代码统称为后端 由于前后端分离这个…

MQTT分析——CONNECT为例子

源代码: using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Net.Sockets;namespace ConsoleApp1 {class Program{static void Main(string[] args){Connect();}/// <summary>/// 向…

Vue3中的动态组件,使用动态组件实现页面的切换。

目录 动态组件 本文主要介绍Vue3中的动态组件&#xff0c;使用动态组件实现页面的切换。 动态组件 在Vue3中&#xff0c;动态组件是通过<component>元素来实现的。动态组件可以根据所设置的组件名称动态地渲染不同的组件。 动态组件可以通过以下步骤来使用&#xff1a;…

SQL Server 2016(在Products表中查询数据)

1、实验环境。 以实验案例一的结果为环境。 2、需求描述。 【1】查询成本低于10元的水果信息。 【2】将所有蔬菜的成本上调1源。 【3】查询成本大于3元并小于40元的产品信息&#xff0c;并按照成本从高到低的顺序显示结果。 【4】查询成本最高的5个产品信息。 【5】查询有…

房产中介管理信息系统的设计与实现

摘 要 随着房地产业的开发&#xff0c;房产中介行业也随之发展起来&#xff0c;由于房改政策的出台&#xff0c;购房、售房、租房的居民越来越多&#xff0c;这对房产中介部门无疑是一个发展的契机。本文结合目前中国城市房产管理的实际情况和现阶段房屋产业的供求关系对房产中…

基于低代码平台开发应用程序

目录 低代码介绍 预研目标 预研产品 1.业务流程 2.用户权限 3.统计图表 4.大屏设计 5.第三方登录 6.分布式调度 小结 近几年&#xff0c;一直对低代码平台有所耳闻&#xff0c;目前已经对低代码平台有了一定的认识&#xff0c;如果能通过一个可视化的配置页面就能完成前端开发&…

Sock0s1.1

信息收集 探测存活主机 发现存活主机为192.168.217.133 探测开放端口 nmap -sT -p- 192.168.217.133 -oA ./ports 发现两个端口开放&#xff0c;分别是22 3128&#xff0c;同时探测到了8080端口&#xff0c;但是显示是关闭的状态。 UDP端口探测 nmap -sU --top-ports 20 1…

linux学习资源

linux书籍资源&#xff08;pdf版&#xff09;&#xff1a; 有需要的请在评论区留言。 《Linux Basics for Hackers》 kaiwan的三部曲&#xff1a; 《Hands-On System Programming with Linux》 《Linux Kernel Programming》 《Linux Kernel Programming Part 2》 《Ma…

编程好处、系统介绍、app演示

编程视频教学地址&#xff1a; 1、编程好处 1.1、自主开发 类似微信、qq等软件应用&#xff0c;解决人们日常生活问题 例如&#xff1a; 1&#xff09;你可以&#xff0c;自己开发一个网站&#xff0c;管理自己的日常生活照片&#xff0c;防止哪一天手机掉了或丢了&#xff0…

【动手学深度学习】(八)数值稳定和模型初始化

文章目录 一、理论知识 一、理论知识 1.神经网络的梯度 考虑如下有d层的神经网络 计算损失l关于参数Wt的梯度&#xff08;链式法则&#xff09; 2.数值稳定性常见的两个问题 3.梯度爆炸 4.梯度爆炸的问题 值超出阈值 对于16位浮点数尤为严重 对学习率敏感 如果学习率太大…

CKafka 一站式搭建数据流转链路,助力长城车联网平台降低运维成本

关于长城智能新能源 长城汽车是一家全球化智能科技公司&#xff0c;业务包括汽车及零部件设计、研发、生产、销售和服务&#xff0c;旗下拥有魏牌、哈弗、坦克、欧拉及长城皮卡。2022年&#xff0c;长城汽车全年销售1,067,523辆&#xff0c;连续7年销量超100万辆。长城汽车面向…

Oracle:左连接、右连接、全外连接、(+)号详解

目录 Oracle 左连接、右连接、全外连接、&#xff08;&#xff09;号详解 1、左外连接&#xff08;LEFT OUTER JOIN/ LEFT JOIN&#xff09; 2、右外连接&#xff08;RIGHT OUTER JOIN/RIGHT JOIN&#xff09; 3、全外连接&#xff08;FULL OUTER JOIN/FULL JOIN&#xff0…

LeetCode哈希表:最长连续序列

LeetCode哈希表&#xff1a;最长连续序列 题目描述 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1&#xff1a; 输入&…

electerm下载和安装

electerm下载和安装 一、概述 electerm 是一款免费开源、基于electron/ssh2/node-pty/xterm/antd/ subx等libs的终端/ssh/sftp客户端(linux, mac, win)。 而且个人觉得electerm界面更好看一些&#xff0c;操作都是类似的。 二、下载安装 下载地址&#xff1a;https://elec…

算法基础四

括号生成 数字 n 代表生成括号的对数&#xff0c;请你设计一个函数&#xff0c;用于能够生成所有可能的并且 有效的 括号组合。 示例 1&#xff1a; 示例 1&#xff1a; 输入&#xff1a;n 3 输出&#xff1a;[“((()))”,“(()())”,“(())()”,“()(())”,“()()()”] 示例…

vue2 -- 封装 echarts 基础组件

文章目录 🍉1:传递 option 方式🍍1.1 开发环境🍍1.2 创建基础文件🍍1.3 页面使用 -- 桑基图🍍1.4 旭日图🍍1.5 图形组件 graphic🍍1.6 富文本 rich🍉1:传递 option 方式 🍍1.1 开发环境 echarts^5.0.1vue~2.6.10nodev12.14.0🍍1.2 创建基础文件 创建 sr…

【Qt之QPen】

QPen 类是 Qt 框架中的一个类&#xff0c;用于定义绘制的画笔。 QPen 类的常见属性及函数&#xff1a; 属性&#xff1a; color&#xff1a;画笔颜色width&#xff1a;画笔宽度style&#xff1a;画笔风格capStyle&#xff1a;线帽样式joinStyle&#xff1a;线段连接样式 函…

typescript泛型的基本使用

文章目录 泛型规范一、泛型的作用二、any 和 泛型 的区别1: any类型2: 泛型3: 总结 三、泛型的简单使用1.返回任何类型的泛型函数2.代码示例3.返回指定类型的泛型函数 四、泛型接口&#xff08;1&#xff09;错误代码示范&#xff08;2&#xff09;报错说明&#xff08;3&#…

无脑018——win11部署whisper,语音转文字

1.conda创建环境 conda create -n whisper python3.9 conda activate whisper安装pytorch pip install torch1.8.1cu101 torchvision0.9.1cu101 torchaudio0.8.1 -f https://download.pytorch.org/whl/torch_stable.html安装whisper pip install -U openai-whisper2.准备模型…

【论文阅读】CAN网络中基于时序信道的隐蔽认证算法

文章目录 摘要一、引言和动机A 相关工作 二、背景及实验设置A 以前工作中的时钟偏差和局限性B.最坏到达时间C.安装组件 三、优化流量分配A.问题陈述B.优化帧调度 四、协议和结果A.主协议B.对手模型C. 优化流量和单一发送者的结果D.多发送方情况和噪声信道 摘要 以前的研究工作…