Pytorch lr_scheduler 调整学习率

Pytorch lr_scheduler 调整学习率

背景

上篇文章连接

在运行 VGG 代码的时候有这么几行代码:

# 定义模型进行训练
model = VGG16()
# model.load_state_dict(torch.load('./my-VGG16.pth'))
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3)
loss_func = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1) # todo 了解这个操作!

定义优化器,损失函数我都知道要做,但是突然出现一个 scheduler 我就看不懂是什么了。

这里来了解一下

在PyTorch中,lr_scheduler 是用于调整学习率(Learning Rate)的一个模块,它可以在训练过程中动态地改变学习率,有助于改善模型的训练效果,避免陷入局部最优解,或者加速收敛过程。StepLRlr_scheduler 模块中的一个类,它按照固定的步长(step_size)来降低学习率。

在你给出的代码片段中:

python复制代码scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
  • optimizer:这是你需要调整学习率的优化器对象。在你的例子中,它是通过 optim.SGD 创建的,用于VGG16模型的参数优化。
  • step_size:这个参数定义了学习率更新的周期。在你的例子中,step_size=5 意味着每经过5个epoch(训练周期),学习率就会更新一次。
  • gamma:这个参数定义了学习率更新的乘法因子。在你的例子中,gamma=0.4 意味着每次学习率更新时,新的学习率将是旧学习率的0.4倍。这有助于在训练过程中逐渐减小学习率,以便在接近最优解时进行更细致的调整。
  • last_epoch:这个参数用于指示在调用 scheduler.step() 之前,已经模拟了多少个epoch的更新。在初次设置时,如果希望从头开始计算(即从epoch 0开始),通常将其设置为 -1。这样,第一次调用 scheduler.step() 时,将认为是从epoch 0结束后的状态开始,从而按照 step_size 的设置来更新学习率。

在训练循环中,你需要在每个epoch结束后调用 scheduler.step() 来更新学习率。例如:

for epoch in range(num_epochs):  # 训练代码...  # 在每个epoch结束后更新学习率  scheduler.step()

通过这种方式,你可以根据训练过程的需要,灵活地调整学习率,以期获得更好的训练效果。

实际调用:

# 定义训练步骤
total_times = 40
total = 0
accuracy_rate = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")for epoch in range(total_times):model.train()model.to(device)running_loss = 0.0total_correct = 0total_trainset = 0print("epoch: ",epoch)for i, (data,labels) in enumerate(train_loader):data = data.to(device)outputs = model(data).to(device)labels = labels.to(device)loss = loss_func(outputs,labels).to(device)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()_,pred = outputs.max(1)correct = (pred == labels).sum().item()total_correct += correcttotal_trainset += data.shape[0]if i % 100 == 0 and i > 0:print(f"正在进行第{i}次训练, running_loss={running_loss}".format(i, running_loss))running_loss = 0.0test()scheduler.step()

其他调整学习率的方法

在PyTorch中,除了StepLR之外,还有多种其他方法用于调整学习率。这些方法可以帮助在训练过程中更灵活地控制学习率,以适应不同的训练需求和数据集特性。以下是一些常见的PyTorch学习率调整方法:

  1. MultiStepLR:
    • 功能:按给定间隔调整学习率。
    • 主要参数:
      • milestones:一个列表,包含需要调整学习率的epoch数。在每个milestones指定的epoch结束时,学习率会按照给定的gamma进行调整。
      • gamma:调整系数,与StepLR中的gamma相同,用于计算新的学习率。
    • 使用场景:当你知道在特定的epoch点需要调整学习率时,可以使用此方法。
  2. ExponentialLR:
    • 功能:按指数衰减调整学习率。
    • 主要参数:
      • gamma:指数的底,通常设置为小于1的数(如0.9),用于计算学习率的衰减。
    • 使用场景:当希望学习率随着训练的进行逐渐减小,且减小速度呈指数级变化时,可以使用此方法。
  3. CosineAnnealingLR:
    • 功能:余弦周期调整学习率。
    • 主要参数:
      • T_max:下降周期,表示余弦周期的一半。学习率会在每个T_max周期内按照余弦函数变化。
      • eta_min:学习率下限,学习率变化过程中不会低于此值。
    • 使用场景:当希望学习率在一个周期内先下降后上升,模拟退火过程时,可以使用此方法。
  4. ReduceLROnPlateau:
    • 功能:监控某个指标(如loss或accuracy),当指标不再改善时调整学习率。
    • 主要参数:
      • mode'min''max',表示监控的指标是应该最小化还是最大化。
      • factor:调整系数,用于计算新的学习率。
      • patience:“耐心”参数,表示在调整学习率之前,指标可以接受连续多少次不改善。
      • cooldown:“冷却时间”,在调整学习率后,暂停监控一段时间。
      • min_lr:学习率下限。
    • 使用场景:当你想根据模型的实际表现(而非固定的epoch数)来调整学习率时,可以使用此方法。
  5. LambdaLR
    • 功能:使用自定义的lambda函数来调整学习率。
    • 主要参数:
      • lr_lambda:一个函数或函数列表,用于计算新的学习率。如果传入函数列表,则列表中的每个函数都会独立地应用于每个参数组的学习率。
    • 使用场景:当需要实现更复杂的学习率调整策略时,可以使用此方法。
  6. WarmupLR(注意:这不是PyTorch官方直接提供的一个类,但可以通过自定义实现):
    • 功能:在训练初期逐渐增加学习率,以达到预热模型的效果。
    • 实现方式:可以通过自定义一个lr_scheduler或使用LambdaLR结合预热函数来实现。
    • 使用场景:当模型在训练初期容易因为学习率过大而不稳定时,可以使用此方法。

这些学习率调整方法各有特点,适用于不同的训练场景和需求。在实际应用中,可以根据数据集的特性、模型的复杂度以及训练目标来选择合适的调整方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/871488.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3中谷歌地图+外网申请-原生-实现地址输入搜索+点击地图获取地址回显 +获取国外的geoJson实现省市区级联选择

一. 效果&#xff1a;输入后显示相关的地址列表&#xff0c;选中后出现标示图标和居中定位 1.初始化谷歌地图 在index.html加上谷歌api请求库 <script src"https://maps.googleapis.com/maps/api/js?key申请到的谷歌地图密钥&vweekly&librariesgeometry,place…

基于TCP的在线词典系统(分阶段实现)(阻塞io和多路io复用(select)实现)

1.功能说明 一共四个功能&#xff1a; 注册 登录 查询单词 查询历史记录 单词和解释保存在文件中&#xff0c;单词和解释只占一行, 一行最多300个字节&#xff0c;单词和解释之间至少有一个空格。 2.功能演示 3、分阶段完成各个功能 3.1 完成服务器和客户端的连接 servic…

Vue el-input 限制输入内容

&#x1f914;日常项目中经常遇到既要el-input的样式&#xff0c;又要el-input-number限制&#xff0c;所以需要绑定input事件进行约束输入限制。 以下使用自定义指令进行约束el-input输入的值&#xff0c;便于后期统一管理和拓展。 预览 代码 <!DOCTYPE html> <ht…

【机器学习】精准农业新纪元:机器学习引领的作物管理革命

&#x1f4dd;个人主页&#x1f339;&#xff1a;Eternity._ &#x1f339;&#x1f339;期待您的关注 &#x1f339;&#x1f339; ❀目录 &#x1f50d;1. 引言&#x1f4d2;2. 精准农业的背景与现状&#x1f341;精准农业的概念与发展历程&#x1f342;国内外精准农业实践案…

【数据结构】手写堆 HEAP

heap【堆】掌握 手写上浮、下沉、建堆函数 对一组数进行堆排序 直接使用接口函数heapq 什么是堆&#xff1f;&#xff1f;&#xff1f;堆是一个二叉树。也就是有两个叉。下面是一个大根堆&#xff1a; 大根堆的每一个根节点比他的子节点都大 有大根堆就有小根堆&#xff1…

(南京观海微电子)——二极管应用及选取

二极管是 用半导体材料(硅、硒、锗等)制成的一种电子器件。二极管有两个电极&#xff0c;正极&#xff0c;又叫阳极&#xff1b;负极&#xff0c;又叫阴极&#xff0c;给二极管两极间加上正向电压时&#xff0c;二极管导通&#xff0c; 加上反向电压时&#xff0c;二极管截止。…

Vue1-Vue核心

目录 Vue简介 官网 介绍与描述 Vue的特点 与其它 JS 框架的关联 Vue周边库 初识Vue Vue模板语法 数据绑定 el与data的两种写法 MVVM模型 数据代理 回顾Object.defineProperty方法 何为数据代理 Vue中的数据代理 数据代理图示 事件处理 事件的基本使用 事件修…

【UE5.1】Chaos物理系统基础——06 子弹破坏石块

前言 在前面我们已经完成了场系统的制作&#xff08;【UE5.1】Chaos物理系统基础——02 场系统的应用_ue5&#xff09;以及子弹的制作&#xff08;【UE5.1 角色练习】16-枪械射击——瞄准&#xff09;&#xff0c;现在我们准备实现的效果是&#xff0c;角色发射子弹来破坏石柱。…

STM32智能空气质量监测系统教程

目录 引言环境准备智能空气质量监测系统基础代码实现&#xff1a;实现智能空气质量监测系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;空气质量监测与优化问题解决方案与优化收尾与总结 1. 引言 智能空…

基于Java+SpringMvc+Vue技术的药品进销存仓库管理系统设计与实现系统(源码+LW+部署讲解)

注&#xff1a;每个学校每个老师对论文的格式要求不一样&#xff0c;故本论文只供参考&#xff0c;本论文页数达到60页以上&#xff0c;字数在6000及以上。 基于JavaSpringMvcVue技术的在线学习交流平台设计与实现 目录 第一章 绪论 1.1 研究背景 1.2 研究现状 1.3 研究内容…

卸载wps office的几种方法收录

​ 第一种方法: 1.打开【任务管理器】&#xff0c;找到相关程序&#xff0c;点击【结束任务】。任务管理器可以通过左下角搜索找到。 2.点击【开始】&#xff0d;【设置】&#xff0d;【应用】&#xff0d;下拉找到WPS应用&#xff0c;右键卸载&#xff0c;不保留软件配置 …

Git学习1_Git安装(CSDN_20240714)

git下载 git下载官网如下&#xff1a; Git - Downloads (git-scm.com)https://git-scm.com/downloads 根据机器操作系统&#xff0c;下载对应的安装包 git安装 1. 点击安装程序&#xff0c;进入安装界面&#xff0c;如下图所示&#xff0c;点击next。 2. 选择安装路径&…

护网HW面试常问——组件中间件框架漏洞(包含流量特征)

apache&iis&nginx中间件解析漏洞 参考我之前的文章&#xff1a;护网HW面试—apache&iis&nginx中间件解析漏洞篇-CSDN博客 log4j2 漏洞原理&#xff1a; 该漏洞主要是由于日志在打印时当遇到${后&#xff0c;以:号作为分割&#xff0c;将表达式内容分割成两部…

Leetcode(经典题)day2

H指数 274. H 指数 - 力扣&#xff08;LeetCode&#xff09; 先对数组排序&#xff0c;然后从大的一头开始遍历&#xff0c;只要数组当前的数比现在的h指数大就给h指数1&#xff0c;直到数组当前的数比现在的h指数小的时候结束&#xff0c;这时h的值就是要返回的结果。 排序…

下载安装nodejs npm jarn笔记

下载安装nodejs npm jarn笔记 下载 Node.js安装Node.js修改node全局路径安装yarn 下载 Node.js 下载Node.js 安装Node.js 双击下载的下来的.msi文件运行并安装一直点next。安装路径可以是默认也可自定义。安装完成后Node.js和npm就安装完成了 命令行输入&#xff1a; nod…

LeetCode 面试题02.04.分割链表

LeetCode 面试题02.04.分割链表 C写法 思路&#x1f914;&#xff1a; ​ 将x分为两段&#xff0c;一段放小于x的值&#xff0c;另一段放大于x的值。开辟四个指针lesshead、lesstail、greaterhead、greatertail&#xff0c;head为哨兵位&#xff0c;防止链表为空时情况过于复杂…

推荐一款 uniapp Vaptcha 手势验证码插件

插件地址&#xff1a;VAPTCHA手势验证码 - DCloud 插件市场 具体使用方式可访问插件地址自行查阅

Vue从零到实战

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 非常期待和您一起在这个小…

WEB前端03-CSS3基础

CSS3基础 1.CSS基本概念 CSS是Cascading Style Sheets&#xff08;层叠样式表&#xff09;的缩写&#xff0c;它是一种对Web文档添加样式的简单机制&#xff0c;是一种表现HTML或XML等文件外观样式的计算机语言&#xff0c;是一种网页排版和布局设计的技术。 CSS的特点 纯C…

R语言安装devtools包失败过程总结

R语言安装devtools包时&#xff0c;遇到usethis包总是安装失败&#xff0c;现总结如下方法&#xff0c;亲测可有效 一、usethis包及cli包安装问题 首先&#xff0c;Install.packages("usethis")出现如下错误&#xff0c;定位到是这个cli包出现问题 载入需要的程辑包…