深度学习优化算法:RMSProp算法

RMSProp算法

在AdaGrad算法中,因为调整学习率时分母上的变量st\boldsymbol{s}_tst一直在累加按元素平方的小批量随机梯度,所以目标函数自变量每个元素的学习率在迭代过程中一直在降低(或不变)。因此,当学习率在迭代早期降得较快且当前解依然不佳时,AdaGrad算法在迭代后期由于学习率过小,可能较难找到一个有用的解。为了解决这一问题,RMSProp算法对AdaGrad算法做了一点小小的修改。

算法内容

之前说过指数加权移动平均。不同于AdaGrad算法里状态变量st\boldsymbol{s}_tst是截至时间步ttt所有小批量随机梯度gt\boldsymbol{g}_tgt按元素平方和,RMSProp算法将这些梯度按元素平方做指数加权移动平均

具体来说,给定超参数0≤γ<10 \leq \gamma < 10γ<1,RMSProp算法在时间步t>0t>0t>0计算

st←γst−1+(1−γ)gt⊙gt.\boldsymbol{s}_t \leftarrow \gamma \boldsymbol{s}_{t-1} + (1 - \gamma) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stγst1+(1γ)gtgt.

RMSProp算法将目标函数自变量中每个元素的学习率通过按元素运算重新调整,然后更新自变量,这是和AdaGrad算法一样的梯度下降公式:

xt←xt−1−ηst+ϵ⊙gt,\boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \frac{\eta}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, xtxt1st+ϵηgt,

其中

  • η\etaη是学习率
  • ϵ\epsilonϵ是为了维持数值稳定性而添加的常数,如10−610^{-6}106

因为RMSProp算法的状态变量st\boldsymbol{s}_tst是对平方项gt⊙gt\boldsymbol{g}_t \odot \boldsymbol{g}_tgtgt的指数加权移动平均,所以可以看作是最近1/(1−γ)1/(1-\gamma)1/(1γ)个时间步的小批量随机梯度平方项的加权平均。如此一来,自变量每个元素的学习率在迭代过程中就不再一直降低(或不变)。

还是使用相同的例子

  • 目标函数f(x)=0.1x12+2x22f(\boldsymbol{x})=0.1x_1^2+2x_2^2f(x)=0.1x12+2x22
  • 学习率为0.4

在AdaGrad算法中,自变量在迭代后期的移动幅度较小。但在同样的学习率下,RMSProp算法可以更快逼近最优解。

from matplotlib import pyplot as pltdef show_trace_2d(f, results):  plt.plot(*zip(*results), '-o', color='#ff7f0e')x1, x2 = np.meshgrid(np.arange(-5.5, 1.0, 0.1), np.arange(-3.0, 1.0, 0.1))plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')plt.xlabel('x1')plt.ylabel('x2')def train_2d(trainer):  x1, x2, s1, s2 = -5, -2, 0, 0  # s1和s2是自变量状态,本章后续几节会使用results = [(x1, x2)]for i in range(20):x1, x2, s1, s2 = trainer(x1, x2, s1, s2)results.append((x1, x2))print('epoch %d, x1 %f, x2 %f' % (i + 1, x1, x2))return results
%matplotlib inline
import math
import torchdef rmsprop_2d(x1, x2, s1, s2):g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6s1 = gamma * s1 + (1 - gamma) * g1 ** 2s2 = gamma * s2 + (1 - gamma) * g2 ** 2x1 -= eta / math.sqrt(s1 + eps) * g1x2 -= eta / math.sqrt(s2 + eps) * g2return x1, x2, s1, s2def f_2d(x1, x2):return 0.1 * x1 ** 2 + 2 * x2 ** 2eta, gamma = 0.4, 0.9
show_trace_2d(f_2d, train_2d(rmsprop_2d))

实现RMSProp算法

def get_data_ch7():  data = np.genfromtxt('data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)
features, labels = get_data_ch7()def init_rmsprop_states():s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)s_b = torch.zeros(1, dtype=torch.float32)return (s_w, s_b)def rmsprop(params, states, hyperparams):gamma, eps = hyperparams['gamma'], 1e-6for p, s in zip(params, states):s.data = gamma * s.data + (1 - gamma) * (p.grad.data)**2p.data -= hyperparams['lr'] * p.grad.data / torch.sqrt(s + eps)

将初始学习率设为0.01,并将超参数γ\gammaγ设为0.9。此时,变量st\boldsymbol{s}_tst可看作是最近1/(1−0.9)=101/(1-0.9) = 101/(10.9)=10个时间步的平方项gt⊙gt\boldsymbol{g}_t \odot \boldsymbol{g}_tgtgt的加权平均。

def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = linreg, squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))set_figsize()plt.plot(np.linspace(0, num_epochs, len(ls)), ls)plt.xlabel('epoch')plt.ylabel('loss')
train_ch7(rmsprop, init_rmsprop_states(), {'lr': 0.01, 'gamma': 0.9}, features, labels)

亦可以使用pytorch内置的optim.RMSProp算法来实现:

train_pytorch_ch7(torch.optim.RMSprop, {'lr': 0.01, 'alpha': 0.9},features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507959.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习-参数与超参数

参数(parameters)/模型参数 由模型通过学习得到的变量比如权重、偏置 超参数(hyperparameters)/算法参数 根据经验进行设定&#xff0c;影响到权重和偏置的大小比如迭代次数、隐藏层的层数、每层神经元的个数、学习速率等

深度学习优化算法-AdaDelta算法

AdaDelta算法 除了RMSProp算法以外&#xff0c;另一个常用优化算法AdaDelta算法也针对AdaGrad算法在迭代后期可能较难找到有用解的问题做了改进 [1]。 不一样的是&#xff0c;AdaDelta算法没有学习率这个超参数。 它通过使用有关自变量更新量平方的指数加权移动平均的项来替代…

深度学习优化算法-Adam算法

Adam算法 Adam算法在RMSProp算法基础上对小批量随机梯度也做了指数加权移动平均。Adam算法可以看做是RMSProp算法与动量法的结合。 算法内容 Adam算法使用了动量变量vt\boldsymbol{v}_tvt​和RMSProp算法中小批量随机梯度按元素平方的指数加权移动平均变量st\boldsymbol{s}_…

pytorch命令式和符号式混合编程

命令式和符号式编程 命令式编程 命令式编程使用编程语句改变程序状态&#xff0c;如下&#xff1a; def add(a, b):return a bdef fancy_func(a, b, c, d):e add(a, b)f add(c, d)g add(e, f)return gfancy_func(1, 2, 3, 4) # 10在运行语句e add(a, b)时&#xff0c;P…

深度学习-自动并行计算

自动并行计算 异步计算 默认情况下&#xff0c;PyTorch中的 GPU 操作是异步的。当调用一个使用 GPU 的函数时&#xff0c;这些操作会在特定的设备上排队但不一定会在稍后立即执行。这就使我们可以并行更多的计算&#xff0c;包括 CPU 或其他 GPU 上的操作。 一般情况下&…

pytorch多GPU计算

pytorch多GPU计算 如果正确安装了NVIDIA驱动&#xff0c;我们可以通过在命令行输入nvidia-smi命令来查看当前计算机上的全部GPU 定义一个模型&#xff1a; import torch net torch.nn.Linear(10, 1).cuda() netoutput: Linear(in_features10, out_features1, biasTrue)要想…

深度学习-计算机视觉--图像增广

图像增广 大规模数据集是成功应用深度神经网络的前提。图像增广&#xff08;image augmentation&#xff09;技术通过对训练图像做一系列随机改变&#xff0c;来产生相似但又不同的训练样本&#xff0c;从而扩大训练数据集的规模。 图像增广的另一种解释是&#xff0c;随机改…

pytorch深度学习-微调(fine tuning)

微调&#xff08;fine tuning&#xff09; 首先举一个例子&#xff0c;假设我们想从图像中识别出不同种类的椅子&#xff0c;然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子&#xff0c;为每种椅子拍摄1,000张不同角度的图像&#xff0c;然后在收集到的图像…

c语言封闭曲线分割平面_高手的平面课堂:8种常用的设计排版方式,告别通宵加班...

重复、对比、对齐以及亲密性是传统平面排版的四大原则&#xff0c;即将元素重复运用(包括颜色、形状、材质、字体、空间关系等)以增加画面的条理性和整体性&#xff1b;避免页面上的元素形态与关系构建过于相似&#xff1b;画面上的每一元素都应该与另一个元素存在某种视觉联系…

我的世界java版和基岩版对比_基岩版Beta1.11.0.1发布

本帖来自好游快爆-我的世界精选推荐原帖作者:好游快爆用户3302482我的世界基岩版1.11.0.1测试版发布了&#xff0c;Minecraft基岩版1.11仍未发布&#xff0c;1.11.0.1为测试版本&#xff0c;Beta版本可能不稳定&#xff0c;并不代表最终版本质量&#xff0c;请在加入测试版之前…

机器人电焊电流电压怎么调_【华光】HG1000型电焊机现场校准仪

机器简介HG-1000型电焊机现场校准仪是依据检定规程JJG124&#xff0d;2005《电流表、电压表、功率表和电阻表检定规程》、JJG(航天)38-1987《直流标准电流源检定规程》、JJG(航天)51-1999《交流标准电流源检定规程》的要求而设计的校准设备。主要用来校验各种用电焊机(如交流手…

循环机换变速箱油教程_变速箱油用循环机换还是重力换更好?一次讲清楚,新手司机学学...

现在换变速箱油有些只要几百块钱&#xff0c;有些要一两千&#xff0c;之所以差价这么大是因为这里面涉及到换变速箱油时用什么方法去换油的问题。目前比较常见换油法是重力换油法和循环换油法。重力换油法就跟平时换机油是一样的&#xff0c;把变速箱底部的螺丝拧开之后让油滴…

mongodb python 存文件_Python保存MongoDB上的文件到本地的方法介绍

本文实例讲述了Python保存MongoDB上的文件到本地的方法。分享给大家供大家参考&#xff0c;具体如下&#xff1a;MongoDB上的文档通过GridFS来操作&#xff0c;Python也可以通过pymongo连接MongoDB数据库&#xff0c;使用pymongo模块的gridfs方法操作文档。以下示例是把MongoDB…

mongodb 监控权限_MongoDB - 监控

随着MongoDB中保存的数据越来越多&#xff0c;对MongoDB服务状态的监控也越来越重要&#xff0c;经常关注服务是否健康&#xff0c;才能防止故障以及优化。1.静态监控db.serverStatus()使用mongo命令进入shell客户端后输入以下命令可以查看MongoDB服务的状态&#xff0c;有助于…

pytorch深度学习-机器视觉-目标检测和边界框简介

机器视觉之目标检测和边界框简介 在图像分类任务里&#xff0c;我们假设图像里只有一个主体目标&#xff0c;并关注如何识别该目标的类别。然而&#xff0c;很多时候图像里有多个我们感兴趣的目标&#xff0c;我们不仅想知道它们的类别&#xff0c;还想得到它们在图像中的具体…

消防荷载楼板按弹性还是塑性计算_第二节 消防登高面、消防救援场地和灭火救援窗...

一、定义1、消防登高面&#xff1a;登高消防车能够靠近高层主体建筑&#xff0c;便于消防车作业和消防人员进入高层建筑进行抢救人员和扑救火灾的建筑立面称为该建筑的消防登高面&#xff0c;也称建筑的消防扑救面。2、消防救援场地&#xff1a;在高层建筑的消防登高面一侧&…

深度学习-词嵌入(word2vec)

词嵌入&#xff08;word2vec&#xff09; 自然语言是一套用来表达含义的复杂系统。在这套系统中&#xff0c;词是表义的基本单元。顾名思义&#xff0c;词向量是用来表示词的向量&#xff0c;也可被认为是词的特征向量或表征。把词映射为实数域向量的技术也叫词嵌入&#xff0…

ggplot2箱式图两两比较_作图技巧024篇ggplot2在循环中的坑

“ggplot2在循环中的输出”生活科学哥-R语言科学 2020-12-23 8:28ggplot2用过之后&#xff0c;你肯定会爱上它&#xff1b;结合一些不错的包&#xff0c;可以得到非常有展现力的图片&#xff0c;但是呢&#xff0c;有时也会碰到一些奇怪的情况。今天来们来看看&#xff0c;其中…

深度学习-自然语言处理中的近似训练

自然语言处理中的近似训练 跳字模型的核心在于使用softmax运算得到给定中心词wcw_cwc​来生成背景词wow_owo​的条件概率 P(wo∣wc)exp(uo⊤vc)∑i∈Vexp(ui⊤vc).P(w_o \mid w_c) \frac{\text{exp}(\boldsymbol{u}_o^\top \boldsymbol{v}_c)}{ \sum_{i \in \mathcal{V}} \te…

pytorch-word2vec的实例实现

word2vec的实例实现 实现词嵌入word2vec中的跳字模型和近似训练中的负采样以及二次采样&#xff08;subsampling&#xff09;&#xff0c;在语料库上训练词嵌入模型的实现。 首先导入实验所需的包或模块。 import collections import math import random import sys import …