深度学习优化算法-Adam算法

Adam算法

Adam算法在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均。Adam算法可以看做是RMSProp算法与动量法的结合

算法内容

Adam算法使用了动量变量vt\boldsymbol{v}_tvt和RMSProp算法中小批量随机梯度按元素平方的指数加权移动平均变量st\boldsymbol{s}_tst,并在时间步0将它们中每个元素初始化为0。

  • 给定超参数0≤β1<10 \leq \beta_1 < 10β1<1(算法作者建议设为0.9)

时间步ttt的动量变量vt\boldsymbol{v}_tvt即小批量随机梯度gt\boldsymbol{g}_tgt的指数加权移动平均:

vt←β1vt−1+(1−β1)gt.\boldsymbol{v}_t \leftarrow \beta_1 \boldsymbol{v}_{t-1} + (1 - \beta_1) \boldsymbol{g}_t. vtβ1vt1+(1β1)gt.

和RMSProp算法中一样,给定超参数0≤β2<10 \leq \beta_2 < 10β2<1(算法作者建议设为0.999), 将小批量随机梯度按元素平方后的项gt⊙gt\boldsymbol{g}_t \odot \boldsymbol{g}_tgtgt做指数加权移动平均得到st\boldsymbol{s}_tst

st←β2st−1+(1−β2)gt⊙gt.\boldsymbol{s}_t \leftarrow \beta_2 \boldsymbol{s}_{t-1} + (1 - \beta_2) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stβ2st1+(1β2)gtgt.

由于我们将v0\boldsymbol{v}_0v0s0\boldsymbol{s}_0s0中的元素都初始化为0, 在时间步ttt我们得到

vt=(1−β1)∑i=1tβ1t−igi\boldsymbol{v}_t = (1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} \boldsymbol{g}_ivt=(1β1)i=1tβ1tigi

将过去各时间步小批量随机梯度的权值相加,得到
(1−β1)∑i=1tβ1t−i=1−β1t(1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} = 1 - \beta_1^t(1β1)i=1tβ1ti=1β1t

需要注意的是,当ttt较小时,过去各时间步小批量随机梯度权值之和会较小。

例如,当β1=0.9\beta_1 = 0.9β1=0.9时,v1=0.1g1\boldsymbol{v}_1 = 0.1\boldsymbol{g}_1v1=0.1g1。为了消除这样的影响,对于任意时间步ttt,我们可以将vt\boldsymbol{v}_tvt再除以1−β1t1 - \beta_1^t1β1t,从而使过去各时间步小批量随机梯度权值之和为1。这也叫作偏差修正。在Adam算法中,我们对变量vt\boldsymbol{v}_tvtst\boldsymbol{s}_tst均作偏差修正:

v^t←vt1−β1t,\hat{\boldsymbol{v}}_t \leftarrow \frac{\boldsymbol{v}_t}{1 - \beta_1^t}, v^t1β1tvt,

s^t←st1−β2t.\hat{\boldsymbol{s}}_t \leftarrow \frac{\boldsymbol{s}_t}{1 - \beta_2^t}. s^t1β2tst.

接下来,Adam算法使用以上偏差修正后的变量v^t\hat{\boldsymbol{v}}_tv^ts^t\hat{\boldsymbol{s}}_ts^t,将模型参数中每个元素的学习率通过按元素运算重新调整:

gt′←ηv^ts^t+ϵ,\boldsymbol{g}_t' \leftarrow \frac{\eta \hat{\boldsymbol{v}}_t}{\sqrt{\hat{\boldsymbol{s}}_t} + \epsilon},gts^t+ϵηv^t,

其中η\etaη是学习率,ϵ\epsilonϵ是为了维持数值稳定性而添加的常数,如10−810^{-8}108。和AdaGrad算法、RMSProp算法以及AdaDelta算法一样,目标函数自变量中每个元素都分别拥有自己的学习率。最后,使用gt′\boldsymbol{g}_t'gt迭代自变量:

xt←xt−1−gt′.\boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}_t'. xtxt1gt.

实现Adam优化算法

def get_data_ch7():  data = np.genfromtxt('data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)
%matplotlib inline
import torch
import sysfeatures, labels = get_data_ch7()def init_adam_states():v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):v[:] = beta1 * v + (1 - beta1) * p.grad.datas[:] = beta2 * s + (1 - beta2) * p.grad.data**2v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p.data -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)hyperparams['t'] += 1

使用学习率为0.01的Adam算法来训练模型。

def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = linreg, squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))set_figsize()plt.plot(np.linspace(0, num_epochs, len(ls)), ls)plt.xlabel('epoch')plt.ylabel('loss')
train_ch7(adam, init_adam_states(), {'lr': 0.01, 't': 1}, features, labels)

也可以使用pytorch内置的optim.Adam实现:

train_pytorch_ch7(torch.optim.Adam, {'lr': 0.01}, features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507956.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习-计算机视觉--图像增广

图像增广 大规模数据集是成功应用深度神经网络的前提。图像增广&#xff08;image augmentation&#xff09;技术通过对训练图像做一系列随机改变&#xff0c;来产生相似但又不同的训练样本&#xff0c;从而扩大训练数据集的规模。 图像增广的另一种解释是&#xff0c;随机改…

pytorch深度学习-微调(fine tuning)

微调&#xff08;fine tuning&#xff09; 首先举一个例子&#xff0c;假设我们想从图像中识别出不同种类的椅子&#xff0c;然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子&#xff0c;为每种椅子拍摄1,000张不同角度的图像&#xff0c;然后在收集到的图像…

c语言封闭曲线分割平面_高手的平面课堂:8种常用的设计排版方式,告别通宵加班...

重复、对比、对齐以及亲密性是传统平面排版的四大原则&#xff0c;即将元素重复运用(包括颜色、形状、材质、字体、空间关系等)以增加画面的条理性和整体性&#xff1b;避免页面上的元素形态与关系构建过于相似&#xff1b;画面上的每一元素都应该与另一个元素存在某种视觉联系…

我的世界java版和基岩版对比_基岩版Beta1.11.0.1发布

本帖来自好游快爆-我的世界精选推荐原帖作者:好游快爆用户3302482我的世界基岩版1.11.0.1测试版发布了&#xff0c;Minecraft基岩版1.11仍未发布&#xff0c;1.11.0.1为测试版本&#xff0c;Beta版本可能不稳定&#xff0c;并不代表最终版本质量&#xff0c;请在加入测试版之前…

机器人电焊电流电压怎么调_【华光】HG1000型电焊机现场校准仪

机器简介HG-1000型电焊机现场校准仪是依据检定规程JJG124&#xff0d;2005《电流表、电压表、功率表和电阻表检定规程》、JJG(航天)38-1987《直流标准电流源检定规程》、JJG(航天)51-1999《交流标准电流源检定规程》的要求而设计的校准设备。主要用来校验各种用电焊机(如交流手…

循环机换变速箱油教程_变速箱油用循环机换还是重力换更好?一次讲清楚,新手司机学学...

现在换变速箱油有些只要几百块钱&#xff0c;有些要一两千&#xff0c;之所以差价这么大是因为这里面涉及到换变速箱油时用什么方法去换油的问题。目前比较常见换油法是重力换油法和循环换油法。重力换油法就跟平时换机油是一样的&#xff0c;把变速箱底部的螺丝拧开之后让油滴…

pytorch深度学习-机器视觉-目标检测和边界框简介

机器视觉之目标检测和边界框简介 在图像分类任务里&#xff0c;我们假设图像里只有一个主体目标&#xff0c;并关注如何识别该目标的类别。然而&#xff0c;很多时候图像里有多个我们感兴趣的目标&#xff0c;我们不仅想知道它们的类别&#xff0c;还想得到它们在图像中的具体…

消防荷载楼板按弹性还是塑性计算_第二节 消防登高面、消防救援场地和灭火救援窗...

一、定义1、消防登高面&#xff1a;登高消防车能够靠近高层主体建筑&#xff0c;便于消防车作业和消防人员进入高层建筑进行抢救人员和扑救火灾的建筑立面称为该建筑的消防登高面&#xff0c;也称建筑的消防扑救面。2、消防救援场地&#xff1a;在高层建筑的消防登高面一侧&…

深度学习-词嵌入(word2vec)

词嵌入&#xff08;word2vec&#xff09; 自然语言是一套用来表达含义的复杂系统。在这套系统中&#xff0c;词是表义的基本单元。顾名思义&#xff0c;词向量是用来表示词的向量&#xff0c;也可被认为是词的特征向量或表征。把词映射为实数域向量的技术也叫词嵌入&#xff0…

ggplot2箱式图两两比较_作图技巧024篇ggplot2在循环中的坑

“ggplot2在循环中的输出”生活科学哥-R语言科学 2020-12-23 8:28ggplot2用过之后&#xff0c;你肯定会爱上它&#xff1b;结合一些不错的包&#xff0c;可以得到非常有展现力的图片&#xff0c;但是呢&#xff0c;有时也会碰到一些奇怪的情况。今天来们来看看&#xff0c;其中…

character-level OCR之Character Region Awareness for Text Detection(CRAFT) 论文阅读

Character Region Awareness for Text Detection 论文阅读 论文地址(arXiv) &#xff0c;pytorch版本代码地址 最近在看一些OCR的问题&#xff0c;CRAFT是在场景OCR中效果比较好的模型&#xff0c;记录一下论文的阅读 已有的文本检测工作大致如下&#xff1a; 基于回归的文…

c# wpf 面试_【远程面试】九强通信 | 九洲电器集团全资子公司

成都IT内推圈成立于2016年,专注成都IT互联网领域的招聘与求职;覆盖精准IT人群10W,通过内推圈推荐且已入职人数超过5000,合作公司均系成都知名或靠谱公司.此公众号每天7:30AM准时推送当天职位详情,敬请关注并置顶&#xff01;岗位投递一、登陆内推圈官网: www.itneituiquan.com,…

ViT(Vision Transformer)学习

ViT(Vison Transformer)学习 Paper:An image is worth 1616 words: transformers for image recognition at scale. In ICLR, 2021. Transformer 在 NLP领域大放异彩&#xff0c;并且随着模型和数据集的不断增长&#xff0c;仍然没有表现出饱和的迹象。这使得使用更大规模的数…

cpri带宽不足的解决方法_u盘容量不足怎么办 u盘容量不足解决方法【介绍】

我们在使用u盘的时候总能碰到各种各样的问题&#xff0c;其中u盘容量不足问题也是神烦&#xff0c;很多时候打开并没有发现有文件存在&#xff0c;但是在你存文件的时候又被提示u盘容量不足无法操作&#xff0c;关于这个问题u启动通过整理和大家一起分享下解决办法。1、u盘里的…

复合的赋值运算符例题_Java学习:运算符的使用与注意事项

运算符的使用与注意事项四则运算当中的加号“”有常见的三种用法&#xff1a;对于数值来&#xff0c;那就是加法。对于字符char类型来说&#xff0c;在计算之前&#xff0c;char会被提升成为int&#xff0c;然后再计算。char类型字符&#xff0c;和int类型数字之间的对照关系比…

腾讯会议如何使用讲演者模式进行汇报(nian gao)

腾讯会议如何使用讲演者模式进行汇报&#xff08;nian gao&#xff09; 首先列出步骤&#xff0c;再一一演示&#xff1a; altf5 开启讲演者模式&#xff0c;调整讲演者模式的窗口为小窗alttab 切换回腾讯会议界面&#xff0c;屏幕共享power point窗口&#xff08;注意不是“…

bulk这个词的用法_15、形容词与副词(二)比较的用法

初中英语语法——形容词与副词(二)比较的用法语法解释1、形容词与副词比较级和最高级的规则变化单音节词与部分双音节词&#xff1a;(1)一般情况加-er&#xff0c;-estlong-longer-longest strong-stronger-strongestclean-cleaner-cleanest(2)以不发音的e结尾的词&#xff0c;…

retinex 的水下图像增强算法_图像增强论文:腾讯优图CVPR2019

Underexposed Photo Enhancement using Deep Illumination Estimation基于深度学习优化光照的暗光下的图像增强论文地址&#xff1a;Underexposed Photo Enhancement using Deep Illumination Estimation暗光拍照也清晰&#xff0c;这是手机厂商目前激烈竞争的新拍照目标。提出…

python 实现 BCH 纠错码的方法

python 实现 BCH 纠错码的方法 BCH码是一类重要的纠错码&#xff0c;它把信源待发的信息序列按固定的κ位一组划分成消息组&#xff0c;再将每一消息组独立变换成长为n(n>κ)的二进制数字组&#xff0c;称为码字。如果消息组的数目为M(显然M>2),由此所获得的M个码字的全…

结构体引用_C/C++结构体完全攻略

结构体是一个由程序员定义的数据类型&#xff0c;可以容纳许多不同的数据值。在过去&#xff0c;面向对象编程的应用尚未普及之前&#xff0c;程序员通常使用这些从逻辑上连接在一起的数据组合到一个单元中。一旦结构体类型被声明并且其数据成员被标识&#xff0c;即可创建该类…