25. 深度学习进阶 - 权重初始化,梯度消失和梯度爆炸

文章目录

    • 权重初始化
    • 梯度消失与梯度爆炸

在这里插入图片描述

Hi,你好。我是茶桁。

咱们这节课会讲到权重初始化、梯度消失和梯度爆炸。咱们先来看看权重初始化的内容。

权重初始化

机器学习在我们使用的过程中的初始值非常的重要。就比如最简单的wx+b,现在要拟合成一个yhat,w如果初始的过大或者初始的过小其实都会比较有影响。

假设举个极端情况,就是w拟合的时候刚刚就拟合到了离x很近的地方,我们想象一下,这个时候是不是学习起来就会很快?所以对于深度学习模型权重的初始化是一个非常重要的事情,甚至有人就说把初始化做好了,其实绝大部分事情就已经解决了。

那么我们怎么样获得一个比较好的初始化的值?首先有这么几个原则

  • 我们的权重值不能设置为0。
  • 尽量将权重变成一个随机化的正态分布。而且有更大的X输入,那我们的权重就应该更小。

l o s s = ∑ ( y ^ − y i ) 2 = ∑ ( ∑ w i x i − y i ) 2 \begin{align*} loss & = \sum(\hat y - y_i)^2 \\ & = \sum(\sum w_ix_i - y_i)^2 \end{align*} loss=(y^yi)2=(wixiyi)2

我们看上面的式子,yhat就是w_i*x_i, 这个时候x_i可能是几百万,也可能是几百。我们w_i取值在(-n, n)之间,那当x_i维度特别大的时候,那yhat值算出来的也就会特别大。所以,x_i的维度特别大的时候,我们期望w_i值稍微小一些,否则加出来的yhat可能就会特别大,那最后求出来的loss也会特别大。

如果loss值特别大,可能就会得到一个非常的梯度。那我们知道,学习的梯度特别大的话,就会发生比较大的震荡。

所以有一个原则,就是当x的dimension很大的时候, 我们期望的它的权重越小。

那后来就有人提出来了一个比较重要的初始化方法,Xavier初始化。这个方法特别适用于sigmoid激活函数或反正切tanh激活函数,它会根据前一层和当前层的神经元数量来选择初始化的范围,以确保权重不会过大或过小。
均值为 0 和标准差的正态分布 : σ = 2 n i n p u t s + n o u t p u t s − r 和 + r 之间的均匀分布: r = 6 n i n p u t s + n o u t p u t s \begin{align*} 均值为0和标准差的正态分布: \sigma & = \sqrt{\frac{2}{n_{inputs}+n_{outputs}}} \\ -r和+r之间的均匀分布:r & = \sqrt{\frac{6}{n_{inputs}+n_{outputs}}} \end{align*} 均值为0和标准差的正态分布:σr+r之间的均匀分布:r=ninputs+noutputs2 =ninputs+noutputs6

然后W的均匀分布就会是这样:
W ∼ U ∣ − 6 n j + n j + 1 , 6 n j + n j + 1 W \sim U \Bigg \vert -\frac{\sqrt 6}{\sqrt{n_j + n_{j+1}}}, \frac{\sqrt 6}{\sqrt{n_j + n_{j+1}}} WU nj+nj+1 6 ,nj+nj+1 6

这个是一个比较有名的初始化方法,如果要做函数的初始化的话,PyTorch在init里面有一个方法:

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

比如,我们看这样例子:

w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.calculate_gain('relu'))

注意: init方法里还有其他的一些方法,大家可以查阅PyTorch的相关文档:https://pytorch.org/docs/stable/nn.init.html

梯度消失与梯度爆炸

当我们的模型层数特别多的时候

Alt text

就比如我们上节课用到的Sequential,我们可以在里面写如非常多的一个函数:

model = nn.Sequential(nn.Linear(in_features=10, out_features=5).double(),nn.Sigmoid(),nn.Linear(in_features=5, out_features=8).double(),nn.Sigmoid(),nn.Linear(in_features=8, out_features=8).double(),nn.Sigmoid(),...nn.Linear(in_features=8, out_features=8).double(),nn.Softmax(),
)

Alt text

这样,在做偏导的时候我们其中几个值特别小,那两个一乘就会乘出来一个特别特别小的数字。最后可能会导致一个结果, ∂ l o s s ∂ w i \frac{\partial loss}{\partial wi} wiloss的值就会极小,它的更新就会特别的慢。我们把这种东西就叫做梯度消失,也有人叫梯度弥散。

以Sigmoid函数为例,其导数为

σ ′ ( x ) = σ ( x ) ( 1 − σ ( x ) ) \begin{align*} \sigma '(x) = \sigma(x)(1-\sigma(x)) \end{align*} σ(x)=σ(x)(1σ(x))

在x趋近正无穷或者负无穷时,导数接近0。当这种小梯度在多层网络中相乘的时候,梯度会迅速减小,导致梯度消失。

除此之外还有一种情况叫梯度爆炸,剃度爆炸类似,当模型的层很多的时候,如果其中某两个值很大,例如两个102,当这两个乘起来就会变成104。乘下来整个loss很大,又会产生一个结果,我们来看这样一个场景:

Alt text

假如说对于上图中这个函数来说,横轴为x, 竖轴为loss,对于这个xi来说,这个地方 ∂ l o s s ∂ x i \frac{\partial loss}{\partial xi} xiloss已经是一个特别大的数字了。

假设咱们举个极端的情况(忽略图中竖轴上的数字),我们现在loss等于x^4: l o s s = x 4 loss=x^4 loss=x4,然后现在 ∂ l o s s ∂ x 4 \frac{\partial loss}{\partial x^4} x4loss就等于 4 x 3 4x^3 4x3,我们假设x在A点,当x=10的时候,那 4 × x 3 = 4000 4\times x^3 = 4000 4×x3=4000, 那我们计算新的xi,就是 x i = x i − α ⋅ ∂ l o s s ∂ x i x_i = x_i - \alpha \cdot \frac{\partial loss}{\partial x_i} xi=xiαxiloss,现在给alpha一个比较小的数,我们假设是0.1,那式子就变成 10 − 0.1 × 4000 10 - 0.1 \times 4000 100.1×4000,结果就是-390。

我们把它变到-390之后,本来我们本来做梯度下降更新完,xi期望的是loss要下降,但是我们结合图像来看,xi=-390的时候,loss就变得极其的巨大了,然后我们在继续,(-390)^4, 这个loss就已经爆炸了。

再继续的时候,会发现会在极值上跳来跳去,loss就无法进行收敛了。所以我们也要拒绝这种情况的发生。

那梯度消失和梯度爆炸这两个问题该如何解决呢?我们来看第一种解决方法: Batch normalization,批量归一化。

那这个方法的核心思想是对神经网络的每一层的输入进行归一化,使其具有零均值和单位方差。

那么首先,对于每个mini-batch中的输入数据,计算均值和方差。 B = { x 1 . . . m } B = \{x_1...m\} B={x1...m}; 要学习的参数: γ , β \gamma,\beta γ,β

μ B = 1 m ∑ i = 1 m x i σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 μ 为均值 m e a n , σ 为方差 \begin{align*} \mu_B & = \frac{1}{m}\sum^m_{i=1}x_i \\ \sigma ^2_B & = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2 \\ & \mu 为均值mean, \sigma为方差 \end{align*} μBσB2=m1i=1mxi=m1i=1m(xiμB)2μ为均值meanσ为方差

这里和咱们之前讲x做normalization的时候其实是特别相似,基本上就是一件事。

然后我们使用均值和方差对输入进行归一化,使得其零均值和单位方差,即将输入标准化为xhat。

x ^ i = x i − μ B σ B 2 + ε \begin{align*} \hat x_i = \frac{x_i - \mu_B}{\sqrt{\sigma ^2_B + \varepsilon}} \end{align*} x^i=σB2+ε xiμB

接着我们对归一化后的输入应用缩放和平移操作,以允许网络学习最佳的变换。

y i = γ x ^ i + β ≡ B N γ , β ( x i ) \begin{align*} y_i = \gamma \hat x_i + \beta \equiv BN_{\gamma,\beta}(x_i) \end{align*} yi=γx^i+βBNγ,β(xi)

输出为 { y i = B N γ , β ( x i ) } \{y_i = BN_{\gamma,\beta}(x_i)\} {yi=BNγ,β(xi)}

最后将缩放和平移后的数据传递给激活函数进行非线性变换。

它会输入一个小批量的x值,

经过反复的梯度下降,会得到一个gamma和beta,能够知道在这一步x要怎么样进行缩放,在缩放之前会经历刚开始的时候那个normalization一样,把把过小值会变大,把过大值会变小。

我们在之前的课程中演示过,没看过和忘掉的同学可以往前翻看一下。

然后在经过这两个可学习的参数进行一个变化,这样它可以做到在每一层x变化不会极度的增大或者极度的缩小,可以让我们的权值保持的比较稳定。

那除了Batch normalization之外,还有一个方法叫Gradient clipping, 它是可以直接将过大的梯度值变小。

Alt text

它其实很简单,也叫做梯度减脂。

如果我们求解出来 ∂ l o s s ∂ w i \frac{\partial loss}{\partial w_i} wiloss很大,假设原来等于400,我们定义了一个100,那超过100的部分,就全部设置成100。

train_loss.backward()
pt.nn.units.clip_grad_value_(model.parameters(), 100)
optimizer.step()

简单粗暴。那其实梯度爆炸还是比较容易解决的,比较复杂的其实是梯度消失的问题。

梯度爆炸为什么比较容易解决?梯度爆炸起码是有导数的,只要把这个导数给它放的特别小就行了,有导数起码保证wi可以更新。

假设alpha,我们的learning_rate等于0.01,乘上一个100,可以保证每次可以有个变化。但是每次这个梯度特别小,假如都快接近于0了,那么1e-10, 就算乘上100倍,最后还是一个特别小的数字。所以相较而言,梯度爆炸就更好解决一些,方法更粗暴一些。

补充一个知识点,这个虽然现在已经用不到了,但是对我们的理解还是有帮助的。方法比较古老。

就是当我们发现梯度有问题的时候, 大概在10年前,那个时候神经网络的模块也不太丰富,很多新出的model,做神经网络的人,一些导数,传播什么的都需要自己写,就我们前几节课写那个神经网络框架的时候做的事。

有的时候导数写错了,就有一种方法叫做gradient checking,梯度检查。

这个使用场景非常的少,当你自己发明了一个新的模块,加到这个模型里面的时候会遇到。

其实很简单,就是把最终的 ∂ l o s s ∂ w i \frac{\partial loss}{\partial w_i} wiloss,求解出来的偏导总是不收敛,可能是这个偏导有问题,那么有可能求导的函数写错了。

那在这个时候就可以做个简单的变化:

∂ l o s s ( θ + ε ) − ∂ l o s s ( θ − ε ) 2 ε \begin{align*} \frac{\partial loss(\theta+\varepsilon)-\partial loss(\theta - \varepsilon)}{2\varepsilon} \end{align*} 2εloss(θ+ε)loss(θε)

这其中 ∂ l o s s ( θ + ε ) \partial loss(\theta + \varepsilon) loss(θ+ε) ∂ l o s s ( θ − ε ) \partial loss(\theta - \varepsilon) loss(θε)是在参数 θ \theta θ, 其实也就是我们的wi上添加和减去微小扰动theta后的损失函数值。

然后我们计算数值梯度和反向传播计算得到的梯度之间的差异。通常这是通过计算它们之间的差异来完成,然后将其与一个小的阈值,比如1e-7进行比较。如果差异非常小(小于阈值),则可以认为梯度计算是正确的,否则可能就需要从新写一下偏导函数了。

这个比较难,但不是一个重点,当且仅当自己要发明一个模型的时候。

那接下来我们来看一下关于Learning_rate和Early Stopping的问题。

理论上,如果深度学习效果不好,那么我们可以将learning rate调小,可以让所有模型效果变得更好,它可以让所有的loss下降。

Alt text

但是如果你的learning rate变得特别小,假如说是1e-9,那这样的结果就是w的变化会非常的慢,训练时间就变得很长。为了解决这个问题,就有一些比较简单的方法。

第一个,我们可以把learning rate和loss设置成一个相关的函数,例如说loss越小的时候,Learning rate越小,或者随着epoch的增大,loss越小。这个就叫learning rate的decay。

将learning rate或者训练次数和loss设置成一个相关的函数,那么越到后面效果越好的时候,learning rate就会越小。

还有,我们可能会发现loss连续k次不下降,那我们就可以提前结束训练过程,这个就是Early Stopping。

也就是当你发现loss连续k次不下降,或者甚至于在上升,那么这个时候,就可以将最优的这个值给它记录下来。

咱们可能会经常出现的情况就是值在那里震荡,本来呢已经快接近于最优点了,可是震荡了几次之后,还可能震荡出去了,loss变大了。或者就一直在这个震荡里边出不去,这个时候多学习也没有用,所以就可以早点停止,这个就是Early Stopping,中文有人称呼它为早停方法。

好,下节课,咱们要讲一个重点,也是一个难点。就是咱们做机器学习的时候,不同的优化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/184086.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机网络】虚拟路由冗余(VRRP)协议原理与配置

目录 1、VRRP虚拟路由器冗余协议 1.1、协议作用 1.2、名词解释 1.3、简介 1.4、工作原理 1.5、应用实例 2、 VRRP配置 2.1、配置命令 1、VRRP虚拟路由器冗余协议 1.1、协议作用 虚拟路由冗余协议(Virtual Router Redundancy Protocol,简称VRRP)是由IETF…

Linux学习笔记 CenOS6.3 yum No package xxx available

环境CenOS [roothncuc ~]# cat /etc/issue CentOS release 6.2 (Final) Kernel \r on an \m安装gcc的时候提示没有包 [roothncuc ~]# sudo yum install gcc gcc-c libstdc-devel Loaded plugins: refresh-packagekit, security Setting up Install Process No package gcc a…

temu的产品发布后在哪里显示

temu是一款备受瞩目的产品,其发布后引起了广泛的关注。但是,很多人对于temu产品发布后在哪里显示存在疑惑。本文将深入探讨temu产品的展示方式和关键特点,帮助读者更好地了解temu产品在发布后的展示位置。 先给大家推荐一款拼多多/temu运营工…

【报错栏】(Vue) Invalid handler for event “click“: got undefined

Property or method "add" is not defined on the instance but referenced during render. 翻译: 属性或方法“add”未在实例上定义,但在渲染期间引用。 Invalid handler for event "click": got undefined 翻译: …

用bat制作图片马——一句话木马

效果图 代码 ECHO OFF TITLE PtoR MODE con COLS55 LINES25 color 0A:main cls echo.当前时间:%date% %time% echo.欢迎使用图片马制作工具 echo.请确保图片和php在同一路径下 echo.echo 请将图像文件拖放到此窗口并按 Enter: set /p "imagefile&q…

肖sir__搭建环境报错:com.alibaba.druid:type=DruidDataSourceStat异常

报错现象: 解决方案: 同一个服务器配置多个tomcat,而这些tomcat里边的项目配置的数据库连接池都是用alibaba.druid。下面说下我的解决过程,首先,修改tomcat bin目录下的catalina.sh,添加如下代码: 代码如…

Siemens S7-300主站Profibus网络设定以及OMRON设定

1.100L流量秤,历史值,D3426,D3427,7位 2.次数,D166,D177,5位 3.PROFIBUS地址03# 1.FA1,历史值,D3426,D3427,6位 2.包数区,D166,D177,5位 3.PROFIB…

前端:实现二级菜单(点击实现二级菜单展开)

效果 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-width, i…

【趣味篇】Scratch之windows11系统

【作品展示】windows11系统 操作&#xff1a;点击小绿旗进入windows11主页面&#xff0c;不仅是能打开浏览器&#xff0c;还可以进行背景切换等功能。

大数据——一文详解数据仓库概念(数据仓库的分层概念和维度建模详解)

1、ods是什么&#xff1f; ods层最好理解&#xff0c;基本上就是数据从源表拉过来&#xff0c;进行etl&#xff0c;比如MySQL映射到Hive&#xff0c;那么到了Hive里面就是ods层。ods全称是 Operational Data Store&#xff0c;操作数据存储——“面向主题的”&#xff0c;数据…

突破界限:R200科研无人车,开辟研究新天地

提到科研无人车&#xff0c;大家可能首先想到的是其在自动驾驶和其他先进技术领域的应用。然而&#xff0c;随着科技的不断进步&#xff0c;科研无人车已经在智慧城市建设、商业服务、地质勘探、环境保护、农业技术革新、灾害应急和自动化服务等多个领域发挥着至关重要的作用。…

Linux MTR(My TraceRoute)command

Internet上有许多小型网络测试工具:Ping、Traceroute、Dig、Host等。 但是&#xff0c;这些工具的功能都比较单一。今天会给大家分享一个包含ping和traceroute功能的工具&#xff1a;MTR 文章目录 什么是MTR&#xff1f;MTR可以提供哪些功能Linux MTR可用选项Linux MTR用法推荐…

【UGUI】事件侦听EventSystem系统0学

前言介绍 EventSystem是Unity UGUI中的一个重要组件&#xff0c;用于处理用户输入事件&#xff0c;如点击、拖拽、滚动等。它负责将用户输入事件传递给合适的UI元素&#xff0c;并触发相应的事件回调函数&#xff08;就是你想要做的事情&#xff0c;自定义函数&#xff09;。 …

FPGA程序执行相关知识点

1.目前&#xff0c;大多数FPGA芯片是基于 SRAM 的结构的&#xff0c; 而 SRAM 单元中的数据掉电就会丢失&#xff0c;因此系统上电后&#xff0c;必须要由配置电路将正确的配置数据加载到 SRAM 中&#xff0c;此后 FPGA 才能够正常的运行。 常见的配置芯片有EPCS 芯片 &#x…

最新报告!11月美国市场的“遥遥领先”来了,该爆的单总会来!

今年周期最长的大促节点已接近尾声&#xff0c;美区市场的11月份的商品销售战绩已全面来袭&#xff1a; 保健类目竟弯道超车&#xff0c;交出了将近翻倍的成绩单&#xff1b;美妆个护、女装与女士内衣等“她经济”类目持续高涨且“辣眼”单品不断&#xff1b;家居大类目下的市…

JenKins快速安装与使用,Gitlab自动触发Jenkins

一、JenKins 0.准备&#xff0c;配置好环境 1&#xff09;Git&#xff08;yum安装&#xff09; 2&#xff09;JDK&#xff08;自行下载&#xff09; 3&#xff09;Jenkins&#xff08;自行下载&#xff09; 1.下载安装包 进官网&#xff0c;点Download下方即可下载。要下…

使用netconf配置华为设备

实验目的&#xff1a; 公司有一台CE12800的设备&#xff0c;管理地址位172.16.1.2&#xff0c;现在需要编写自动化脚本&#xff0c;通过SSH登陆到设备上配置netconf协议的用户名&#xff0c;密码以及netconf服务&#xff0c;并且通过netconf协议将设备的loopback0接口IP地址配…

一文读懂Asyncio

什么是Asyncio asyncio 是用来编写并发代码的库&#xff0c;使用async/await语法。 asyncio 被用作多个提供高性能 Python 异步框架的基础&#xff0c;包括网络和网站服务&#xff0c;数据库连接库&#xff0c;分布式任务队列等等。 asyncio 往往是构建 IO 密集型和高层级结构化…

Linux创建与编辑视图

本博客将会详细讲解如何在Linux中如何编辑配置文件 输出重定向 对于一台设备而言&#xff0c;存在着两种设备&#xff0c;分别负责输入与输出&#xff1a; 显示器&#xff08;输出设备>&#xff09; 与 键盘&#xff08;输入设备<&#xff09; 对于Linux系统而言&#…

深入理解 Vue 中的指针操作(二)

文章目录 ☘️引言☘️基本用法&#x1f342;v-for指令&#x1f342;v-model指令&#x1f331;v-model适用表单控件&#x1f331;修饰符&#x1f9c4;.lazy 修饰符&#x1f9c4;.number 修饰符&#x1f9c4;.trim 修饰符 ☘️结论 ☘️引言 Vue.js 是一款非常流行且功能强大的…