机器学习常见的sampling策略 附PyTorch实现

初始工作

定义一个模拟的长尾数据集

import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoadernp.random.seed(0)
random.seed(0)
torch.manual_seed(0)
class LongTailDataset(Dataset):def __init__(self, num_classes=25, max_samples_per_class=100):self.num_classes = num_classesself.max_samples_per_class = max_samples_per_class# Generate number of samples for each class inversely proportional to class indexself.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]self.total_samples = sum(self.samples_per_class)# Generate targets for the datasetself.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])def __len__(self):return self.total_samplesdef __getitem__(self, idx):# For simplicity, just return the index as the datareturn idx, self.targets[idx]# Create dataset
batch_size = 64
dataset = LongTailDataset()
print(f'The total number of samples: {len(dataset) // 2}')
print(f'The number of samples per class: {dataset.samples_per_class}')
print(f'The {len(dataset) // 2} th samples of the dataset: {dataset[len(dataset) // 2]}')

Output:

The total number of samples: 187
The number of samples per class: [100, 50, 33, 25, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4]
The 187 th samples of the dataset: (187, tensor(3))

定义一个测试sample一个batch的函数

def test_loader_in_one_batch(test_dataloader: DataLoader, inf: str):print(inf)for (_, target) in test_dataloader:cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True)cls_idx = [int(i) for i in cls_idx]cls_counts = [int(i) for i in cls_counts]print(f'Class indices: {cls_idx}')print(f'Class counts: {cls_counts}')break  # just show one batchprint('-' * 20)

采样介绍

每个类的采样概率可抽象为:\(p_j=\frac{n_jq}{\sum_{i=1}Cn_i^q}\),

  • \(p_j\)表示从j类采样数据的概率;
  • \(C\)表示类别数量;\(n_j\)表示j类样本数;
  • \(q\in\{1,0\}\)

均匀采样

\(q=1\),实例平衡采样(Instance-balanced sampling)(也称uniform sampling),最常见的数据采样方式,每个训练样本被选择的概率相等均为\(\frac{1}{N}\)。对j类的采样,按数据集中j类的基数\(n_j\)进行采样,即\(p{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}Cn_i}\)。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader_in_one_batch(dataloader, inf='Instance-balanced sampling(Default):')

Output:

Instance-balanced sampling(Default):
Class indices: [0, 1, 2, 3, 4, 5, 6, 10, 11, 13, 15, 17, 18, 19, 20, 21, 23]
Class counts: [13, 10, 4, 4, 4, 6, 5, 3, 1, 4, 3, 1, 2, 1, 1, 1, 1]
--------------------

类平衡采样

实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样(Class-balanced sampling)让所有的类有相同的被采样概率(\(q=0\)):\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。

这里具体实现使用很多论文都在使用的 Class Aware Sampler,通过循环过采样,使得batch内每个类别的样本数相等。

import random
from torch.utils.data.sampler import Sampler
import numpy as npclass RandomCycleIter:def __init__(self, data, test_mode=False):self.data_list = list(data)self.length = len(self.data_list)self.i = self.length - 1self.test_mode = test_modedef __iter__(self):return selfdef __next__(self):self.i += 1if self.i == self.length:self.i = 0if not self.test_mode:random.shuffle(self.data_list)return self.data_list[self.i]def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):i = 0j = 0while i < n:if j >= num_samples_cls:j = 0if j == 0:temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls))# next(cls_iter) 会返回一个类别的index,# data_iter_list[next(cls_iter)]会返回list,list内包括该类的所有样本的index# 用*解包上面的list,然后内部每个元素重复 num_samples_cls 次,然后用zip打包,再用next取出yield temp_tuple[j]else:yield temp_tuple[j]i += 1j += 1class ClassAwareSampler(Sampler):def __init__(self, data_source, num_samples_cls=1):super().__init__()num_classes = len(np.unique(data_source.targets))self.class_iter = RandomCycleIter(range(num_classes))  # 返回一个循环迭代器,迭代器每次返回一个类的indexcls_data_list = [list() for _ in range(num_classes)]  # N个类,每个类对应一个listfor i, label in enumerate(data_source.targets):cls_data_list[label].append(i)  # 将每个样本的index按照类别放入对应的listself.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]  # 每个类用循环迭代器包装,返回类内sample的indexself.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)  # 总样本数 = 最大样本数的类的样本数 * 类别数self.num_samples_cls = num_samples_clsdef __iter__(self):return class_aware_sample_generator(self.class_iter, self.data_iter_list,self.num_samples, self.num_samples_cls)def __len__(self):return self.num_samplesdataloader = DataLoader(dataset, batch_size=batch_size, sampler=ClassAwareSampler(dataset))
test_loader_in_one_batch(dataloader, inf='Class-aware sampling:')

Output:

Class-aware sampling:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
Class counts: [2, 2, 2, 3, 2, 3, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 2, 2, 3, 3]
--------------------

类平衡采样的另一种写法(通过调整采样器的类的权重)

最早是把每个类的权重(采样概率)设为样本数倒数:\(p_j=\frac{1}{n_j}\)。[3]提出effective number,对每个类的权重(effective number)调整为:

\[E_n=(1-\beta^n)/(1-\beta),\ \mathrm{where~}\beta=(N-1)/N. \]

并用这个权重调整损失。[4]把这个权重用于采样权重,这里用PyTorch提供的WeightedRandomSampler实现:第一个参数表示每个样本(不是类)的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

from torch.utils.data.sampler import WeightedRandomSamplerdef imbalance_sampler(targets, mode='inverse'):cls_counts = np.bincount(targets)cls_weights = Noneif mode == 'inverse':cls_weights = 1. / cls_countselif mode == 'effective':beta = (len(targets) - 1) / len(targets)cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))assert cls_weights is not Nonereturn WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)modes = ['inverse', 'effective']
for mode in modes:sampler = imbalance_sampler(dataset.targets.numpy(), mode)dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)test_loader_in_one_batch(dataloader, inf=f'{mode.capitalize()}:')

Output:

Inverse:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24]
Class counts: [1, 3, 1, 2, 5, 4, 2, 1, 3, 3, 1, 3, 3, 3, 6, 3, 3, 1, 3, 5, 1, 3, 1, 3]
--------------------
Effective:
Class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 23, 24]
Class counts: [2, 3, 2, 2, 5, 1, 1, 2, 3, 3, 7, 2, 1, 3, 3, 4, 1, 4, 3, 2, 1, 3, 6]
--------------------

实际中,提到类平衡采样Class-Balanced Re-Sampling,两种实现方式都有可能,注意上下文描述和参考文献的引用。

混合采样策略

最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:

\[p_j^{\mathbf{PB}}(t)=(1-\frac tT)p_j^{\mathbf{IB}}+\frac tTp_j^{\mathbf{CB}} \]

t表示当前epoch,T表示总epoch数。

运行环境

pytorch                   2.1.1           py3.11_cuda12.1_cudnn8_0    pytorch

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/855658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息量、香农熵、交叉熵、KL散度的意义

文章目录 1. 信息量2. 香农熵3. 交叉熵4. KL散度(Kullback-Leibler Divergence(Relative Entropy)) 1. 信息量 相关概念&#xff1a; 对于一个事件&#xff1a; 小概率 – > 大的信息量大概率 – > 小的信息量多个独立事件的信息量计算可以直接相加 表达公式&#xf…

开关阀(1):定位器与电磁阀的区别

气动阀门带电磁阀是控制气源开关的&#xff0c;如果装配阀门定位器即为调节型&#xff1f; 一般来说&#xff0c;电磁阀就是控制气源通断的&#xff0c;用于阀门快速的全开或全关&#xff0c;电磁阀是得电与失电是起开关作用。 定位器是控制气源压力的大小&#xff0c;控制阀门…

数据结构---二叉树的性质总结

第i层上的节点数 证明: 二叉树的最大节点数 证明: 第一层对应2^0个节点, 累加得到 这是一个等比数列 求和公式: 那么这里的n指的是一共有多少个相加 根据从b到a一共有b-a1个可推出 有(k-1)-01个相加 那么结果为: 叶节点与度为2的节点关系 证明: 假设二叉树的总节点数为 NNN…

解决动态权限路由页面刷新空白404

需要将任意路由 path: /:pathMatch(.*)* 从固定路由中提取出来&#xff0c;在刷新时&#xff0c;等待用户信息获取完毕&#xff0c;将动态路由和任意路由通过 router.addRoute() 重新添加到路由中 // 固定路由 export const constantRoute [ ... ]// 权限路由 export const …

主键的定义,理解

"主键"是数据库中的一个术语&#xff0c;用于标识数据库表中的每一条记录的唯一标识。主键的特点如下&#xff1a; 唯一性&#xff1a;每个表中的主键值必须是唯一的&#xff0c;这样每条记录都能被准确地识别和检索。不可更改性&#xff1a;一旦定义&#xff0c;主…

常数变易法求解非齐次线性微分方程

文章目录 常数变易法求解一阶非齐次线性微分方程常数变易法求解二阶非齐次线性微分方程例题 常数变易法求解一阶非齐次线性微分方程 对于一阶非齐次线性微分方程 y ′ p ( x ) y q ( x ) y p(x)y q(x) y′p(x)yq(x) 先用分离变量法求解对应的齐次方程 y ′ p ( x ) y 0…

SpelExpressionParser评估SpEL(Spring Expression Language)表达式的解析器

是Spring中用于解析和评估SpEL(Spring Expression Language)表达式的解析器,SpEL是一种强大且灵活的表达式语言,广泛用于Spring框架中,以便在运行时解析和评估表达式 主要功能 1.解析和评估表达式:spelExpressionParser可以解析复杂的表达式,并在运行时对其进行评估; 2.访问…

Linux时间子系统7:sleep timer接口定时实现

1、前言 之前的文章中介绍了Linux时间相关的内容&#xff0c;包括用户态/内核态的时间获取&#xff0c;时间的种类&#xff0c;时钟源等&#xff0c;本篇开始的后续几篇文章将介绍Linux系统关于定时相关的服务&#xff0c;这与之前的内容是高度相关的&#xff0c;本篇还是从应用…

SolidWorks科研版更快地开发产品创意

在当今竞争激烈的市场环境中&#xff0c;产品创新的速度和质量直接决定了企业的生死存亡。对于科研人员和设计师来说&#xff0c;如何能够快速、准确地实现产品创意的转化&#xff0c;是摆在面前的一大挑战。SolidWorks科研版作为一款功能强大的三维设计软件&#xff0c;为科研…

正则表达式之三剑客grep

正则表达式匹配的是文本内容&#xff0c;linux的文本三剑客 都是针对文本内容 grep 过滤文本内容 sed 针对文本内容进行增删改查 awk 按行取列 文本三剑客都是按行进行匹配。 grep grep 的作用就是使用正则表达式来匹配文本内容 选项&#xff1a; -m …

centos查找文件 写入的进程

du -sh * 查看目录空间占用、发现大文件&#xff0c;确定进程&#xff0c;结束 yum install lsof 安装lsof 查看文件写入的 进程 2. lsof /root/.influxdbv2/engine/data/bab49411e5f7cbce/autogen/1/000000036-000000002.tsm COMMAND PID USER FD TYPE …

Ubuntu-基础工具配置

基础工具配置 点击左下角 在弹出界面中点击 以下命令都是在上面这个界面执行&#xff08;请大家注意空格&#xff09; 命令输入完后&#xff0c;回车键就是执行,系统会提示输入密码&#xff08;就是你登录的密码&#xff09; 1.安装net工具 &#xff1a;&#xff08;ifconfi…

vue3-自定义指令来实现input框输入限制

文章目录 前言具体实现分析主要部分详细解析导入和类型定义mounted 钩子函数unmounted 钩子函数指令注册使用 总结 前言 使用vue中的自定义指令来实现input框输入限制 其中关键代码强制触发input &#xff0c;来避免&#xff0c;输入规则外的字符时&#xff0c;没触发vue的响…

无需安装就能一键部署Stable Diffusion 3?

一键部署使用SD3&#xff1f;让你的创作更加便捷&#xff01; 前言 厚德云上架SD3! 距离Stable Diffusion 3的上线已经有一阵时间了。从上线至今SD3也是一直好评不断&#xff0c;各项性能的提升也让它荣获“最强开源新模型”的称号。成为了AI绘画设计师们新的香馍馍。 可对于SD…

短期内股票跌了就难受的人有哪些?

短期内股票跌了难受的人&#xff0c;主要是四类 第一类压根就没有打算长期持有&#xff0c;就是玩短线的。这类人来股市是为了一夜暴富的。 第二类人&#xff0c;这类人也是打算一夜暴富的&#xff0c;但是他们会上杠杆&#xff0c;借钱买股票。股价涨了好说&#xff0c;股价…

python网站地图解析

分析&#xff1a; ⽹站的地图&#xff08;sitemap.xml&#xff09;是⼀个XML⽂件&#xff0c;列出了⽹站上所有可访问的⻚⾯的URL。解析⽹站的地图可以⾼效地发现⽹站上所有的⻚⾯&#xff0c;特别是那些可能不容易通过常规爬⾍发现的⻚⾯。 # Python代码&#xff1a; 以下是⼀…

Mac用虚拟机玩游戏很卡 Mac电脑玩游戏怎么流畅运行 苹果电脑怎么畅玩Windows游戏

对于许多Mac电脑用户而言&#xff0c;他们经常面临一个令人头疼的问题&#xff1a;在虚拟机中玩游戏时卡顿严重&#xff0c;影响了游戏体验。下面我们将介绍Mac用虚拟机玩游戏很卡&#xff0c;Mac电脑玩游戏怎么流畅运行的相关内容。 一、Mac用虚拟机玩游戏很卡 下面我们来看…

嵌入式期末复习--补充(答案来自文心一言)

一、第一章 1、常见的RTOS&#xff0c;嵌入式操作系统的特点 RTOS就是实时操作系统。根据响应时间的不同&#xff0c;可分为以下3类&#xff1a; &#xff08;1&#xff09;强实时嵌入式操作系统 响应时间&#xff1a;微妙或毫秒 &#xff08;2&#xff09;一般实时…

删除重复文件如何操作?电脑重复文件删除教程分享:详细!高效!

在数字化时代&#xff0c;我们的电脑中往往存储着大量的文件&#xff0c;这些文件随着时间的推移可能会产生许多重复项。重复文件不仅占用了宝贵的硬盘空间&#xff0c;还可能导致文件管理的混乱。因此&#xff0c;定期删除重复文件是维护电脑健康和提高工作效率的重要步骤。本…

请问为什么下面的HTML代码没有显示内容?

请问下面的HTML程序为什么没有显示内容&#xff1f; <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>HTML教程()</title> <script>function getTime() {var date new Date();var time date.toLocalString…