Debug系列 GroupNorm和BatchNorm出现Nan或inf的情况
- 前言
- 这两个函数做了什么
- 可能出现的问题
- 解决方法
- train和eval
- batchsize或channel设置过小
- 可训练参数的问题
- 数值溢出
- 其它的方法
前言
在复现别人论文的实验结果时,按照README乖乖做完之后,却发现损失函数的走向十分诡异,具体表现为在cifar10数据集上运行2万步以前,一切风平浪静,祥和美好,但是突然loss就全部变为nan了。令人费解。在Debug两天后,终于发现并修改了问题。令人发指。
在此记录,也希望各位以后遇到此类问题,能够快速解决提供参考。
如果你也想试试,就看这个论文 论文连接。
这两个函数做了什么
x ′ = x − u i σ i 2 + ϵ ∗ W + γ x'=\frac{x - u_i}{\sqrt{\sigma_i^2+\epsilon}}*W + \gamma x′=σi2+ϵx−ui∗W+γ
Normalize实际上是将当前数据 x x x的分布转换为一个标准正态分布,即均值 u i = 0 u_i=0 ui=0,方差 σ i = 1 \sigma_i=1 σi=1。后面的 W , γ W,\gamma W,γ是可以训练的参数, ϵ \epsilon ϵ是为了实现数学稳定,在分母上增添的小量。
在神经网络中添加这个层,可以改变数据的分布,从而使得不同阶段数据分布相同,从而实现加速神经网络收敛,提高表达能力的效果。
需要说明的是,这两个函数会受到model.eval()和model.train()的影响。
这两个函数本身的区别在于,一个是按batch进行normalization,另一个则是按照channel进行normalization。
可能出现的问题
根据公式,我们会很容易发现,每一轮次,当前层都会计算当前数据batch中的均值和方差。那么显而易见,当数据分布不合理,或出现问题时,均值和方差有可能计算得到inf或nan。
此外为什么有时候,在训练的时候没问题,测试的时候却出现问题了呢?
还有为什么我的模型loss很好,输出却总是黑色图像呢?
为什么loss一开始很好,突然就不好了呢?
解决方法
train和eval
最容易解决的方式便是,如果发现模型只在测试中出现问题,那么不妨一直采用model.train的模式来进行预测。
batchsize或channel设置过小
这两个网络层每次计算实际上是会根据当前数据的batchsize或channel数进行的,那么过小的batchsize当然会导致算法不稳定。一般来说单GPU,batchsize定在32是一个合理的范围,而channel的话我是采用了128以上。
可训练参数的问题
这一说法的话,直接冻结参数即可,以免受到模型收敛的影响。
数值溢出
对了,这就是我最想说的!!!!
你可能会觉得奇怪,这为什么会数值溢出呢?如果没有觉得奇怪,可能大多数人都会想到,是不是 σ = 0 \sigma=0 σ=0?但是转念一想,我不是加了一个 ϵ \epsilon ϵ吗?是的,一开始我也是这么认为的。但是在经过不断Debug的过程中,发现这都不是问题的关键。
可能聪明的你猜到了,这可能与输入数据x有关,但是x怎么会导致数值溢出呢?这是因为,在Normalization的过程中,计算得到的均值和方差可能是很大的值,而这个值的大小如果一旦超过这两个变量本身能表达的最大大小,那么就会导致数值溢出,从而产生inf和nan。
这里可能与python的直觉不符,但是由于pytorch中存在 torch.float16之类的类型,其最大能表示的数值甚至不超过1e6,太抽象了。因此,需要手动将变量的type变为torch.float32以上,才能更好的计算。
如果还是不行,附以下代码,是我手写的GroupNorm,里面可以用clip强行划定 σ \sigma σ的值,从而避免出现inf,当然如果方差本来就是很大,这样clip实际上并不能使数据的分布转变为标准正态分布,但是在超不是特别多的情况下,这种clip实际上是一种妥协的方法,最差情况,即方差接近无限大,那么这样clip实际上相当于将整个GroupNorm的功能给作废。这并不影响整体模型效果,无非是GroupNorm的优良性质没了。
class GroupNorm(nn.GroupNorm):def forward(self, x):if len(x.shape) == 4:N,C,H,W = x.size()G = self.num_groupsassert C % G == 0x = x.view(N,G,-1)mean = x.mean(-1, keepdim=True)var = th.clip(x.var(-1, keepdim=True), max=1e7)x = (x-mean) / (var+self.eps).sqrt()x = x.view(N,C,H,W)else:N,C,H = x.size()G = self.num_groupsassert C % G == 0x = x.view(N,G,-1)mean = x.mean(-1, keepdim=True)var = th.clip(x.var(-1, keepdim=True), max=1e7)x = (x-mean) / (var+self.eps).sqrt()x = x.view(N,C,H)return x
其它的方法
之前看到有一个说法是可以冻结均值和方差????不知道怎么做,可以查一下。