解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:
先在训练文件中给trainer加一个callback

from smoe.callbacks.save_model import SchedulerStateCallback
trainer.add_callback(SchedulerStateCallback)
class SchedulerStateCallback(TrainerCallback):def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):if os.environ.get("RANK", "0") == "0":#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = scheduler.state_dict()#save_path = os.path.join(args.output_dir, SCHEDULER_NAME)# 使用 PREFIX_CHECKPOINT_DIR 和 global_step 创建检查点目录名checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"# 完整的检查点目录路径checkpoint_path = os.path.join(args.output_dir, checkpoint_folder)# 如果目录不存在,则创建它if not os.path.exists(checkpoint_path):os.makedirs(checkpoint_path)# 完整的保存路径save_path = os.path.join(checkpoint_path, SCHEDULER_NAME)# 保存scheduler状态torch.save(scheduler_state, save_path)def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):# 如果resume_from_checkpoint设置了有效路径if args.resume_from_checkpoint is not None:load_path = os.path.join(args.resume_from_checkpoint, SCHEDULER_NAME)# 如果该路径下有保存的调度器状态,则加载它if os.path.exists(load_path):#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = torch.load(load_path)scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【广州华锐互动】AR技术在配电系统运维中的应用

随着科技的不断发展,AR(增强现实)技术逐渐走进了我们的生活。在电力行业,AR技术的应用也为巡检工作带来了许多新突破,提高了巡检效率和安全性。本文将从以下几个方面探讨AR配电系统运维系统的新突破。 首先,AR技术可以实现虚拟巡检…

Android Jetpack架构组件库:Hilt

一、开发者官网关于Hilt库使用链接如下 使用 Hilt 实现依赖项注入 Hilt版本说明 二、工程目录图 请点击下面工程名称,跳转到代码的仓库页面,将工程 下载下来 Demo Code 里有详细的注释 代码:LearnJetpack-hilt:hilt版本2.48 代…

Redis集群3.2.11离线安装详细版本(使用Ruby)

1.安装软件准备 1.Redis版本下载 Index of /releases/http://download.redis.io/releases/ 1.2gcc环境准备 GCC(GNU Compiler Collection,GNU编译器套件)是一套用于编译程序代码的开源编译器工具集。它的主要用途是将高级编程语言(如C、C++、Fortran等)编写的源代码转换…

【项目 计网12】4.32UDP通信实现 4.33广播 4.34组播 4.35本地套接字通信

文章目录 4.32UDP通信实现udp_client.cudp_server.c 4.33广播bro_server.cbro_client.c 4.34组播multi_server.cmulti_client.c 4.35本地套接字通信ipc_server.cipc_client.c 4.32UDP通信实现 udp_client.c #include <stdio.h> #include <stdlib.h> #include <…

vmware网卡(网络适配器)桥接、NAT、仅主机3种模式解析

Bridged&#xff08;桥接模式&#xff09;、NAT&#xff08;网络地址转换模式&#xff09;、Host-Only&#xff08;仅主机模式&#xff09; Windows系统安装好vmware后&#xff0c;在网络连接中会生成VMnet1和VMnet8两个虚拟网卡。 VMnet1作用于仅主机模式&#xff0c;VMnet8作…

目标检测笔记(十五): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)

文章目录 一、目标检测介绍二、YOLOX介绍三、源码获取四、环境搭建4.1 环境检测 五、数据集准备六、模型训练七、模型验证八、模型测试 一、目标检测介绍 目标检测&#xff08;Object Detection&#xff09;是计算机视觉领域的一项重要技术&#xff0c;旨在识别图像或视频中的…

单元测试与自测

单元测试在百度百科的定义&#xff1a; 自测在百度百科的定义&#xff1a; 单元测试是测一个类或一个函数&#xff0c;自立门第main函数&#xff0c;不依赖于项目&#xff0c;预期的是这个类或函数是没有问题的。程序编码完成之后至各种测试再到用户使用一二十年出现的任何bug都…

娱乐时间 —— 用python将图片转为excel十字绘

最近看蛮多朋友在玩&#xff0c;要么只能画比较简单的&#xff0c;要么非常花时间。想了下本质上就是把excel对应的单元格涂色&#xff0c;如果能知道哪些格子要上什么颜色&#xff0c;用编程来实现图片转为excel十字绘应该是很方便的。 图片的每一个像素点都可以数值化&#x…

Jmeter如何设置中文版

第一步&#xff1a;找到 apache-jmeter-5.4.3\bin目录下的 jmeter.properties 第二步:打开 三&#xff0c;ctrf 输入languageen&#xff0c;注释掉&#xff0c;增加以行修改如下 四&#xff0c;ctrs 保存修改内容&#xff0c;重新打开jmeter就可以了

微信小程序Day2笔记

1、WXML模板语法 1. 数据绑定 数据绑定的基本原则 在data中定义数据在WXML中使用数据 2. 在data中定义页面的数据 在页面对应的.js文件中&#xff0c;把数据定义到data对象中。 3. Mustache语法的格式 把data中的数据绑定到页面中渲染&#xff0c;使用Mustache语法&…

PHP8中获取并删除数组中最后一个元素-PHP8知识详解

在php8中&#xff0c;array_pop()函数将返回数组的最后一个元素&#xff0c;并且将该元素从数组中删除。语法格式如下&#xff1a; array_pop(目标数组) 获取并删除数组中最后一个元素&#xff0c;参考代码&#xff1a; <?php $stu array(s001>明明,s002>亮亮,s…

【FPGA】通俗理解从VGA显示到HDMI显示

注&#xff1a;大部分参考内容来自“征途Pro《FPGA Verilog开发实战指南——基于Altera EP4CE10》2021.7.10&#xff08;上&#xff09;” 贴个下载地址&#xff1a; 野火FPGA-Altera-EP4CE10征途开发板_核心板 — 野火产品资料下载中心 文档 hdmi显示器驱动设计与验证 — …

前端list.push,封装多个对象

js var fruit [apple, banana];fruit.push(pear);console.log(fruit); // [apple, banana, pear]现在为对象 data1:{addUser: 1,editUser: 1,addTime: null,editTime: 1527410579000,userId: 3,systemNo: mc,userName: zengzhuo,userPassword: e10adc3949ba59abbe56e057f20f88…

【广州华锐互动】AR远程协助技术提供实时远程协作和指导

随着科技的不断发展&#xff0c;企业的运营管理模式也在不断地进行创新和升级。在这个过程中&#xff0c;AR&#xff08;增强现实&#xff09;技术的应用逐渐成为了企业运维管理的新兴趋势。AR远程协助平台作为一种结合了AR技术和远程协助理念的技术手段&#xff0c;为企业运维…

Netty-NIO

文章目录 一、NIO-Selector1.处理accept2.cancel3.处理read4.处理客户端断开5. 处理消息的边界6. 写入内容过多的问题7. 处理可写事件 二、多线程优化三、NIO概念剖析1. stream 和 channel2. IO模型2.1 阻塞IO2.2 非阻塞IO2.3多路复用2.4 同步异步 3. 零拷贝3.1 NIO优化3.2 sen…

hive葵花宝典:hive函数大全

文章目录 版权声明函数1 函数分类2 查看函数列表3 数学函数取整函数: round指定精度取整函数: round向下取整函数: floor向上取整函数: ceil取随机数函数: rand幂运算函数: pow绝对值函数: abs 4 字符串函数字符串长度函数&#xff1a;length字符串反转函数&#xff1a;reverse…

autoware.ai感知随笔--地面滤波

autwoware.ai中点云预处理–points_preprocessor points_preprocessor cloud_transformer: 点云坐标转换,将输入的点云转化为velodyne坐标系下的点云。 compare_map_filter: 对比激光雷达点云和点云地图&#xff0c;然后提取&#xff08;或去除&#xff09;一致的点。 |input_…

机器学习实战-系列教程7:SVM分类实战2线性SVM(鸢尾花数据集/软间隔/线性SVM/非线性SVM/scikit-learn框架)项目实战、代码解读

&#x1f308;&#x1f308;&#x1f308;机器学习 实战系列 总目录 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 SVM分类实战1之简单SVM分类 SVM分类实战2线性SVM SVM分类实战3非线性SVM 3、不同软间隔C值 3.1 数据标准化的影响 如图左边是没…

css 左右宽固定,中间自适应——双飞翼布局

最近面试的时候遇到一个提问说&#xff0c;如何做到一个左右宽度固定&#xff0c;中间自适应的布局&#xff0c;我的答案不重要&#xff0c;重要的是不是面试官想听到的答案&#xff0c;这样问大概率他想听到的答案一定是双飞翼布局&#xff0c;所以今天就手敲一个双飞翼布局让…

ES-索引管理

前言 数据类型 ​ 搜索引擎是对数据的检索&#xff0c;所以我们先从生活中的数据说起。我们生活中的数据总体分为两种&#xff1a; 结构化数据非结构化数据 结构化数据&#xff1a; 也称作行数据&#xff0c;是由二维表结构来逻辑表达和实现的数据&#xff0c;严格地遵循数…