【深度学习笔记】7_7 AdaDelta算法

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.7 AdaDelta算法

除了RMSProp算法以外,另一个常用优化算法AdaDelta算法也针对AdaGrad算法在迭代后期可能较难找到有用解的问题做了改进 [1]。有意思的是,AdaDelta算法没有学习率这一超参数

Adadelta是一种自适应学习率的方法,用于神经网络的训练过程中。 它的基本思想是避免使用手动调整学习率的方法来控制训练过程,而是自动调整学习率,使得训练过程更加顺畅。

7.7.1 算法

AdaDelta算法也像RMSProp算法一样,使用了小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方的指数加权移动平均变量 s t \boldsymbol{s}_t st。在时间步0,它的所有元素被初始化为0。给定超参数 0 ≤ ρ < 1 0 \leq \rho < 1 0ρ<1(对应RMSProp算法中的 γ \gamma γ),在时间步 t > 0 t>0 t>0,同RMSProp算法一样计算

s t ← ρ s t − 1 + ( 1 − ρ ) g t ⊙ g t . \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t. stρst1+(1ρ)gtgt.

与RMSProp算法不同的是,AdaDelta算法还维护一个额外的状态变量 Δ x t \Delta\boldsymbol{x}_t Δxt,其元素同样在时间步0时被初始化为0。我们使用 Δ x t − 1 \Delta\boldsymbol{x}_{t-1} Δxt1来计算自变量的变化量:

g t ′ ← Δ x t − 1 + ϵ s t + ϵ ⊙ g t , \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, gtst+ϵΔxt1+ϵ gt,

其中 ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 5 10^{-5} 105。接着更新自变量:

x t ← x t − 1 − g t ′ . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t. xtxt1gt.

最后,我们使用 Δ x t \Delta\boldsymbol{x}_t Δxt来记录自变量变化量 g t ′ \boldsymbol{g}'_t gt按元素平方的指数加权移动平均:

Δ x t ← ρ Δ x t − 1 + ( 1 − ρ ) g t ′ ⊙ g t ′ . \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t. ΔxtρΔxt1+(1ρ)gtgt.

可以看到,如不考虑 ϵ \epsilon ϵ的影响,AdaDelta算法跟RMSProp算法的不同之处在于使用 Δ x t − 1 \sqrt{\Delta\boldsymbol{x}_{t-1}} Δxt1 来替代学习率 η \eta η

7.7.2 从零开始实现

AdaDelta算法需要对每个自变量维护两个状态变量,即 s t \boldsymbol{s}_t st Δ x t \Delta\boldsymbol{x}_t Δxt。我们按AdaDelta算法中的公式实现该算法。

%matplotlib inline
import torch
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2lfeatures, labels = d2l.get_data_ch7()def init_adadelta_states():s_w, s_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)delta_w, delta_b = torch.zeros((features.shape[1], 1), dtype=torch.float32), torch.zeros(1, dtype=torch.float32)return ((s_w, delta_w), (s_b, delta_b))def adadelta(params, states, hyperparams):rho, eps = hyperparams['rho'], 1e-5for p, (s, delta) in zip(params, states):s[:] = rho * s + (1 - rho) * (p.grad.data**2)g =  p.grad.data * torch.sqrt((delta + eps) / (s + eps))p.data -= gdelta[:] = rho * delta + (1 - rho) * g * g

使用超参数 ρ = 0.9 \rho=0.9 ρ=0.9来训练模型。

d2l.train_ch7(adadelta, init_adadelta_states(), {'rho': 0.9}, features, labels)

输出:

loss: 0.243728, 0.062991 sec per epoch

在这里插入图片描述

7.7.3 简洁实现

通过名称为Adadelta的优化器方法,我们便可使用PyTorch提供的AdaDelta算法。它的超参数可以通过rho来指定。

d2l.train_pytorch_ch7(torch.optim.Adadelta, {'rho': 0.9}, features, labels)

输出:

loss: 0.242104, 0.047702 sec per epoch

在这里插入图片描述

小结

  • AdaDelta算法没有学习率超参数,它通过使用有关自变量更新量平方的指数加权移动平均的项来替代RMSProp算法中的学习率。

参考文献

[1] Zeiler, M. D. (2012). ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/741710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式驱动学习第三周——linux内核链表

前言 在 Linux 内核中使用最多的数据结构就是链表了&#xff0c;其中就包含了许多高级思想。 比如面向对象、类似C模板的实现、堆和栈的实现。 嵌入式驱动学习专栏将详细记录博主学习驱动的详细过程&#xff0c;未来预计四个月将高强度更新本专栏&#xff0c;喜欢的可以关注本博…

C#无法给PLC写入数据原因分析

一、背景 1.1 概述 C#中无法给PLC写入数据的原因有很多&#xff0c;这里分享网络端口号被占用导致无法写入的确认方法 1.2 环境 ①使用三菱PLC ②C#通过网口与PLC进行通讯 二、现象 1.1 代码 通过HslCommunication连接PLC时&#xff0c;连接返回成功&#xff0c;写入返回失败 …

snakeflow的springboot项目

Gitee搜索“liuxz/snakerflow”&#xff0c;它是spring boot集成了一款国产工作流引擎snakerflow。 下面是安装步骤&#xff1a; 创建数据库snaker-web&#xff0c;字符集设置成utf8mb4和utf8mb4_generic。不然的话&#xff0c;中文插入不进去。 运行sql命令 CREATE TABLE …

Kotlin:泛型

点击查看泛型中文文档 点击查看泛型英文文档 简介 与 Java 类似&#xff0c;Kotlin 中的类也可以有类型参数&#xff1a; class Box<T>(t: T) {var value t }一般来说&#xff0c;要创建这样类的实例&#xff0c;我们需要提供类型参数&#xff1a; val box: Box<…

调试安卓 gles性能瓶颈

目录 下载Arm Performance Studio编译Unity Shader运行malios调试用处和限制 原文请见&#xff1a;参考地址 使用mali offline shader compiler分析shader的性能瓶颈。 下载Arm Performance Studio 下载地址 编译Unity Shader 通常选择GLES3x。 You might need to select G…

智能控制:物联网智能插座对接文档

介绍 一开始买的某米的插座&#xff0c;但是好像接口不开放&#xff0c;所以找到了这个插座&#xff0c;然后自己开发了下&#xff0c;用接口控制插座开关。wifi的连接方式&#xff0c;通电后一般几秒后就会连接上wifi&#xff0c;这个时候通过接口发送命令给他。 产品图片 通…

idea配置自定义注释模版和其他模板

项目场景&#xff1a; idea配置自定义模版 自定义注释模版其他模板&#xff0c;包括syso快捷键&#xff0c;swith快捷键等 自定义注释模版 1、File and Code Templates 第一种类创建完后头部自动生成注释模板 打开idea&#xff0c;选择 Settings--> Editor--> File a…

nvm安装不同版本的node

在项目开发过程中&#xff0c;不同项目依赖的node版本不同&#xff0c;但频繁的卸载和安装很麻烦&#xff0c;这篇文章介绍nvm安装过程 1.nvm安装 这个网上随便找一篇跟着安装即可 nvm安装教程 2.nvm安装不同版本的node 网上普遍的方式是&#xff1a; 找到nvm安装目录下的s…

浅谈LockBit勒索病毒

在数字时代&#xff0c;随着科技的飞速发展&#xff0c;网络安全问题愈发凸显。恶意软件和勒索软件等网络威胁正不断演变&#xff0c;其中一款备受关注的勒索软件就是LockBit。 LockBit是一种高度复杂且具有破坏性的勒索软件。与传统的勒索软件相比&#xff0c;LockBit在其攻击…

NVMFS5A160PLZT1G汽车级功率MOSFET P沟道60 V 15A 满足AEC-Q101标准

关于汽车电子AEC Q101车规认证&#xff1f; 是一种针对分立半导体的可靠性测试认证程序&#xff0c;由汽车电子协会发布。这个认证程序主要是为了确保汽车电子产品在各种严苛的条件下能够正常工作和可靠运行。它包括了对分立半导体的可靠性、环境适应性、温度循环和湿度变化等…

新建项目module,但想归到一个目录下面

1. 想建几个module, 例如 component-base-service,component-config-service, 但是module多了会在CloudAction下面显示很多目录, 所以想把它们归到components模块下面去, 类似于下图的效果 2. 创建过程 右击CloudAction 新建 module -> 选maven类型 输入components, 建成后删…

Capture One 23:光影魔术师,细节掌控者mac/win版

Capture One 23&#xff0c;不仅仅是一款摄影后期处理软件&#xff0c;它更是摄影师们的得力助手和创意伙伴。这款软件凭借其卓越的性能、丰富的功能和前沿的技术&#xff0c;为摄影师们带来了前所未有的影像处理体验。 Capture One 23 软件获取 Capture One 23以其强大的色彩…

【C语言】Infiniband驱动mlx4_load_one函数

一、中文注释 以下是针对mlx4_load_one函数的主要代码路径的中文注释。该函数是用于加载并初始化Mellanox网络设备的驱动函数。通过注释&#xff0c;可以了解函数在初始化过程中执行的关键步骤。 /* mlx4_load_one函数&#xff1a;用于加载并初始化PCI设备&#xff08;例如网…

效果图代渲多少钱一张?带你详细了解它的计费规则!

不知道有没有朋友遇到过渲着渲着就崩溃的情况发生&#xff0c;不然也不会去找代渲染的平台/某宝等渠道 也就是为了图能够顺利的跑出来&#xff0c;做了后期处理后&#xff0c;及时交付给客户。 我们以渲染100云渲染来举例&#xff0c;它成立2015年&#xff0c;是一家效果图代…

接口自动化测试框架:Pytest+Allure+Excel

1. Allure 简介 简介 Allure 框架是一个灵活的、轻量级的、支持多语言的测试报告工具&#xff0c;它不仅以 Web 的方式展示了简介的测试结果&#xff0c;而且允许参与开发过程的每个人可以从日常执行的测试中&#xff0c;最大限度地提取有用信息。 Allure 是由 Java 语言开发…

Unity DropDown 组件 详解

Unity版本 2022.3.13f1 Dropdown下拉菜单可以快速创建大量选项 一、 Dropwon属性详解 属性&#xff1a;功能&#xff1a;Interactable此组件是否接受输入&#xff1f;请参阅 Interactable。Transition确定控件以何种方式对用户操作进行可视化响应的属性。请参阅过渡选项。Nav…

Titanic数据分析项目——Kaggle数据分析项目实战1

目前预测准确度达到77.511%, 会持续优化并且更新。 一、特征工程&#xff1a; 1、先对缺失值进行填充&#xff0c;先找到缺失值的位置&#xff0c;数值型数据填充众数&#xff0c;字符数据或者是离散型数据则填充出现最多的数据。 2、标准化数值型数据&#xff0c; 根据标准化…

Vue使用L2Dwidget

1、在根文件index.html中引入live2dw/lib/L2Dwidget.min.js 下载模型的文件&#xff0c;放在本地或者cdn 切换不同的模型 模型地址&#xff1a;https://github.com/xiazeyu/live2d-widget-models showLive2d(name: String) {var live2dWidget document.querySelector("…

专升本 C语言笔记-01 printf 占位符 转义符

目录 一.printf()函数简介 1.1作用 将格式化后的字符串输出(打印东西) 1.2函数原型 1.3返回值 二.常见占位符 2.1.占位符的使用 2.2.格式修饰符 2.3.输出格式说明 三.转义字符 一.printf()函数简介 1.1作用 将格式化后的字符串输出(打印东西) printf…