Pytorch 学习率衰减 之 余弦退火与余弦warmup 自定义学习率衰减scheduler

学习率衰减,通常我们英文也叫做scheduler。本文学习率衰减自定义,通过2种方法实现自定义,一是利用lambda,另外一个是继承pytorch的lr_scheduler

import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import *
from torchvision import models
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.fc = nn.Linear(1, 10)def forward(self,x):return self.fc(x)

余弦退火

  1. 当T_max=20
lrs = []
model = Net()
LR = 0.01
epochs = 100
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-9)
for epoch in range(epochs): optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))   
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

请添加图片描述

  1. 当T_max = epochs,这就是我们经常用到的弦退火的 scheduler,下面再来看看带Warm-up的
lrs = []
model = Net()
LR = 0.01
epochs = 100
optimizer = Adam(model.parameters(),lr = LR)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-9)
for epoch in range(epochs): optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))   
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

请添加图片描述

WarmUp

下面来看看 Pytorch定义的余弦退货的公式如下
ηt=ηmin+12(ηmax−ηmin)(1+cos⁡(TcurTmaxπ)),Tcur≠(2k+1)Tmax;ηt+1=ηt+12(ηmax−ηmin)(1−cos⁡(1Tmaxπ)),Tcur=(2k+1)Tmax.\begin{aligned} \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), & T_{cur} \neq (2k+1)T_{max}; \\ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), & T_{cur} = (2k+1)T_{max}. \end{aligned}ηtηt+1=ηmin+21(ηmaxηmin)(1+cos(TmaxTcurπ)),=ηt+21(ηmaxηmin)(1cos(Tmax1π)),Tcur=(2k+1)Tmax;Tcur=(2k+1)Tmax.

实际上是用下面的公式做为更新的, 当Tcur=TmaxT_{cur} = T_{max}Tcur=Tmax是,coscoscos部分为0,所以就等于ηmin\eta_{min}ηmin

ηt=ηmin+12(ηmax−ηmin)(1+cos⁡(TcurTmaxπ))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)ηt=ηmin+21(ηmaxηmin)(1+cos(TmaxTcurπ))

这里直接根据公式的定义来画个图看看

etas = []
epochs = 100
eta_max = 1e-4
eta_min = 1e-9
t_max = epochs / 1
for i in range(epoch):t_cur = ieta = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * t_cur / t_max))etas.append(eta)plt.figure(figsize=(10, 6))    
plt.plot(range(len(etas)), etas, color='r')
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

请添加图片描述
从图上来看,跟上面的余弦退化是一样的,眼尖的都会发现lr_min 不等于eta_min=1e-9

利用Lambda来定义的

有个较小的bug(也不算,在description里有指出)

def warm_up_cosine_lr_scheduler(optimizer, epochs=100, warm_up_epochs=5, eta_min=1e-9):"""Description:- Warm up cosin learning rate scheduler, first epoch lr is too smallArguments:- optimizer: input optimizer for the training- epochs: int, total epochs for your training, default is 100. NOTE: you should pass correct epochs for your training- warm_up_epochs: int, default is 5, which mean the lr will be warm up for 5 epochs. if warm_up_epochs=0, means no needto warn up, will be as cosine lr scheduler- eta_min: float, setup ConsinAnnealingLR eta_min while warm_up_epochs = 0Returns:- scheduler"""if warm_up_epochs <= 0:scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=eta_min)else:warm_up_with_cosine_lr = lambda epoch: eta_min + (epoch / warm_up_epochs) if epoch <= warm_up_epochs else 0.5 * (np.cos((epoch - warm_up_epochs) / (epochs - warm_up_epochs) * np.pi) + 1)scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)return scheduler
# warm up consin lr scheduler
lrs = []
model = Net()
LR = 1e-4
warm_up_epochs = 30
epochs = 100
optimizer = SGD(model.parameters(), lr=LR)scheduler = warm_up_cosine_lr_scheduler(optimizer, warm_up_epochs=warm_up_epochs, eta_min=1e-9)for epoch in range(epochs):optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()plt.figure(figsize=(10, 6))  ![请添加图片描述](https://img-blog.csdnimg.cn/566b2c036b4a44598ae2a5a0548f2550.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBAamFzbmVpaw==,size_20,color_FFFFFF,t_70,g_se,x_16)plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

请添加图片描述
从图上看,第一个lr非常非常小,导致训练时的,第一个epoch基本上不更新

继承lr_scheduler的类

class WarmupCosineLR(lr_scheduler._LRScheduler):def __init__(self, optimizer, lr_min, lr_max, warm_up=0, T_max=10, start_ratio=0.1):"""Description:- get warmup consine lr schedulerArguments:- optimizer: (torch.optim.*), torch optimizer- lr_min: (float), minimum learning rate- lr_max: (float), maximum learning rate- warm_up: (int),  warm_up epoch or iteration- T_max: (int), maximum epoch or iteration- start_ratio: (float), to control epoch 0 lr, if ratio=0, then epoch 0 lr is lr_minExample:<<< epochs = 100<<< warm_up = 5<<< cosine_lr = WarmupCosineLR(optimizer, 1e-9, 1e-3, warm_up, epochs)<<< lrs = []<<< for epoch in range(epochs):<<<     optimizer.step()<<<     lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])<<<     cosine_lr.step()<<< plt.plot(lrs, color='r')<<< plt.show()"""self.lr_min = lr_minself.lr_max = lr_maxself.warm_up = warm_upself.T_max = T_maxself.start_ratio = start_ratioself.cur = 0    # current epoch or iterationsuper().__init__(optimizer, -1)def get_lr(self):if (self.warm_up == 0) & (self.cur == 0):lr = self.lr_maxelif (self.warm_up != 0) & (self.cur <= self.warm_up):if self.cur == 0:lr = self.lr_min + (self.lr_max - self.lr_min) * (self.cur + self.start_ratio) / self.warm_upelse:lr = self.lr_min + (self.lr_max - self.lr_min) * (self.cur) / self.warm_up# print(f'{self.cur} -> {lr}')else:            # this works finelr = self.lr_min + (self.lr_max - self.lr_min) * 0.5 *\(np.cos((self.cur - self.warm_up) / (self.T_max - self.warm_up) * np.pi) + 1)self.cur += 1return [lr for base_lr in self.base_lrs]
# class
epochs = 100
warm_up = 5
cosine_lr = WarmupCosineLR(optimizer, 1e-9, 1e-3, warm_up, epochs, 0.1)
lrs = []
for epoch in range(epochs):optimizer.step()lrs.append(optimizer.state_dict()['param_groups'][0]['lr'])cosine_lr.step()plt.figure(figsize=(10, 6))   
plt.plot(lrs, color='r')
plt.text(0, lrs[0], str(lrs[0]))
plt.text(epochs, lrs[-1], str(lrs[-1]))
plt.show()

请添加图片描述
从图上看出,第一个epoch的lr也不至于非常非常小了,达到了所需预期,当然,如果你说first epoch的lr,你也需要非常非常小(<1e-8),你也可以自己尝试其它值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/260505.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ 字符串赋给另一个_7.2 C++字符串处理函数

点击上方“C语言入门到精通”&#xff0c;选择置顶第一时间关注程序猿身边的故事作者闫小林白天搬砖&#xff0c;晚上做梦。我有故事&#xff0c;你有酒么&#xff1f;C字符串处理函数C语言和C提供了一些字符串函数&#xff0c;使得用户能很方便地对字符串进行处理。这些是放在…

如何检测远程主机上的某个端口是否开启

有时候我们要测试远程主机上的某个端口是否开启&#xff0c;无需使用太复杂的工作&#xff0c;windows下就自带了工具&#xff0c;那就是telnet。怎么检测呢&#xff0c;按下面的步骤&#xff1a; 1、安装telnet。我的win7下就没有telnet&#xff0c;在cmd下输入telnet提示没有…

Windows10 + WSL (Ubuntu) + Anaconda + vscode 手把手配置python运行环境(含虚拟环境)

配置WSL windows桌面下&#xff0c;按下面顺序可以找到 "启动或关闭windows功能” &#xff0c; 开始 -> 设置 -> 应用 -> 应用和功能 -> 可选功能 -> 相关设置下 更多Windows功能&#xff08;滚动鼠标到底部&#xff09;点击后&#xff0c;会弹出 启动或…

2019编译ffepeg vs_如何在windows10下使用vs2017编译最新版本的FFmpeg和ffplay

该文章描述了如何在windows10 64位系统下面编译出FFmpeg的库及其自带的ffplay播放器&#xff0c;而且全部采用最新的版本&#xff0c;这样我们可以在vs2017的ide下调试ffplay&#xff0c;能使我们更容易学习FFmpeg的架构以及音视频播放器的原理。步骤&#xff1a;1.安装vs2017在…

训练集山准确率高测试集上准确率很低_推荐算法改版前的AB测试

编辑导语&#xff1a;所谓推荐算法就是利用用户的一些行为&#xff0c;通过一些数学算法&#xff0c;推测出用户可能喜欢的东西&#xff1b;如今很多软件都有这样的操作&#xff0c;对于此系统的设计也会进行测试&#xff1b;本文作者分享了关于推荐算法改版前的AB测试&#xf…

C#实现渐变颜色的Windows窗体控件

C#实现渐变颜色的Windows窗体控件! 1,定义一个BaseFormGradient,继承于System.Windows.Forms.Form2,定义三个变量: privateColor _Color1 Color.Gainsboro; privateColor _Color2 Color.White; privatefloat_ColorAngle 0f;3,重载OnPaintBackground方法 protecte…

Windows下 jupyter notebook 运行multiprocessing 报错的问题与解决方法

文章目录测试用的代码错误解决方法测试用的代码 下面每一个对应一个jupyter notebook的单元格 import time from multiprocessing import Process, Queuedef generator():c 0while True:time.sleep(1.0) # read somethingyield cc 1%%timeds generator() for i in range(3…

vc mysql_vc6.0连接mysql数据库

一、MySQL的安装Mysql的安装去官网下载就可以。。。最新的是5.7版本。。二、VC6.0的设置(1)打开VC6.中选0 工具栏Tools菜单下的Options选项&#xff0c;在Directories的标签页中右边的“Show directories for:”下拉列表中“Includefiles”&#xff0c;然后在中间列表框中添加你…

python class用法_python原类、类的创建过程与方法

【小宅按】今天为大家介绍一下python中与class 相关的知识……获取对象的类名python是一门面向对象的语言&#xff0c;对于一切接对象的python来说&#xff0c;咱们有必要深入的学习与了解一些知识首先大家都知道&#xff0c;要获取一个对象所对应的类&#xff0c;需要使用clas…

深度学习中的一些常见的激活函数集合(含公式与导数的推导)sigmoid, relu, leaky relu, elu, numpy实现

文章目录Sigmoid(x)双曲正切线性整流函数 rectified linear unit &#xff08;ReLu&#xff09;PReLU(Parametric Rectified Linear Unit) Leaky ReLu指数线性单元 Exponential Linear Units &#xff08;ELU&#xff09;感知机激活%matplotlib inline %config InlineBackend.f…

最牛X的GCC 内联汇编

正如大家知道的&#xff0c;在C语言中插入汇编语言&#xff0c;其是Linux中使用的基本汇编程序语法。本文将讲解 GCC 提供的内联汇编特性的用途和用法。对于阅读这篇文章&#xff0c;这里只有两个前提要求&#xff0c;很明显&#xff0c;就是 x86 汇编语言和 C 语言的基本认识。…

mysql的告警日志_MySQL Aborted connection告警日志的分析

前言&#xff1a;有时候&#xff0c;连接MySQL的会话经常会异常退出&#xff0c;错误日志里会看到"Got an error reading communication packets"类型的告警。本篇文章我们一起来讨论下该错误可能的原因以及如何来规避。1.状态变量Aborted_clients和Aborted_connects…

hosts多个ip对应一个主机名_一个简单的Web应用程序,用作连接到ssh服务器的ssh客户端...

WebSSH一个简单的Web应用程序&#xff0c;用作连接到ssh服务器的ssh客户端。它是用Python编写的&#xff0c;基于tornado&#xff0c;paramiko和xterm.js。特征支持SSH密码验证&#xff0c;包括空密码。支持SSH公钥认证&#xff0c;包括DSA RSA ECDSA Ed25519密钥。支持加密密钥…

Shell Notes(1)

> vi复制粘贴 光标移动到要复制的部分的开头&#xff0c;Esc退出插入模式&#xff0c;按v进入Visual模式&#xff0c;用hjkl选中要复制的部分 按Y或者yy&#xff0c;复制 移动光标到目标位置&#xff0c;按p&#xff0c;粘贴 > echo –e 参数 –e 可以使echo解释由反斜杠…

mysql多表查询语句_mysql查询语句 和 多表关联查询 以及 子查询

1.查询一张表&#xff1a;select * from 表名&#xff1b;2.查询指定字段&#xff1a;select 字段1&#xff0c;字段2&#xff0c;字段3….from 表名&#xff1b;3.where条件查询&#xff1a;select字段1&#xff0c;字段2&#xff0c;字段3 frome表名 where 条件表达式&#x…

Pytorch 自定义激活函数前向与反向传播 sigmoid

文章目录Sigmoid公式求导过程优点&#xff1a;缺点&#xff1a;自定义Sigmoid与Torch定义的比较可视化import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F%matplotlib inlineplt.rcPa…

js高级编程_这位设计师用Processing把创意编程玩到了极致!

Processing作为新媒体从业者的必备工具&#xff0c;近来却越来越成为设计师们的新宠&#xff01;今天小编将介绍以为用Processing把创意编程玩到极致的设计师Tim Rodenbrker。“我们的世界正在以惊人的速度变化。新技术为创作带来了根本性的转变。编程是我们这个时代最宝贵的技…

微软.NET Framework 4.5.2 RTM正式版

今天&#xff0c;微软.NET开发团队发布.NET Framework 4.5.2 RTM正式版。新版框架继续高度兼容现有的.NET Framework 4、4.5、4.5.1等版本&#xff0c;该版本框架与旧版的.NET Framework 3.5 SP1和早期版本采取不同的处理方式&#xff0c;但与.NET Framework 4、4.5相比&#x…

Pytorch 自定义激活函数前向与反向传播 Tanh

看完这篇&#xff0c;你基本上可以自定义前向与反向传播&#xff0c;可以自己定义自己的算子 文章目录Tanh公式求导过程优点&#xff1a;缺点&#xff1a;自定义Tanh与Torch定义的比较可视化import matplotlib import matplotlib.pyplot as plt import numpy as np import torc…

HDU ACM 1181 变形课 (广搜BFS + 动态数组vector)-------第一次使用动态数组vector

http://acm.hdu.edu.cn/showproblem.php?pid1181 题意&#xff1a;给我若干个单词,若单词A的结尾与单词B的开头相同,则表示A能变成B,判断能不能从b开头变成m结尾. 如: big-got-them 第一次使用动态数组vector View Code 1 #include <iostream>2 #include <vector>…