M2m中的采样

 采样的完整代码

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSamplerdef get_oversampled_data(dataset, num_sample_per_class):""" Generate a list of indices that represents oversampling of the dataset. """targets = np.array(dataset.targets)class_sample_count = np.array([num_sample_per_class[target] for target in targets])weight = 1. / class_sample_countsamples_weight = torch.from_numpy(weight)sampler = WeightedRandomSampler(samples_weight, len(samples_weight))return samplerdef get_val_test_data(dataset, num_test_samples):""" Split dataset into validation and test indices. """num_classes = 10targets = dataset.targetstest_indices = []val_indices = []for i in range(num_classes):indices = [j for j, x in enumerate(targets) if x == i]np.random.shuffle(indices)val_indices.extend(indices[:num_test_samples])test_indices.extend(indices[num_test_samples:num_test_samples*2])return val_indices, test_indicesdef get_oversampled(dataset_name, num_sample_per_class, batch_size, transform_train, transform_test):""" Create training and testing loaders with oversampling for imbalance. """dataset_class = datasets.__dict__[dataset_presets[dataset_name]['class']]dataset_train = dataset_class(root='./data', train=True, download=True, transform=transform_train)dataset_test = dataset_class(root='./data', train=False, download=True, transform=transform_test)# Oversamplingsampler = get_oversampled_data(dataset_train, num_sample_per_class)train_loader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler)# Validation and Test splitval_idx, test_idx = get_val_test_data(dataset_test, 1000)val_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(val_idx))test_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(test_idx))return train_loader, val_loader, test_loader# Configuration and run
dataset_presets = {'cifar10': {'class': 'CIFAR10', 'num_classes': 10}
}
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
num_sample_per_class = [500] * 10  # Pretend we want equal class distributiontrain_loader, val_loader, test_loader = get_oversampled('cifar10', num_sample_per_class, 64, transform, transform)# Print out some info from loaders
for i, (inputs, targets) in enumerate(train_loader):print(f'Batch {i}, Targets Counts: {torch.bincount(targets)}')if i == 1:  # Just show first two batches for demonstrationbreak

WeightedRandomSampler类的__iter__

def __iter__(self) -> Iterator[int]:rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)return iter(rand_tensor.tolist())
  • 方法功能:此方法实现了迭代器协议,允许WeightedRandomSampler对象在迭代中返回一系列随机选择的索引。

过采样的效果

get_oversampled函数中,使用了WeightedRandomSampler来实现过采样的逻辑。这个过程虽然看起来是通过权重调整样本的选取概率,但实际上,通过这种方式也可以达到过采样的效果,尤其是当设置replacement=True时。让我们更详细地分析一下这一点:

权重的分配

权重是根据num_sample_per_class数组分配的,这个数组定义了每个类别希望被采样到的频率。在数据加载过程中,每个类别的样本将根据其在num_sample_per_class中对应的值获得一个权重。权重越大的类别在每次迭代中

被选中的概率也越大。这样,通过调整这些权重,我们可以控制模型在训练过程中看到的每个类别样本的频率,实现对类别不平衡的处理。

过采样的实现

在使用WeightedRandomSampler时,关键的参数是replacement

  • 如果replacement=True:这允许同一个样本在一次抽样中被多次选择,即进行了过采样。对于少数类的样本来说,即使它们在数据集中的绝对数量不多,也可以通过这种方式增加它们在每个训练批次中出现的次数,从而让模型更频繁地从这些少数类样本学习。

  • 如果replacement=False:则每个样本只能被抽样一次,这通常用于不放回的抽样。在这种模式下,WeightedRandomSampler不会直接导致过采样,但可以用来确保每个类别在数据批次中都有均等的代表性,从而帮助模型学习到更平衡的特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17006.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Upstream最新发布2024年汽车网络安全报告-百度网盘下载

Upstream最新发布2024年汽车网络安全报告-百度网盘下载 2024年2月7日,Upstream Security发布了2024年Upstream《GLOBAL AUTOMOTIVE CYBERSECURITY REPORT》。这份报告的第六版着重介绍了汽车网络安全的拐点:从实验性的黑客攻击发展到规模庞大的攻击&…

fpga系列 HDL 00 : 可编程逻辑器件原理

一次性可编程器件(融保险丝实现) 一次性可编程器件(One-Time Programmable Device,简称 OTP)是一种在制造后仅能编程一次的存储设备。OTP器件在编程后数据不可更改。这些器件在很多应用场景中具有独特的优势和用途。 …

【LeetCode算法】第83题:删除排序链表中的重复元素

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路:双指针法,只需遍历一遍。使用low指向前面的元素,high用于查找low后面与low不同内容的节点。将具有不同内容的节点链接在low后面,实…

【c++】菱形虚拟继承的虚函数表如何继承

请看如下代码 #include <iostream>// 基类 class Base { public:virtual void foo() { std::cout << "Base::foo()" << std::endl; }virtual void bar() { std::cout << "Base::bar()" << std::endl; } };// 虚拟继承的中间…

西门子S7-1200加入MRP 环网用法

MRP&#xff08;介质冗余&#xff09;功能概述 SIMATIC 设备采用标准的冗余机制为 MRP&#xff08;介质冗余协议&#xff09;&#xff0c;符合 IEC62439-2 标准&#xff0c;典型重新组态时间为 200ms&#xff0c;每个环网最多支持 50个设备。​博途TIA/WINCC社区VX群 ​博途T…

Linux 批量网络远程PXE

一、搭建PXE远程安装服务器 1、yum -y install tftp-server xinetd #安装tftp服务 2、修改vim /etc/xinetd.d/tftpTFTP服务的配置文件 systemctl start tftp systemctl start xinetd 3、yum -y install dhcp #---安装服务 cp /usr/share/doc/dhc…

利用Python队列生产者消费者模式构建高效爬虫

目录 一、引言 二、生产者消费者模式概述 三、Python中的队列实现 四、生产者消费者模式在爬虫中的应用 五、实例分析 生产者类&#xff08;Producer&#xff09; 消费者类&#xff08;Consumer&#xff09; 主程序 六、总结 一、引言 随着互联网的发展&#xff0c;信…

Bug:Linux用户拥有r权限但无法打开文件【Linux权限体系】

Bug&#xff1a;Linux用户拥有r权限但无法打开文件【Linux权限体系】 0 问题描述&解决 问题描述&#xff1a; 通过go编写了一个程序&#xff0c;产生的/var/log/xx日志文件发现普通用户无权限打开 - 查看文件权限发现该文件所有者、所有者组、其他用户均有r权限 - 查看该日…

5个好用的AI写论文网站推荐

目录 1.AIQuora论文写作 2.passyyds 答辩PPT 3.AIPassgo论文降AIGC 4.文状元 5.passyyds论文写作 毕业论文是每个毕业生的痛&#xff0c;不管你是本科还是硕士要想顺利毕业你就不得不面对论文。然而&#xff0c;面对论文写作时常常感到无从下手&#xff1a;有时缺乏灵感&a…

【JavaEE进阶】——要想代码不写死,必须得有spring配置(properties和yml配置文件)

目录 本章目标&#xff1a; &#x1f6a9;配置文件 &#x1f6a9;SpringBoot配置文件 &#x1f388;配置⽂件的格式 &#x1f388; properties 配置⽂件说明 &#x1f4dd;properties语法格式 &#x1f4dd;读取配置文件 &#x1f4dd;properties 缺点分析 &#x1f3…

修改uniapp内置组件checkbox的样式

默认情况下 <view style"margin-bottom: 20rpx;"><label style"display: flex;align-items: center;width: fit-content;" click"handleCheck(cxm4s)"><checkbox /><text>车信盟出险4S维保</text></label>…

3d火灾救援模拟仿真培训软件复用性强

消防VR安全逃生体验系统是深圳VR公司华锐视点引入了前沿的VR虚拟现实、web3d开发和多媒体交互技术&#xff0c;为用户打造了一个逼真的火灾现场应急逃生模拟演练环境。 相比传统的消防逃生模拟演练&#xff0c;消防VR安全逃生体验系统包含知识讲解和模拟实训演练&#xff0c;体…

【百度智能体】5分钟打造一款为你写情书的智能体

目录 前言一、智能体特点二、应用场景三、打造一款写情书智能体1、名称2、简介3、填写人物设定&#xff1a;4、开场白5、引导示例6、预览 最后 前言 智能体作为人工智能领域的一个重要概念&#xff0c;是指能够自主感知环境、做出决策并执行行动的系统。它具备自主性、交互性、…

单元测试(了解)

单元测试定义 针对最小功能单元&#xff08;方法&#xff09;&#xff0c;编写测试代码对其进行正确性测试 之前如何进行单元测试&#xff1f;有什么问题&#xff1f; main中编写测试代码&#xff0c;调用方法测试 问题&#xff1a; 无法自动化测试 每个方法的测试可能不是…

EPSON爱普生RTC RA8900CE/RA8000CE+松下Panasonic电池组合

RTC是一种实时时钟&#xff0c;用于记录和跟踪时间&#xff0c;具有独立供电和时钟功能。在某些应用场景中&#xff0c;为了保证RTC在断电或者其他异常情况下依然能够正常工作&#xff0c;需要备份电池方案来提供稳定的供电。本文将介绍EPSON爱普生RTC RA8900CE/RA8000CE松下Pa…

【吊打面试官系列】Java高并发篇 - AQS 支持几种同步方式 ?

大家好&#xff0c;我是锋哥。今天分享关于 【AQS 支持几种同步方式 &#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; AQS 支持几种同步方式 &#xff1f; 1、独占式 2、共享式 这样方便使用者实现不同类型的同步组件&#xff0c;独占式如 ReentrantLock&…

VUE3.0学习-模版语法

安装Node.js的过程相对直接&#xff0c;以下是详细的步骤指导&#xff0c;适用于大多数操作系统&#xff1a; ### 1. 访问Node.js官方网站 首先&#xff0c;打开浏览器&#xff0c;访问 [Node.js 官方网站](https://nodejs.org/)。 ### 2. 选择合适的版本下载 在Node.js官网上…

OpenHarmony 实战开发——ArkUI中的线程和看门狗机制

一、前言 本文主要分析ArkUI中涉及的线程和看门狗机制。 二、ArkUI中的线程 应用Ability首次创建界面的流程大致如下&#xff1a; 说明&#xff1a; • AceContainer是一个容器类&#xff0c;由前端、任务执行器、资源管理器、渲染管线、视图等聚合而成&#xff0c;提供了生…

C++ 头文件优化

C 是一种灵活的语言&#xff0c;所以需要一种积极的方法来分析和减少编译时依赖。一种常见的达到这个目的的方法是&#xff0c;将依赖从头文件里转移到源代码文件里。实现这个目的的方法叫做提前声明。 简而言之&#xff0c;这些声明告诉编译器某个函数接受和返回哪些参数&…

Python操作MySQL实战

文章导读 本文用于巩固Pymysql操作MySQL与MySQL操作的知识点&#xff0c;实现一个简易的音乐播放器&#xff0c;拟实现的功能包括&#xff1a;用户登录&#xff0c;窗口显示&#xff0c;加载本地音乐&#xff0c;加入和删除播放列表&#xff0c;播放音乐。 点击此处获取参考源…