动手学深度学习11.9. Adadelta-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记,以及对课后练习的一些思考,自留回顾,也供同学之人交流参考。

本节课程地址:72 优化算法【动手学深度学习v2】_哔哩哔哩_bilibili

本节教材地址:11.9. Adadelta — 动手学深度学习 2.0.0 documentation

本节开源代码:...>d2l-zh>pytorch>chapter_optimization>adadelta.ipynb


Adadelta

Adadelta是AdaGrad的另一种变体( 11.7节 ), 主要区别在于前者减少了学习率适应坐标的数量。 此外,广义上Adadelta被称为没有学习率,因为它使用变化量本身作为未来变化的校准。 Adadelta算法是在 (Zeiler, 2012)中提出的。

Adadelta算法

简而言之,Adadelta使用两个状态变量, \mathbf{s}_t 用于存储梯度二阶导数的泄露平均值, \Delta\mathbf{x}_t 用于存储模型本身中参数变化二阶导数的泄露平均值。请注意,为了与其他出版物和实现的兼容性,我们使用作者的原始符号和命名(没有其它真正理由让大家使用不同的希腊变量来表示在动量法、AdaGrad、RMSProp和Adadelta中用于相同用途的参数)。

以下是Adadelta的技术细节。鉴于参数du jour是 \rho ,我们获得了与 11.8节 类似的以下泄漏更新:

\begin{aligned} \mathbf{s}_t & = \rho \mathbf{s}_{t-1} + (1 - \rho) \mathbf{g}_t^2. \end{aligned}

与 11.8节 的区别在于,我们使用重新缩放的梯度 \mathbf{g}_t' 执行更新,即

\begin{aligned} \mathbf{x}_t & = \mathbf{x}_{t-1} - \mathbf{g}_t'. \\ \end{aligned}

那么,调整后的梯度 \mathbf{g}_t' 是什么?我们可以按如下方式计算它:

\begin{aligned} \mathbf{g}_t' & = \frac{\sqrt{\Delta\mathbf{x}_{t-1} + \epsilon}}{\sqrt{​{\mathbf{s}_t + \epsilon}}} \odot \mathbf{g}_t, \\ \end{aligned}

其中 Δxt−1 是重新缩放梯度的平方  \mathbf{g}_t' 的泄漏平均值。我们将 \Delta \mathbf{x}_{0} 初始化为0,然后在每个步骤中使用 \mathbf{g}_t' 更新它,即

\begin{aligned} \Delta \mathbf{x}_t & = \rho \Delta\mathbf{x}_{t-1} + (1 - \rho) {\mathbf{g}_t'}^2, \end{aligned}

和 \epsilon (例如 10^{-5} 这样的小值)是为了保持数字稳定性而加入的。

代码实现

Adadelta需要为每个变量维护两个状态变量,即 \mathbf{s}_t 和 \Delta\mathbf{x}_t 。这将产生以下实现。

%matplotlib inline
import torch
from d2l import torch as d2ldef init_adadelta_states(feature_dim):s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)delta_w, delta_b = torch.zeros((feature_dim, 1)), torch.zeros(1)return ((s_w, delta_w), (s_b, delta_b))def adadelta(params, states, hyperparams):rho, eps = hyperparams['rho'], 1e-5for p, (s, delta) in zip(params, states):with torch.no_grad():# In-place updates via [:] 原地更新,不会干扰计算图的结构s[:] = rho * s + (1 - rho) * torch.square(p.grad)g = (torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.gradp[:] -= gdelta[:] = rho * delta + (1 - rho) * g * gp.grad.data.zero_()

对于每次参数更新,选择 \rho = 0.9 相当于10个半衰期。由此我们得到:

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adadelta, init_adadelta_states(feature_dim),{'rho': 0.9}, data_iter, feature_dim)

输出结果:
loss: 0.244, 0.031 sec/epoch

为了简洁实现,我们只需使用高级API中的Adadelta算法。

trainer = torch.optim.Adadelta
d2l.train_concise_ch11(trainer, {'rho': 0.9}, data_iter)

输出结果:
loss: 0.243, 0.011 sec/epoch

小结

  • Adadelta没有学习率参数。相反,它使用参数本身的变化率来调整学习率。
  • Adadelta需要两个状态变量来存储梯度的二阶导数和参数的变化。
  • Adadelta使用泄漏的平均值来保持对适当统计数据的运行估计。

练习

  1. 调整 \rho 的值,会发生什么?
    解:
    类似于RMSProp的 \gamma , \rho 之于Adadelta也是一个介于 0 和 1 之间的值,用于平衡历史信息和当前信息的权重。
    当 \rho 接近于1时,模型会更多地依赖于过去的梯度和更新量信息;
    当 \rho 接近于0时,模型会更多地依赖于当前的梯度和更新量信息。
  2. 展示如何在不使用 \mathbf{g}_t' 的情况下实现算法。为什么这是个好主意?
    解:
    \mathbf{g}_t' 是一个显式的中间变量,不使用 \mathbf{g}_t',即直接在更新变量时进行计算和累积量时进行计算,这样做可以减少内存占用,但实际计算量是重复的,所以每epoch的计算用时增加了。
def adadelta_no_g(params, states, hyperparams):rho, eps = hyperparams['rho'], 1e-5for p, (s, delta) in zip(params, states):with torch.no_grad():            s[:] = rho * s + (1 - rho) * p.grad ** 2p[:] -= (torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.graddelta[:] = rho * delta + (1 - rho) * ((torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.grad) ** 2p.grad.data.zero_()
d2l.train_ch11(adadelta_no_g, init_adadelta_states(feature_dim),{'rho': 0.9}, data_iter, feature_dim)

输出结果:
loss: 0.243, 0.016 sec/epoch

3. Adadelta真的是学习率为0吗?能找到Adadelta无法解决的优化问题吗?
解:
不是,Adadelta 是一种自适应学习率优化算法,它的学习率并不是固定的也不是为0,而是动态调整的,具体为: \eta = \frac{\sqrt{\Delta\mathbf{x}_{t-1} + \epsilon}}{\sqrt{​{\mathbf{s}_t + \epsilon}}}

Adadelta 在处理非凸优化问题时可能会遇到困难。非凸优化问题的损失函数可能有多个局部最小值,Adadelta 可能会陷入局部最小值而无法找到全局最小值。
在某些情况下,梯度可能非常稀疏,即大部分梯度值为零,这时,Adadelta 的累积梯度平方 st 可能会变得非常小,导致学习率变得非常大。这可能会导致参数更新过大,从而破坏训练过程。例如,在训练稀疏数据集(如某些自然语言处理任务)时,Adadelta 可能会表现不佳。

4. 将Adadelta的收敛行为与AdaGrad和RMSProp进行比较。
解:

算法AdaGradRMSPropAdadelta
收敛速度在训练初期,AdaGrad 的收敛速度较快,因为它能够快速调整学习率以适应稀疏梯度。在训练初期,RMSProp 的收敛速度相对较快,因为它能够更好地适应动态变化的梯度。在训练初期,Adadelta 的收敛速度可能较慢,因为它依赖于历史信息来调整学习率。
后期表现由于累积梯度平方不断增大,学习率逐渐减小,最终可能趋近于零,导致训练停止。由于引入了衰减率 $\gamma$,RMSProp 能够保持相对稳定的学习率,避免了学习率趋近于零的问题。Adadelta 能够自适应地调整学习率,避免了手动设置全局学习率的问题。
适用场景适用于稀疏梯度问题,如自然语言处理中的词嵌入。适用于非平稳优化问题,如深度神经网络的训练。适用于非平稳优化问题,尤其是需要长期依赖历史信息的任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/74364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Audio基础(13)——audiomixer

在 Android 平台上,音频混合器 AudioMixer 主要用在 AudioFlinger 里,将多路音频源数据混音(包括混音、音量处理、重采样及处理声道等)。位于 framework 的音频处理模库 libaudioprocessing(frameworks/av/media/libau…

【React】使用Swiper报错`Swiper` needs at least one child

问题 聊天页面的表情面板&#xff0c;滑动效果使用了ant design mobile的Swiper。 原代码中&#xff0c;Swiper 组件在 isShow 为 false 时渲染的是 <></>&#xff08;空元素&#xff09;&#xff0c;控制台警告Swiper needs at least one child&#xff0c;Swip…

Matlab教程001:软件介绍和界面使用

1.1 软件介绍 1.1.1 Matlab的介绍 MATLAB&#xff08;MATrix LABoratory&#xff09;是一款由 MathWorks 公司开发的高级编程语言和交互式环境&#xff0c;广泛用于 科学计算、数据分析、机器学习、工程建模、仿真和信号处理 等领域。 1.1.2 主要应用领域 数据分析与可视化…

蓝桥杯算法实战分享:算法进阶之路与实战技巧

引言 蓝桥杯作为国内极具影响力的程序设计竞赛&#xff0c;为众多编程爱好者和专业人才提供了展示自我的舞台。参与蓝桥杯不仅能检验自身编程水平&#xff0c;还能拓宽技术视野&#xff0c;为未来职业发展积累宝贵经验。本文将结合历年真题与参赛经验&#xff0c;全面分享蓝桥…

Android Compose 层叠布局(ZStack、Surface)源码深度剖析(十三)

Android Compose 层叠布局&#xff08;ZStack、Surface&#xff09;源码深度剖析 一、引言 在 Android 应用开发领域&#xff0c;用户界面&#xff08;UI&#xff09;的设计与实现一直是至关重要的环节。随着技术的不断演进&#xff0c;Android Compose 作为一种全新的声明式…

MongoDB 面试备战指南

MongoDB 面试备战指南 一、基础概念 1. MongoDB是什么类型的数据库&#xff1f;和关系型数据库有什么区别&#xff1f; 答案&#xff1a; MongoDB是文档型NoSQL数据库&#xff0c;核心区别&#xff1a; 数据模型&#xff1a;存储JSON-like文档&#xff08;动态schema&#xf…

毫米波雷达标定(2)

1. 前言 前面文章中介绍了产线上毫米波雷达的标定原理和流程,这篇文章则主要介绍其在线标定方法。相对于产线标定,在线标定具备使用自然场景而不是依赖特定标靶的优点,但因此其标定精度会相对差一点。在线标定一般应用于售出产品的维护场景,如果其标定结果精度可以满足使用…

Linux fority source和__builtin_xxx

这段代码是用于启用和配置 GCC/Clang 的 Fortify Source 安全机制的预处理指令。Fortify Source 主要用于在编译时增强对缓冲区溢出等内存安全问题的检查。以下是对每一部分的详细解释&#xff1a; 1. 最外层条件编译 # if CONFIG_FORTIFY_SOURCE > 0目的&#xff1a;检查…

SQL GROUP BY 自定义排序规则

在 SQL 中&#xff0c;GROUP BY 子句用于将结果集按一个或多个列进行分组。默认情况下&#xff0c;GROUP BY 会按照列的自然顺序&#xff08;如字母顺序或数字顺序&#xff09;进行排序。如果你需要按照自定义的排序规则对结果进行分组&#xff0c;可以使用 ORDER BY 子句结合 …

语言模型理论基础-持续更新-思路清晰

1.预训练 相似的任务A、B&#xff0c;任务A已经用大数据完成了训练&#xff0c;得到模型A。 我们利用-特征提取模型的-“浅层参数通用”的特性&#xff0c;使用模型A的浅层参数&#xff0c;其他参数再通过任务B去训练&#xff08;微调&#xff09;。 2.统计语言模型 通过条件…

ResNet与注意力机制:深度学习中的强强联合

引言 在深度学习领域&#xff0c;卷积神经网络&#xff08;CNN&#xff09;一直是图像处理任务的主流架构。然而&#xff0c;随着网络深度的增加&#xff0c;梯度消失和梯度爆炸问题逐渐显现&#xff0c;限制了网络的性能。为了解决这一问题&#xff0c;ResNet&#xff08;残差…

【C++】——C++11新特性

目录 前言 1.初始化列表 2.std::initializer_list 3.auto 4.decltype 5.nullptr 6.左值引用和右值引用 6.1右值引用的真面目 6.2左值引用和右值引用比较 6.3右值引用的意义 6.3.1移动构造 6.4万能引用 6.5完美转发——forward 结语 前言 C&#xff0c;这门在系统…

【C++网络编程】第5篇:UDP与广播通信

一、UDP协议核心特性 1. UDP vs TCP ​特性 ​UDP​TCP连接方式无连接面向连接&#xff08;三次握手&#xff09;可靠性不保证数据到达或顺序可靠传输&#xff08;超时重传、顺序控制&#xff09;传输效率低延迟&#xff0c;高吞吐相对较低&#xff08;因握手和确认机制&…

MOSN(Modular Open Smart Network)是一款主要使用 Go 语言开发的云原生网络代理平台

前言 大家好&#xff0c;我是老马。 sofastack 其实出来很久了&#xff0c;第一次应该是在 2022 年左右开始关注&#xff0c;但是一直没有深入研究。 最近想学习一下 SOFA 对于生态的设计和思考。 sofaboot 系列 SOFABoot-00-sofaboot 概览 SOFABoot-01-蚂蚁金服开源的 s…

微信小程序日常开发问题整理

微信小程序日常开发问题整理 1 切换渲染模式1.1 WebView 的链接在模拟器可以打开&#xff0c;手机上无法打开。 1 切换渲染模式 1.1 WebView 的链接在模拟器可以打开&#xff0c;手机上无法打开。 在 app.json 中看到如下配置项&#xff0c;那么当前项目就是 keyline 渲染模式…

【Altium Designer】铜皮编辑

一、动态铜皮与静态铜皮的区分与切换 动态铜皮&#xff08;活铜&#xff09;&#xff1a; 通过快捷键 PG 创建&#xff0c;支持自动避让其他网络对象&#xff0c;常用于大面积铺铜&#xff08;如GND或电源网络&#xff09;。修改动态铜皮后&#xff0c;需通过 Tools → Polygo…

Java「Deque」 方法详解:从入门到实战

Java Deque 各种方法解析&#xff1a;从入门到实战 在 Java 编程中&#xff0c;Deque&#xff08;双端队列&#xff09;是一个功能强大的数据结构&#xff0c;允许开发者从队列的两端高效地添加、删除和检查元素。作为 java.util 包中的一部分&#xff0c;Deque 接口继承自 Qu…

ffmpeg+QOpenGLWidget显示视频

​一个基于 ‌FFmpeg 4.x‌ 和 QOpenGLWidget的简单视频播放器代码示例&#xff0c;实现视频解码和渲染到 Qt 窗口的功能。 1&#xff09;ffmpeg库界面&#xff0c;视频解码支持软解和硬解方式。 硬解后&#xff0c;硬件解码完成需要将数据从GPU复制到CPU。优先采用av_hwf…

深入解析嵌入式内核:从架构到实践

一、嵌入式内核概述 嵌入式内核是嵌入式操作系统的核心组件&#xff0c;负责管理硬件资源、调度任务、处理中断等关键功能。其核心目标是在资源受限的环境中提供高效、实时的控制能力。与通用操作系统不同&#xff0c;嵌入式内核通常具有高度可裁剪性、实时性和可靠性&#xff…

20250324-使用 `nltk` 的 `sent_tokenize`, `word_tokenize、WordNetLemmatizer` 方法时报错

解决使用 nltk 的 sent_tokenize, word_tokenize、WordNetLemmatizer 方法时报错问题 第 2 节的手动方法的法1可解决大部分问题&#xff0c;可首先尝试章节 2 的方法 1. nltk.download(‘punkt_tab’) LookupError: *******************************************************…