【深度学习笔记】7_4 动量法momentum

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.4 动量法

在7.2节(梯度下降和随机梯度下降)中我们提到,目标函数有关自变量的梯度代表了目标函数在自变量当前位置下降最快的方向。因此,梯度下降也叫作最陡下降(steepest descent)。在每次迭代中,梯度下降根据自变量当前位置,沿着当前位置的梯度更新自变量。然而,如果自变量的迭代方向仅仅取决于自变量当前位置,这可能会带来一些问题。

7.4.1 梯度下降的问题

让我们考虑一个输入和输出分别为二维向量 x = [ x 1 , x 2 ] ⊤ \boldsymbol{x} = [x_1, x_2]^\top x=[x1,x2]和标量的目标函数 f ( x ) = 0.1 x 1 2 + 2 x 2 2 f(\boldsymbol{x})=0.1x_1^2+2x_2^2 f(x)=0.1x12+2x22。与7.2节中不同,这里将 x 1 2 x_1^2 x12系数从 1 1 1减小到了 0.1 0.1 0.1。下面实现基于这个目标函数的梯度下降,并演示使用学习率为 0.4 0.4 0.4时自变量的迭代轨迹。

%matplotlib inline
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
import torcheta = 0.4 # 学习率def f_2d(x1, x2):return 0.1 * x1 ** 2 + 2 * x2 ** 2def gd_2d(x1, x2, s1, s2):return (x1 - eta * 0.2 * x1, x2 - eta * 4 * x2, 0, 0)d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

输出:

epoch 20, x1 -0.943467, x2 -0.000073

在这里插入图片描述

可以看到,同一位置上,目标函数在竖直方向( x 2 x_2 x2轴方向)比在水平方向( x 1 x_1 x1轴方向)的斜率的绝对值更大。因此,给定学习率,梯度下降迭代自变量时会使自变量在竖直方向比在水平方向移动幅度更大。那么,我们需要一个较小的学习率从而避免自变量在竖直方向上越过目标函数最优解。然而,这会造成自变量在水平方向上朝最优解移动变慢。

下面我们试着将学习率调得稍大一点,此时自变量在竖直方向不断越过最优解并逐渐发散。

eta = 0.6
d2l.show_trace_2d(f_2d, d2l.train_2d(gd_2d))

输出:

epoch 20, x1 -0.387814, x2 -1673.365109

在这里插入图片描述

7.4.2 动量法

动量法的提出是为了解决梯度下降的上述问题。由于小批量随机梯度下降比梯度下降更为广义,本章后续讨论将沿用7.3节(小批量随机梯度下降)中时间步 t t t的小批量随机梯度 g t \boldsymbol{g}_t gt的定义。设时间步 t t t的自变量为 x t \boldsymbol{x}_t xt,学习率为 η t \eta_t ηt
在时间步 0 0 0,动量法创建速度变量 v 0 \boldsymbol{v}_0 v0,并将其元素初始化成0。在时间步 t > 0 t>0 t>0,动量法对每次迭代的步骤做如下修改:

v t ← γ v t − 1 + η t g t , x t ← x t − 1 − v t , \begin{aligned} \boldsymbol{v}_t &\leftarrow \gamma \boldsymbol{v}_{t-1} + \eta_t \boldsymbol{g}_t, \\ \boldsymbol{x}_t &\leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{v}_t, \end{aligned} vtxtγvt1+ηtgt,xt1vt,

其中,动量超参数 γ \gamma γ满足 0 ≤ γ < 1 0 \leq \gamma < 1 0γ<1。当 γ = 0 \gamma=0 γ=0时,动量法等价于小批量随机梯度下降。

在解释动量法的数学原理前,让我们先从实验中观察梯度下降在使用动量法后的迭代轨迹。

def momentum_2d(x1, x2, v1, v2):v1 = gamma * v1 + eta * 0.2 * x1v2 = gamma * v2 + eta * 4 * x2return x1 - v1, x2 - v2, v1, v2eta, gamma = 0.4, 0.5
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

输出:

epoch 20, x1 -0.062843, x2 0.001202

在这里插入图片描述

可以看到使用较小的学习率 η = 0.4 \eta=0.4 η=0.4和动量超参数 γ = 0.5 \gamma=0.5 γ=0.5时,动量法在竖直方向上的移动更加平滑,且在水平方向上更快逼近最优解。下面使用较大的学习率 η = 0.6 \eta=0.6 η=0.6,此时自变量也不再发散。

eta = 0.6
d2l.show_trace_2d(f_2d, d2l.train_2d(momentum_2d))

输出:

epoch 20, x1 0.007188, x2 0.002553

在这里插入图片描述

7.4.2.1 指数加权移动平均

为了从数学上理解动量法,让我们先解释一下指数加权移动平均(exponentially weighted moving average)。给定超参数 0 ≤ γ < 1 0 \leq \gamma < 1 0γ<1,当前时间步 t t t的变量 y t y_t yt是上一时间步 t − 1 t-1 t1的变量 y t − 1 y_{t-1} yt1和当前时间步另一变量 x t x_t xt的线性组合:

y t = γ y t − 1 + ( 1 − γ ) x t . y_t = \gamma y_{t-1} + (1-\gamma) x_t. yt=γyt1+(1γ)xt.

我们可以对 y t y_t yt展开:

y t = ( 1 − γ ) x t + γ y t − 1 = ( 1 − γ ) x t + ( 1 − γ ) ⋅ γ x t − 1 + γ 2 y t − 2 = ( 1 − γ ) x t + ( 1 − γ ) ⋅ γ x t − 1 + ( 1 − γ ) ⋅ γ 2 x t − 2 + γ 3 y t − 3 … \begin{aligned} y_t &= (1-\gamma) x_t + \gamma y_{t-1}\\ &= (1-\gamma)x_t + (1-\gamma) \cdot \gamma x_{t-1} + \gamma^2y_{t-2}\\ &= (1-\gamma)x_t + (1-\gamma) \cdot \gamma x_{t-1} + (1-\gamma) \cdot \gamma^2x_{t-2} + \gamma^3y_{t-3}\\ &\ldots \end{aligned} yt=(1γ)xt+γyt1=(1γ)xt+(1γ)γxt1+γ2yt2=(1γ)xt+(1γ)γxt1+(1γ)γ2xt2+γ3yt3

n = 1 / ( 1 − γ ) n = 1/(1-\gamma) n=1/(1γ),那么 ( 1 − 1 / n ) n = γ 1 / ( 1 − γ ) \left(1-1/n\right)^n = \gamma^{1/(1-\gamma)} (11/n)n=γ1/(1γ)。因为

lim ⁡ n → ∞ ( 1 − 1 n ) n = exp ⁡ ( − 1 ) ≈ 0.3679 , \lim_{n \rightarrow \infty} \left(1-\frac{1}{n}\right)^n = \exp(-1) \approx 0.3679, nlim(1n1)n=exp(1)0.3679,

所以当 γ → 1 \gamma \rightarrow 1 γ1时, γ 1 / ( 1 − γ ) = exp ⁡ ( − 1 ) \gamma^{1/(1-\gamma)}=\exp(-1) γ1/(1γ)=exp(1),如 0.9 5 20 ≈ exp ⁡ ( − 1 ) 0.95^{20} \approx \exp(-1) 0.9520exp(1)。如果把 exp ⁡ ( − 1 ) \exp(-1) exp(1)当作一个比较小的数,我们可以在近似中忽略所有含 γ 1 / ( 1 − γ ) \gamma^{1/(1-\gamma)} γ1/(1γ)和比 γ 1 / ( 1 − γ ) \gamma^{1/(1-\gamma)} γ1/(1γ)更高阶的系数的项。例如,当 γ = 0.95 \gamma=0.95 γ=0.95时,

y t ≈ 0.05 ∑ i = 0 19 0.9 5 i x t − i . y_t \approx 0.05 \sum_{i=0}^{19} 0.95^i x_{t-i}. yt0.05i=0190.95ixti.

因此,在实际中,我们常常将 y t y_t yt看作是对最近 1 / ( 1 − γ ) 1/(1-\gamma) 1/(1γ)个时间步的 x t x_t xt值的加权平均。例如,当 γ = 0.95 \gamma = 0.95 γ=0.95时, y t y_t yt可以被看作对最近20个时间步的 x t x_t xt值的加权平均;当 γ = 0.9 \gamma = 0.9 γ=0.9时, y t y_t yt可以看作是对最近10个时间步的 x t x_t xt值的加权平均。而且,离当前时间步 t t t越近的 x t x_t xt值获得的权重越大(越接近1)。

7.4.2.2 由指数加权移动平均理解动量法

现在,我们对动量法的速度变量做变形:

v t ← γ v t − 1 + ( 1 − γ ) ( η t 1 − γ g t ) . \boldsymbol{v}_t \leftarrow \gamma \boldsymbol{v}_{t-1} + (1 - \gamma) \left(\frac{\eta_t}{1 - \gamma} \boldsymbol{g}_t\right). vtγvt1+(1γ)(1γηtgt).

由指数加权移动平均的形式可得,速度变量 v t \boldsymbol{v}_t vt实际上对序列 { η t − i g t − i / ( 1 − γ ) : i = 0 , … , 1 / ( 1 − γ ) − 1 } \{\eta_{t-i}\boldsymbol{g}_{t-i} /(1-\gamma):i=0,\ldots,1/(1-\gamma)-1\} {ηtigti/(1γ):i=0,,1/(1γ)1}做了指数加权移动平均。换句话说,相比于小批量随机梯度下降,动量法在每个时间步的自变量更新量近似于将最近 1 / ( 1 − γ ) 1/(1-\gamma) 1/(1γ)个时间步的普通更新量(即学习率乘以梯度)做了指数加权移动平均后再除以 1 − γ 1-\gamma 1γ。所以,在动量法中,自变量在各个方向上的移动幅度不仅取决当前梯度,还取决于过去的各个梯度在各个方向上是否一致。在本节之前示例的优化问题中,所有梯度在水平方向上为正(向右),而在竖直方向上时正(向上)时负(向下)。这样,我们就可以使用较大的学习率,从而使自变量向最优解更快移动。

7.4.3 从零开始实现

相对于小批量随机梯度下降,动量法需要对每一个自变量维护一个同它一样形状的速度变量,且超参数里多了动量超参数。实现中,我们将速度变量用更广义的状态变量states表示。

features, labels = d2l.get_data_ch7()def init_momentum_states():v_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)v_b = torch.zeros(1, dtype=torch.float32)return (v_w, v_b)def sgd_momentum(params, states, hyperparams):for p, v in zip(params, states):v.data = hyperparams['momentum'] * v.data + hyperparams['lr'] * p.grad.datap.data -= v.data

我们先将动量超参数momentum设0.5,这时可以看成是特殊的小批量随机梯度下降:其小批量随机梯度为最近2个时间步的2倍小批量梯度的加权平均。

注:个人认为这里不应该是“加权平均”而应该是“加权和”,因为根据7.4.2.2节分析,加权平均最后除以了 1 − γ 1-\gamma 1γ,所以就相当于没有进行平均。

d2l.train_ch7(sgd_momentum, init_momentum_states(),{'lr': 0.02, 'momentum': 0.5}, features, labels)

输出:

loss: 0.245518, 0.042304 sec per epoch

在这里插入图片描述

将动量超参数momentum增大到0.9,这时依然可以看成是特殊的小批量随机梯度下降:其小批量随机梯度为最近10个时间步的10倍小批量梯度的加权平均。我们先保持学习率0.02不变。

同理,这里不应该是“加权平均”而应该是“加权和”。

d2l.train_ch7(sgd_momentum, init_momentum_states(),{'lr': 0.02, 'momentum': 0.9}, features, labels)

输出:

loss: 0.252046, 0.095708 sec per epoch

在这里插入图片描述

可见目标函数值在后期迭代过程中的变化不够平滑。直觉上,10倍小批量梯度比2倍小批量梯度大了5倍,我们可以试着将学习率减小到原来的1/5。此时目标函数值在下降了一段时间后变化更加平滑。

这也印证了刚刚的观点。

d2l.train_ch7(sgd_momentum, init_momentum_states(),{'lr': 0.004, 'momentum': 0.9}, features, labels)

输出:

loss: 0.242905, 0.073496 sec per epoch

在这里插入图片描述

7.4.4 简洁实现

在PyTorch中,只需要通过参数momentum来指定动量超参数即可使用动量法。

d2l.train_pytorch_ch7(torch.optim.SGD, {'lr': 0.004, 'momentum': 0.9},features, labels)

输出:

loss: 0.253280, 0.060247 sec per epoch

在这里插入图片描述

小结

  • 动量法使用了指数加权移动平均的思想。它将过去时间步的梯度做了加权平均,且权重按时间步指数衰减。
  • 动量法使得相邻时间步的自变量更新在方向上更加一致。

注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/738271.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ArcGIS】栅格数据进行标准化(归一化)处理

栅格数据进行标准化&#xff08;归一化&#xff09;处理 方法1&#xff1a;栅格计算器方法2&#xff1a;模糊分析参考 栅格数据进行标准化(归一化)处理 方法1&#xff1a;栅格计算器 栅格计算器&#xff08;Raster Calculator&#xff09; 方法2&#xff1a;模糊分析 空间…

Python实现图片(合并)转PDF

在日常的工作和学习过程当中,我相信很多人遇到过这样一个很普通的需求,就是将某一个图片转为PDF或者是将多个图片合并到一个PDF文件。但是,在苦苦搜寻一圈之后发现要么要下载软件,下载了还要注册,注册了还要VIP,甚至SVIP才能实现这样的需求! 今天,我带大家把这个功能打…

2024年华为HCIA-DATACOM新增题库(H12-811)

801、[单选题]178/832、在系统视图下键入什么命令可以切换到用户视图? A quit B souter C system-view D user-view 试题答案&#xff1a;A 试题解析&#xff1a;在系统视图下键入quit命令退出到用户视图。因此答案选A。 802、[单选题]“网络管理员在三层交换机上创建了V…

Kubernetes | 起源 | 组件详解

起源 起源&#xff1a; Kubernetes&#xff08;常简称为K8s&#xff09;起源于Google内部的Borg项目&#xff0c;是一个开源的容器编排引擎&#xff0c;于2014年首次对外发布。 Google Borg Google Borg 是 Google 内部开发和使用的大规模集群管理系统&#xff0c;用于管理和运…

Jmeter+Ant+Git/SVN+Jenkins实现持续集成接口测试,一文精通(二)

前言 上篇内容已经介绍接口测试流程以及了解如何用jmeter接口测试&#xff0c;本篇将介绍如何在实战中应用 一、Jmeter接口关联 1.使用正则表达式实现接口关联&#xff08;可以作用于任意值&#xff09; 如果说一个请求里面有多次请求服务器。 2.使用Jsonpath表达式实现接口关…

c++ primer plus笔记 第十八章 探讨c++新标准

复习前面的内容&#xff1a; 1.auto&#xff0c;可以自动识别auto本身在这种语境下是什么类型 2.decltype,让一个变量的类型和另外一个变量的类型相同 decltype(x) y;//让y的类型和x的类型相同 如何理解&#xff1f; decltype是一个关键词&#xff0c;其作用是检查括号内的…

Android studio虚拟调试出现“我的APP keeps stopping”问题

问题如图&#xff1a; 遇到这种情况&#xff0c;一看代码&#xff0c;也没有报错呀&#xff0c;怎么不能运行呢&#xff1f;不要慌&#xff01;我们一步一步来。 1、查看Logcat日志 在Android Studio中查看Logcat窗口&#xff0c;可以获取应用程序崩溃时的详细错误信息&…

【触想智能】工业触摸显示器在户外使用需要注意哪些问题?

工业显示器是智能制造领域应用比较广泛的电子产品&#xff0c;它广泛应用于工厂产线以及各种配套设备&#xff0c;在很大程度上提升了工厂的生产效率。 工业显示器按触摸方式分&#xff0c;可以分为工业触摸显示器和非触摸工业显示器两种;按使用环境分&#xff0c;又可以分为室…

几何变换 - 图像的缩放、翻转、仿射变换、透视等

1、前言 图像的几何变换是指改变图像的几何结构,大小、形状等等,让图像呈现出具备缩放、翻转、映射和透视的效果 图像的几何变换都比较复杂,计算也很复杂。 例如仿射变换,像素点的位置和灰度值都需要变换。 数字图像处理中利用后向传播的方法,将像素点变换后的位置通过…

腾讯云和阿里云4核8G云服务器多少钱一年和1个月费用对比

4核8G云服务器多少钱一年&#xff1f;阿里云ECS服务器u1价格955.58元一年&#xff0c;腾讯云轻量4核8G12M带宽价格是646元15个月&#xff0c;阿腾云atengyun.com整理4核8G云服务器价格表&#xff0c;包括一年费用和1个月收费明细&#xff1a; 云服务器4核8G配置收费价格 阿里…

案例分析篇08:Web架构设计相关20个考点(1~6)(2024年软考高级系统架构设计师冲刺知识点总结系列文章)

专栏系列文章推荐: 2024高级系统架构设计师备考资料(高频考点&真题&经验)https://blog.csdn.net/seeker1994/category_12601310.html 【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-…

golang学习随便记16-反射

为什么需要反射 下面的例子中编写一个 Sprint 函数&#xff0c;只有1个参数&#xff08;类型不定&#xff09;&#xff0c;返回和 fmt.Fprintf 类似的格式化后的字符串。实现方法大致为&#xff1a;如果参数类型本身实现了 String() 方法&#xff0c;那调用 String() 方法即可…

钡铼技术R40工业路由器4G WiFi一体,适用于各类工业场景

钡铼技术R40工业路由器是一款集4G网络连接和WiFi功能于一体的先进设备&#xff0c;旨在满足各类工业场景对稳定、高速网络连接的需求。作为一家致力于工业互联网解决方案的领先厂商&#xff0c;钡铼技术致力于为工业企业提供可靠的网络设备&#xff0c;以支持其数字化转型和智能…

OSI七层模型TCP四层模型横向对比

OSI 理论模型&#xff08;Open Systems Interconnection Model&#xff09;和TCP/IP模型 七层每一层对应英文 应用层&#xff08;Application Layer&#xff09; 表示层&#xff08;Presentation Layer&#xff09; 会话层&#xff08;Session Layer&#xff09; 传输层&#x…

02_electron快速建立项目

一、安装 yarn 在此之前可以先安装 git&#xff1a;Git - Downloads (git-scm.com) 下面就是 yarn 安装的代码&#xff0c;在终端输入即可。 npm install --global yarn 检查是否安装成功&#xff1a; yarn --version 二、快速建立一个electron项目 其实在Getting Started - …

MYSQL Unknown column ‘appreciation.latitude‘ in ‘where clause‘

问题 笔者编写mysql语句&#xff0c;执行报错 详细问题 笔者sql代码 SELECT ap.*, su.username, wh.wheat_name FROM appreciation ap LEFT JOIN sys_user su ON su.id ap.user_id LEFT JOIN wheat wh ON wh.id ap.crop_id WHERE appreciation.latitude 1报错信息 >…

LeetCode707:设计链表

题目描述 实现 MyLinkedList 类&#xff1a; MyLinkedList() 初始化 MyLinkedList 对象。 int get(int index) 获取链表中下标为 index 的节点的值。如果下标无效&#xff0c;则返回 -1 。 void addAtHead(int val) 将一个值为 val 的节点插入到链表中第一个元素之前。在插入完…

java关键字是什么?关键字有哪些?什么是常量?

1、关键字 &#xff08;1&#xff09;关键字概述&#xff1a;被java语言赋予了特定含义的单词。 &#xff08;2&#xff09;关键字特点&#xff1a; 关键字的字母全部小写&#xff1b;常用的代码编辑器&#xff0c;针对关键字有特殊的颜色标记&#xff0c;非常直观。 以IDE…

【QT+QGIS跨平台编译】之七十一:【QGIS_Analysis跨平台编译】—【qgsrastercalclexer.cpp生成】

文章目录 一、Flex二、生成来源三、构建过程一、Flex Flex (fast lexical analyser generator) 是 Lex 的另一个替代品。它经常和自由软件 Bison 语法分析器生成器 一起使用。Flex 最初由 Vern Paxson 于 1987 年用 C 语言写成。 “flex 是一个生成扫描器的工具,能够识别文本中…

机器学习之分类回归模型(决策数、随机森林)

回归分析 回归分析属于监督学习方法的一种&#xff0c;主要用于预测连续型目标变量&#xff0c;可以预测、计算趋势以及确定变量之间的关系等。 Regession Evaluation Metrics 以下是一些最流行的回归评估指标: 平均绝对误差(MAE):目标变量的预测值与实际值之间的平均绝对差…