Python 梯度下降法(六):Nadam Optimize

文章目录

  • Python 梯度下降法(六):Nadam Optimize
    • 一、数学原理
      • 1.1 介绍
      • 1.2 符号定义
      • 1.3 实现流程
    • 二、代码实现
      • 2.1 函数代码
      • 2.2 总代码
    • 三、优缺点
      • 3.1 优点
      • 3.2 缺点
    • 四、相关链接

Python 梯度下降法(六):Nadam Optimize

一、数学原理

1.1 介绍

Nadam(Nesterov-accelerated Adaptive Moment Estimation)优化算法是 Adam 优化算法的改进版本,结合了 Nesterov 动量(Nesterov Momentum)和 Adam 算法的优点。

Nadam 在 Adam 算法的基础上引入了 Nesterov 动量的思想。Adam 算法通过计算梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差)来自适应地调整每个参数的学习率。而 Nesterov 动量则是在计算梯度时,考虑了参数在动量作用下未来可能到达的位置的梯度,从而让优化过程更具前瞻性。

1.2 符号定义

设置一下超参数:

参数说明
η \eta η学习率,控制参数更新的步长
m m m一阶矩估计,梯度均值
β 1 \beta_{1} β1一阶矩指数衰减率,通常取 0.9 0.9 0.9
v v v二阶矩估计,梯度未中心化方差
β 2 \beta_{2} β2二阶矩指数衰减率,通常取 0.999 0.999 0.999
ϵ \epsilon ϵ无穷小量,用于避免分母为零, 1 0 − 8 10^{-8} 108
g t g_{t} gt t t t时刻位置的梯度
θ \theta θ需要进行拟合的参数

1.3 实现流程

  1. 初始化参数: θ n × 1 \theta_{n\times 1} θn×1 m 0 ⃗ n × 1 = 0 \vec{m_{0}}_{n\times 1}=0 m0 n×1=0 v 0 ⃗ n × 1 = 0 \vec{v_{0}}_{n\times 1}=0 v0 n×1=0
  2. 更新一阶矩估计 m t m_{t} mt m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt1+(1β1)gt
  3. 更新二阶矩估计 v t v_{t} vt v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt1+(1β2)gt2
  4. 偏差修正:由于 m 0 , v 0 = 0 m_{0},v_{0}=0 m0,v0=0,在训练初期会存在偏差,需要进行修正: m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_{t}=\frac{m_{t}}{1-\beta_{1}^{t}},\hat{v}_{t}=\frac{v_{t}}{1-\beta_{2}^{t}} m^t=1β1tmt,v^t=1β2tvt
  5. 计算预估一阶矩: m ~ t = β 1 m ^ t + ( 1 − β 1 ) g t 1 − β 1 t \widetilde{m}_{t}=\beta_{1}\hat{m}_{t}+\frac{(1-\beta_{1})g_{t}}{1-\beta_{1}^{t}} m t=β1m^t+1β1t(1β1)gt
  6. 更新模型参数 θ t \theta_{t} θt θ t = θ t − 1 − η v t ^ + ϵ ⊙ m ~ t \theta_{t}=\theta_{t-1}-\frac{\eta}{\sqrt{ \hat{v_{t}} }+\epsilon}\odot\widetilde{m}_{t} θt=θt1vt^ +ϵηm t

二、代码实现

2.1 函数代码

# 定义 Nadam 函数
def nadam_optimizer(X, y, eta, num_iter=1000, beta1=0.9, beta2=0.999, epsilon=1e-8, threshold=1e-8):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1))  # 初始化数据loss_ = []for t in range(1, num_iter + 1):# 计算梯度h = X.dot(theta)err = h - yloss_.append(np.mean(err ** 2) / 2)g = (1 / m) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 先计算偏差修正,后面需要使用到,并且去除负数m_hat, v_hat = mt / (1 - pow(beta1, t)), np.maximum(vt / (1 - pow(beta2, t)), 0)# 计算预估一阶矩m_pre = beta1 * m_hat + (1 - beta1) * g / (1 - pow(beta1, t))# 更新参数theta = theta - np.multiply((eta / (np.sqrt(v_hat) + epsilon)), m_pre)# 检查是否收敛if t > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {t}")breakreturn theta.flatten(), loss_

2.2 总代码

import numpy as np
import matplotlib.pyplot as plt# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False# 定义 Nadam 函数
def nadam_optimizer(X, y, eta, num_iter=1000, beta1=0.9, beta2=0.999, epsilon=1e-8, threshold=1e-8):"""X: 数据 x  mxn,可以在传入数据之前进行数据的归一化y: 数据 y  mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1))  # 初始化数据loss_ = []for t in range(1, num_iter + 1):# 计算梯度h = X.dot(theta)err = h - yloss_.append(np.mean(err ** 2) / 2)g = (1 / m) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 先计算偏差修正,后面需要使用到,并且去除负数m_hat, v_hat = mt / (1 - pow(beta1, t)), np.maximum(vt / (1 - pow(beta2, t)), 0)# 计算预估一阶矩m_pre = beta1 * m_hat + (1 - beta1) * g / (1 - pow(beta1, t))# 更新参数theta = theta - np.multiply((eta / (np.sqrt(v_hat) + epsilon)), m_pre)# 检查是否收敛if t > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {t}")breakreturn theta.flatten(), loss_# 生成一些示例数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 超参数
eta = 0.1# 运行 Nadam 优化器
theta, loss_ = nadam_optimizer(X_b, y, eta)
print("最优参数 theta:")
print(theta)# 绘制损失函数图像
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.legend()  # 显示图例
plt.grid(True)  # 显示网格线
plt.show()

1738389513_xunsrs0jxa.png1738389512232.png

三、优缺点

3.1 优点

自适应学习率:NAdam 继承了 Adam 的自适应学习率特性,能够根据梯度的一阶矩(均值)和二阶矩(方差)动态调整每个参数的学习率。这使得 NAdam 在处理不同尺度的参数时更加高效,尤其适合稀疏梯度问题。

Nesterov 动量:NAdam 引入了 Nesterov 动量,能够在更新参数时先根据当前动量预测参数的未来位置,再计算梯度。这种“前瞻性”的更新方式使得 NAdam 能够更准确地调整参数,从而加速收敛。

快速收敛:由于结合了 Adam 的自适应学习率和 Nesterov 动量的前瞻性更新,NAdam 在大多数优化问题中能够比 Adam 和传统梯度下降法更快地收敛。特别是在非凸优化问题中,NAdam 的表现通常优于其他优化算法。

鲁棒性:NAdam 对超参数的选择相对鲁棒,尤其是在学习率和动量参数的选择上。这使得 NAdam 在实际应用中更容易调参。

适合大规模数据:NAdam 能够高效处理大规模数据集和高维参数空间,适合深度学习中的大规模优化问题。

3.2 缺点

计算复杂度较高:由于 NAdam 需要同时维护一阶矩和二阶矩估计,并计算 Nesterov 动量,其计算复杂度略高于传统的梯度下降法。虽然现代深度学习框架(如 PyTorch、TensorFlow)已经对 NAdam 进行了高效实现,但在某些资源受限的场景下,计算开销仍然是一个问题。

对初始学习率敏感:尽管 NAdam 对超参数的选择相对鲁棒,但初始学习率的选择仍然对性能有较大影响。如果初始学习率设置不当,可能会导致收敛速度变慢或无法收敛。

可能陷入局部最优:在某些复杂的非凸优化问题中,NAdam 可能会陷入局部最优解,尤其是在损失函数存在大量鞍点或平坦区域时。

内存占用较高:NAdam 需要存储一阶矩和二阶矩估计,这会增加内存占用。对于非常大的模型(如 GPT-3 等),内存占用可能成为一个瓶颈。

理论分析较少:相比于 Adam 和传统的梯度下降法,NAdam 的理论分析相对较少。虽然实验结果表明 NAdam 在大多数任务中表现优异,但其理论性质仍需进一步研究。

四、相关链接

Python 梯度下降法合集:

  • Python 梯度下降法(一):Gradient Descent-CSDN博客
  • Python 梯度下降法(二):RMSProp Optimize-CSDN博客
  • Python 梯度下降法(三):Adagrad Optimize-CSDN博客
  • Python 梯度下降法(四):Adadelta Optimize-CSDN博客
  • Python 梯度下降法(五):Adam Optimize-CSDN博客
  • Python 梯度下降法(六):Nadam Optimize-CSDN博客
  • Python 梯度下降法(七):Summary-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【狂热算法篇】探秘图论之Dijkstra 算法:穿越图的迷宫的最短路径力量(通俗易懂版)

羑悻的小杀马特.-CSDN博客羑悻的小杀马特.擅长C/C题海汇总,AI学习,c的不归之路,等方面的知识,羑悻的小杀马特.关注算法,c,c语言,青少年编程领域.https://blog.csdn.net/2401_82648291?typebbshttps://blog.csdn.net/2401_82648291?typebbshttps://blog.csdn.net/2401_8264829…

MySQL(Undo日志)

后面也会持续更新&#xff0c;学到新东西会在其中补充。 建议按顺序食用&#xff0c;欢迎批评或者交流&#xff01; 缺什么东西欢迎评论&#xff01;我都会及时修改的&#xff01; 大部分截图和文章采用该书&#xff0c;谢谢这位大佬的文章&#xff0c;在这里真的很感谢让迷茫的…

全面剖析 XXE 漏洞:从原理到修复

目录 前言 XXE 漏洞概念 漏洞原理 XML 介绍 XML 结构语言以及语法 XML 结构 XML 语法规则 XML 实体引用 漏洞存在原因 产生条件 经典案例介绍分析 XXE 漏洞修复方案 结语 前言 网络安全领域暗藏危机&#xff0c;各类漏洞威胁着系统与数据安全。XXE 漏洞虽不常见&a…

初级数据结构:栈和队列

目录 一、栈 (一)、栈的定义 (二)、栈的功能 (三)、栈的实现 1.栈的初始化 2.动态扩容 3.压栈操作 4.出栈操作 5.获取栈顶元素 6.获取栈顶元素的有效个数 7.检查栈是否为空 8.栈的销毁 9.完整代码 二、队列 (一)、队列的定义 (二)、队列的功能 (三&#xff09…

登录认证(5):过滤器:Filter

统一拦截 上文我们提到&#xff08;登录认证&#xff08;4&#xff09;&#xff1a;令牌技术&#xff09;&#xff0c;现在大部分项目都使用JWT令牌来进行会话跟踪&#xff0c;来完成登录功能。有了JWT令牌可以标识用户的登录状态&#xff0c;但是完整的登录逻辑如图所示&…

Python 网络爬虫实战:从基础到高级爬取技术

&#x1f4dd;个人主页&#x1f339;&#xff1a;一ge科研小菜鸡-CSDN博客 &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; 1. 引言 网络爬虫&#xff08;Web Scraping&#xff09;是一种自动化技术&#xff0c;利用程序从网页中提取数据&#xff0c;广泛…

MySQL锁类型(详解)

锁的分类图&#xff0c;如下&#xff1a; 锁操作类型划分 读锁 : 也称为共享锁 、英文用S表示。针对同一份数据&#xff0c;多个事务的读操作可以同时进行而不会互相影响&#xff0c;相互不阻塞的。 写锁 : 也称为排他锁 、英文用X表示。当前写操作没有完成前&#xff0c;它会…

93,【1】buuctf web [网鼎杯 2020 朱雀组]phpweb

进入靶场 页面一直在刷新 在 PHP 中&#xff0c;date() 函数是一个非常常用的处理日期和时间的函数&#xff0c;所以应该用到了 再看看警告的那句话 Warning: date(): It is not safe to rely on the systems timezone settings. You are *required* to use the date.timez…

51单片机 01 LED

一、点亮一个LED 在STC-ISP中单片机型号选择 STC89C52RC/LE52RC&#xff1b;如果没有找到hex文件&#xff08;在objects文件夹下&#xff09;&#xff0c;在keil中options for target-output- 勾选 create hex file。 如果要修改编程 &#xff1a;重新编译-下载/编程-单片机重…

【Rust自学】19.2. 高级trait:关联类型、默认泛型参数和运算符重载、完全限定语法、supertrait和newtype

喜欢的话别忘了点赞、收藏加关注哦&#xff08;加关注即可阅读全文&#xff09;&#xff0c;对接下来的教程有兴趣的可以关注专栏。谢谢喵&#xff01;(&#xff65;ω&#xff65;) 19.2.1. 在trait定义中使用关联类型来指定占位类型 我们首先在第10章的10.3. trait Pt.1&a…

Elasticsearch:如何搜索含有复合词的语言

作者&#xff1a;来自 Elastic Peter Straer 复合词在文本分析和标记过程中给搜索引擎带来挑战&#xff0c;因为它们会掩盖词语成分之间的有意义的联系。连字分解器标记过滤器等工具可以通过解构复合词来帮助解决这些问题。 德语以其长复合词而闻名&#xff1a;Rindfleischetik…

web-SQL注入-CTFHub

前言 在众多的CTF平台当中&#xff0c;作者认为CTFHub对于初学者来说&#xff0c;是入门平台的不二之选。CTFHub通过自己独特的技能树模块&#xff0c;可以帮助初学者来快速入门。具体请看官方介绍&#xff1a;CTFHub。 作者更新了CTFHub系列&#xff0c;希望小伙伴们多多支持…

WPS动画:使图形平移、围绕某个顶点旋转一定角度

1、平移 案例三角形如下图&#xff0c;需求&#xff1a;该三角形的A点平移至原点 &#xff08;1&#xff09;在预想动画结束的位置绘制出图形 &#xff08;2&#xff09;点击选中原始图像&#xff0c;插入/动画/绘制自定义路径/直线 &#xff08;3&#xff09;十字星绘制的直线…

xmind使用教程

xmind使用教程 前言xmind版本信息“xmind使用教程”的xmind思维导图 前言 首先xmind是什么&#xff1f;XMind 是一款思维导图和头脑风暴工具&#xff0c;用于帮助用户组织和可视化思维、创意和信息。它允许用户通过图形化的方式来创建、整理和分享思维导图&#xff0c;可以用于…

KNIME:开源 AI 数据科学

KNIME&#xff08;Konstanz Information Miner&#xff09;是一款开源且功能强大的数据科学平台&#xff0c;由德国康斯坦茨大学的软件工程师团队开发&#xff0c;自2004年推出以来&#xff0c;广泛应用于数据分析、数据挖掘、机器学习和可视化等领域。以下是对KNIME的深度介绍…

2025年01月27日Github流行趋势

项目名称&#xff1a;onlook项目地址url&#xff1a;https://github.com/onlook-dev/onlook项目语言&#xff1a;TypeScript历史star数&#xff1a;5340今日star数&#xff1a;211项目维护者&#xff1a;Kitenite, drfarrell, iNerdStack, abhiroopc84, apps/dependabot项目简介…

TCL C++开发面试题及参考答案

进程和线程的区别 进程和线程都是操作系统中重要的概念,它们在很多方面存在着明显的区别。 从概念上来说,进程是资源分配的基本单位,每个进程都有自己独立的地址空间、内存、文件描述符等资源。例如,当我们在计算机上同时运行多个应用程序,像浏览器、文本编辑器等,每个应…

本地部署DeepSeek-R1模型(新手保姆教程)

背景 最近deepseek太火了&#xff0c;无数的媒体都在报道&#xff0c;很多人争相着想本地部署试验一下。本文就简单教学一下&#xff0c;怎么本地部署。 首先大家要知道&#xff0c;使用deepseek有三种方式&#xff1a; 1.网页端或者是手机app直接使用 2.使用代码调用API …

VS Code 复制正确格式的文件路径/文件夹路径 (绝对路径,相对路径, 斜杠 /, 反斜杠\\ 等)

VS Code 搜索 : baincd.copy-path-unixstyle Github : https://github.com/baincd/vscode-copy-path-unixstyle 插件市场: https://marketplace.visualstudio.com/items?itemNamebaincd.copy-path-unixstyle 支持复制各种格式的路径 格式 GitBash /c/chris/project-name/sr…

每天学点小知识之设计模式的艺术-策略模式

行为型模式的名称、定义、学习难度和使用频率如下表所示&#xff1a; 1.如何理解模板方法模式 模板方法模式是结构最简单的行为型设计模式&#xff0c;在其结构中只存在父类与子类之间的继承关系。通过使用模板方法模式&#xff0c;可以将一些复杂流程的实现步骤封装在一系列基…