人工智能-优化算法之学习率调度器

学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 (Izmailov et al, 2018)来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2ldef net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modelloss = nn.CrossEntropyLoss()
device = d2l.try_gpu()batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler=None):net.to(device)animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examplesfor i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()trainer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % 50 == 0:animator.add(epoch + i / len(train_iter),(train_loss, train_acc, None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# UsingPyTorchIn-Builtschedulerscheduler.step()else:# Usingcustomdefinedschedulerfor param_group in trainer.param_groups:param_group['lr'] = scheduler(epoch)print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为0.3并训练30次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)

train loss 0.128, train acc 0.951, test acc 0.885

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/189348.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Todesk 无法登录,无法联网

前言 我习惯用todesk远程ubuntu,但是突然发现掉线了,但是ssh还能连接 问题查找 1.ping 一下主机ip 2.ssh连接后,ping 一下百度,查看是否外网正常 3.输入一下命令 ps -ef | grep todesk #查看todesk 进程 sudo kill -9 ....…

快速掌握Pyqt5的20种输入控件(Input Widgets)

Pyqt5相关文章: 快速掌握Pyqt5的三种主窗口 快速掌握Pyqt5的2种弹簧 快速掌握Pyqt5的5种布局 快速弄懂Pyqt5的5种项目视图(Item View) 快速弄懂Pyqt5的4种项目部件(Item Widget) 快速掌握Pyqt5的6种按钮 快速掌握Pyqt5的10种容器&…

HTML5 的全局属性 hidden 和 display:none 的关系

目录 1,hidden 和 display:none 的关系2,其他隐藏元素的方式2.1,语意上的隐藏2.2,视觉上的隐藏 1,hidden 和 display:none 的关系 hidden - MDN 参考 一句话总结:hidden 是HTML5 新增的全局布尔属性&…

Centos7使用阿里云镜像加速服务安装Docker

文章目录 一、前提说明二、安装docker1、创建docker文件夹2、安装所需的软件包3、设置Docker仓库4、安装docker5、启动验证使用阿里云镜像加速服务 三、卸载docker 一、前提说明 需要先安装好虚拟机,可以查看这篇https://blog.csdn.net/qq_36433289/article/detail…

Python批处理PDF文件,PDF附件轻松批量提取

PDF附件是指在PDF文档中嵌入的其他文件,如图像、表格、音频、视频或其他文档。这些附件可以与PDF文档一起存储、传输和共享,为文档提供了更丰富的内容和更多的功能。通过添加附件,我们可以将相关文件和信息捆绑在一起,使其更易于管…

Verilog 入门(五)数据流模型化

文章目录 连续赋值语句时延 连续赋值用于数据流行为建模;相反,过程赋值用于顺序行为建模。组合逻辑电路的行为最好使用连续赋值语句建模。 连续赋值语句 连续赋值语句将值赋给线网(连续赋值不能为寄存器赋值),它的格式…

Linux 只能收到 SYN 包 不能回包

如果用户发现云主机不能登录,例如无法远程 22 端口或其他端口,但是更换网络环境正常,服务端抓包发现客户端发包只有 SYN,没有回包,可以执行 netstat -s |grep rejec 查看下是否是 tcp_timestamps 的问题 [roothfgo2 ~…

Java的53个关键字分类及详细说明(包含3个特殊直接量+2个保留字)

文章目录 关键字,特殊直接量,保留字关键字的详细用法说明(1)访问控制类关键字(2)修饰符类关键字(3)程序控制类关键字(4)错误处理类关键字(5)包相关…

Python+Requests模拟发送GET请求

模拟发送GET请求 前置条件:导入requests库 一、发送不带参数的get请求 代码如下: 以百度首页为例 import requests# 发送get请求 response requests.get(url"http://www.baidu.com") print(response.content.decode("utf-8"))…

Drift plus penalty 漂移加惩罚Part2——性能分析

文章目录 正文Performance analysisAverage penalty analysis 平均惩罚分析Average queue size analysis 平均队列大小分析Probability 1 convergenceApplication to queues with finite capacityTreatment of queueing systemsConvex functions of time averages Delay tradeo…

SSR是什么?Vue中怎么实现?

一、是什么 概念 SSR是指服务器端渲染(Server-Side Rendering),是一种将客户端和服务器端合并的 Web 应用程序渲染技术。在 SSR 中,应用程序的 UI 在服务器端渲染完成后,再将整个渲染好的 HTML、CSS 和 JavaScript 发…

西南科技大学信号与系统A课程设计报告(信号卷积与应用)

一、设计任务 编制MATLAB程序,实现任意两输入信号的卷积和运算,并正确显示输入信号、输出信号的波形。(程序文件名:ConvSum_本人姓名首拼小写.m,本程序不允许使用conv函数)输入信号x为频率0.1 Hz和0.3 Hz的等幅正弦波之和,利用fir1函数设计滤波器h,去除0.3 Hz的正弦信号…

Vue 3.0 组合式API Setup

文章目录 前言参数Props上下文访问组件的 property结合模板使用使用渲染函数使用 this后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:vue.js 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不…

使用WalletConnect Web3Modal v3 链接钱包基础教程

我使用的是vueethers 官方文档:WalletConnect 1.安装 yarn add web3modal/ethers ethers 或者 npm install web3modal/ethers ethers2.引用 新建一个js文件,在main.js中引入,初始化配置sdk import {createWeb3Modal,defaultConfig, } from…

CMMI认证含金量高吗

一、CMMI认证含金量解答 CMMI,即能力成熟度模型集成,是由美国卡内基梅隆大学软件工程研究所开发的一种评估企业软件开发过程成熟度的模型。CMMI认证的含金量究竟高不高呢?答案是肯定的。CMMI认证被誉为软件开发行业的“金牌标准”&#xff0…

【算法】合并K个升序链表

这道题主要考察的是归并排序,因为已经升序过了,更好理解了。 当然也可以采用分治的思路;或采用最小堆的思路;面试中校招同学写出一种即可,如果能全概览讲一下,就更加分了。 #############################…

力扣题:字符的统计-12.2

力扣题-12.2 [力扣刷题攻略] Re:从零开始的力扣刷题生活 力扣题1:423. 从英文中重建数字 解题思想:有的单词通过一个字母就可以确定,依次确定即可 class Solution(object):def originalDigits(self, s):""":typ…

okhttp系列-拦截器的执行顺序

1.将拦截器添加到ArrayList final class RealCall implements Call {Response getResponseWithInterceptorChain() throws IOException {//将Interceptor添加到ArrayListList<Interceptor> interceptors new ArrayList<>();interceptors.addAll(client.intercept…

03-IDEA集成Git,初始化本地库,添加远程仓库,提交,拉取,推送,分支的快捷操作

IDEA集成Git 创建Git忽略文件 不同的IDE开发工具有不同的特点文件,这些文件与项目的实际功能无关且不参与服务器上的部署运行, 把它们忽略掉能够屏蔽之间的差异 局部忽略配置文件: 在本地仓库的根目录即项目根目录下直接创建.gitignore文件, 以文件后缀或目录名的方式忽略指定…

双远心镜头:让视觉检测更精准、高效!

工业镜头是视觉系统中的重要组件&#xff0c;工业镜头的选型影响着整个系统的成像效果。在做视觉检测时&#xff0c;会遇到无法检测空间物体、无法控制视场变化、无法控制图像扭曲、对比度低、畸变大、反光等问题&#xff0c;这时普通的工业镜头并不能有效地解决问题&#xff0…