【漫话机器学习系列】087.常见的神经网络最优化算法(Common Optimizers Of Neural Nets)

常见的神经网络优化算法

1. 引言

在深度学习中,优化算法(Optimizers)用于更新神经网络的权重,以最小化损失函数(Loss Function)。一个高效的优化算法可以加速训练过程,并提高模型的性能和稳定性。本文介绍几种常见的神经网络优化算法,包括随机梯度下降(SGD)、带动量的随机梯度下降(Momentum SGD)、均方根传播算法(RMSProp)以及自适应矩估计(Adam),并提供相应的代码示例。

2. 常见的优化算法

2.1 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)是最基本的优化算法,其更新规则如下:

其中:

  • w 代表网络参数(权重);
  • α 是学习率(Learning Rate),控制更新步长;
  • ∇L(w) 是损失函数相对于权重的梯度。

代码示例(使用 PyTorch 实现 SGD)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降# 训练步骤
for epoch in range(100):optimizer.zero_grad()  # 清空梯度inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 4.9142
Epoch [10/100], Loss: 2.1721
Epoch [20/100], Loss: 0.9601
Epoch [30/100], Loss: 0.4244
Epoch [40/100], Loss: 0.1876
Epoch [50/100], Loss: 0.0829
Epoch [60/100], Loss: 0.0366
Epoch [70/100], Loss: 0.0162
Epoch [80/100], Loss: 0.0072
Epoch [90/100], Loss: 0.0032


2.2 带动量的随机梯度下降(Momentum SGD)

带动量的 SGD 在 SGD 的基础上加入动量(Momentum),用于加速收敛并减少震荡:


其中:

  • 是累积的梯度,类似于物理中的动量;
  • β 是动量系数(通常取 0.9)。

代码示例(Momentum SGD)

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果 

Epoch [0/100], Loss: 3.0073
Epoch [10/100], Loss: 1.3292
Epoch [20/100], Loss: 0.5875
Epoch [30/100], Loss: 0.2597
Epoch [40/100], Loss: 0.1148
Epoch [50/100], Loss: 0.0507
Epoch [60/100], Loss: 0.0224
Epoch [70/100], Loss: 0.0099
Epoch [80/100], Loss: 0.0044
Epoch [90/100], Loss: 0.0019

优点:

  • 缓解了 SGD 震荡问题,提高收敛速度;
  • 在非凸优化问题中表现更好。

2.3 均方根传播算法(RMSProp)

RMSProp 通过自适应调整学习率来加速训练,并缓解震荡问题:


其中:

  • 是梯度平方的滑动平均;
  • β 是衰减系数(一般取 0.9);
  • ϵ 是一个很小的数,防止除零错误。

代码示例(RMSProp)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 1.1952
Epoch [10/100], Loss: 0.5887
Epoch [20/100], Loss: 0.3333
Epoch [30/100], Loss: 0.1731
Epoch [40/100], Loss: 0.0752
Epoch [50/100], Loss: 0.0239
Epoch [60/100], Loss: 0.0043
Epoch [70/100], Loss: 0.0003
Epoch [80/100], Loss: 0.0000
Epoch [90/100], Loss: 0.0000

优点:

  • 适用于非平稳目标函数;
  • 能有效处理不同特征尺度的问题;
  • 在 RNN(循环神经网络)等任务上表现较好。

2.4 自适应矩估计(Adam, Adaptive Moment Estimation)

Adam 结合了动量法(Momentum)和 RMSProp,同时考虑梯度的一阶矩(平均值)和二阶矩(方差):



其中:

  • ​ 是梯度的一阶矩估计;
  • ​ 是梯度的二阶矩估计;
  • ​ 分别控制一阶矩和二阶矩的指数衰减率(通常取 0.9 和 0.999)。

代码示例(Adam)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

输出结果 

Epoch [0/100], Loss: 3.6065
Epoch [10/100], Loss: 2.8894
Epoch [20/100], Loss: 2.2642
Epoch [30/100], Loss: 1.7359
Epoch [40/100], Loss: 1.3021
Epoch [50/100], Loss: 0.9555
Epoch [60/100], Loss: 0.6855
Epoch [70/100], Loss: 0.4805
Epoch [80/100], Loss: 0.3287
Epoch [90/100], Loss: 0.2192

优点:

  • 结合 Momentum 和 RMSProp 的优势;
  • 适用于大规模数据集和高维参数优化;
  • 具有自适应学习率,适用于不同类型的问题。

3. 选择合适的优化算法

优化算法特点适用场景
SGD计算简单,但容易震荡适用于大规模数据,适合凸优化问题
Momentum SGD增加动量,减少震荡,加速收敛适用于复杂深度神经网络
RMSProp自适应调整学习率,适用于非平稳问题适用于 RNN、强化学习等
Adam结合 Momentum 和 RMSProp,自适应学习率适用于大多数深度学习任务

4. 结论

在神经网络训练过程中,优化算法的选择对最终的模型性能有重要影响。SGD 是最基础的优化方法,而带动量的 SGD 在收敛速度和稳定性上有所提升。RMSProp 适用于非平稳目标函数,而 Adam 结合了 Momentum 和 RMSProp 的优势,成为当前最流行的优化算法之一。

不同任务可能需要不同的优化算法,通常的建议是:

  • 对于简单的凸优化问题,可以使用 SGD。
  • 对于深度神经网络,可以使用 Momentum SGD 或 Adam。
  • 对于 RNN 和强化学习问题,RMSProp 是一个不错的选择。

合理选择优化算法可以显著提升模型训练的效率和效果!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/69486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

傅里叶公式推导(一)

文章目录 三角函数系正交证明图观法数学证明法计算当 n不等于m当 n等于m(重点) 其它同理 首先要了解的一点基础知识: 三角函数系 { sin ⁡ 0 , cos ⁡ 0 , sin ⁡ x , cos ⁡ x , sin ⁡ 2 x , cos ⁡ 2 x , … , sin ⁡ n x , cos ⁡ n x ,…

1. 构建grafana(版本V11.5.1)

一、grafana官网 https://grafana.com/ 二、grafana下载位置 进入官网后点击downloads(根据自己的需求下载) 三、grafana安装(点击下载后其实官网都写了怎么安装) 注:我用的Centos,就简略的写下我的操作步…

macOS 上部署 RAGFlow

在 macOS 上从源码部署 RAGFlow-0.14.1:详细指南 一、引言 RAGFlow 作为一款强大的工具,在人工智能领域应用广泛。本文将详细介绍如何在 macOS 系统上从源码部署 RAGFlow 0.14.1 版本,无论是开发人员进行项目实践,还是技术爱好者…

快速集成DeepSeek到项目

DeepSeek API-KEY 获取 登录DeekSeek 官网,进入API 开放平台 2. 创建API-KEY 复制API-KEY进行保存,后期API调用使用 项目中集成DeepSeek 这里只展示部分核心代码,具体请查看源码orange-ai-deepseek-biz-starter Slf4j AllArgsConstructo…

保姆级教程Docker部署Zookeeper模式的Kafka镜像

目录 一、安装Docker及可视化工具 二、Docker部署Zookeeper 三、单节点部署 1、创建挂载目录 2、命令运行容器 3、Compose运行容器 4、查看运行状态 5、验证功能 四、部署可视化工具 1、创建挂载目录 2、Compose运行容器 3、查看运行状态 一、安装Docker及可视化工…

Docker容器访问外网:启动时的网络参数配置指南

在启动Docker镜像时,可以通过设置网络参数来确保容器能够访问外网。以下是几种常见的方法: 1. 使用默认的bridge网络 Docker的默认网络模式是bridge,它会创建一个虚拟网桥,将容器连接到宿主机的网络上。在大多数情况下,使用默认的bridge网络配置即可使容器访问外网。 启动…

白话文实战Nacos(保姆级教程)

前言 上一篇博客 我们创建好了微服务项目,本篇博客来体验一下Nacos作为注册中心和配置中心的功能。 注册中心 如果我们启动了一个Nacos注册中心,那么微服务比如订单服务,启动后就可以连上注册中心把自己注册上去,这过程就是服务注册。每个微服务,比如商品服务都应该注册…

C语言基础08:运算符+流程控制总结

运算符 算术运算符 结果:数值 、-、*、\、%、(正)、-(负)、、-- i和i 相同点:i自身都会增1 不同点:它们运算的最终结果是不同的。i:先使用,后计算;i&am…

Node.js开发属于自己的npm包(发布到npm官网)

在 Node.js 中开发并发布自己的 npm 包是一个非常好的练习,可以帮助我们更好地理解模块化编程和包管理工具,本篇文章主要阐述如何使用nodejs开发一个属于自己的npm包,并且将其发布在npm官网。在开始之前确保已经安装了 Node.js 和 npm。可以在…

如何在RTACAR中配置IP多播(IP Multicast)

一、什么是IP多播 IP多播(IP Multicast)是一种允许数据包从单一源地址发送到多个目标地址的技术,是一种高效的数据传输方式。 多播地址是专门用于多播通信的IP地址,范围从 224.0.0.0到239.255.255.255 与单播IP地址不同&#x…

12.翻转、对称二叉树,二叉树的深度

反转二叉树 递归写法 很简单 class Solution { public:TreeNode* invertTree(TreeNode* root) {if(rootnullptr)return root;TreeNode* tmp;tmproot->left;root->leftroot->right;root->righttmp;invertTree(root->left);invertTree(root->right);return …

网络安全行业的冬天

冬天已经来了,春天还会远吗?2022年10月28日,各个安全大厂相继发布了财报,纵观2022年前三季度9个月,三六零亏了19亿,奇安信亏了11亿,深信服亏了6亿,天融信亏了4亿,安恒亏了…

MYSQL索引与视图

一、新建数据库 mysql> create database mydb15_indexstu; mysql> use mydb15_indexstu; 二、新建表 (1)学生表Student mysql> create table Student(-> Sno int primary key auto_increment,-> Sname varchar(30) not null unique,-…

深度优先搜索(DFS)——八皇后问题与全排列问题

( ^ _ ^ ) 数据结构好难哇(哭 1.BFS和DFS 数据结构空间性质DFSstackO(h)不具有最短性质BFSqueueO(2^h)具有最短路性质 空间上DFS占优势,但是BFS具有最短性 (若所有权重都是1,则BFS一定最短)&…

Flink 内存模型各部分大小计算公式

Flink 的运行平台 如果 Flink 是运行在 yarn 或者 standalone 模式的话,其实都是运行在 JVM 的基础上的,所以首先 Flink 组件运行所需要给 JVM 本身要耗费的内存大小。无论是 JobManager 或者 TaskManager ,他们 JVM 内存的大小都是一样的&a…

Vue07

一、Vuex 概述 目标:明确Vuex是什么,应用场景以及优势 1.是什么 Vuex 是一个 Vue 的 状态管理工具,状态就是数据。 大白话:Vuex 是一个插件,可以管理 Vue 通用的数据 (多组件共享的数据)。例如:购物车数…

Linux 安装 Ollama

1、下载地址 Download Ollama on Linux 2、有网络直接执行 curl -fsSL https://ollama.com/install.sh | sh 命令 3、下载慢的解决方法 1、curl -fsSL https://ollama.com/install.sh -o ollama_install.sh 2、sed -i s|https://ollama.com/download/ollama-linux|https://…

Docker Desktop无法安装报错(求助记录中)

之前Docker Desktop无法使用,报了一个注册表的错误(忘记截图)我想着更新安装下应该就好了,结果Docker Desktop一直无法安装,花了几天都没解决。同时我的window11更新也出现下载错误 - 0x80040154异常,启动或关闭Window…

Docker入门(Windows)

视频链接:Docker | 狂神说 环境说明 Windows For Docker WSL2 概念 Docker是什么? 百度百科:百度百科 Docker 是一个开源的平台,它利用操作系统级虚拟化技术来打包和运行应用程序。通过使用容器化技术,Docker 提…

STM32 RTC亚秒

rtc时钟功能实现:rtc模块在stm32内部,由电池或者主电源供电。如下图,需注意实现时仅需设置一次初始化。 1、stm32cubemx 代码生成界面设置,仅需开启时钟源和激活日历功能。 2、生成的代码,需要对时钟进行初始化,仅需…