神经网络中的优化方法

 一、引入

在传统的梯度下降优化算法中,如果碰到平缓区域,梯度值较小,参数优化变慢 ,遇到鞍点(是指在某些方向上梯度为零而在其他方向上梯度非零的点。),梯度为 0,参数无法优化,碰到局部最小值。实践中使用的小批量梯度下降法(mini-batch SGD)因其梯度估计的噪声性质,有时能够使模型脱离这些点。

💥为了克服这些困难,研究者们提出了多种改进策略,出现了一些对梯度下降算法的优化方法:Momentum、AdaGrad、RMSprop、Adam 等。

二、指数加权平均

我们最常见的算数平均指的是将所有数加起来除以数的个数,每个数的权重是相同的。加权平均指的是给每个数赋予不同的权重求得平均数。指数加权平均是一种数据处理方式,它通过对历史数据应用不同的权重来减少过去数据的影响,并强调近期数据的重要性。

[ vt = beta * v{t-1} + (1 - beta) * theta_t] 

比如:明天气温怎么样,和昨天气温有很大关系,而和一周前的气温关系就小一些。 

vt​ 是第 𝑡 天的平均温度值,𝜃𝑡​ 是第 𝑡t 天的实际观察值,而 𝛽 是一个可调节的超参数(通常 0<𝛽<1)。这个公式表明,当前的平均值是前一天平均值与当天实际值的加权平均。

  • β 调节权重系数,该值越大平均数越平缓。

 我们接下来通过一段代码来看下结果,随机产生进 30 天的气温数据:

import torch
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'# 实际平均温度
def test01():# 固定随机数种子torch.manual_seed(0)# 产生30天的随机温度temperature = torch.randn(size=[30]) * 10print(temperature)# 绘制平均温度days = torch.arange(1, 31, 1)plt.plot(days, temperature, color='r')plt.scatter(days, temperature)plt.show()# 指数加权平均温度
def test02(beta=0.8):# 固定随机数种子torch.manual_seed(0)# torch.initial_seed()# 产生30天的随机温度temperature = torch.randn(size=[30,]) * 10print(temperature)exp_weight_avg = []for idx, temp in enumerate(temperature, 1):# 第一个元素的的 EWA 值等于自身if idx == 1:exp_weight_avg.append(temp)continue# 第二个元素的 EWA 值等于上一个 EWA 乘以 β + 当前气氛乘以 (1-β)new_temp = exp_weight_avg[idx - 2] * beta + (1 - beta) * tempexp_weight_avg.append(new_temp)days = torch.arange(1, 31, 1)plt.plot(days, exp_weight_avg, color='r')plt.scatter(days, temperature)plt.show()if __name__ == '__main__':test01()test02()

这是test01执行后产生的实际值: 

 我们再看一下指数平均后的值:

🔎指数加权平均绘制出的气氛变化曲线更加平缓; β 的值越大,则绘制出的折线越加平缓; 

三、Momentum

我们通过对指数加权平均的知识来研究Momentum优化方法💢

  • 鞍点:梯度为零的点,损失函数的梯度在所有方向上都接近或等于零。由于梯度为零,标准梯度下降法在此将无法继续优化参数。
  • 平缓区域:这些区域的梯度值较小,导致参数更新缓慢。虽然这意味着算法接近极小值点,但收敛速度会变得非常慢。

当梯度下降碰到 “峡谷” 、”平缓”、”鞍点” 区域时,参数更新速度变慢,Momentum 通过指数加权平均法,累计历史梯度值,进行参数更新,越近的梯度值对当前参数更新的重要性越大。

Momentum优化方法是对传统梯度下降法的一种改进:

Momentum优化算法的核心思想是在一定程度上积累之前的梯度信息,以此来调整当前的梯度更新方向。这种方法可以在一定程度上减少训练过程中的摆动现象,使得学习过程更加平滑,从而可能使用较大的学习率而不必担心偏离最小值太远。 

梯度计算公式:Dt = β * St-1 + (1- β) * Dt

在面对梯度消失、鞍点等问题时,Momentum能够改善SGD的表现,帮助模型跳出局部最小值或平坦区域;如果当处于鞍点位置时,由于当前的梯度为 0,参数无法更新。但是 Momentum梯度下降算法已经在先前积累了一些梯度值,很有可能使得跨过鞍点。

由于 mini-batch 普通的梯度下降算法,每次选取少数的样本梯度确定前进方向,可能会出现震荡,使得训练时间变长。Momentum 使用移动加权平均,平滑了梯度的变化,使得前进方向更加平缓,有利于加快训练过程。一定程度上有利于降低 “峡谷” 问题的影响。

Momentum方法的实现案例: 

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
model = nn.Linear(10, 1)# 定义损失函数
criterion = nn.MSELoss()# 定义优化器,并设置momentum参数为0.9
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 模拟数据
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)# 训练模型
for epoch in range(10):# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 10, loss.item()))

Momentum 算法可以理解为是对梯度值的一种调整,我们知道梯度下降算法中还有一个很重要的学习率,Momentum 并没有学习率进行优化。 

四、AdaGrad

💥Momentum 算法是对梯度值调整,使得模型可以更好的进行参数更新,AdaGrad算法则是对学习率,即每次更新走的步长,进行调整更新~

AdaGrad 通过对不同的参数分量使用不同的学习率,AdaGrad 的学习率总体会逐渐减小,这是因为 AdaGrad算法认为:在起初时,我们距离最优目标仍较远,可以使用较大的学习率,加快训练速度,随着迭代次数的增加,学习率逐渐下降。

🗨️计算步骤如下:

  1. 初始化学习率 α、初始化参数 θ、小常数 σ = 1e-6
  2. 初始化梯度累积变量 s = 0
  3. 从训练集中采样 m 个样本的小批量,计算梯度 g
  4. 累积平方梯度 s = s + g ⊙ g,⊙ 表示各个分量相乘

 

AdaGrad通过这种方式实现了对每个参数的个性化学习率调整,使得在参数空间较平缓的方向上可以取得更大的进步,而在陡峭的方向上则能够变得更加平缓,从而加快了训练速度( 如果累计梯度值s大的话,学习率就会小一点)        

使用Python实现AdaGrad算法的API代码:

import torchclass AdaGrad:def __init__(self, params, lr=0.01, epsilon=1e-8):self.params = list(params)self.lr = lrself.epsilon = epsilonself.cache = [torch.zeros_like(param) for param in self.params]def step(self):for i, param in enumerate(self.params):self.cache[i] += param.grad.data ** 2param.data -= self.lr * param.grad.data / (torch.sqrt(self.cache[i]) + self.epsilon)

💥AdaGrad 缺点是可能会使得学习率过早、过量的降低,导致模型训练后期学习率太小,较难找到最优解。

五、RMSProp

RMSProp(Root Mean Square Prop)是一种常用的自适应学习率优化算法,是对 AdaGrad 的优化,最主要的不同是,RMSProp使用指数移动加权平均梯度替换历史梯度的平方和。

  1. 初始化学习率 α、初始化参数 θ、小常数 σ = 1e-6
  2. 初始化参数 θ
  3. 初始化梯度累计变量 s
  4. 从训练集中采样 m 个样本的小批量,计算梯度 g
  5. 使用指数移动平均累积历史梯度

 

RMSProp 与 AdaGrad 最大的区别是对梯度的累积方式不同,对于每个梯度分量仍然使用不同的学习率。RMSProp 通过引入衰减系数 β,控制历史梯度对历史梯度信息获取的多少,使得学习率衰减更加合理一些。 

import numpy as npdef rmsprop(params, grads, learning_rate=0.01, decay_rate=0.9, epsilon=1e-8):cache = {}for key in params.keys():cache[key] = np.zeros_like(params[key])for key in params.keys():cache[key] = decay_rate * cache[key] + (1 - decay_rate) * grads[key] ** 2params[key] -= learning_rate * grads[key] / (np.sqrt(cache[key]) + epsilon)return params

params是一个字典,包含了模型的参数;grads是一个字典,包含了参数对应的梯度;learning_rate是学习率;decay_rate是衰减系数;epsilon是一个很小的正数,用于防止除以零。

六、Adam 

💯Adam 结合了两种优化算法的优点:RMSProp(Root Mean Square Prop)和Momentum。Adam在深度学习中被广泛使用,因为它能够自动调整学习率,特别适合处理大规模数据集和复杂模型。

Adam的关键特点:

  1. 一阶矩估计(First Moment):梯度的均值,类似于Momentum中的velocity term,用于指示梯度在何时变得非常剧烈。

  2. 二阶矩估计(Second Moment):梯度的未中心化方差,类似于RMSProp中的平方梯度的指数移动平均值,用于指示梯度变化的范围。

我们在平时使用中会经常用到次方法,在PyTorch中就是optim.Adam方法,不再是optim.SGD方法:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的线性模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 1)  # 假设输入维度是10,输出维度是1def forward(self, x):return self.linear(x)# 创建模型实例
model = SimpleModel()# 定义损失函数
criterion = nn.MSELoss()  # 均方误差损失函数# 创建优化器,设定学习率为0.001,参数beta1默认为0.9,beta2默认为0.999
optimizer = optim.Adam(model.parameters(), lr=0.001)# 假设有一个输入数据x和对应的目标y
x = torch.randn(32, 10)  # 批量大小为32,每个样本的输入维度是10
y = torch.randn(32, 1)   # 批量大小为32,每个样本的输出维度是1# 前向传播
outputs = model(x)# 计算损失
loss = criterion(outputs, y)# 清空之前所有的梯度
optimizer.zero_grad()# 反向传播
loss.backward()# 更新模型参数
optimizer.step()# 打印损失值
print("Loss: ", loss.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831401.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构-AVL树

目录 什么是 AVL 树 ASL 度量查找效率 结构体定义 平衡调整 调整类型 左旋和右旋 右旋 左旋 左、右平衡调整 左平衡调整 右平衡调整 插入数据 模拟建立 AVL 树 什么是 AVL 树 二叉排序树的形状取决于数据集&#xff0c;当二叉树的高度越小、结构越合理&#xff0c…

thinkphp家政上门预约服务小程序家政保洁师傅上门服务小程序上门服务在线派单安装教程

介绍 thinkphp家政上门预约服务小程序家政保洁师傅上门服务小程序上门服务在线派单安装教程 上门预约服务派单小程序家政小程序同城预约开源代码独立版安装教程 程序完整&#xff0c;经过安装检测&#xff0c;可放心下载安装。 适合本地的一款上门预约服务小程序&#xff0…

计算机网络——初识网络

一、局域网与广域网 1.局域网&#xff08;LAN&#xff09; 局域网&#xff1a;即Local Area Network&#xff0c;简称LAN。Local即标识了局域⽹是本地&#xff0c;局部组建的⼀种私有⽹络。局域⽹内的主机之间能⽅便的进⾏⽹络通信&#xff0c;⼜称为内⽹&#xff1b;局域⽹和…

A4的PDF按A3打印

先用办公软件打开&#xff0c;比如WPS。 选择打印-属性。 纸张选A3&#xff0c;如果是双面打印&#xff0c;选短边装订&#xff0c;然后在版面-页面排版-每张页数&#xff08;N合1&#xff09;选2。 不同打印机的具体配置可能不一样&#xff0c;但大体都是这个套路。

[NSSCTF]prize_p1

前言 之前做了p5 才知道还有p1到p4 遂来做一下 顺便复习一下反序列化 prize_p1 <META http-equiv"Content-Type" content"text/html; charsetutf-8" /><?phphighlight_file(__FILE__);class getflag{function __destruct(){echo getenv(&qu…

The Role of Subgroup Separability in Group-Fair Medical Image Classification

文章目录 The Role of Subgroup Separability in Group-Fair Medical Image Classification摘要方法实验结果 The Role of Subgroup Separability in Group-Fair Medical Image Classification 摘要 研究人员调查了深度分类器在性能上的差异。他们发现&#xff0c;分类器将个…

基于Springboot的民宿管理平台

基于SpringbootVue的民宿管理平台设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 首页 民宿信息 后台登录 后台首页 用户管理 商家管理 民宿信息管理 房间类型管理 …

【正版系统】知识付费系统搭建,系统模板开发完善功能强大,支持快速部署上线

在数字化时代&#xff0c;知识付费已成为一种趋势&#xff0c;为内容创作者和求知者搭建了一个高效的交流平台。今天&#xff0c;我要为大家介绍一款功能强大的知识付费系统。 知识付费系统模板是我们一款经过精心开发、完善的系统&#xff0c;旨在为用户提供一站式的知识付费…

【Docker】如何注册Hub账号并上传镜像到Hub仓库

一、创建Hub账户 浏览器访问&#xff1a;hub.docker.com 点击【Sign up】注册账号 输入【邮箱】【用户名】【密码】 ps&#xff1a;用户名要有字母数字&#xff1b;订阅不用勾选 点击【Sign up】注册即可 点击【Sign in】登录账号 输入【邮箱】【密码】 点击【Continue】登录 二…

Unreal 编辑器工具 批量重命名资源

右键 - Editor Utilities - Editor Utility Blueprint&#xff0c;基类选择 Asset Action Utility 在类默认值内&#xff0c;可以添加筛选器&#xff0c;筛选指定的类型 然后新建一个函数&#xff0c;加上4个输入&#xff1a;ReplaceFrom&#xff0c;ReplaceTo&#xff0c;Add…

大数据学习笔记14-Hive基础2

一、数据字段类型 数据类型 &#xff1a;LanguageManual Types - Apache Hive - Apache Software Foundation 基本数据类型 数值相关类型 整数 tinyint smallint int bigint 小数 float double decimal 精度最高 日期类型 date 日期 timestamps 日期时间 字符串类型 s…

nginx--自定义日志跳转长连接文件缓存状态页

自定义日志服务 [rootlocalhost ~]# cat /apps/nginx/conf/conf.d/pc.conf server {listen 80;server_name www.fxq.com;error_log /data/nginx/logs/fxq-error.log info;access_log /data/nginx/logs/fxq-access.log main;location / {root /data/nginx/html/pc;index index…

使用STM32F103C8T6与蓝牙模块HC-05连接实现手机蓝牙控制LED灯

导言: 在现代智能家居系统中,远程控制设备变得越来越普遍和重要。本文将介绍如何利用STM32F103C8T6单片机和蓝牙模块HC-05实现远程控制LED灯的功能。通过这个简单的项目,可以学会如何将嵌入式系统与蓝牙通信技术相结合,实现远程控制的应用。 目录 导言: 准备工作: 硬…

Java零基础入门到精通_Day 11

1.继承 定义&#xff1a; 继承是面向对象三大特征之一。可以使得子类具有父类的属性和方法&#xff0c;还可以在子类中重新定义&#xff0c;追加属性和方法 格式&#xff1a; public class 子类 extends 父类{} 子类&#xff1a;也叫派生类 父类&#xff1a;基类/超类 继…

低代码技术在构建质量管理系统中的应用与优势

引言 在当今快节奏的商业环境中&#xff0c;高效的质量管理系统对于组织的成功至关重要。质量管理系统帮助组织确保产品或服务符合客户的期望、符合法规标准&#xff0c;并持续改进以满足不断变化的需求。与此同时&#xff0c;随着技术的不断进步&#xff0c;低代码技术作为一…

机器学习笔记-14

机器学习系统设计 1.导入 以垃圾邮件分类器为例子&#xff0c;当我们想要做一个能够区分邮件是否为垃圾邮件的项目的时候&#xff0c;首先在大量垃圾邮件中选出出现频次较高的10000-50000词作为词汇表&#xff0c;并为其设置特征&#xff0c;在对邮件分析的时候输出该邮件的特…

计算机网络-408考研

后续更新发布在B站账号&#xff1a;谭同学很nice http://【计算机408备考-什么是计算机网络&#xff0c;有什么特点&#xff1f;】 https://www.bilibili.com/video/BV1qZ421J7As/?share_sourcecopy_web&vd_source58c2a80f8de74ae56281305624c60b13http://【计算机408备考…

在idea中连接mysql

IDE&#xff08;集成开发环境&#xff09;是一种软件应用程序&#xff0c;它为开发者提供编程语言的开发环境&#xff0c;通常集成了编码、编译、调试和运行程序的多种功能。一个好的IDE可以大幅提高开发效率&#xff0c;尤其是在进行大型项目开发时。IDE通常包括以下几个核心组…

Docker-Compose编排lnmp(dockerfile) 完成Wordpress

目录 一、创建nginx镜像 二、创建mysql镜像 三、创建php镜像 四、启动wordpress 五、安装Compose 六、准备环境 ​编辑 七、编写docker-compose.yml 八、启动并运行 九、浏览器访问 一、创建nginx镜像 #基于基础镜像 FROM centos:7 #用户信息 MAINTAINER this is ngi…

LabVIEW换智能仿真三相电能表研制

LabVIEW换智能仿真三相电能表研制 在当前电力工业飞速发展的背景下&#xff0c;确保电能计量的准确性与公正性变得尤为重要。本文提出了一种基于LabVIEW和单片机技术&#xff0c;具有灵活状态切换功能的智能仿真三相电能表&#xff0c;旨在通过技术创新提高电能计量人员的培训…