解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:
先在训练文件中给trainer加一个callback

from smoe.callbacks.save_model import SchedulerStateCallback
trainer.add_callback(SchedulerStateCallback)
class SchedulerStateCallback(TrainerCallback):def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):if os.environ.get("RANK", "0") == "0":#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = scheduler.state_dict()#save_path = os.path.join(args.output_dir, SCHEDULER_NAME)# 使用 PREFIX_CHECKPOINT_DIR 和 global_step 创建检查点目录名checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"# 完整的检查点目录路径checkpoint_path = os.path.join(args.output_dir, checkpoint_folder)# 如果目录不存在,则创建它if not os.path.exists(checkpoint_path):os.makedirs(checkpoint_path)# 完整的保存路径save_path = os.path.join(checkpoint_path, SCHEDULER_NAME)# 保存scheduler状态torch.save(scheduler_state, save_path)def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):# 如果resume_from_checkpoint设置了有效路径if args.resume_from_checkpoint is not None:load_path = os.path.join(args.resume_from_checkpoint, SCHEDULER_NAME)# 如果该路径下有保存的调度器状态,则加载它if os.path.exists(load_path):#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = torch.load(load_path)scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s中的容器

目录 容器容器的状态容器的重启策略pause容器init容器 容器 容器的状态 Running(运行中):容器正在运行并且正常工作。Waiting(等待中):容器正在等待某些条件满足,例如等待其他容器就绪、等待网…

微信小程序——如何获取到输入框的值

在微信小程序中&#xff0c;可以通过以下几种方式来获取输入框的值&#xff1a; 使用 bindinput 绑定输入事件&#xff0c;通过 event.detail.value 获取输入框的值。具体操作如下&#xff1a; <input bindinput"onInput" placeholder"请输入内容">…

Google 开源库Guava详解(集合工具类)—Maps、Multisets、Multimaps

一、Maps Maps有许多很酷的实用程序&#xff0c;值得单独解释。 1、uniqueIndex Maps.uniqueIndex&#xff08;Iterable&#xff0c;Function&#xff09;解决了一个常见的情况&#xff0c;即有一堆对象&#xff0c;每个对象都有一些唯一的属性&#xff0c;并希望能够根据该…

Neo4j 基本语法

一、基本语法 1、新建节点 &#xff08;1&#xff09;基本语法&#xff1a; () 代表节点 示例&#xff1a; CREATE (u:User {uid:970939424 }) // 节点类型为User&#xff0c;属性值为uid970939424CREATE (u:Round {rid:7194842697444819113 }) // 节点类型为Rou…

【广州华锐互动】AR技术在配电系统运维中的应用

随着科技的不断发展&#xff0c;AR(增强现实)技术逐渐走进了我们的生活。在电力行业&#xff0c;AR技术的应用也为巡检工作带来了许多新突破&#xff0c;提高了巡检效率和安全性。本文将从以下几个方面探讨AR配电系统运维系统的新突破。 首先&#xff0c;AR技术可以实现虚拟巡检…

数据仓库-核心概念

数据仓库 数据仓库&#xff0c;英文名称为Data Warehouse&#xff0c;可简写为DW或DWH。数据仓库&#xff0c;是为企业所有级别的决策制定过程&#xff0c;提供所有类型数据支持的战略集合。它是单个数据存储&#xff0c;出于分析性报告和决策支持目的而创建。为需要业务智能的…

<图像处理> 空间滤波基础

空间滤波基础 图像滤波是一种常见的图像处理技术&#xff0c;用于平滑图像、去除噪音和边缘检测等任务。图像滤波的基本原理是在进行卷积操作时&#xff0c;通过把每个像素的值替换为该像素及其邻域的设定的函数值来修改图像。 预备知识&#xff1a;可分离滤波核、边缘填充。…

Vue知识系列(1)每天10个小知识点

目录 系列文章目录知识点**1. Vue修饰符**的概念、作用、原理、特性、优点、缺点、区别、使用场景**2. 双向数据绑定**的概念、作用、原理、特性、优点、缺点、区别、使用场景**3. MVVM、MVC、MVP** 的概念、作用、原理、特性、优点、缺点、区别、使用场景**4. slot** 的概念、…

Android Jetpack架构组件库:Hilt

一、开发者官网关于Hilt库使用链接如下 使用 Hilt 实现依赖项注入 Hilt版本说明 二、工程目录图 请点击下面工程名称&#xff0c;跳转到代码的仓库页面&#xff0c;将工程 下载下来 Demo Code 里有详细的注释 代码&#xff1a;LearnJetpack-hilt&#xff1a;hilt版本2.48 代…

Redis集群3.2.11离线安装详细版本(使用Ruby)

1.安装软件准备 1.Redis版本下载 Index of /releases/http://download.redis.io/releases/ 1.2gcc环境准备 GCC(GNU Compiler Collection,GNU编译器套件)是一套用于编译程序代码的开源编译器工具集。它的主要用途是将高级编程语言(如C、C++、Fortran等)编写的源代码转换…

【项目 计网12】4.32UDP通信实现 4.33广播 4.34组播 4.35本地套接字通信

文章目录 4.32UDP通信实现udp_client.cudp_server.c 4.33广播bro_server.cbro_client.c 4.34组播multi_server.cmulti_client.c 4.35本地套接字通信ipc_server.cipc_client.c 4.32UDP通信实现 udp_client.c #include <stdio.h> #include <stdlib.h> #include <…

vmware网卡(网络适配器)桥接、NAT、仅主机3种模式解析

Bridged&#xff08;桥接模式&#xff09;、NAT&#xff08;网络地址转换模式&#xff09;、Host-Only&#xff08;仅主机模式&#xff09; Windows系统安装好vmware后&#xff0c;在网络连接中会生成VMnet1和VMnet8两个虚拟网卡。 VMnet1作用于仅主机模式&#xff0c;VMnet8作…

目标检测笔记(十五): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)

文章目录 一、目标检测介绍二、YOLOX介绍三、源码获取四、环境搭建4.1 环境检测 五、数据集准备六、模型训练七、模型验证八、模型测试 一、目标检测介绍 目标检测&#xff08;Object Detection&#xff09;是计算机视觉领域的一项重要技术&#xff0c;旨在识别图像或视频中的…

单元测试与自测

单元测试在百度百科的定义&#xff1a; 自测在百度百科的定义&#xff1a; 单元测试是测一个类或一个函数&#xff0c;自立门第main函数&#xff0c;不依赖于项目&#xff0c;预期的是这个类或函数是没有问题的。程序编码完成之后至各种测试再到用户使用一二十年出现的任何bug都…

娱乐时间 —— 用python将图片转为excel十字绘

最近看蛮多朋友在玩&#xff0c;要么只能画比较简单的&#xff0c;要么非常花时间。想了下本质上就是把excel对应的单元格涂色&#xff0c;如果能知道哪些格子要上什么颜色&#xff0c;用编程来实现图片转为excel十字绘应该是很方便的。 图片的每一个像素点都可以数值化&#x…

DRF03-权限与分页

文章目录 1. 认证Authentication2. 权限Permissions使用提供的权限举例自定义权限3. 限流Throttling基本使用可选限流类4. 过滤Filtering5. 排序Ordering6. 分页Pagination可选分页器7. 异常处理 ExceptionsREST framework定义的异常8. 自动生成接口文档coreapi安装依赖设置接口…

Jmeter如何设置中文版

第一步&#xff1a;找到 apache-jmeter-5.4.3\bin目录下的 jmeter.properties 第二步:打开 三&#xff0c;ctrf 输入languageen&#xff0c;注释掉&#xff0c;增加以行修改如下 四&#xff0c;ctrs 保存修改内容&#xff0c;重新打开jmeter就可以了

微信小程序Day2笔记

1、WXML模板语法 1. 数据绑定 数据绑定的基本原则 在data中定义数据在WXML中使用数据 2. 在data中定义页面的数据 在页面对应的.js文件中&#xff0c;把数据定义到data对象中。 3. Mustache语法的格式 把data中的数据绑定到页面中渲染&#xff0c;使用Mustache语法&…

PHP8中获取并删除数组中最后一个元素-PHP8知识详解

在php8中&#xff0c;array_pop()函数将返回数组的最后一个元素&#xff0c;并且将该元素从数组中删除。语法格式如下&#xff1a; array_pop(目标数组) 获取并删除数组中最后一个元素&#xff0c;参考代码&#xff1a; <?php $stu array(s001>明明,s002>亮亮,s…

【FPGA】通俗理解从VGA显示到HDMI显示

注&#xff1a;大部分参考内容来自“征途Pro《FPGA Verilog开发实战指南——基于Altera EP4CE10》2021.7.10&#xff08;上&#xff09;” 贴个下载地址&#xff1a; 野火FPGA-Altera-EP4CE10征途开发板_核心板 — 野火产品资料下载中心 文档 hdmi显示器驱动设计与验证 — …