深度学习模型数值稳定性——梯度衰减和梯度爆炸的说明

文章目录

      • 0. 前言
      • 1. 为什么会出现梯度衰减和梯度爆炸?
      • 2. 如何提高数值稳定性?
        • 2.1 随机初始化模型参数
        • 2.2 梯度裁剪(Gradient Clipping)
        • 2.3 正则化
        • 2.4 Batch Normalization
        • 2.5 LSTM?Short Cut!

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文的主旨是说明深度学习网络模型中关于数值稳定性的常见问题:梯度衰减(vanishing)和爆炸(explosion),以及常见的解决方法。

本文的部分内容、观点及配图借鉴了多伦多大学计算机科学学院讲座——Lecture 15: Exploding and Vanishing Gradients内容,以及Dive into deep learning第3.15章节《数值稳定性和模型初始化》。

1. 为什么会出现梯度衰减和梯度爆炸?

用下面简化的全连接神经元网络讲解,这个全连接神经元网络每层只有一个神经元,可以看作是一串神经元连接而成的网络。

在这里插入图片描述

在前向传播中,由于数值的传递需要经过非线性的激活函数 σ ( ) \sigma() σ()(例如Sigmoid、Tanh函数),其数值大小被限制住了,因此前向传播一般不存在数值稳定性的问题

在反向传播中,例如求解输出 y y y对权重 w 1 w_1 w1的偏导为:
∂ y ∂ w 1 = σ ′ ( z n ) w n ⋅ σ ′ ( z n − 1 ) w n − 1 ⋅ ⋅ ⋅ σ ′ ( z 1 ) x \frac{\partial y}{\partial w_1}=\sigma'(z_n)w_n · \sigma'(z_{n-1})w_{n-1} ··· \sigma'(z_{1})x w1y=σ(zn)wnσ(zn1)wn1⋅⋅⋅σ(z1)x
z n = { w n ⋅ h n − 1 + b n , n > 1 w 1 ⋅ x + b 1 , n = 1 z_n= \left \{\begin{array}{cc} w_n·h_{n-1}+b_n, & n>1\\ w_1·x+b_1, & n=1 \end{array} \right. zn={wnhn1+bn,w1x+b1,n>1n=1
这里就可以看出,如果权重 w n w_n wn的初始选择不合理,或者 w n w_n wn在逐渐优化过程中,出现导致 σ ′ ( z n ) w n \sigma'(z_n)w_n σ(zn)wn大部分或全部大于1或者小于1的情况,且网络足够深,就会导致反向传播的偏导出现数值不稳定——梯度衰减或者梯度爆炸。

再简化点理解,假设 σ ′ ( z n ) w n = 0.8 \sigma'(z_n)w_n=0.8 σ(zn)wn=0.8,有50层网络深度, 0. 8 50 = 0.000014 0.8^{50}=0.000014 0.850=0.000014;假设 σ ′ ( z n ) w n = 1.2 \sigma'(z_n)w_n=1.2 σ(zn)wn=1.2,有50层网络深度, 1. 2 50 = 9100 1.2^{50}=9100 1.250=9100

参考Lecture 15: Exploding and Vanishing Gradients的另一种解释数值稳定性的方法是:深度学习网络类似于非线性方程的迭代使用,例如 f ( x ) = 3.5 x ( 1 − x ) f(x)=3.5x(1-x) f(x)=3.5x(1x)经过多次迭代 y = f ( f ( ⋅ ⋅ ⋅ f ( x ) ) ) y=f(f(···f(x))) y=f(f(⋅⋅⋅f(x)))后的情况如下图:
在这里插入图片描述
可见,非线性函数再经历多次迭代后会呈现复杂且混沌的表现,在这个实例中仅经历6次迭代后就出现了偏导很大的情况(对应梯度爆炸)。

我们也应该注意到经历6次迭代后也出现了 ∂ y ∂ x ≈ 0 \frac{\partial y}{\partial x}≈0 xy0的区域(对应梯度衰减)。

2. 如何提高数值稳定性?

2.1 随机初始化模型参数

这是最简单、最常用的对抗梯度衰减和梯度爆炸的方法。上文已经说明: σ ′ ( z n ) w n \sigma'(z_n)w_n σ(zn)wn大部分或全部大于1或者小于1的情况,且网络足够深,就容易发生数值不稳定的情况。如果随机初始化模型参数,就会很大程度上避免因为 w n w_n wn的初始选择不合理导致的梯度衰减或爆炸。

Xavier随机初始化是一种常用的方法:假设某隐藏层输入个数为 a a a,输出个数为 b b b,Xavier随机初始化会将该层中的权重参数随机采样于 ( − 6 a + b , 6 a + b ) (-\sqrt{\frac{6}{a+b}},\sqrt{\frac{6}{a+b}}) (a+b6 ,a+b6 )

2.2 梯度裁剪(Gradient Clipping)

这是一种人为限制梯度过大或过小的方法,其思路是给原本的梯度 g g g加上一个系数,在 g g g的绝对值过大时对其进行缩小,反之亦然。这个系数为:
η ∣ ∣ g ∣ ∣ \frac{\eta}{||g||} ∣∣g∣∣η

其中 η \eta η为超参数, ∣ ∣ g ∣ ∣ ||g|| ∣∣g∣∣为梯度的二范数。

增加这个系数后虽然会导致这个结果并非是真正的损失函数对于权重的偏导数,但是能够维持数值稳定性。

2.3 正则化

这是一种抑制梯度爆炸的方法。我之前介绍过正则化方法:基于PyTorch实战权重衰减——L2范数正则化方法(附代码),其思想是在损失函数中增加权重的范数作为惩罚项:
l o s s = 1 n Σ ( y − y ^ ) 2 + λ 2 n ∣ ∣ w ∣ ∣ 2 loss = \dfrac{1}{n} \Sigma (y - \widehat{y})^2+ \dfrac{\lambda}{2n}||w||^2 loss=n1Σ(yy )2+2nλ∣∣w2
在深度学习模型不断地迭代(学习)过程中, l o s s loss loss越来越小导致权重的范数也越来越小,也就抑制了梯度爆炸。

2.4 Batch Normalization

Batch Normalization(批标准化)是基于Normalization(归一化)增加scaling和shifting的一种数据标准化处理方式,其具体作用原理可以参考:关于Batch Normalization的说明。

Batch Normalization能维持数值稳定性的基本原理与梯度裁剪类似:都是对数值人为增加缩放,维持数值保持在一个不大不小的合理范围内。两者的区别是梯度裁剪在反向传播过程中直接作用于损失函数对权重的偏导数;而Batch Normalization在正向传播中对某层的输出进行标准化处理,间接维持对权重偏导的稳定性。

这里需要指出的是:由于输入 x x x也参与了偏导的计算,如果 x x x是一个高维向量,那对于输入 x x x的Batch Normalization处理也是必要的。

2.5 LSTM?Short Cut!

很多文章说明LSTM(长短周期记忆)网络有助于维持数值稳定性,我最初看到这些文章时大为不解——因为我们是需要通用的方法来改进提高现有模型的数值稳定性,而不是直接替换成LSTM网络模型,况且LSTM也不是万能的深度学习模型,不可能遇到梯度衰减或者梯度爆炸就把模型替换成LSTM。

如果不知道LSTM是什么可以看下:LSTM(长短期记忆)网络的算法介绍及数学推导

后来我看到Lecture 15: Exploding and Vanishing Gradients明白了其中的误解:这篇文章通篇都在用RNN为例来说明数值稳定性。对于RNN来说,LSTM确实是一个改进的模型,因为其内部维持“长期记忆”的“门”结构确实有助于提升数值稳定性。

我想大部分把LSTM单列出来说明可以提升数值稳定性的文章都误会了。

而Short Cut这种结构才是提升数值稳定性的普适规则,LSTM仅是改善RNN的一个特例而已。
在这里插入图片描述

Short Cut的具体作用机理可以参考He Kaiming的原文:Deep Residual Learning for Image Recognition

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/55870.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode-中等题】2. 两数相加

文章目录 题目方法一:借助一个进制位,以及更新尾结点方法一改进:相比较第一种,给head一个临时头节点(开始节点),最后返回的时候返回head.next,这样可以省去第一次的判断 题目 方法一…

JVM——类加载与字节码技术—类文件结构

由源文件被编译成字节码文件,然后经过类加载器进行类加载,了解类加载的各个阶段,了解有哪些类加载器,加载到虚拟机中执行字节码指令,执行时使用解释器进行解释执行,解释时对热点代码进行运行期的编译处理。…

idea的debug断点的使用

添加断点(目前不知道如何添加断点,就给AutoConfigurationImportSelector的每个方法都加上断点): 然后将StockApplication启动类以debug方式运行,然后程序就会停在119行 点击上边的step over让程序往下运行一行&#x…

《入门级-Cocos2dx4.0 塔防游戏开发》---第七课:游戏界面开发(自定义Layer)

目录 一、开发环境 二、开发内容 2.1 添加资源文件 2.2 游戏MenuLayer开发 2.3 GameLayer开发 三、演示效果 四、知识点 4.1 sprite、layer、scene区别 4.2 setAnchorPoint 一、开发环境 操作系统:UOS1060专业版本。 cocos2dx:版本4.0 环境搭建教程&…

2.3.Dubbo的基本应用- 异步调用 、泛化调用 、动态配置

异步调用 官网地址: http://dubbo.apache.org/zh/docs/v2.7/user/examples/async-call/ 理解起来比较容易, 主要要理解CompletableFuture, 如果不理解, 就直接把它理解为Future 其他异步调用方式:Dubbo 同步调用太慢&a…

web、HTTP协议

目录 一、Web基础 1.1 HTML概述 1.1.1 HTML的文件结构 1.2 HTML中的部分基本标签 二.HTTP协议 2.1.http概念 2.2.HTTP协议版本 2.3.http请求方法 2.4.HTTP请求访问的完整过程 2.5.http状态码 2.6.http请求报文和响应报文 2.7.HTTP连接优化 三.httpd介绍 3.1.http…

RK3399平台开发系列讲解(存储篇)Linux 存储系统的 I/O 栈

平台内核版本安卓版本RK3399Linux4.4Android7.1🚀返回专栏总目录 文章目录 一、Linux 存储系统全景二、Linux 存储系统的缓存沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将介绍 Linux 存储系统的 I/O 原理。 一、Linux 存储系统全景 我们可以把 Linux 存储系…

JUC并发编程(一)

JUC并发编程 1. 查看进程和线程的方法1.1 Windows1.2 Linux 1. 查看进程和线程的方法 1.1 Windows 任务管理器可以查看进程和线程数&#xff0c;也可以用来杀死进程tasklist 查看进程taskkill 杀死进程 1.2 Linux ps -fe 查看所有进程ps -fT -p <PID> 查看某个进程&a…

10*1000【2】

知识: -----------金融科技背后的技术---------------- -------------三个数字化趋势 1.数据爆炸&#xff1a;internet of everything&#xff08;iot&#xff09;&#xff1b;实时贡献数据&#xff1b;公有云服务->提供了灵活的计算和存储。 2.由计算能力驱动的&#x…

【跟小嘉学 Rust 编程】十三、函数式语言特性:迭代器和闭包

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 【跟小嘉学 Rust 编程】二、Rust 包管理工具使用 【跟小嘉学 Rust 编程】三、Rust 的基本程序概念 【跟小嘉学 Rust 编程】四、理解 Rust 的所有权概念 【跟小嘉学 Rust 编程】五、使用结构体关联结构化数据 【跟小嘉学…

Leetcode 191.位1的个数

编写一个函数&#xff0c;输入是一个无符号整数&#xff08;以二进制串的形式&#xff09;&#xff0c;返回其二进制表达式中数字位数为 1 的个数&#xff08;也被称为汉明重量&#xff09;。 提示&#xff1a; 请注意&#xff0c;在某些语言&#xff08;如 Java&#xff09;中…

2023/08/27

一、图片引入 项目中往往不使用相对路径引入文件&#xff0c;一般都使用实现绝对路径引入文件。 方式一&#xff1a;【适用vue2&#xff0c;不适用vue3】 <img :src"require(/assets/images/home/bottom_can.png)" alt"">方式二&#xff1a;【适用…

mac m1 docker 安装kafka和zookeeper

获取本地ip地址 ifconfig en0 192.168.0.105. 下面的ip都会使用到 1、拉取镜像 docker pull wurstmeister/zookeeper docker pull wurstmeister/kafka 2、启动容器 启动 zookeeper docker run -d --name zookeeper -p 2181:2181 映射 3、 启动 kafka 注意&#xff…

计网-All

路由器的功能与路由表的查看_路由器路由表_傻傻小猪哈哈的博客-CSDN博客路由基础-直连路由、静态路由与动态路由的概念_MikeVane-bb的博客-CSDN博客路由器的功能与路由表的查看_路由器路由表_傻傻小猪哈哈的博客-CSDN博客 直连路由就是路由器直接连了一个网段&#xff0c;他就…

一个短视频去水印小程序,附源码

闲来无事&#xff0c;开发了一个短视频去水印小程序&#xff0c;目前支持抖音、快手&#xff0c;后续再加上别的平台。 因为平台原因&#xff0c;就不放二维码了&#xff0c;你可以直接微信搜索【万能老助手】这里贴一张效果图。 页面非常简单&#xff0c;这里就不过多介绍了&…

Git企业开发控制理论和实操-从入门到深入(五)|标签管理

前言 那么这里博主先安利一些干货满满的专栏了&#xff01; 首先是博主的高质量博客的汇总&#xff0c;这个专栏里面的博客&#xff0c;都是博主最最用心写的一部分&#xff0c;干货满满&#xff0c;希望对大家有帮助。 高质量博客汇总 然后就是博主最近最花时间的一个专栏…

正则表达式总结

作为软件工程师&#xff0c;工作中经常都需要使用正则表达式进行搜索&#xff0c;替换&#xff0c;验证数据&#xff08;手机号、邮箱、账号&#xff09;等。 但没有系统的学习总结过。现在就来学习总结一下。 认识元字符 元字符就是一些特殊符号&#xff0c;代表一些特殊意思…

初始Netty

文章目录 目录 文章目录 前言 一、netty 总结 前言 认识netty 一、netty Netty是一个基于Java的高性能网络应用框架&#xff0c;用于快速开发可扩展的网络服务器和客户端。它提供了易于使用的抽象API&#xff0c;使开发人员能够轻松地构建各种网络应用程序&#xff0c;包括…

JavaScript基础语法

一、JavaScript编写方式 位置一&#xff1a;HTML代码行内&#xff08;不推荐&#xff09; <!-- 方式一&#xff1a;行内编写 --> <a href"javascript:alert(hello world)">hello world</a> <!-- 方式二&#xff1a;行内编写&#xff0c;通过监…

学信息系统项目管理师第4版系列03_文件与标准

审核未通过&#xff0c;删除文件部分&#xff0c;仅保留标准化相关内容&#xff0c;重发 12. 标准化 12.1. 采用国际标准和国外先进标准的程度分为等同采用、修改采用和等效采用 3 种 12.1.1. 【高21上选20】 12.1.2. 采用指与国际标准在技术内容和文本结构上相同,或者与国…