2学习率调整_学习率衰减

之前我们的优化,主要是聚焦于对梯度下降运动方向的调整,而在参数迭代更新的过程中,除了梯度,还有一个重要的参数是学习率α,对于学习率的调整也是优化的一个重要方面。

3a4b16727334d131400cc0477326ee3a.png

01

学习率衰减

首先我们以一个例子,来说明一下我们为什么需要学习率α衰减(learning rate decay)。如果学习率不衰减的话,如下图蓝线所示,由于噪音影响,代价函数更新路径相对不规则,但总体朝着最低点方向移动,但是移动到最低点附近时,由于学习率较大,每一次会移动相对较远的距离,容易直接跨过最低点,导致代价函数在更新完毕后距离最低点仍然相对较远;当我们随着迭代的次数逐渐降低学习率,那么便如绿线所示,一开始学习率较大,前进速度较快,到达最低点附近后,学习率降低到更小的值,于是最终更新完成后代价函数离最低点更近,也就是模型更加优化,预测值与实际值差距更小。

315c87570c823913ddb72fa05a5bc745.png

进行学习率衰减的一种方式如下所示,需要设置如下学习率更新规则:

d5eefa45499c7eb6da0e7fa6d4f014d2.png

1个epoch是指将所有mini-batch全部迭代一遍,即遍历一遍。假设α0=0.2,decay_rate=1,那么随着epoch增加,学习率α会如下图变化:

7ad8684e8c18cf5cb1d3fec8498c7d48.png

在应用这个公式时,我们需要选择合适的超参数α0和decay_rate。除了这种学习率衰减方式,还有一些其他方式来进行学习率衰减:

21c6b610741c2bd926cbab1cdc02eb69.png

3afbf291b9a71e2a9d832672a747dc8f.png

此外还有离散衰减,经过一段时间衰减一半:

1cbacb2baa5d51af92c42ca7a4bb2e5e.png

02

学习率衰减的pytorch实现

指数衰减

我们首先需要确定需要针对哪个优化器执行学习率动态调整策略,也就是首先定义一个优化器:
optimizer_ExpLR = torch.optim.SGD(net.parameters(), lr=0.1)
定义好优化器以后,就可以给这个优化器绑定一个指数衰减学习率控制器:
ExpLR = torch.optim.lr_scheduler.ExponentialLR(optimizer_ExpLR, gamma=0.98)
参数gamma表示衰减的底数,也就是decay_rate,选择不同的gamma值可以获得幅度不同的衰减曲线。

ce1b8775feda673a80af1632239792b7.png

固定步长衰减

即离散型衰减,学习率每隔一定步数(或者epoch)就减少为原来的gamma分之一,使用固定步长衰减依旧先定义优化器,再给优化器绑定StepLR对象:
optimizer_StepLR = torch.optim.SGD(net.parameters(), lr=0.1)StepLR = torch.optim.lr_scheduler.StepLR(optimizer_StepLR, step_size=step_size, gamma=0.65)
其中gamma参数表示衰减的程度,step_size参数表示每隔多少个step进行一次学习率调整,下面对比了不同gamma值下的学习率变化情况:

55bcba827dba94e404bb51e13cf04620.png

多步长衰减

有时我们希望不同的区间采用不同的更新频率,或者是有的区间更新学习率,有的区间不更新学习率,这就需要使用MultiStepLR来实现动态区间长度控制:
optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1)torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR,\                milestones=[200, 300, 320, 340, 200], gamma=0.8)
其中milestones参数为表示学习率更新的起止区间,在区间[0. 200]内学习率不更新,而在[200, 300]、[300, 320].....[340, 400]的右侧值都进行一次更新;gamma参数表示学习率衰减为上次的gamma分之一。其图示如下:

91aaccc7315e1c13be18d784e8d3980e.png

从图中可以看出,学习率在区间[200, 400]内快速的下降,这就是milestones参数所控制的,在milestones以外的区间学习率始终保持不变。

余弦退火衰减

严格的说,余弦退火策略不应该算是学习率衰减策略,因为它使得学习率按照周期变化,其定义方式如下:
optimizer_CosineLR = torch.optim.SGD(net.parameters(), lr=0.1)CosineLR = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_CosineLR, T_max=150, eta_min=0)
参数T_max表示余弦函数周期;eta_min表示学习率的最小值,默认它是0表示学习率至少为正值。确定一个余弦函数需要知道最值和周期,其中周期就是T_max,最值是初试学习率。下图展示了不同周期下的余弦学习率更新曲线:

117eecb78e2d4ef31906b3ea97864f2d.png

为网络的不同层设置不同的学习率

定义一个简单的网络结构:
class net(nn.Module):    def __init__(self):        super(net, self).__init__()        self.conv1 = nn.Conv2d(3, 64, 1)        self.conv2 = nn.Conv2d(64, 64, 1)        self.conv3 = nn.Conv2d(64, 64, 1)        self.conv4 = nn.Conv2d(64, 64, 1)        self.conv5 = nn.Conv2d(64, 64, 1)    def forward(self, x):        out = conv5(conv4(conv3(conv2(conv1(x)))))        return out
我们希望conv5学习率是其他层的100倍,我们可以:
net = net()lr = 0.001conv5_params = list(map(id, net.conv5.parameters())) # 1 base_params = filter(lambda p: id(p) not in conv5_params,                     net.parameters()) # 2,3optimizer = torch.optim.SGD([            {'params': base_params},            {'params': net.conv5.parameters(), 'lr': lr * 100}], lr=lr, momentum=0.9)
1. conv5_params = list(map(id,net.conv5.parameters()))中id()函数用于获取网络参数的内存地址,map()函数用于将id()函数作用于net.conv5.parameters()得到的每个参数上。2.lambda p: id(p)中lamda表达式是python中用于定义匿名函数的方式,其后面定义的是一个函数操作,冒号前的符号是函数的形式参数,用于接收参数,符合的个数表示需要接收的参数个数,冒号右边是具体的函数操作。3.filter()函数的作用是过滤掉不符合条件(False)的元素,返回一个迭代器对象。该函数接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判断,然后返回True或False,最后返回值为True的元素

Reference

深度学习课程 --吴恩达

https://zhuanlan.zhihu.com/p/93624972

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/470079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Codeforces Round #299 (Div. 2) D. Tavas and Malekas kmp

题目链接: http://codeforces.com/problemset/problem/535/DD. Tavas and Malekastime limit per test2 secondsmemory limit per test256 megabytes问题描述 Tavas is a strange creature. Usually "zzz" comes out of peoples mouth while sleeping, bu…

可怕的乖孩子_当今的中国,有句很可怕的话:所有的乖孩子注定不幸福!

来自soogif▼/01不知道从什么时候起,乖孩子被贴上了一个不幸福的标签, 一个表现很乖的孩子总是会被认为是因为缺乏爱和安全感,才表现的很乖,很懂事的样子的。事实真的是这样子的吗?No!爸爸去哪儿5里的Jaspe…

js获取当前日期星期几

var str "今天是星期" "日一二三四五六".charAt(new Date().getDay());alert(str); 转载于:https://www.cnblogs.com/lccnblog/p/5902525.html

适用于VS C++环境的注释代码段,可以让你的代码被使用时有高可读性的注释

编码时,在对高级语言(C#/VB etc)函数的访问时,经常会有很明确的函数功能提示,参数提示,与返回值提示。微软的VisualStudio C集成开发环境同样有这样的功能,只是常见开源的代码很少按照VS的注释格…

mysql 用户管理表_Mysql—用户表详解(mysql.user)

MySQL数据库Mysql—用户表详解(mysql.user)MySQL是一个多用户管理的数据库,可以为不同用户分配不同的权限,分为root用户和普通用户,root用户为超级管理员,拥有所有权限,而普通用户拥有指定的权限。MySQL是通过权限表来…

Orchard商城模块(Commerce)设计与后台部分

前言:使用CMS开发网站为目标,编写一个扩展性比较好的商城模块。 首先是整体流程图,大概介绍功能与设计。 接下来我们逐个模块功能介绍。 一。商品管理模块 商品模块中可发布需要在线售卖的商品 (套餐商品) 1.1 添加一个商品 1. 商品正常价&…

mysql数据库架构_MySQL数据库之互联网常用架构方案

一、数据库架构原则高可用高性能可扩展一致性二、常见的架构方案方案一:主备架构,只有主库提供读写服务,备库冗余作故障转移用jdbc:mysql://vip:3306/xxdb高可用分析:高可用,主库挂了,keepalive(只是一种工…

mysql数据库恢复策略_MySQL 备份和恢复策略(一)

在数据库表丢失或损坏的情况下,备份你的数据库是很重要的。如果发生系统崩溃,你肯定想能够将你的表尽可能丢失最少的数据恢复到崩溃发生时的状态。本文主要对MyISAM表做备份恢复。备份策略一:直接拷贝数据库文件(不推荐)备份策略二&#xff1…

laravel方法汇总详解

1.whereRaw() 用原生的SQL语句来查询,whereRaw(select * from user) 就和 User::all()方法是一样的效果 2.whereBetween() 查询时间格式 whereBetween(problem_date, [2016-10-05 19:00:00, 2016-10-05 20:35:10]) 这种可以查到,时间格式类似这种, 查询日…

输入输出优化

被各种变态的出题者出的数据坑到了这里/sad 1 int read() 2 { 3 int num0; char chgetchar(); 4 while(ch<0&&ch>9) chgetchar(); //过滤前面非数字字符 5 while(ch>0&&ch<9) {num*10;numch-0;chgetchar();} 6 return num…

mysql整数索引没用到_MYSQL 索引无效和索引有效的详细介绍

1、WHERE字句的查询条件里有不等于号(WHERE column!...)&#xff0c;MYSQL将无法使用索引2、类似地&#xff0c;如果WHERE字句的查询条件里使用了函数(如&#xff1a;WHERE DAY(column)...)&#xff0c;MYSQL将无法使用索引3、在JOIN操作中(需要从多个数据表提取数据时)&#x…

Qt词典搜索

Qt词典搜索 采用阿凡达数据-API数据接口及爱词霸API数据接口实现词典搜索功能&#xff0c;实例字符串搜索接口分别为&#xff1a;中文词组采用“词典”&#xff0c;中文单个字采用“中华字典”&#xff0c;英文或其他字符采用“爱词霸”&#xff1b; 对应的API接口&#xff1a;…

mysql8.0.13 rpm_Centos7 安装mysql 8.0.13(rpm)的教程详解

yum or rpm&#xff1f;yum安装方式很方便&#xff0c;但是下载mysql的时候从官网下载&#xff0c;速度较慢。rpm安装方式可以从国内镜像下载mysql的rpm包&#xff0c;比较快。rpm也适合离线安装。环境说明•操作系统&#xff1a;Centos7.4 (CentOS-7-x86_64-Minimal-1804.iso)…

如何参与一个GitHub开源项目

Github作为开源项目的著名托管地&#xff0c;可谓无人不知&#xff0c;越来越多的个人和公司纷纷加入到Github的大家族里来&#xff0c;为开源尽一份绵薄之力。对于个人来讲&#xff0c;你把自己的项目托管到Github上并不表示你参与了Github开源项目&#xff0c;只能说你开源了…

mysql数据库的多实例_MySQL数据库多实例应用实战 - 橙子柠檬's Blog

本文采用的是/data目录作为mysql多实例总的根目录&#xff0c;然后规划不同 的MySQL实例端口号来作为/data下面的二级目录&#xff0c;不同的端口号就是不同实例目录&#xff0c;以区别不同的实例&#xff0c;二级目录下包含mysql数据文件&#xff0c;配置文件以及启动文件的目…

微信企业号开发[二]——获取用户信息

注&#xff1a;文中绿色部分为摘自微信官方文档 在《微信企业号开发[一]——创建应用》介绍了如何创建应用&#xff0c;但是当用户点击应用跳转到我们设定的URL时&#xff0c;其实并没有带上用户的任何信息&#xff0c;为了获取用户信息&#xff0c;我们需要借助微信提供的OAut…

mysql全套基础知识_Mysql基础知识整理

MySQL的查询过程 (一条sql语句在MySQL中如何执行)&#xff1a;客户端请求 ---> 连接器(验证用户身份&#xff0c;给予权限) ---> 查询缓存(存在缓存则直接返回&#xff0c;不存在则执行后续操作) ---> 分析器(对SQL进行词法分析和语法分析操作) ---> 优化器(主要对…

渗透思维导图

转载于:https://www.cnblogs.com/DonAndy/p/5914747.html

mysql数据库用户的创建_mysql创建用户及数据库

登陆mysql[rootdn210120 conf]# mysql -uroot创建用户及密码mysql> grant usage on *.* to hive14localhost identified by 123456 with grant option;创建数据库mysql> create database hive14 DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;赋予新用户操作新数据…

【计算机视觉】论文笔记:Ten years of pedestrian detection, what have we learned?

最近正在研究行人检测&#xff0c;学习了一篇2014年发表在ECCV上的一篇综述性的文章&#xff0c;是对行人检测过去十年的一个回顾&#xff0c;从dataset&#xff0c;main approaches的角度分析了近10年的40多篇论文提出的方法&#xff0c;发现有三种方法&#xff08;DPM变体&am…