前段时间想把我模型的输入由DWT子带改为分块的图像块,一顿魔改后,模型跑着跑着损失就朝着奇怪的方向跑去了:要么突然增大,要么变为NAN。
为什么训练损失会突然变为NAN呢?这个作者将模型训练过程中loss为NAN或INF的原因解释得好好详尽(感谢):Pytorch训练模型损失Loss为Nan或者无穷大(INF)原因_pytorch loss nan-CSDN博客https://blog.csdn.net/ytusdc/article/details/122321907 我经过输入几番输入打印测试,确认我的输入确实没有问题,那么问题只能出现在模型的前向传播或者反向梯度传播过程中。我跟着这个作者的排查思路,最终定位问题出在梯度反向传播上,于是通过梯度剪裁成功解决NAN问题(我还增大了batch_size的大小,输入修改后,我发现模型运算量减小了,显存支持我每个step跑更大的batch_size了)。pytorch训练过程中出现nan的排查思路_torch判断nan-CSDN博客https://blog.csdn.net/mch2869253130/article/details/111034068修改部分:
if mode == 'train':# # 1.debug loss# assert torch.isnan(total_loss).sum() == 0, print(total_loss)total_loss.backward()# # 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试nn.utils.clip_grad_norm(net.parameters(), max_norm=3, norm_type=2)optim.step()optim.zero_grad()
梯度剪裁:
对超出值域范围的梯度进行约束,避免梯度持续大于1,造成梯度爆炸。(没办法规避梯度消失)
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type)
- parameters参数是需要进行梯度裁剪的参数列表。通常是模型的参数列表,即model.parameters();
- max_norm参数可以理解为梯度(默认是L2 范数)范数的最大阈值;
- norm_type参数可以理解为指定范数的类型,比如norm_type=1 表示使用L1 范数,norm_type=2 表示使用L2 范数。
【Pytorch】梯度裁剪——torch.nn.utils.clip_grad_norm_的原理及计算过程-CSDN博客https://blog.csdn.net/m0_46412065/article/details/131396098?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522170435889016800215059432%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=170435889016800215059432&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-131396098-null-null.142^v99^pc_search_result_base7&utm_term=%E6%A2%AF%E5%BA%A6%E5%89%AA%E8%A3%81&spm=1018.2226.3001.4187