PyTorch使用Tricks:学习率衰减 !!

文章目录

前言

1、指数衰减

2、固定步长衰减

3、多步长衰减

4、余弦退火衰减

5、自适应学习率衰减

6、自定义函数实现学习率调整:不同层不同的学习率


前言

在训练神经网络时,如果学习率过大,优化算法可能会在最优解附近震荡而无法收敛;如果学习率过小,优化算法的收敛速度可能会非常慢。因此,一种常见的策略是在训练初期使用较大的学习率来快速接近最优解,然后逐渐减小学习率,使得优化算法可以更精细地调整模型参数,从而找到更好的最优解。

通常学习率衰减有以下的措施:

  • 指数衰减:学习率按照指数的形式衰减,每次乘以一个固定的衰减系数,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现,需要指定优化器和衰减系数。
  • 固定步长衰减:学习率每隔一定步数(或者epoch)就减少为原来的一定比例,可以使用 torch.optim.lr_scheduler.StepLR 类来实现,需要指定优化器、步长和衰减比例。
  • 多步长衰减:学习率在指定的区间内保持不变,在区间的右侧值进行一次衰减,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现,需要指定优化器、区间列表和衰减比例。
  • 余弦退火衰减:学习率按照余弦函数的周期和最值进行变化,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现,需要指定优化器、周期和最小值。
  • 自适应学习率衰减:这种策略会根据模型的训练进度自动调整学习率,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。
  • 自适应函数实现学习率调整:不同层不同的学习率。

1、指数衰减

指数衰减是一种常用的学习率调整策略,其主要思想是在每个训练周期(epoch)结束时,将当前学习率乘以一个固定的衰减系数(gamma),从而实现学习率的指数衰减。这种策略可以帮助模型在训练初期快速收敛,同时在训练后期通过降低学习率来提高模型的泛化能力。

在PyTorch中,可以使用 torch.optim.lr_scheduler.ExponentialLR 类来实现指数衰减。该类的构造函数需要两个参数:一个优化器对象和一个衰减系数。在每个训练周期结束时,需要调用ExponentialLR 对象的 step() 方法来更新学习率。

以下是一个使用 ExponentialLR代码示例:

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR# 假设有一个模型参数
model_param = torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))# 使用SGD优化器,初始学习率设置为0.1
optimizer = SGD([model_param], lr=0.1)# 创建ExponentialLR对象,衰减系数设置为0.9
scheduler = ExponentialLR(optimizer, gamma=0.9)# 在每个训练周期结束时,调用step()方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()

在这个例子中,初始的学习率是0.1,每个训练周期结束时,学习率会乘以0.9,因此学习率会按照指数的形式衰减。

2、固定步长衰减

固定步长衰减是一种学习率调整策略,它的原理是每隔一定的迭代次数(或者epoch),就将学习率乘以一个固定的比例,从而使学习率逐渐减小。这样做的目的是在训练初期使用较大的学习率,加快收敛速度,而在训练后期使用较小的学习率,提高模型精度。

PyTorch提供了 torch.optim.lr_scheduler.StepLR 类来实现固定步长衰减,它的参数有:

  • optimizer:要进行学习率衰减的优化器,例如 torch.optim.SGD torch.optim.Adam等。
  • step_size:每隔多少隔迭代次数(或者epoch)进行一次学习率衰减,必须是正整数。
  • gamma:学习率衰减的乘法因子,必须是0到1之间的数,表示每次衰减为原来的 gamma倍。
  • last_epoch:最后一个epoch的索引,用于恢复训练的状态,默认为-1,表示从头开始训练。
  • verbose:是否打印学习率更新的信息,默认为False。

下面是一个使用  torch.optim.lr_scheduler.StepLR 类的具体例子,假设有一个简单的线性模型,使用 torch.optim.SGD 作为优化器,初始学习率为0.1,每隔5个epoch就将学习率乘以0.8,训练100个epoch:

import torch
import matplotlib.pyplot as plt# 定义一个简单的线性模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.fc = torch.nn.Linear(1, 1)def forward(self, x):return self.fc(x)# 创建模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建固定步长衰减的学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)# 记录学习率变化
lr_list = []# 模拟训练过程
for epoch in range(100):# 更新学习率scheduler.step()# 记录当前学习率lr_list.append(optimizer.param_groups[0]['lr'])# 绘制学习率曲线
plt.plot(range(100), lr_list, color='r')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.show()

学习率在每个5个epoch后都会下降为原来的0.8倍,直到最后接近0。

固定步长衰减指数衰减都是学习率衰减的策略,但它们在衰减的方式和速度上有所不同:

  • 固定步长衰减:在每隔固定的步数(或epoch)后,学习率会减少为原来的一定比例。这种策略的衰减速度是均匀的,不会随着训练的进行而改变。
  • 指数衰减:在每个训练周期(或epoch)结束时,学习率会乘以一个固定的衰减系数,从而实现学习率的指数衰减。这种策略的衰减速度是逐渐加快的,因为每次衰减都是基于当前的学习率进行的。

3、多步长衰减

多步长衰减是一种学习率调整策略,它在指定的训练周期(或epoch)达到预设的里程碑时,将学习率减少为原来的一定比例。这种策略可以在模型训练的关键阶段动态调整学习率。

在PyTorch中,可以使用 torch.optim.lr_scheduler.MultiStepLR 类来实现多步长衰减。以下是一个使用 MultiStepLR 的代码示例:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR# 假设有一个简单的模型
model = torch.nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)# 创建 MultiStepLR 对象,设定在第 30 和 80 个 epoch 时学习率衰减为原来的 0.1 倍
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()

在这个例子中,初始的学习率是0.1,当训练到第30个epoch时,学习率会变为0.01(即0.1*0.1),当训练到第80个epoch时,学习率会再次衰减为0.001(即0.01*0.1)。

4、余弦退火衰减

余弦退火衰减是一种学习率调整策略,它按照余弦函数的周期和最值来调整学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.CosineAnnealingLR 类来实现余弦退火衰减。

以下是一个使用 CosineAnnealingLR 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)model = SimpleModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建 CosineAnnealingLR 对象,周期设置为 10,最小学习率设置为 0.01
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.01)lr_list = []
# 在每个训练周期结束时,调用 step() 方法来更新学习率
for epoch in range(100):# 这里省略了模型的训练代码# ...# 更新学习率scheduler.step()lr_list.append(optimizer.param_groups[0]['lr'])# 绘制学习率变化曲线
plt.plot(range(100), lr_list)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title("Learning rate schedule: CosineAnnealingLR")
plt.show()

在这个例子中,初始的学习率是0.1,学习率会按照余弦函数的形式在0.01到0.1之间变化,周期为10个epoch。

5、自适应学习率衰减

自适应学习率衰减是一种学习率调整策略,它会根据模型的训练进度自动调整学习率。例如,如果模型的验证误差停止下降,那么就减小学习率;如果模型的训练误差上升,那么就增大学习率。在PyTorch中,可以使用 torch.optim.lr_scheduler.ReduceLROnPlateau 类来实现自适应学习率衰减。

以下是一个使用 ReduceLROnPlateau 的代码示例:

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau# 假设有一个简单的模型
model = nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 创建 ReduceLROnPlateau 对象,当验证误差在 10 个 epoch 内没有下降时,将学习率减小为原来的 0.1 倍
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1)# 模拟训练过程
for epoch in range(100):# 这里省略了模型的训练代码# ...# 假设有一个验证误差val_loss = ...# 在每个训练周期结束时,调用 step() 方法来更新学习率scheduler.step(val_loss)

6、自定义函数实现学习率调整:不同层不同的学习率

可以通过为优化器提供一个参数组列表来实现对不同层使用不同的学习率。每个参数组是一个字典,其中包含一组参数和这组参数的学习率。

以下是一个具体的例子,假设有一个包含两个线性层的模型,想要对第一层使用学习率0.01,对第二层使用学习率0.001:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个包含两个线性层的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.layer1 = nn.Linear(10, 2)self.layer2 = nn.Linear(2, 10)model = SimpleModel()# 创建一个参数组列表,每个参数组包含一组参数和这组参数的学习率
params = [{'params': model.layer1.parameters(), 'lr': 0.01},{'params': model.layer2.parameters(), 'lr': 0.001}]# 创建优化器
optimizer = optim.SGD(params)# 现在,当调用 optimizer.step() 时,第一层的参数会使用学习率 0.01 进行更新,第二层的参数会使用学习率 0.001 进行更新

在这个例子中,首先定义了一个包含两个线性层的模型。然后,创建一个参数组列表,每个参数组都包含一组参数和这组参数的学习率。最后创建了一个优化器SGD,将这个参数组列表传递给优化器。这样,当调用 optimizer.step() 时,第一层的参数会使用学习率0.01进行更新,第二层的参数会使用学习率0.001进行更新。

参考:深度图学习与大模型LLM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/687415.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源】SpringBoot框架开发智能教学资源库系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 课程档案模块2.3 课程资源模块2.4 课程作业模块2.5 课程评价模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 课程档案表3.2.2 课程资源表3.2.3 课程作业表3.2.4 课程评价表 四、系统展示五、核心代…

超级楼梯(动态规划)

#include <iostream> using namespace std; int main (){int n;cin >> n;int f[41];f[1] 0;f[2] 1;f[3] 2;for (int i 4;i < 42;i){f[i] f[i-1] f[i-2];}while (n--){int stair;cin>> stair;cout << f[stair]<<endl;}return 0; }

【ChatIE】论文解读:Zero-Shot Information Extraction via Chatting with ChatGPT

文章目录 介绍ChatIEEntity-Relation Triple Extration (RE)Named Entity Recognition (NER)Event Extraction (EE) 实验结果结论 论文&#xff1a;Zero-Shot Information Extraction via Chatting with ChatGPT 作者&#xff1a;Xiang Wei, Xingyu Cui, Ning Cheng, Xiaobin W…

嵌入式——EEPROM(AT24C02)

目录 一、初识AT24C02 1. 介绍 2. 引脚功能 补&#xff1a; 二、AT24C02组成 1. 存储结构 2. AT24C02通讯地址 3. AT24C02寻址方式 &#xff08;1&#xff09;芯片寻址 &#xff08;2&#xff09;片内子地址寻址 三、AT24C02读写时序 1. 写操作 &#xff08;1&…

Linux|centos7下的编译|ffmpeg的二进制安装

Windows版本的ffmpeg&#xff1a; ###注意&#xff0c;高版本可能必须要windows10以及以上才支持&#xff0c;win7估计是用不了的 下载地址&#xff1a;Builds - CODEX FFMPEG gyan.dev 或者这个下载地址&#xff1a;https://github.com/BtbN/FFmpeg-Builds/releases 这两个…

ClickHouse--12-可视化工具操作

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 可视化工具操作1 tabixhttp://ui.tabix.io/ 2 DBeaverhttps://dbeaver.io/download/ 可视化工具操作 1 tabix tabix 支持通过浏览器直接连接 ClickHouse&#xff…

【制作100个unity游戏之25】3D背包、库存、制作、快捷栏、存储系统、砍伐树木获取资源、随机战利品宝箱12(附带项目源码)

效果演示 文章目录 效果演示系列目录前言悬停显示物品详情源码完结 系列目录 前言 欢迎来到【制作100个Unity游戏】系列&#xff01;本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第25篇中&#xff0c;我们将探索如何用unity制作一个3D背包、库存、制作、快…

11.【CPP】模版(深入理解模版的实例化,从编译链接的原理理解模版为何无法分离编译)

非类型模版参数 1.模版参数分为类型模版参数和非类型模版参数&#xff0c;非类型模版参数一般都是整形常量&#xff08;整形&#xff1a;size_t,int,char等&#xff09; 2.浮点数、类对象以及字符串是不允许作为非类型模版参数的。非类型模版的参数必须在编译的时候就能确定结…

leetcode hot100 拆分整数

在本题目中&#xff0c;我们需要拆分一个整数n&#xff0c;让其拆分的整数积最大。因为每拆分一次都和之前上一次拆分有关系&#xff0c;比如拆分6可以拆成2x4&#xff0c;还可以拆成2x2x2&#xff0c;那么我们可以采用动态规划来做。 首先确定dp数组的含义&#xff0c;这里dp…

第13章 网络 Page744~746 asio核心类 ip::tcp::endPoint

2. ip::tcp::endpoint ip::tcp::socket用于连接TCP服务端的 async_connect()方法的第一个入参是const endpoint_type& peer_endpoint. 此处的类型 endpoint_type 是 ip::tcp::endpoint 在 在 ip::tcp::socket 类内部的一个别名。 libucurl 库采用字符串URL表达目标的地…

LeetCode 100题目(python版本)待续...

一.哈希 1.两数之和 题目 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是&#xff0c;数组中同一个元素在答案里不能重复…

Acwing---846. 树的重心

树的重心 1.题目2.基本思想3.代码实现 1.题目 给定一颗树&#xff0c;树中包含 n n n 个结点&#xff08;编号 1 ∼ n 1∼n 1∼n&#xff09;和 n − 1 n−1 n−1 条无向边。 请你找到树的重心&#xff0c;并输出将重心删除后&#xff0c;剩余各个连通块中点数的最大值。 …

【C Primer Plus第六版 学习笔记】 第十七章 高级数据表示

有基础&#xff0c;进阶用&#xff0c;个人查漏补缺 链表&#xff1a;假设要编写一个程序&#xff0c;让用户输入一年内看过的所有电影&#xff0c;要储存每部影片的片名和评级。 #include <stdio.h> #include <stdlib.h> /* 提供malloc()的原型 */ #include <s…

el-date-picker 选择年后输出的是Wed Jan 01 2025 00:00:00 GMT+0800 (中国标准时间)

文章目录 问题分析 问题 在使用 el-date-picker 做只选择年份的控制器时&#xff0c;出现如下问题&#xff1a;el-date-picker选择年后输出的是Wed Jan 01 2025 00:00:00 GMT0800 (中国标准时间)&#xff0c;输出了两次如下 分析 在 el-date-picker 中&#xff0c;我们使用…

数学建模【非线性规划】

一、非线性规划简介 通过分析问题判断是用线性规划还是非线性规划 线性规划&#xff1a;模型中所有的变量都是一次方非线性规划&#xff1a;模型中至少一个变量是非线性 非线性规划在形式上与线性规划非常类似&#xff0c;但在数学上求解却困难很多 线性规划有通用的求解准…

计算机网络之网络安全

文章目录 1. 网络安全概述1.1 安全威胁1.1.1 被动攻击1.1.2 主动攻击 1.2 安全服务 2. 密码学与保密性2.1 密码学相关基本概念2.2 对称密钥密码体制2.2.1 DES的加密方法2.2.2.三重DES 2.3 公钥密码体制 3. 报文完整性与鉴别3.1 报文摘要和报文鉴别码3.1.1 报文摘要和报文鉴别码…

从零开始手写mmo游戏从框架到爆炸(十二)— 角色设定

导航&#xff1a;从零开始手写mmo游戏从框架到爆炸&#xff08;零&#xff09;—— 导航-CSDN博客 写了这么多的框架&#xff0c;说好的mmo游戏呢&#xff1f;所以我们暂时按下框架不表&#xff0c;这几篇我们设计英雄角色、怪物、技能和地图。本篇我们来对游戏角色…

【BUG】段错误

1. 问题 8核工程&#xff0c;核4在运行了20分钟以上&#xff0c;发生了段错误。 [C66xx_4] A00x53 A10x53 A20x4 A30x167e A40x1600 A50x850e2e A60x845097 A70xbad9f5e0 A80x0 A90x33 A100x53535353 A110x0 A120x0 A130x0 A140x0 A150x0 A160x36312e35 A170x20 A180x844df0 …

没有PFMEA分析的检测过程会有什么风险?

随着科技的快速发展&#xff0c;产品复杂度不断提升&#xff0c;检测过程的重要性日益凸显。然而&#xff0c;在这个过程中&#xff0c;如果没有进行PFMEA分析&#xff0c;将会带来怎样的风险呢&#xff1f;本文将对此进行深入探讨。 众所周知&#xff0c;检测是确保产品质量的…

openGauss学习笔记-222 openGauss性能调优-系统调优-操作系统参数调优

文章目录 openGauss学习笔记-222 openGauss性能调优-系统调优-操作系统参数调优222.1 前提条件222.2 内存相关参数设置222.3 网络相关参数设置222.4 I/O相关参数设置 openGauss学习笔记-222 openGauss性能调优-系统调优-操作系统参数调优 在性能调优过程中&#xff0c;可以根据…