使用 PyTorch FSDP 微调 Llama 2 70B

引言

通过本文,你将了解如何使用 PyTorch FSDP 及相关最佳实践微调 Llama 2 70B。在此过程中,我们主要会用到 Hugging Face Transformers、Accelerate 和 TRL 库。我们还将展示如何在 SLURM 中使用 Accelerate。

完全分片数据并行 (Fully Sharded Data Parallelism,FSDP) 是一种训练范式,在该范式中优化器状态、梯度和模型参数都会被跨设备分片。前向传播时,每个 FSDP 单元执行 all gather 以获取完整的权重,然后用它们进行计算并在计算后丢弃掉其他设备的分片。随后是反向传播,然后就是损失计算。反向传播时,每个 FSDP 单元执行 all gather 操作以获取完整的权重,并执行计算以获得本地 batch 的梯度。这些梯度通过 reduce scatter 在设备上进行均值计算并分片,这样每个设备都可以更新其对应分片的参数。有关 PyTorch FSDP 的更多信息,请参阅此博文: ‍使用 PyTorch 完全分片数据并行技术加速大模型训练‍。

6507f00f7c6dc346723d063afcdc9f88.png
FSDP 工作流

使用的硬件

节点数: 2,至少 1 个节点
每节点 GPU 数: 8
GPU 类型: A100
GPU 显存: 80GB
节点内互联: NVLink
每节点内存: 1TB
每节点 CPU 核数: 96
节点间互联: AWS 的 Elastic Fabric Adapter (EFA)

微调 LLaMa 2 70B 面临的挑战

在尝试使用 FSDP 微调 LLaMa 2 70B 时,我们主要遇到了三个挑战:

  1. FSDP 会先加载整个预训练模型,然后再对模型进行分片。这样就意味着节点内的每个进程 (即 rank) 都会加载整个 Llama-70B 模型,因此需要 7048 GB ~ 2TB 的 CPU 内存,这个算式中 4 是每个参数所需字节数,8 是每个节点的 GPU 数。这会导致 CPU 内存不足,进而导致进程终止。

  2. 使用 FULL_STATE_DICT 来保存完整中间检查点并将其卸载至 rank 0 的 CPU 内存中需要花费大量时间,且由于在此期间通信库需要无限期挂起等待保存完成,因此经常会导致 NCCL 超时错误。然而,完全关掉这个选项也不好,因为在训练结束时我们需要保存完整的模型状态字典,而不是 FSDP 式分片的状态字典。

  3. 我们需要提高速度并减少显存使用,以加快训练并节约计算成本。

下文,我们主要讨论如何一一解决上述挑战,最终微调出一个 70B 的模型!

先列出重现结果所需的所有资源:

  1. 代码库: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training,代码中包含了使能 flash 注意力 V2 的热补丁

  2. FSDP 配置文件: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml

  3. SLURM 启动脚本 - launch.slurm: https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25

  4. 模型: meta-llama/Llama-2-70b-chat-hf

  5. 数据集: smangrul/code-chat-assistant-v1 (混合了 LIMA 和 GUANACO 数据集,且已转换为训练所需的格式)

准备工作

首先按照 此步骤 安装 Flash Attention V2。然后,安装最新的 PyTorch nightly (CUDA ≥11.8)。接着,根据 此文件 安装其余依赖软件。在本文中,我们是从主分支安装 🤗 Accelerate 和 🤗 Transformers 的。

微调

应对挑战 1

PR 25107 和 PR 1777 解决了第一个挑战,且无需用户侧更改任何代码。主要做的事情如下:

  1. 在所有 rank 上创建无权重的空模型 (使用 meta 设备)

  2. 仅在 rank 0 上将状态字典加载至模型

  3. 其他 rank 仅对 meta 设备上的参数执行 torch.empty(*param.size(), dtype=dtype)

  4. 因此,只有 rank 0 上加载了完整的模型及权重,而所有其他 rank 上的权重是空的

  5. 设置 sync_module_states=True ,以便 FSDP 实例在训练开始之前将权重广播到各 rank

下面是在 2 个 GPU 上加载 7B 模型的输出日志片段,它测量了各个阶段内存的消耗及其加载的模型参数量。我们可以观察到,在加载预训练模型时,rank 0 和 rank 1 的 CPU 峰值内存分别为 32744 MB1506 MB 。因此可知,仅有 rank 0 加载了预训练模型,这就实现了 CPU 内存的有效利用。你可在 此处 找到完整日志。

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506

应对挑战 2

该挑战可以通过在配置 FSDP 时将状态字典类型设为 SHARDED_STATE_DICT 来解决。设为 SHARDED_STATE_DICT 后,每个 rank 各自保存各自 GPU 所需要的分片,这使得用户可以快速保存中间检查点并快速从其恢复训练。而当使用 FULL_STATE_DICT 时,第一个进程 (rank 0) 会用 CPU 收集整个模型,然后将其保存为标准格式。

我们可以用以下命令创建相应的 accelerte 配置文件:

accelerate config --config_file "fsdp_config.yaml"
492f627df72008ebc6c4a183bc4ddf30.jpeg
fsdp 配置

你可以从此处获取生成的配置文件: fsdp_config.yaml。在该配置文件中,分片策略是 FULL_SHARD 。我们使用 TRANSFORMER_BASED_WRAP 作为自动模型包装策略,它使用 _no_split_module 来搜索 transformer 块名并自动进行嵌套 FSDP 包装。我们使用 SHAARDED_STATE_DICT 把中间检查点和优化器状态保存为 PyTorch 官方推荐的格式。同时,如上一节 应对挑战 1 中所述,我们还需要确保训练开始时用 rank 0 来广播参数。从配置文件中你还可以看到我们用的是 bf16 混合精度训练。

那么,在保存最终检查点时,如果将其保存成单个文件呢?我们使用的是以下代码段:

if trainer.is_fsdp_enabled:trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")trainer.save_model(script_args.output_dir) # 或者 , 如果整个模型小于 50 GB (即 LFS 单文件的最大尺寸),你还可以使用 trainer.push_to_hub() 把模型推到 hub 上去。

应对挑战 3

为了加快训练速度并减少显存占用,我们可以使用 flash 注意力并开启梯度检查点优化,从而在微调的同时节省计算成本。当前,我们用了一个热补丁来实现 flash 注意力,具体代码可见 这儿。

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 一文基于对底层硬件 (即 GPU) 的内存层次结构的深刻理解而引入了一种更快、更节省内存的无损注意力加速算法。底层硬件在设计内存层次结构时,遵循的实践原则是: 带宽/速度越高的内存,其容量越小,因为它更贵。

根据博文 根据第一性原理让深度学习性能起飞,我们可以发现,当前硬件上的注意力模块是 内存带宽受限 的。原因是注意力机制 主要由逐元素操作 组成,如下左图所示。我们可以观察到,掩码、softmax 和 dropout 操作占用了大部分时间,而非需要大量 FLOP 的矩阵乘法。

e456acc938f8afcf9847f4c6b937bf7b.png
注意力机制的性能瓶颈

这正是 flash 注意力解决的问题,其想法是 去除冗余的 HBM 读/写操作。该算法通过将所有内容保留在 SRAM 中,待执行完所有中间步骤后再将最终结果写回到 HBM,即 算子融合 来实现这一目的。下图简要描述了算子融合是如何克服内存瓶颈的。

7a5ae60952890b09293795e2a6b1fba9.jpeg
算子融合

在前向和反向传播过程中我们还使用了 平铺 (Tiling) 优化技巧,将 NxN 大小的 softmax 分数计算切成块,以克服 SRAM 内存大小的限制。在使用平铺技巧时,我们会使用在线 softmax 算法。同时,我们还在反向传播中使用了 重计算 技巧,以大大降低在前向传播过程中存储整个 NxN softmax 分数矩阵所带来的内存消耗。

如欲深入理解 flash 注意力,请参考博文 ELI5: FlashAttention、根据第一性原理让深度学习性能起飞 以及原始论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness。

综合运用所有手段

你可参考 此脚本,以在 SLURM 中用 Accelerate 启动器运行训练。下面还给出了一个等效命令,展示了如何使用 Accelerate 启动器来运行训练。请注意,该命令会覆盖 fsdp_config.yaml 中的 main_process_ipmain_process_portmachine_ranknum_processes 以及 num_machines 配置。另一个需要重点注意的是,这里的存储是所有节点共享的。

accelerate launch \--config_file configs/fsdp_config.yaml \--main_process_ip $MASTER_ADDR \--main_process_port $MASTER_PORT \--machine_rank \$MACHINE_RANK \--num_processes 16 \--num_machines 2 \train.py \--model_name "meta-llama/Llama-2-70b-chat-hf" \--dataset_name "smangrul/code-chat-assistant-v1" \--max_seq_len 2048 \--max_steps 500 \--logging_steps 25 \--eval_steps 100 \--save_steps 250 \--bf16 True \--packing True \--output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \--per_device_train_batch_size 1 \--gradient_accumulation_steps 1 \--dataset_text_field "content" \--use_gradient_checkpointing True \--learning_rate 5e-5 \--lr_scheduler_type "cosine" \--weight_decay 0.01 \--warmup_ratio 0.03 \--use_flash_attn True

整个微调过程需要约 13.5 小时,下图给出了训练损失曲线。

7f372c289cd5d16b8f6ca3439640b325.png
训练损失曲线

下例给出了使用上述模型完成的一段对话:

System Prompt: You are a helpful, respectful and honest assistant. Always answer as helpfully \
as possible, while being safe. Your answers should not include any harmful, \
unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that \
your responses are socially unbiased and positive in nature.If a question does not make any sense, or is not factually coherent, explain why \
instead of answering something not correct. If you don’t know the answer to a \
question, please don’t share false information.- Human: What is Deep Learning? Explain like a Pirate.+ Assistant: Arrr Grumete! Are ye lookin' fer a proper explanation o' Deep Learnin'?
+ Alright then, listen up me hearty! Deep Learnin' be a type o' machine learnin' that uses artificial neural networks to analyze data.
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but instead o' using neurons, they use interconnected nodes or "artificial neurons" to process information.
+ These nodes work together to form layers o' processing units that can learn from large amounts o' data.
+ The more data these systems receive, the better they get at recognizing patterns and makin' decisions on their own.
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It's a treasure trove o' knowledge just waitin' to be discovered.
+ Fair winds and following seas!- Human: Now explain it like a chef.+ Assistant: Certainly! Here's an explanation of deep learning from a culinary perspective:+ Think of a recipe as a sequence of steps used to transform raw ingredients into a delicious dish.
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) that are combined and transformed through various operations to produce a final output or prediction.
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the desired outcome.
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions.
+ And much like how chefs must carefully balance the amount of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to achieve the perfect blend of accuracy and efficiency.
+ With enough practice and experimentation, anyone can become a master chef... I mean, a master of deep learning! Bon appétit!

整个对话使用的格式如下:

<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...

总结

我们在多节点多 GPU 上使用 PyTorch FSDP 成功微调了一个 70B Llama 模型,并在此过程中解决了各种挑战。我们看到了当前在 🤗 Transformers 和 🤗 Accelerates 中应如何初始化大模型从而有效克服 CPU 内存不足的问题。我们还给出了如何高效地保存/加载中间检查点,同时又能以易于使用的方式保存最终模型的最佳实践。为了加速训练并减少 GPU 显存使用,我们还强调了 flash 注意力和梯度检查点机制的重要性。最后,我们向大家展示了在 🤗 Accelerate 上仅需要简单的配置就可以在多节点多 GPU 上微调大模型。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/ram-efficient-pytorch-fsdp

原文作者: Sourab Mangrulkar,Sylvain Gugger,Lewis Tunstall,Philipp Schmid

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

FSDP MFU (Model FLOPS Utilization) 相关讨论:https://github.com/huggingface/blog/issues/1649

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入学习 C++编程,数据结构与算法关系

数据结构是计算机科学中非常重要的概念之一。它是一种组织和存储数据的方式&#xff0c;能够有效地操作和管理数据&#xff0c;以便提高算法的效率。 以下是一些为什么要有数据结构的原因&#xff1a; (1) 数据组织&#xff1a;数据结构可以帮助我们组织和管理大量的数据。通过…

【elementui笔记:el-table表格的输入校验】

之前做得比较多的校验是在el-form表单里做的&#xff0c;但有时也遇到&#xff0c;需要在table内输入数据&#xff0c;然后校验输入的数据是否符合要求的情况。因此记录一下。 思路&#xff1a; 1.需要借助el-form的校验&#xff0c;el-table外层嵌套一层el-form&#xff0c;使…

世界5G大会

会议名称:世界 5G 大会 时间:2023 年 12 月 5 日-12 月 8 日 地点:河南郑州 一、会议简介 世界 5G 大会,是由国务院批准,国家发展改革委、科技部、工 信部与地方政府共同主办,未来移动通信论坛联合属地主管厅局联合 承办,邀请全球友好伙伴共同打造的全球首个 5G 领域…

CleanMyMac X2024值不值得下载?

macOS已经成为最受欢迎的桌面操作系统之一&#xff0c;它提供了直观、简洁的用户界面&#xff0c;使用户可以轻松使用和管理系统。macOS拥有丰富的应用程序生态系统&#xff1b;还可以与其他苹果产品和服务紧密协作&#xff0c;如iPhone、iPad&#xff0c;用户可以通过iCloud同…

# 和 $ 的区别②

上节博客说了使用 # 的时候,如果参数为 String ,会自动加上单引号 但是当参数为String 类型的时候,也有不需要加单引号的情况,这时候用 # 那就会出问题 比如根据 升序(asc) 或者 降序(desc) 查找的时候,加了单引号那就会报错 这个时候我们就只能使用 $ 如果看不懂代码,就去…

jmeter,同一线程组内,调用cookie实现接口关联

取cookie方式参考上一篇&#xff1a;jemeter&#xff0c;取“临时重定向的登录接口”响应头中的cookie-CSDN博客 元件结构 登录后要执行的接口为“api/get_event_list/”&#xff0c;在该HTTP请求下创建HTTP信息头管理器&#xff0c;配置如下&#xff1a; 执行测试后&#xff0…

Docker网络模式:深度理解与容器网络配置

Docker 的网络模式是容器化应用中一个关键而复杂的方面。本文将深入讨论 Docker 的网络模式&#xff0c;包括基本概念、常用网络模式以及高级网络配置&#xff0c;并通过更为丰富和实际的示例代码&#xff0c;帮助读者全面掌握如何理解和配置容器网络。 Docker网络基础 1 Doc…

Android--Jetpack--Navigation详解

须知少日拏云志&#xff0c;曾许人间第一流 一&#xff0c;定义 Navigation 翻译成中文就是导航的意思。它是谷歌推出的Jetpack的一员&#xff0c;其目的主要就是来管理页面的切换和导航。 Activity 嵌套多个 Fragment 的 UI 架构模式已经非常普遍&#xff0c;但是对 Fragmen…

Sql标准梳理

SQL&#xff08;Structured Query Language&#xff09;是一种用于管理关系型数据库管理系统&#xff08;RDBMS&#xff09;的标准化语言。SQL标准由国际标准化组织&#xff08;ISO&#xff09;和美国国家标准化组织&#xff08;ANSI&#xff09;制定和维护&#xff0c;旨在提供…

SpringIOC之FilterType

博主介绍&#xff1a;✌全网粉丝5W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

spring 笔记六 SpringMVC 获得请求数据

文章目录 SpringMVC 获得请求数据获得请求参数获得基本类型参数获得POJO类型参数获得数组类型参数获得集合类型参数请求数据乱码问题参数绑定注解requestParam获得Restful风格的参数获得Restful风格的参数自定义类型转换器获得Servlet相关API获得请求头RequestHeaderCookieValu…

js解析.shp文件

效果图 原理与源码 本文采用的是shapefile.js工具 这里是他的npm地址 https://www.npmjs.com/package/shapefile 这是他的unpkg地址&#xff0c;可以点开查看源码 https://unpkg.com/shapefile0.6.6/dist/shapefile.js 这个最关键的核心问题是如何用这个工具&#xff0c;网上…

Vue3-18-侦听器watch()、watchEffect() 的基本使用

什么是侦听器 个人理解&#xff1a;当有一个响应式状态&#xff08;普通变量 or 一个响应式对象&#xff09;发生改变时&#xff0c;我们希望监听到这个改变&#xff0c;并且能够进行一些逻辑处理。那么侦听器就是来帮助我们实现这个功能的。侦听器 其实就是两个函数&#xff…

〖大前端 - 基础入门三大核心之JS篇(54)〗- 原型和原型链

说明&#xff1a;该文属于 大前端全栈架构白宝书专栏&#xff0c;目前阶段免费&#xff0c;如需要项目实战或者是体系化资源&#xff0c;文末名片加V&#xff01;作者&#xff1a;哈哥撩编程&#xff0c;十余年工作经验, 从事过全栈研发、产品经理等工作&#xff0c;目前在公司…

曹操出行集成:无代码API连接广告推广与用户运营

曹操出行集成的必要性 随着科技的不断进步&#xff0c;无代码API集成已经成为企业提升效率、优化营销策略的重要手段。对于新能源汽车共享服务领导者曹操出行而言&#xff0c;将其服务集成至企业营销系统中&#xff0c;不仅可以提升客户体验&#xff0c;还能加强品牌的市场竞争…

微信小程序 全局共享数据 mobx

前言 全局数据共享&#xff08;又叫做&#xff1a;状态管理&#xff09;是为了解决组件之间数据共享的问题。开发中常用的全局数据共享方案有&#xff1a;Vuex、Redux、MobX 等。 一. 安装 npm install --save mobx-miniprogram4.13.2 mobx-miniprogram-bindings2.1.5 安装完…

基于主动安全的AIGC数据安全建设

面对AIGC带来的数据安全新问题&#xff0c;是不是就应该一刀切禁止AIGC的研究利用呢&#xff1f;答案是否定的。要发展AIGC&#xff0c;也要主动积极地对AIGC的数据安全进行建设。让AIGC更加安全、可靠的为用户服务。为达到此目的&#xff0c;应该从三个方面来开展AIGC的数据安…

【WebRTC】用WebRTC做即时视频聊天应用

【配套项目源码】 打开即用,设置一个免费的Agora账户就可以实现视频电话。非常好的WebRTC学习和应用项目。 用VSCode打开即可。 https://download.csdn.net/download/weixin_41697242/88630069 【什么是WebRTC?】 WebRTC是一套基于JS的API,能够建立端对端的直接通信,实…

后端接口开发-web前台请求接口对后台数据库增删改查-实例

一、后端接口开发的逻辑是&#xff1a; 1.Application项目启动 2.前台接口Url请求后台 3.Controller控制拿到前台请求参数&#xff0c;传递给中间组件Service 4.Service调用Mapper.java 5. mapper.java映射到mapper.xml中的mybatis语句&#xff0c;类似Sql语句操作数据库 6.其…

GDPU 数据结构 天码行空14

实验十四 查找算法的实现 一、【实验目的】 1、掌握顺序排序&#xff0c;二叉排序树的基本概念 2、掌握顺序排序&#xff0c;二叉排序树的基本算法&#xff08;查找算法、插入算法、删除算法&#xff09; 3、理解并掌握二叉排序数查找的平均查找长度。 二、【实验内容】 …