神经网络优化器-从SGD到AdamW

优化器准则

凸优化基本概念

  • 先定义凸集,集合中的两个点连接的线还在集合里面,就是凸集,用数学语言来表示就是:对于集合中的任意两个元素x,y以及任意实数 λ ∈ ( 0 , 1 ) \lambda \in (0,1) λ(0,1),有 λ x + ( 1 − λ ) y ∈ C \lambda x + (1 - \lambda) y \in C λx+(1λ)yC,则称为凸集。
  • 再定义凸函数: f ( λ x + ( 1 − λ ) y ) ≤ λ f ( x ) + ( 1 − λ ) f ( y ) f(\lambda x + (1 - \lambda) y) \leq \lambda f(x) + (1 - \lambda) f(y) f(λx+(1λ)y)λf(x)+(1λ)f(y)其中, λ \lambda λ是一个满足 0 ≤ λ ≤ 1 0 \leq \lambda \leq 1 0λ1的实数参数。
  • 可以看出,凸函数的定义域必须是凸集。直观上,凸函数的图像不会在任何地方凹陷,这使得凸函数的局部最小值也是全局最小值,这使得优化问题更容易解决。

现在再定义凸优化问题:
凸优化是数学优化理论中的一个重要分支,它研究的是凸函数的优化问题。用数学语言来表示就是:
minimize f ( x ) subject to g i ( x ) ≤ 0 , i = 1 , … , m and h j ( x ) = 0 , j = 1 , … , p \begin{align*} \text{minimize} \quad & f(x) \\ \text{subject to} \quad & g_i(x) \leq 0, \quad i = 1, \ldots, m \\ \text{and} \quad & h_j(x) = 0, \quad j = 1, \ldots, p \\ \end{align*} minimizesubject toandf(x)gi(x)0,i=1,,mhj(x)=0,j=1,,p
其中, f ( x ) f(x) f(x) 是目标函数, g i ( x ) g_i(x) gi(x)是不等式约束函数, h j ( x ) h_j(x) hj(x)是等式约束函数, x x x是决策变量。如果目标函数 f ( x ) f(x) f(x) 和所有约束函数 g i ( x ) g_i(x) gi(x) h j ( x ) h_j(x) hj(x)都是凸函数,并且可行域(满足所有约束的 x x x的集合)也是凸集,那么这个问题就是一个凸优化问题。

凸优化问题有以下特点:

  • 局部最优即全局最优:如果一个点是局部最小点,那么它也是全局最小点。这使得寻找最优解变得更加容易。
  • 对偶性:凸优化问题具有良好的对偶性质,即原问题的对偶问题也是一个凸优化问题。
  • 存在性:如果目标函数和约束函数都是下闭的,并且可行域非空,那么凸优化问题总是有解的。
  • 稳定性:凸优化问题的解对问题的微小变化是稳定的。

研究方法有:

  • 梯度下降法:通过迭代地沿着目标函数的负梯度方向移动来找到最小点。
  • 牛顿法:利用目标函数的二阶导数(Hessian)来加速梯度下降法。
  • 内点法:一种专门用于解决有约束凸优化问题的算法。
  • 次梯度法:对于非光滑的凸函数,使用次梯度而不是梯度来优化。
  • 对偶方法:通过解决对偶问题来找到原问题的解。

神经网络优化问题定义

现在我们可以开始讨论优化器了:
深度学习模型的训练就是一个优化问题,模型权重就是我们上面提到的决策变量 x x x,目标函数就是我们所设计的损失函数(所以我们将损失函数设计成凸函数,以满足凸优化的条件),模型本身就是一个等式或者不等式约束,输出结果必须满足事先知道的label。优化器,就是解决这个凸优化问题的实现方案。
我们的优化问题用数学来表示就是:
f ( W ) = min ⁡ w 1 N ∑ i = 1 N L ( y i , F ( x i ) ) + ∑ j = 1 n λ ∥ w j ∥ f(W) = \min_{w} {\frac{1}{N}\sum_{i=1}^{N} L(y_i,F(x_i)) + \sum_{j=1 }^{n}\lambda \left \| w_{j} \right \| } f(W)=wminN1i=1NL(yi,F(xi))+j=1nλwj
W W W是模型的所有参数
前一项是损失,其中 w w w是参数, N N N是样本总数, y i y_i yi是样本标签, F ( x i ) F(x_i) F(xi)是模型结果,L是损失函数,
后一项是正则损失,用于避免过拟合现象的, λ \lambda λ是正则化系数, ∥ w j ∥ \left \| w_{j} \right \| wj是参数的范数,常见有L1,L2等

这就是深度学习优化的数学定义,这样我们就可以去使用数学方法来解决这个问题了。

  • 使用梯度下降法就是: W t = W t − 1 − α ∗ ▽ f ( W t − 1 ) W_t = W_{t-1}-\alpha *\bigtriangledown f(W_{t-1}) Wt=Wt1αf(Wt1)其中 ▽ f ( W t − 1 ) \bigtriangledown f(W_{t-1}) f(Wt1)是函数的梯度向量。
  • 使用牛顿法就是: W t = W t − 1 − α ∗ H t − 1 − 1 ∗ ▽ f ( W t − 1 ) W_t = W_{t-1}-\alpha*H_{t-1}^{-1} *\bigtriangledown f(W_{t-1}) Wt=Wt1αHt11f(Wt1)其中 H t − 1 − 1 H_{t-1}^{-1} Ht11为Hessian矩阵的逆矩阵即二阶偏导矩阵的逆矩阵。

这也是深度学习中随机梯度下降的由来,从最优化的梯度下降借鉴过来的。
优化器可以做的事情,就是对解决问题方法中的:梯度gt,学习率,参数正则项,参数初始化这几个因素进行调整。
不同的优化器,他们的区别就是这四项的不同。

优化器分类与发展

随机梯度

  • SGD:梯度计算的变种,主要区别在于gt的计算方式,原始梯度下降算法叫做GD,计算所有梯度然后更新,SGD叫做随机梯度下降,因为它每次只采用一小批训练样本作为梯度更新参数,然后根据这个梯度更新模型参数。这种方法的优点是计算效率高,因为不需要计算整个训练集上的梯度,这在数据量很大时尤其有用。
  • 动量SGD:mSGD,gt不光包括计算出的梯度,还包括了部分过去的梯度信息,好处是会加速收敛,并且跳过一些局部最优
    RMS等。

自适应梯度

算法的核心思想是根据参数的历史更新信息来调整每个参数的学习率,从而提高收敛速度并减少训练时间。

  • Adaptive Gradient:自适应梯度算法,它通过为每个参数维护一个累积的梯度平方和来调整学习率。AdaGrad 的更新规则如下:
    θ i = θ i − η G i i + ϵ ⋅ ∇ θ L ( θ ) i \theta_i = \theta_i - \frac{\eta}{\sqrt{G_{ii} + \epsilon}} \cdot \nabla_\theta L(\theta)_i θi=θiGii+ϵ ηθL(θ)i
    其中, G G G 是一个对角矩阵, G i i G_{ii} Gii是参数 θ i \theta_i θi的累积梯度平方和, ϵ \epsilon ϵ是一个很小的常数,用来保证数值稳定性。这种算法的缺点是因为下面的累计梯度平方和越来越大,越往后训练的效果越弱,如果有出现异常梯度值,那直接后面的训练就约等于无效了。
  • RMSProp(均方根传播):
    RMSProp 是一种指数加权的移动平均算法,用于计算梯度的平方的指数衰减平均。它与 AdaGrad 类似,但是使用了梯度平方的指数衰减平均而不是累积和,避免了学习率变得过小的问题。更新规则如下:
    G i i = γ G i i + ( 1 − γ ) ⋅ ( ∇ θ L ( θ ) i ) 2 G_{ii} = \gamma G_{ii} + (1 - \gamma) \cdot (\nabla_\theta L(\theta)_i)^2 Gii=γGii+(1γ)(θL(θ)i)2
    θ i = θ i − η G i i + ϵ ⋅ ∇ θ L ( θ ) i \theta_i = \theta_i - \frac{\eta}{\sqrt{G_{ii} + \epsilon}} \cdot \nabla_\theta L(\theta)_i θi=θiGii+ϵ ηθL(θ)i。其中, γ \gamma γ是衰减率。

    历史的梯度只占一部分,避免了因为历史梯度导致G不断增大,进而出现无法更新的情况。

Adam & AdamW

  • Adam(自适应矩估计):
    14年提出,Adam 结合了 AdaGrad 和 RMSProp 的优点,同时计算了梯度的一阶矩(均值)和二阶矩(方差)的指数加权移动平均。Adam 的更新规则较为复杂,涉及两个时刻的估计量:
    m t = β 1 m t − 1 + ( 1 − β 1 ) ⋅ ∇ θ L ( θ ) m_t = \beta_1 m_{t-1} + (1 - \beta_1) \cdot \nabla_\theta L(\theta) mt=β1mt1+(1β1)θL(θ)

    就是上面提到的一阶动量部分,借鉴Momentum部分

    v t = β 2 v t − 1 + ( 1 − β 2 ) ⋅ ( ∇ θ L ( θ ) ) 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) \cdot (\nabla_\theta L(\theta))^2 vt=β2vt1+(1β2)(θL(θ))2

    二阶动量部分,也就是借鉴RMSProp部分

    这样虽然避免了后期无法更新的问题,但是引入了一个新的问题,那就是因为有衰弱因数,导致在刚开始训练的时候梯度信息积累太慢,因此在更新的时候设一个无偏估计,使用该无偏估计来进行更新
    m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1-\beta_1^t} m^t=1β1tmt
    v ^ t = v t 1 − β 1 t \hat{v}_t = \frac{v_t}{1-\beta_1^t} v^t=1β1tvt
    θ t + 1 = θ t − η v ^ t + ϵ ⋅ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1=θtv^t +ϵηm^t
    其中, m t m_t mt v t v_t vt 分别是梯度的一阶矩和二阶矩的估计, m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t 是它们的无偏估计量,( \beta_1 ) 和 ( \beta_2 ) 是超参数。

    无偏估计的意思是:在大量数据的时候,估计量(estimator)的期望值(或平均值)等于被估计的参数的真实值。

在transformer模型中常用,因为transformer的Lipschitz常量很大,每一层的Lipschitz常量差异又很大,学习率很难估计,而且学习完表现也比较差。所以mSGD基本不用,都是用Adam。

mSGD在卷积网络的时候效果还是不错的,能够和Adam打个平手
Lipschitz常量是指在Lipschitz连续中的一个量,能够体现凸函数的变化率。Lipschitz常量差异大就表示不同函数间相同的自变量变化导致的因变量变化差异大,简而言之也就是学习率需要设置的不同。

AdamW

AdamW和Adam基本一致,只有对正则项的处理不一致。Adam和前面是其他的一样,都是在损失函数里面加一个正则项,但是当训练时,前期梯度太大,会把正则项淹没掉,后期梯度太小,正则项又会把梯度信息淹没掉,AdamW的目的是为了平衡这两项。
AdamW中W 代表权重衰减(Weight Decay),将原本的正则项改为weight decay,将原本在损失函数中的项,放到了权重更新公式中:
θ t + 1 = θ t − η v ^ t + ϵ ⋅ m ^ t − λ θ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t - \lambda \theta_{t} θt+1=θtv^t +ϵηm^tλθt
在 AdamW 中,权重衰减不是直接从参数更新中减去,而是作为参数更新的一部分。这样做的好处是:

  • 保持自适应学习率的一致性:权重衰减与自适应学习率相结合,确保了不同参数的学习率保持一致。
  • 提高收敛性和稳定性:调整后的权重衰减有助于算法更快地收敛,并提高了训练过程的稳定性。

优化器内存占用

在进行小模型训练时,对于优化器的内存占用不是很关注,但是在进行大模型训练时,优化器的内存占用非常大,就需要专门考虑了,大模型常用的优化器为AdamW。
AdamW算法的内存占用相对较高,因为它需要同时保存一阶和二阶矩。具体来说,AdamW算法在优化过程中需要存储以下内容:

  1. 参数的当前值 θ \theta θ
  2. 梯度的一阶矩估计(即一阶动量) m \mathbf{m} m
  3. 梯度的二阶矩估计(即二阶动量) v \mathbf{v} v

每个参数 θ \theta θ 都需要额外存储两个与其尺寸相同的向量 m \mathbf{m} m v \mathbf{v} v,这导致内存占用大约是原始参数内存的两倍。此外,还需要存储超参数。在大规模训练或参数量非常大的模型中,这种内存占用可能会成为一个问题。例如,在训练具有数百万参数的模型时,使用AdamW可能会导致显著的内存需求增加,这可能限制了模型的大小或训练并行度。
对于参数量为 Φ \Phi Φ的模型,使用混合精度进行训练,模型参数本身使用fp16存储,占用 2 Φ 2\Phi 个字节,同样模型梯度占用 2 Φ 2\Phi 个字节,Adam状态(fp32的模型参数备份,fp32的momentum和fp32的variance)一共要占用 12 Φ 12\Phi 12Φ个字节,这两个统称模型状态,共占用 16 Φ 16\Phi 16Φ个字节

混合精度训练时,前向传播和反向传播都是fp16,但是参数更新时使用fp32。

针对显存这个问题,微软提出了ZeRO技术,将模型状态进行分片,对于N个GPU,每个GPU中保存 1 N \frac{1}{N} N1的模型状态量

实现

Pytorch中优化器:官方教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/13800.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【NLP】词性标注

词 词是自然语言处理的基本单位,自动词法分析就是利用计算机对词的形态进行分析,判断词的结构和类别。 词性(Part of Speech)是词汇最重要的特性,链接词汇和句法 词的分类 屈折语:形态分析 分析语&#…

k8s 1.24.x之后如果rest 访问apiserver

1.由于 在 1.24 (还是 1.20 不清楚了)之后,下面这两个apiserver的配置已经被弃用 了,简单的说就是想不安全的访问k8s是不可能了,所以只能走安全的访问方式也就是 https://xx:6443了,所以需要证书。 - --ins…

Git系列:git rm 的高级使用技巧

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【go项目01_学习记录15】

重构MVC 1 Article 模型1.1 首先创建 Article 模型文件1.2 接下来创建获取文章的方法1.3 新增 types.StringToUint64()函数1.4 修改控制器的调用1.5 重构 route 包1.6 通过 SetRoute 来传参对象变量1.7 新增方法:1.8 控制器将 Int64ToString 改为 Uint64ToString1.9…

【数据结构】栈和队列的相互实现

欢迎浏览高耳机的博客 希望我们彼此都有更好的收获 感谢三连支持! 1.用栈实现队列 当队列中进入这些元素时,相应的栈1中元素出栈顺序与出队列相反,因此我们可以使用两个栈来使元素的出栈顺序相同; 通过将栈1元素出栈,再…

Databend 倒排索引的设计与实现

倒排索引是一种用于全文搜索的数据结构。它的主要功能是将文档中的单词作为索引项,映射到包含该单词的文档列表。通过倒排索引,可以快速准确地定位到与查询词相匹配的文档列表,从而大幅提高查询性能。倒排索引在搜索引擎、数据库和信息检索系…

matlab实现绘制烟花代码

下面是一个简化的示例,它使用MATLAB的绘图功能来模拟烟花爆炸的视觉效果。请注意,这个示例是概念性的,并且可能需要根据您的具体需求进行调整。 % 初始化参数 num_fireworks 5; % 烟花数量 num_particles_per_firework 200; % 每个烟花…

前端 CSS 经典:3D 渐变轮播图

前言&#xff1a;无论什么样式的轮播图&#xff0c;核心 JS 实现原理都差不多。所以小伙伴们&#xff0c;还是需要了解一下核心 JS 实验原理的。 效果图&#xff1a; 实现代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta chars…

MySQL —— 复合查询

一、基本的查询回顾练习 前面两章节整理了许多关于查询用到的语句和关键字&#xff0c;以及MySQL的内置函数&#xff0c;我们先用一些简单的查询练习去回顾之前的知识 1. 前提准备 同样是前面用到的用于测试的表格和数据&#xff0c;一张学生表和三张关于雇员信息表 雇员信息…

优化数据查询性能:StarRocks 与 Apache Iceberg 的强强联合

Apache Iceberg 是一种开源的表格格式&#xff0c;专为在数据湖中存储大规模分析数据而设计。它与多种大数据生态系统组件高度兼容&#xff0c;相较于传统的 Hive 表格格式&#xff0c;Iceberg 在设计上提供了更高的性能和更好的可扩展性。它支持 ACID 事务、Schema 演化、数据…

leetcode-设计LRU缓存结构-112

题目要求 思路 双链表哈希表 代码实现 struct Node{int key, val;Node* next;Node* pre;Node(int _key, int _val): key(_key), val(_val), next(nullptr), pre(nullptr){} };class Solution { public: unordered_map<int, Node*> hash; Node* head; Node* tail; int …

普源DHO924示波器OFFSET设置

一、简介 示波器是电子工程师常用的测量工具之一&#xff0c;能够直观地显示电路信号的波形和参数。普源DHO924是一款优秀的数字示波器&#xff0c;具有优异的性能和易用性。其中OFFSET功能可以帮助用户调整信号的垂直位置&#xff0c;使波形更清晰易读。本文将详细介绍DHO924…

专注于运动控制芯片、运动控制产品研发、生产与销售为一体的技术型芯片代理商、方案商——青牛科技

深圳市青牛科技实业有限公司,是专注于运 动控制芯片、运动控制产品研发、生产与销售为一体的技术型 芯片代理商、方案商。现今代理了国产品牌GLOBALCHIP&#xff0c;芯谷&#xff0c;矽普&#xff0c;TOPPOWER等品牌。其中代理品牌TOPPOWER为电源模块&#xff0c;他们公司通过了…

cherry-pick的强大之处在于哪里

git cherry-pick 的强大之处在于它提供了一种灵活的方式来应用特定的提交到不同的分支上&#xff0c;而无需合并整个分支或拉取其他不需要的提交。以下是 git cherry-pick 的几个主要优点和强大之处&#xff1a; 选择性应用提交&#xff1a;你可以挑选一个或多个特定的提交&…

声音转文本(免费工具)

声音转文本&#xff1a;解锁语音技术的无限可能 在当今这个数字化时代&#xff0c;信息的传递方式正以前所未有的速度进化。从手动输入到触控操作&#xff0c;再到如今的语音交互&#xff0c;技术的发展让沟通变得更加自然与高效。声音转文本&#xff08;Speech-to-Text, STT&…

Plant Simulation验证AGV算法

Plant Simulation验证算法也是非常高效且直观的&#xff0c;一直以来波哥在迭代算法的时候图形显示这块都是使用Openframeworks去做&#xff0c;效果还是非常不错的。 这里简要介绍一下openFrameworks&#xff0c;openFrameworks是一个开源的、跨平台的 C 工具包。旨在开发实时…

LeetCode hot100-49-N

236. 二叉树的最近公共祖先 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;…

爬虫学习--12.MySQL数据库的基本操作(下)

MySQL查询数据 MySQL 数据库使用SQL SELECT语句来查询数据。 语法&#xff1a;在MySQL数据库中查询数据通用的 SELECT 语法 SELECT 字段1&#xff0c;字段2&#xff0c;……&#xff0c;字段n FROM table_name [WHERE 条件] [LIMIT N] 查询语句中你可以使用一个或者多个表&…

uni-app项目在微信开发者工具打开时报错[ app.json 文件内容错误] app.json: 在项目根目录未找到 app.json

uni-app项目在微信开发者工具打开时报错[ app.json 文件内容错误] app.json: 在项目根目录未找到 app.json 出现这个问题是因为打开的文件地址不对&#xff0c;解决这个问题首先我们要查看是否有unpackage文件夹&#xff0c;如果有&#xff0c;项目直接指向unpackage\dist\dev\…

vue3使用mitt.js进行各种组件间通信

我们在vue工程中&#xff0c;除开vue自带的什么父子间&#xff0c;祖孙间通信&#xff0c;还有一个非常方便的通信方式&#xff0c;类似Vue2.x 使用 EventBus 进行组件通信&#xff0c;而 Vue3.x 推荐使用 mitt.js。可以实现各个组件间的通信 优点&#xff1a;首先它足够小&…