人工智能-优化算法之学习率调度器

学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 (Izmailov et al, 2018)来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2ldef net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modelloss = nn.CrossEntropyLoss()
device = d2l.try_gpu()batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler=None):net.to(device)animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examplesfor i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()trainer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % 50 == 0:animator.add(epoch + i / len(train_iter),(train_loss, train_acc, None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# UsingPyTorchIn-Builtschedulerscheduler.step()else:# Usingcustomdefinedschedulerfor param_group in trainer.param_groups:param_group['lr'] = scheduler(epoch)print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为0.3并训练30次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)

train loss 0.128, train acc 0.951, test acc 0.885

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/189348.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Todesk 无法登录,无法联网

前言 我习惯用todesk远程ubuntu,但是突然发现掉线了,但是ssh还能连接 问题查找 1.ping 一下主机ip 2.ssh连接后,ping 一下百度,查看是否外网正常 3.输入一下命令 ps -ef | grep todesk #查看todesk 进程 sudo kill -9 ....…

HTML5 的全局属性 hidden 和 display:none 的关系

目录 1,hidden 和 display:none 的关系2,其他隐藏元素的方式2.1,语意上的隐藏2.2,视觉上的隐藏 1,hidden 和 display:none 的关系 hidden - MDN 参考 一句话总结:hidden 是HTML5 新增的全局布尔属性&…

Centos7使用阿里云镜像加速服务安装Docker

文章目录 一、前提说明二、安装docker1、创建docker文件夹2、安装所需的软件包3、设置Docker仓库4、安装docker5、启动验证使用阿里云镜像加速服务 三、卸载docker 一、前提说明 需要先安装好虚拟机,可以查看这篇https://blog.csdn.net/qq_36433289/article/detail…

Python批处理PDF文件,PDF附件轻松批量提取

PDF附件是指在PDF文档中嵌入的其他文件,如图像、表格、音频、视频或其他文档。这些附件可以与PDF文档一起存储、传输和共享,为文档提供了更丰富的内容和更多的功能。通过添加附件,我们可以将相关文件和信息捆绑在一起,使其更易于管…

Verilog 入门(五)数据流模型化

文章目录 连续赋值语句时延 连续赋值用于数据流行为建模;相反,过程赋值用于顺序行为建模。组合逻辑电路的行为最好使用连续赋值语句建模。 连续赋值语句 连续赋值语句将值赋给线网(连续赋值不能为寄存器赋值),它的格式…

Python+Requests模拟发送GET请求

模拟发送GET请求 前置条件:导入requests库 一、发送不带参数的get请求 代码如下: 以百度首页为例 import requests# 发送get请求 response requests.get(url"http://www.baidu.com") print(response.content.decode("utf-8"))…

Drift plus penalty 漂移加惩罚Part2——性能分析

文章目录 正文Performance analysisAverage penalty analysis 平均惩罚分析Average queue size analysis 平均队列大小分析Probability 1 convergenceApplication to queues with finite capacityTreatment of queueing systemsConvex functions of time averages Delay tradeo…

SSR是什么?Vue中怎么实现?

一、是什么 概念 SSR是指服务器端渲染(Server-Side Rendering),是一种将客户端和服务器端合并的 Web 应用程序渲染技术。在 SSR 中,应用程序的 UI 在服务器端渲染完成后,再将整个渲染好的 HTML、CSS 和 JavaScript 发…

使用WalletConnect Web3Modal v3 链接钱包基础教程

我使用的是vueethers 官方文档:WalletConnect 1.安装 yarn add web3modal/ethers ethers 或者 npm install web3modal/ethers ethers2.引用 新建一个js文件,在main.js中引入,初始化配置sdk import {createWeb3Modal,defaultConfig, } from…

CMMI认证含金量高吗

一、CMMI认证含金量解答 CMMI,即能力成熟度模型集成,是由美国卡内基梅隆大学软件工程研究所开发的一种评估企业软件开发过程成熟度的模型。CMMI认证的含金量究竟高不高呢?答案是肯定的。CMMI认证被誉为软件开发行业的“金牌标准”&#xff0…

力扣题:字符的统计-12.2

力扣题-12.2 [力扣刷题攻略] Re:从零开始的力扣刷题生活 力扣题1:423. 从英文中重建数字 解题思想:有的单词通过一个字母就可以确定,依次确定即可 class Solution(object):def originalDigits(self, s):""":typ…

okhttp系列-拦截器的执行顺序

1.将拦截器添加到ArrayList final class RealCall implements Call {Response getResponseWithInterceptorChain() throws IOException {//将Interceptor添加到ArrayListList<Interceptor> interceptors new ArrayList<>();interceptors.addAll(client.intercept…

03-IDEA集成Git,初始化本地库,添加远程仓库,提交,拉取,推送,分支的快捷操作

IDEA集成Git 创建Git忽略文件 不同的IDE开发工具有不同的特点文件,这些文件与项目的实际功能无关且不参与服务器上的部署运行, 把它们忽略掉能够屏蔽之间的差异 局部忽略配置文件: 在本地仓库的根目录即项目根目录下直接创建.gitignore文件, 以文件后缀或目录名的方式忽略指定…

双远心镜头:让视觉检测更精准、高效!

工业镜头是视觉系统中的重要组件&#xff0c;工业镜头的选型影响着整个系统的成像效果。在做视觉检测时&#xff0c;会遇到无法检测空间物体、无法控制视场变化、无法控制图像扭曲、对比度低、畸变大、反光等问题&#xff0c;这时普通的工业镜头并不能有效地解决问题&#xff0…

校园门禁可视化系统解决方案

随着科技的持续进步&#xff0c;数字化校园在教育领域中的地位日益上升&#xff0c;各种智能门禁、安防摄像头等已遍布校园各个地方&#xff0c;为师生提供安全便捷的通行体验。然而数据收集分散、缺乏管理、分析困难等问题也逐渐出现&#xff0c;在这个数字化环境中&#xff0…

《opencv实用探索·六》简单理解图像膨胀

1、图像膨胀原理简单理解 膨胀是形态学最基本的操作&#xff0c;都是针对白色部分&#xff08;高亮部分&#xff09;而言的。膨胀就是使图像中高亮部分扩张&#xff0c;效果图拥有比原图更大的高亮区域。 2、图像膨胀的作用 注意一般情况下图像膨胀和腐蚀是联合使用的。 &…

scrapy介绍,并创建第一个项目

一、scrapy简介 scrapy的概念 Scrapy是一个Python编写的开源网络爬虫框架。它是一个被设计用于爬取网络数据、提取结构性数据的框架。 Scrapy 使用了Twisted异步网络框架&#xff0c;可以加快我们的下载速度。 Scrapy文档地址&#xff1a;http://scrapy-chs.readthedocs.io/z…

【.net core 7】新建net core web api并引入日志、处理请求跨域以及发布

效果图&#xff1a; 1.新建.net core web api项目 选择src文件夹》添加》新建项目 输入框搜索&#xff1a;web api 》选择ASP.NET Core Web API 输入项目名称、选择位置为项目的 src文件夹下 我的项目是net 7.0版本&#xff0c;实际选择请看自己的项目规划 2.处理Progr…

SpringBoot Bean解析

Bean解析 IOC介绍 松耦合灵活性可维护 注解方式配置Bean 实现方式1: Component声明,直接类上进行添加注解, 同时保证包扫描能扫到即可实现方式2: 配置类中使用Bean Configuration public class BeanConfiguration implements SuperConfiguration{Bean("dog")Ani…

基于DigiThread的仿真模型调参功能

仿真模型调参是指通过调整模型内部的参数值&#xff0c;使仿真模型的输出更符合实际系统的行为或者预期结果的过程。 仿真过程中&#xff0c;往往需要频繁对模型参数进行调整&#xff0c;通过观察不同参数下系统整体的运行情况&#xff0c;实现系统的性能、可靠性和效率的优化…