【深度学习笔记】优化算法——学习率调度器

学习率调度器

🏷sec_scheduler

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。
然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要(有关详细信息,请参见 :numref:sec_momentum)。直观地说,这是最不敏感与最敏感方向的变化量的比率。
  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 :numref:sec_minibatch_sgd比较详细地讨论了这一点,在 :numref:sec_sgd中我们则分析了性能保证。简而言之,我们希望速率衰减,但要比 O ( t − 1 2 ) \mathcal{O}(t^{-\frac{1}{2}}) O(t21)慢,这样能成为解决凸问题的不错选择。
  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式(详情请参阅 :numref:sec_numerical_stability),又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。
  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 :cite:Izmailov.Podoprikhin.Garipov.ea.2018来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。
在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

一个简单的问题

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。
为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。
此外,我们混合网络以提高性能。
由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论。如果需要,请参阅 :numref:chap_cnn进行复习。

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2ldef net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modelloss = nn.CrossEntropyLoss()
device = d2l.try_gpu()batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler=None):net.to(device)animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examplesfor i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()trainer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % 50 == 0:animator.add(epoch + i / len(train_iter),(train_loss, train_acc, None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# UsingPyTorchIn-Builtschedulerscheduler.step()else:# Usingcustomdefinedschedulerfor param_group in trainer.param_groups:param_group['lr'] = scheduler(epoch)print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。
例如设学习率为 0.3 0.3 0.3并训练 30 30 30次迭代。
留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。
两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)
train loss 0.128, train acc 0.951, test acc 0.885

在这里插入图片描述

学习率调度器

我们可以在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。
例如,以动态的方式来响应优化的进展情况。

lr = 0.1
trainer.param_groups[0]["lr"] = lr
print(f'learning rate is now {trainer.param_groups[0]["lr"]:.2f}')
learning rate is now 0.10

更通常而言,我们应该定义一个调度器。
当调用更新次数时,它将返回学习率的适当值。
让我们定义一个简单的方法,将学习率设置为 η = η 0 ( t + 1 ) − 1 2 \eta = \eta_0 (t + 1)^{-\frac{1}{2}} η=η0(t+1)21

class SquareRootScheduler:def __init__(self, lr=0.1):self.lr = lrdef __call__(self, num_update):return self.lr * pow(num_update + 1.0, -0.5)

让我们在一系列值上绘制它的行为。

scheduler = SquareRootScheduler(lr=0.1)
d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])


在这里插入图片描述

现在让我们来看看这对在Fashion-MNIST数据集上的训练有何影响。
我们只是提供调度器作为训练算法的额外参数。

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler)
train loss 0.270, train acc 0.901, test acc 0.876

在这里插入图片描述

这比以前好一些:曲线比以前更加平滑,并且过拟合更小了。
遗憾的是,关于为什么在理论上某些策略会导致较轻的过拟合,有一些观点认为,较小的步长将导致参数更接近零,因此更简单。
但是,这并不能完全解释这种现象,因为我们并没有真正地提前停止,而只是轻柔地降低了学习率。

策略

虽然我们不可能涵盖所有类型的学习率调度器,但我们会尝试在下面简要概述常用的策略:多项式衰减和分段常数表。
此外,余弦学习率调度在实践中的一些问题上运行效果很好。
在某些问题上,最好在使用较高的学习率之前预热优化器。

单因子调度器

多项式衰减的一种替代方案是乘法衰减,即 η t + 1 ← η t ⋅ α \eta_{t+1} \leftarrow \eta_t \cdot \alpha ηt+1ηtα其中 α ∈ ( 0 , 1 ) \alpha \in (0, 1) α(0,1)
为了防止学习率衰减到一个合理的下界之下,
更新方程经常修改为 η t + 1 ← m a x ( η m i n , η t ⋅ α ) \eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha) ηt+1max(ηmin,ηtα)

class FactorScheduler:def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):self.factor = factorself.stop_factor_lr = stop_factor_lrself.base_lr = base_lrdef __call__(self, num_update):self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)return self.base_lrscheduler = FactorScheduler(factor=0.9, stop_factor_lr=1e-2, base_lr=2.0)
d2l.plot(torch.arange(50), [scheduler(t) for t in range(50)])


在这里插入图片描述

接下来,我们将使用内置的调度器,但在这里仅解释它们的功能。

多因子调度器

训练深度网络的常见策略之一是保持学习率为一组分段的常量,并且不时地按给定的参数对学习率做乘法衰减。
具体地说,给定一组降低学习率的时间点,例如 s = { 5 , 10 , 20 } s = \{5, 10, 20\} s={5,10,20}
每当 t ∈ s t \in s ts时,降低 η t + 1 ← η t ⋅ α \eta_{t+1} \leftarrow \eta_t \cdot \alpha ηt+1ηtα
假设每步中的值减半,我们可以按如下方式实现这一点。

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
scheduler = lr_scheduler.MultiStepLR(trainer, milestones=[15, 30], gamma=0.5)def get_lr(trainer, scheduler):lr = scheduler.get_last_lr()[0]trainer.step()scheduler.step()return lrd2l.plot(torch.arange(num_epochs), [get_lr(trainer, scheduler)for t in range(num_epochs)])


在这里插入图片描述

这种分段恒定学习率调度背后的直觉是,让优化持续进行,直到权重向量的分布达到一个驻点。
此时,我们才将学习率降低,以获得更高质量的代理来达到一个良好的局部最小值。
下面的例子展示了如何使用这种方法产生更好的解决方案。

train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler)
train loss 0.191, train acc 0.928, test acc 0.889

在这里插入图片描述

余弦调度器

余弦调度器是 :cite:Loshchilov.Hutter.2016提出的一种启发式算法。
它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。
这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在 t ∈ [ 0 , T ] t \in [0, T] t[0,T]之间。

η t = η T + η 0 − η T 2 ( 1 + cos ⁡ ( π t / T ) ) \eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\pi t/T)\right) ηt=ηT+2η0ηT(1+cos(πt/T))

这里 η 0 \eta_0 η0是初始学习率, η T \eta_T ηT是当 T T T时的目标学习率。
此外,对于 t > T t > T t>T,我们只需将值固定到 η T \eta_T ηT而不再增加它。
在下面的示例中,我们设置了最大更新步数 T = 20 T = 20 T=20

class CosineScheduler:def __init__(self, max_update, base_lr=0.01, final_lr=0,warmup_steps=0, warmup_begin_lr=0):self.base_lr_orig = base_lrself.max_update = max_updateself.final_lr = final_lrself.warmup_steps = warmup_stepsself.warmup_begin_lr = warmup_begin_lrself.max_steps = self.max_update - self.warmup_stepsdef get_warmup_lr(self, epoch):increase = (self.base_lr_orig - self.warmup_begin_lr) \* float(epoch) / float(self.warmup_steps)return self.warmup_begin_lr + increasedef __call__(self, epoch):if epoch < self.warmup_steps:return self.get_warmup_lr(epoch)if epoch <= self.max_update:self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2return self.base_lrscheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)
d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])


在这里插入图片描述

在计算机视觉的背景下,这个调度方式可能产生改进的结果。
但请注意,如下所示,这种改进并不一定成立。

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.3)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler)
train loss 0.207, train acc 0.923, test acc 0.892

在这里插入图片描述

预热

在某些情况下,初始化参数不足以得到良好的解。
这对某些高级网络设计来说尤其棘手,可能导致不稳定的优化结果。
对此,一方面,我们可以选择一个足够小的学习率,
从而防止一开始发散,然而这样进展太缓慢。
另一方面,较高的学习率最初就会导致发散。

解决这种困境的一个相当简单的解决方法是使用预热期,在此期间学习率将增加至初始最大值,然后冷却直到优化过程结束。
为了简单起见,通常使用线性递增。
这引出了如下表所示的时间表。

scheduler = CosineScheduler(20, warmup_steps=5, base_lr=0.3, final_lr=0.01)
d2l.plot(torch.arange(num_epochs), [scheduler(t) for t in range(num_epochs)])


在这里插入图片描述

注意,观察前5个迭代轮数的性能,网络最初收敛得更好。

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.3)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler)
train loss 0.261, train acc 0.904, test acc 0.878

在这里插入图片描述

预热可以应用于任何调度器,而不仅仅是余弦。
有关学习率调度的更多实验和更详细讨论,请参阅 :cite:Gotmare.Keskar.Xiong.ea.2018
其中,这篇论文的点睛之笔的发现:预热阶段限制了非常深的网络中参数的发散程度 。
这在直觉上是有道理的:在网络中那些一开始花费最多时间取得进展的部分,随机初始化会产生巨大的发散。

小结

  • 在训练期间逐步降低学习率可以提高准确性,并且减少模型的过拟合。
  • 在实验中,每当进展趋于稳定时就降低学习率,这是很有效的。从本质上说,这可以确保我们有效地收敛到一个适当的解,也只有这样才能通过降低学习率来减小参数的固有方差。
  • 余弦调度器在某些计算机视觉问题中很受欢迎。
  • 优化之前的预热期可以防止发散。
  • 优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/737031.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

cefsharp(winForm)调用js脚本,js脚本调用c#方法

本博文针对js-csharp交互(相互调用的应用) (一)、js调用c#方法 1.1 类名称:cs_js_obj public class cs_js_obj{//注意,js调用C#,不一定在主线程上调用的,需要用SynchronizationContext来切换到主线程//private System.Threading.SynchronizationContext context;//…

Elasticsearch 分享

一、Elasticsearch 基础介绍 ElasticSearch 是分布式实时搜索、实时分析、实时存储引擎&#xff0c;简称&#xff08;ES)&#xff0c; 成立于2012年&#xff0c;是一家来自荷兰的、开源的大数据搜索、分析服务提供商&#xff0c;为企业提供实时搜索、数据分析服务&#xff0c;…

AHU 汇编 实验四

实验名称&#xff1a;实验四 两个数的相乘 实验内容&#xff1a; 用子程序形式编写&#xff1a; A*B&#xff1a;从键盘输入a和b&#xff0c;计算A*B&#xff0c;其中乘法采用移位和累加完成 实验过程&#xff1a; 源代码&#xff1a; data segmentmul1 db 16,?,16 dup(?…

树莓派安装Nginx服务搭建web网站结合内网穿透实现公网访问本地站点

文章目录 1. Nginx安装2. 安装cpolar3.配置域名访问Nginx4. 固定域名访问5. 配置静态站点 安装 Nginx&#xff08;发音为“engine-x”&#xff09;可以将您的树莓派变成一个强大的 Web 服务器&#xff0c;可以用于托管网站或 Web 应用程序。相比其他 Web 服务器&#xff0c;Ngi…

什么是高级编程语言?——跟老吕学Python编程

什么是高级编程语言&#xff1f;——跟老吕学Python编程 高级编程语言简介高级编程语言发展历程高级编程语言特点高级编程语言分类命令式语言函数式语言逻辑式语言面向对象语言 常见的高级编程语言及其特点和应用领域高级编程语言性能分析高级编程语言的工作方式 高级编程语言简…

GPT出现Too many requests in 1 hour. Try again later.

换节点 这个就不用多说了&#xff0c;你都可以上GPT帐号了&#xff0c;哈…… 清除cooki

平面纯弯梁单元Matlab有限元编程 |欧拉梁单元| 简支梁|悬臂梁|弯矩图 |变形图| Matlab源码 | 视频教程

专栏导读 作者简介&#xff1a;工学博士&#xff0c;高级工程师&#xff0c;专注于工业软件算法研究本文已收录于专栏&#xff1a;《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现&#xff0c;并提供所有案例完整源码&#xff1b;2.单元…

容灾演练双月报|美创DRCC助力银行高效验证数据库高可用架构

了解更多灾备行业动态 守护数字化时代业务连续 目录 CONTENTS 01 灾备法规政策 02 热点安全事件 03 容灾演练典型案例 01 灾备法规政策 2月&#xff0c;工信部印发《工业领域数据安全能力提升实施方案&#xff08;2024—2026年&#xff09;》&#xff0c;要求到2026年…

专属你的时尚盛宴,尽在手机无人直播!

时尚&#xff0c;是一个永恒的话题。在这个充满活力的时代&#xff0c;时尚不仅仅是穿着打扮&#xff0c;更是一种生活态度&#xff0c;一种表达自我的方式。每个人都有自己独特的时尚理念&#xff0c;每个人都可以在时尚的世界里找到属于自己的一席之地。 手机无人直播&#…

鼠标在QTreeView、QTableView、QTableWidget项上移动,背景色改变

目录 1. 前言 2. 需求 3. 功能实现 3.1. 代码实现 3.2. 功能讲解 4. 附录 1. 前言 本博文用到了Qt的model/view framework框架,如果对Qt的“模型/视图/委托”框架不懂&#xff0c;本博文很难读懂。如果不懂这方面的知识&#xff0c;请在Qt Assistant 中输入Model/View…

力扣大厂热门面试算法题 15-17

15. 三数之和&#xff0c;16. 最接近的三数之和&#xff0c;17. 电话号码的字母组合&#xff0c;每题做详细思路梳理&#xff0c;配套Python&Java双语代码&#xff0c; 2024.03.11 可通过leetcode所有测试用例。 目录 15. 三数之和 解题思路 完整代码 Java Python ​…

Ubuntu 24.04 抢先体验换国内源 清华源 阿里源 中科大源 163源

Update 240307:Ubuntu 24.04 LTS 进入功能冻结期 预计4月25日正式发布。 Ubuntu22.04换源 Ubuntu 24.04重要升级daily版本下载换源步骤 (阿里源)清华源中科大源网易163源 Ubuntu 24.04 LTS&#xff0c;代号 「Noble Numbat」&#xff0c;即将与我们见面&#xff01; Canonica…

vue provide 与 inject使用

在vue项目中&#xff0c;如果遇到跨组件多层次传值的话&#xff0c;一般会用到vuex&#xff0c;或者其他第三方共享状态管理模式&#xff0c;如pinia等&#xff0c;但是对于父组件与多层次孙子组件时&#xff0c;建议使用provide 与 inject&#xff0c;与之其他方式相比&#x…

如何使用Everything+cpolar实现公网远程搜索下载内网储存文件资料

文章目录 前言1.软件安装完成后&#xff0c;打开Everything2.登录cpolar官网 设置空白数据隧道3.将空白数据隧道与本地Everything软件结合起来总结 前言 要搭建一个在线资料库&#xff0c;我们需要两个软件的支持&#xff0c;分别是cpolar&#xff08;用于搭建内网穿透数据隧道…

高分1、2号卫星原始遥感影像数据

高分一号 高分一号卫高分一号卫星是中国高分辨率对地观测系统的首发星&#xff0c;突破了高空间分辨率、多光谱与宽覆盖相结合的光学遥感等关键技术&#xff0c;设计寿命5至8年。 高分辨率对地观测系统工程是《国家中长期科学和技术发展规划纲要(2006&#xff5e;2020年)》确定…

StarRocks实战——欢聚集团极速的数据分析能力

目录 一、大数据平台架构 二、OLAP选型及改进 三、StarRocks 经验沉淀 3.1 资源隔离&#xff0c;助力业务推广 3.1.1 面临的挑战 3.1.2 整体效果 3.2 稳定优先&#xff0c;监控先行&#xff0c;优化运维 3.3降低门槛&#xff0c;不折腾用户 3.3.1 与现有的平台做打通 …

分库分表浅析原理

数据库存放数据大了&#xff0c;查询等操作就会存在瓶颈&#xff0c;怎么办&#xff1f; 1. 如果是单张表数据大了&#xff0c;可以在原有库上新建几张表table_0、table_1、table_2、.....table_n 写程序对数据进行分表&#xff1a; --这里提供一种一种分表策略,这里只需维护…

容器安全是什么?

容器安全定义 容器安全是指保护容器的完整性。这包括从其保管的应用到其所依赖的基础架构等全部内容。容器安全需要完整且持续。通常而言&#xff0c;企业拥有持续的容器安全涵盖两方面&#xff1a; 保护容器流水线和应用保护容器部署环境和基础架构 如何将安全内置于容器流…

Java开发从入门到精通(一):Java的基础语法项目案例

Java大数据开发和安全开发 Java 案例练习案例一:买飞机票案例二:开发验证码案例三:评委打分案例四:数字加密案例五:数组拷贝案例六: 抢红包案例七:找素数案例八:模拟双色球[拓展案例] Java 案例练习 案例一:买飞机票 分析: 方法是需要接收数据?需要接收机票原价、当前月份、舱…

新手如何快速上手学习单片机?

读者朋友能容我&#xff0c;不使博文负真心 新开专栏&#xff0c;期待与诸君共享精彩 个人主页&#xff1a;17_Kevin-CSDN博客 专栏&#xff1a;《单片机》 学习单片机是一个有趣且有挑战性的过程。单片机是一种微控制器&#xff0c;广泛应用于各种电子设备和嵌入式系统中。在这…