使用FP8加速PyTorch训练

现代的人工智能硬件架构(例如,Nvidia Hopper, Nvidia Ada Lovelace和Habana Gaudi2)中,FP8张量内核能够显著提高每秒浮点运算(FLOPS),以及为人工智能训练和推理工作负载提供内存优化和节能的机会。

在这篇文章中,我们将介绍如何修改PyTorch训练脚本,利用Nvidia H100 GPU的FP8数据类型的内置支持。这里主要介绍由Transformer Engine库公开的fp8特定的PyTorch API,并展示如何将它们集成到一个简单的训练脚本中。(我们这里只介绍如何使用FP8,不会介绍FP8具体的理论知识)

随着人工智能模型变得越来越复杂,训练它们所需的机器也越来越复杂。Nvidia H100 GPU据称支持“前所未有的性能和可扩展性”。

在AWS中,H100 gpu是作为AWS EC2 p5实例的一个组件提供的。这些实例声称“与上一代基于gpu的EC2实例相比,可将解决方案的时间加快4倍,并将训练ML模型的成本降低高达40%”。

当涉及到机器学习训练实例时,并不总是越大越好。p5实例族尤其如此。p5可能会比其他实例要快很多,因为H100是无可争议的性能野兽。但是一旦考虑到p5的成本(8-GPU p5.48xlarge实例的成本为每小时98.32美元),你可能会发现其他实例类型更适合。

下面我们将在p5.48xlarge上训练一个相对较大的计算机视觉模型,并将其性能与p4d进行比较。p4d.24xlarge包含8个Nvidia A100 gpu。

模型

我们定义了一个Vision Transformer (ViT)支持的分类模型(使用流行的timm Python包版本0.9.10)以及一个随机生成的数据集。ViT主干有多种形状和大小。我们选择了通常被称为ViT-Huge的配置-具有6.32亿个参数-这样能够更好地利用H100对大型模型的容量。

 import torch, timeimport torch.optimimport torch.utils.dataimport torch.distributed as distfrom torch.nn.parallel.distributed import DistributedDataParallel as DDPimport torch.multiprocessing as mp# modify batch size according to GPU memorybatch_size = 64from timm.models.vision_transformer import VisionTransformerfrom torch.utils.data import Dataset# use random dataclass FakeDataset(Dataset):def __len__(self):return 1000000def __getitem__(self, index):rand_image = torch.randn([3, 224, 224], dtype=torch.float32)label = torch.tensor(data=[index % 1000], dtype=torch.int64)return rand_image, labeldef mp_fn(local_rank, *args):# configure processdist.init_process_group("nccl",rank=local_rank,world_size=torch.cuda.device_count())torch.cuda.set_device(local_rank)device = torch.cuda.current_device()# create dataset and dataloadertrain_set = FakeDataset()train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,num_workers=12, pin_memory=True)# define ViT-Huge modelmodel = VisionTransformer(embed_dim=1280,depth=32,num_heads=16,).cuda(device)model = DDP(model, device_ids=[local_rank])# define loss and optimizercriterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)model.train()t0 = time.perf_counter()summ = 0count = 0for step, data in enumerate(train_loader):# copy data to GPUinputs = data[0].to(device=device, non_blocking=True)label = data[1].squeeze(-1).to(device=device, non_blocking=True)# use mixed precision to take advantage of bfloat16 supportwith torch.autocast(device_type='cuda', dtype=torch.bfloat16):outputs = model(inputs)loss = criterion(outputs, label)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()# capture step timebatch_time = time.perf_counter() - t0if step > 10:  # skip first stepssumm += batch_timecount += 1t0 = time.perf_counter()if step > 50:breakprint(f'average step time: {summ/count}')if __name__ == '__main__':mp.spawn(mp_fn,args=(),nprocs=torch.cuda.device_count(),join=True)

我们使用专用PyTorch 2.1 AWS深度学习容器(763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-ec2)在p5.48xlarge和p4d上都训练了这个模型。

p5的性能远远超过了p4d的性能——每步0.199秒比0.41秒——快了两倍多!!这意味着训练大型机器学习模型的时间将减少一半。但是当你考虑到成本的差异(p4d每小时32.77美元,p5每小时98.32美元),p5的性价比比p4d差30% !!

在这一点上,可能会得出两个可能的结论之一。第一种可能性是,尽管有这么多宣传,但p5根本不适合您。第二个是p5仍然是可行的,但是需要对模型进行调整,充分利用它的潜力。

FP8与Transformer Engine的集成

PyTorch(版本2.1)不包括FP8数据类型。为了将我们的脚本编程为使用FP8,我们将使用Transformer Engine (TE),这是一个用于在NVIDIA gpu上加速Transformer模型的专用库。TE(版本0.12)预装在AWS PyTorch 2.1 DL容器中。

使用FP8的机制比16位(float16和bfloat16)要复杂得多。TE库实现向用户隐藏了所有杂乱的细节。有关如何使用TE api的说明(请参阅官方文档)。

为了修改我们的模型以使用TE,我们将TE的专用Transformer层,所以需要我们自己写一个包装器:

 import transformer_engine.pytorch as tefrom transformer_engine.common import recipeclass TE_Block(te.transformer.TransformerLayer):def __init__(self,dim,num_heads,mlp_ratio=4.,qkv_bias=False,qk_norm=False,proj_drop=0.,attn_drop=0.,init_values=None,drop_path=0.,act_layer=None,norm_layer=None,mlp_layer=None):super().__init__(hidden_size=dim,ffn_hidden_size=int(dim * mlp_ratio),num_attention_heads=num_heads,hidden_dropout=proj_drop,attention_dropout=attn_drop)

然后修改VisionTransformer初始化自定义块:

   model = VisionTransformer(embed_dim=1280,depth=32,num_heads=16,block_fn=TE_Block).cuda(device)

到目前为止,还没有做任何针对h100特定的更改-相同的代码可以在我们的a100的p4d实例类型上运行。最后一个修改是用te包裹模型前向传递。Fp8_autocast上下文管理器。此更改需要支持FP8的GPU:

 with torch.autocast(device_type='cuda', dtype=torch.bfloat16):with te.fp8_autocast(enabled=True):outputs = model(inputs)loss = criterion(outputs, label)

关于使用FP8的一些注意事项

使用8位浮点表示(相对于16位或32位表示)意味着较低的精度和较低的动态范围。这些可以对模型收敛的可达性和/或速度产生有意义的影响,但不能保证这将适用于所有的模型。所以可能需要调整底层FP8机制(例如,使用TEapi),调整一些超参数,和/或将FP8的应用限制在模型的子模型(一部分)。最坏的可能是尽管进行了所有尝试,模型还是无法与FP8兼容。

结果

在下表中总结了在两个p4d上的实验结果。24xlarge和p5.48xlarge EC2实例类型,使用和不使用TE库。对于p5.48xlarge实验,我们将批处理大小加倍,这样提高80 GB GPU内存的利用率。使用FP8可以减少GPU内存消耗,从而进一步增加批处理大小。

可以看到,使用TE提高了p4d(19%)和p5(32%)的性价比。使用FP8可将p5上的性能额外提高约20%。在TE和FP8优化之后,基于h100的p5.48large的性价比优于基于a100的p4d.xlarge——虽然差距不大(2%)。考虑到训练速度提高了3倍,我们可以有把握地得出结论,p5将是训练优化模型的更好的实例类型。

但是我们也看到了,这是相对较小的性价比提升(远低于p5公告中提到的40%),所以可能还有更多的优化方案,我们需要继续研究。

总结

在这篇文章中,我们演示了如何编写PyTorch训练脚本来使用8位浮点类型。展示了FP8的使用是如何从Nvidia H100中获得最佳性能的关键因素。FP8的可行性及其对训练性能的影响可以根据模型的细节而变化很大。

https://avoid.overfit.cn/post/541a04c656db474d91ee5eb1fa5bc5f8

作者:Chaim Rand

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/147012.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Arduino驱动LM35线性温度传感器(温湿度传感器)

目录 1、传感器特性 2、控制器和传感器连线图 3、驱动程序 LM35半导体的温度传感器,可以用来对环境温度进行定性的检测。LM35半导体温度传感器是美国国家半导体公司生产的线性温度传感器。其测温范围是-40℃到150℃,灵敏度为10mV/℃,输出电压与温度成正比。

<C++> 反向迭代器

我们知道正向迭代器的设计:begin迭代器指向第一个数据,end迭代器指向最后一个数据的下一个位置 。移向下一个数据,解引用得到数据的值,并根据容器储存方式的不同,容器有不同类型的迭代器。 注意:rbegin迭代…

c语言:模拟实现qsort函数

qsort函数的功能: qsort相较于冒泡排序法,不仅效率更快,而且能够比较不同类型的元素,如:浮点数,结构体等等。这里我们来模拟下qsort是如何实现这一功能的,方便我们对指针数组有一个更深层次的理…

龙芯 操作系统选择和安装

龙芯3a5000及之后的cpu底层架构已经从mips64el改为了loongarch64 所以这里分了2种来说明,分别对应3a4000之前的和3a5000之后的 龙芯的系统安装难点在于操作系统的选取和引导 一、烧录工具 制作安装盘使用常规的烧录工具是不行滴,会提示没有\boot\initrd…

webpack的安全保障是怎么做的?

文章目录 前言Webpack 内容安全策略后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:webpack 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努力填补技术短板。(如果出现错误,感…

阿里云linux升级新版本npm、nodejs

在阿里云服务器上编译部署NextJS工程发现 alibaba linux默认yum install npm安装的版本太低, 使用以下方式升级node、npm新版本。 1、卸载现有版本 yum remove nodejs npm -y2、安装新版本 sudo yum install https://rpm.nodesource.com/pub_21.x/nodistro/repo/nodesource-…

STM32的启动流程

1、STM32上电启动的主要步骤 a、初始化堆栈指针sp_initial_sp,初始化PC指针pcReset_Handler。 b、初始化中断向量表。 c、配置系统时钟。 d、调用 C 库函数_main 初始化用户堆栈,然后进入 main 函数。 2、STM32的三种启动模式 复位后,在 S…

从底层原理看Android的序列化是如何实现的

对于Java的序列化,我们可以认为是在数据传输的时候的一套协议或者是一个标准,因为Java存在自己特定的一个数据结构(class),举个例子 data class User(val name: String,val age: Int )User是一个对象,我们…

产品经理必备技能:如何快速锁定种子用户群体?

大家好,我是小米,一名热爱技术、热衷分享的90后小青年。今天我们要探讨的话题是一个在产品经理面试中经常被问到的问题:“产品上线后的种子用户该如何获取?”作为一个热爱挑战、乐于探讨的小伙伴,我将和大家分享一些我…

一、MySQL.pratice.search

MySQL是一种常用的关系型数据库管理系统,广泛应用于各种Web应用程序中。在编程中,使用MySQL进行数据操作是非常常见的操作。在MySQL中,查询是最常用的操作之一,可以查询整个表或者根据特定的条件查询数据。 文章目录 一、查询&am…

第七部分:Maven(项目管理工具)

目录 Maven简介 7.1:为什么学习Maven? 7.1.1、Maven是一个依赖管理工具 7.1.2:Maven是一个构建工具 7.1.3:结论 7.2:Maven介绍 7.3:Maven的优点 Maven安装和配置 7.4:安装教程及环境配置 …

Linux给根目录扩容

需求:Linux系统挂载到根目录的磁盘空间满了,如何扩容? 一、添加磁盘并分区 [rootcdn ~]# fdisk /dev/sdbWelcome to fdisk (util-linux 2.37.2). Changes will remain in memory only, until you decide to write them. Be careful before u…

什么是Java伪随机数,基础打牢。 #程序员 #Java #编程

你一定听说过这样一个词,伪随机数,你有没有这样的疑惑,为什么不用真随机,要用的个假的? 先说一个结论: Java Random英/ˈrndəm/ 随机数生成不安全,如果同时泄漏第一个和第二个随机数&#xf…

uniapp自定义组件

在UniApp中,你可以使用自定义组件来拓展应用程序的功能和界面。自定义组件是由多个Vue组件构成的,可以在应用程序中重复使用。 要创建一个自定义组件,你需要在UniApp项目中的components目录下创建一个新的文件夹,并在该文件夹中创…

最大似然估计的介绍

最大似然估计(Maximum Likelihood Estimation,简称MLE)是一种用于估计概率分布中参数的方法。该方法的核心思想是选择使得观察到的数据在给定模型下出现的概率最大的参数值作为估计值。 最大似然估计具有很好的性质,包括渐进正态性…

SystemVerilog学习 (9)——随机化

目录 一、概述 二、随机化 2.1、如何简单地产生一个随机数 2.1.1 利用系统函数产生随机数 2.1.2 urandom() 2.2、什么需要随机化 2.3、随机约束 2.3.1 rand 和 randc 2.3.2 随机约束的使用 2.3.3 约束块 三、总结 一、概述 随着设计变得越来越大,要产生一个完整的激…

面试资料快速复习 Git常用命令(简单实用)

Git-command Git常用命令、面试复习、简单实用命令 ​ 一、概念理解 (一)工作区、暂存区、本地仓库、远程仓库 workspace:工作区staging area:暂存区/缓存区local repository:本地仓库remote repository&#xff…

Apache Airflow (九) :Airflow Operators及案例之BashOperator及调度Shell命令及脚本

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹…

03_SHELL编程之嵌套循环+随机数及综合案例

###课程目标 掌握for循环语句的基本语法结构 掌握while和until循环语句的基本语法结构 能会使用RANDOM产生随机数 理解嵌套循环 一、随机数 bash默认有一个$RANDOM的变量 默认是0~32767。使用set |grep RANDOM 查看上一次产生的随机数 echo $RANDOM ​ 产生0~1之间…

C#单例模式懒汉式与饿汉式

单例模式一般分为懒汉模式和饿汉模式,懒汉式单例在第一次引用时创建实例,不是在类加载时;饿汉式单例模式是一种在类加载时就创建实例的方式,因此也称为静态初始化。 单例模式实现的技巧时构造私有,向外提供静态实例。…