深度学习:Pytorch分布式训练

深度学习:Pytorch分布式训练

  • 简介
  • 模型并行
  • 数据并行
  • 参考文献

简介

在深度学习领域,模型越来越庞大、数据量不断增加,训练这些大型模型越来越耗时。通过在多个GPU或多个节点上并行地训练模型,我们可以显著减少训练时间。此外,某些模型因为巨大的参数量,单个设备可能无法容纳其整个模型和数据。在这种情况下,分布式训练不仅能提高训练速度,更是必要的手段来训练大模型。为此,PyTorch 分布式训练提供了两种基本的并行方法:

  • 模型并行(Model Parallel):模型并行是指将模型的不同部分放到不同的设备上。这种方式通常用于当一个单独的模型太大而无法放到单个GPU上时。

  • 数据并行(Data Parallel):数据并行是将训练数据分割并在多个设备上同时训练的方法。PyTorch提供了 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel 用于在多个GPU上并行化模型训练。

模型并行

在这里插入图片描述

模型并行主要利用to(device)函数将模型和数据(Tensor张量)放置在适当设备上,其余代码基本无需额外改动。
以下是一个简单的模型并行的代码示例:

import torch
import torch.nn as nn
import torch.optim as optimclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = torch.nn.Linear(10, 10).to('cuda:0')self.relu = torch.nn.ReLU()self.net2 = torch.nn.Linear(10, 5).to('cuda:1')def forward(self, x):x = self.relu(self.net1(x.to('cuda:0')))return self.net2(x.to('cuda:1'))model = DemoModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

注意调用损失函数时,您只需要确保标签与输出位于同一设备上。不难看出,此模型并行的方法效率相对较低,因为在任何时间点,两个 GPU 中只有一个在工作,而另一个则处于闲置状态。而且中间过程变量从cuda:0复制到cuda:1,又会需要额外的开销。因此可以引入流水线并行来进行加速。

在以下代码示例中,采取将输入数据批次划分为 20 组。由于 PyTorch 异步启动 CUDA 操作,因此可以不需要生成多个线程来实现并发。值得注意的是,使用较小的结果split_size会导致许多微小的 CUDA 内核启动,而使用较大的split_size会导致在第一次和最后一次数据划分期间存在相对较长的空闲时间。因此split_size对于特定实验可能有一个最佳配置,可以多次尝试最佳的超参数。

class PipelineParallelResNet50(ModelParallelResNet50):def __init__(self, split_size=20, *args, **kwargs):super(PipelineParallelResNet50, self).__init__(*args, **kwargs)self.split_size = split_sizedef forward(self, x):splits = iter(x.split(self.split_size, dim=0))s_next = next(splits)s_prev = self.seq1(s_next).to('cuda:1')ret = []for s_next in splits:# A. ``s_prev`` runs on ``cuda:1``s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))# B. ``s_next`` runs on ``cuda:0``, which can run concurrently with As_prev = self.seq1(s_next).to('cuda:1')s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))return torch.cat(ret)

数据并行

在这里插入图片描述

DataParallel是单进程、多线程,仅适用于单机,而是DistributedDataParallel多进程,适用于单机和多机训练。由于跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销,DataParallel通常比DistributedDataParallel在单台机器上更慢。

一般地,数据并行的流程为:

  1. 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时初始化 distributed 包。
  2. 如果需要进行组内集体通信,用 new_group 创建子分组
  3. 创建分布式并行模型 DDP(model, device_ids=device_ids)
  4. 为数据集创建 Sampler
  5. 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练
  6. 使用 destory_process_group() 销毁进程组

以下是一个简单的数据并行的代码示例:

# demo_ddp.py
# 在init_process_group()时,一般可设置为Gloo、NCCL或mpi后端,Gloo目前在GPU上运行速度比 NCCL慢。所以经验法则是:
# 分布式GPU训练使用 NCCL 后端
# 分布式CPU训练使用 Gloo 后端import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optimfrom torch.nn.parallel import DistributedDataParallel as DDPclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = nn.Linear(10, 10)self.relu = nn.ReLU()self.net2 = nn.Linear(10, 5)def forward(self, x):return self.net2(self.relu(self.net1(x)))def demo_basic():dist.init_process_group("nccl")rank = dist.get_rank()print(f"Start running basic DDP example on rank {rank}.")# create model and move it to GPU with id rankdevice_id = rank % torch.cuda.device_count()model = DemoModel().to(device_id)ddp_model = DDP(model, device_ids=[device_id])loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(device_id)loss_fn(outputs, labels).backward()optimizer.step()dist.destroy_process_group()if __name__ == "__main__":demo_basic()

然后使用torchrun命令进行启动,其中,nnodes表示总节点数,nproc_per_node表示每个节点运行的进程数,rdzv_id表示用户定义的ID,唯一标识作业的工作组, rdzv_backend表示集合点的后端,rdzv_endpoint表示rendezvous后端运行的地址

# 需要应用 slurm 等集群管理工具来实际在 2 个节点上运行此命令。
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 demo_ddp.py

此命令表示在两台服务器上运行 DDP 脚本,每台服务器运行 8 个进程,即在 16 个 GPU 上运行。

参考文献

  1. https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
  2. https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  3. https://medium.com/deelvin-machine-learning/model-parallelism-vs-data-parallelism-in-unet-speedup-1341bc74ff9e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1137.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Canvas与艺术】绘制黑白山间野营Camping徽章

【说明】 中间的山月图是借用的网上的成图&#xff0c;不是用Canvas绘制的。 【成果图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head>…

Ocr识别

https://blog.csdn.net/qq_47571357/article/details/132017514 Tesserocr 的安装 https://cuiqingcai.com/31102.html https://digi.bib.uni-mannheim.de/tesseract/ https://zhuanlan.zhihu.com/p/642903270 https://segmentfault.com/a/1190000039929696

fdisk使用的MBR分区

MBR和GPT分区 MBR分区 MBR分区一般在分区的时候 &#xff0c;MBR分区格式只能支持2TB以下的硬盘容量。 分区最多为4个主分区 或 3个主分区和1个扩展分区&#xff0c;而创建扩展分区后可以分无数个逻辑分区&#xff0c;当然跟磁盘容量有关&#xff0c; 逻辑分区在扩展分区上…

Windows 下 bat 脚本调用 Git bash 环境 sh 脚本

1、先找到 Git 安装目录 D:\Install\Git 2、Git bash 编写 sh 脚本 start.sh脚本 3、编写 start.bat脚本 echo offcd /d %~dp0 "D:\Install\Git\bin\sh.exe" --login -i -c "./test/start.sh"pause4、执行 bat 脚本 双击 start.bat 我们下期见&#xf…

运算符介绍

运算符介绍 运算符是一种特殊的符号&#xff0c; 用以表示数据的运算、 赋值和比较等。 算术运算符赋值运算符关系运算符 [比较运算符]逻辑运算符位运算符 [需要二进制基础]三元运算符 算术运算符 算术运算符是对数值类型的变量进行运算的&#xff0c; 在 Java 程序中使用的…

Meta Llama 3 简介

文章目录 要点我们对 Llama 3 的目标最先进的性能模型架构训练数据扩大预训练规模指令微调与 Llama 3 一起建造系统级责任方法大规模部署 Llama 3Llama 3 的下一步是什么&#xff1f;立即尝试 Meta Llama 3 本文翻译自&#xff1a;https://ai.meta.com/blog/meta-llama-3/ 要点…

vue实现文字转语音的组件,class类封装,实现项目介绍文字播放,不需安装任何包和插件(2024-04-17)

1、项目界面截图 2、封装class类方法&#xff08;实例化调用&#xff09; // 语音播报的函数 export default class SpeakVoice {constructor(vm, config) {let that thisthat._vm vmthat.config {text: 春江潮水连海平&#xff0c;海上明月共潮生。滟滟随波千万里&#xf…

Android Gradle插件对应的Gradle脚本所需版本

gradle/wrapper目录 gradle-wrapper.properties 文件对应的是脚本版本 distributionUrlhttps\://services.gradle.org/distributions/gradle-7.5-bin.zip根目录中 build.gradle 文件中 对应的 插件版本 如 7.4.2 buildscript {ext.kotlin_version 1.7.0dependencies {classp…

5G网络架构;6G网络架构

目录 5G和6G架构 6G网络架构 5G和6G架构 在设计和功能上有显著的区别,这主要体现在它们各自的核心特点、优势和应用场景上。 5G技术架构的核心特点包括高速率与低延迟、大容量与高密度以及网络切片。高速率与低延迟极大地提升了用户体验,支持更多实时应用和大规模数据传输…

PS-ZB转座子分析流程2-重新分析并总结

数据处理 数据质控 随机挑出九个序列进行比对&#xff0c;结果如下&#xff1a; 所有序列前面的部分序列均完全相同&#xff0c;怀疑是插入的转座子序列&#xff0c;再随机挑选9个序列进行比对&#xff0c;结果如下&#xff1a; 结果相同&#xff0c;使用cutadapt将该段序列修…

mybatis(5)参数处理+语句查询

参数处理&#xff0b;语句查询 1、简单单个参数2、Map参数3、实体类参数4、多参数5、Param注解6、语句查询6.1 返回一个实体类对象6.2 返回多个实体类对象 List<>6.3 返回一个Map对象6.4 返回多个Map对象 List<Map>6.5 返回一个大Map6.6 结果映射6.6.1 使用resultM…

Windows10安装配置nodejs环境

一、下载 下载地址&#xff1a;https://nodejs.cn/download/ ​ 二、安装 1、找到node-v16.17.0-x64.msi安装包, 根据默认提示安装, 过程中间的弹窗不勾选 2、安装完成后, 打开powershell(管理员身份) ​ 3、命令行输入 node -v 和 npm -v 如下图所示则nodejs安装成功 ​ 三…

A27 STM32_HAL库函数 之 IRDA通用驱动 -- B -- 所有函数的介绍及使用

A27 STM32_HAL库函数 之 IRDA通用驱动 -- B -- 所有函数的介绍及使用 1 该驱动函数预览1.11 HAL_IRDA_DMAPause1.12 HAL_IRDA_DMAResume1.13 HAL_IRDA_DMAStop1.14 HAL_IRDA_Abort1.15 HAL_IRDA_AbortTransmit1.16 HAL_IRDA_AbortReceive1.17 HAL_IRDA_Abort_IT1.18 HAL_IRDA_A…

JavaScript Promise与async/await

Promise与async/await 为什么要使用他们如何使用.then() 和.catch()如何将相同的代码转换成sync和Await关键字 为什么要使用他们 前面学习了JavaScript的简单类型&#xff08;例如 数字和字符串&#xff09;&#xff0c;我们的代码会按顺序从上往下执行 console.log(1111); c…

并发编程之ConcurrentHashMap源码分析

1. 主源码逻辑 final V putVal(K key, V value, boolean onlyIfAbsent) {if (key null || value null) throw new NullPointerException();// 1.计算key对应的hashint hash spread(key.hashCode());int binCount 0;// 2. 进行自旋 for (Node<K,V>[] tab table;;) {N…

Win11启用HyperV

Win11启用HyperV 编辑一个txt&#xff0c;输入下面的指令 pushd "%~dp0"dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txtfor /f %%i in (findstr /i . hyper-v.txt 2^>nul) do dism /online /norestart /add-package:"%SystemRoot%…

PLC中连接外部现场设备和CPU的桥梁——输入/输出(I/O)模块

输入&#xff08;Input&#xff09;模块和输出&#xff08;Output&#xff09;模块简称为I/O模块&#xff0c;数字量&#xff08;Digital&#xff0c;又称为开关量&#xff09;输入模块和数字量输出模块简称为DI模块和DQ模块&#xff0c;模拟量&#xff08;Analog&#xff09;输…

第3章 数据

第3章 数据 学习笔记书后练习问题3问题7问题10问题11问题21 学习笔记 value value - 0; 通常用于将字符转换为其对应的整数值enum Jar_Type { CUP, PINT, QUART, HALF_GALLON, GALLON }; 这些符号名的实际值都是整型值&#xff0c;例如&#xff0c;CUP 是0&#xff0c;PINT …

Android安卓写入WIFI热点自动连接NDEF标签

本示例使用的发卡器&#xff1a;Android Linux RFID读写器NFC发卡器WEB可编程NDEF文本/网址/海报-淘宝网 (taobao.com) package com.usbreadertest;import android.os.Bundle; import android.view.MenuItem; import android.view.View; import android.widget.EditText; impo…

JAVA算法训练营打卡总结

目录 初心 目标 挑战 总结 初心 过完年后&#xff0c;突然发现自毕业后到现在已经工作将近两年&#xff0c;在这段时间中除了工作和备考软考外&#xff0c;也就是算法偶尔的刷几道&#xff0c;其它没有什么实际上的提升。 抱着现在的时间不去提升那以后就更没时间提升的心…