文章目录
- 深度学习中的 Batch 概念
- 为什么关注批次内的类别分布?
- 合理的类别分布策略
- 1. 保持与整体数据集的类别比例一致
- 2. 固定每个类别的采样数量
- 3. 动态采样(自适应采样)
- 不同场景下的选择
- Batch 大小与类别数之间的关系
- 结语
- 使用 PyTorch 的 `WeightedRandomSampler` 来平衡批次类别分布
- 代码示例
- 运行结果与讲解
- 自定义 `Sampler` 的思路
深度学习中的 Batch 概念
在深度学习的训练过程中,我们通常不会把整个数据集一次性送入模型进行前向传播和反向传播,而是将数据划分成多个批次(batch)
进行迭代训练。每一个批次包含了若干条训练样本,批次的大小
即为我们常说的 batch_size。其目的是为了在保证一定的计算效率的同时,让模型在每个迭代过程中可以对数据进行一定程度的采样
,从而更好地学习到数据中的特性。
但是,batch_size 并不只影响计算效率和显存使用,它还有另外一个关键影响因素,就是 批次内数据的类别分布是否均衡。在分类任务中,若批次内类别分布与真实数据分布差异过大,可能导致模型在训练时受到的梯度更新不稳定
,甚至在某些训练轮数里过度偏向某些类别。这会使得模型整体的收敛过程变得较为困难
,影响模型的最终表现。
为什么关注批次内的类别分布?
对于一个多分类任务(例如有 10 个不同类别要识别),如果我们使用随机采样的方式在每个 batch 中抽取数据,理论上这能够让批次平均下来与整体数据分布相