25.梯度消失和梯度爆炸

深度学习中的梯度消失与梯度爆炸:定义、原因、解决办法与残差网络

一、引言

在深度学习的训练过程中,梯度消失(Gradient Vanishing)和梯度爆炸(Gradient Exploding)是两个常见且棘手的问题。它们严重阻碍了深层神经网络的训练效率和效果。本文将深入探讨这两个问题的定义、原因、解决办法,并介绍残差网络(ResNet)如何解决这些问题。

二、梯度消失与梯度爆炸的定义

梯度消失

梯度消失指的是在训练深层神经网络时,由于链式法则的连乘效应,当网络层数过深时,梯度在反向传播过程中会逐渐减小到接近于0,导致深层网络的权重无法得到有效的更新。

梯度爆炸

梯度爆炸则相反,指的是在训练深层神经网络时,梯度在反向传播过程中逐渐增大,甚至以指数级速度增长,导致权重更新过大,破坏网络的稳定性。

三、梯度消失与梯度爆炸的原因

链式法则

在反向传播过程中,梯度是通过链式法则逐层传递的。如果网络层数过深,且激活函数的梯度小于1(如Sigmoid函数),那么在多层连续相乘后,梯度会逐渐减小到接近于0,导致梯度消失;而如果梯度大于1,则会导致梯度爆炸。

初始化权重

网络权重的初始化方式也会影响梯度的传播。如果初始权重过大,可能导致梯度在反向传播过程中迅速增大,引起梯度爆炸;如果初始权重过小,则可能导致梯度在传播过程中逐渐减小,引起梯度消失。

四、梯度消失与梯度爆炸的解决办法

1.预训练与微调(Pre-training and Fine-tuning):早期的一种方法,先在一个大型数据集上进行预训练,然后在特定任务上进行微调。这种方法可以减轻梯度消失和爆炸的问题,但现在已经较少使用。

2.梯度裁剪

梯度裁剪是一种直接控制梯度大小的方法。在反向传播过程中,如果梯度的范数超过某个阈值,就将其截断为阈值大小。这样可以有效防止梯度爆炸。

3.使用ReLU激活函数

ReLU(Rectified Linear Unit)激活函数在输入大于0时梯度为1,不会出现梯度消失的问题;而在输入小于0时梯度为0,有助于稀疏化网络。因此,使用ReLU激活函数可以有效缓解梯度消失和梯度爆炸的问题。

4.改进版的ReLU激活函数:为了解决ReLU的缺点,研究者提出了多种改进版的ReLU函数,如Leaky ReLU、Parametric ReLU(PReLU)、Exponential Linear Unit(ELU)等。

5.Batch Normalization

Batch Normalization是一种有效的正则化方法,它通过规范化每一层的输入来加速网络训练。在训练过程中,Batch Normalization会对每一层的输入进行标准化处理,使其具有均值为0、方差为1的分布。这样可以减小梯度对初始权重的依赖,从而缓解梯度消失和梯度爆炸的问题。

6.残差网络(ResNet)

残差网络通过引入残差连接(shortcut connections)来解决梯度消失和梯度爆炸的问题。残差连接允许梯度在反向传播时绕过某些层直接传播到较浅的层,从而有效避免了梯度消失的问题。同时,由于残差连接的存在,网络在训练时可以更容易地学习到恒等映射(identity mapping),这有助于保持网络的稳定性并防止梯度爆炸。

 

五、残差网络(ResNet)的实现

基于残差网络(ResNet)的实现,我们可以进一步探讨其结构、特点以及在实际应用中的优势。以下是对ResNet实现的详细解析:

1. 残差块(Residual Block)

残差块是ResNet的核心组件,它解决了随着网络深度增加出现的性能下降(也称为退化问题)的问题。残差块的设计基于恒等映射(identity mapping)的思想,允许网络在必要时跳过一些层,从而更直接地传播梯度。

残差块的基本结构如下:

  • 包含两个或多个卷积层(以及可能的批量归一化层和激活函数层)。
  • 引入了一个跨层的连接(即shortcut或skip connection),将输入直接连接到输出。

这样的结构可以表示为:

H(x)=F(x)+x

其中,x 是输入,F(x) 是残差函数(即卷积层等结构所学习的映射),H(x) 是最终的输出。

2. 残差网络的构建

ResNet由多个残差块堆叠而成,形成一个深层的神经网络结构。根据具体的任务和网络规模,可以设计不同深度和宽度的ResNet。

在构建ResNet时,需要考虑以下几点:

  • 深度:通常,增加网络深度可以提高性能,但也会增加计算量和过拟合的风险。因此,需要根据任务和数据集的大小选择合适的深度。
  • 宽度:每个残差块的宽度(即卷积层的通道数)也会影响网络的性能。较宽的残差块可以提取更多的特征,但也会增加计算量。
  • 残差块的类型:根据残差块中卷积层的数量和连接方式,可以设计不同类型的残差块,如基本的残差块(包含两个卷积层)和瓶颈残差块(包含三个卷积层,其中第一个和最后一个卷积层的通道数较少,以减少计算量)。

3. 实现细节

在实现ResNet时,需要注意以下细节:

  • 初始化:使用合适的权重初始化方法,如He初始化,可以加速训练并提高模型的性能。
  • 批量归一化:在每个卷积层后添加批量归一化层,可以加速训练并缓解过拟合问题。
  • 激活函数:使用ReLU或类似的激活函数,以增加模型的非线性表达能力。
  • 下采样:在需要减小特征图尺寸时,可以使用步长为2的卷积层或池化层进行下采样。同时,为了确保残差连接能够匹配输入和输出的尺寸,可以在shortcut连接中添加一个额外的卷积层或池化层进行下采样。

4. 应用与优势

ResNet在多个领域都取得了显著的性能提升,特别是在图像分类、目标检测等任务中。其优势主要体现在以下几个方面:

  • 解决了深度神经网络中的退化问题,使得训练更深层的网络成为可能。
  • 通过引入残差连接,缓解了梯度消失和梯度爆炸的问题,提高了模型的训练效率和稳定性。
  • 具有较强的特征提取能力,可以学习到更丰富的层次化特征表示。
  • 具有良好的泛化能力,可以在不同的数据集和任务上取得较好的性能。

总之,ResNet通过引入残差连接的思想,成功解决了深度神经网络中的退化问题,并在多个领域取得了显著的性能提升。其实现细节和应用优势也为我们设计更优秀的深度学习模型提供了有益的参考。

import torch  
import torch.nn as nn  class BasicBlock(nn.Module):  expansion = 1  def __init__(self, in_channels, out_channels, stride=1, downsample=None):  super(BasicBlock, self).__init__()  # 第一个卷积层,不改变通道数  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)  self.bn1 = nn.BatchNorm2d(out_channels)  self.relu = nn.ReLU(inplace=True)  # 第二个卷积层,不改变通道数和步长  self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)  self.bn2 = nn.BatchNorm2d(out_channels)  # 如果需要下采样,则使用1x1卷积改变通道数并降低空间分辨率  self.downsample = downsample  def forward(self, x):  residual = x  # 经过两个卷积层  out = self.conv1(x)  out = self.bn1(out)  out = self.relu(out)  out = self.conv2(out)  out = self.bn2(out)  # 如果需要进行下采样,则对输入x进行同样的操作  if self.downsample is not None:  residual = self.downsample(x)  # 将残差连接添加到输出上  out += residual  out = self.relu(out)  return out  class ResNet(nn.Module):  def __init__(self, block, layers, num_classes=10):  super(ResNet, self).__init__()  # 输入为3通道的图像,大小为224x224  self.in_channels = 64  # 初始的卷积层  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  self.bn1 = nn.BatchNorm2d(64)  self.relu = nn.ReLU(inplace=True)  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 构建残差块  self.layer1 = self._make_layer(block, 64, layers[0])  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # 全连接层进行分类  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  self.fc = nn.Linear(512 * block.expansion, num_classes)  def _make_layer(self, block, out_channels, blocks, stride=1):  downsample = None  if stride != 1 or self.in_channels != out_channels * block.expansion:  downsample = nn.Sequential(  nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),  nn.BatchNorm2d(out_channels * block.expansion)  )  layers = []  layers.append(block(self.in_channels, out_channels, stride, downsample))  self.in_channels = out_channels * block.expansion  for _ in range(1, blocks):  layers.append(block(self.in_channels, out_channels))  return nn.Sequential(*layers)  def forward(self, x):  out = self.conv1(x)  out = self.bn1(out)  out = self.relu(out)  out = self.maxpool(out)  # 传递输入到各个残差层  out = self.layer1(out)  out = self.layer2(out)  out = self.layer3(out)  out = self.layer4(out)  # 对输出进行全局平均池化,展平  out = self.avgpool(out)  out = torch.flatten(out, 1)  # 全连接层进行分类  out = self.fc(out)  return out  # 示例:定义一个ResNet18  
def resnet18(num_classes=1000):  return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)  # 实例化ResNet18模型  
model = resnet18(num_classes=10)  # 假设有10个类别  # 打印模型结构  
print(model)  # 如果你有数据的话,可以继续编写代码进行训练  
# 例如,加载数据集、定义损失函数、优化器、训练循环等  # 示例:定义损失函数和优化器(这里只是示例,你需要根据实际情况设置)  
# criterion = nn.CrossEntropyLoss()  
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 注意:上面的代码只是一个示例,并没有实际的数据加载和训练过程。  
# 在实际使用中,你需要添加数据加载、训练循环、验证等步骤来完整实现ResNet的训练。

以上代码定义了一个简单的ResNet模型,并给出了一个实例化ResNet18的示例。ResNet18包含4个残差层,每个层包含2个BasicBlock。你可以根据实际需求调整层数和每层的Block数量来构建不同深度的ResNet模型。同时,你还需要定义损失函数和优化器,并编写数据加载和训练循环的代码来完成模型的训练过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28360.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python时间序列分析库

Sktime Welcome to sktime — sktime documentation 用于ML/AI和时间序列的统一API,用于模型构建、拟合、应用和验证支持各种学习任务,包括预测、时间序列分类、回归、聚类。复合模型构建,包括具有转换、集成、调整和精简功能的管道scikit学习式界面约定的交互式用户体验Pro…

一个比官方strings.Title更精简高效的将字符串中所有单词首字母转换为大小写的go函数

在go语言的官方包 strings中,官方提供了Title函数用于将字符串中的单词首字母转换为大写,这个函数很绕,对于要转换的字符串先是一个Map循环,然后接着又是一个Map循环,且函数调函数掉了好多层,而且最新版本中已经标记为…

【原创】springboot+mysql小区用水监控管理系统设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

介绍一个 SpringBoot 集成各种场景的项目

springboot-demo 今天给大家介绍一个 SpringBoot 集成各种场景的项目,可以用来学习,也可以开箱即用,无需重复造轮子!包含中英文使用说明文档 a simple springboot demo with some components for example: redis,solr,rockmq an…

洛谷 P3379:最近公共祖先(LCA)← 倍增+链式前向星

【题目来源】https://www.luogu.com.cn/problem/P3379【题目描述】 如题,给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。【输入格式】 第一行包含三个正整数 N,M,S,分别表示树的结点个数、询问的个数和树根结点的序号。 接下来 N−…

AI 定位!GeoSpyAI上传一张图片分析具体位置 不可思议! ! !

🏡作者主页:点击! 🤖常见AI大模型部署:点击! 🤖Ollama部署LLM专栏:点击! ⏰️创作时间:2024年6月16日12点23分 🀄️文章质量:94分…

动态规划日常刷题

力扣70.爬楼梯 class Solution {public int climbStairs(int n) {return dfs(n);}//递归 //注意每次你可以爬 1 或 2 个台阶//如果最后一步是1 就先爬到n-1 把它缩小成0-n-1的范围//如果最后一步是2 就先爬到n-2 把它缩小成0-n-2的范围 private int dfs(int i){if(i < 1){r…

代码随想录算法训练营第五十九天|115.不同的子序列、 583. 两个字符串的删除操作、72. 编辑距离、编辑距离总结篇

代码随想录算法训练营第五十九天 115.不同的子序列 题目链接&#xff1a;115.不同的子序列 确定dp数组以及下标的含义&#xff1a;dp[i][j] &#xff1a;以下标i - 1为结尾的s&#xff0c;和以下标j - 1为结尾的t&#xff0c;s中t的个数dp[i][j]确定递推公式&#xff1a; s[…

vue3 如何给表单添加表单效验+正则表达式

校验要求 我们的表单中有密码、电话号码 &#xff0c;两项。 我们设置用密码为3到20位的非空字符 电话号码就用目前用的电话号码正则表达式&#xff0c;要求手机号码以 1 开头&#xff0c;第二位为 3 到 9 之间的数字&#xff0c;后面跟着任意 9 个数字&#xff0c;总共是 11…

Intel HDSLB 高性能四层负载均衡器 — 代码剖析和高级特性

目录 文章目录 目录前言代码剖析软件架构目录结构配置解析启动流程分析数据面 jobs 注册数据面 jobs 执行 转发流程分析收包阶段L2 处理阶段L3 处理阶段L4 处理阶段 高级特性大象流转发优化快慢路径分离转发优化报文基础转发优化 最后参考文档 前言 在前 2 篇文章中&#xff0…

超详细的描述UItralytics中的特征增强方法

目录 yolov8导航 YOLOv8(附带各种任务详细说明链接) 各个特征增强参数概述 各个参数优缺点与调整技巧 1. hsv_h (色调调整) 2. hsv_s (饱和度调整) 3. hsv_v (亮度调整) 4. degrees (图像旋转) 5. translate (图像平移) 6. scale (图像缩放) 7. shear (图像剪切) …

Qt第三方库QHotKey设置小键盘数字快捷键

一、看了一圈没有找到可以设置小键盘的情况。 这两天在研究快捷键的使用。发现qt的里的快捷键不是全局的。找了两个第三方快捷键QHotKey&#xff0c;还有一个QxtGlobalShortcut。但是这两个都不能设置小键盘的数字。 比如QKeySequenceEdit &#xff08;Ctrl1&#xff09; 这个…

【SpringBoot】SpringBoot:构建实时聊天应用

文章目录 引言项目初始化添加依赖 配置WebSocket创建WebSocket配置类创建WebSocket处理器 创建前端页面创建聊天页面 测试与部署示例&#xff1a;编写单元测试 部署扩展功能用户身份验证消息持久化群组聊天 结论 引言 随着实时通信技术的快速发展&#xff0c;聊天应用在现代We…

Luma AI如何注册:文生视频领域的新星

文章目录 Luma AI如何注册&#xff1a;文生视频领域的新星一、Luma 注册方式二、Luma 的效果三、Luma 的优势四、Luma 的功能总结 Luma AI如何注册&#xff1a;文生视频领域的新星 近年来&#xff0c;Luma AI 凭借其在文生视频领域的创新技术&#xff0c;逐渐成为行业的新星。…

MySQL基础——多表查询和事务

目录 1多表关系 2多表查询概述 3连接查询 3.1内连接 3.2左外连接 3.3右外连接 3.4自连接 4联合查询 5子查询 5.1标量子查询(子查询结果为单个值) 5.2列子查询(子查询结果为一列) 5.3行子查询(子查询结果为一行) 5.4表子查询(子查询结果为多行多列) 6事务简介和操…

vulnhub靶场-xxe打靶教程

目录 靶机导入 信息收集 发现IP 端口扫描 目录扫描 漏洞利用 靶机下载地址&#xff1a;XXE Lab: 1 ~ VulnHub 靶机导入 导入虚拟机 开启虚拟机 信息收集 发现IP arp-scan -l 发现靶机IP是192.168.202.150 端口扫描 使用nmap进行扫描 nmap -sS -A 192.168.202.150 …

EasyRecovery2024数据恢复神器#电脑必备良品

EasyRecovery数据恢复软件&#xff0c;让你的数据重见天日&#xff01; 大家好&#xff01;今天我要给大家种草一个非常实用的软件——EasyRecovery数据恢复软件&#xff01;你是不是也曾经遇到过不小心删除了重要的文件&#xff0c;或者电脑突然崩溃导致数据丢失的尴尬情况呢&…

初识PHP

一、格式 每行以分号结尾 <?phpecho hello; ?>二、echo函数和print函数 作用&#xff1a;两个函数都是输出内容到页面中&#xff0c;多用于代码调试。 <?php echo "<h1 styletext-align: center;>test</h1>"; print "<h1 stylet…

【vue3中使用$refs】

在使用uniapp官网里的uni-popup弹出层组件时&#xff0c;要将vue2转换成vue3,&#xff0c;这里遇到了一个问题&#xff1a;vue2可以通过this访问到绑定的ref&#xff0c;但是vue3没有了this,应该怎么办呢&#xff1f; 解决方法&#xff1a; !

Footer组件在home 、search 显示,在登录、注册隐藏

footer组件显示与隐藏 我们可以根据组件身上的$route获取当前路由的信息&#xff0c;通过路由路径判断Footer显示与隐藏。配置的路由的时候&#xff0c;可以给路由添加路由元信息【meta】&#xff0c;路由需要配置对象&#xff0c;它的key不能瞎写、胡写、乱写 <template&…