DeepLearning - 余弦退火热重启学习率 CosineAnnealingWarmRestartsLR

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134249925

CosineAnnealingWarmRestartsLR,即 余弦退火热重启学习率,周期性修改学习率的下降和上升,间隔幅度逐渐增大,避免模型的性能抖动。其中核心参数:

  • optimizer 的参数,lr 学习率,默认学习率是 lr * GPU 数量,例如 lr 设置成 0.00001,32卡实际是 0.00032。
  • T_0,衰减的 global step 数,即单卡的运行次数,根据运行时间确定,例如 step 是 28.5 秒一次,(28.5 * 2000) / 3600 = 15.8 小时。
  • T_mult,周期间隔,逐渐加大,例如 T_mult 是 2,则表示,第n次是 T 0 ∗ T m u l t n T_0*T_{mult}^{n} T0Tmultn 步。
  • eta_min,从 LR 衰减的最小步数,可以设置成0。

源码:

optimizer = deepspeed.ops.adam.FusedAdam(self.model.parameters(), lr=learning_rate, eps=eps)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=lr_t_0, T_mult=2, eta_min=0, last_epoch=-1)

LR 曲线如下:

GitHub - SevenZhan/Pytorch: self-used pytorch utilities

源码:CosineAnnealingWarmRestarts

class CosineAnnealingWarmRestarts(LRScheduler):r"""Set the learning rate of each parameter group using a cosine annealingschedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`is the number of epochs since the last restart and :math:`T_{i}` is the numberof epochs between two warm restarts in SGDR:.. math::\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.It has been proposed in`SGDR: Stochastic Gradient Descent with Warm Restarts`_.Args:optimizer (Optimizer): Wrapped optimizer.T_0 (int): Number of iterations for the first restart.T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.eta_min (float, optional): Minimum learning rate. Default: 0.last_epoch (int, optional): The index of last epoch. Default: -1.verbose (bool): If ``True``, prints a message to stdout foreach update. Default: ``False``... _SGDR\: Stochastic Gradient Descent with Warm Restarts:https://arxiv.org/abs/1608.03983"""def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):if T_0 <= 0 or not isinstance(T_0, int):raise ValueError(f"Expected positive integer T_0, but got {T_0}")if T_mult < 1 or not isinstance(T_mult, int):raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")if not isinstance(eta_min, (float, int)):raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")self.T_0 = T_0self.T_i = T_0self.T_mult = T_multself.eta_min = eta_minself.T_cur = last_epochsuper().__init__(optimizer, last_epoch, verbose)def get_lr(self):if not self._get_lr_called_within_step:warnings.warn("To get the last learning rate computed by the scheduler, ""please use `get_last_lr()`.", UserWarning)return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2for base_lr in self.base_lrs][docs]    def step(self, epoch=None):"""Step could be called after every batch updateExample:>>> # xdoctest: +SKIP("Undefined vars")>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)>>> iters = len(dataloader)>>> for epoch in range(20):>>>     for i, sample in enumerate(dataloader):>>>         inputs, labels = sample['inputs'], sample['labels']>>>         optimizer.zero_grad()>>>         outputs = net(inputs)>>>         loss = criterion(outputs, labels)>>>         loss.backward()>>>         optimizer.step()>>>         scheduler.step(epoch + i / iters)This function can be called in an interleaved way.Example:>>> # xdoctest: +SKIP("Undefined vars")>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)>>> for epoch in range(20):>>>     scheduler.step()>>> scheduler.step(26)>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)"""if epoch is None and self.last_epoch < 0:epoch = 0if epoch is None:epoch = self.last_epoch + 1self.T_cur = self.T_cur + 1if self.T_cur >= self.T_i:self.T_cur = self.T_cur - self.T_iself.T_i = self.T_i * self.T_multelse:if epoch < 0:raise ValueError(f"Expected non-negative epoch, but got {epoch}")if epoch >= self.T_0:if self.T_mult == 1:self.T_cur = epoch % self.T_0else:n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)self.T_i = self.T_0 * self.T_mult ** (n)else:self.T_i = self.T_0self.T_cur = epochself.last_epoch = math.floor(epoch)class _enable_get_lr_call:def __init__(self, o):self.o = odef __enter__(self):self.o._get_lr_called_within_step = Truereturn selfdef __exit__(self, type, value, traceback):self.o._get_lr_called_within_step = Falsereturn selfwith _enable_get_lr_call(self):for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):param_group, lr = dataparam_group['lr'] = lrself.print_lr(self.verbose, i, lr, epoch)self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

WandB 测试效果:

WandB

参考:

  • 知乎 - PyTorch中学习率调度器可视化介绍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/133839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K7系列FPGA进行FLASH读写1——CCLK控制(STARTUPE2原语)

最近的工作涉及对 FPGA 进行远程更新&#xff0c;也就是通过远程通信接口将 .bin 文件送到 FPGA&#xff0c;然后写入 FLASH&#xff0c;这样当 FPGA 重新上电后就可以执行更新后的程序了。因此第一步工作就是进行 FLASH 的读写控制。 然而如果尝试配置 FLASH 管脚时&#xff0…

Android Datastore 动态创建与源码解析

涉及到的知识点 1、协程原理---->很好的博客介绍&#xff0c;一个小故事讲明白进程、线程、Kotlin 协程到底啥关系&#xff1f; 2、Channel知识点---->Android—kotlin-Channel超详细讲解 3、Coroutines : CompletableDeferred and structured concurrency 封装的DataS…

数学建模比赛中常用的建模提示词(数模prompt)

以下为数学建模比赛中常用的建模提示词&#xff0c;希望对你有所帮助&#xff01; 帮我总结一下数学建模有哪些预测类算法&#xff1f; 灰色预测模型级比检验是什么意思? 描述一下BP神经网络算法的建模步骤 对于分类变量与分类变量相关性分析用什么算法 前10年的数据分别是1&a…

代码随想录 Day38 完全背包问题 LeetCode T70 爬楼梯 T322 零钱兑换 T279 完全平方数

前言 在今天的题目开始之前,让我们来回顾一下之前的知识,动规五部曲 1.确定dp数组含义 2.确定dp数组的递推公式 3.初始化dp数组 4.确定遍历顺序 5.打印dp数组来排错 tips: 1.当求取物品有限的时候用0-1背包,求取物品无限的时候用完全背包 结果是排列还是组合也有说法,当结果是组…

渗透实战靶机2wp

0x00 简介 1、测试环境 目标IP&#xff1a;10.xxxx 测试IP&#xff1a;192.168.139.128 测试环境&#xff1a;win10、kali等 测试时间&#xff1a;2021.7.22-2021.7.22 测试人员&#xff1a;ruanruan 2、测试过程 本次实战主要通过对收集到的端口、目录等信息进行持续整…

润和软件HopeStage与奇安信网神终端安全管理系统、可信浏览器完成产品兼容性互认证

近日&#xff0c;江苏润和软件股份有限公司&#xff08;以下简称“润和软件”&#xff09;HopeStage 操作系统与奇安信网神信息技术&#xff08;北京&#xff09;股份有限公司&#xff08;以下简称“奇安信”&#xff09;终端安全管理系统、可信浏览器完成产品兼容性测试。 测试…

阿里云二级域名绑定与宝塔Nginx反向代理配置

在阿里或者腾讯...各大域名商买好域名&#xff0c;备案解析好&#xff0c;目标URL&#xff0c;是真正的地址&#xff0c;比如一些端口&#xff0c;后者会自动填写。 注意ssl配置好&#xff0c;这里不要带反代端口

vue中异步更新$nextTick

1.需求 编辑标题, 编辑框自动聚焦 点击编辑&#xff0c;显示编辑框让编辑框&#xff0c;立刻获取焦点 2.代码实现 <template><div class"app"><div v-if"isShowEdit"><input type"text" v-model"editValue"…

王道p18 第12题假设 A中的 n个元素保存在一个一维数组中,请设计一个尽可能高效的算法,找出A的主元素。若存在主元素,则输出该元素:否则输出-1

视频讲解在&#xff1a;&#x1f447; p18 第12题 c语言实现王道数据结构课后习题_哔哩哔哩_bilibili 从前向后扫描数组元素&#xff0c;标记出一个可能成为主元素的元素 Num。然后重新计数&#xff0c;确认 Num 是否是主元素。 我们可分为以下两步: 1.选取候选的主元素。依…

YOLO目标检测——汽车头部尾部检测数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用&#xff1a;用于训练自动驾驶系统中的车辆感知模块&#xff0c;以实现对周围车辆头部和尾部的准确检测和识别数据集说明&#xff1a;汽车头部尾部检测数据集&#xff0c;真实场景的高质量图片数据&#xff0c;数据场景丰富标签说明&#xff1a;使用lableimg标注软…

【JMeter】后置处理器的分类以及场景介绍

1.常用后置处理器的分类 Json提取器 针对响应体的返回结果是json格式的会自动生成新的变量名为【提取器中变量名_MatchNr】,取到的个数由jsonpath expression取到的个数决定 可以当作普通变量调用,调用语法:${提取器中变量名_MatchNr}正则表达式提取器 返回结果是任何数据格…

一款好用的PDF转翻页电子书网站

​你是否曾经遇到过PDF文件无法翻页或者阅读不便的问题&#xff1f;今天给大家推荐一款好用的PDF转翻页电子书网站&#xff0c;让你轻松阅读PDF文件&#xff0c;不再烦恼翻页问题&#xff01; 一、网站介绍 这款FLBOOK在线制作电子杂志网站支持多种电子文件格式转换&#xff0…

JWT简介 JWT结构 JWT示例 前端添加JWT令牌功能 后端程序

目录 1. JWT简述 1.1 什么是JWT 1.2 为什么使用JWT 1.3 JWT结构 1.4 验证过程 2. JWT示例 2.1 后台程序 2.2 前台加入jwt令牌功能 1. JWT简述 1.1 什么是JWT Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准&#xff08;(RFC 7…

〔001〕虚幻 UE5 安装教程

✨ 目录 🎈 下载启动程序🎈 注册个人账户🎈 选择引擎版本🎈 选择安装选项🎈 虚幻商城的使用🎈 每月免费插件🎈 安装插件🎈 下载启动程序 下载地址:https://www.unrealengine.com/zh-CN/download点击上面地址,下载 UE5 启动程序并安装🎈 注册个人账户 打开商…

Linux多虚拟主机和配置限制访问与日志

目录 一、多虚拟主机 1.配置单网卡多个ip 2.给每个主机站点设置主页 3.测试访问 二、限制访问 1.限制所有 2.放行192.168.0.0/24网段访问 三、日志与状态页 1.定义访客日志 2.状态页配置 一、多虚拟主机 1.配置单网卡多个ip ip address add 192.168.0.231/24 dev e…

案例研究|腾讯音乐娱乐集团与JumpServer共探安全运维审计解决方案

近年来&#xff0c;得益于人民消费水平的提升以及版权意识的加强&#xff0c;用户付费意愿和在线用户数量持续增长&#xff0c;中国在线音乐市场呈现出稳定增长的发展态势。随着腾讯音乐于2018年12月上市&#xff0c;进一步推动了中国在线音乐市场的发展。 腾讯音乐娱乐集团&a…

rust入门基础案例:猜数字游戏

案例出处是《Rust权威指南》&#xff0c;书中有更加详细的解释。从这个例子中&#xff0c;我们可以了解到 rust 的两个操作&#xff1a; 如何从控制台读取用户输入rust 如何生成随机数 代码格式化 编译器可在保存时对代码做格式化处理&#xff0c;底层调用 rustfmt 来实现&a…

Kubernetes Dashboard 用户名密码方式登录

Author&#xff1a;rab 前言 为了 K8s 集群安全&#xff0c;默认情况下 Dashboard 以 Token 的形式登录的&#xff0c;那如果我们想以用户名/密码的方式登录该怎么操作呢&#xff1f;其实只需要我们创建用户并进行 ClusterRoleBinding 绑定即可&#xff0c;接下来是具体的操作…

MCU常见通信总线串讲(二)—— RS232和RS485

&#x1f64c;秋名山码民的主页 &#x1f602;oi退役选手&#xff0c;Java、大数据、单片机、IoT均有所涉猎&#xff0c;热爱技术&#xff0c;技术无罪 &#x1f389;欢迎关注&#x1f50e;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; 获取源码&#xff0c;添加WX 目录 前言一…

低代码工具的常见用例与受众市场

目录 一、低代码工具的常见用例是什么&#xff1f; 1.业务流程管理&#xff08;BPM&#xff09; 2.自定义应用程序开发 3.数据管理和分析 4.移动应用程序开发 二、低代码受众和市场 1.制造商 2.个人开发者/自由职业者 3.代理商 4.小型企业和初创企业 5.中型企业 6.营销团队 7.软…