DeepLearning - 余弦退火热重启学习率 CosineAnnealingWarmRestartsLR

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134249925

CosineAnnealingWarmRestartsLR,即 余弦退火热重启学习率,周期性修改学习率的下降和上升,间隔幅度逐渐增大,避免模型的性能抖动。其中核心参数:

  • optimizer 的参数,lr 学习率,默认学习率是 lr * GPU 数量,例如 lr 设置成 0.00001,32卡实际是 0.00032。
  • T_0,衰减的 global step 数,即单卡的运行次数,根据运行时间确定,例如 step 是 28.5 秒一次,(28.5 * 2000) / 3600 = 15.8 小时。
  • T_mult,周期间隔,逐渐加大,例如 T_mult 是 2,则表示,第n次是 T 0 ∗ T m u l t n T_0*T_{mult}^{n} T0Tmultn 步。
  • eta_min,从 LR 衰减的最小步数,可以设置成0。

源码:

optimizer = deepspeed.ops.adam.FusedAdam(self.model.parameters(), lr=learning_rate, eps=eps)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=lr_t_0, T_mult=2, eta_min=0, last_epoch=-1)

LR 曲线如下:

GitHub - SevenZhan/Pytorch: self-used pytorch utilities

源码:CosineAnnealingWarmRestarts

class CosineAnnealingWarmRestarts(LRScheduler):r"""Set the learning rate of each parameter group using a cosine annealingschedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`is the number of epochs since the last restart and :math:`T_{i}` is the numberof epochs between two warm restarts in SGDR:.. math::\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.It has been proposed in`SGDR: Stochastic Gradient Descent with Warm Restarts`_.Args:optimizer (Optimizer): Wrapped optimizer.T_0 (int): Number of iterations for the first restart.T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.eta_min (float, optional): Minimum learning rate. Default: 0.last_epoch (int, optional): The index of last epoch. Default: -1.verbose (bool): If ``True``, prints a message to stdout foreach update. Default: ``False``... _SGDR\: Stochastic Gradient Descent with Warm Restarts:https://arxiv.org/abs/1608.03983"""def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):if T_0 <= 0 or not isinstance(T_0, int):raise ValueError(f"Expected positive integer T_0, but got {T_0}")if T_mult < 1 or not isinstance(T_mult, int):raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")if not isinstance(eta_min, (float, int)):raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")self.T_0 = T_0self.T_i = T_0self.T_mult = T_multself.eta_min = eta_minself.T_cur = last_epochsuper().__init__(optimizer, last_epoch, verbose)def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2for base_lr in self.base_lrs][docs]    def step(self, epoch=None):"""Step could be called after every batch updateExample:>>> # xdoctest: +SKIP("Undefined vars")>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)>>> iters = len(dataloader)>>> for epoch in range(20):>>>     for i, sample in enumerate(dataloader):>>>         inputs, labels = sample['inputs'], sample['labels']>>>         optimizer.zero_grad()>>>         outputs = net(inputs)>>>         loss = criterion(outputs, labels)>>>         loss.backward()>>>         optimizer.step()>>>         scheduler.step(epoch + i / iters)This function can be called in an interleaved way.Example:>>> # xdoctest: +SKIP("Undefined vars")>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)>>> for epoch in range(20):>>>     scheduler.step()>>> scheduler.step(26)>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)"""if epoch is None and self.last_epoch < 0:epoch = 0if epoch is None:epoch = self.last_epoch + 1self.T_cur = self.T_cur + 1if self.T_cur >= self.T_i:self.T_cur = self.T_cur - self.T_iself.T_i = self.T_i * self.T_multelse:if epoch < 0:raise ValueError(f"Expected non-negative epoch, but got {epoch}")if epoch >= self.T_0:if self.T_mult == 1:self.T_cur = epoch % self.T_0else:n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)self.T_i = self.T_0 * self.T_mult ** (n)else:self.T_i = self.T_0self.T_cur = epochself.last_epoch = math.floor(epoch)class _enable_get_lr_call:def __init__(self, o):self.o = odef __enter__(self):self.o._get_lr_called_within_step = Truereturn selfdef __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = Falsereturn selfwith _enable_get_lr_call(self):for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):param_group, lr = dataparam_group['lr'] = lrself.print_lr(self.verbose, i, lr, epoch)self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

WandB 测试效果:

WandB

参考:

  • 知乎 - PyTorch中学习率调度器可视化介绍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/133839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K7系列FPGA进行FLASH读写1——CCLK控制(STARTUPE2原语)

最近的工作涉及对 FPGA 进行远程更新&#xff0c;也就是通过远程通信接口将 .bin 文件送到 FPGA&#xff0c;然后写入 FLASH&#xff0c;这样当 FPGA 重新上电后就可以执行更新后的程序了。因此第一步工作就是进行 FLASH 的读写控制。 然而如果尝试配置 FLASH 管脚时&#xff0…

Android Datastore 动态创建与源码解析

涉及到的知识点 1、协程原理---->很好的博客介绍&#xff0c;一个小故事讲明白进程、线程、Kotlin 协程到底啥关系&#xff1f; 2、Channel知识点---->Android—kotlin-Channel超详细讲解 3、Coroutines : CompletableDeferred and structured concurrency 封装的DataS…

数学建模比赛中常用的建模提示词(数模prompt)

以下为数学建模比赛中常用的建模提示词&#xff0c;希望对你有所帮助&#xff01; 帮我总结一下数学建模有哪些预测类算法&#xff1f; 灰色预测模型级比检验是什么意思? 描述一下BP神经网络算法的建模步骤 对于分类变量与分类变量相关性分析用什么算法 前10年的数据分别是1&a…

javascript自定义事件的观察者模式写法和用法以及继承

<html><head><meta http-equiv"Context-Type:text/html;charsetutf-8"/><title>自定义事件之观察者模式</title><script type"text/javascript" src"common.js"></script></head><body>&…

avue中 curd的列表配置

说明&#xff1a; avue-crud组件中添加查询条件或者新增的时候&#xff0c;条件为下拉框且接口在curd组件中配置 1. html代码 <template><basic-container><avue-crud:data"dataList":option"option"search-change"searchChange&quo…

代码随想录 Day38 完全背包问题 LeetCode T70 爬楼梯 T322 零钱兑换 T279 完全平方数

前言 在今天的题目开始之前,让我们来回顾一下之前的知识,动规五部曲 1.确定dp数组含义 2.确定dp数组的递推公式 3.初始化dp数组 4.确定遍历顺序 5.打印dp数组来排错 tips: 1.当求取物品有限的时候用0-1背包,求取物品无限的时候用完全背包 结果是排列还是组合也有说法,当结果是组…

【PostgreSql基础语法 】1、增删改查、where、limit、like模糊查询

Shell命令框和Navicat联合使用 一、数据库层面&#xff08;shell命令行&#xff09;二、表格层面&#xff08;Navicat&#xff09;三、增删改查1. 增insert into2. 查询select3. UPDATE 改4. DELETE 删除 四、 关键字1. AND2.OR3. NOT NULL 和 NULL4. LIKE 模糊查询4.1 like查找…

linux boot阶段内存分配(x86)

x86中没有boot memory allocator&#xff0c;是用 memblock 来分配的。 memblock有memory 与reserved两种类型&#xff0c;它们的内存是静态内存&#xff0c;不需要用memblock本身去维护&#xff0c;它们被标记为__initdata_memblock&#xff0c;会在boot结束后&#xff08;fre…

渗透实战靶机2wp

0x00 简介 1、测试环境 目标IP&#xff1a;10.xxxx 测试IP&#xff1a;192.168.139.128 测试环境&#xff1a;win10、kali等 测试时间&#xff1a;2021.7.22-2021.7.22 测试人员&#xff1a;ruanruan 2、测试过程 本次实战主要通过对收集到的端口、目录等信息进行持续整…

润和软件HopeStage与奇安信网神终端安全管理系统、可信浏览器完成产品兼容性互认证

近日&#xff0c;江苏润和软件股份有限公司&#xff08;以下简称“润和软件”&#xff09;HopeStage 操作系统与奇安信网神信息技术&#xff08;北京&#xff09;股份有限公司&#xff08;以下简称“奇安信”&#xff09;终端安全管理系统、可信浏览器完成产品兼容性测试。 测试…

SQLSugar查询返回DataTable

SQLSugar是一个用于执行SQL查询的C#库&#xff0c;它提供了简单易用的API接口来执行SQL查询。要查询返回DataTable&#xff0c;可以使用SQLSugar的QueryHelper类。 以下是一个示例代码&#xff0c;展示了如何使用SQLSugar的QueryHelper类查询返回DataTable&#xff1a; 首先&…

SICP01(待续)

一、Lisp概览 语言&#xff1a;规则本身计算机科学的任务&#xff1a;形式化有关”怎么做“的指令性知识&#xff0c;并付诸实践问题产生&#xff1a;构建大型系统的时候难以管理解决方法&#xff1a;在大系统中控制复杂度的方法也是计算机所关注的注意&#xff1a;计算机中的…

阿里云二级域名绑定与宝塔Nginx反向代理配置

在阿里或者腾讯...各大域名商买好域名&#xff0c;备案解析好&#xff0c;目标URL&#xff0c;是真正的地址&#xff0c;比如一些端口&#xff0c;后者会自动填写。 注意ssl配置好&#xff0c;这里不要带反代端口

vue中异步更新$nextTick

1.需求 编辑标题, 编辑框自动聚焦 点击编辑&#xff0c;显示编辑框让编辑框&#xff0c;立刻获取焦点 2.代码实现 <template><div class"app"><div v-if"isShowEdit"><input type"text" v-model"editValue"…

王道p18 第12题假设 A中的 n个元素保存在一个一维数组中,请设计一个尽可能高效的算法,找出A的主元素。若存在主元素,则输出该元素:否则输出-1

视频讲解在&#xff1a;&#x1f447; p18 第12题 c语言实现王道数据结构课后习题_哔哩哔哩_bilibili 从前向后扫描数组元素&#xff0c;标记出一个可能成为主元素的元素 Num。然后重新计数&#xff0c;确认 Num 是否是主元素。 我们可分为以下两步: 1.选取候选的主元素。依…

【软考:系统集成项目管理】之 项目成本管理

1. 成本管理的过程 制订成本管理计划成本估算成本预算成本控制 2. 过程的输入、输出、工具与技术 过程输入工具与技术输出1. 制订成本管理计划1. 项目管理计划 2. 项目章程 3. 事业环境因素 4. 组织过程资产1. 专家判断 2. 分析技术 3. 会议1. 成本管理计划2. 成本估算1. 成…

YOLO目标检测——汽车头部尾部检测数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;用于训练自动驾驶系统中的车辆感知模块&#xff0c;以实现对周围车辆头部和尾部的准确检测和识别数据集说明&#xff1a;汽车头部尾部检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富标签说明&#xff1a;使用lableimg标注软…

【JMeter】后置处理器的分类以及场景介绍

1.常用后置处理器的分类 Json提取器 针对响应体的返回结果是json格式的会自动生成新的变量名为【提取器中变量名_MatchNr】,取到的个数由jsonpath expression取到的个数决定 可以当作普通变量调用,调用语法:${提取器中变量名_MatchNr}正则表达式提取器 返回结果是任何数据格…

一款好用的PDF转翻页电子书网站

​你是否曾经遇到过PDF文件无法翻页或者阅读不便的问题&#xff1f;今天给大家推荐一款好用的PDF转翻页电子书网站&#xff0c;让你轻松阅读PDF文件&#xff0c;不再烦恼翻页问题&#xff01; 一、网站介绍 这款FLBOOK在线制作电子杂志网站支持多种电子文件格式转换&#xff0…

C# 9.0 record和with的定义及使用

C# 9 引入record&#xff0c;它一种可以创建的新引用类型&#xff0c;而不是类或结构。 C# 10 添加了 record structs&#xff0c;以便可以将记录定义为值类型。 记录与类不同&#xff0c;区别在于record类型使用基于值的相等性。 两个记录类型的变量在它们的类型和值都相同时&…