【Python】学习率调整策略详解和示例

学习率调整得当将有助于算法快速收敛和获取全局最优,以获得更好的性能。本文对学习率调度器进行示例介绍。

  • 学习率调整的意义
  • 基础示例
    • 无学习率调整方法
    • 学习率调整方法一
    • 多因子调度器
    • 余弦调度器
  • 结论

学习率调整的意义

首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果(陷入局部最优)。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 简而言之,我们希望速率衰减,但要比慢,这样能成为解决凸问题的不错选择

另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。本文将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

基础示例

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。

无学习率调整方法

import math
import torch
from torch import nn
from torch.optim import lr_scheduler, SGD
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as pltdef load_data_fashion_mnist(batch_size):# 定义数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 加载训练集和测试集train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader
def net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modeldef train(net, train_loader, test_loader, num_epochs, loss, optimizer, device, scheduler=None):net.to(device)running_loss = 0.0train_losses = []test_losses = []test_accuracies = []for epoch in range(num_epochs):for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)# Zero the parameter gradientsoptimizer.zero_grad()# Forward passoutputs = net(inputs)loss_value = loss(outputs, labels)# Backward and optimizeloss_value.backward()optimizer.step()# Print statisticsrunning_loss += loss_value.item()# if i % 200 == 199:  # print every 200 mini-batches#     print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200}')#     running_loss = 0.0train_losses.append(running_loss / len(train_loader))# Evaluate the model on the test datasettest_loss, test_acc = evaluate(net, test_loader, device)test_losses.append(test_loss)test_accuracies.append(test_acc)print(f'Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Test Acc: {test_accuracies[-1]:.2f}')if scheduler:if scheduler.__module__ == lr_scheduler.__name__:scheduler.step()else:for param_group in  optimizer.param_groups:param_group['lr'] = scheduler(epoch)plt.figure(figsize=(10, 6))plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')plt.title('Training, Test Losses and Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Loss / Accuracy')plt.legend()plt.grid(True)plt.savefig("1.jpg")plt.show()def evaluate(model, data_loader, device):model.eval()test_loss = 0correct = 0with torch.no_grad():for inputs, labels in data_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)test_loss += nn.CrossEntropyLoss(reduction='sum')(outputs, labels).item()_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()test_loss /= len(data_loader.dataset)accuracy = correct / len(data_loader.dataset)#accuracy = 100. * correct / len(data_loader.dataset)return test_loss, accuracy# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Define the model
model = net_fn()# Define the loss function
loss = nn.CrossEntropyLoss()# Define the optimizer
lr=0.3
optimizer = SGD(model.parameters(), lr=lr)# Load the dataset
batch_size=128
train_loader, test_loader=load_data_fashion_mnist(batch_size)
num_epochs=30
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device)

这里没有使用学习率调整策略。训练过程和结果如下图所示:

.
.
.
.
Epoch 23, Train Loss: 0.1247, Test Loss: 0.3939, Test Acc: 0.90
Epoch 24, Train Loss: 0.1236, Test Loss: 0.4370, Test Acc: 0.89
Epoch 25, Train Loss: 0.1167, Test Loss: 0.4117, Test Acc: 0.89
Epoch 26, Train Loss: 0.1169, Test Loss: 0.4440, Test Acc: 0.89
Epoch 27, Train Loss: 0.1163, Test Loss: 0.4336, Test Acc: 0.89
Epoch 28, Train Loss: 0.1055, Test Loss: 0.4312, Test Acc: 0.90
Epoch 29, Train Loss: 0.1065, Test Loss: 0.4942, Test Acc: 0.89
Epoch 30, Train Loss: 0.1051, Test Loss: 0.4763, Test Acc: 0.89

在这里插入图片描述

学习率调整方法一

设置在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。

在代码最后添加SquareRootScheduler类,并更新train()函数参数,其它内容不变。

class SquareRootScheduler:def __init__(self, lr=0.1):self.lr = lrdef __call__(self, num_update):return self.lr * pow(num_update + 1.0, -0.5)scheduler = SquareRootScheduler(lr=0.1)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行代码,可得相应参数值和变化过程,如下所示。

Epoch 23, Train Loss: 0.1823, Test Loss: 0.2811, Test Acc: 0.90
Epoch 24, Train Loss: 0.1801, Test Loss: 0.2800, Test Acc: 0.90
Epoch 25, Train Loss: 0.1767, Test Loss: 0.2819, Test Acc: 0.90
Epoch 26, Train Loss: 0.1747, Test Loss: 0.2800, Test Acc: 0.91
Epoch 27, Train Loss: 0.1720, Test Loss: 0.2818, Test Acc: 0.90
Epoch 28, Train Loss: 0.1689, Test Loss: 0.2856, Test Acc: 0.90
Epoch 29, Train Loss: 0.1669, Test Loss: 0.2907, Test Acc: 0.90
Epoch 30, Train Loss: 0.1641, Test Loss: 0.2813, Test Acc: 0.90

在这里插入图片描述
我们可以看出曲线比没有策略时平滑了很多,效果有所提升。

多因子调度器

多因子调度器。
在这里插入图片描述
在这里插入图片描述
代码部分修改:

scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)

运行结果为:
在这里插入图片描述
可见效果不理想,出现过拟合现象。

余弦调度器

余弦调度器是 (Loshchilov and Hutter, 2016)提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在
之间。
在这里插入图片描述
代码中添加CosineScheduler类和修改scheduler。

class CosineScheduler:def __init__(self, max_update, base_lr=0.01, final_lr=0,warmup_steps=0, warmup_begin_lr=0):self.base_lr_orig = base_lrself.max_update = max_updateself.final_lr = final_lrself.warmup_steps = warmup_stepsself.warmup_begin_lr = warmup_begin_lrself.max_steps = self.max_update - self.warmup_stepsdef get_warmup_lr(self, epoch):increase = (self.base_lr_orig - self.warmup_begin_lr) \* float(epoch) / float(self.warmup_steps)return self.warmup_begin_lr + increasedef __call__(self, epoch):if epoch < self.warmup_steps:return self.get_warmup_lr(epoch)if epoch <= self.max_update:self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2return self.base_lr#scheduler = SquareRootScheduler(lr=0.1)
#scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)
scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行结果如下:

在这里插入图片描述
过拟合现象消失,效果提升。

结论

在开发时应根据自己需要,选择合适的学习率调整策略。优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

注:部分内容摘选子书籍《动手学深度学习》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/773848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java 面向对象入门

类的创建 右键点击对应的包&#xff0c;点击新建选择java类 填写名称一般是名词&#xff0c;要知道大概是什么的名称&#xff0c;首字母一般大写 下面是创建了一个Goods类&#xff0c;里面的成员变量有&#xff1a;1.编号&#xff08;id&#xff09;&#xff0c;2.名称&#x…

Android 性能优化(六):启动优化的详细流程

书接上文&#xff0c;Android 性能优化&#xff08;一&#xff09;&#xff1a;闪退、卡顿、耗电、APK 从用户体验角度有四个性能优化方向&#xff1a; 追求稳定&#xff0c;防止崩溃追求流畅&#xff0c;防止卡顿追求续航&#xff0c;防止耗损追求精简&#xff0c;防止臃肿 …

【IT之家】IT之家网站的资讯文章资源,实时数据抓取检索软件免费下载NO.65

简介&#xff1a;IT之家是业内领先的IT资讯和数码产品类网站。IT之家快速精选泛科技新闻&#xff0c;分享即时的IT业界动态和紧跟潮流的数码产品资讯&#xff0c;提供给力的PC和手机技术文章、丰富的系统应用美化资源&#xff0c;以及享不尽的智能阅读。 本软件基于C#实现的win…

苹果 WWDC 24 将举行;高通、谷歌、英特尔等联合开发 AI 软件;艺术家谈及使用 Sora 创作视频体验

▶ 苹果WWDC 24 将于当地时间 6 月 10 日召开 3 月 27 日凌晨&#xff0c;苹果官宣将于当地时间 6 月 10 日举行今年的全球开发者发布大会。 苹果全球营销高级副总裁 Greg Joswiak 在社交媒体上表示&#xff1a;「在您的日历标记上 WWDC24 吧。这场活动无疑会令人惊喜&#xf…

数字化转型核心:实现业务与技术深度融合的运维数字化管理之道

写在前面 数字化转型已经成为大势所趋&#xff0c;各行各业正朝着数字化方向转型&#xff0c;利用数字化转型方法论和前沿科学技术实现降本、提质、增效&#xff0c;从而提升竞争力。 数字化转型是一项长期工作&#xff0c;包含的要素非常丰富&#xff0c;如数字化转型顶层设…

Spring:面试八股

文章目录 参考Spring模块CoreContainerAOP 参考 JavaGuide Spring模块 CoreContainer Spring框架的核心模块&#xff0c;主要提供IoC依赖注入功能的支持。内含四个子模块&#xff1a; Core&#xff1a;基本的核心工具类。Beans&#xff1a;提供对bean的创建、配置、管理功能…

同城双活:交易链路的稳定性与可靠性探索

知易行难&#xff0c;双活过程中遇到了非常多的问题&#xff0c;但是回过头看很难完美的表述出来&#xff0c;之所以这么久才行文也是这个原因&#xff0c;总是希望可以尽可能的复现当时的思考、问题细节及解决方案&#xff0c;但是写出来才发现能给出的都是多次打磨、摸索之后…

阿里云安装宝塔后面板打不开

前言 按理来说装个宝塔面板应该很轻松的&#xff0c;我却装了2天&#xff0c;真挺恼火的&#xff0c;网上搜的教程基本上解决不掉我的问题点&#xff0c;问了阿里云和宝塔客服&#xff0c;弄了将近2天&#xff0c;才找出问题出在哪里&#xff0c;在此记录一下问题的处理。 服…

MySQL索引优化二

分页查询优化 很多时候我们的业务系统实现分页功能可能会用如下sql实现 select * from employees limit 10000,10;表示从表employees中取出从10001行开始的10条记录.看似只查询了10条记录,实际这条sql是先读取10010条记录,然后抛弃前10000条记录,然后读到后面10条想要的数据,…

Pillow教程07:调整图片的亮度+对比度+色彩+锐度

---------------Pillow教程集合--------------- Python项目18&#xff1a;使用Pillow模块&#xff0c;随机生成4位数的图片验证码 Python教程93&#xff1a;初识Pillow模块&#xff08;创建Image对象查看属性图片的保存与缩放&#xff09; Pillow教程02&#xff1a;图片的裁…

JVM本地方法

本地方法接口 NAtive Method就是一个java调用非java代码的接口 本地方法栈&#xff08;Native Method Statck&#xff09; Java虚拟机栈用于管理Java方法的调用&#xff0c;而本地方法栈用于管理本地方法的调用。 本地方法栈&#xff0c;也是线程私有的。 允许被实现成固定或…

【机器学习之---数学】统计学基础概念

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 统计学基础 1. 频率派 频率学派&#xff08;传统学派&#xff09;认为样本信息来自总体&#xff0c;通过对样本信息的研究可以合理地推断和估计总体信息…

Transformer 模型中增加一个 Token 对计算量的影响

Transformer 模型中增加一个 Token 对计算量的影响 Transformer 模型中增加一个 Token 对计算量的影响1. Transformer 模型简介2. Token 对计算量的影响3. 增加一个 Token 的计算量估算4. 应对策略5. 结论 Transformer 模型中增加一个 Token 对计算量的影响 Transformer 模型作…

【无标题】C高级325

练习1&#xff1a;输入一个数&#xff0c;实现倒叙123-》321 练习2&#xff1a;输入一个&#xff0c;判断是否是素数 练习3&#xff1a;输入一个文件名&#xff0c; 判断是否在家目录下存在, 如果是一个目录&#xff0c;则直接输出是目录下的sh文件的个数 如果存在则判断是否是…

ELF 1技术贴|应用层更改引脚复用的方法

在嵌入式系统设计中&#xff0c;引脚复用功能通常是通过设备树(Device Tree)预先配置设定的。出厂的设备树中UART2_TX_DATA和UART2_RX_DATA两个引脚被复用成了UART2功能&#xff0c;如果想要在不更换系统镜像的情况下&#xff0c;将这两个引脚的功能转换为GPIO&#xff0c;并作…

深入探讨iOS开发:从创建第一个iOS程序到纯代码实现全面解析

iOS开发作为移动应用开发的重要领域之一&#xff0c;对于开发人员具有重要意义。本文将深入探讨iOS开发的各个方面&#xff0c;从创建第一个iOS程序到纯代码实现iOS开发&#xff0c;带领读者全面了解iOS应用程序的开发流程和技术要点。 &#x1f4f1; 第一个iOS程序 在创建第…

基于springboot+vue+Mysql的超市进销存系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

工地污水处理一体化成套设备如何选型

工地污水处理一体化成套设备的选型是确保工地污水处理效果的关键。在选择合适的设备前&#xff0c;我们需要考虑几个重要因素。 首先&#xff0c;我们需要评估工地的实际污水处理需求。包括污水产生量、水质特征、处理要求等。通过了解工地的情况&#xff0c;我们能够确定适合处…

《探索移动开发的未来之路》

移动开发作为当今科技领域中最为炙手可热的领域之一&#xff0c;正以惊人的速度不断迭代和发展。从技术进展到应用案例&#xff0c;再到面临的挑战与机遇以及未来的趋势&#xff0c;移动开发都呈现出了令人瞩目的发展前景。本文将围绕移动开发的技术进展、行业应用案例、面临的…

服务运营 | 印第安纳大学翟成成:改变生活的水井选址

编者按&#xff1a; 作者于2023年4月在“Production and Operations Management”上发表的“Improving drinking water access and equity in rural Sub-Saharan Africa”探讨了欠发达地区水资源供应中的可达性和公平性问题。作者于2020年1月去往非洲埃塞俄比亚提格雷地区进行…