# 深度学习中的优化算法详解

深度学习中的优化算法详解

优化算法是深度学习的核心组成部分,用于最小化损失函数以更新神经网络的参数。本文将详细介绍深度学习中常用的优化算法,包括其概念、数学公式、代码示例、实际案例以及图解,帮助读者全面理解优化算法的原理与应用。


一、优化算法的基本概念

在深度学习中,优化算法的目标是通过迭代更新模型参数 θ \theta θ,最小化损失函数 L ( θ ) L(\theta) L(θ)。损失函数通常表示为:

L ( θ ) = 1 N ∑ i = 1 N l ( f ( x i ; θ ) , y i ) L(\theta) = \frac{1}{N} \sum_{i=1}^N l(f(x_i; \theta), y_i) L(θ)=N1i=1Nl(f(xi;θ),yi)

其中:

  • f ( x i ; θ ) f(x_i; \theta) f(xi;θ):模型对输入 x i x_i xi 的预测;
  • y i y_i yi:真实标签;
  • l l l:单个样本的损失(如均方误差或交叉熵);
  • N N N:样本数量。

优化算法通过计算梯度 ∇ θ L ( θ ) \nabla_\theta L(\theta) θL(θ),按照一定规则更新参数 θ \theta θ,以逼近损失函数的最优解。


二、常见优化算法详解

以下是深度学习中常用的优化算法,逐一分析其原理、公式、优缺点及代码实现。

1. 梯度下降(Gradient Descent, GD)

概念

梯度下降通过计算整个训练集的梯度来更新参数,公式为:

θ t + 1 = θ t − η ∇ θ L ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla_\theta L(\theta_t) θt+1=θtηθL(θt)

其中:

  • η \eta η:学习率,控制步长;
  • ∇ θ L ( θ t ) \nabla_\theta L(\theta_t) θL(θt):损失函数对参数的梯度。
优缺点
  • 优点:全局梯度信息准确,适合简单凸优化问题。
  • 缺点:计算全量梯度开销大,速度慢,易陷入局部极小值。
代码示例
import numpy as np# 模拟损失函数 L = (theta - 2)^2
def loss_function(theta):return (theta - 2) ** 2def gradient(theta):return 2 * (theta - 2)# 梯度下降
theta = 0.0  # 初始参数
eta = 0.1    # 学习率
for _ in range(100):grad = gradient(theta)theta -= eta * grad
print(f"优化后的参数: {theta}")  # 接近 2

在这里插入图片描述

参数沿梯度方向逐步逼近损失函数的最优解。*


2. 随机梯度下降(Stochastic Gradient Descent, SGD)

概念

SGD 每次仅基于单个样本计算梯度,更新公式为:

θ t + 1 = θ t − η ∇ θ l ( f ( x i ; θ t ) , y i ) \theta_{t+1} = \theta_t - \eta \nabla_\theta l(f(x_i; \theta_t), y_i) θt+1=θtηθl(f(xi;θt),yi)

优缺点
  • 优点:计算效率高,适合大规模数据集,随机性有助于逃离局部极小值。
  • 缺点:梯度噪声大,收敛路径不稳定。
代码示例
# 模拟 SGD
np.random.seed(42)
data = np.random.randn(100, 2)  # 模拟数据
labels = data[:, 0] * 2 + 1     # 模拟标签theta = np.zeros(2)  # 初始参数
eta = 0.01
for _ in range(100):i = np.random.randint(0, len(data))x, y = data[i], labels[i]grad = -2 * (y - np.dot(theta, x)) * x  # 均方误差梯度theta -= eta * grad
print(f"优化后的参数: {theta}")

SGD 的更新路径波动较大,但整体趋向最优解。*


3. 小批量梯度下降(Mini-Batch Gradient Descent)

概念

Mini-Batch GD 结合 GD 和 SGD 的优点,使用小批量样本计算梯度:

θ t + 1 = θ t − η 1 B ∑ i ∈ batch ∇ θ l ( f ( x i ; θ t ) , y i ) \theta_{t+1} = \theta_t - \eta \frac{1}{B} \sum_{i \in \text{batch}} \nabla_\theta l(f(x_i; \theta_t), y_i) θt+1=θtηB1ibatchθl(f(xi;θt),yi)

其中 B B B 为批量大小。

优缺点
  • 优点:平衡了计算效率和梯度稳定性,广泛应用于深度学习框架。
  • 缺点:批量大小需调优,学习率敏感。
代码示例
import torch# 模拟数据
X = torch.randn(100, 2)
y = X[:, 0] * 2 + 1
theta = torch.zeros(2, requires_grad=True)
optimizer = torch.optim.SGD([theta], lr=0.01)# Mini-Batch GD
batch_size = 16
for _ in range(100):indices = torch.randperm(100)[:batch_size]batch_X, batch_y = X[indices], y[indices]pred = batch_X @ thetaloss = ((pred - batch_y) ** 2).mean()optimizer.zero_grad()loss.backward()optimizer.step()
print(f"优化后的参数: {theta}")

4. 动量法(Momentum)

概念

动量法通过引入速度项 v t v_t vt,加速梯度下降,公式为:

v t + 1 = μ v t + ∇ θ L ( θ t ) v_{t+1} = \mu v_t + \nabla_\theta L(\theta_t) vt+1=μvt+θL(θt)
θ t + 1 = θ t − η v t + 1 \theta_{t+1} = \theta_t - \eta v_{t+1} θt+1=θtηvt+1

其中 μ \mu μ 为动量系数(通常为 0.9)。

优缺点
  • 优点:加速收敛,减少震荡。
  • 缺点:超参数需调优,可能超调。
代码示例
# 动量法
theta = 0.0
v = 0.0
eta, mu = 0.1, 0.9
for _ in range(100):grad = gradient(theta)v = mu * v + gradtheta -= eta * v
print(f"优化后的参数: {theta}")

动量法通过累积速度平滑更新路径。*


5. Adam(Adaptive Moment Estimation)

概念

Adam 结合动量法和自适应学习率,通过一阶动量(均值)和二阶动量(方差)更新参数:

m t + 1 = β 1 m t + ( 1 − β 1 ) ∇ θ L ( θ t ) m_{t+1} = \beta_1 m_t + (1 - \beta_1) \nabla_\theta L(\theta_t) mt+1=β1mt+(1β1)θL(θt)
v t + 1 = β 2 v t + ( 1 − β 2 ) ( ∇ θ L ( θ t ) ) 2 v_{t+1} = \beta_2 v_t + (1 - \beta_2) (\nabla_\theta L(\theta_t))^2 vt+1=β2vt+(1β2)(θL(θt))2
m ^ t + 1 = m t + 1 1 − β 1 t + 1 , v ^ t + 1 = v t + 1 1 − β 2 t + 1 \hat{m}_{t+1} = \frac{m_{t+1}}{1 - \beta_1^{t+1}}, \quad \hat{v}_{t+1} = \frac{v_{t+1}}{1 - \beta_2^{t+1}} m^t+1=1β1t+1mt+1,v^t+1=1β2t+1vt+1
θ t + 1 = θ t − η m ^ t + 1 v ^ t + 1 + ϵ \theta_{t+1} = \theta_t - \eta \frac{\hat{m}_{t+1}}{\sqrt{\hat{v}_{t+1}} + \epsilon} θt+1=θtηv^t+1 +ϵm^t+1

其中:

  • β 1 = 0.9 \beta_1 = 0.9 β1=0.9 β 2 = 0.999 \beta_2 = 0.999 β2=0.999
  • ϵ = 1 0 − 8 \epsilon = 10^{-8} ϵ=108,防止除零。
优缺点
  • 优点:自适应学习率,收敛快,适合复杂模型。
  • 缺点:可能过早收敛到次优解。
代码示例
import torch.optim as optim# 使用 PyTorch 的 Adam
model = torch.nn.Linear(2, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
for _ in range(100):pred = model(X)loss = ((pred - y) ** 2).mean()optimizer.zero_grad()loss.backward()optimizer.step()
print(f"优化后的参数: {model.weight}")

Adam 通过自适应步长快速逼近最优解。*


三、实际案例:优化神经网络

任务

使用 PyTorch 训练一个简单的二分类神经网络,比较 SGD 和 Adam 的性能。

代码实现
import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 生成模拟数据
X = torch.randn(1000, 2)
y = (X[:, 0] + X[:, 1] > 0).float().reshape(-1, 1)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(2, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 训练函数
def train(model, optimizer, epochs=100):criterion = nn.BCELoss()losses = []for _ in range(epochs):pred = model(X)loss = criterion(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())return losses# 比较 SGD 和 Adam
model_sgd = Net()
model_adam = Net()
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.01)
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)losses_sgd = train(model_sgd, optimizer_sgd)
losses_adam = train(model_adam, optimizer_adam)# 绘制损失曲线
plt.plot(losses_sgd, label="SGD")
plt.plot(losses_adam, label="Adam")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
结果分析

Adam 通常比 SGD 收敛更快,损失下降更平稳,但在某些任务中 SGD 配合动量可能获得更好的泛化性能。


四、优化算法选择建议

  1. 小型数据集:SGD + 动量,简单且泛化能力强。
  2. 复杂模型(如深度神经网络):Adam 或其变体(如 AdamW),收敛速度快。
  3. 超参数调优
    • 学习率:尝试 1 0 − 3 10^{-3} 103 1 0 − 5 10^{-5} 105
    • 批量大小:16、32 或 64;
    • 动量系数:0.9 或 0.99。

五、总结

优化算法是深度学习训练的基石,从简单的梯度下降到自适应的 Adam,每种算法都有其适用场景。通过理解其数学原理、代码实现和实际表现,开发者可以根据任务需求选择合适的优化策略。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/900802.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

汽车的四大工艺

文章目录 冲压工艺核心流程关键技术 焊接工艺核心流程 涂装工艺核心流程 总装工艺核心流程终检与测试静态检查动态检查四轮定位制动转鼓测试淋雨测试总结 简单总结下汽车的四大工艺(从网上找了一张图,感觉挺全面的)。 冲压工艺 将金属板材通过…

Perl 发送邮件

Perl 发送邮件 概述 Perl 是一种强大的编程语言,广泛应用于系统管理、网络编程和数据分析等领域。其中,使用 Perl 发送邮件是一项非常实用的技能。本文将详细介绍使用 Perl 发送邮件的方法,包括必要的配置、代码示例以及注意事项。 准备工…

关于柔性数组

以前确实没关注过这个问题,一直都是直接定义固定长度的数组,尽量减少指针的操作。 柔性数组主要是再结构体里面定义一个长度为0的数组,这里和定义一个指针式存在明显去别的。定义一个指针会占用内存,但是定义一个长度为0的数组不会…

NOIP2011提高组.玛雅游戏

目录 题目算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化思路*详细注释版代码精简注释版代码 题目 185. 玛雅游戏 算法标签: 模拟, 搜索, d f s dfs dfs, 剪枝优化 思路 可行性剪枝 如果某个颜色的格子数量少于 3 3 3一定无解因为要求字典序最小, 因此当一个格子左边有…

go游戏后端开发29:实现游戏内聊天

接下来,我们再来开发一个功能,这个功能相对简单,就是聊天。在游戏里,我们会收到一个聊天请求,我们只需要做一个聊天推送即可。具体来说,就是谁发的消息,就推送给所有人,包括消息内容…

基于大数据的美团外卖数据可视化分析系统

【大数据】基于大数据的美团外卖数据可视化分析系统 (完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统通过对海量外卖数据的深度挖掘与分析,能够为美团外卖平台提供运营决策支…

[ctfshow web入门] web32

前置知识 协议相关博客:https://blog.csdn.net/m0_73353130/article/details/136212770 include:include "filename"这是最常用的方法,除此之外还可以 include url,被包含的文件会被当做代码执行。 data://&#xff1a…

kotlin中const 和val的区别

在 Kotlin 中,const 和 val 都是用来声明常量的,但它们的使用场景和功能有所不同: 1. val: val 用于声明只读变量,也就是不可修改的变量(类似于 Java 中的 final 变量)。它可以是任何类型,包括…

【STM32】综合练习——智能风扇系统

目录 0 前言 1 硬件准备 2 功能介绍 3 前置配置 3.1 时钟配置 3.2 文件配置 4 功能实现 4.1 按键功能 4.2 屏幕功能 4.3 调速功能 4.4 倒计时功能 4.5 摇头功能 4.6 测距待机功能 0 前言 由于时间关系,暂停详细更新,本文章中,…

任务扩展-输入商品原价,折扣并计算促销后的价格

1.在HbuilderX软件中创建项目,把项目的路径放在xampp中的htdocs 2.创建php文件:price.php,price_from.php 3.在浏览器中,运行项目效果,通过xampp中admin进行运行浏览,在后添加文件名称即可,注意&#xff…

3D Gaussian Splatting as MCMC 与gsplat中的应用实现

3D高斯泼溅(3D Gaussian splatting)自2023年提出以后,相关研究paper井喷式增长,尽管出现了许多改进版本,但依旧面临着诸多挑战,例如实现照片级真实感、应对高存储需求,而 “悬浮的高斯核” 问题就是其中之一。浮动高斯核通常由输入图像中的曝光或颜色不一致引发,也可能…

【软件测试】Postman中如何搭建Mock服务

在 Postman 中,Mock 服务是一项非常有用的功能,允许你在没有实际后端服务器的情况下模拟 API 响应。通过创建 Mock 服务,你可以在开发阶段或测试中模拟 API 的行为,帮助团队成员进行前端开发、API 测试和集成测试等工作。 Mock 服…

Spring-MVC

Spring-MVC 1.SpringMVC简介 - SpringMVC概述 SpringMVC是一个基于Spring开发的MVC轻量级框架,Spring3.0后发布的组件,SpringMVC和Spring可以无缝整合,使用DispatcherServlet作为前端控制器,且内部提供了处理器映射器、处理器适…

关于Spring MVC中@RequestParam注解的详细说明,用于在前后端参数名称不一致时实现参数映射。包含代码示例和总结表格

以下是关于Spring MVC中RequestParam注解的详细说明,用于在前后端参数名称不一致时实现参数映射。包含代码示例和总结表格: 1. 核心作用 RequestParam用于显式绑定HTTP请求参数到方法参数,支持以下场景: 参数名不一致&#xff1…

MySQL主从复制技术详解:原理、实现与最佳实践

目录 引言:MySQL主从复制的技术基础 MySQL主从复制的实现机制 复制架构与线程模型 复制连接建立过程 数据变更与传输流程 MySQL不同复制方式的特点与适用场景 异步复制(Asynchronous Replication) 全同步复制(Fully Synch…

ROS Master多设备连接

Bash Shell Shell是位于用户与操作系统内核之间的桥梁,当用户在终端敲入命令后,这些输入首先会进入内核中的tty子系统,TTY子系统负责捕获并处理终端的输入输出流,确保数据正确无误的在终端和系统内核之中。Shell在此过程不仅仅是…

Trae + LangGPT 生成结构化 Prompt

Trae LangGPT 生成结构化 Prompt 0. 引言1. 安装 Trae2. 克隆 LangGPT3. Trae 和 LangGPT 联动4. 集成到 Dify 中 0. 引言 Github 上 LangGPT 这个项目,主要向我们介绍了写结构化Prompt的一些方法和示例,我们怎么直接使用这个项目,辅助我们…

《安富莱嵌入式周报》第352期:手持开源终端,基于参数阵列的定向扬声器,炫酷ASCII播放器,PCB电阻箱,支持1Ω到500KΩ,Pebble智能手表代码重构

周报汇总地址:嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版 https://www.bilibili.com/video/BV1DEf3YiEqE/ 《安富莱嵌入式周报》第352期:手持开源终端&#x…

python 浅拷贝copy与深拷贝deepcopy 理解

一 浅拷贝与深拷贝 1. 浅拷贝 浅拷贝只复制了对象本身(即c中的引用)。 2. 深拷贝 深拷贝创建一个新的对象,同时也会创建所有子对象的副本,因此新对象与原对象之间完全独立。 二 代码理解 1. 案例一 a 10 b a b 20 print…

day22 学习笔记

文章目录 前言一、遍历1.行遍历2.列遍历3.直接遍历 二、排序三、去重四、分组 前言 通过今天的学习,我掌握了对Pandas的数据类型进行基本操作,包括遍历,去重,排序,分组 一、遍历 1.行遍历 intertuples方法用于遍历D…