普通RNN的缺陷—梯度消失和梯度爆炸

之前的RNN,无法很好地学习到时序数据的长期依赖关系。因为BPTT会发生梯度消失和梯度爆炸的问题。

RNN梯度消失和爆炸

对于RNN来说,输入时序数据xt时,RNN 层输出ht。这个ht称为RNN 层的隐藏状态,它记录过去的信息。

语言模型的任务是根据已经出现的单词预测下一个将要出现的单词。

学习正确解标签过程中,RNN层通过向过去传递有意义的梯度,能够学习时间方向上的依赖关系。如果这个梯度在中途变弱(甚至没有包含任何信息),权重参数将不会被更新,也就是所谓的RNN层无法学习长期的依赖关系。梯度的流动如下图绿色箭头。

在这里插入图片描述

随着时间增加,RNN会产生梯度变小(梯度消失)或梯度变大(梯度爆炸)。

RNN 层在时间方向上的梯度传播,如下图。

在这里插入图片描述

反向传播的梯度流经tanh、+、MatMul(矩阵乘积)运算。

+的反向传播,将上游传来的梯度原样传给下游,梯度值不变。

tanh的计算图如下。它将上游传来的梯度乘以tanh的导数传给下游。

在这里插入图片描述

y=tanh(x)的值及其导数的值如下图。导数值小于1,x越远离0,值越小。反向传播梯度经过tanh节点要乘上tanh的导数,这就导致梯度越来越小。

如果RNN层的激活函数使用ReLU,可以抑制梯度消失,当ReLU输入为x时,输出是max(0,x)。x大于0时,反向传播将上游的梯度原样传递到下游,梯度不会退化。

在这里插入图片描述

对于MatMul(矩阵乘积)节点。仅关注RNN层MatMul节点时的梯度反向传播如下图。每一次矩阵乘积计算都使用相同的权重Wh。

在这里插入图片描述

N = 2  # mini-batch的大小
H = 3  # 隐藏状态向量的维数
T = 20  # 时序数据的长度dh = np.ones((N, H))#初始化为所有元素均为 1 的矩阵,dh是梯度np.random.seed(3)Wh = np.random.randn(H, H)#梯度的大小随时间步长呈指数级增加,发生梯度爆炸
#Wh = np.random.randn(H, H) * 0.5
#梯度的大小随时间步长呈指数级减小,发生梯度消失,权重梯度不能被更新,模型无法学习长期的依赖关系
norm_list = []
for t in range(T):dh = np.dot(dh, Wh.T)#根据反向传播的 MatMul 节点的数量更新 dh 相应次数norm = np.sqrt(np.sum(dh**2)) / N#mini-batch(N)中的平均L2 范数,L2 范数对所有元素的平方和求平方根.norm_list.append(norm)#将各步的 dh 的大小(范数)添加到 norm_list 中print(norm_list)# 绘制图形
plt.plot(np.arange(len(norm_list)), norm_list)
plt.xticks([0, 4, 9, 14, 19], [1, 5, 10, 15, 20])
plt.xlabel('time step')
plt.ylabel('norm')
plt.show()

如果Wh是标量,由于Wh被反复乘了T次,当Wh大于1时,梯度呈指数级增加;当 Wh 小于1时,梯度呈指数级减小。

如果wh是矩阵,矩阵的奇异值表示数据的离散程度,根据奇异值(多个奇异值中的最大值)是否大于1,可以预测梯度大小的变化。奇异值比1大是梯度爆炸的必要非充分条件。

在这里插入图片描述

在这里插入图片描述

梯度裁剪gradients clipping

梯度裁剪(gradients clipping)是解决解决梯度爆炸的一个方法。

将神经网络用到的所有参数的梯度整合成一个,用g表示,将阈值设置为threshold,如果梯度g的L2范数大于等于该阈值,就按如下方式修正梯度。

在这里插入图片描述

dW1 = np.random.rand(3, 3) * 10
dW2 = np.random.rand(3, 3) * 10
grads = [dW1, dW2]
max_norm = 5.0#阈值def clip_grads(grads, max_norm):total_norm = 0for grad in grads:total_norm += np.sum(grad ** 2)total_norm = np.sqrt(total_norm)#L2 范数对所有元素的平方和求平方根rate = max_norm / (total_norm + 1e-6)if rate < 1:#如果梯度的L2范数total_norm大于等于阈值max_norm,rate是小于1的,此时就需要修正梯度for grad in grads:grad *= rateprint('before:', dW1.flatten())
clip_grads(grads, max_norm)
print('after:', dW1.flatten())
before: [7.14418135 3.58857143 7.82910303 8.04057218 8.8617387  1.899638863.0606848  8.14163088 5.25490409]
after: [1.43122195 0.71891263 1.56843501 1.61079946 1.77530697 0.380562130.61315903 1.63104494 1.05273561]

解决梯度消失

为了解决梯度消失,需要从根本上改变 RNN 层的结构。

LSTM 和GRU中增加了一种门结构,可以学习到时序数据的长期依赖关系。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/560180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LSTM的结构

RNN和LSTM 简略表示RNN层&#xff1a;长方形节点中包含了矩阵乘积、偏置的和、tanh函数的变换。将下面这个公式表示成一个tanh节点。 LSTM&#xff1a;Long Short-Term Memory&#xff08;长短期记忆&#xff09;&#xff0c;长时间维持短期记忆。 LSTM与RNN的接口(输入输出)…

STM32 USART 补充

串口通讯的数据包&#xff1a;发送设备通过自身的TXD接口传输到接收设备的RXD接口。 串口通讯的协议层中&#xff0c;规定了数据包的内容&#xff0c;由起始位、主体数据、校验位、停止位组成&#xff0c;通讯双方的数据包格式要约定一致才能正常收发数据。 异步通讯&#xf…

ROS TF变换

静态坐标转换&#xff1a;机器人本体中心到雷达中心的转换。因为激光雷达可能没安装到机器人的中心。 动态坐标转换&#xff1a;机器人中心和里程计坐标的变换。机器人从起点出发后&#xff0c;里程计坐标相对于本体就会产生一个偏移&#xff0c;这个偏移随着机器人的运动不断…

ROS底盘控制节点 源码分析

先在机器人端通过launch文件启动底盘控制。 robot:~$ roslaunch base_control base_control.launch ... logging to /home/jym/.ros/log/3e52acda-914a-11ec-beaa-ac8247315e93/roslaunch-robot-8759.log Checking log directory for disk usage. This may take a while. Pres…

ROS + OpenCV

视觉节点测试 先进行一些测试。并记录数据。 圆的是节点&#xff0c;方的是话题。 1.robot_camera.launch robot:~$ roslaunch robot_vision robot_camera.launch ... logging to /home/jym/.ros/log/bff715b6-9201-11ec-b271-ac8247315e93/roslaunch-robot-8830.log Check…

ROS+雷达 运行数据记录

先测试一下雷达&#xff0c;记录数据。方便接下来分析源码。 1.roslaunch robot_navigation lidar.launch robot:~$ roslaunch robot_navigation lidar.launch ... logging to /home/jym/.ros/log/7136849a-92cc-11ec-acff-ac8247315e93/roslaunch-robot-9556.log Checking l…

ROS 找C++算法源码的方法

在gmapping的launch文件中看到&#xff0c;type“slam_gmapping”&#xff0c;这里的slam_gmapping是c编译后的可执行文件。 如果想要修改gmapping算法&#xff0c;就需要找到slam_gmapping的c源码。 但是这是用apt下载的包&#xff0c;是二进制类型的&#xff0c;没有下载出…

ros 雷达 slam 导航 文件分析

ros 雷达 slam 导航 文件分析robot_slam_laser.launchrobot_lidar.launchlidar.launchraplidar.launchkarto.launchgmapping.launchcartographer.launchrobot_navigation.launchmap.yamlmap.pgmamcl_params.yamlmove_base.launchcostmap_common_params.yamllocal_costmap_param…

Apprentissage du français partie 1

Apprentissage du franais partie 1 键盘转换图&#xff1a; 字母&#xff1a;26个 元音字母&#xff1a;a、e、i、o、u、y b浊辅音(声带)-p清辅音 d-t 音符 音符&#xff1a;改变字母发音。 &#xff1a;闭音符 [e] &#xff1a;开音符 /ε/ &#xff1a;长音符 /ε/…

stm32基本定时器

定时器分类 stm32f1系列&#xff0c;8个定时器&#xff0c;基本定时器(TIM6,7)、通用定时器(TIM2,3,4,5)、高级定时器(TIM1,8)。 基本定时器&#xff1a;16位&#xff0c;只能向上计数的定时器&#xff0c;只能定时&#xff0c;没有外部IO 通用定时器&#xff1a;16位&#…

stm32高级定时器 基础知识

stm32高级定时器 高级定时器时基单元&#xff1a; 包含一个16位自动重装载寄存器 ARR 一个16位的计数器CNT&#xff0c;可向上/下计数 一个16位可编程预分频器PSC&#xff0c;预分频器时钟源有多种可选&#xff0c;有内部的时钟、外部时钟。 一个8位的重复计数器 RCR&…

stm32 PWM互补输出

stm32高级定时器例子—stm32 PWM互补输出 定时器初始化结构体 TIM_TimeBaseInitTypeDef 时基结构体&#xff0c;用于定时器基础参数设置&#xff0c;与TIM_TimeBaseInit函数配合使用&#xff0c;完成配置。 typedef struct { TIM_Prescaler /*定时器预分频器设置&…

stm32 输入捕获 测量脉宽

选用通用定时器TIM5的CH1。 PA0接一个按键&#xff0c;默认接GND&#xff0c;当按键按下时&#xff0c;IO口被拉高&#xff0c;此时&#xff0c;可利用定时器的输入捕获功能&#xff0c;测量按键按下的这段高电平的时间。 宏定义方便程序升级、移植&#xff0c;举个例子&#…

stm32 PWM输入捕获

普通的输入捕获&#xff0c;可使用定时器的四个通道&#xff0c;一路捕获占用一个捕获寄存器. PWM输入&#xff0c;只能使用两个通道&#xff0c;通道1和通道2。 一路PWM输入占用两个捕获寄存器&#xff0c;一个捕获周期&#xff0c;一个捕获占空比。 这里&#xff0c;用通用…

直流有刷减速电机结构及其工作原理

寒假无聊拆了个直流有刷减速电机。下面介绍一下它的结构和工作原理 直流电机 直流电机和直流减速电机&#xff1a; 构造上相差的是一个减速齿轮组。 普通的直流电机当空载时&#xff0c;电机的转速由电压决定&#xff0c;直流减速电机的转速由齿轮组和电压决定。 齿轮组作…

数据库基础概念

postgreSQL设置只允许本地机器连接 在D:\program files\PostgreSQL\14\data里面设置postgresql.conf&#xff1a; listen_addresses ‘localhost’ 然后在服务窗口重新启动postgresql。 PostgreSQL执行SQL语句 PostgreSQL的psql工具可通过命令行执行SQL语句。 psql -U po…

电机和驱动的种类

电机种类 直流电机 分为普通的直流电机、直流减速电机、有刷、无刷。 直流有刷减速电机参数&#xff1a; 空载转速&#xff0c;正常工作电压&#xff0c;电机不带任何负载的转速。 空载电流&#xff0c;正常工作电压&#xff0c;电机不带任何负载的工作电流。单位mA。 负载…

Linux shell基础知识

Shell简介 Shell是一个应用程序&#xff0c;接收用户输入的命令&#xff0c;根据命令做出相应动作。 Shell负责将应用层或者用户输入的命令&#xff0c;传递给系统内核。由操作系统内核&#xff0c;来完成相应的工作。然后将结果反馈给应用层或者用户。 shell命令格式&#…

Linux APT VIM 的一些指令

APT APT下载工具&#xff0c;可以实现软件自动下载、配置、安装二进制或源码功能。 APT采用客户端/服务器模式。 sudo apt-get update 更新软件 sudo apt-get check 检查依赖关系 sudo apt-get install package-name 安装软件 apt-get负责下载软件&#xff0c;install负责安…

CATIA 界面介绍

窗口介绍 窗口主要有&#xff1a;菜单栏、工具栏、特征树、罗盘、信息栏、图形区。 菜单栏&#xff0c;开始里面有CATIA的各个功能模块。 图形区&#xff0c;进行3D、2D设计的图形创建、编辑区域。 信息栏&#xff0c;显示用户即将进行操作的文字提示。 工具栏&#xff0c;…