pytorch中模型训练的学习率动态调整

pytorch动态调整学习率

  • 背景
  • 手动设置自动衰减的学习率
  • pytorch中的torch.optim.lr_scheduler
    • torch.optim.lr_scheduler.ExponentialLR
    • torch.optim.lr_scheduler.StepLR
    • torch.optim.lr_scheduler.MultiStepLR
    • torch.optim.lr_scheduler.ReduceLROnPlateau

背景

  在神经网络模型的训练过程中,一般采取梯度下降法来对模型的参数进行更新,其中,学习率 α \alpha α控制着梯度更新的步长(step), α \alpha α越大,意味着下降的越快,到达最优点的速度也越快。学习率较大时,会加速学习,使得模型更容易接近局部或全局最优解。但是在后期会有较大波动,始终难以达到最优。
  因此,我们引入学习率衰减的概念,就是在模型训练初期,使用较大的学习率进行优化,随着迭代次数增加,学习率会逐渐进行减小,保证模型在训练后期不会有太大的波动,从而更加接近最优解,那么,在pytorch中,学习率衰减应该如何实现?

手动设置自动衰减的学习率

  根据进行的epoch的数量,在每一轮对优化器的学习率进行更新。

def adjust_learning_rate(optimizer, epoch, start_lr):#每三个epoch衰减一次lr = start_lr * (0.1 ** (epoch // 3))for param_group in optimizer.param_groups:param_group['lr'] = lr

  这种方法根据自己的逻辑和epoch的数量对学习率进行调整,使用举例:

optimizer = torch.optim.SGD(net.parameters(),lr = start_lr)
for epoch in range(100):#手动调整学习率adjust_learning_rate(optimizer,epoch,start_lr)#查看每一轮的学习率情况print("Epoch:{}  Lr:{:.2E}".format(epoch,optimizer.state_dict()['param_groups'][0]['lr']))for data,label in traindataloader :output = net(data)loss = myloss(output,label)optimizer.zero_grad()loss.backward()optimizer.step()

pytorch中的torch.optim.lr_scheduler

  torch.optim.lr_scheduler是pytorch提供的自动调整学习率的方法,基于当前epoch的数值,封装了几种相应的动态学习率调整方法,官方文档optim.lr_scheduler。需要注意的是这种方法对学习率的调整需要应用在优化器参数更新之后,应用方法示例:

optimizer = torch.optim.XXXXXXX()#具体optimizer的初始化
scheduler = torch.optim.lr_scheduler.XXXXXXXXXX()#具体学习率变更策略的初始化
for i in range(epoch):for data,label in dataloader:out = net(data)output_loss = loss(out,label)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()

  下面我们介绍其中几种常用的学习率更新策略。

torch.optim.lr_scheduler.ExponentialLR

  torch.optim.lr_scheduler.ExponentialLR是最简单学习率调整方法,即每一次epoch,lr都乘gamma:

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)

  其中,optimizer(optimizer):之前定义好的需要优化的优化器的实例名;gamma(float):学习率衰减的乘法因子,默认为0.1,即每次将学习率乘以0.1;last_epoch(int):默认为-1,为-1时表示将人为设置的学习率设定为调整学习率的基础值lr;verbose:如果为True,每一次更新都会打印一个标准的输出信息,默认为False。

torch.optim.lr_scheduler.StepLR

  torch.optim.lr_scheduler.StepLR是比较常用的等间隔动态调整方法,每经过step_size个epoch,做一次学习率衰减,以gamma值为缩小倍数:

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

  相比于ExponentialLR方法,多了一个step_size(int)参数,即学习率衰减的周期,每经过step_size 个epoch,做一次学习率衰减。

torch.optim.lr_scheduler.MultiStepLR

  torch.optim.lr_scheduler.StepLR根据自己设定的训练阶段调整学习率的方法,一旦达到某一阶段(milestones)时,就可以通过gamma系数降低每个参数组的学习率。可以按照milestones列表中给定的值,进行分阶段式调整学习率:

torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)

  相比于ExponentialLR方法,多了一个milestones(list)参数,这是一个关于epoch数值的list,表示在达到哪个epoch范围内开始变化,必须是升序排列,使用例子:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
#在第2,6,15个epoch调整学习率
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,6,15], gamma=0.1)

torch.optim.lr_scheduler.ReduceLROnPlateau

  与上述几种基于epoch数目调整学习率的方法不同,该方法根据验证指标的变化的调整学习率。它的原理是:当指标停止改善时,降低学习率。当模型的学习停滞时,训练过程通常会受益于将学习率降低2~10倍。该种调整方法读取一个度量指标,如果在“耐心”期间内没有发现它有所改善,那么就会降低学习率:

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode= 'rel', cooldown=0, min_1r=0, eps=1e-08)

  其中,optimizer(Optimizer):之前定义好的需要优化的优化器的实例名;mode(str):设置为min或max。当选择min时,代表当度量指标停止下降时,开始减小学习率;当选择max时,代表当度量指标停止上升时,开始减小学习率;factor(float):学习率调整的乘法因子,默认值为0.1;patience(int):可容忍的度量指标没有提升的epoch数目,默认为10。举例说明,当其设置为10时,我们可以容忍10个epoch内没有提升,如果在第11个epoch依然没有提升,那么就开始降低学习率;verbose(bool):如果设置为True,输出每一次更新的信息,默认为False;threshold(float):float类型数据,衡量新的最佳阈值,仅关注重大变化,默认为0.0001;threshold_mode(str):可选str字符串数据,为rel或abs,默认为rel。在rel模式下,如果mode参数为max,则动态阈值(dynamic_threshold)为best*(1+threshold),如果mode参数为min,则动态阈值为best+threshold,如果mode参数为min,则动态阈值为best-threshold;cooldown(int):减少lr后恢复正常操作之前要等待的epoch数,默认为0;min_lr(float):学习率的下界,默认为0;eps(float):学习率的最小变化值。如果调整后的学习率和调整前的差距小于eps的话,那么就不做任何调整,默认为1e-8。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/1081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

项目实践:贪吃蛇

引言 贪吃蛇作为一项经典的游戏,想必大家应该玩过。贪吃蛇所涉及的知识也不是很难,涉及到一些C语言函数、枚举、结构体、动态内存管理、预处理指令、链表、Win32 API等。这里我会介绍贪吃蛇的一些思路。以及源代码也会给大家放到文章末尾。 我们最终的…

优雅的最大公约数函数

记录一个极其优雅的最大公约数方法 // 递归形式 int gcd(int a, int b) {return b 0 ? a : gcd(b, a % b); }这里求最大公约数的方法使用了辗转相除法,只是比循环求最大公约数的方法更加优雅与简洁: // 迭代形式 int gcd(int a, int b) {while(b ! 0…

电大搜题微信公众号:福建开放大学学子的学习新篇章

在当今信息化时代,学习已经成为每个人不可或缺的一部分。福建开放大学,作为广播电视大学的重要一员,始终致力于为学生提供优质、灵活的教育资源。而电大搜题微信公众号的推出,更是为福建开放大学的学子们带来了全新的学习体验&…

【数学】常用等价无穷小及其注意事项示例

常用极限 lim ⁡ x → 0 sin ⁡ x x 1 \lim_{x \to 0} {\frac{\sin x}{x}}1 limx→0​xsinx​1 lim ⁡ x → 0 ( x 1 ) 1 x e \lim_{x \to 0} {(x1)^\frac{1}{x}}e limx→0​(x1)x1​e lim ⁡ n → ∞ a n 1 \lim_{n \to \infty} {\sqrt[n]{a}}1 limn→∞​na ​1 lim ⁡ n…

数组中两个字符串的最短距离---一题多解(贪心/二分)

点击跳转到题目 方法:贪心 / 二分 目录 贪心: 二分: 贪心: 要找出字符串数组中指定两个字符串的最小距离,即找出指定字符串对应下标之差的最小值 思考:如果是直接暴力求解,需要两层for循环…

VLOOKUP函数使用,为什么会报错“引用有问题”?

VLOOKUP函数的使用非常广泛,在excel2007版之后的软件中,使用VLOOKUP函数也许会遇到这样的场景,明明公式是没有问题的,公式还会报错“引用有问题”。 一、报错场景 输入公式后,回车确认,显示如下报错&…

xilinx cpri ip 开发记录

CPRI是无线通信里的一个标准协议,连接REC和RE的通信。 Xilinx有提供CPRI IP核。 区别于其它通信协议,如以太网等,CPRI是一个同步系统。 这就意味着两端的Master和Slave应当是同源时钟的,两边不存在频差,并且内部延时…

mysql 行锁,间隙锁,临键锁,锁范围和死锁实际例子实战

文章目录 背景锁介绍表默认数据测试唯一键记录存在事务1事务2结论 唯一键记录不存在事务1事务2结论 范围查询事务1事务2结论 普通索引存在事务1事务2总结 普通索引不存在事务A事务B结论 死锁例子 背景 想了解下RR事务如何防止幻读的,以及一个实际的死锁例子 锁介绍…

【计算机网络】面经

1.TCP&UDP 1.1TCP与UDP的区别 TCP传输数据稳定可靠,适用于对网络通信质量要求较高的场景。 面向连接。 每一条TCP有且只有两个端点,为一对一关系。 提供可靠交付。 全双工通信,全双工为即可传输又可接收。 面向字节流。 UDP的优点是速…

客户端动态降级系统

本文字数:4576字 预计阅读时间:20分钟 01 背景 无论是iOS还是Android系统的设备,在线上运行时受硬件、网络环境、代码质量等多方面因素影响,可能会导致性能问题,这一类问题有些在开发阶段是发现不了的。如何在线上始终…

微服务架构中的业务完整性验证设计

目录 1.概要设计 1.1 功能完整性与正确性验证 1.2 性能与响应速度验证 1.3 安全性验证 1.4 容错性与恢复能力验证 1.5 监控与日志记录验证 2.技术实现 2.1 测试策略与工具选择 2.2 身份验证与授权 2.3 数据一致性与事务管理 2.4 监控与日志 2.5 容错与恢复 2.6 数…

【linux kernel】 一文总结linux内核中的kobject、kset和ktype

文章目录 一、kobject、kset、ktype(1-1)kobject(1-2)ktype(1-3)kset 二、kobject操作API(2-1)kobject_init()(2-2)kobject_add()(2-3&#xff09…

【命名空间详解】c++入门

目录 命名空间的定义 1.命名空间的正常定义 2.命名空间还可以嵌套 3. 命名空间可以合并 命名空间的使用 1.加命名空间名称及作用域限定符 2.使用using将命名空间中某个成员引入 3.使用using namespace 命名空间名称 引入 输入,输出 输出 命名空间的定义 …

linux命令ar使用说明

ar 建立或修改备存文件,或是从备存文件中抽取文件 补充说明 ar命令 是一个建立或修改备存文件,或是从备存文件中抽取文件的工具,ar可让您集合许多文件,成为单一的备存文件。在备存文件中,所有成员文件皆保有原来的属…

Java技术学习|Git

学习材料声明 尚硅谷Git入门到精通全套教程(涵盖GitHub\Gitee码云\GitLab) GIt Git 是一个免费的、开源的分布式版本控制系统,可以快速高效地处理从小型到大型的各种项目。Git 易于学习,占地面积小,性能极快。 它具有…

ARM_day8:基于iic总线的通信

一、IIC总线的基本概念: iic总线是一种带应答的同步的、串行、半双工的通信方式,支持一个主机对应多个从机。它有一根SCL(时钟线)和一根SDA(数据线)组成,由于只有一根数据线,所以它是…

英伟达大跳水!一夜暴跌10%,市值蒸发2000亿

相信大家已经在各大社交平台上看到了,英伟达一夜蒸发了2000亿美元! GPT-3.5研究测试: https://hujiaoai.cn GPT-4研究测试: https://higpt4.cn Claude-3研究测试(全面吊打GPT-4): https://hic…

大语言模型隐私防泄漏:差分隐私、参数高效化

大语言模型隐私防泄漏:差分隐私、参数高效化 写在最前面题目6:大语言模型隐私防泄漏Differentially Private Fine-tuning of Language Models其他初步和之前的基线微调模型1微调模型2通过低秩自适应进行微调( 实例化元框架1) 在隐…

Vue2 —— 学习(九)

目录 一、全局事件总线 (一)全局总线介绍 关系图 对图中的中间商 x 的要求 1.所有组件都能看到 2.有 $on $off $emit (二)案例 发送方 student 接收方 二、消息订阅和发布 (一)介绍 &#xff08…

虚拟机中的打印机,无法打印内容,打印的是白纸或英文和数字,打印不了中文

原因:打印机驱动设置不正确 解决方案: 打开打印机属性 -> 高级 -> 新驱动程序 下一页 -> Windows 更新 耐心等待,时间较长。 选择和打印机型号匹配的驱动,我选择的是: 虽然虚拟机和主机使用的驱动不…