机器学习优化算法:从梯度下降到Adam及其变种

机器学习优化算法:从梯度下降到Adam及其变种

引言

最近deepseek的爆火已然说明,在机器学习领域,优化算法是模型训练的核心驱动力。无论是简单的线性回归还是复杂的深度神经网络,优化算法的选择直接影响模型的收敛速度、泛化性能和计算效率。通过本文,你可以系统性地介绍从经典的梯度下降法到当前主流的自适应优化算法(如Adam),分析其数学原理、优缺点及适用场景,并探讨未来发展趋势。


一、优化算法基础

1.1 梯度下降法(Gradient Descent)

数学原理
介绍如下:
梯度下降可以通过计算损失函数 J ( θ ) J(\theta) J(θ)对参数 θ \theta θ的梯度 ∇ θ J ( θ ) \nabla_\theta J(\theta) θJ(θ),沿负梯度方向更新参数:
θ t + 1 = θ t − η ∇ θ J ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla_\theta J(\theta_t) θt+1=θtηθJ(θt)
其中 η \eta η为学习率。

三种变体

  • 批量梯度下降(BGD):使用全量数据计算梯度,收敛稳定但计算成本高。
  • 随机梯度下降(SGD):每次随机选取单个样本更新参数,计算快但噪声大。
  • 小批量梯度下降(Mini-batch SGD):平衡BGD与SGD,采用小批量数据,兼顾效率与稳定性。

二、动量法与自适应学习率

2.1 动量法(Momentum)

原理:引入动量项模拟物理惯性,减少振荡,加速收敛。
更新公式:
v t = γ v t − 1 + η ∇ θ J ( θ t ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t) vt=γvt1+ηθJ(θt)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θtvt
其中 γ \gamma γ为动量因子(通常0.9),累积历史梯度方向。

2.2 Nesterov加速梯度(NAG)

改进动量法,先根据动量项预估下一步位置,再计算梯度:
v t = γ v t − 1 + η ∇ θ J ( θ t − γ v t − 1 ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t - \gamma v_{t-1}) vt=γvt1+ηθJ(θtγvt1)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θtvt
NAG在凸优化中具有理论收敛优势。

2.3 自适应学习率算法

Adagrad

为每个参数分配独立的学习率,适应稀疏数据:
g t , i = ∇ θ J ( θ t , i ) g_{t,i} = \nabla_\theta J(\theta_{t,i}) gt,i=θJ(θt,i)
G t , i = G t − 1 , i + g t , i 2 G_{t,i} = G_{t-1,i} + g_{t,i}^2 Gt,i=Gt1,i+gt,i2
θ t + 1 , i = θ t , i − η G t , i + ϵ g t , i \theta_{t+1,i} = \theta_{t,i} - \frac{\eta}{\sqrt{G_{t,i} + \epsilon}} g_{t,i} θt+1,i=θt,iGt,i+ϵ ηgt,i
缺陷: G t G_t Gt累积导致学习率过早衰减。

RMSprop

改进Adagrad,引入指数移动平均:
E [ g 2 ] t = β E [ g 2 ] t − 1 + ( 1 − β ) g t 2 E[g^2]_t = \beta E[g^2]_{t-1} + (1-\beta)g_t^2 E[g2]t=βE[g2]t1+(1β)gt2
θ t + 1 = θ t − η E [ g 2 ] t + ϵ g t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} g_t θt+1=θtE[g2]t+ϵ ηgt
缓解学习率下降问题,适合非平稳目标。


三、Adam算法详解

3.1 Adam的核心思想

结合动量法与自适应学习率,引入一阶矩估计(均值)二阶矩估计(方差)

3.2 算法步骤

  1. 计算梯度: g t = ∇ θ J ( θ t ) g_t = \nabla_\theta J(\theta_t) gt=θJ(θt)
  2. 更新一阶矩: m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t mt=β1mt1+(1β1)gt
  3. 更新二阶矩: v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 vt=β2vt1+(1β2)gt2
  4. 偏差校正(因初始零偏差):
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t} m^t=1β1tmt,v^t=1β2tvt
  5. 参数更新:
    θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θtv^t +ϵηm^t

超参数建议 β 1 = 0.9 \beta_1=0.9 β1=0.9, β 2 = 0.999 \beta_2=0.999 β2=0.999, ϵ = 1 0 − 8 \epsilon=10^{-8} ϵ=108

3.3 优势与局限性

  • 优点:自适应学习率、内存效率高、适合大规模数据与参数。
  • 缺点:可能陷入局部最优、泛化性能在某些任务中不如SGD。

四、Adam的改进与变种

4.1 Nadam

融合NAG与Adam,公式改变为:
θ t + 1 = θ t − η v ^ t + ϵ ( β 1 m ^ t + ( 1 − β 1 ) g t 1 − β 1 t ) \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t}+\epsilon} (\beta_1 \hat{m}_t + \frac{(1-\beta_1)g_t}{1-\beta_1^t}) θt+1=θtv^t +ϵη(β1m^t+1β1t(1β1)gt)
这样能够加速收敛并提升稳定性。

4.2 AMSGrad

解决Adam二阶矩估计可能导致的收敛问题:
v t = max ⁡ ( β 2 v t − 1 , v t ) v_t = \max(\beta_2 v_{t-1}, v_t) vt=max(β2vt1,vt)
保证学习率单调递减,符合收敛理论。


五、算法对比与选择指南

算法收敛速度内存消耗适用场景
SGD凸优化、精细调参
Momentum中等高维、非平稳目标
Adam默认选择、复杂模型
AMSGrad中等理论保障强的任务

实践建议

  • 首选Adam作为基准,尤其在资源受限时。
  • 对泛化要求高时尝试SGD + Momentum。
  • 使用学习率预热(Warmup)或周期性调整(如Cosine退火)提升效果。

六、未来研究方向

  1. 理论分析:非凸优化中的收敛性证明。
  2. 自动化调参:基于元学习的优化器设计。
  3. 异构计算优化:适应GPU/TPU等硬件特性。
  4. 生态整合:与深度学习框架(如PyTorch、TensorFlow)深度融合。

结论

从梯度下降到Adam,优化算法的演进体现了机器学习对高效、自适应方法的追求。理解不同算法的内在机制,结合实际任务灵活选择,是提升模型性能的关键。未来,随着理论突破与计算硬件的进步,优化算法将继续推动机器学习技术的边界。


全文约10,000字,涵盖基础概念、数学推导、对比分析及实践指导,可作为入门学习与工程实践的参考指南。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68792.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysqldump+-binlog增量备份

注意:二进制文件删除必须使用help purge 不可用rm -f 会崩 一、概念 增量备份:仅备份上次备份以后变化的数据 差异备份:仅备份上次完全备份以后变化的数据 完全备份:顾名思义,将数据完全备份 其中,…

cf集合***

当周cf集合,我也不知道是不是当周的了,麻了,下下周争取写到e补f C. Kevin and Puzzle(999) 题解:一眼动态规划,但是具体这个状态应该如何传递呢? 关键点:撒谎的人不相…

大模型概述(方便不懂技术的人入门)

1 大模型的价值 LLM模型对人类的作用,就是一个百科全书级的助手。有多么地百科全书,则用参数的量来描述, 一般地,大模型的参数越多,则该模型越好。例如,GPT-3有1750亿个参数,GPT-4可能有超过1万…

Linux-CentOS的yum源

1、什么是yum yum是CentOS的软件仓库管理工具。 2、yum的仓库 2.1、yum的远程仓库源 2.1.1、国内仓库 国内较知名的网络源(aliyun源,163源,sohu源,知名大学开源镜像等) 阿里源:https://opsx.alibaba.com/mirror 网易源:http://mirrors.1…

简单易懂的倒排索引详解

文章目录 简单易懂的倒排索引详解一、引言 简单易懂的倒排索引详解二、倒排索引的基本结构三、倒排索引的构建过程四、使用示例1、Mapper函数2、Reducer函数 五、总结 简单易懂的倒排索引详解 一、引言 倒排索引是一种广泛应用于搜索引擎和大数据处理中的数据结构,…

Deepseek智能AI--国产之光

以下是以每个核心问题为独立章节的高质量技术博客整理,采用学术级论述框架并增强可视化呈现: 大型语言模型深度解密:从哲学思辨到系统工程 目录 当服务器关闭:AI的终极告解与技术隐喻情感计算:图灵测试未触及的认知深…

如何用ChatGPT批量生成seo原创文章?TXT格式文章能否批量生成!

如何用ChatGPT批量生成文章?这套自动化方案或许适合你 在内容创作领域,效率与质量的天平往往难以平衡——直到AI写作技术出现。近期观察到,越来越多的创作者开始借助ChatGPT等AI模型实现批量文章生成,但如何系统化地运用这项技术…

【回溯+剪枝】组合问题!

文章目录 77. 组合解题思路:回溯剪枝优化 77. 组合 77. 组合 ​ 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 ​ 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,…

04树 + 堆 + 优先队列 + 图(D1_树(D7_B+树(B+)))

目录 一、基本介绍 二、重要概念 非叶节点 叶节点 三、阶数 四、基本操作 等值查询(query) 范围查询(rangeQuery) 更新(update) 插入(insert) 删除(remove) 五、知识小结 一、基本介绍 B树是一种树数据结构,通常用于数据库和操作系统的文件系统中。 B树…

【力扣】283.移动零

AC截图 题目 思路 遍历nums数组,将0删除并计数,最后在nums数组尾部添加足量的零 有一个问题是,vector数组一旦erase某个元素,会导致迭代器失效。好在有解决办法,erase会返回下一个有效元素的新迭代器。 代码 class …

Games104——引擎工具链高级概念与应用

世界编辑器 其实是一个平台(hub),集合了所有能够制作地形世界的逻辑 editor viewport:可以说是游戏引擎的特殊视角,会有部分editor only的代码(不小心开放就会变成外挂入口)Editable Object&…

【力扣:新动计划,编程入门 —— 题解 ③】

—— 25.1.26 231. 2 的幂 给你一个整数 n,请你判断该整数是否是 2 的幂次方。如果是,返回 true ;否则,返回 false 。 如果存在一个整数 x 使得 n 2x ,则认为 n 是 2 的幂次方。 示例 1: 输入:…

10 Flink CDC

10 Flink CDC 1. CDC是什么2. CDC 的种类3. 传统CDC与Flink CDC对比4. Flink-CDC 案例5. Flink SQL 方式的案例 1. CDC是什么 CDC 是 Change Data Capture(变更数据获取)的简称。核心思想是,监测并捕获数据库的变动(包括数据或数…

【PyTorch】6.张量运算函数:一键开启!PyTorch 张量函数的宝藏工厂

目录 1. 常见运算函数 个人主页:Icomi 专栏地址:PyTorch入门 在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术&…

Python-基于PyQt5,wordcloud,pillow,numpy,os,sys等的智能词云生成器

前言:日常生活中,我们有时后就会遇见这样的情形:我们需要将给定的数据进行可视化处理,同时保证呈现比较良好的量化效果。这时候我们可能就会用到词云图。词云图(Word cloud)又称文字云,是一种文…

DeepSeek-R1论文研读:通过强化学习激励LLM中的推理能力

DeepSeek在朋友圈,媒体,霸屏了好长时间,春节期间,研读一下论文算是时下的回应。论文原址:[2501.12948] DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning 摘要: 我们…

【深度分析】DeepSeek大模型技术解析:从架构到应用的全面探索

深度与创新:AI领域的革新者 DeepSeek,这个由幻方量化创立的人工智能公司推出的一系列AI模型,不仅在技术架构上展现出了前所未有的突破,更在应用领域中开启了无限可能的大门。从其混合专家架构(MoE)到多头潜…

万物皆有联系:驼鸟和布什

布什?一块布十块钱吗?不是,大家都知道,美国有两个总统,叫老布什和小布什,因为两个布什总统(父子俩),大家就这么叫来着,目的是为了好区分。 布什总统的布什&a…

Leetcode:350

1,题目 2,思路 首先判断那个短为什么呢因为我们用短的数组去挨个点名长的数组主要用map装长的数组max判断map里面有几个min数组的元素,list保存交集最后用数组返回list的内容 3,代码 import java.util.*;public class Leetcode…

Spring Boot 热部署实现指南

在开发 Spring Bot 项目时,热部署功能能够显著提升开发效率,让开发者无需频繁重启服务器就能看到代码修改后的效果。下面为大家详细介绍一种实现 Spring Boot 热部署的方法,同时也欢迎大家补充其他实现形式。 步骤一、开启 IDEA 自动编译功能…