深度学习 - 梯度下降优化方法

梯度下降的基本概念

梯度下降(Gradient Descent)是一种用于优化机器学习模型参数的算法,其目的是最小化损失函数,从而提高模型的预测精度。梯度下降的核心思想是通过迭代地调整参数,沿着损失函数下降的方向前进,最终找到最优解。

生活中的背景例子:寻找山谷的最低点

想象你站在一个山谷中,眼睛被蒙住,只能用脚感受地面的坡度来找到山谷的最低点(即损失函数的最小值)。你每一步都想朝着坡度下降最快的方向走,直到你感觉不到坡度,也就是你到了最低点。这就好比在优化一个模型时,通过不断调整参数,使得模型的预测误差(损失函数)越来越小,最终找到最佳参数组合。

梯度下降的具体方法及其优化

1. 批量梯度下降(Batch Gradient Descent)

生活中的例子
你决定每次移动之前,都要先测量整个山谷的坡度,然后再决定移动的方向和步幅。虽然每一步的方向和步幅都很准确,但每次都要花很多时间来测量整个山谷的坡度。

公式
θ : = θ − η ⋅ ∇ θ J ( θ ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta) θ:=θηθJ(θ)
其中:

  • θ \theta θ是模型参数
  • η \eta η是学习率
  • ∇ θ J ( θ ) \nabla_{\theta} J(\theta) θJ(θ)是损失函数 J ( θ ) J(\theta) J(θ)关于 θ \theta θ的梯度

API
TensorFlow

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

PyTorch

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
2. 随机梯度下降(Stochastic Gradient Descent, SGD)

生活中的例子
你决定每一步都只根据当前所在位置的坡度来移动。虽然这样可以快速决定下一步怎么走,但由于只考虑当前点,可能会导致路径不稳定,有时候会走过头。

公式
θ : = θ − η ⋅ ∇ θ J ( θ ; x ( i ) , y ( i ) ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta; x^{(i)}, y^{(i)}) θ:=θηθJ(θ;x(i),y(i))
其中 ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))是当前样本的数据

API
TensorFlowPyTorch 中的API与批量梯度下降相同,具体行为取决于数据的加载方式。例如在训练时可以一批数据包含一个样本。

3. 小批量梯度下降(Mini-Batch Gradient Descent)

生活中的例子
你决定每次移动之前,只测量周围一小部分区域的坡度,然后根据这小部分区域的平均坡度来决定方向和步幅。这样既不需要花太多时间测量整个山谷,也不会因为只看一个点而导致路径不稳定。

公式
θ : = θ − η ⋅ ∇ θ J ( θ ; B ) \theta := \theta - \eta \cdot \nabla_{\theta} J(\theta; \mathcal{B}) θ:=θηθJ(θ;B)
其中 B \mathcal{B} B是当前小批量的数据

API
TensorFlowPyTorch 中的API与批量梯度下降相同,但在数据加载时使用小批量。

4. 动量法(Momentum)

生活中的例子
你在移动时,不仅考虑当前的坡度,还考虑之前几步的移动方向,就像带着惯性一样。如果前几步一直往一个方向走,那么你会倾向于继续往这个方向走,减少来回震荡。

公式
v : = β v + ( 1 − β ) ∇ θ J ( θ ) v := \beta v + (1 - \beta) \nabla_{\theta} J(\theta) v:=βv+(1β)θJ(θ)
θ : = θ − η v \theta := \theta - \eta v θ:=θηv
其中:

  • v v v是动量项
  • β \beta β是动量系数(通常接近1,如0.9)

API
TensorFlow

optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)

PyTorch

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
5. RMSProp

生活中的例子
你在移动时,会根据最近一段时间内每一步的坡度情况,动态调整步幅。比如,当坡度变化剧烈时,你会迈小步,当坡度变化平缓时,你会迈大步。

公式
s : = β s + ( 1 − β ) ( ∇ θ J ( θ ) ) 2 s := \beta s + (1 - \beta) (\nabla_{\theta} J(\theta))^2 s:=βs+(1β)(θJ(θ))2
θ : = θ − η s + ϵ ∇ θ J ( θ ) \theta := \theta - \frac{\eta}{\sqrt{s + \epsilon}} \nabla_{\theta} J(\theta) θ:=θs+ϵ ηθJ(θ)
其中:

  • s s s是梯度平方的加权平均值
  • ϵ \epsilon ϵ是一个小常数,防止除零错误

API
TensorFlow

optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)

PyTorch

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
6. Adam(Adaptive Moment Estimation)

生活中的例子
你在移动时,结合动量法和RMSProp的优点,不仅考虑之前的移动方向(动量),还根据最近一段时间内的坡度变化情况(调整步幅),从而使移动更加平稳和高效。

公式
m : = β 1 m + ( 1 − β 1 ) ∇ θ J ( θ ) m := \beta_1 m + (1 - \beta_1) \nabla_{\theta} J(\theta) m:=β1m+(1β1)θJ(θ)
v : = β 2 v + ( 1 − β 2 ) ( ∇ θ J ( θ ) ) 2 v := \beta_2 v + (1 - \beta_2) (\nabla_{\theta} J(\theta))^2 v:=β2v+(1β2)(θJ(θ))2
m ^ : = m 1 − β 1 t \hat{m} := \frac{m}{1 - \beta_1^t} m^:=1β1tm
v ^ : = v 1 − β 2 t \hat{v} := \frac{v}{1 - \beta_2^t} v^:=1β2tv
θ : = θ − η m ^ v ^ + ϵ \theta := \theta - \eta \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon} θ:=θηv^ +ϵm^
其中:

  • m m m v v v分别是梯度的一阶和二阶动量
  • β 1 \beta_1 β1 β 2 \beta_2 β2是动量系数(通常分别取0.9和0.999)
  • m ^ \hat{m} m^ v ^ \hat{v} v^是偏差校正后的动量项
  • t t t是时间步

API
TensorFlow

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

PyTorch

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

综合应用示例

假设我们在使用TensorFlow和PyTorch训练一个简单的神经网络,以下是如何应用这些优化方法的示例代码。

TensorFlow 示例

import tensorflow as tf# 定义模型
model = tf.keras.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dense(10, activation='softmax')
])# 编译模型并选择优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 准备数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 训练模型
model.fit(x_train, y_train, epochs=10, batch_size=32)

PyTorch 示例

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNN()# 选择优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()# 准备数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 训练模型
for epoch in range(10):for batch in train_loader:x_train, y_train = batchx_train = x_train.view(x_train.size(0), -1)  # Flatten the imagesoptimizer.zero_grad()outputs = model(x_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()

更多问题咨询

CosAI

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/24360.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人体感应提醒 大声公+微波模块

文章目录 模块简介接线程序示例 模块简介 微波感应开关模块 RCWL-0516是一款采用多普勒雷达技术,专门检测物体移动的微波感应模块。采用 2.7G 微波信号检测,该模块具有灵敏度高,感应距离远,可靠性强,感应角度大&#…

Ruoyi-Vue-Plus 下载启动后菜单无法点击展开,

1.Ruoyi-Vue-Plus框架下载后运行 2.使用mock数据 3.进入页面后无法点击菜单 本以为是动态路由或者菜单逻辑出了问题,最后发现是websocket的问题 解决办法 把这两行代码注释 页面菜单即可点击。 以上。

【ROS使用记录】—— ros使用过程中的rosbag录制播放和ros话题信息相关的指令与操作记录

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、rosbag的介绍二、rosbag的在线和离线录制三、rosbag的播放相关的指令四、其他rosbag和ros话题相关的指令总结 前言 rosbag是ROS(机器人操作系统…

Suse Linux ssh配置免密后仍需要输入密码

【问题描述】 Suse Linux已经配置了ssh免密,但无法ssh到目标服务器。 对自身的ssh登陆也需要输入密码。 系统–Suse 15 SP5 【重现步骤】 1.使用ssh-keygen -t rsa生产key文件 2.使用ssh-copy-id拷贝public key到目标机器(或者自身) 3.配置成功后ssh 目标时仍需要输…

电商API在维护数据安全与合规性中的重要性

摘要 在数字化时代,数据安全和合规性是电商企业不可忽视的重大议题。本文将探讨电商API如何在保护敏感数据、遵守法律法规和防范网络威胁方面发挥关键作用。 引言 随着大量敏感数据的电子化处理和存储,电商企业面临的安全挑战日益严峻。API接口技术成为…

手机模拟操作进阶:1.某团获取附近商店情况

0.以超市便利为例分析: 超市便利的xp (//android.widget.ImageView[@resource-id="com.sankuai.meituan:id/channel_icon"])[5] 附近的xp //android.widget.TextView[@text="全部200+店"] 商家信息列表区: //android.support.v7.widget.RecyclerView[@…

《青少年编程与数学》课程方案:2、课程内容 4_4

《青少年编程与数学》课程方案:2、课程内容 4_4 十四、数学(三)高中数学(四)微机分(五)线性代数(六)概率论与数理统计(七)离散数学(八…

娛閑放鬆篇1

最近在B站看了挺多的動漫,挺小說化的,我這個人比較哲學,故和大家分享一下 B站娛閑 1.蘇老大的動漫 1.<<人類清除計劃>> 本來看的過癮,但沒想到,連小說也停更了..... 2.黑山羊遊戲 挺劇本的 3.顧毅 一個小說的主人公,第一個能力是無限推演... 崇山醫…

[C#]使用OpenCvSharp图像滤波中值滤波均值滤波高通滤波双边滤波锐化滤波自定义滤波

在使用OpenCvSharp进行图像滤波处理时&#xff0c;各种滤波方法都有其特定的用途和效果。以下是对中值滤波、均值滤波、高通滤波、双边滤波、锐化滤波和自定义滤波的详细解释和归纳&#xff1a; 中值滤波&#xff08;MedianBlur&#xff09; 原理与作用&#xff1a;中值滤波是…

Stable diffusion采样器详解

在我们使用SD web UI的过程中&#xff0c;有很多采样器可以选择&#xff0c;那么什么是采样器&#xff1f;它们是如何工作的&#xff1f;它们之间有什么区别&#xff1f;你应该使用哪一个&#xff1f;这篇文章将会给你想要的答案。 什么是采样&#xff1f; Stable Diffusion模…

UI学习--导航控制器

导航控制器 导航控制器基础基本概念具体使用 导航控制器切换演示具体使用注意 导航栏与工具栏基本概念具体使用&#xff1a; 总结 导航控制器基础 基本概念 根视图控制器&#xff08;Root View Controller&#xff09;&#xff1a;导航控制器的第一个视图控制器&#xff0c;通…

压缩大文件消耗电脑CPU资源达到33%以上

今天用7-Zip压缩一个大文件&#xff0c;文件大小是9G多&#xff0c;这时能听到电脑风扇声音&#xff0c;查看了一下电脑资源使用情况&#xff0c;确实增加了不少。 下面是两张图片&#xff0c;图片上有电脑资源使用数据。

Spring系统学习 -Spring IOC 的XML管理Bean之bean的获取、依赖注入值的方式

在Spring框架中&#xff0c;XML配置是最传统和最常见的方式之一&#xff0c;用于管理Bean的创建、依赖注入和生命周期等。这个在Spring中我们使用算是常用的&#xff0c;我们需要根据Spring的基于XML管理Bean了解相关Spring中常用的获取bean的方式、依赖注入值的几种方式等等。…

c++ namespace以及使用建议

命名空间就是用来区分你使用的这个变量和函数是属于那一块的。用来防止不同的人所写函数和变量&#xff0c;名字相同产生冲突。 在写c代码的时候&#xff0c;经常会使用标准库中的函数&#xff0c;使用之前我们必须在前面添加一个std::&#xff0c;因为c标准库的函数是在命名空…

关闭Cloudflare Pages的访问策略

curl API 获取相应的 uid curl -X GET "https://api.cloudflare.com/client/v4/accounts/账户标识符/access/apps" \-H "X-Auth-Email: 邮箱" \-H "X-Auth-Key: Global API KEY" \-H "Content-Type: application/json"账户标识符是登…

Dubbo面试题甄选及参考答案

目录 Dubbo是什么? Dubbo的主要使用场景有哪些? Dubbo的核心功能有哪些? Dubbo与Spring框架的集成方式是什么? Dubbo的RPC调用原理是什么? Dubbo的架构中包含哪些核心组件? Provider、Consumer、Registry、Monitor在Dubbo中分别承担什么角色? Container在Dubbo中…

Maven项目打包成jar项目后运行报错误: 找不到或无法加载主类 Main.Main 和 jar中没有主清单属性解决方案

已经用maven工程的package功能进行了打包 找不到或无法加载主类 Main.Main 规定主类 主要在maven的配置文件当中 这边一定要绑定自己的启动类 jar中没有主清单属性 删掉这一行就行哈 正确的插件代码 <plugin><groupId>org.springframework.boot</groupId&…

毫米波SDK使用1

本文档是AM273x等毫米波雷达处理器SDK的配置和使用&#xff0c;主要参考TI的官方文档《mmwave mcuplus sdk user guide》。这里仅摘取其中重要的部分&#xff0c;其余枝节可参考原文。 2 系统概览 mmWave SDK分为两个主要组件:mmWave套件和mmWave演示。 2.1. mmWave套件 mmWa…

AXI Quad SPI IP核基于AXI-Lite接口的标准SPI设计指南

在标准SPI配置下&#xff0c;SPI设备除了包含基本的SPI特性外&#xff0c;还具备以下一些标准功能&#xff0c;这些功能如下所示&#xff1a; 支持FPGA内部的多主设备配置&#xff0c;其中使用单独的_I&#xff08;输入&#xff09;、_O&#xff08;输出&#xff09;、_T&…

FM148A,FM146B运行备件

FM148A,FM146B运行备件。电源保险丝仓主控底座的保险丝仓示意图底座上共有两个保险丝&#xff08;800mA&#xff09;&#xff0c;FM148A,FM146B运行备件。&#xff08;10&#xff5e;73&#xff09;30/195主控单元2.K-CUT014槽底座地址接口主控站地址拨开关从上到下为二进制数的…