M2m中的采样

 采样的完整代码

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSamplerdef get_oversampled_data(dataset, num_sample_per_class):""" Generate a list of indices that represents oversampling of the dataset. """targets = np.array(dataset.targets)class_sample_count = np.array([num_sample_per_class[target] for target in targets])weight = 1. / class_sample_countsamples_weight = torch.from_numpy(weight)sampler = WeightedRandomSampler(samples_weight, len(samples_weight))return samplerdef get_val_test_data(dataset, num_test_samples):""" Split dataset into validation and test indices. """num_classes = 10targets = dataset.targetstest_indices = []val_indices = []for i in range(num_classes):indices = [j for j, x in enumerate(targets) if x == i]np.random.shuffle(indices)val_indices.extend(indices[:num_test_samples])test_indices.extend(indices[num_test_samples:num_test_samples*2])return val_indices, test_indicesdef get_oversampled(dataset_name, num_sample_per_class, batch_size, transform_train, transform_test):""" Create training and testing loaders with oversampling for imbalance. """dataset_class = datasets.__dict__[dataset_presets[dataset_name]['class']]dataset_train = dataset_class(root='./data', train=True, download=True, transform=transform_train)dataset_test = dataset_class(root='./data', train=False, download=True, transform=transform_test)# Oversamplingsampler = get_oversampled_data(dataset_train, num_sample_per_class)train_loader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler)# Validation and Test splitval_idx, test_idx = get_val_test_data(dataset_test, 1000)val_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(val_idx))test_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(test_idx))return train_loader, val_loader, test_loader# Configuration and run
dataset_presets = {'cifar10': {'class': 'CIFAR10', 'num_classes': 10}
}
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
num_sample_per_class = [500] * 10  # Pretend we want equal class distributiontrain_loader, val_loader, test_loader = get_oversampled('cifar10', num_sample_per_class, 64, transform, transform)# Print out some info from loaders
for i, (inputs, targets) in enumerate(train_loader):print(f'Batch {i}, Targets Counts: {torch.bincount(targets)}')if i == 1:  # Just show first two batches for demonstrationbreak

WeightedRandomSampler类的__iter__

def __iter__(self) -> Iterator[int]:rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)return iter(rand_tensor.tolist())
  • 方法功能:此方法实现了迭代器协议,允许WeightedRandomSampler对象在迭代中返回一系列随机选择的索引。

过采样的效果

get_oversampled函数中,使用了WeightedRandomSampler来实现过采样的逻辑。这个过程虽然看起来是通过权重调整样本的选取概率,但实际上,通过这种方式也可以达到过采样的效果,尤其是当设置replacement=True时。让我们更详细地分析一下这一点:

权重的分配

权重是根据num_sample_per_class数组分配的,这个数组定义了每个类别希望被采样到的频率。在数据加载过程中,每个类别的样本将根据其在num_sample_per_class中对应的值获得一个权重。权重越大的类别在每次迭代中

被选中的概率也越大。这样,通过调整这些权重,我们可以控制模型在训练过程中看到的每个类别样本的频率,实现对类别不平衡的处理。

过采样的实现

在使用WeightedRandomSampler时,关键的参数是replacement

  • 如果replacement=True:这允许同一个样本在一次抽样中被多次选择,即进行了过采样。对于少数类的样本来说,即使它们在数据集中的绝对数量不多,也可以通过这种方式增加它们在每个训练批次中出现的次数,从而让模型更频繁地从这些少数类样本学习。

  • 如果replacement=False:则每个样本只能被抽样一次,这通常用于不放回的抽样。在这种模式下,WeightedRandomSampler不会直接导致过采样,但可以用来确保每个类别在数据批次中都有均等的代表性,从而帮助模型学习到更平衡的特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17006.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言从头学12——流程控制(一)

C语言程序的执行顺序是从前到后依次序执行的。如果想要控制程序执行的流程,就必须使用 流程控制的语法结构,分为条件执行和循环执行。 1、if语句 if 语句在前面的举例中曾经出现过,这里做详细介绍。该语句用于条件判断,满…

Upstream最新发布2024年汽车网络安全报告-百度网盘下载

Upstream最新发布2024年汽车网络安全报告-百度网盘下载 2024年2月7日,Upstream Security发布了2024年Upstream《GLOBAL AUTOMOTIVE CYBERSECURITY REPORT》。这份报告的第六版着重介绍了汽车网络安全的拐点:从实验性的黑客攻击发展到规模庞大的攻击&…

fpga系列 HDL 00 : 可编程逻辑器件原理

一次性可编程器件(融保险丝实现) 一次性可编程器件(One-Time Programmable Device,简称 OTP)是一种在制造后仅能编程一次的存储设备。OTP器件在编程后数据不可更改。这些器件在很多应用场景中具有独特的优势和用途。 …

【软件设计师】——10.面向对象技术

目录 10.1 基本概念 10.2设计原则 10.3 设计模式的概念与分类 10.4 创建型模式 10.4.1 Singleton 单例模式 10.4.2 Builder 构建器模式 10.4.3 Abstract Factory 抽象工厂模式 10.4.4 Prototype原型模式 10.4.5 Factory Method工厂方法模式 10.5 结构型模式 10.5.1 A…

【LeetCode算法】第83题:删除排序链表中的重复元素

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路:双指针法,只需遍历一遍。使用low指向前面的元素,high用于查找low后面与low不同内容的节点。将具有不同内容的节点链接在low后面,实…

【c++】菱形虚拟继承的虚函数表如何继承

请看如下代码 #include <iostream>// 基类 class Base { public:virtual void foo() { std::cout << "Base::foo()" << std::endl; }virtual void bar() { std::cout << "Base::bar()" << std::endl; } };// 虚拟继承的中间…

全栈:session用户会话信息,用户浏览记录实例

PHP中的session是一种存储机制&#xff0c;它允许您存储和跟踪用户在访问Web应用程序时的信息。会话通常用于存储用户特定的数据&#xff0c;如用户ID、购物车内容、用户偏好设置等&#xff0c;这些数据需要在多个页面请求之间保持不变。 session详解 1. 会话是如何工作的 会…

西门子S7-1200加入MRP 环网用法

MRP&#xff08;介质冗余&#xff09;功能概述 SIMATIC 设备采用标准的冗余机制为 MRP&#xff08;介质冗余协议&#xff09;&#xff0c;符合 IEC62439-2 标准&#xff0c;典型重新组态时间为 200ms&#xff0c;每个环网最多支持 50个设备。​博途TIA/WINCC社区VX群 ​博途T…

Linux 批量网络远程PXE

一、搭建PXE远程安装服务器 1、yum -y install tftp-server xinetd #安装tftp服务 2、修改vim /etc/xinetd.d/tftpTFTP服务的配置文件 systemctl start tftp systemctl start xinetd 3、yum -y install dhcp #---安装服务 cp /usr/share/doc/dhc…

c 语言 ---- 结构体

什么是结构体 自定义的数据类型 结构体的声明定义 //1.先声明再定义 struct point{int x;int y; };struct point p1,p2;//2.声明的同时定义 struct point{int x;int y; }p1,p2;typedef定义别名 关键字typedef用于为系统固有的或者程序员自定义的数据类型定义一个别名。数据类…

利用Python队列生产者消费者模式构建高效爬虫

目录 一、引言 二、生产者消费者模式概述 三、Python中的队列实现 四、生产者消费者模式在爬虫中的应用 五、实例分析 生产者类&#xff08;Producer&#xff09; 消费者类&#xff08;Consumer&#xff09; 主程序 六、总结 一、引言 随着互联网的发展&#xff0c;信…

Bug:Linux用户拥有r权限但无法打开文件【Linux权限体系】

Bug&#xff1a;Linux用户拥有r权限但无法打开文件【Linux权限体系】 0 问题描述&解决 问题描述&#xff1a; 通过go编写了一个程序&#xff0c;产生的/var/log/xx日志文件发现普通用户无权限打开 - 查看文件权限发现该文件所有者、所有者组、其他用户均有r权限 - 查看该日…

5个好用的AI写论文网站推荐

目录 1.AIQuora论文写作 2.passyyds 答辩PPT 3.AIPassgo论文降AIGC 4.文状元 5.passyyds论文写作 毕业论文是每个毕业生的痛&#xff0c;不管你是本科还是硕士要想顺利毕业你就不得不面对论文。然而&#xff0c;面对论文写作时常常感到无从下手&#xff1a;有时缺乏灵感&a…

ajax应用

在互联网的浩瀚海洋中&#xff0c;信息如同潮水般涌动。对于数据分析师、研究人员或是任何想要从网络上获取信息的人来说&#xff0c;掌握Ajax数据抓取技术无疑是一把开启宝藏的钥匙。Ajax&#xff08;Asynchronous JavaScript and XML&#xff09;不仅让网页交互变得更加流畅&…

【JavaEE进阶】——要想代码不写死,必须得有spring配置(properties和yml配置文件)

目录 本章目标&#xff1a; &#x1f6a9;配置文件 &#x1f6a9;SpringBoot配置文件 &#x1f388;配置⽂件的格式 &#x1f388; properties 配置⽂件说明 &#x1f4dd;properties语法格式 &#x1f4dd;读取配置文件 &#x1f4dd;properties 缺点分析 &#x1f3…

中国科技期刊卓越行动计划重点期刊

https://kyc.webs.nbpt.edu.cn/_upload/article/files/3e/82/8a3a267048079f3487aef8802af8/09b9fa10-af17-4c2c-89ab-aa0facfb107c.pdf

记录下面试(240522)

试问面试有多难&#xff0c;看看今天的面试题。 1、kong都用了哪些插件&#xff1f; 2、zk底层一致性协议是什么&#xff0c;源码级别的理解。 3、mysql mvcc 引擎&#xff0c;索引都怎么实现的。 4、es底层全文索引如何实现的。 4、业务整体架构是什么样的&#xff0c; 5、如何…

操作MySQL数据库

【一】针对库的增删查改&#xff08;文件夹&#xff09; 【1】创建数据库 &#xff08;1&#xff09;语法 创建一个存储数据表的文件夹。 注意&#xff1a;mysql中的编码字符集中utf-8&#xff0c;要换成utf8mb4。SQL语句中的中括号部分表示可选。 create database [if no…

修改uniapp内置组件checkbox的样式

默认情况下 <view style"margin-bottom: 20rpx;"><label style"display: flex;align-items: center;width: fit-content;" click"handleCheck(cxm4s)"><checkbox /><text>车信盟出险4S维保</text></label>…

Java实验08

实验一 demo.java package q8.demo02;public class demo{public static void main(String[] args) {WindowMenu win new WindowMenu("Hello World",20,30,600,290);} }WindowMenu.java package q8.demo02; import javax.swing.*;public class WindowMenu extends…