为什么RNN(循环神经网络)存在梯度消失和梯度爆炸?

1️⃣ 原理分析

在这里插入图片描述
RNN前向传播的公式为:

  • x t x_t xt是t时刻的输入
  • s t s_t st是t时刻的记忆, s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(Uxt+Wst1),f表示激活函数, s t − 1 s_{t-1} st1表示t-1时刻的记忆
  • o t o_t ot是t时刻的输出, o t = s o f t m a x ( V ⋅ s t ) o_t=softmax(V\cdot s_t) ot=softmax(Vst)

采用交叉熵作为损失函数:
L = ∑ i = 1 T − o t ˉ l o g o t L=\sum_{i=1}^{T}-\bar{o_{t}}logo_{t} L=i=1Totˉlogot
其中T代表时间步的长度, o ˉ t \bar o_{t} oˉt代表ground truth, o t o_t ot代表预测的输出。

假设有三个时间步, t = 1 , 2 , 3 t=1,2,3 t=1,2,3。假设初始记忆 s t = 0 s_t=0 st=0,则 t = 1 t=1 t=1时的记忆和输出为:
s 1 = f ( U x 1 + W s 0 ) o 1 = f [ V ⋅ f ( U x 1 + W s 0 ) ] \begin{aligned}&s_1=f(Ux_1+Ws_0)\\&o_{1}=f[V\cdot f(Ux_{1}+Ws_{0})]\end{aligned} s1=f(Ux1+Ws0)o1=f[Vf(Ux1+Ws0)]
t = 2 t=2 t=2时的记忆和输出为:
s 2 = f ( U x 2 + W s 1 ) o 2 = f [ V ⋅ f ( U x 2 + W s 1 ) ] = f [ V ⋅ f ( U x 2 + W f ( U x 1 + W s 0 ) ) ] \begin{aligned}&s_2=f(Ux_2+Ws_1)\\&o_{2}=f[V\cdot f(Ux_{2}+Ws_{1})]=f[V\cdot f(Ux_{2}+Wf(Ux_1+Ws_0))]\end{aligned} s2=f(Ux2+Ws1)o2=f[Vf(Ux2+Ws1)]=f[Vf(Ux2+Wf(Ux1+Ws0))]

这样很晕,我来画个箭头:
在这里插入图片描述
可以发现 s 2 s_2 s2 s 1 s_1 s1的函数

t = 3 t=3 t=3时的记忆和输出为:
s 3 = f ( U x 3 + W s 2 ) o 3 = f [ V ⋅ f ( U x 3 + W s 2 ) ] = f [ V ⋅ f ( U x 3 + W f ( U x 2 + W s 1 ) ) ] = f [ V ⋅ f ( U x 3 + W f ( U x 2 + W f ( U x 1 + W s 0 ) ) ) ] \begin{aligned}&s_3=f(Ux_3+Ws_2)\\&o_{3}=f[V\cdot f(Ux_{3}+Ws_{2})]=f[V\cdot f(Ux_{3}+Wf(Ux_2+Ws_1))]=f[V\cdot f(Ux_{3}+Wf(Ux_2+Wf(Ux_1+Ws_0)))] \end{aligned} s3=f(Ux3+Ws2)o3=f[Vf(Ux3+Ws2)]=f[Vf(Ux3+Wf(Ux2+Ws1))]=f[Vf(Ux3+Wf(Ux2+Wf(Ux1+Ws0)))]
画个箭头:
在这里插入图片描述
可以发现 s 3 s_3 s3 s 2 s_2 s2的函数,又 s 2 s_2 s2 s 1 s_1 s1的函数,因此 s 3 s_3 s3包含 s 2 s_2 s2 s 1 s_1 s1

然后我们来分析反向传播:BPTT(Back-Propagation Through Time,时间上的反向传播)是针对RNN的训练算法,它的核心依然是基于梯度下降的反向传播。对于RNN来说,主要参数包括U、W和V。
在这里插入图片描述
以t=3时举例子,求U,V,W的梯度:
∂ L 3 ∂ V = ∂ L 3 ∂ o 3 ∂ o 3 ∂ V 3 ◯ ∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 2 ∂ s 2 ∂ W + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ W 4 ◯ ∂ L 3 ∂ U = ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 2 ∂ s 2 ∂ U + ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ U 5 ◯ \begin{aligned} &\frac{\partial L_3}{\partial V} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial V}\textcircled{3} \\ &\frac{\partial L_3}{\partial W} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_2}\frac{\partial s_2}{\partial W}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial W}\textcircled{4} \\ &\frac{\partial L_3}{\partial U} =\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_2}\frac{\partial s_2}{\partial U}+\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial U}\textcircled{5} \end{aligned} VL3=o3L3Vo33WL3=o3L3s3o3Ws3+o3L3s2o3Ws2+o3L3s3o3s2s3s1s2Ws14UL3=o3L3s3o3Us3+o3L3s2o3Us2+o3L3s3o3s2s3s1s2Us15

对于公式⑤可以简写成:
∂ L 3 ∂ U = ∑ k = 0 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s k ∂ s k ∂ U \frac{\partial L_3}{\partial U}=\sum_{k=0}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial U} UL3=k=03o3L3s3o3sks3Usk

由于 ∂ s 3 ∂ s k \frac{\partial s_3}{\partial s_k} sks3也需要链式法则,即 ∂ s 3 ∂ s 1 = ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 \frac{\partial s_3}{\partial s_1}=\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1} s1s3=s2s3s1s2。因此公式可以进一步修改为:

∂ L 3 ∂ U = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ∂ s 3 ∂ s k ∂ s k ∂ U = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ( ∏ j = k + 1 3 ∂ s j ∂ s j − 1 ) ∂ s k ∂ U 6 ◯ \frac{\partial L_3}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial U}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial U}\textcircled{6} UL3=k=13o3L3s3o3sks3Usk=k=13o3L3s3o3(j=k+13sj1sj)Usk6

同理,对公式④也可以写为:
∂ L 3 ∂ W = ∑ k = 1 3 ∂ L 3 ∂ o 3 ∂ o 3 ∂ s 3 ( ∏ j = k + 1 3 ∂ s j ∂ s j − 1 ) ∂ s k ∂ W 7 ◯ \frac{\partial L_3}{\partial W}=\sum_{k=1}^3\frac{\partial L_3}{\partial o_3}\frac{\partial o_3}{\partial s_3}(\prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial W}\textcircled{7} WL3=k=13o3L3s3o3(j=k+13sj1sj)Wsk7

观察③式,对与V的偏导不存在依赖关系。

观察④和⑤式,对W和U求偏导的时候,存在长期依赖关系。原因是前向传播的时候 s t s_t st会随着时间向前传播,而 s t s_t st是W、U的函数。

假设激活函数为tanh,将⑥⑦中累乘部分取出来:
∏ j = k + 1 3 ∂ s j ∂ s j − 1 = ∏ j = k + 1 3 t a n h ′ W \prod_{j=k+1}^3\frac{\partial s_j}{\partial s_{j-1}}=\prod_{j=k+1}^3tanh^{'}W j=k+13sj1sj=j=k+13tanhW
例如: s 3 = f ( U x 3 + W s 2 ) s_3=f(Ux_3+Ws_2) s3=f(Ux3+Ws2) ∂ s 3 ∂ s 2 = t a n h ′ ( U ) W \frac{\partial s3}{\partial s_{2}}=tanh'(U) W s2s3=tanh(U)W
在这里插入图片描述

由上图可知,tanh的梯度最大为1,通常情况下会小于1,因此当t很大的时候,例如t=100时,⑥⑦中的累乘部分 ∏ j = k + 1 100 t a n h ′ W \prod_{j=k+1}^{100}tanh^{^{\prime}}W j=k+1100tanhW将趋于0,因此t=100时对于W和U的梯度将趋于0,导致梯度消失。

分析完tanh,再来分析一下W,如果W中的值太大,那么产生问题就是梯度爆炸


2️⃣ 总结

  • RNN存在梯度消失的原因是:隐藏层的输出 s t s_t st会向前传播,这样导致在反向传播求梯度时存在一个累乘项,这个累乘项由激活函数的梯度参数W组成,如果我们采用tanh作为激活函数,其梯度小于1,时间步越多,累乘项越趋近于0,导致梯度消失。
  • RNN存在梯度爆炸的原因:参数W如果过大,则会导致累乘项逐渐变大,导致梯度爆炸

3️⃣ 参考

RNN梯度消失与梯度爆炸的原因 - Hideonbush的文章 - 知乎


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/60583.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NUXT3学习日记二(样式配置、引入组件库、区分在服务端还是在客户端渲染)

上一章已经给大家分享官网下载的nuxt3了,下面正式进入我所要说的内容吧 一、初始化样式 想必大家从我的git下载下来的nuxt3,能看到nuxt.config.ts这个文件了吧。 这里我们有两种css配置方式 1、css:[~/assets/base.scss] 这种方式不能在scss文件中定义…

2024AAAI | DiffRAW: 利用扩散模型从手机RAW图生成单反相机质量的RGB图像

文章标题:《DiffRAW: Leveraging Diffusion Model to Generate DSLR-Comparable Perceptual Quality sRGB from Smartphone RAW Images》 原文链接:DiffRAW 本文是清华大学深圳研究院联合华为发表在AAAI-2024上的论文(小声bb:华…

计算机视觉 ---图像模糊

1、图像模糊的作用: 减少噪声: 在图像获取过程中,例如通过相机拍摄或者传感器采集,可能会受到各种因素的干扰,从而引入噪声。这些噪声在图像上表现为一些孤立的、不符合图像主体内容的像素变化,如椒盐噪声&…

串口DMA接收不定长数据

STM32F767—>串口通信接收不定长数据的处理方法_stm32串口超时中断-CSDN博客 STM32-HAL库串口DMA空闲中断的正确使用方式解析SBUS信号_stm32 hal usart2 dma-CSDN博客 #define USART1_RxBuffSize 100 extern DMA_HandleTypeDef hdma_usart1_rx; //此处声明的变量在…

web实验3:虚拟主机基于不同端口、目录、IP、域名访问不同页面

创建配置文件: 创建那几个目录及文件,并且写内容: 为网卡ens160添加一个 IPv4 地址192.168.234.199/24: 再重新激活一下网卡ens160: 重启服务: 关闭防火墙、改宽松模式: 查看nginx端口监听情况:…

Jmeter中的监听器(二)

5--JSR223 Listener 用途 自定义数据处理:使用脚本语言处理测试结果,实现高度定制化的数据处理和分析。实时监控:实时处理和显示测试结果。集成外部系统:将测试结果发送到外部系统进行进一步处理和分析。日志记录:记…

计算机组成原理——进位计数制

1.认识不同进制 通常的我们日常生活中用到的都是十进制,比如买东西或者期末成绩等等,当然肯定不止这一种进制方法,相关的还有二进制、八进制、十六进制,还有古罗马数字,通常古罗马数字近似可以看作是五进制的数&#x…

function and task

任务和函数 在Verilog语言中提供了任务和函数,可以将较大的行为级设计划分为较小的代码段,允许设计者将需要在多个地方重复使用的相同代码提取出来,编写成任务和函数,这样可以使代码更加简洁和易懂。 1.1任务 任务的定义 任务定义…

【C++】用红黑树封装set和map

在C标准库中,set容器和map容器的底层都是红黑树,它们的各种接口都是基于红黑树来实现的,我们在这篇文章中已经模拟实现了红黑树 ->【C】红黑树,接下来我们在此红黑树的基础上来看看如何封装set和map。 一、共用一颗红黑树 我…

类与实例

1 问题如何理解类与实例? 2 方法 类与实例 类(class)的概述:用来描述具有相同的属性和方法的对象的集合。它定义了该集合中每个对象所共有的属性和方法。对象是类的实例。 类是一类事物,实例是具体的一个事物。 编程与生活是相通的&#xff0…

2024/11/4 计网强化

10: 17: 22: 09: 11: 12: 13: 14: 15: 18: 19: 20: 21: 16:

力扣104 : 二叉树最大深度

补:二叉树的最大深度 描述: 给定一个二叉树 root ,返回其最大深度。二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 何解? 树一般常用递归:递到叶子节点开始倒着处理

机器情绪及抑郁症算法

🏡作者主页:点击! 🤖编程探索专栏:点击! ⏰️创作时间:2024年11月12日17点02分 点击开启你的论文编程之旅https://www.aspiringcode.com/content?id17230869054974 计算机来理解你的情绪&a…

JAVA学习日记(十五) 数据结构

一、数据结构概述 数据结构是计算机底层存储、组织数据的方式。 数据结构是指数据相互之间以什么方式排列在一起的。 数据结构是为了更加方便的管理和使用数据,需要结合具体的业务场景来进行选择。 二、常见的数据结构 (一)栈 特点&…

i春秋-SQLi(无逗号sql注入,-- -注释)

练习平台地址 竞赛中心 题目描述 后台有获取flag的线索应该是让我们检查源码找到后台 题目内容 空白一片 F12检查源码 发现login.php 访问login.php?id1 F12没有提示尝试sql注入 常规sql注入 //联合注入得到表格列数 1 order by 3 # 1 union select 1,2,3 #&#xff08…

sql注入之二次注入(sqlilabs-less24)

二阶注入(Second-Order Injection)是一种特殊的 SQL 注入攻击,通常发生在用户输入的数据首先被存储在数据库中,然后在后续的操作中被使用时,触发了注入漏洞。与传统的 SQL 注入(直接注入)不同&a…

nginx 部署2个相同的vue

起因: 最近遇到一个问题,在前端用nginx 部署 vue, 发现如果前端有改动,如果不适用热更新,而是直接复制项目过去,会404 因此想到用nginx 负载两套相同vue项目,然后一个个复制vue项目就可以了。…

MySQL:CRUD

MySQL表的增删改查(操作的是表中的记录) CRUD(增删改查) C-Create新增R-Retrieve检查,查询U-Update更新D-Delete删除 新增(Create) 语法: 单行数据全列插入 insert into 表名[字段一,字段…

centos7 node升级到node18

使用jenkins发布vue3项目提示node18安装失败 错误日志: /var/lib/jenkins/tools/jenkins.plugins.nodejs.tools.NodeJSInstallation/Node18/bin/node: /lib64/libm.so.6: version GLIBC_2.27 not found (required by /var/lib/jenkins/tools/jenkins.plugins.node…

万字长文解读深度学习——ViT、ViLT、DiT

文章目录 🌺深度学习面试八股汇总🌺ViT1. ViT的基本概念2. ViT的结构与工作流程1. 图像分块(Image Patch Tokenization)2. 位置编码(Positional Encoding)3. Transformer 编码器(Transformer En…