Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1

1. 使用说明

在megatron中指定--use-distributed-optimizer就能开启分布式优化器, 参数定义在megatron/arguments.py中。分布式优化器的思路是将训练中的优化器状态均匀地分布到不同数据并行的rank结点上,相当于开启ZERO-1的训练。

    group.add_argument('--use-distributed-optimizer', action='store_true',help='Use distributed optimizer.')

在使用--use-distributed-optimizer, 同时会check两个参数 args.DDP_impl == 'local'(默认开启)和args.use_contiguous_buffers_in_local_ddp(默认开启)。

    # If we use the distributed optimizer, we need to have local DDP# and we should make sure use-contiguous-buffers-in-local-ddp is on.if args.use_distributed_optimizer:assert args.DDP_impl == 'local'assert args.use_contiguous_buffers_in_local_ddp

分布式优化器节省的理论显存值依赖参数类型和梯度类型,以下是每一个parameter对应占用的理论字节数(d表示数据并行的size大小,也就是一个数据并行中的卡数, 等于 T P × P P TP \times PP TP×PP ):

训练数据类型Non-distributed optim(单位Byte)Distributed optim(单位Byte)
float16 param, float16 grads204 + 16/d
float16 param, fp32 grads186 + 12/d
fp32 param, fp32 grads168 + 8/d

2. 实现介绍

  • Distributed-Optimizer分布式优化器的主要实现是通过连续的grad buffer来进行的,grad buffer中用于模型状态和优化器状态之间进行parameter参数和grad梯度的通信。grad buffer中使用reduce-scatter和all-gather进行通信。

  • 数据流如下:
    在这里插入图片描述

    1. 在每个dp的rank上计算完grad后,组成待更新的grad buffer数组
    2. 更新的时候通过reduce-scatter将grad buffer切分到各个rank上
    3. 在每个rank上完成优化器的step操作
    4. 最后将所有结果执行allgather操作得到更新后的grad buffer。
  • 以fp16类型grad为例,grad buffer分片说明如下:
    在这里插入图片描述

    • 一共有4个参数,分别用绿/黄/蓝/红表示;总参数大小为16个fp16类型数据
    • 按DP中rank的个数对总数据均匀切分
    • 如果参数过大,每个rank可能会只包含部分参数的数据,所以要考虑参数的偏移
    • 每个DP rank中的每个param参数都对应有3个偏移,一个是world_index表示总的数据偏移,一个是local_index表示在当前rank中的数据偏移,一个是param_index相对于param来说,表示当前rank结点存的数据的偏移。
    • 以黄色参数Param1为例,在rank0存了Param1的一个元素,rank1存了Param1的4个元素;world_index来说rank0上黄色部分的元素是总数据的[3,4], rank1上黄色部分的4个元素是总数据的[4,8]; local_index来说在rank0上表示[3,4],rank1表示当前结点全部的4个元素,范围也就是[0,4];param_index来说,对于rank0上的Param1的param_index就是[0,1],在rank2上的param_index就是[1,5];
  • 关键步骤详解:

    1. 上图中每个方块看成是一个grad buffer中的一个fp16类型元素,在反向结束以后,grad buffer中有16个fp16类型的元素
    2. 在每一个DP rank上调用reduce-scatter操作
    3. 每个DP rank的grad buffer中都有4个fp16类型元素经过了reduce-scatter操作更新,没更新的12个fp16类型元素等待后续垃圾回收
    4. 每个DP rank从grad buffer中拷贝更新后的4个fp16类型元素到fp32类型的main grad buffer中,准备开始后续的更新操作,例如
      • DP rank0拷贝[0:4]个元素
      • DP rank1拷贝[4:8]个元素
      • DP rank2拷贝[8:12]个元素
      • DP rank3拷贝[12:16]个元素
    5. 执行Optimizer.step(), step()操作必须通过fp32类型来进行计算
    6. 每个DP rank从main grad buffer中拷贝step()更新后的4个fp32类型元素到fp16类型的grad buffer中
    7. 执行allgather操作, 这样每个grad buffer就都是最新更新后的数据了
    8. 基于grad buffer来更新各个模型的fp16类型的参数
    9. 开始进行下一轮的更新

3. 源码实现

3.1 程序入口

  • 初始化的入口在文件megatron/training.pyget_model函数中,在创建LocalDDP的实例中会传入args.use_contiguous_buffers_in_local_ddp
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDPdef get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):...if wrap_with_ddp:if args.DDP_impl == 'torch':...elif args.DDP_impl == 'local':model = [LocalDDP(model_module,args.accumulate_allreduce_grads_in_fp32,args.use_contiguous_buffers_in_local_ddp)for model_module in model]...
  • 训练的入口定义在train_step函数中, 基本流程如下:
def train_step(forward_step_func, data_iterator,model, optimizer, opt_param_scheduler):...# 清除gradif args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:for partition in model:partition.zero_grad_buffer()optimizer.zero_grad()...# 执行前反向计算losses_reduced = forward_backward_func(...)...# 对梯度执行Reduce-Scatter操作optimizer.reduce_model_grads(args, timers)...# 更新梯度timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)timers('optimizer').stop()...# 对更新后的param执行gather操作if update_successful:optimizer.gather_model_params(args, timers)...# 通过scheduler更新学习率if update_successful:increment = get_num_microbatches() * \args.micro_batch_size * \args.data_parallel_sizeopt_param_scheduler.step(increment=increment)skipped_iter = 0else:skipped_iter = 1...

3.2 grad buffer初始化(DistributedDataParallel类)

  • grad buffer初始化是在类DistributedDataParallel的init函数中, 源码定义在megatron/optimizer/distrib_optimizer.py文件中。
class DistributedDataParallel(DistributedDataParallelBase):def __init__(self, module,accumulate_allreduce_grads_in_fp32,use_contiguous_buffers):
  • 创建grad buffer和index map
            self._grad_buffers = {}self._grad_buffer_param_index_map = {}data_parallel_world_size = mpu.get_data_parallel_world_size()
  • 按类型分别计算每个类型元素的个数,使用type_num_elements map进行存储,key是元素类型,value是类型出现的元素个数
            # First calculate total number of elements per type.type_num_elements = {}for param in self.module.parameters():if param.requires_grad:dtype = _get_buffer_type(param)type_num_elements[dtype] = type_num_elements.get(dtype, 0) \+ param.data.nelement()
  • 实际开始分配grad buffer, 为了支持被DP并行数正好切分,需要先对每个类型出现的个数进行padding操作;然后通过MemoryBuffer进行存储的分配
            # Allocate the buffer.for dtype, num_elements in type_num_elements.items():# If using distributed optimizer, pad memory buffer to be# multiple of data_parallel_world_size. (This padding is done# due to a constraint with the reduce_scatter op, which requires# all tensors have equal size. See: optimizer.py.)num_elements_padded = data_parallel_world_size * \int(math.ceil(num_elements / data_parallel_world_size))# Allocate grad buffer.self._grad_buffers[dtype] = MemoryBuffer(num_elements,num_elements_padded,dtype)
  • 从grad buffer中给每一个param参数分配对应的main_grad空间,在分配main_grad时根据每个param参数的类型从对应的self._grad_buffers[dtype]中得到跟param.data.shape一样的tensor,这里的tensor与grad buffer共享存储。同时grad buffer的分配是按倒序来分配的,比如self.module.parameters()中有三个参数分别是[p1, p2, p3], 在grad buffer中存储则是[p3_grad, p2_grad, p1_grad]_grad_buffer_param_index_map用来记录每个param的梯度在grad buffer中存储的起始和结束位置。
            ...# Assume the back prop order is reverse the params order,# store the start index for the gradients.for param in self.module.parameters():if param.requires_grad:dtype = _get_buffer_type(param)type_num_elements[dtype] -= param.data.nelement()# get的第二个参数是start_index,这里的start_index是从grad_buffer从大到小来算的param.main_grad = self._grad_buffers[dtype].get(param.data.shape, type_num_elements[dtype])if dtype not in self._grad_buffer_param_index_map:self._grad_buffer_param_index_map[dtype] = {}self._grad_buffer_param_index_map[dtype][param] = (type_num_elements[dtype],type_num_elements[dtype] + param.data.nelement(),)
  • 遍历每一个参数,对于每一个参数的grad_fn的下一个function累加grad_acc函数进行改写,由于param本身没有grad_fn,通过trick方式使用param.expand_as给param加上了grad_fn函数。
            ...# Backward hook.# Accumalation function for the gradients. We need# to store them so they don't go out of scope.self.grad_accs = []# Loop over all the parameters in the model.for param in self.module.parameters():if param.requires_grad:# 使用expand_as使param具有grad_fn.param_tmp = param.expand_as(param)# 获取梯度累加函数,并注册hook改写grad_acc = param_tmp.grad_fn.next_functions[0][0]grad_acc.register_hook(self._make_param_hook(param))self.grad_accs.append(grad_acc)def _make_param_hook(self, param):"""Create the all-reduce hook for backprop."""# Hook used for back-prop.def param_hook(*unused):# Add the gradient to the buffer.if param.grad is not None:# The gradient function of linear layers is fused with GEMMsparam.main_grad.add_(param.grad.data)# Now we can deallocate grad memory.param.grad = Nonereturn param_hook

4. 参考

  • Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/588601.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringIOC之ClassPathXmlApplicationContext

博主介绍:✌全网粉丝5W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验…

PostgreSQL16.1(Windows版本)

1、卸载原有的PostgreSQL   点击Next即可。  点击OK即可。 卸载完成。 2、安装 (1) 前两部直接Next,第二部可以换成自己想要安装的路径。 (2) 直接点击Next。…

雪花算法(Snowflake)介绍和Java实现

1、雪花算法介绍 (1) 雪花算法(SnowFlake)是分布式微服务下生成全局唯一ID,并且可以做到去中心化的常用算法,最早是Twitter公司在其内部的分布式环境下生成ID的方式。 雪花算法的名字可以这么理解,世界上没有两片完全相同的雪花,…

Select缺点及代码示例

一、Select缺点 二、服务器端 #include <stdio.h> #include <arpa/inet.h> #include <unistd.h> #include <stdlib.h> #include <string.h> #include <sys/select.h>int main() {// 创建socketint lfd socket(PF_INET, SOCK_STREAM, 0)…

006、函数

1. 一个小技巧 在前面文章中&#xff0c;我们提到&#xff0c;在黑窗口中输入 code . 命令可以快速在 Visual Studio Code 中打开新建的项目&#xff0c;这个是你刚刚新建了项目&#xff0c;并且黑窗口正好是打开的情况下。 如果是之前创建的项目&#xff0c;用上面的方法就会有…

git(安装,常用命令,分支操作,gitee,IDEA集成git,IDEA集成gitee,IDEA集成github,远程仓库操作)

文章目录 1. Git概述1.1 何为版本控制1.2 为什么需要版本控制1.3 版本控制工具1.4 Git简史1.5 Git工作机制1.6 Git和代码托管中心 2. Git安装3. Git常用命令3.1 设置用户签名3.1.1 说明3.1.2 语法3.1.3 案例实操 3.2 初始化本地库3.2.1 基本语法3.2.2 案例实操3.2.3 结果查看 3…

【Java】log4j和slf4j区别

log4j&#xff1a;Apache Software Foundation 开源 slf4j&#xff1a;不支持日志滚动等高级功能 在开源库或内部库中使用 SLF4J&#xff0c;将使其独立于任何特定的日志记录实现&#xff0c;这意味着无需为多个库管理多个日志记录配置&#xff0c;您的客户端将会很需要这一点…

【k8s】deamonset文件和说明

目录 deamonset的相关命令 deamonset的定义 deamonset的使用场景 deamonset的例子 deamonset字段说明 serviceAccountName DaemonSet的结构及其各个部分的作用 deamonset的相关命令 #查看<name-space>空间内有哪些deamonset kubectl get DaemonSet -n <na…

Django 学习教程- Django 入门案例

Django学习教程系列 Django学习教程-介绍与安装 前言 本教程是为 Django 5.0 编写的&#xff0c;它支持 Python 3.10 至以上。如果 Django 版本不匹配&#xff0c;可以参考教程 使用右下角的版本切换器来获取你的 Django 版本 &#xff0c;或将 Django 更新到最新版本。如果…

Winclone Pro 10 for Mac:轻松备份和还原你的Windows系统

Winclone Pro 10 for Mac是一款专为Mac用户设计的备份和还原软件&#xff0c;旨在帮助用户轻松管理和保护他们的Windows系统。无论是为了数据安全还是系统的稳定性&#xff0c;Winclone Pro 10都能提供全面的解决方案。 这款软件具备强大的备份功能&#xff0c;能够快速而准确…

Java流程控制语句(if语句,switch语句,for循环,while循环,do...while循环,三种循环的区别)

文章目录 第一章 流程控制语句1.1 流程控制语句分类1.2 顺序结构 第二章 判断语句&#xff1a;if语句2.1 if语句格式1练习1&#xff1a;老丈人选女婿练习2&#xff1a;考试奖励第一种格式的细节&#xff1a; 2.2 if语句格式2练习1&#xff1a;吃饭练习2&#xff1a;影院选座 2.…

AI产品经理 - 如何做一款软硬协同AI产品

【背景】从0做一款软硬协同的AI产品&#xff0c;以智能医药保温箱 1.以智能医药保温箱 2.调研定义市场方向 地点&#xff1a;医药、实验室 场景&#xff1a;长宽高/装箱/运输/实验室 3.需求挖掘 4.如何进行软硬件AI产品工作 软硬件产品设计&#xff1a;功能/硬件外观设计、…

SetWindowsHookEx: 全局钩子实现键盘记录器

简介 SetWindowsHookEx 钩子(Hook)&#xff0c;是Windows消息处理机制的一个平台&#xff0c;应用程序可以在上面设置子程以监视指定窗口的某种消息&#xff0c;而且所监视的窗口可以是其他进程所创建的。当消息到达后&#xff0c;在目标窗口处理函数之前处理它。钩子机制允许应…

Hive生产调优介绍

1.Fetch抓取 Fetch抓取是指&#xff0c;Hive中对某些情况的查询可以不必使用MapReduce计算。例如&#xff1a;SELECT * FROM employees;在这种情况下&#xff0c;Hive可以简单地读取employee对应的存储目录下的文件&#xff0c;然后输出查询结果到控制台。 在hive-default.xml…

云卷云舒:构建业务型电信智能运维方法

1 引言 智能运维&#xff08;AIOps-Algorithmic IT Operations基于算法的IT运维&#xff09;是人工智能技术在IT运维领域的运用&#xff0c;引用Gartner 的报告的一段话“未来几年&#xff0c;将近50%的企业将会在他们的业务和IT运维方面采用AIOps&#xff0c;远远高于今天的10…

php的laravel权限问题

1.这是我新建的一个路由&#xff0c;然后就是说每新建一个路由都要给他开个权限&#xff01;&#xff01;&#xff01;&#xff01; 2.这个是组内大佬写的&#xff1a; 我们也可以在里面加&#xff0c;也可以在浏览器的页面手动加&#xff08;对我们新手来说还是浏览器的页面…

matlab导出高清图片,须经修改后放入latex(例如添加文字说明,matlab画图不易操作)

一、背景 我们在写文章时&#xff0c;使用matlab画图后&#xff0c;如果不需要对图片进行额外修改或调整&#xff0c;例如添加文字说明&#xff0c;即可直接从matlab导出eps格式图片&#xff0c;然后插入到latex使用。 通常latex添加图片&#xff0c;是需要eps格式的。 但很…

微服务实战系列之Dubbo(下)

前言 眼看着2023即将走远&#xff0c;心里想着似乎还有啥&#xff0c;需要再跟各位盆友叨叨。这不说曹操&#xff0c;曹操就来了。趁着上一篇Dubbo博文的余温尚在&#xff0c;博主兴匆匆地“赶制”了Dubbo的下集&#xff0c;以飨读者。 上一篇博主依然从Dubbo的内核出发&#…

原型链补充

1.什么是原型对象 函数的独有属性,他用prototype来表示,可以在函数的prototype上挂载一些公用的属性和方法,供实例化对象来访问。 2.__proto__属性 这个属性每一个对象都有,实例化对象就是通过这个属性,来访问原型对象上的属性和方法的。 3.三者之间的关系 1.在构造函数的原型…

PTA——计算火车运行时间

本题要求根据火车的出发时间和达到时间&#xff0c;编写程序计算整个旅途所用的时间。 输入格式&#xff1a; 输入在一行中给出2个4位正整数&#xff0c;其间以空格分隔&#xff0c;分别表示火车的出发时间和到达时间。每个时间的格式为2位小时数&#xff08;00-23&#xff0…