优化器算法SGD、Adam、AdamW等

文章目录

    • SGD
    • SGD with momentum
    • SGD with Nesterov Acceleration
    • AdaGrad
    • RMSprop
    • AdaDelta
    • Adam
    • AdamW
    • 参考资料

假设有:

  • 待优化的目标函数为 f ( w ) f(w) f(w),使用优化算法来最小化目标函数 f ( w ) : a r g m i n w f ( w ) f(w):argmin_wf(w) f(w):argminwf(w)

  • 在时间步t的梯度 g t = ∇ f ( w t ) g_t= \nabla f(w_t) gt=f(wt)

  • 模型参数为 w w w w t w_t wt为时刻t的参数, w t + 1 w_{t+1} wt+1​​是时刻t+1的参数

  • 在时刻t的学习率为 α t \alpha_t αt

  • 平滑项 ϵ \epsilon ϵ

SGD

SGD(Stochastic gradient descent)只考虑当前时间步的梯度,其更新方式为
w t + 1 = w t − α t g t w_{t+1} = w_t - \alpha_t g_t wt+1=wtαtgt
pytorch 对应的类为torch.optim.SGD

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

SGD with momentum

对于非凸目标函数,可能会存在多个局部极小值,使用SGD求解时,在这些局部极小值附近的梯度很小,使得优化算法陷入到局部最优解。

而带动量的SGD算法不仅仅使用当前梯度,也会考虑到历史梯度,设动量参数为 μ \mu μ,其参数更新方式为:
b t = μ b t − 1 + g t w t + 1 = w t − α t b t b_t = \mu b_{t-1} + g_t \\ w_{t+1} = w_t - \alpha_t b_t bt=μbt1+gtwt+1=wtαtbt
pytorch 对应的类也为torch.optim.SGD,可以设置momentum参数。

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

SGD with Nesterov Acceleration

SGD with Nesterov Acceleration是对SGD with momentum的改进,先根据累积梯度进行一次参数更新。
g t = ∇ f ( w t − μ b t − 1 ) b t = μ b t − 1 + g t w t + 1 = w t − α t b t g_t = \nabla f(w_{t} - \mu b_{t-1}) \\ b_t = \mu b_{t-1} + g_t \\ w_{t+1} = w_t - \alpha_t b_t gt=f(wtμbt1)bt=μbt1+gtwt+1=wtαtbt
pytorch 对应的类也为torch.optim.SGD,在设置momentum参数后,设置nesterov参数为True。

AdaGrad

AdaGrad(Adaptive Gradient Algorithm)是在每次迭代时自适应地调整每个参数的学习率,出自2021年的论文《Adaptive Subgradient Methods for Online Learning and Stochastic Optimization》。

若有d个参数
v t = d i a g ( ∑ i = 1 t g i , 1 2 , ∑ i = 1 t g i , 2 2 , ⋯ , ∑ i = 1 t g i , d 2 ) w t + 1 = w t − α t g t v t + ϵ v_t = diag(\sum^t_{i=1}g^2_{i,1}, \sum^t_{i=1}g^2_{i,2}, \cdots,\sum^t_{i=1}g^2_{i,d} ) \\ w_{t+1} = w_t - \alpha_t \frac{g_t}{\sqrt{v_t} + \epsilon} vt=diag(i=1tgi,12,i=1tgi,22,,i=1tgi,d2)wt+1=wtαtvt +ϵgt
相比于SGD,每个参数的学习率是会随时间变化的,即对于第j个参数,学习率为 α t v t , j + ϵ = α t s u m i = t t g i , j 2 + ϵ \frac{\alpha_t}{\sqrt{v_{t,j}} + \epsilon} = \frac{\alpha_t}{\sqrt{sum^t_{i=t}g_{i,j}^2} + \epsilon} vt,j +ϵαt=sumi=ttgi,j2 +ϵαt。并且AdaGrad使用了二阶动量。

pytorch 对应的类为torch.optim.Adagrad

RMSprop

AdaGrad考虑过去所有时间的梯度累加和,所以学习率可能会趋近于零,从而使模型在没有找到最优解时就终止了学习。 RMSprop对AdaGrad进行了改进,该算法出自 G. Hinton的 lecture notes 。RMSprop相比于AdaGrad只关注过去一段时间窗口的梯度平方和:
v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) w t + 1 = w t − α t g t v t + ϵ v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ w_{t+1} = w_t - \alpha_t \frac{g_t}{\sqrt{v_t} + \epsilon} vt=β2vt1+(1β2)diag(gt2)wt+1=wtαtvt +ϵgt
pytorch对应的类为torch.optim.RMSprop

AdaDelta

AdaGrad考虑过去所有时间的梯度累加和,所以学习率可能会趋近于零,从而使模型在没有找到最优解时就终止了学习。 AdaDelta对AdaGrad进行了改进,该算法出自论文《ADADELTA: An Adaptive Learning Rate Method》。AdaDelta相比于AdaGrad有两个改进:

  • 只关注过去一段时间窗口的梯度平方和: v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ d i a g ( g t 2 ) v_t = \beta_2 \cdot v_{t-1} + (1- \beta_2) \cdot diag(g^2_t) vt=β2vt1+(1β2)diag(gt2)(指数移动平均) ,一般取 β 2 = 0.9 \beta_2 = 0.9 β2=0.9​(相当于关注过去10个时间步的梯度平方和)
  • 引入每次参数更新差值 Δ θ \Delta \theta Δθ的平方的指数移动平均: Δ X t − 1 2 = β 1 Δ X t − 2 2 + ( 1 − β 1 ) Δ θ t − 1 ⊙ Δ θ t − 1 \Delta X^2_{t-1} = \beta_1 \Delta X^2_{t-2} + (1-\beta_1) \Delta \theta_{t-1} \odot \Delta \theta_{t-1} ΔXt12=β1ΔXt22+(1β1)Δθt1Δθt1

v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) Δ X t − 1 2 = β 1 Δ X t − 2 2 + ( 1 − β 1 ) Δ θ t − 1 ⊙ Δ θ t − 1 w t + 1 = w t − α t Δ X t − 1 2 + ϵ v t + ϵ g t v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ \Delta X^2_{t-1} = \beta_1 \Delta X^2_{t-2} + (1-\beta_1) \Delta \theta_{t-1} \odot \Delta \theta_{t-1} \\ w_{t+1} = w_t - \alpha_t \frac{\sqrt{\Delta X^2_{t-1} + \epsilon}}{\sqrt{v_t} + \epsilon}g_t vt=β2vt1+(1β2)diag(gt2)ΔXt12=β1ΔXt22+(1β1)Δθt1Δθt1wt+1=wtαtvt +ϵΔXt12+ϵ gt

pytorch对应的类为torch.optim.AdaDelta

Adam

Adam出自论文《Adam: A Method for Stochastic Optimization》,它同时考虑了一阶动量和二阶动量。(公式中的纠正项 m ^ t \hat m_t m^t v ^ t \hat v_t v^t只在初始阶段校正)
m t = β 1 m t − 1 + ( 1 − β 1 ) g t v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) m ^ t = m t 1 − β 1 t v ^ t = v t 1 − β 2 t w t + 1 = w t − α t m ^ t v ^ t + ϵ m_t = \beta_1m_{t-1} + (1-\beta_1)g_t \\ v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ \hat m_t = \frac{m_t}{1-\beta_1^t} \\ \hat v_t = \frac{v_t}{1-\beta_2^t} \\ w_{t+1} = w_t - \alpha_t \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)diag(gt2)m^t=1β1tmtv^t=1β2tvtwt+1=wtαtv^t +ϵm^t
pytorch对应的类为torch.optim.Adam

AdamW

AdamW是对Adam的改进,出自论文《Decoupled Weight Decay Regularization》,现在大模型训练基本上都是使用AdamW优化器。

AdamW改进的主要出发点是 L 2 L_2 L2正则和权重衰减(weight decay)对于自适应梯度如Adam是不一样的,所以作者们对Adam做了如下图的修改。

在这里插入图片描述

pytorch对应的类为torch.optim.AdamW

参考资料

  1. pytorch优化算法
  2. 知乎文章:从 SGD 到 AdamW —— 优化算法的演化
  3. https://www.fast.ai/posts/2018-07-02-adam-weight-decay.html
  4. Cornell University Computational Optimization Open Textbook
  5. 神经网络与深度学习
  6. 视频:从SGD到AdamW(后面的两个视频还讲了为什么transformer用SGD的效果不好)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/756441.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Mysql】面试题汇总

1. 存储引擎 1-1. MySQL 支持哪些存储引擎?默认使用哪个? 答: MySQL 支持的存储引擎包括 InnoDB、MyISAM、Memory 等。 Mysql 5.5 之前默认的是MyISAM,Mysql 5.5 之后默认的是InnoDB。 可以通过 show engines 查看 Mysql 支持…

外包2月,技术退步惊现!大专生逆袭大厂,全靠这份神秘资料!

大家好,我是一名大专生,19年通过校招进入湖南某软件公司,从事功能测试工作已近4年。今年8月,我意识到长期舒适的环境让我变得不思进取,技术停滞不前,甚至因此失去了谈了2年的女朋友。我下定决心&#xff0c…

金蝶云星空——插件dll重新发布报错:鏃犳硶鏄剧ず椤甸潰锛屽洜涓哄彂鐢熷唴閮ㄦ湇鍔″櫒閿欒銆�

项目场景: 金蝶插件开发 问题描述 今天更新了插件dll然后重启IIS金蝶就报如下错误: 解决方案: 折腾了一天结果发现是给自己挖坑了,这次更新我担心插件代码有问题就把原dll重命名了然后把最新dll更新到金蝶bin文件中&#xff0c…

【DBC专题】-11-使用Cantools将CAN/CANFD DBC自动生成C语言代码

目录 1 安装Python和Cantools 1.1 查看Python已安装的Package包 1.2 在Python中安装Cantools插件包 1.3 获取更多Cantools工具的更新动态 2 经典CAN/CANFD DBC自动生成C语言代码 2.1 批处理文件CAN_DBC_To_C.bat内容说明 2.2 经典CAN/CANFD DBC文件要求 2.3 如何使用生…

网站引用图片但它域名被墙了或者它有防盗链,我们想引用但又不能显示,本文附详细的解决方案非常简单!

最好的办法就是直接读取图片文件&#xff0c;用到php中一个常用的函数file_get_contents(图片地址)&#xff0c;意思是读取远程的一张图片&#xff0c;在输出就完事。非常简单&#xff5e;话不多说&#xff0c;直接上代码 <?php header("Content-type: image/jpeg&quo…

clipboard好用的复制剪切库

clipboard是现代复制到剪贴板的工具&#xff0c;其 gzip 压缩后只有 3kb&#xff0c;能够减少选择文本的重复操作&#xff0c;点击按钮就可以复制指定内容&#xff0c;支持原生HTMLjs&#xff0c;vue3和vue2。使用方法参照官方文档&#xff0c;so easy&#xff01;&#xff01;…

装X神器,装X图片生成器,高富帅模拟器

先展示两张效果 基金装X图 短信存款图 神器功能展示 总共有12大类可供用户选择 还有一些美感的&#xff1a; 总结 总之种类非常多&#xff0c;有了这个神器你懂的&#xff5e; 关注下方公众号&#xff0c;回复【zzsq】即可获取。

2、鸿蒙学习-申请调试证书和调试Profile文件

申请发布证书 发布证书由AGC颁发的、为HarmonyOS应用配置签名信息的数字证书&#xff0c;可保障软件代码完整性和发布者身份真实性。证书格式为.cer&#xff0c;包含公钥、证书指纹等信息。 说明 请确保您的开发者帐号已实名认证。每个帐号最多申请1个发布证书。 1、登录AppGa…

Linux软件管理(1)

软件管理 下载 wget Linux wget是一个下载文件的工具&#xff0c;它用在命令行下。 wget工具体积小但功能完善&#xff0c;它支持断点下载功能&#xff0c;同时支持FTP和HTTP下载方式&#xff0c;支持代理服务器和设置起来方便简单。 1.语法 wget [选项]……[URL]…… 2、…

阅读基础知识1

一 网络 1. 三次握手四次挥手 三次握手&#xff1a;为了建立长链接进行交互即建立一个会话&#xff0c;使用 http/https 协议 ① 客户端产生初始化序列号 Seqx &#xff0c;向服务端发送建立连接的请求报文&#xff0c;将 SYN1 同步序列号&#xff1b; ② 服务端接收建立连接…

洛谷_P1068 [NOIP2009 普及组] 分数线划定_python写法

P1068 [NOIP2009 普及组] 分数线划定 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 知识点&#xff1a; 这道题用到了自定义排序规则 n, m map(int,input().split()) data [] for i in range(n):l list(map(int,input().split()))data.append(l)import functoolsdef my_cm…

香港科技大学广州|智能制造学域博士招生宣讲会—同济大学专场

时间&#xff1a;2024年3月28日&#xff08;星期四&#xff09;10:00 地点&#xff1a;同济大学嘉定校区济人楼310 报名链接&#xff1a;https://www.wjx.top/vm/mmukLPC.aspx# 宣讲嘉宾&#xff1a;崔华晨 助理教授 跨学科重点研究领域 •工业4.0 •智能传感器、自动光学检…

web攻防——csrf,ssrf

csrf 当我们在访问自己的管理员系统的时候&#xff0c;打开别人发的钓鱼连接就会自动增加管理员&#xff08;前提&#xff0c;后台在登录状态&#xff09;当我们打开别人发的网站&#xff0c;就会触发增加管理员的数据包 假设我们要测试这个网站 看到这个&#xff0c;就得下载一…

计算机组成原理-3-系统总线

3. 系统总线 文章目录 3. 系统总线3.1 总线的基本概念3.2 总线的分类3.3 总线特性及性能指标3.4 总线结构3.5 总线控制3.5.1 总线判优控制3.5.2 总线通信控制 本笔记参考哈工大刘宏伟老师的MOOC《计算机组成原理&#xff08;上&#xff09;_哈尔滨工业大学》、《计算机组成原理…

Positive Technologies 专家发现的漏洞已在 ABB 控制器中得到修复

&#x1f31f; 我们的同事一如既往地表现出色&#xff1a;应用分析专家 Natalia Tlyapova 和 Denis Goryushev 因发现 Freelance AC 900F 和 AC 700F 控制器中的两个漏洞而受到 ABB 的表彰。 这些设备用于自动化大规模连续循环生产设施和构建企业配送控制系统。利用这些漏洞的…

Codeforces Round 925 (Div. 3) G. One-Dimensional Puzzle【推公式+组合数学+隔板法】

原题链接&#xff1a;https://codeforces.com/problemset/problem/1931/G 题目描述&#xff1a; 有 4 种拼图&#xff0c;其中第 i 种拼图有 ci​ 张。 两张拼图可以连结当且仅当它们相邻的卡槽中一个凹陷一个突出。 我们希望将所有的拼图从左往右拼起来&#xff0c;求总方案…

js 中文乱码解决、乱码对照

1、js iso-8859-1转utf-8 在JavaScript中&#xff0c;可以使用内置的TextEncoder和TextDecoderAPI来实现ISO-8859-1编码和UTF-8编码之间的转换。以下是一个将ISO-8859-1编码的字符串转换为UTF-8编码的示例代码&#xff1a; function convertISO88591ToUTF8(isoString) {// 将…

C语言数据结构基础——二叉树学习笔记(二)topk问题

1.top-k问题 1.1思路分析 TOP-K 问题&#xff1a;即求数据结合中前 K 个最大的元素或者最小的元素&#xff0c;一般情况下数据量都比较大 。 比如&#xff1a;专业前 10 名、世界 500 强、富豪榜、游戏中前 100 的活跃玩家等。 对于 Top-K 问题&#xff0c;能想到的最简单直…

Gradle v8.5 笔记 - 从入门到进阶(基于 Kotlin DSL)

目录 一、前置说明 二、Gradle 启动&#xff01; 2.1、安装 2.2、初始化项目 2.3、gradle 项目目录介绍 2.4、Gradle 项目下载慢&#xff1f;&#xff08;万能解决办法&#xff09; 2.5、Gradle 常用命令 2.6、项目构建流程 2.7、设置文件&#xff08;settings.gradle.…

什么是web组态?Web组态软件哪个好用?

随着工业4.0的到来&#xff0c;物联网、大数据、人工智能等技术的融合应用&#xff0c;使得工业领域正在经历一场深刻的变革。在这个过程中&#xff0c;Web组态技术以其独特的优势&#xff0c;正在逐渐受到越来越多企业的关注和认可。那么&#xff0c;什么是Web组态&#xff1f…