优化器算法SGD、Adam、AdamW等

文章目录

    • SGD
    • SGD with momentum
    • SGD with Nesterov Acceleration
    • AdaGrad
    • RMSprop
    • AdaDelta
    • Adam
    • AdamW
    • 参考资料

假设有:

  • 待优化的目标函数为 f ( w ) f(w) f(w),使用优化算法来最小化目标函数 f ( w ) : a r g m i n w f ( w ) f(w):argmin_wf(w) f(w):argminwf(w)

  • 在时间步t的梯度 g t = ∇ f ( w t ) g_t= \nabla f(w_t) gt=f(wt)

  • 模型参数为 w w w w t w_t wt为时刻t的参数, w t + 1 w_{t+1} wt+1​​是时刻t+1的参数

  • 在时刻t的学习率为 α t \alpha_t αt

  • 平滑项 ϵ \epsilon ϵ

SGD

SGD(Stochastic gradient descent)只考虑当前时间步的梯度,其更新方式为
w t + 1 = w t − α t g t w_{t+1} = w_t - \alpha_t g_t wt+1=wtαtgt
pytorch 对应的类为torch.optim.SGD

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

SGD with momentum

对于非凸目标函数,可能会存在多个局部极小值,使用SGD求解时,在这些局部极小值附近的梯度很小,使得优化算法陷入到局部最优解。

而带动量的SGD算法不仅仅使用当前梯度,也会考虑到历史梯度,设动量参数为 μ \mu μ,其参数更新方式为:
b t = μ b t − 1 + g t w t + 1 = w t − α t b t b_t = \mu b_{t-1} + g_t \\ w_{t+1} = w_t - \alpha_t b_t bt=μbt1+gtwt+1=wtαtbt
pytorch 对应的类也为torch.optim.SGD,可以设置momentum参数。

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

SGD with Nesterov Acceleration

SGD with Nesterov Acceleration是对SGD with momentum的改进,先根据累积梯度进行一次参数更新。
g t = ∇ f ( w t − μ b t − 1 ) b t = μ b t − 1 + g t w t + 1 = w t − α t b t g_t = \nabla f(w_{t} - \mu b_{t-1}) \\ b_t = \mu b_{t-1} + g_t \\ w_{t+1} = w_t - \alpha_t b_t gt=f(wtμbt1)bt=μbt1+gtwt+1=wtαtbt
pytorch 对应的类也为torch.optim.SGD,在设置momentum参数后,设置nesterov参数为True。

AdaGrad

AdaGrad(Adaptive Gradient Algorithm)是在每次迭代时自适应地调整每个参数的学习率,出自2021年的论文《Adaptive Subgradient Methods for Online Learning and Stochastic Optimization》。

若有d个参数
v t = d i a g ( ∑ i = 1 t g i , 1 2 , ∑ i = 1 t g i , 2 2 , ⋯ , ∑ i = 1 t g i , d 2 ) w t + 1 = w t − α t g t v t + ϵ v_t = diag(\sum^t_{i=1}g^2_{i,1}, \sum^t_{i=1}g^2_{i,2}, \cdots,\sum^t_{i=1}g^2_{i,d} ) \\ w_{t+1} = w_t - \alpha_t \frac{g_t}{\sqrt{v_t} + \epsilon} vt=diag(i=1tgi,12,i=1tgi,22,,i=1tgi,d2)wt+1=wtαtvt +ϵgt
相比于SGD,每个参数的学习率是会随时间变化的,即对于第j个参数,学习率为 α t v t , j + ϵ = α t s u m i = t t g i , j 2 + ϵ \frac{\alpha_t}{\sqrt{v_{t,j}} + \epsilon} = \frac{\alpha_t}{\sqrt{sum^t_{i=t}g_{i,j}^2} + \epsilon} vt,j +ϵαt=sumi=ttgi,j2 +ϵαt。并且AdaGrad使用了二阶动量。

pytorch 对应的类为torch.optim.Adagrad

RMSprop

AdaGrad考虑过去所有时间的梯度累加和,所以学习率可能会趋近于零,从而使模型在没有找到最优解时就终止了学习。 RMSprop对AdaGrad进行了改进,该算法出自 G. Hinton的 lecture notes 。RMSprop相比于AdaGrad只关注过去一段时间窗口的梯度平方和:
v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) w t + 1 = w t − α t g t v t + ϵ v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ w_{t+1} = w_t - \alpha_t \frac{g_t}{\sqrt{v_t} + \epsilon} vt=β2vt1+(1β2)diag(gt2)wt+1=wtαtvt +ϵgt
pytorch对应的类为torch.optim.RMSprop

AdaDelta

AdaGrad考虑过去所有时间的梯度累加和,所以学习率可能会趋近于零,从而使模型在没有找到最优解时就终止了学习。 AdaDelta对AdaGrad进行了改进,该算法出自论文《ADADELTA: An Adaptive Learning Rate Method》。AdaDelta相比于AdaGrad有两个改进:

  • 只关注过去一段时间窗口的梯度平方和: v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ d i a g ( g t 2 ) v_t = \beta_2 \cdot v_{t-1} + (1- \beta_2) \cdot diag(g^2_t) vt=β2vt1+(1β2)diag(gt2)(指数移动平均) ,一般取 β 2 = 0.9 \beta_2 = 0.9 β2=0.9​(相当于关注过去10个时间步的梯度平方和)
  • 引入每次参数更新差值 Δ θ \Delta \theta Δθ的平方的指数移动平均: Δ X t − 1 2 = β 1 Δ X t − 2 2 + ( 1 − β 1 ) Δ θ t − 1 ⊙ Δ θ t − 1 \Delta X^2_{t-1} = \beta_1 \Delta X^2_{t-2} + (1-\beta_1) \Delta \theta_{t-1} \odot \Delta \theta_{t-1} ΔXt12=β1ΔXt22+(1β1)Δθt1Δθt1

v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) Δ X t − 1 2 = β 1 Δ X t − 2 2 + ( 1 − β 1 ) Δ θ t − 1 ⊙ Δ θ t − 1 w t + 1 = w t − α t Δ X t − 1 2 + ϵ v t + ϵ g t v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ \Delta X^2_{t-1} = \beta_1 \Delta X^2_{t-2} + (1-\beta_1) \Delta \theta_{t-1} \odot \Delta \theta_{t-1} \\ w_{t+1} = w_t - \alpha_t \frac{\sqrt{\Delta X^2_{t-1} + \epsilon}}{\sqrt{v_t} + \epsilon}g_t vt=β2vt1+(1β2)diag(gt2)ΔXt12=β1ΔXt22+(1β1)Δθt1Δθt1wt+1=wtαtvt +ϵΔXt12+ϵ gt

pytorch对应的类为torch.optim.AdaDelta

Adam

Adam出自论文《Adam: A Method for Stochastic Optimization》,它同时考虑了一阶动量和二阶动量。(公式中的纠正项 m ^ t \hat m_t m^t v ^ t \hat v_t v^t只在初始阶段校正)
m t = β 1 m t − 1 + ( 1 − β 1 ) g t v t = β 2 v t − 1 + ( 1 − β 2 ) d i a g ( g t 2 ) m ^ t = m t 1 − β 1 t v ^ t = v t 1 − β 2 t w t + 1 = w t − α t m ^ t v ^ t + ϵ m_t = \beta_1m_{t-1} + (1-\beta_1)g_t \\ v_t = \beta_2 v_{t-1} + (1- \beta_2) diag(g^2_t) \\ \hat m_t = \frac{m_t}{1-\beta_1^t} \\ \hat v_t = \frac{v_t}{1-\beta_2^t} \\ w_{t+1} = w_t - \alpha_t \frac{\hat m_t}{\sqrt{\hat v_t} + \epsilon} mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)diag(gt2)m^t=1β1tmtv^t=1β2tvtwt+1=wtαtv^t +ϵm^t
pytorch对应的类为torch.optim.Adam

AdamW

AdamW是对Adam的改进,出自论文《Decoupled Weight Decay Regularization》,现在大模型训练基本上都是使用AdamW优化器。

AdamW改进的主要出发点是 L 2 L_2 L2正则和权重衰减(weight decay)对于自适应梯度如Adam是不一样的,所以作者们对Adam做了如下图的修改。

在这里插入图片描述

pytorch对应的类为torch.optim.AdamW

参考资料

  1. pytorch优化算法
  2. 知乎文章:从 SGD 到 AdamW —— 优化算法的演化
  3. https://www.fast.ai/posts/2018-07-02-adam-weight-decay.html
  4. Cornell University Computational Optimization Open Textbook
  5. 神经网络与深度学习
  6. 视频:从SGD到AdamW(后面的两个视频还讲了为什么transformer用SGD的效果不好)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/756441.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

获取指定路径下,所有指定后缀文件列表

要获取指定路径下所有指定后缀的文件列表,你可以使用Python的os和glob模块。下面是一个简单的示例,展示了如何获取指定路径下所有.txt后缀的文件列表: import os import globdef get_files_with_extension(directory, extension):"&quo…

vivado 布线、路线_设计

路由 Vivado路由器对放置的设计执行路由,并对路由设计,以解决保留时间冲突。Vivado路由器从放置的设计开始,并尝试路由所有网络。它可以从已放置的未布线、部分布线或完全布线的设计。对于部分路由的设计,Vivado路由器使用现有的…

Unittest框架及自动化测试实现流程

🍅 视频学习:文末有免费的配套视频可观看 🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 Unittest框架介绍 Unittest框架是Python中一个标准的库中的一个模块,该模块包括许多…

红与黑(c++题解)

题目描述 有一间长方形的房子,地上铺了红色、黑色两种颜色的正方形瓷砖。你站在其中一块黑色的瓷砖上,只能向相邻的黑色瓷砖移动。请写一个程序,计算你总共能够到达多少块黑色的瓷砖。 输入格式 包括多个数据集合。每个数据集合的第一行是…

【Mysql】面试题汇总

1. 存储引擎 1-1. MySQL 支持哪些存储引擎?默认使用哪个? 答: MySQL 支持的存储引擎包括 InnoDB、MyISAM、Memory 等。 Mysql 5.5 之前默认的是MyISAM,Mysql 5.5 之后默认的是InnoDB。 可以通过 show engines 查看 Mysql 支持…

外包2月,技术退步惊现!大专生逆袭大厂,全靠这份神秘资料!

大家好,我是一名大专生,19年通过校招进入湖南某软件公司,从事功能测试工作已近4年。今年8月,我意识到长期舒适的环境让我变得不思进取,技术停滞不前,甚至因此失去了谈了2年的女朋友。我下定决心&#xff0c…

金蝶云星空——插件dll重新发布报错:鏃犳硶鏄剧ず椤甸潰锛屽洜涓哄彂鐢熷唴閮ㄦ湇鍔″櫒閿欒銆�

项目场景: 金蝶插件开发 问题描述 今天更新了插件dll然后重启IIS金蝶就报如下错误: 解决方案: 折腾了一天结果发现是给自己挖坑了,这次更新我担心插件代码有问题就把原dll重命名了然后把最新dll更新到金蝶bin文件中&#xff0c…

vue实现双向绑定原理深度解析

1. vue双向绑定应用场景 Vue的双向绑定机制主要体现在以下几个方面: 表单输入:在表单输入中,Vue的双向绑定机制非常有用。通过v-model指令,可以将表单元素的值与Vue实例中的数据进行双向绑定,当用户在表单输入框中输入内容时,数据会自动更新,反之,当数据发生变化时,输…

【DBC专题】-11-使用Cantools将CAN/CANFD DBC自动生成C语言代码

目录 1 安装Python和Cantools 1.1 查看Python已安装的Package包 1.2 在Python中安装Cantools插件包 1.3 获取更多Cantools工具的更新动态 2 经典CAN/CANFD DBC自动生成C语言代码 2.1 批处理文件CAN_DBC_To_C.bat内容说明 2.2 经典CAN/CANFD DBC文件要求 2.3 如何使用生…

网站引用图片但它域名被墙了或者它有防盗链,我们想引用但又不能显示,本文附详细的解决方案非常简单!

最好的办法就是直接读取图片文件&#xff0c;用到php中一个常用的函数file_get_contents(图片地址)&#xff0c;意思是读取远程的一张图片&#xff0c;在输出就完事。非常简单&#xff5e;话不多说&#xff0c;直接上代码 <?php header("Content-type: image/jpeg&quo…

clipboard好用的复制剪切库

clipboard是现代复制到剪贴板的工具&#xff0c;其 gzip 压缩后只有 3kb&#xff0c;能够减少选择文本的重复操作&#xff0c;点击按钮就可以复制指定内容&#xff0c;支持原生HTMLjs&#xff0c;vue3和vue2。使用方法参照官方文档&#xff0c;so easy&#xff01;&#xff01;…

装X神器,装X图片生成器,高富帅模拟器

先展示两张效果 基金装X图 短信存款图 神器功能展示 总共有12大类可供用户选择 还有一些美感的&#xff1a; 总结 总之种类非常多&#xff0c;有了这个神器你懂的&#xff5e; 关注下方公众号&#xff0c;回复【zzsq】即可获取。

YS/T 429.2-2012 有机聚合物喷涂幕墙铝单板检测

有机聚合物喷涂幕墙铝单板是指以氟碳漆或粉末做表面涂层的幕墙用铝及铝合金单层形成的铝单板。 YS/T 429.2-2012有机聚合物喷涂幕墙铝单板检测项目&#xff1a; 测试项目 测试方法 力学性能 GB/T 16865 尺寸偏差 GB/T 3880.3 光泽 GB 5237 颜色和色差 GB 5237 厚度 …

2、鸿蒙学习-申请调试证书和调试Profile文件

申请发布证书 发布证书由AGC颁发的、为HarmonyOS应用配置签名信息的数字证书&#xff0c;可保障软件代码完整性和发布者身份真实性。证书格式为.cer&#xff0c;包含公钥、证书指纹等信息。 说明 请确保您的开发者帐号已实名认证。每个帐号最多申请1个发布证书。 1、登录AppGa…

Linux软件管理(1)

软件管理 下载 wget Linux wget是一个下载文件的工具&#xff0c;它用在命令行下。 wget工具体积小但功能完善&#xff0c;它支持断点下载功能&#xff0c;同时支持FTP和HTTP下载方式&#xff0c;支持代理服务器和设置起来方便简单。 1.语法 wget [选项]……[URL]…… 2、…

阅读基础知识1

一 网络 1. 三次握手四次挥手 三次握手&#xff1a;为了建立长链接进行交互即建立一个会话&#xff0c;使用 http/https 协议 ① 客户端产生初始化序列号 Seqx &#xff0c;向服务端发送建立连接的请求报文&#xff0c;将 SYN1 同步序列号&#xff1b; ② 服务端接收建立连接…

5.1.4、【AI技术新纪元:Spring AI解码】Amazon Bedrock

Amazon Bedrock是一个托管服务,通过统一的 API 提供来自各种 AI 提供商的基础模型。 Spring AI 通过实现 Spring 接口 ChatClient、StreamingChatClient 和 EmbeddingClient 来支持所有通过Amazon Bedrock可用的聊天和嵌入式 AI 模型。 此外,Spring AI 为所有客户端提供了 …

洛谷_P1068 [NOIP2009 普及组] 分数线划定_python写法

P1068 [NOIP2009 普及组] 分数线划定 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 知识点&#xff1a; 这道题用到了自定义排序规则 n, m map(int,input().split()) data [] for i in range(n):l list(map(int,input().split()))data.append(l)import functoolsdef my_cm…

软件测试:C++ Google Test单元测试框架GTest

目录 编译和安装框架使用AssertionsGoogle TestingGoogle MockingMatchersActions 运行结果 最近在写项目的时候&#xff0c;学到了许多关于软件测试的知识&#xff0c;也不断的使用新的测试框架和测试工具&#xff0c;每次总是机械式的拼接其他人的代码&#xff0c;代码发生错…

香港科技大学广州|智能制造学域博士招生宣讲会—同济大学专场

时间&#xff1a;2024年3月28日&#xff08;星期四&#xff09;10:00 地点&#xff1a;同济大学嘉定校区济人楼310 报名链接&#xff1a;https://www.wjx.top/vm/mmukLPC.aspx# 宣讲嘉宾&#xff1a;崔华晨 助理教授 跨学科重点研究领域 •工业4.0 •智能传感器、自动光学检…