模型训练中出现loss为NaN怎么办?

在这里插入图片描述

文章目录

  • 一、模型训练中出现loss为NaN原因
    • 1. 学习率过高
    • 2. 梯度消失或爆炸
    • 3. 数据不平衡或异常
    • 4. 模型不稳定
    • 5. 过拟合
  • 二、 针对梯度消失或爆炸的解决方案
    • 1. 使用`torch.autograd.detect_anomaly()`
    • 2. 使用 torchviz 可视化计算图
    • 3. 检查梯度的数值范围
    • 4. 调整梯度剪裁
  • 三、更具体的办法
    • 3.1 可能导致梯度爆炸的部分
    • 3.2 解决方案

一、模型训练中出现loss为NaN原因

1. 学习率过高

在训练的某个阶段,学习率可能设置得过高,导致模型参数更新幅度过大,甚至可能出现数值不稳定的情况。你可以尝试降低学习率,并观察训练过程中的变化。

2. 梯度消失或爆炸

如果模型的某些层出现梯度消失或爆炸的问题,可能会导致loss变得异常低。你可以检查梯度的大小,确保它们在合理范围内。

3. 数据不平衡或异常

训练数据中可能存在异常值或分布不平衡的情况,导致模型在某些批次的训练过程中出现异常。你可以检查数据集,确保数据质量。

4. 模型不稳定

模型架构或训练过程中的某些设置可能导致不稳定,比如过深的网络、过复杂的模型等。你可以尝试简化模型架构或添加正则化项。

5. 过拟合

模型可能在某些阶段已经过拟合到训练数据上,导致训练loss异常低而验证loss较高。你可以通过早停法(early stopping)、正则化、数据增强等方法来缓解过拟合问题。
解决方法

  1. 调节学习率:适当降低学习率,观察训练过程中的变化。
  2. 检查梯度:通过torch.autograd检查梯度的大小,确保没有出现梯度消失或爆炸。
  3. 数据检查:确保数据集没有异常值或分布不平衡的情况。
  4. 模型架构:简化模型架构,增加正则化项,如L2正则化、dropout等。
  5. 验证集监控:通过监控验证集的loss和指标,防止过拟合。\

二、 针对梯度消失或爆炸的解决方案

使用 torch.autograd.detect_anomaly() 和相关工具确实可以帮助你检测和排除训练过程中出现的梯度问题。以下是如何在你的代码中使用这些工具来检测异常和可视化梯度的示例。

1. 使用torch.autograd.detect_anomaly()

这个函数可以帮助检测反向传播过程中出现的异常,并输出具体的错误信息和位置。

import torch# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()

2. 使用 torchviz 可视化计算图

torchviz 是一个可以帮助你可视化计算图的工具,这对于调试复杂的模型非常有用。

首先,安装 torchviz:

pip install torchviz

然后,可以使用以下代码来生成和保存计算图:

from torchviz import make_dot# 定义模型
model = MyModel()# 输入数据
inputs = torch.randn(56, 1024, 28, 28)# 获取模型输出
outputs = model(inputs)# 创建计算图
dot = make_dot(outputs, params=dict(model.named_parameters()))# 保存计算图
dot.format = 'png'
dot.render('model_graph')

3. 检查梯度的数值范围

你可以在每个训练步骤之后检查模型中各个参数的梯度,以确保梯度的数值范围正常。

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 检查梯度数值范围for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')optimizer.step()

4. 调整梯度剪裁

在训练过程中,可以使用梯度剪裁来防止梯度爆炸。以下是如何在 PyTorch 中实现梯度剪裁的示例:

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 梯度剪裁torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()

通过以上方法,可以更好地检测和调试训练过程中出现的梯度问题,提高模型的训练稳定性和效率。如果在使用过程中发现任何异常或需要进一步调试,请随时提供更多细节。

三、更具体的办法

3.1 可能导致梯度爆炸的部分

  1. ReLU 激活函数的使用:激活函数可参考激活函数汇总
    ReLU 是一种常见的激活函数,但如果输入有较大的正值,经过 ReLU 之后,这些值会直接传递下去,可能导致后续层的梯度爆炸。考虑使用其他激活函数,如 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

    embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
    
  2. 特征插值:
    插值操作可能会生成较大的值,尤其是在上采样过程中。如果插值后的值过大,可能会导致梯度爆炸。
    upsample_feat = F.interpolate(feat_high, scale_factor=2., mode=‘nearest’)

  3. 特征拼接:
    多个特征拼接后,如果这些特征值过大,会导致拼接后的张量值过大,进而影响后续层的梯度。

    inner_out = self.fpn_blocks[len(proj_feats) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
    
  4. 全连接层:
    全连接层的权重初始化方式可能会导致梯度爆炸。确保使用了合适的初始化方法,如 Xavier 初始化或 He 初始化。

  5. 权重共享:
    如果多个部分共享权重,需要确保这些共享权重不会导致梯度的累积效应。

3.2 解决方案

  1. 梯度剪裁:
    在反向传播过程中使用梯度剪裁,可以防止梯度爆炸。你可以在 optimizer.step() 之前加上梯度剪裁。

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  2. 使用更稳定的激活函数:
    尝试使用 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

  3. 检查权重初始化:
    确保所有层的权重初始化方式合理,避免初始值过大。

    for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
  4. 监控梯度值:
    在每次反向传播后,监控梯度的值,确保梯度不会爆炸。

    for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')
    

Enjoy~

∼ O n e p e r s o n g o f a s t e r , a g r o u p o f p e o p l e c a n g o f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim One person go faster, a group of people can go further

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/47679.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app开发日志:unicloud使用时遇到的问题解决汇总(不断补充)

插件安装后提示与原数据库表冲突(2024.7.18) 安装uni-admin后再安装uni-cms,在uni-admin中添加好菜单,结果提示该错误 回到hbuilder中uniCloud/database中找到冲突的部分 比较一下,选中老的删除 opendb-news-articl…

HarmonyOS根据官网写案列~ArkTs从简单地页面开始

Entry Component struct Index {State message: string 快速入门;build() {Column() {Text(this.message).fontSize(24).fontWeight(700).width(100%).textAlign(TextAlign.Start).padding({ left: 16 }).fontFamily(HarmonyHeiTi-Bold).lineHeight(33)Scroll() {Column() {Ba…

eclipse免安装版64位 2018版本

前言 eclipse是一个开放源代码的、基于Java的可扩展开发平台。就其本身而言,它只是一个框架和一组服务,用于通过插件组件构建开发环境。 一、下载地址 下载地址:http://source/download 选择如下图红色框文件内容下载 二、安装步骤 1、…

Day16_集合与迭代器

Day16-集合 Day16 集合与迭代器1.1 集合的概念 集合继承图1.2 Collection接口1、添加元素2、删除元素3、查询与获取元素不过当我们实际使用都是使用的他的子类Arraylist!!! 1.3 API演示1、演示添加2、演示删除3、演示查询与获取元素 2 Iterat…

分词任务介绍-(十)

分词任务 中文分词正向最大匹配实现方式一实现方式二 反向最大匹配双向最大匹配jieba分词上述分词方法的缺点总结基于机器学习 总结分词技术经验总结 中文分词 正向最大匹配 分词的步骤 1.收集整理一个词表,类似于字典。如下图 2.对于待分词的句子,或者…

机器学习西瓜书笔记(一)

机器学习西瓜书笔记 第一章(chapter 1) 绪论 (参考机器学习西瓜书)第一节(section 1)引言第二节(section 2)基本术语第三节(section 3)假设空间第四节(section 4)归纳偏好第五节(section 5)发展历程第六节(section 6)应用现状第一章(chapter 1) 绪论 (参考机…

C++编程小游戏------斗罗大陆(1)魂力测评和武魂觉醒

#include <bits/stdc.h> #include <windows.h> using namespace std; string name,wh; int hl,wh1; int gj50,fy50,jy5000,hp60; int main() { // 共十个武魂["昊天锤","蓝电霸王龙","七杀剑","火凤凰","尖尾雨燕&qu…

麒麟系统arm架构上部署开发环境。

今天早早来到公司&#xff0c;这也是再公司搬址前在老地址待得最后一天&#xff0c;昨天把前面重要的一个任务也完成的差不多了&#xff0c;遂现在记录一下。 收到任务&#xff0c;要将公司的开发环境和生产环境配置在银河麒麟v10服务器上。这个服务器是向华为申请得到的&…

实战:Eureka的概念作用以及用法详解

概叙 什么是Eureka&#xff1f; Netflix Eureka 是一款由 Netflix 开源的基于 REST 服务的注册中心&#xff0c;用于提供服务发现功能。Spring Cloud Eureka 是 Spring Cloud Netflix 微服务套件的一部分&#xff0c;基于 Netflix Eureka 进行了二次封装&#xff0c;主要负责…

高性能分布式IO系统BL205 OPC UA耦合器

边缘计算是指在网络的边缘位置进行数据处理和分析&#xff0c;而不是将所有数据都传送到云端或中心服务器&#xff0c;这样可以减少延迟、降低带宽需求、提高响应速度并增强数据安全性。 钡铼BL205耦合器就内置边缘计算功能&#xff0c;它不依赖上位机和云平台&#xff0c;就能…

数据仓库实践:使用 SQL 计算材料BOM成本单价

背景 在制造业财务数据分析建设过程中&#xff0c;有时需要通过BOM汇总计算材料的单价&#xff0c;一般会有采购核价&#xff0c;库存成本&#xff0c;还有下阶材料单价按用量汇总得到的单价参与。 这些单价来源一般会根据优先级获取并在计算后作为最终的BOM 单价结果。参与财…

iOS ------ 编译链接

编译流程分析 编译可以分为四步&#xff1a; 预处理&#xff08;Prepressing)编译&#xff08;Compilation&#xff09;汇编 &#xff08;Assembly)链接&#xff08;Linking&#xff09; 预编译&#xff08;Prepressing&#xff09; 过程是源文件main.c和相关头文件被&#…

window11 部署llama.cpp并运行Qwen2-0.5B-Instruct-GGUF

吾名爱妃&#xff0c;性好静亦好动。好编程&#xff0c;常沉浸于代码之世界&#xff0c;思维纵横&#xff0c;力求逻辑之严密&#xff0c;算法之精妙。亦爱篮球&#xff0c;驰骋球场&#xff0c;尽享挥洒汗水之乐。且喜跑步&#xff0c;尤钟马拉松&#xff0c;长途奔袭&#xf…

FastAPI 学习之路(五十九)封装统一的json返回处理工具

在本篇文章之前的接口&#xff0c;我们每个接口异常返回的数据格式都不一样&#xff0c;处理起来也没有那么方便&#xff0c;因此我们可以封装一个统一的json。 from fastapi import status from fastapi.responses import JSONResponse, Response from typing import Unionde…

= null 和 is null;SQL中关于NULL处理的4个陷阱;三值逻辑

一、概述 1、NULL参与的所有的比较和算术运算符(>,,<,<>,<,>,,-,*,/) 结果为unknown&#xff1b; 2、unknown的逻辑运算(AND、OR、NOT&#xff09;遵循三值运算的真值表&#xff1b; 3、如果运算结果直接返回用户&#xff0c;使用NULL来标识unknown 4、如…

Go语言并发编程-Channel通信_2

Channel通信 Channel概述 不要通过共享内存的方式进行通信&#xff0c;而是应该通过通信的方式共享内存 这是Go语言最核心的设计模式之一。 在很多主流的编程语言中&#xff0c;多个线程传递数据的方式一般都是共享内存&#xff0c;而Go语言中多Goroutine通信的主要方案是Cha…

JavaEE:Lombok工具包的使用以及EditStarter插件的安装

Lombok是一个Java工具库&#xff0c;通过添加注解的方式&#xff0c;简化Java的开发。 目录 1、引入依赖 2、使用 3、原理解释 4、更多使用 5、更快捷的引入依赖 1、引入依赖 <dependency><groupId>org.projectlombok</groupId><artifactId>lomb…

pdf提取其中一页怎么操作?提取PDF其中一页的方法

pdf提取其中一页怎么操作&#xff1f;需要从一个PDF文件中提取特定页码的操作通常是在处理文档时常见的需求。这种操作允许用户选择性地获取所需的信息&#xff0c;而不必操作整个文档。通过选择性提取页面&#xff0c;你可以更高效地管理和利用PDF文件的内容&#xff0c;无论是…

Linux编辑器——vim的使用

目录 vim的基本概念 命令模式 底行模式 插入模式 注释和取消注释 普通用户进行sudo提权 vim配置问题 vim的基本概念 一般使用的vim有三种模式&#xff1a; 命令模式 底行模式和插入模式&#xff0c;可以进行转换&#xff1b; vim filename 打开vim&#xff0c;进入的…

ffmpeg ffplay.c 源码分析

1 ffplay.c的意义 ffplay.c是FFmpeg源码⾃带的播放器&#xff0c;调⽤FFmpeg和SDL API实现⼀个⾮常有⽤的播放器。 例如哔哩哔哩著名开源项⽬ijkplayer也是基于ffplay.c进⾏⼆次开发。 ffplay实现了播放器的主体功能&#xff0c;掌握其原理对于我们独⽴开发播放器⾮常有帮助…