深度学习-学习率调整策略

在深度学习中,学习率调整策略(Learning Rate Scheduling)用于在训练过程中动态调整学习率,以实现更快的收敛和更好的模型性能。选择合适的学习率策略可以避免模型陷入局部最优、震荡不稳定等问题。下面介绍一些常见的学习率调整策略:

1. Step Decay(分步衰减)

原理

Step Decay 是一种分段衰减策略,每隔一定的训练周期或步骤,学习率会缩减一个固定的因子。这可以在训练中途降低学习率,从而让模型在训练末期更加稳定地收敛。

公式

其中:

  • initial_lr 是初始学习率
  • factor 是每次衰减的因子,一般小于 1(例如 0.1)
  • k 是衰减次数
适用场景

适合训练中需要逐步收敛的模型,如卷积神经网络。在一定训练轮次后,降低学习率有助于模型以更稳定的步伐接近最优解。

优缺点
  • 优点:可以逐步收敛,适合比较平稳的优化任务。
  • 缺点:由于步长是固定的,可能会导致过早或过晚调整学习率。
代码示例

在 PyTorch 中实现 Step Decay 可以使用 StepLR

import torch
import torch.optim as optim
import torch.nn as nn# 假设我们有一个简单的模型
model = nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 初始学习率 0.1
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  # 每10个epoch衰减一半# 模拟训练过程
for epoch in range(30):# 假设进行前向和后向传播optimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()# 更新学习率scheduler.step()print(f"Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()[0]}")

2. Exponential Decay(指数衰减)

原理

指数衰减策略的学习率会以指数方式逐渐减少,公式为:

其中 decay_rate 控制学习率下降的速度。指数衰减适合需要平稳下降的任务,因为这种衰减是连续的且平滑。

适用场景

适合长时间训练或训练数据复杂的模型,能让模型在训练后期继续保持较好的收敛效果。

优缺点
  • 优点:平滑衰减,适合长时间训练。
  • 缺点:如果 decay_rate 设置不当,可能会导致过早或过晚下降。
代码示例

在 PyTorch 中使用 ExponentialLR 实现:

scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)  # 指数衰减因子 0.9for epoch in range(30):optimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()# 更新学习率scheduler.step()print(f"Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()[0]}")

3. Cosine Annealing(余弦退火)

原理

Cosine Annealing 利用余弦函数使学习率周期性下降,并在周期末期快速降低至接近 0 的值。

其中 T_max 是控制余弦周期的最大步数,周期性下降可以使模型在接近全局最优时表现更稳定。

适用场景

适合在有一定噪声的数据集上进行多轮次训练,使模型在每个周期内都能充分探索损失函数的不同区域。

优缺点
  • 优点:自然周期下降,易于模型在训练中后期稳定收敛。
  • 缺点:周期设置需要与任务匹配,否则可能在全局最优时过早结束。
代码示例

在 PyTorch 中使用 CosineAnnealingLR 实现:

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)for epoch in range(30):optimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()# 更新学习率scheduler.step()print(f"Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()[0]}")

4. Reduce on Plateau(基于验证集表现动态调整)

原理

当验证集的损失在一段时间(耐心期)内没有显著下降,则将学习率按一定因子减少。这样可以防止模型陷入局部最优。

适用场景

特别适合那些验证集损失不稳定或在收敛后期趋于平稳的模型,比如需要细致调整的分类任务。

优缺点
  • 优点:自适应调整学习率,使训练在收敛后期更稳定。
  • 缺点:依赖验证集表现,调整耐心期参数复杂。
代码示例

在 PyTorch 中使用 ReduceLROnPlateau 实现:

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)for epoch in range(30):optimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()# 模拟验证损失val_loss = loss.item() + (epoch % 10) * 0.1  # 可调节该值scheduler.step(val_loss)print(f"Epoch {epoch+1}, Learning Rate: {optimizer.param_groups[0]['lr']}")

5. Cyclical Learning Rate (CLR)

原理

CLR 设定了上下限值,让学习率在两者之间循环,探索损失空间不同区域,防止陷入局部最优。

适用场景

适合包含复杂损失结构的任务,如图像分类中的较大卷积网络。

优缺点
  • 优点:避免陷入局部最优,提高全局搜索能力。
  • 缺点:调整范围较难控制,适用性有限。
代码示例

可以使用 torch.optim 库实现自定义的 CLR:

import numpy as np# 计算CLR的函数
def cyclical_learning_rate(step, base_lr=0.001, max_lr=0.006, step_size=2000):cycle = np.floor(1 + step / (2 * step_size))x = np.abs(step / step_size - 2 * cycle + 1)lr = base_lr + (max_lr - base_lr) * np.maximum(0, (1 - x))return lr# 训练过程
for step in range(10000):lr = cyclical_learning_rate(step)optimizer.param_groups[0]['lr'] = lroptimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()if step % 1000 == 0:print(f"Step {step}, Learning Rate: {optimizer.param_groups[0]['lr']}")

6. One Cycle Policy(单周期策略)

原理

One Cycle Policy 从较低学习率开始,逐渐增加到最大值,然后再逐步减小到较低值。适合需要快速探索和稳定收敛的任务。

适用场景

适合迁移学习和较小数据集。

优缺点
  • 优点:适合迁移学习,能快速稳定收敛。
  • 缺点:对于较长训练任务效果一般。
代码示例

在 PyTorch 中实现 One Cycle Policy:

from torch.optim.lr_scheduler import OneCycleLRscheduler = OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=100, epochs=10)for epoch in range(10):for i in range(100):  # 假设一个 epoch 有 100 个 batchoptimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()scheduler.step()  # 每步更新学习率print(f"Epoch {epoch+1}, Step {i+1}, Learning Rate: {scheduler.get_last_lr()[0]}")

7.如何选择合适的学习率调整策略?

(1). 数据规模和训练时长

  • 小数据集训练时间短
    使用 One Cycle PolicyCyclical Learning Rate。这类策略能够快速调整学习率,在有限的时间内加速训练并避免局部最优。

  • 中等数据集训练时间适中
    可以选择 Step DecayExponential Decay。这些策略在收敛过程中平稳下降,适合中等规模的任务。

  • 大数据集长时间训练
    选择 Cosine AnnealingReduce on Plateau。这类策略能够适应较长的训练周期,避免学习率下降过快,从而保持稳定的收敛效果。


(2). 模型类型和复杂度

  • 简单模型(如浅层神经网络):
    使用 Step DecayExponential Decay。这些简单的衰减策略适合训练时间不长且模型复杂度低的情况。

  • 深度模型(如卷积神经网络、递归神经网络):
    选择 Cosine AnnealingReduce on PlateauOne Cycle Policy。这些策略在后期能平滑衰减,有助于复杂模型更好地探索损失函数的不同区域。

  • 预训练模型的微调
    One Cycle Policy 是一个理想选择。它从较低的学习率开始,快速升至最大,再衰减回较小值,适合在微调过程中稳定调整参数。


(3). 任务类型

  • 分类任务
    分类任务中常用 Step DecayCosine AnnealingCyclical Learning Rate,特别是在图像分类任务中,余弦退火可以在训练后期更好地收敛,CLR 则有助于探索不同的损失空间。

  • 回归任务
    Exponential DecayReduce on Plateau,回归任务通常要求模型在后期保持较稳定的收敛效果,因此指数衰减和基于验证集表现的动态调整策略更为合适。

  • 时间序列预测
    使用 Reduce on PlateauExponential Decay,因为时间序列预测中数据较为复杂,不同时间段的学习率需求变化大,可以使用验证集损失表现来决定学习率的动态调整。


(4). 模型对学习率敏感性

  • 学习率敏感模型
    对学习率波动敏感的模型适合使用 Cosine AnnealingReduce on Plateau。这类模型需要学习率逐步下降的过程来平稳收敛,不易受到过大的学习率波动影响。

  • 对学习率不敏感的模型
    使用 Cyclical Learning RateOne Cycle Policy,这两种策略适合让学习率在一个范围内波动,从而让模型更快跳出局部最优,快速找到全局最优解。


(5). 损失函数表现与收敛性

  • 损失波动较大(不稳定收敛):
    选择 Reduce on Plateau,让模型在验证集损失长时间不下降时再降低学习率,避免过早或频繁地调整学习率。

  • 损失逐渐收敛(平稳下降):
    使用 Step DecayExponential Decay,这些策略更适合平稳下降的场景,且能在训练后期提供更小的学习率。

任务场景推荐学习率调整策略
小数据集,快速训练One Cycle Policy,CLR
大数据集,长时间训练Cosine Annealing,Reduce on Plateau
微调预训练模型One Cycle Policy
简单模型Step Decay,Exponential Decay
深层复杂模型Cosine Annealing,Reduce on Plateau
分类任务Step Decay,Cosine Annealing,CLR
时间序列或自然语言处理Exponential Decay,Reduce on Plateau
高波动的验证集损失Reduce on Plateau

以下是一个综合示例,展示了如何在 PyTorch 中动态选择并应用学习率调整策略:

import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR# 假设我们有一个简单的模型
model = nn.Linear(10, 2)
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 初始学习率 0.1# 根据需求选择合适的学习率调整策略
def get_scheduler(optimizer, strategy='step_decay'):if strategy == 'step_decay':return StepLR(optimizer, step_size=10, gamma=0.5)elif strategy == 'exponential_decay':return ExponentialLR(optimizer, gamma=0.9)elif strategy == 'cosine_annealing':return CosineAnnealingLR(optimizer, T_max=30)elif strategy == 'reduce_on_plateau':return ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)elif strategy == 'one_cycle':return OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=100, epochs=10)else:raise ValueError("Unknown strategy type")# 选择策略
scheduler = get_scheduler(optimizer, strategy='cosine_annealing')# 模拟训练过程
for epoch in range(30):optimizer.zero_grad()output = model(torch.randn(10))loss = nn.MSELoss()(output, torch.randn(2))loss.backward()optimizer.step()# 调整学习率if isinstance(scheduler, ReduceLROnPlateau):# 如果是 Reduce on Plateau,使用验证集的损失作为依据val_loss = loss.item() + (epoch % 10) * 0.1  # 模拟验证损失scheduler.step(val_loss)else:scheduler.step()print(f"Epoch {epoch+1}, Learning Rate: {optimizer.param_groups[0]['lr']}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/59128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Caffeine 手动策略缓存 put() 方法源码解析

BoundedLocalManualCache put() 方法源码解析 先看一下BoundedLocalManualCache的类图 com.github.benmanes.caffeine.cache.BoundedLocalCache中定义的BoundedLocalManualCache静态内部类。 static class BoundedLocalManualCache<K, V> implements LocalManualCache&…

《Qwen2-VL》论文精读【上】:发表于2024年10月 Qwen2-VL 迅速崛起 | 性能与GPT-4o和Claude3.5相当

1、论文地址Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution 2、Qwen2-VL的Github仓库地址 该论文发表于2024年4月&#xff0c;是Qwen2-VL的续作&#xff0c;截止2024年11月&#xff0c;引用数24 文章目录 1 论文摘要2 引言3 实验3.…

StandardThreadExecutor源码解读与使用(tomcat的线程池实现类)

&#x1f3f7;️个人主页&#xff1a;牵着猫散步的鼠鼠 &#x1f3f7;️系列专栏&#xff1a;Java源码解读-专栏 &#x1f3f7;️个人学习笔记&#xff0c;若有缺误&#xff0c;欢迎评论区指正 目录 目录 1.前言 2.线程池基础知识回顾 2.1.线程池的组成 2.2.工作流程 2…

前端埋点与监控最佳实践:从基础到全流程实现.

前端埋点与监控最佳实践&#xff1a;从基础到全流程实现 大纲 我们会从以下三个方向来讲解埋点与监控的知识&#xff1a; 什么是埋点&#xff1f;什么是监控&#xff1f; JS 中实现监控的核心方案 写一个“相对”完整的监控实例 一、什么是埋点&#xff1f;什么是监控&am…

rom定制系列------红米k30_4G版澎湃os安卓13批量线刷固件

&#x1f49d;&#x1f49d;&#x1f49d;红米k30 4G版&#xff0c;机型代码;phoenix.此机型官方固件最后一版为稳定版13.0.6安卓12的固件。客户的软件需运行在至少安卓13的系统至少。测试原生适配有bug。最终测试在第三方澎湃os安卓13的固件可以完美运行。 &#x1f49d;&am…

钉钉平台开发小程序

一、下载小程序开发者工具 官网地址&#xff1a;小程序开发工具 - 钉钉开放平台 客户端类型 下载链接 MacOS x64 https://ur.alipay.com/volans-demo_MiniProgramStudio-x64.dmg MacOS arm64 https://ur.alipay.com/volans-demo_MiniProgramStudio-arm64.dmg Windows ht…

android——渐变色

1、xml的方式实现渐变色 效果图&#xff1a; xml的代码&#xff1a; <?xml version"1.0" encoding"utf-8"?> <shape xmlns:android"http://schemas.android.com/apk/res/android"xmlns:tools"http://schemas.android.com/tools…

微信小程序生成二维码

目前是在开发小程序端 --> 微信小程序。然后接到需求&#xff1a;根据 form 表单填写内容生成二维码&#xff08;第一版&#xff1a;表单目前需要客户进行自己输入&#xff0c;然后点击生成按钮实时生成二维码&#xff0c;不需要向后端请求&#xff0c;不存如数据库&#xf…

rhce:web服务器

web服务器简介 服务器端&#xff1a;此处使用 nginx 提供 web 服务&#xff0c; RPM 包获取&#xff1a; http://nginx.org/packages/ /etc/nginx/ ├── conf.d #子配置文件目录 ├── default.d ├── fastcgi.conf ├── fastcgi.conf.default ├── fastcgi_params #用…

解决使用netstat查看端口显示FIN_WAIT的问题

解决使用netstat查看端口显示FIN_WAIT的问题 1. 理解`FIN_WAIT`状态2. 检查应用程序3. 检查网络延迟和稳定性4. 更新和修补系统5. 调整TCP参数6. 使用更详细的工具进行分析7. 咨询开发者或技术支持8. 定期监控和评估结论在使用 netstat查看网络连接状态时,如果发现大量连接处…

01LangChain 实战课开篇——AI奇点时刻

LangChain 实战课开篇——AI奇点时刻 课程简介 课程背景&#xff1a;随着ChatGPT和GPT-4的出现&#xff0c;AI技术与实际应用之间的距离变得前所未有的近。LangChain作为基于大模型的应用开发框架&#xff0c;为程序员提供了开发智能应用的新工具。 LangChain 概述 定义&am…

【java】java的基本程序设计结构06-运算符

运算符 一、分类 算术运算符关系运算符位运算符逻辑运算符赋值运算符其他运算符 1.1 算术运算符 操作符描述例子加法 - 相加运算符两侧的值A B 等于 30-减法 - 左操作数减去右操作数A – B 等于 -10*乘法 - 相乘操作符两侧的值A * B等于200/除法 - 左操作数除以右操作数B /…

Spring Cloud Sleuth(Micrometer Tracing +Zipkin)

分布式链路追踪 分布式链路追踪技术要解决的问题&#xff0c;分布式链路追踪&#xff08;Distributed Tracing&#xff09;&#xff0c;就是将一次分布式请求还原成调用链路&#xff0c;进行日志记录&#xff0c;性能监控并将一次分布式请求的调用情况集中展示。比如各个服务节…

关于我的编程语言——C/C++——第四篇(深入1)

&#xff08;叠甲&#xff1a;如有侵权请联系&#xff0c;内容都是自己学习的总结&#xff0c;一定不全面&#xff0c;仅当互相交流&#xff08;轻点骂&#xff09;我也只是站在巨人肩膀上的一个小卡拉米&#xff0c;已老实&#xff0c;求放过&#xff09; 字符类型介绍 char…

一台手机可以登录运营多少个TikTok账号?

很多TikTok内容创作者和商家通过运营多个账号来实现品牌曝光和产品销售&#xff0c;这种矩阵运营方式需要一定的技巧和设备成本&#xff0c;那么对于很多新手来说&#xff0c;一台手机可以登录和运营多少个TikTok账号呢&#xff1f; 一、运营TikTok账号的数量限制 TikTok的官…

DNS服务器部署

一、要求 1.搭建dns服务器能够对自定义的正向或者反向域完成数据解析查询。 2.配置从DNS服务器&#xff0c;对主dns服务器进行数据备份。 二、配置 1.搭建dns服务器能够对自定义的正向或者反向域完成数据解析查询。 &#xff08;1&#xff09;首先需要安装bind服务 &#xf…

三周精通FastAPI:28 构建更大的应用 - 多个文件

官方文档&#xff1a;https://fastapi.tiangolo.com/zh/tutorial/bigger-applications 更大的应用 - 多个文件 如果你正在开发一个应用程序或 Web API&#xff0c;很少会将所有的内容都放在一个文件中。 FastAPI 提供了一个方便的工具&#xff0c;可以在保持所有灵活性的同时…

【react使用AES对称加密的实现】

react使用AES对称加密的实现 前言使用CryptoJS库密钥存放加密方法解密方法结语 前言 项目中要求敏感信息怕被抓包泄密必须进行加密传输处理&#xff0c;普通的md5加密虽然能解决传输问题&#xff0c;但是项目中有权限的用户是需要查看数据进行查询的&#xff0c;所以就不能直接…

【STM32】INA3221三通道电压电流采集模块,HAL库

一、简单介绍 芯片的datasheet地址&#xff1a; INA3221 三通道、高侧测量、分流和总线电压监视器&#xff0c;具有兼容 I2C 和 SMBUS 的接口 datasheet (Rev. B) 笔者所使用的INA3221是淘宝买的模块 原理图 模块的三个通道的电压都是一样&#xff0c;都是POWER。这个芯片采用…

《机器人SLAM导航核心技术与实战》第1季:第10章_其他SLAM系统

视频讲解 【第1季】10.第10章_其他SLAM系统-视频讲解 【第1季】10.1.第10章_其他SLAM系统_RTABMAP算法-视频讲解 【第1季】10.2.第10章_其他SLAM系统_VINS算法-视频讲解 【第1季】10.3.第10章_其他SLAM系统_机器学习与SLAM-视频讲解 第1季&#xff1a;第10章_其他SLAM系统 …