【深度学习笔记】优化算法——Adam算法

Adam算法

🏷sec_adam

本章我们已经学习了许多有效优化的技术。
在本节讨论之前,我们先详细回顾一下这些技术:

  • 在 :numref:sec_sgd中,我们学习了:随机梯度下降在解决优化问题时比梯度下降更有效。
  • 在 :numref:sec_minibatch_sgd中,我们学习了:在一个小批量中使用更大的观测值集,可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键。
  • 在 :numref:sec_momentum中我们添加了一种机制,用于汇总过去梯度的历史以加速收敛。
  • 在 :numref:sec_adagrad中,我们通过对每个坐标缩放来实现高效计算的预处理器。
  • 在 :numref:sec_rmsprop中,我们通过学习率的调整来分离每个坐标的缩放。

Adam算法 :cite:Kingma.Ba.2014将所有这些技术汇总到一个高效的学习算法中。
不出预料,作为深度学习中使用的更强大和有效的优化算法之一,它非常受欢迎。
但是它并非没有问题,尤其是 :cite:Reddi.Kale.Kumar.2019表明,有时Adam算法可能由于方差控制不良而发散。
在完善工作中, :cite:Zaheer.Reddi.Sachan.ea.2018给Adam算法提供了一个称为Yogi的热补丁来解决这些问题。
下面我们了解一下Adam算法。

算法

Adam算法的关键组成部分之一是:它使用指数加权移动平均值来估算梯度的动量和二次矩,即它使用状态变量

v t ← β 1 v t − 1 + ( 1 − β 1 ) g t , s t ← β 2 s t − 1 + ( 1 − β 2 ) g t 2 . \begin{aligned} \mathbf{v}_t & \leftarrow \beta_1 \mathbf{v}_{t-1} + (1 - \beta_1) \mathbf{g}_t, \\ \mathbf{s}_t & \leftarrow \beta_2 \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2. \end{aligned} vtstβ1vt1+(1β1)gt,β2st1+(1β2)gt2.

这里 β 1 \beta_1 β1 β 2 \beta_2 β2是非负加权参数。
常将它们设置为 β 1 = 0.9 \beta_1 = 0.9 β1=0.9 β 2 = 0.999 \beta_2 = 0.999 β2=0.999
也就是说,方差估计的移动远远慢于动量估计的移动。
注意,如果我们初始化 v 0 = s 0 = 0 \mathbf{v}_0 = \mathbf{s}_0 = 0 v0=s0=0,就会获得一个相当大的初始偏差。
我们可以通过使用 ∑ i = 0 t β i = 1 − β t 1 − β \sum_{i=0}^t \beta^i = \frac{1 - \beta^t}{1 - \beta} i=0tβi=1β1βt来解决这个问题。
相应地,标准化状态变量由下式获得

v ^ t = v t 1 − β 1 t and  s ^ t = s t 1 − β 2 t . \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_1^t} \text{ and } \hat{\mathbf{s}}_t = \frac{\mathbf{s}_t}{1 - \beta_2^t}. v^t=1β1tvt and s^t=1β2tst.

有了正确的估计,我们现在可以写出更新方程。
首先,我们以非常类似于RMSProp算法的方式重新缩放梯度以获得

g t ′ = η v ^ t s ^ t + ϵ . \mathbf{g}_t' = \frac{\eta \hat{\mathbf{v}}_t}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}. gt=s^t +ϵηv^t.

与RMSProp不同,我们的更新使用动量 v ^ t \hat{\mathbf{v}}_t v^t而不是梯度本身。
此外,由于使用 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon} s^t +ϵ1而不是 1 s ^ t + ϵ \frac{1}{\sqrt{\hat{\mathbf{s}}_t + \epsilon}} s^t+ϵ 1进行缩放,两者会略有差异。
前者在实践中效果略好一些,因此与RMSProp算法有所区分。
通常,我们选择 ϵ = 1 0 − 6 \epsilon = 10^{-6} ϵ=106,这是为了在数值稳定性和逼真度之间取得良好的平衡。

最后,我们简单更新:

x t ← x t − 1 − g t ′ . \mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \mathbf{g}_t'. xtxt1gt.

回顾Adam算法,它的设计灵感很清楚:
首先,动量和规模在状态变量中清晰可见,
它们相当独特的定义使我们移除偏项(这可以通过稍微不同的初始化和更新条件来修正)。
其次,RMSProp算法中两项的组合都非常简单。
最后,明确的学习率 η \eta η使我们能够控制步长来解决收敛问题。

实现

从头开始实现Adam算法并不难。
为方便起见,我们将时间步 t t t存储在hyperparams字典中。
除此之外,一切都很简单。

%matplotlib inline
import torch
from d2l import torch as d2ldef init_adam_states(feature_dim):v_w, v_b = torch.zeros((feature_dim, 1)), torch.zeros(1)s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)+ eps)p.grad.data.zero_()hyperparams['t'] += 1

现在,我们用以上Adam算法来训练模型,这里我们使用 η = 0.01 \eta = 0.01 η=0.01的学习率。

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),{'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.244, 0.015 sec/epoch

在这里插入图片描述

此外,我们可以用深度学习框架自带算法应用Adam算法,这里我们只需要传递配置参数。

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer, {'lr': 0.01}, data_iter)
loss: 0.254, 0.015 sec/epoch

在这里插入图片描述

Yogi

Adam算法也存在一些问题:
即使在凸环境下,当 s t \mathbf{s}_t st的二次矩估计值爆炸时,它可能无法收敛。
:cite:Zaheer.Reddi.Sachan.ea.2018 s t \mathbf{s}_t st提出了的改进更新和参数初始化。
论文中建议我们重写Adam算法更新如下:

s t ← s t − 1 + ( 1 − β 2 ) ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \left(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\right). stst1+(1β2)(gt2st1).

每当 g t 2 \mathbf{g}_t^2 gt2具有值很大的变量或更新很稀疏时, s t \mathbf{s}_t st可能会太快地“忘记”过去的值。
一个有效的解决方法是将 g t 2 − s t − 1 \mathbf{g}_t^2 - \mathbf{s}_{t-1} gt2st1替换为 g t 2 ⊙ s g n ( g t 2 − s t − 1 ) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}) gt2sgn(gt2st1)
这就是Yogi更新,现在更新的规模不再取决于偏差的量。

s t ← s t − 1 + ( 1 − β 2 ) g t 2 ⊙ s g n ( g t 2 − s t − 1 ) . \mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}). stst1+(1β2)gt2sgn(gt2st1).

论文中,作者还进一步建议用更大的初始批量来初始化动量,而不仅仅是初始的逐点估计。

def yogi(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-3for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = s + (1 - beta2) * torch.sign(torch.square(p.grad) - s) * torch.square(p.grad)v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)+ eps)p.grad.data.zero_()hyperparams['t'] += 1data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),{'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.245, 0.015 sec/epoch

在这里插入图片描述

小结

  • Adam算法将许多优化算法的功能结合到了相当强大的更新规则中。
  • Adam算法在RMSProp算法基础上创建的,还在小批量的随机梯度上使用EWMA。
  • 在估计动量和二次矩时,Adam算法使用偏差校正来调整缓慢的启动速度。
  • 对于具有显著差异的梯度,我们可能会遇到收敛性问题。我们可以通过使用更大的小批量或者切换到改进的估计值 s t \mathbf{s}_t st来修正它们。Yogi提供了这样的替代方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736610.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣--动态规划5.最长回文子串

class Solution { public:string longestPalindrome(string s) {// 获取输入字符串的长度int n s.size();// 如果字符串长度为1,直接返回原字符串,因为任何单个字符都是回文串if (n 1)return s;// 创建一个二维数组dp,用于记录子串是否为回…

React-路由小知识

1.默认路由 说明:当访问的是一级路由时,默认的二级路由组件可以得到渲染,只需要在二级路由的位置去掉path,设置index.属性为true。 2.404路由 说明:当浏览器输入ul的路径在整个路由配置中都找不到对应的pth,为了用户体验&#x…

《农商网》商业计划书(附模板下载)

在当今互联网高速发展的时代,农业与电子商务的结合成为了新的经济增长点。《农商网》商业计划书详细阐述了一个以大学生创业为核心的创新项目,旨在通过打造一个全新的农产品在线交易平台,实现农产品的高效流通和价值最大化。该计划书首先对市…

amv是什么文件格式?如何播放amv视频?

AMV文件格式源自于中国公司Actions Semiconductor,最初作为其MP4播放器中使用的专有视频格式。产生于数码媒体发展的需求下,AMV格式为小屏幕便携设备提供了一种高度压缩的视频存储方案。 AMV文件格式的主要特性与使用场景 AMV格式以其独特的特性在小尺寸…

【活动】探索人工智能的“迷惑瞬间”:真实体验与技术挑战

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 标题:探索人工智能的“迷惑瞬间”:真实体验与技术挑战引言…

Elasticsearch:dense vector 数据类型及标量量化

密集向量(dense_vector)字段类型存储数值的密集向量。 密集向量场主要用于 k 最近邻 (kNN) 搜索。 dense_vector 类型不支持聚合或排序。 默认情况下,你可以基于 element_type 添加一个 dend_vector 字段作为 float 数值数组: …

学习Java的第七天

目录 一、什么是数组 二、作用 三、如何使用数组 1、声明数组变量 2、创建数组 示例: 3、数组的使用 示例: 4、数组的遍历 for循环示例(不知道for循环的可以查看我之前发的文章) for-each循环(也就是增强for…

Unity基础学习

目录 基础知识点3D数学——基础Mathf三角函数坐标系 3D数学——向量向量模长和单位向量向量的加减乘除向量点乘向量叉乘向量插值运算 3D数学——四元数为何使用四元数四元数是什么四元数常用方法四元数计算 MonoBehavior中的重要内容延迟函数协同程序协同程序原理 Resources资源…

STM32CubeIDE基础学习-STM32CubeIDE软件工程文件拷贝粘贴

STM32CubeIDE基础学习-STM32CubeIDE软件工程文件拷贝粘贴 前言 在后面开发程序时,往往不需要再重新新建工程的了,可以直接在原有的工程基础上直接复制粘贴新增功能就可以了。 具体的操作方法步骤如下介绍: 第一步:找到一个原有的…

力扣中档题的简单写法:在链表中插入最大公约数

其实暴力遍历开数组也可以,但不如以下新建链表块的方法简单 int FindCommDivisor(int num1, int num2) {int n;int i;n fmin(num1, num2);for (i n; i > 1; i--) {if (num1 % i 0 && num2 % i 0) {return i;}}return 0; }struct ListNode *insertGr…

Mock.js 基本语法与应用笔记

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

python 导入excel空间三维坐标 生成三维曲面地形图 5-3、线条平滑曲面且可通过面观察柱体变化(三)

环境 python:python-3.12.0-amd64 包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 scipy 1.12.0 import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.interpolate import griddata from matplotlib.c…

【SSM】整合原理和配置实战

文章目录 SSM整合是什么?SSM整合核心问题第一问:SSM整合需要几个IoC容器?第二问:每个IoC容器对应哪些类型组件?第三问:IoC容器之间关系和调用方向?第四问:具体多少配置类以及对应容器…

力扣hot100:22.括号生成(回溯)

复习一下: 回溯法解决的问题都可以抽象为树形结构。回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度,都构成的树的深度。 对于同一层而言,其儿子都是等价的不同情况,因此当儿…

【Poe】保姆级注册教程

AI聊天机器人已成为技术界的热点。Quora推出了其全新的AI聊天机器人应用——poe,为用户提供了一种新的与人工智能进行互动的方式。与其他常见的AI聊天机器人不同,poe支持多家公司的AI系统,例如OpenAI的ChatGPT和Anthropic的聊天机器人。本教程…

【零基础学习01】嵌入式linux驱动中pinctrl和gpio子系统实现

大家好,为了进一步提升大家对实验的认识程度,每个控制实验将加入详细控制思路与流程,欢迎交流学习。 今天给大家分享一下,linux系统里面pinctrl和gpio子系统控制实验,操作硬件为I.MX6ULL开发板。 第一:pinctrl和gpio子系统简介 Linux系统是一个庞大又完善的系统,如果采用…

Window部署Oracle并实现公网环境远程访问本地数据库

文章目录 前言1. 数据库搭建2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射 3. 公网远程访问4. 配置固定TCP端口地址4.1 保留一个固定的公网TCP端口地址4.2 配置固定公网TCP端口地址4.3 测试使用固定TCP端口地址远程Oracle 前言 Oracle,是甲骨文公司的一款关系…

基于单片机的机动车智能远光灯系统设计

目 录 摘 要 I Abstract II 引 言 1 1 主要研究内容及总体设计方案 3 1.1 主要研究内容 3 1.2 系统总体方案选择 3 1.3 系统功能的确定 4 2 硬件电路的设计 5 2.1 单片机控制模块设计 5 2.2 液晶显示模块电路设计 7 2.3 远近灯光电路设计 9 2.4 按键电路设计 9 2.5 超声波电路…

5G与智慧文旅的融合发展:推动旅游业转型升级与可持续发展

随着5G技术的飞速发展和广泛应用,其与智慧文旅的融合发展正成为推动旅游业转型升级与可持续发展的重要力量。5G技术以其高速率、低时延、大连接的特性,为智慧文旅注入了新的活力,助力旅游业实现更高效、更智能、更绿色的发展。本文将深入探讨…

保持长期高效的七个法则(一)7 Rules for Staying Productive Long-Term(1)

Easily the best habit I’ve ever started was to use a productivity system.The idea is simple:organizing all the stuff you need to do (and how you’re going to do it) prevents a lot of internal struggle to get things done. 无疑,我曾经建立过的最好…