深度学习优化算法-Adam算法

Adam算法

Adam算法在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均。Adam算法可以看做是RMSProp算法与动量法的结合

算法内容

Adam算法使用了动量变量vt\boldsymbol{v}_tvt和RMSProp算法中小批量随机梯度按元素平方的指数加权移动平均变量st\boldsymbol{s}_tst,并在时间步0将它们中每个元素初始化为0。

  • 给定超参数0≤β1<10 \leq \beta_1 < 10β1<1(算法作者建议设为0.9)

时间步ttt的动量变量vt\boldsymbol{v}_tvt即小批量随机梯度gt\boldsymbol{g}_tgt的指数加权移动平均:

vt←β1vt−1+(1−β1)gt.\boldsymbol{v}_t \leftarrow \beta_1 \boldsymbol{v}_{t-1} + (1 - \beta_1) \boldsymbol{g}_t. vtβ1vt1+(1β1)gt.

和RMSProp算法中一样,给定超参数0≤β2<10 \leq \beta_2 < 10β2<1(算法作者建议设为0.999), 将小批量随机梯度按元素平方后的项gt⊙gt\boldsymbol{g}_t \odot \boldsymbol{g}_tgtgt做指数加权移动平均得到st\boldsymbol{s}_tst

st←β2st−1+(1−β2)gt⊙gt.\boldsymbol{s}_t \leftarrow \beta_2 \boldsymbol{s}_{t-1} + (1 - \beta_2) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stβ2st1+(1β2)gtgt.

由于我们将v0\boldsymbol{v}_0v0s0\boldsymbol{s}_0s0中的元素都初始化为0, 在时间步ttt我们得到

vt=(1−β1)∑i=1tβ1t−igi\boldsymbol{v}_t = (1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} \boldsymbol{g}_ivt=(1β1)i=1tβ1tigi

将过去各时间步小批量随机梯度的权值相加,得到
(1−β1)∑i=1tβ1t−i=1−β1t(1-\beta_1) \sum_{i=1}^t \beta_1^{t-i} = 1 - \beta_1^t(1β1)i=1tβ1ti=1β1t

需要注意的是,当ttt较小时,过去各时间步小批量随机梯度权值之和会较小。

例如,当β1=0.9\beta_1 = 0.9β1=0.9时,v1=0.1g1\boldsymbol{v}_1 = 0.1\boldsymbol{g}_1v1=0.1g1。为了消除这样的影响,对于任意时间步ttt,我们可以将vt\boldsymbol{v}_tvt再除以1−β1t1 - \beta_1^t1β1t,从而使过去各时间步小批量随机梯度权值之和为1。这也叫作偏差修正。在Adam算法中,我们对变量vt\boldsymbol{v}_tvtst\boldsymbol{s}_tst均作偏差修正:

v^t←vt1−β1t,\hat{\boldsymbol{v}}_t \leftarrow \frac{\boldsymbol{v}_t}{1 - \beta_1^t}, v^t1β1tvt,

s^t←st1−β2t.\hat{\boldsymbol{s}}_t \leftarrow \frac{\boldsymbol{s}_t}{1 - \beta_2^t}. s^t1β2tst.

接下来,Adam算法使用以上偏差修正后的变量v^t\hat{\boldsymbol{v}}_tv^ts^t\hat{\boldsymbol{s}}_ts^t,将模型参数中每个元素的学习率通过按元素运算重新调整:

gt′←ηv^ts^t+ϵ,\boldsymbol{g}_t' \leftarrow \frac{\eta \hat{\boldsymbol{v}}_t}{\sqrt{\hat{\boldsymbol{s}}_t} + \epsilon},gts^t+ϵηv^t,

其中η\etaη是学习率,ϵ\epsilonϵ是为了维持数值稳定性而添加的常数,如10−810^{-8}108。和AdaGrad算法、RMSProp算法以及AdaDelta算法一样,目标函数自变量中每个元素都分别拥有自己的学习率。最后,使用gt′\boldsymbol{g}_t'gt迭代自变量:

xt←xt−1−gt′.\boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}_t'. xtxt1gt.

实现Adam优化算法

def get_data_ch7():  data = np.genfromtxt('data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)
%matplotlib inline
import torch
import sysfeatures, labels = get_data_ch7()def init_adam_states():v_w, v_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):v[:] = beta1 * v + (1 - beta1) * p.grad.datas[:] = beta2 * s + (1 - beta2) * p.grad.data**2v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p.data -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)hyperparams['t'] += 1

使用学习率为0.01的Adam算法来训练模型。

def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = linreg, squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))set_figsize()plt.plot(np.linspace(0, num_epochs, len(ls)), ls)plt.xlabel('epoch')plt.ylabel('loss')
train_ch7(adam, init_adam_states(), {'lr': 0.01, 't': 1}, features, labels)

也可以使用pytorch内置的optim.Adam实现:

train_pytorch_ch7(torch.optim.Adam, {'lr': 0.01}, features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507956.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytorch命令式和符号式混合编程

命令式和符号式编程 命令式编程 命令式编程使用编程语句改变程序状态&#xff0c;如下&#xff1a; def add(a, b):return a bdef fancy_func(a, b, c, d):e add(a, b)f add(c, d)g add(e, f)return gfancy_func(1, 2, 3, 4) # 10在运行语句e add(a, b)时&#xff0c;P…

深度学习-自动并行计算

自动并行计算 异步计算 默认情况下&#xff0c;PyTorch中的 GPU 操作是异步的。当调用一个使用 GPU 的函数时&#xff0c;这些操作会在特定的设备上排队但不一定会在稍后立即执行。这就使我们可以并行更多的计算&#xff0c;包括 CPU 或其他 GPU 上的操作。 一般情况下&…

pytorch多GPU计算

pytorch多GPU计算 如果正确安装了NVIDIA驱动&#xff0c;我们可以通过在命令行输入nvidia-smi命令来查看当前计算机上的全部GPU 定义一个模型&#xff1a; import torch net torch.nn.Linear(10, 1).cuda() netoutput: Linear(in_features10, out_features1, biasTrue)要想…

深度学习-计算机视觉--图像增广

图像增广 大规模数据集是成功应用深度神经网络的前提。图像增广&#xff08;image augmentation&#xff09;技术通过对训练图像做一系列随机改变&#xff0c;来产生相似但又不同的训练样本&#xff0c;从而扩大训练数据集的规模。 图像增广的另一种解释是&#xff0c;随机改…

pytorch深度学习-微调(fine tuning)

微调&#xff08;fine tuning&#xff09; 首先举一个例子&#xff0c;假设我们想从图像中识别出不同种类的椅子&#xff0c;然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子&#xff0c;为每种椅子拍摄1,000张不同角度的图像&#xff0c;然后在收集到的图像…

c语言封闭曲线分割平面_高手的平面课堂:8种常用的设计排版方式,告别通宵加班...

重复、对比、对齐以及亲密性是传统平面排版的四大原则&#xff0c;即将元素重复运用(包括颜色、形状、材质、字体、空间关系等)以增加画面的条理性和整体性&#xff1b;避免页面上的元素形态与关系构建过于相似&#xff1b;画面上的每一元素都应该与另一个元素存在某种视觉联系…

我的世界java版和基岩版对比_基岩版Beta1.11.0.1发布

本帖来自好游快爆-我的世界精选推荐原帖作者:好游快爆用户3302482我的世界基岩版1.11.0.1测试版发布了&#xff0c;Minecraft基岩版1.11仍未发布&#xff0c;1.11.0.1为测试版本&#xff0c;Beta版本可能不稳定&#xff0c;并不代表最终版本质量&#xff0c;请在加入测试版之前…

机器人电焊电流电压怎么调_【华光】HG1000型电焊机现场校准仪

机器简介HG-1000型电焊机现场校准仪是依据检定规程JJG124&#xff0d;2005《电流表、电压表、功率表和电阻表检定规程》、JJG(航天)38-1987《直流标准电流源检定规程》、JJG(航天)51-1999《交流标准电流源检定规程》的要求而设计的校准设备。主要用来校验各种用电焊机(如交流手…

循环机换变速箱油教程_变速箱油用循环机换还是重力换更好?一次讲清楚,新手司机学学...

现在换变速箱油有些只要几百块钱&#xff0c;有些要一两千&#xff0c;之所以差价这么大是因为这里面涉及到换变速箱油时用什么方法去换油的问题。目前比较常见换油法是重力换油法和循环换油法。重力换油法就跟平时换机油是一样的&#xff0c;把变速箱底部的螺丝拧开之后让油滴…

mongodb python 存文件_Python保存MongoDB上的文件到本地的方法介绍

本文实例讲述了Python保存MongoDB上的文件到本地的方法。分享给大家供大家参考&#xff0c;具体如下&#xff1a;MongoDB上的文档通过GridFS来操作&#xff0c;Python也可以通过pymongo连接MongoDB数据库&#xff0c;使用pymongo模块的gridfs方法操作文档。以下示例是把MongoDB…

mongodb 监控权限_MongoDB - 监控

随着MongoDB中保存的数据越来越多&#xff0c;对MongoDB服务状态的监控也越来越重要&#xff0c;经常关注服务是否健康&#xff0c;才能防止故障以及优化。1.静态监控db.serverStatus()使用mongo命令进入shell客户端后输入以下命令可以查看MongoDB服务的状态&#xff0c;有助于…

pytorch深度学习-机器视觉-目标检测和边界框简介

机器视觉之目标检测和边界框简介 在图像分类任务里&#xff0c;我们假设图像里只有一个主体目标&#xff0c;并关注如何识别该目标的类别。然而&#xff0c;很多时候图像里有多个我们感兴趣的目标&#xff0c;我们不仅想知道它们的类别&#xff0c;还想得到它们在图像中的具体…

消防荷载楼板按弹性还是塑性计算_第二节 消防登高面、消防救援场地和灭火救援窗...

一、定义1、消防登高面&#xff1a;登高消防车能够靠近高层主体建筑&#xff0c;便于消防车作业和消防人员进入高层建筑进行抢救人员和扑救火灾的建筑立面称为该建筑的消防登高面&#xff0c;也称建筑的消防扑救面。2、消防救援场地&#xff1a;在高层建筑的消防登高面一侧&…

深度学习-词嵌入(word2vec)

词嵌入&#xff08;word2vec&#xff09; 自然语言是一套用来表达含义的复杂系统。在这套系统中&#xff0c;词是表义的基本单元。顾名思义&#xff0c;词向量是用来表示词的向量&#xff0c;也可被认为是词的特征向量或表征。把词映射为实数域向量的技术也叫词嵌入&#xff0…

ggplot2箱式图两两比较_作图技巧024篇ggplot2在循环中的坑

“ggplot2在循环中的输出”生活科学哥-R语言科学 2020-12-23 8:28ggplot2用过之后&#xff0c;你肯定会爱上它&#xff1b;结合一些不错的包&#xff0c;可以得到非常有展现力的图片&#xff0c;但是呢&#xff0c;有时也会碰到一些奇怪的情况。今天来们来看看&#xff0c;其中…

深度学习-自然语言处理中的近似训练

自然语言处理中的近似训练 跳字模型的核心在于使用softmax运算得到给定中心词wcw_cwc​来生成背景词wow_owo​的条件概率 P(wo∣wc)exp(uo⊤vc)∑i∈Vexp(ui⊤vc).P(w_o \mid w_c) \frac{\text{exp}(\boldsymbol{u}_o^\top \boldsymbol{v}_c)}{ \sum_{i \in \mathcal{V}} \te…

pytorch-word2vec的实例实现

word2vec的实例实现 实现词嵌入word2vec中的跳字模型和近似训练中的负采样以及二次采样&#xff08;subsampling&#xff09;&#xff0c;在语料库上训练词嵌入模型的实现。 首先导入实验所需的包或模块。 import collections import math import random import sys import …

pytorch-LSTM的输入和输出尺寸

LSTM的输入和输出尺寸 CLASS torch.nn.LSTM(*args, **kwargs)Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: 对于一个输入序列实现多层长短期记忆的…

python中的[-1]、[:-1]、[::-1]、[n::-1]

import numpy as np anp.random.rand(4) print(a)[0.48720333 0.67178384 0.65662903 0.40513918]print(a[-1]) #取最后一个元素 0.4051391774882336print(a[:-1]) #去除最后一个元素 [0.48720333 0.67178384 0.65662903]print(a[::-1]) #逆序 [0.40513918 0.65662903 0.67178…

torchtext.data.Field

torchtext.data.Field 类接口 class torchtext.data.Field(sequentialTrue, use_vocabTrue, init_tokenNone, eos_tokenNone, fix_lengthNone, dtypetorch.int64, preprocessingNone, postprocessingNone, lowerFalse, tokenizeNone, tokenizer_languageen, include_lengthsF…