GPU集群上分布式训练大模型

总结一下如何在超算系统上进行预训练大模型的分布式训练 / 微调,文中代码已上传至 github

实验环境

集群1:国家广州超算 星逸A800智能AI集群
GPU:8 * Nvdia Tesla-A800 80G显存
CPU:2 * 28核 Intel Xeon Gold 6348
内存:1024GB

集群2:并行科技 中国国家网格 N12 区(cngrid12)
GPU:4 * Nvdia Tesla-V100 16G显存
CPU:20 核 Intel® Xeon® CPU E5-2640 v4
内存:128GB

在超算分布式环境上和本地训练有几点不同:

  1. 超算环境无法科学上网,需要手动下载并上传:数据tokenizer模型模型参数,并在代码中作相应修改。
  2. 通过 slurm 进行作业管理,编写并提交 sbatch 脚本来运行作业。
  3. 每个集群的环境各不相同,移植时需要注意配置环境和预加载相关的库。
  4. 训练超大模型时,单个GPU显存有限,仅使用 torch.nn.parallel 数据并行常常无法加载完整模型无法完成训练,或只能小批次训练。因此需要用到分布式训练框架,常见的分布式训练框架有Horovod,Megatron-LM,DeepSpeed等。

1 举例:bert-large

1.1 本地单卡训练bert-large(假设GPU显存足够大)

Step1 编写训练代码 run_bert.py

import torch
from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm# 加载预训练的tokenizer和模型
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
model = BertForMaskedLM.from_pretrained("bert-large-uncased")# 加载WikiText数据集
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")# 数据预处理
def tokenize_function(examples):return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])# 设置数据加载器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
train_loader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)  # 根据显存调整batch_size# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 设置优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)# 设置训练参数
num_epochs = 3
gradient_accumulation_steps = 8  # 梯度累积步数,根据显存调整
model.train()# 手动实现训练循环
for epoch in range(num_epochs):print(f"Epoch {epoch + 1}/{num_epochs}")epoch_loss = 0for step, batch in enumerate(tqdm(train_loader)):# 将数据移到GPUbatch = {k: v.to(device) for k, v in batch.items()}# 前向传播outputs = model(**batch)loss = outputs.lossloss = loss / gradient_accumulation_steps  # 梯度累积# 反向传播loss.backward()# 更新参数并清空梯度if (step + 1) % gradient_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()# 记录损失epoch_loss += loss.item()avg_loss = epoch_loss / len(train_loader)print(f"Epoch {epoch + 1} finished with average loss: {avg_loss:.4f}")print("Training complete.")

Step2 运行:

python run_bert.py

1.2 迁移到分布式训练(微调大模型)

按如下步骤转换为分布式训练代码,并移植到超算平台上完成训练(单节点多卡):

Step1 下载数据、tokenizer、模型
从huggingface官网或镜像网站(国内)下载对应文件,模型和tokenizer搜索bert-large-uncased,数据集搜索wikitext:

模型:config.json,pytorch_model.bin
数据:train-00000-of-00001.parquet
tokenizer:tokenizer.json,vocab.txt,tokenizer_config.json(可选)

Step2 修改训练代码 run_bert.py

  • 分布式训练框架我用微软提供的deepspeed框架进行训练。deepspeed支持3D并行训练,同时集成了一些优化,如ZeRO、CPU-offload、混合精度训练等,能够提供比原生pytorch更加高效的训练。
  • 修改内容:
  1. tokenizer、model、dataset 分别修改为从本地加载
  2. 添加 import deepspeed 和 import torch.distributed as dist
  3. 添加 deepspeed.init_distributed() 初始化分布式环境
  4. 添加 deepspeed.initialize() 将模型转换为deepspeed模型
  5. 获取 world_size 和 rank ,并将 .to(device) 替换为 .to(rank)
  6. 删除 device, optimizer, gradient_accumulation_steps 配置,转移到 ds_config.json 中
  7. 将所有的 model 替换为 model_engine ,即deepspeed模型
  8. 训练循环中只保留 model_engine(input_ids), model_engine.backward(), model_engine.step() 三个函数
  9. 添加 time.time() 函数用于计时
  10. 修改数据集加载和预处理逻辑,增加sampler,删除shuffle选项,使其符合分布式训练
  11. 对模型和数据添加 .contiguous() 以保证 tensor 的连续性
import torch
from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
import deepspeed
import torch.distributed as dist
import time# 初始化分布式环境
deepspeed.init_distributed()
world_size = dist.get_world_size()
rank = dist.get_rank()# 加载预训练的tokenizer和模型
tokenizer = BertTokenizer.from_pretrained("./tokenizer/")
model = BertForMaskedLM.from_pretrained("./model")# 加载WikiText数据集
dataset = load_dataset("parquet", data_files="./data/train-00000-of-00001.parquet")# 数据预处理
def tokenize_function(examples):return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
train_dataset = tokenized_dataset["train"]# 使用分布式采样器
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)# 设置数据加载器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=data_collator, sampler=sampler)  # 设置训练参数
epochs = 3# 确保模型参数的连续性
for param in model.parameters():param.data = param.data.contiguous()# 初始化deepspeed模型,将model转换成model_engine
model_engine, optimizer, _, _ = deepspeed.initialize(args=None,model=model,model_parameters=model.parameters(),config='./ds_config.json'
)model_engine.train()start_time = time.time()if rank == 0: print("Training start...")# 手动实现训练循环
for epoch in range(epochs):epoch_loss = 0for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", disable=(rank != 0), mininterval=20)):  # 设置只有0号进程打印进度条且打印间隔为20s# 将数据移到GPUbatch = {k: v.to(rank) for k, v in batch.items()}# 前向传播outputs = model_engine(**batch)loss = outputs.loss# 反向传播model_engine.backward(loss)# 更新参数model_engine.step()# 记录损失epoch_loss += loss.item()end_time = time.time()elapsed_time = end_time - start_timeprint(f"\nRank {rank}, Epoch {epoch+1}/{epochs}, Batch {step+1}/{len(train_loader)}, Loss: {epoch_loss/len(train_loader)}, total_time_used: {elapsed_time / 60:.2f} mins")if rank == 0: print("Training complete.")
  • 上面代码是加载了预训练模型权重文件 pytorch_model.bin 然后对模型进行微调,若需要从头训练,只需要将 model = BertForMaskedLM.from_pretrained("./model") 改为:
from transformers import BertConfigconfig = BertConfig.from_pretrained('./model/config.json')
model = BertForMaskedLM(config)

Step3 编写运行脚本 sbatch.sh

#!/bin/bash#SBATCH --nodes=1                   # 节点数
#SBATCH --ntasks=4                  # 任务数
#SBATCH --partition=gpu             # 分区名称(根据集群修改)
#SBATCH --gres=gpu:4                # 设置使用的GPU数module load nvidia/cuda/12.2
module load mpich/3.4.1-gcc9.3      # 加载gcc-5版本以上deepspeed   --num_nodes=1          \--num_gpus=4           \--launcher slurm       \run_bert.py 

Step4 创建deepspeed的配置文件 ds_config.json

{"train_batch_size": 4,                // batch_size,必须等于 train_micro_batch_size_per_gpu * gradient_accumulation_steps * GPU数,且和训练代码中设置相同"train_micro_batch_size_per_gpu": 1,  // 每个GPU上micro_batch的数量"gradient_accumulation_steps": 1,     // 梯度累积多少个batch同步一次// 设置使用ZeRO-3优化"zero_allow_untested_optimizer": true,"zero_optimization":{"stage": 3},// 配置优化器"optimizer": {"type": "Adam","params": {"lr": 1e-4,"betas": [0.9, 0.999],"eps": 1e-8}} 
}

Step5 上传到服务器

  • 将所有文件打包上传到服务器,其中模型文件 pytorch_model.bin 比较大,可能需要单独上传。上传后解压,文件结构如下:
bert-large
│   run_bert.py
│   sbatch.sh   
│   ds_config.json  
│
└───data
│   │   train-00000-of-00001.parquet
│   
└───model│   config.json│   pytorch_model.bin
│
└───tokenizer
│   │   tokenizer.json
│   │   vocab.txt
│   

Step6 配置服务器运行环境

# 创建虚拟环境
$ conda create -n bert-large python==3.10
$ conda activate bert-large# 安装必要的库
$ pip install -r requirement.txt -i https://pypi.tuna.tsinghua.edu.cn/simple# 使用 conda 安装mpi4py,因为这个库需要的依赖太多了,pip很容易报错
$ conda install mpi4py

其中,requirement.txt

torch==2.4.1
transformers==4.46.0
deepspeed==0.15.2
datasets
tensorboard
fire==0.4.0
pytz==2021.1
loguru==0.5.3
sh==1.14.2
pytest==6.2.5
tqdm==4.62.3

Step7 命令行运行:

$ cd bert-large
$ sbatch sbatch.sh

运行结果(部分)
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/885350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫自动库DrissionPage保存网页快照mhtml/pdf/全局截图/打印机另存pdf

目录 零一、保存网页快照的三种方法二、利用打印机保存pdf的方法 零 最近星球有人问如何使用页面打印功能,另存为pdf 一、保存网页快照的三种方法 解决方案已经放在星球内:https://articles.zsxq.com/id_55mr53xahr9a.html当然也可以看如下代码&…

Redis 中 Bitmap 原理和应用

Bitmap Redis中的Bitmap(位图)是一种较为特殊数据类型,它以最小单位bit来存储数据,我们知道一个字节由 8个 bit 组成,和传统数据结构用字节存储相比,这使得它在处理大量二值状态(true、false 或…

elementUI 点击弹出时间 date-picker

elementUI的日期组件,有完整的UI样式及弹窗,但是我的页面不要它的UI样式,点击的时候却要弹出类似的日期选择器,那怎么办呢? 以下是elementUI自带的UI风格,一定要一个输入框来触发。 这是我的项目中要用到的…

微软日志丢失事件敲响安全警钟

NEWS | 事件回顾 最近,全球最大的软件公司之一——微软,遭遇了一场罕见的日志丢失危机。据报告,从9月2日至9月19日,持续长达两周的时间里,微软的多项核心云服务,包括身份验证平台Microsoft Entra、安全信息…

2021-04-22 51单片机玩转点阵

理论就不赘述了,网络上多得很,直接从仿真软件感性上操作认识点阵,首先打开ISIS仿真软件,放置一个点阵和电源与地线就可以开始了;由点阵任何一脚连线到地线,另一边对应的引脚就连接到电源,如图:点击运行看是否点亮?看到蓝色与红色的点表示电源正常但是没有任何亮点,这时对调一下…

(十三)JavaWeb后端开发——MySQL2

目录 1.DQL数据查询语言 1.1基本查询 1.2条件查询 where关键字 1.3分组查询 1.4排序查询 1.5分页查询 2.多表设计 3.多表查询——联查 4.多表查询——子查询​ 5.MySQL 事务 6.事务管理(事务进阶) 7.MySQL 索引 1.DQL数据查询语言 分为五大…

恭喜!2024年度大连市科技人才创新、科技人才创业项目拟立项公示!

精选SCI/SSCI/EI SCI&EI ●IEEE 1区TOP 计算机类(含CCF); ●EI快刊:最快1周录用! 知网(CNKI)、谷歌学术期刊 ●7天录用-检索(100%录用),1周上线; 免费稿件评估 …

【前端】-音乐播放器(源代码和结构讲解,大家可以将自己喜欢的歌曲添加到数据当中,js实现页面动态显示音乐)

前言:音乐播放器是前端开发中的一个经典项目,通过它可以掌握很多核心技术,如音频处理、DOM操作、事件监听、动画效果等。这个项目不仅能提升前端开发的技能,还能让开发者深入理解JavaScript与HTML的协同作用。 页面展示&#xff1…

虚拟机linux7.9下安装mysql

1.MySQL官网下载安装包: MySQL :: Download MySQL Community Server https://cdn.mysql.com/archives/mysql-5.7/mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 2.解压文件: #tar xvzf mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 3.移动文件&#…

03_CC2530基于定时器3的Delay_ms函数

CC2530定时器3与Delay_ms延时函数 前言 ​ Delay函数是开发中常用到的函数,可以用于按键消抖,LED闪烁,生成一定频率信号等(软件模拟通讯协议)。由于利用循环执行一定次数的空指令实现的延时函数在精度上并不能让人满意,而用定时…

【系统面试篇】其他相关题目——虚拟内存、局部性原理、分页、分块、页面置换算法

目录 一、相关问题 1. 什么是虚拟内存?为什么需要虚拟内存? (1)内存扩展 (2)内存隔离 (3)物理内存管理 (4)页面交换 (5)内存映…

43.第二阶段x86游戏实战2-提取游戏里面的lua

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 本人写的内容纯属胡编乱造,全都是合成造假,仅仅只是为了娱乐,请不要…

容器内pip安装Apache Airflow的经历:如何重置初始密码

背景 Apache Airflow™https://github.com/apache/airflow 是一个开源平台,用于开发、调度和监控面向批处理的工作流程。Airflow 可扩展的 Python 框架使您能够构建几乎可以连接任何技术的工作流程。Web 界面有助于管理工作流程的状态。Airflow 可以通过多种方式部…

Java爬虫 爬取某招聘网站招聘信息

Java爬虫 爬取某招聘网站招聘信息 一、系统介绍二、功能展示1.需求爬取的网站内容2.实现流程2.1数据采集2.2页面解析2.3数据存储 三、其它1.其他系统实现 一、系统介绍 系统主要功能:本项目爬取的XX招聘网站 二、功能展示 1.需求爬取的网站内容 2.实现流程 爬虫…

stm32不小心把SWD和JTAG都给关了,程序下载不进去,怎么办?

因为想用STM32F103的PA15引脚,调试程序的时候不小心把SWD和JTAD接口都给关了,先看下罪魁祸首 GPIO_PinRemapConfig(GPIO_Remap_SWJ_JTAGDisable,ENABLE);//关掉JTAG,不关SWGPIO_PinRemapConfig(GPIO_Remap_SWJ_Disable, ENABLE);//关掉SW&am…

雷军-2022.8小米创业思考-11-新零售:用电商思维做新零售,极致的效率+极致的体验。也有弯路,重回极致效率的轨道上。

第十一章 新零售 当我们说到小米模式的时候,其实我们说的是两件东西: 一是小米模式的本质,即高效率的商业模式; 另一件是小米这家公司具象的商业模式,这是小米在实践中摸索、建立的一整套业务模型。 从2015年到202…

C语言实现数据结构之堆

文章目录 堆一. 树概念及结构1. 树的概念2. 树的相关概念3. 树的表示4. 树在实际中的运用(表示文件系统的目录树结构) 二. 二叉树概念及结构1. 概念2. 特殊的二叉树3. 二叉树的性质4. 二叉树的存储结构 三. 二叉树的顺序结构及实现1. 二叉树的顺序结构2.…

知识课堂之域名系统中实现动态代理

怎么在域名系统中解析动态ip,这一直是一个需要解决的问题,人们对与网络的稳定连接与灵活运用已经成为生活和工作中不可或缺的一部分,因此这样的问题的解决迫在眉睫。 大家对于动态ip是什么,应该都有所了解了,所谓的动…

5G周边知识笔记

这里写目录标题 3GPP 5G标准路径图5G协议规范5G新空口关键指标4G LTE和5G NR新空口技术对比5G新频段FR1FR2 信道带宽上下行解耦新频点规划,信道栅格FR1各频段实际信道栅格和NR-ARFCN范围定义 同步栅格大规模天线阵列新型调制编码技术大规模载波聚合设备到设备直接通…

uniapp配置h5路由模式为history时404

为了不让URL中出现#,让uniapp项目配置h5路由模式为hisory 然而本地好好的,放到服务器上却404了。 解决方法是给nginx配置一个伪静态: location /xxx-html/ {alias /home/nginx_web/xxx_new_html/;try_files $uri $uri/ /xxx-html/index.ht…