使用DeepSpeed进行单机多卡训练

这是你提供的DeepSpeed单机多卡训练步骤的Markdown格式:

使用 DeepSpeed 进行单机多卡训练的主要步骤

1. 安装 DeepSpeed

确保你已经安装了 DeepSpeed 及其依赖:

pip install deepspeed
  1. 设置模型并集成 DeepSpeed

在模型的定义和训练循环中集成 DeepSpeed:


import deepspeed

假设你有一个 PyTorch 模型

model = MyModel()

配置 DeepSpeed 参数,例如优化器、梯度累积等

ds_config = {
“train_batch_size”: 32,
“gradient_accumulation_steps”: 1,
“fp16”: {
“enabled”: True,
“initial_scale_power”: 16
},
“zero_optimization”: {
“stage”: 2
}
}

使用 deepspeed.initialize 初始化模型

model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config
)

  1. 多卡并行训练

DeepSpeed 自动管理数据并行和模型优化。使用 torch.distributed.launch 来启动单机多卡训练。通常,你可以使用以下命令:

deepspeed --num_gpus=4 train.py

其中,train.py 是你编写的训练脚本,–num_gpus=4 表示使用 4 张 GPU 进行训练。

  1. 修改训练脚本以支持多卡训练

在训练循环中,将模型的 forward 和 backward 操作交给 DeepSpeed 管理。例如:

for step, batch in enumerate(data_loader):
inputs, labels = batch

# 前向传播
outputs = model_engine(inputs)# 计算损失
loss = loss_fn(outputs, labels)# 反向传播和优化
model_engine.backward(loss)
model_engine.step()
  1. DeepSpeed 配置文件

你还可以通过一个 JSON 配置文件来管理 DeepSpeed 的设置。通常可以创建一个 ds_config.json 文件,里面包含优化器、调度器、fp16 等的配置:

{
“train_batch_size”: 32,
“gradient_accumulation_steps”: 1,
“fp16”: {
“enabled”: true,
“initial_scale_power”: 16
},
“zero_optimization”: {
“stage”: 2
}
}

然后在训练时通过 --deepspeed_config ds_config.json 来引用此文件。

  1. 优化和调优

DeepSpeed 提供了许多优化选项,例如 Zero Redundancy Optimizer(ZeRO)可以减少 GPU 显存的占用,同时支持混合精度(FP16)来加快训练速度。

通过这些步骤,你就能在一台机器上使用多个 GPU 进行分布式训练,同时享受 DeepSpeed 带来的优化。

这样你可以很方便地在Markdown中展示DeepSpeed的使用步骤。

你可以使用一些常见的深度学习模型进行单机多卡训练,尤其是那些已经在主流库中实现的模型。例如,Hugging Face 的 transformers 库和 PyTorch 的 torchvision 模型都非常适合进行分布式训练。

  1. Hugging Face Transformers 模型

Hugging Face 提供了很多预训练模型,如 BERT、GPT 等,非常适合用来练习单机多卡训练。你可以使用它们并集成 DeepSpeed。

下面是一个使用 Hugging Face 库中 BERT 模型进行训练的示例:

安装依赖:

pip install transformers datasets deepspeed

训练脚本示例:

import torch
import deepspeed
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset

加载数据集

dataset = load_dataset(‘glue’, ‘mrpc’)

加载预训练模型和 tokenizer

model = BertForSequenceClassification.from_pretrained(‘bert-base-uncased’, num_labels=2)
tokenizer = BertTokenizer.from_pretrained(‘bert-base-uncased’)

数据处理

def preprocess_function(examples):
return tokenizer(examples[‘sentence1’], examples[‘sentence2’], truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True)

DeepSpeed 配置

ds_config = {
“train_batch_size”: 32,
“fp16”: {
“enabled”: True
},
“zero_optimization”: {
“stage”: 2
}
}

DeepSpeed 初始化

model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config
)

模型训练

for epoch in range(3):
for batch in encoded_dataset[‘train’]:
inputs = {k: torch.tensor(v).to(model_engine.local_rank) for k, v in batch.items() if k != ‘label’}
labels = torch.tensor(batch[‘label’]).to(model_engine.local_rank)

    outputs = model_engine(**inputs, labels=labels)loss = outputs.lossmodel_engine.backward(loss)model_engine.step()

运行命令:

deepspeed --num_gpus=4 train.py

这个示例使用了 Hugging Face 提供的 BERT 模型以及 GLUE 数据集中的 MRPC 任务,进行文本分类。

  1. PyTorch torchvision 模型

如果你对计算机视觉模型感兴趣,可以使用 PyTorch 的 torchvision 提供的模型,如 ResNet。

安装依赖:

pip install torch torchvision deepspeed

训练脚本示例:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import deepspeed

超参数设置

batch_size = 32
num_epochs = 5

使用 torchvision 加载 CIFAR-10 数据集

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root=‘./data’, train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

加载预训练的 ResNet 模型

model = torchvision.models.resnet18(pretrained=False, num_classes=10)

DeepSpeed 配置

ds_config = {
“train_batch_size”: batch_size * torch.cuda.device_count(),
“fp16”: {
“enabled”: True
},
“zero_optimization”: {
“stage”: 2
}
}

DeepSpeed 初始化

model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=ds_config
)

损失函数和优化器

criterion = nn.CrossEntropyLoss()

模型训练

for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(model_engine.local_rank)labels = labels.to(model_engine.local_rank)# 前向传播outputs = model_engine(images)loss = criterion(outputs, labels)# 反向传播和优化model_engine.backward(loss)model_engine.step()if i % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

运行命令:

deepspeed --num_gpus=4 train.py

其他模型获取途径

•	Hugging Face Model Hub: https://huggingface.co/models
•	PyTorch Model Zoo: https://pytorch.org/vision/stable/models.html

这些模型和数据集可以让你快速开始训练,并能帮助你熟悉 DeepSpeed 的单机多卡训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/57537.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

solana phantom NFT图片显示不出来?

solana phantom NFT图片显示不出来? 问题 同样是jpeg格式图片,一个phatom可以显示,一个不可以显示为什么,nft图片格式大小有要求吗? 问题分析 Phantom 官网有一些关于 NFT 集成的文档,其中可能会有关于图片大小限制…

049_python基于Python的热门微博数据可视化分析

目录 系统展示 开发背景 代码实现 项目案例 获取源码 博主介绍:CodeMentor毕业设计领航者、全网关注者30W群落,InfoQ特邀专栏作家、技术博客领航者、InfoQ新星培育计划导师、Web开发领域杰出贡献者,博客领航之星、开发者头条/腾讯云/AW…

@tarojs/components 和 taro-ui 中的组件之间的区别

1. 来源与用途: tarojs/components:Taro 官方提供的基础组件库,包含了微信小程序、H5 等不同平台的通用组件(如 View, Input, Button, Form 等)。这些组件是跨平台的,并提供了与微信小程序等平台原生组件类…

15分钟学Go 第7天:控制结构 - 条件语句

第7天:控制结构 - 条件语句 在Go语言中,控制结构是程序逻辑的重要组成部分。通过条件语句,我们可以根据不同的条件采取不同的行动。今天我们将详细探讨Go语言中的两种主要条件结构:if语句和switch语句。理解这些控制结构对于编写…

CTA-GAN:基于生成对抗网络对颈动脉和主动脉的非增强CT影像进行血管增强

写在前面 目前只分析了文章的大体内容和我个人认为的比较重要的细节,代码实现还没仔细看,后续有时间会补充代码细节部分。 文章地址:Generative Adversarial Network-based Noncontrast CT Angiography for Aorta and Carotid Arteries 代…

JAVA基础面试题准备

一些常见的JAVA基础题,面试中遇到过的会加*显示。 JAVA基础 1.Java中重载和重写的区别?* 2.int 和Integer类型这两个区别吗? 为什么需要有Integer类型: int和Integer类型的区别: 3.遍历list有那些方式吗?…

python如何提取MYSQL数据,并在完成数据处理后保存?

在现代数据驱动的世界中,数据分析已成为企业决策的重要组成部分。 Python作为一种强大的编程语言,因其丰富的库和简单的语法,广泛应用于数据分析、数据清洗和数据可视化等领域。 本文将详细介绍如何使用Python提取MySQL数据库中的数据,并进行数据分析、数据清洗、汇总等操…

【Linux】进程信号(下)

目录 一、信号的阻塞 1.1 信号在内核中的保存方式 1.2 sigset_t信号集 (1)信号集操作 (2)sigprocmask函数 (3)sigpending函数 二、信号的处理 2.1 用户态和内核态 2.2 重谈进程地址空间 三、信号…

盘点2024年4款高清稳定的Windows10录屏工具。

Windows10电脑录屏在生活当中还是挺重要的,无论是教育领域的制作教程,还是游戏玩家记录精彩瞬间,亦或是商务人士进行演示,录屏都能发挥巨大作用。如果设备自带的一些工具无法完成录屏需求的话,这里帮大家找了几款好用到…

AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务

AI大模型应用(3)开源框架Vanna: 利用RAG方法做Text2SQL任务 RAG(Retrieval-Augmented Generation,如下图所示)检索增强生成,即大模型LLM在回答问题时,会先从大量的文档中检索出相关信息,然后基于这些检索出…

W25Q64的学习

24位地址意味着系统有24根地址线,每根地址线可以取两种状态(0或1),所以系统可以形成 2242^{24}224 个不同的地址组合。每个地址对应一个存储单元,通常是1字节。 在大多数现代计算机体系结构中,地址指向的…

万家数科:零售业务信息化融合的探索|OceanBase案例

本文作者:马琳,万家数科数据库专家。 万家数科商业数据有限公司,作为华润万家旗下的信息技术企业,专注于零售行业,在为华润万家提供服务的同时,也积极面向市场,为零售商及其生态系统提供全面的核…

挖矿病毒来势汹汹

病毒来了, 我的个人站点使用了 wordpress, 它的不知哪个漏洞让黑客攻入了我的站点 使用 top 命令看到了有不明进程始终占据了 100% 的 CPU snapshot 1 snapshot 2 通过以下 "三板斧"可以查杀这个进程 先用 top (shiftp) 查找占据 CPU 最多的进程根据其进程号 pid 查看…

【数据结构】宜宾大学-计院-实验四

栈和队列之(栈的基本操作) 实验目的:实验内容:实验结果:实验报告:(及时撰写实验报告):实验测试结果:代码实现1.0:(C/C)【含注释】代码…

QGIS之三十二DEM地形导出三维模型gltf

效果 1、准备数据 (1)dem.tif (2)dom.tif 2、qgis加载dem和dom数据 3、安装插件 插件步骤可以参考这篇文章 QGIS之二十四安装插件 安装了Qgis2threejs插件,结果

无人机之自主降落系统篇

一、定义与功能 无人机自主降落系统是指无人机在无需人工干预的情况下,按照预先设定好的程序或基于实时感知的环境信息,自主完成降落过程的技术系统。该系统能够确保无人机在完成任务后安全、准确地降落到指定位置。 二、系统组成 无人机自主降落系统主…

二十、行为型(访问者模式)

访问者模式(Visitor Pattern) 概念 访问者模式是一种行为型设计模式,允许你在不修改被访问对象的前提下,定义新的操作。它通过将操作封装在访问者类中,从而将操作与对象结构分离。访问者模式非常适合于需要对一组对象…

对“一个中心,三重防护”中安全管理中心的理解

安全管理中心 本控制项为网络安全等级保护标准的技术部分。本项主要包括系统管理、审计管理、安全管理和集中管控四个控制点,其中的集中管控可以说是重中之重,主要都是围绕它来展开的。 28448基本要求中安全管理中心 8.1.5 安全管理中心 8.1.5.1 系统…

ELK之路第二步——可视化界面Kibana

Kibana 1.安装2.解压3.修改配置4.启动 这部分内容就比较简单了,水一片文章。 1.安装 需要梯子 官网下载链接:https://www.elastic.co/cn/downloads/past-releases/kibana-7-3-0 如果你去官网下载页面,点击下载是404报错,记得切换…

redis的zset实现下滑滚动分页查询思路

常规zset查询 我们redis的数据为 我们知道 我们常规查询的话 我们假如 zset 表中 有7个元素,然后我们进行分页查询的话,我们一次查3个元素,然后查出来元素 和元素的分数 我们redis的语法应该这样写 zrevrangebyscore wang 1000 0 withsc…