跟代码执行流程,读Megatron源码(二)训练入口pretrain_gpt.py

  Megatron-LM默认支持GPT、T5、BERT等多个常见模型的预训练,当下大模型流行,故以pretrain_gpt.py为例做源码的走读。

一. 启动pretrain_gpt.py

  pretrain_gpt.py为GPT类模型的训练入口,它通过命令行形式被调用,其精确执行路径位于Megatron-LM框架的examples/gpt3目录下。具体而言,启动过程依赖于train_gpt3_175b_distributed.sh这一脚本,该脚本专为部署GPT-3模型在分布式环境下训练而设计(当然,也可以参照编写自定义的启动脚本)。

  在train_gpt3_175b_distributed.sh脚本内部,核心操作是通过torchrun命令实现的,该命令是PyTorch分布式训练的一部分,用于在多个计算节点上高效并行地执行pretrain_gpt.py。此过程确保了模型训练任务能够充分利用集群资源,加速训练过程,代码如下图:

二. torchrun简介

  trochrun是PyTorch官方推荐用于替代torch.distributed.launch的分布式数据并行训练模块。它旨在提供一种更灵活、更健壮的方式来启动和管理分布式训练任务。

  trochrun启动并行训练任务的原理如下:

1. 初始化分布式环境

  trochrun首先负责初始化分布式训练所需的环境。这包括设置通信后端(如NCCL、GLOO等)、分配工作进程的RANK和WORLD_SIZE(即参与训练的总进程数),以及处理其他与分布式训练相关的配置。

2. 分配工作进程

  trochrun会根据指定的参数(如--nnodes、--nproc-per-node等)来分配工作进程。这些进程可以是同一台机器上的多个 GPU,也可以是跨多台机器的GPU。每个进程都会加载相同的训练脚本(如pretrain_gpt.py),但会处理不同的数据子集,以实现并行训练。

3. 同步与通信

  在训练过程中,torchrun 管理下的各个工作进程需要频繁地进行同步和通信。这包括梯度同步(在反向传播后同步各GPU上的梯度)、参数更新(使用同步后的梯度更新模型参数)等。PyTorch提供了丰富的API(如torch.distributed.all_reduce、torch.distributed.barrier等)来支持这些操作。

4. 优雅处理故障

  trochrun相比torch.distributed.launch的一大改进是它能够更优雅地处理工作进程的故障。例如,如果某个工作进程因为某种原因崩溃了,torchrun可以尝试重新启动该进程,以确保训练任务的连续性。此外,torchrun还支持弹性训练(elastic training),即允许在训练过程中动态地增加或减少工作进程的数量。

5. 简化配置与启动

  trochrun通过提供命令行接口和配置文件选项来简化分布式训练的配置和启动过程。用户只需指定少量的参数(如节点数、每节点进程数等),即可启动复杂的分布式训练任务。此外,torchrun还支持从环境变量中读取配置信息,这使得在不同环境中部署训练任务变得更加灵活。

6. 自动化资源分配

  在某些情况下,torchrun还可以与资源管理器(如Kubernetes、Slurm等)集成,以自动化地分配和管理训练所需的计算资源。这包括GPU、CPU、内存和存储等资源。通过集成资源管理器,torchrun可以进一步提高分布式训练的可扩展性和灵活性。

  总之,torchrun通过以上机制共同作用,使得使用PyTorch进行分布式训练变得更加高效、可靠和易于管理。

三. 主要函数

  pretrain_gpt.py脚本封装了多个核心功能组件,具体包括model_provider(),forward_step(),train_valid_test_datasets_provider(),以及pretrain()等主要函数。

  其中,model_provider()负责提供预训练所需的模型实例对象;forward_step()定义了模型前向传播的具体步骤,包括输入处理、模型计算等;train_valid_test_datasets_provider()则负责准备训练、验证及测试所需的数据集,确保数据的有效供给。

  值得注意的是,前三个函数model_provider(),forward_step(),train_valid_test_datasets_provider()更是作为pretrain()函数的入参,共同构成了GPT模型训练入口。

  这种设计确保了预训练过程的模块化、灵活性与可扩展性。下面会从model_provider()开始逐行解析源码。

四. 源码分析

1. model_provider

  def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:"""Builds the model.If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.Args:pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.Returns:Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model"""args = get_args()use_te = args.transformer_impl == "transformer_engine"print_rank_0('building GPT model ...')# Experimental loading arguments from yamlif args.yaml_cfg is not None:config = core_transformer_config_from_yaml(args, "language_model")else:config = core_transformer_config_from_args(args)if args.use_legacy_models:model = megatron.legacy.model.GPTModel(config,num_tokentypes=0,parallel_output=True,pre_process=pre_process,post_process=post_process,)else: # using core modelsif args.spec is not None:transformer_layer_spec = import_module(args.spec)else:if use_te:transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)else:transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)model = GPTModel(config=config,transformer_layer_spec=transformer_layer_spec,vocab_size=args.padded_vocab_size,max_sequence_length=args.max_position_embeddings,pre_process=pre_process,post_process=post_process,fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,parallel_output=True,share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,position_embedding_type=args.position_embedding_type,rotary_percent=args.rotary_percent,)return model

  model_provider函数是用于构建GPT(生成预训练Transformer)模型实例的函数,它会以类似函数指针的形式,作为pretrain()的入参传递到后续的训练代码中,供训练过程调用。

  该函数主要代码流程包含以下几个步骤:

  a. 获取参数和配置:

  通过 get_args() 函数获取命令行参数和配置文件中的参数。根据 args.transformer_impl 的值确定是否使用 Transformer Engine (use_te)。

  b. 配置模型:

  如果指定了 YAML 配置文件 (args.yaml_cfg),则从 YAML 文件中加载模型结构。否则,根据命令行参数 (args) 加载。

  c. 选择模型类型:

  如果args.use_legacy_models为True,则使用megatron.legacy.model.GPTModel构建模型。这通常用于向后兼容或测试旧版本的模型。

  如果不使用旧版模型,则直接运行构建GPTModel,如图红框部分。

  其中,GPTModel的参数包括配置 (config)、词汇表大小 (vocab_size)、最大序列长度 (max_sequence_length)、是否进行前处理和后处理、是否使用 FP16 进行语言模型交叉熵计算、是否并行输出等。这些参数基本上都来源于配置文件,关于配置文件的内容和解析将于下文详述。

  注:此处的model_provider是作为函数指针传递到pretrain()中,函数指针只有在调用时才会真正执行,故,GPTModel的具体实现待到执行时再具体分析。

2. forward_step

def forward_step(data_iterator, model: GPTModel):"""Forward training step.Args:data_iterator : Input data iteratormodel (GPTModel): The GPT Model"""args = get_args()timers = get_timers()# Get the batch.timers('batch-generator', log_level=2).start()global stimerwith stimer(bdata=True):tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)timers('batch-generator').stop()with stimer:output_tensor = model(tokens, position_ids, attention_mask,labels=labels)return output_tensor, partial(loss_func, loss_mask)

  forward_step顾名思义,这个函数是GPT模型训练过程中前向处理函数,负责处理一批输入数据并通过模型进行前向传播。

  该函数的核心实现,仍然是对model(forward_step函数的入参)的forward的调用,只是在调用之前封装了计时器计时逻辑以及批次数据获取的逻辑(这部分逻辑会根据不同的业务场景变化而变化,故不能直接封装到model的forward函数中,而是应该在pretrain脚本中实现),具体代码流程如下:

  a. 获取参数和计时器:

  通过get_args()和get_timers()函数分别获取训练参数和计时器对象,用于控制训练过程和记录时间消耗。

  b. 获取批次数据:

  使用timers对象记录获取批次数据的时间(可选,通过log_level=2控制)。

  调用get_batch函数从data_iterator中获取一批数据,包括tokens(输入文本对应的token IDs)、labels(训练标签,通常用于计算损失,对于语言模型任务,labels通常是tokens的右移一位版本)、loss_mask(损失掩码,用于忽略某些位置的损失计算,如填充位置)、attention_mask(注意力掩码,用于指示哪些位置需要参与注意力计算)和position_ids(位置ID,用于模型中的位置编码)。

  c. 模型前向传播:

  使用stimer(可能是一个自定义的计时器)记录模型前向传播的时间。

  将获取到的数据(tokens, position_ids, attention_mask, labels)传递给模型model进行前向传播。这里labels是可选的,用于计算损失,但在前向传播阶段不一定需要。

  模型输出output_tensor,通常包含模型的预测结果(如logits)。

  d. 返回输出和损失函数:

  返回output_tensor和partial函数,该函数需要loss_mask作为参数来计算损失。这种方式允许延迟损失的计算,直到所有相关的数据都已准备好。

3. train_valid_test_datasets_provider

def train_valid_test_datasets_provider(train_val_test_num_samples):"""Build the train test and validation datasets.Args:train_val_test_num_samples : A list containing the number of samples in train test and validation."""args = get_args()config = core_gpt_dataset_config_from_args(args)if args.mock_data:dataset_type = MockGPTDatasetelse:dataset_type = GPTDatasetprint_rank_0("> building train, validation, and test datasets for GPT ...")train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(dataset_type,train_val_test_num_samples,is_dataset_built_on_rank,config).build()print_rank_0("> finished creating GPT datasets ...")return train_ds, valid_ds, test_ds

  该函数接收一个参数train_val_test_num_samples,这是一个列表,包含了训练集、验证集和测试集的样本数量。函数的目的是根据提供的参数和配置,构建GPT模型的训练、验证和测试数据集。主要代码流程如下:

  a. 获取参数和配置:

  使用get_args()函数获取训练过程中的全局参数,并通过core_gpt_dataset_config_from_args根据这些参数生成数据集配置对象config。

  b. 确定数据集类型:

  根据args.mock_data的值决定使用哪种数据集类型。如果mock_data为True,则使用MockGPTDataset,这是一种模拟数据集,可能用于测试或快速原型开发。如果mock_data为False,则使用GPTDataset,这是实际的数据集类型,包含真实的训练数据。

  c. 构建数据集:

  使用BlendedMegatronDatasetBuilder类来构建数据集,传递给BlendedMegatronDatasetBuilder的参数包括数据集类型dataset_type、训练/验证/测试集的样本数量train_val_test_num_samples、is_dataset_built_on_rank(用于检查当前处理单元是否负责构建数据集),以及配置对象config。

  调用build()方法实际构建数据集,该方法返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds。

  其中BlendedMegatronDatasetBuilder来源于包“megatron.core.datasets.blended_megatron_dataset_builder”,由于数据集构建逻辑比较简单,故,在此不做详述,有兴趣的同学可以自行查看。

  d. 返回值

  函数返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds,这些对象可以用于后续的训练、验证和测试过程。

4. pretrain

  pretrain函数是megatron/pretrain_gpt.py文件中的一个执行入口,通常会将该函数写于文件的末尾。该函数被第一章中的启动脚本调用,进而开启训练流程。

  pretrain函数的入参如下:

  train_valid_test_datasets_provider:这是第3小节分析的函数指针,负责提供训练、验证和测试数据集。

  model_provider:这是第1小节分析的函数指针,负责提供GPT模型的实例。它可能根据传入的配置或参数来初始化模型。

  ModelType.encoder_or_decoder:这个参数指定了模型的类型,这里是编码器或解码器(对于GPT模型,它实际上是一个解码器)。

  forward_step:这是第2小节分析的函数指针,定义了模型训练过程中的一个前向传播步骤,包括数据的前向传递和损失的计算。

  args_defaults:这是一个字典,包含了预训练过程中一些默认参数的键值对。在这个例子中,它指定了默认的tokenizer_type为GPT2BPETokenizer,这意味着在文本预处理时将使用基于BPE(Byte Pair Encoding)的GPT-2分词器。

  至此,pretrain_gpt.py的源码基本解析完毕,下一篇文章将以pretrain函数为入口,跟随代码运行流程,深入其内部实现,详细解析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络通信基础概念

目录 1、网络通信的本质 2、网络的发展 3、网络协议(TCP\IP协议) 3.1 协议实现通信的原理 3.2 协议的具体概念 3.3 协议的模型 4、数据链路层 5、网络协议栈和操作系统的关系 6、网络协议通信过程 6.1 通信过程的封装与解包 7、以太网通信…

Ai绘画变现的14种途径 学习Stablediffusion midjourney用途

AIGC,一个在当代社会中不可忽视的词汇,指的是利用人工智能技术生成创作内容。近年来,全球范围内涌现出50个热门的AI工具,其中,以140亿次访问量雄踞榜首的“GBT”,无疑是AI领域的领头羊。在这些工具中&#…

DETR目标检测模型训练自己的数据集

前言 基础环境:ubuntu20.04、python3.8、pytorch:1.10.0、CUDA:11.3 代码地址:https://github.com/facebookresearch/detr 目录 一、训练准备1、预训练模型下载2、txt文件转为coco模式 二、修改训练模型参数三、开始训练四、实现DETR的推理 一、训练准备…

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案 一, 文章简介二,硬件平台构建2.1 音频源板2.2 音频收发板2.3 双板硬件连接 三,软件方案与软件实现3.1 方案实现3.2 软件代码实现3.2.1 4路I2S接收3.2.2 I2S DMA pingpong配置3.2.3 音频数…

Python自动化批量下载ECWMF和GFS最新预报数据脚本

一、白嫖EC和GFS预报数据 EC的openData部分公开了一部分预报数据,作为普通用户只能访问这些免费预报数据,具体位置在这 可以发现,由于是Open Data,我们只能获得临近四天的预报结果,虽然时间较短,但是我们…

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转!这一次,做两件事。第一件事是把二级分类页面的跳转(也就是路由)设计一下。第二件事是把二级页面的面包屑导航设计一下。 第一件事,二级页面的跳转路由设计一下。 如…

Python爬虫(4) --爬取网页图片

文章目录 爬虫爬取图片指定url发送请求获取想要的数据数据解析定位想要内容的位置存放图片 完整代码实现总结 爬虫 Python 爬虫是一种自动化工具,用于从互联网上抓取网页数据并提取有用的信息。Python 因其简洁的语法和丰富的库支持(如 requests、Beaut…

科普文:后端性能优化的实战小结

一、背景与效果 ICBU的核心沟通场景有了10年的“积累”,核心场景的界面响应耗时被拉的越来越长,也让性能优化工作提上了日程,先说结论,经过这一波前后端齐心协力的优化努力,两个核心界面90分位的数据,FCP平…

Day05-readinessProbe探针,startupProbe探针,Pod生命周期,静态Pod,初始化容器,rc控制器的升级和回滚,rs控制器精讲

Day05-readinessProbe探针,startupProbe探针,Pod生命周期,静态Pod,初始化容器,rc控制器的升级和回滚,rs控制器精讲 0、昨日内容回顾1、readinessProbe可用性检查探针之exec案例2、可用性检查之httpGet案例3…

[数据集][目标检测]躺坐站识别检测数据集VOC+YOLO格式9488张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):9488 标注数量(xml文件个数):9488 标注数量(txt文件个数):9488 标注…

C语言 | Leetcode C语言题解之第242题有效的字母异位词

题目&#xff1a; 题解&#xff1a; bool isAnagram(char* s, char* t) {int len_s strlen(s), len_t strlen(t);if (len_s ! len_t) {return false;}int table[26];memset(table, 0, sizeof(table));for (int i 0; i < len_s; i) {table[s[i] - a];}for (int i 0; i &…

EMQX 跨域集群:增强可扩展性,打破地域限制

跨域集群的概念 提到 EMQX&#xff0c;人们通常首先会想到它的可扩展性。尽管 EMQX 能随着硬件数量的增加几乎实现线性扩展&#xff0c;但在单个计算实例上的扩展能力终究有限&#xff1a;资源总会耗尽&#xff0c;升级成本也会急剧上升。这时&#xff0c;分布式部署就显得尤为…

JavaScript(11)——对象

对象 声明&#xff1a; let 对象名 { 属性名&#xff1a;属性值, 方法名&#xff1a;函数 } let 对象名 new Object() 对象的操作 先创建一个对象 let op {name:jvav,id:4,num:1001} 查 对象名.属性 console.log(op.name) 对象名[属性名] 改 对象名.属性 新值 op.name …

Pytorch学习笔记day4——训练mnist数据集和初步研读

该来的还是来了hhhhhhhhhh&#xff0c;基本上机器学习的初学者都躲不开这个例子。开源&#xff0c;数据质量高&#xff0c;数据尺寸整齐&#xff0c;问题简单&#xff0c;实在太适合初学者食用了。 今天把代码跑通&#xff0c;趁着周末好好的琢磨一下里面的各种细节。 代码实…

Spring MVC的高级功能——拦截器(三)拦截器的执行流程

一、单个拦截器的执行流程 如果在项目中只定义了一个拦截器&#xff0c;单个拦截器的执行流程如图所示。 二、单个拦截器的执行流程分析 从单个拦截器的执行流程图中可以看出&#xff0c;程序收到请求后&#xff0c;首先会执行拦截器中的preHandle()方法&#xff0c;如果preHa…

bug诞生记——动态库加载错乱导致程序执行异常

大纲 背景问题发生问题猜测和分析过程是不是编译了本工程中的其他代码是不是有缓存是不是编译了非本工程的文件是不是调用了其他可执行文件查看CMakefiles分析源码检查正在运行程序的动态库 解决方案 这个案例发生在我研究ROS 2的测试Demo时发生的。 整体现象是&#xff1a;修改…

聊一聊前端动画的种类,以及动画的触发方式有哪些?

引言 动画在前端开发中扮演着重要的角色。它不仅可以提升用户体验&#xff0c;还可以使界面更加生动和有趣。在这篇文章中&#xff0c;我们将深入探讨前端动画的各种实现方式&#xff0c;包括 CSS 动画、JavaScript 动画、SVG 动画等。我们还将讨论一些触发动画的方式和动画在…

【MQTT(2)】开发一个客户端,ubuntu版本

基本流程如下&#xff0c;先生成Mosquitto的库&#xff0c;然后qt调用库进行开发界面。 文章目录 0 生成库1 有界面的QT版本2 无界面版本 0 生成库 下载源码&#xff1a;https://github.com/eclipse/mosquitto.git 编译ubuntu 版本很简单&#xff0c;安装官方说明直接make&am…

rk3568 OpenHarmony4.1 Launcher定制开发—桌面壁纸替换

Launcher 作为系统人机交互的首要入口&#xff0c;提供应用图标的显示、点击启动、卸载应用&#xff0c;并提供桌面布局设置以及最近任务管理等功能。本文将介绍如何使用Deveco Studio进行单独launcher定制开发、然后编译并下载到开发板&#xff0c;以通过Launcher修改桌面背景…

记录|如何打包C#项目

参考文章&#xff1a; c#窗体应用程序怎么打包 经过检验确实有效 Step1. 生成发布文件 在Visual Studio的菜单中&#xff0c;找到“生成”->“发布” 第一次会有个向导&#xff0c;基本上一路next下来既可以 最后&#xff0c;点击完成即可以 Step2. 获得publish文件 自…