深度学习:Pytorch分布式训练

深度学习:Pytorch分布式训练

  • 简介
  • 模型并行
  • 数据并行
  • 参考文献

简介

在深度学习领域,模型越来越庞大、数据量不断增加,训练这些大型模型越来越耗时。通过在多个GPU或多个节点上并行地训练模型,我们可以显著减少训练时间。此外,某些模型因为巨大的参数量,单个设备可能无法容纳其整个模型和数据。在这种情况下,分布式训练不仅能提高训练速度,更是必要的手段来训练大模型。为此,PyTorch 分布式训练提供了两种基本的并行方法:

  • 模型并行(Model Parallel):模型并行是指将模型的不同部分放到不同的设备上。这种方式通常用于当一个单独的模型太大而无法放到单个GPU上时。

  • 数据并行(Data Parallel):数据并行是将训练数据分割并在多个设备上同时训练的方法。PyTorch提供了 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel 用于在多个GPU上并行化模型训练。

模型并行

在这里插入图片描述

模型并行主要利用to(device)函数将模型和数据(Tensor张量)放置在适当设备上,其余代码基本无需额外改动。
以下是一个简单的模型并行的代码示例:

import torch
import torch.nn as nn
import torch.optim as optimclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = torch.nn.Linear(10, 10).to('cuda:0')self.relu = torch.nn.ReLU()self.net2 = torch.nn.Linear(10, 5).to('cuda:1')def forward(self, x):x = self.relu(self.net1(x.to('cuda:0')))return self.net2(x.to('cuda:1'))model = DemoModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

注意调用损失函数时,您只需要确保标签与输出位于同一设备上。不难看出,此模型并行的方法效率相对较低,因为在任何时间点,两个 GPU 中只有一个在工作,而另一个则处于闲置状态。而且中间过程变量从cuda:0复制到cuda:1,又会需要额外的开销。因此可以引入流水线并行来进行加速。

在以下代码示例中,采取将输入数据批次划分为 20 组。由于 PyTorch 异步启动 CUDA 操作,因此可以不需要生成多个线程来实现并发。值得注意的是,使用较小的结果split_size会导致许多微小的 CUDA 内核启动,而使用较大的split_size会导致在第一次和最后一次数据划分期间存在相对较长的空闲时间。因此split_size对于特定实验可能有一个最佳配置,可以多次尝试最佳的超参数。

class PipelineParallelResNet50(ModelParallelResNet50):def __init__(self, split_size=20, *args, **kwargs):super(PipelineParallelResNet50, self).__init__(*args, **kwargs)self.split_size = split_sizedef forward(self, x):splits = iter(x.split(self.split_size, dim=0))s_next = next(splits)s_prev = self.seq1(s_next).to('cuda:1')ret = []for s_next in splits:# A. ``s_prev`` runs on ``cuda:1``s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))# B. ``s_next`` runs on ``cuda:0``, which can run concurrently with As_prev = self.seq1(s_next).to('cuda:1')s_prev = self.seq2(s_prev)ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))return torch.cat(ret)

数据并行

在这里插入图片描述

DataParallel是单进程、多线程,仅适用于单机,而是DistributedDataParallel多进程,适用于单机和多机训练。由于跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销,DataParallel通常比DistributedDataParallel在单台机器上更慢。

一般地,数据并行的流程为:

  1. 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时初始化 distributed 包。
  2. 如果需要进行组内集体通信,用 new_group 创建子分组
  3. 创建分布式并行模型 DDP(model, device_ids=device_ids)
  4. 为数据集创建 Sampler
  5. 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练
  6. 使用 destory_process_group() 销毁进程组

以下是一个简单的数据并行的代码示例:

# demo_ddp.py
# 在init_process_group()时,一般可设置为Gloo、NCCL或mpi后端,Gloo目前在GPU上运行速度比 NCCL慢。所以经验法则是:
# 分布式GPU训练使用 NCCL 后端
# 分布式CPU训练使用 Gloo 后端import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optimfrom torch.nn.parallel import DistributedDataParallel as DDPclass DemoModel(nn.Module):def __init__(self):super(DemoModel, self).__init__()self.net1 = nn.Linear(10, 10)self.relu = nn.ReLU()self.net2 = nn.Linear(10, 5)def forward(self, x):return self.net2(self.relu(self.net1(x)))def demo_basic():dist.init_process_group("nccl")rank = dist.get_rank()print(f"Start running basic DDP example on rank {rank}.")# create model and move it to GPU with id rankdevice_id = rank % torch.cuda.device_count()model = DemoModel().to(device_id)ddp_model = DDP(model, device_ids=[device_id])loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(device_id)loss_fn(outputs, labels).backward()optimizer.step()dist.destroy_process_group()if __name__ == "__main__":demo_basic()

然后使用torchrun命令进行启动,其中,nnodes表示总节点数,nproc_per_node表示每个节点运行的进程数,rdzv_id表示用户定义的ID,唯一标识作业的工作组, rdzv_backend表示集合点的后端,rdzv_endpoint表示rendezvous后端运行的地址

# 需要应用 slurm 等集群管理工具来实际在 2 个节点上运行此命令。
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 demo_ddp.py

此命令表示在两台服务器上运行 DDP 脚本,每台服务器运行 8 个进程,即在 16 个 GPU 上运行。

参考文献

  1. https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
  2. https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  3. https://medium.com/deelvin-machine-learning/model-parallelism-vs-data-parallelism-in-unet-speedup-1341bc74ff9e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1137.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Canvas与艺术】绘制黑白山间野营Camping徽章

【说明】 中间的山月图是借用的网上的成图&#xff0c;不是用Canvas绘制的。 【成果图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head>…

fdisk使用的MBR分区

MBR和GPT分区 MBR分区 MBR分区一般在分区的时候 &#xff0c;MBR分区格式只能支持2TB以下的硬盘容量。 分区最多为4个主分区 或 3个主分区和1个扩展分区&#xff0c;而创建扩展分区后可以分无数个逻辑分区&#xff0c;当然跟磁盘容量有关&#xff0c; 逻辑分区在扩展分区上…

Meta Llama 3 简介

文章目录 要点我们对 Llama 3 的目标最先进的性能模型架构训练数据扩大预训练规模指令微调与 Llama 3 一起建造系统级责任方法大规模部署 Llama 3Llama 3 的下一步是什么&#xff1f;立即尝试 Meta Llama 3 本文翻译自&#xff1a;https://ai.meta.com/blog/meta-llama-3/ 要点…

vue实现文字转语音的组件,class类封装,实现项目介绍文字播放,不需安装任何包和插件(2024-04-17)

1、项目界面截图 2、封装class类方法&#xff08;实例化调用&#xff09; // 语音播报的函数 export default class SpeakVoice {constructor(vm, config) {let that thisthat._vm vmthat.config {text: 春江潮水连海平&#xff0c;海上明月共潮生。滟滟随波千万里&#xf…

PS-ZB转座子分析流程2-重新分析并总结

数据处理 数据质控 随机挑出九个序列进行比对&#xff0c;结果如下&#xff1a; 所有序列前面的部分序列均完全相同&#xff0c;怀疑是插入的转座子序列&#xff0c;再随机挑选9个序列进行比对&#xff0c;结果如下&#xff1a; 结果相同&#xff0c;使用cutadapt将该段序列修…

mybatis(5)参数处理+语句查询

参数处理&#xff0b;语句查询 1、简单单个参数2、Map参数3、实体类参数4、多参数5、Param注解6、语句查询6.1 返回一个实体类对象6.2 返回多个实体类对象 List<>6.3 返回一个Map对象6.4 返回多个Map对象 List<Map>6.5 返回一个大Map6.6 结果映射6.6.1 使用resultM…

Windows10安装配置nodejs环境

一、下载 下载地址&#xff1a;https://nodejs.cn/download/ ​ 二、安装 1、找到node-v16.17.0-x64.msi安装包, 根据默认提示安装, 过程中间的弹窗不勾选 2、安装完成后, 打开powershell(管理员身份) ​ 3、命令行输入 node -v 和 npm -v 如下图所示则nodejs安装成功 ​ 三…

JavaScript Promise与async/await

Promise与async/await 为什么要使用他们如何使用.then() 和.catch()如何将相同的代码转换成sync和Await关键字 为什么要使用他们 前面学习了JavaScript的简单类型&#xff08;例如 数字和字符串&#xff09;&#xff0c;我们的代码会按顺序从上往下执行 console.log(1111); c…

并发编程之ConcurrentHashMap源码分析

1. 主源码逻辑 final V putVal(K key, V value, boolean onlyIfAbsent) {if (key null || value null) throw new NullPointerException();// 1.计算key对应的hashint hash spread(key.hashCode());int binCount 0;// 2. 进行自旋 for (Node<K,V>[] tab table;;) {N…

PLC中连接外部现场设备和CPU的桥梁——输入/输出(I/O)模块

输入&#xff08;Input&#xff09;模块和输出&#xff08;Output&#xff09;模块简称为I/O模块&#xff0c;数字量&#xff08;Digital&#xff0c;又称为开关量&#xff09;输入模块和数字量输出模块简称为DI模块和DQ模块&#xff0c;模拟量&#xff08;Analog&#xff09;输…

Android安卓写入WIFI热点自动连接NDEF标签

本示例使用的发卡器&#xff1a;Android Linux RFID读写器NFC发卡器WEB可编程NDEF文本/网址/海报-淘宝网 (taobao.com) package com.usbreadertest;import android.os.Bundle; import android.view.MenuItem; import android.view.View; import android.widget.EditText; impo…

LabVIEW频谱感知实验平台

LabVIEW频谱感知实验平台 在当前的通信网络中&#xff0c;频谱资源的高效利用成为了研究和实践的重要方向之一。随着无线通信技术的快速发展&#xff0c;传统的固定频谱分配策略已无法满足日益增长的通信需求&#xff0c;因此&#xff0c;频谱感知技术作为认知无线电的核心组成…

Zed 捕获图像+测距

Zed 捕获图像测距 1. 导入相关库2. 相机初始化设置3. 获取中心点深度数据4. 计算中心点深度值5. 完整代码5. 实验效果 此代码基于官方代码基础上进行改写&#xff0c;主要是获取zed相机深度画面中心点的深度值&#xff0c;为yolo测距打基础。 1. 导入相关库 import pyzed.sl …

iview中基于upload源代码组件封装更为完善的上传组件

业务背景 最近接了一个用iview为基础搭建的vue项目&#xff0c;在开需求研讨会议的时候&#xff0c;我个人提了一个柑橘很合理且很常规的建议&#xff0c;upload上传文件支持同时上传多个并且可限制数量。当时想的是这不应该很正常吗&#xff0c;但是尴尬的是&#xff1a;只有…

【Proteus】蜂鸣器播放音乐

按键按一次&#xff0c;蜂鸣器响一次 &#xff0c;LCD1602同步。 #include <REGX52.H> #include <INTRINS.H>unsigned int keynum; sbit RSP3^0; //** sbit RWP3^1; //** sbit EP3^2; //** sbit buzzerP1^5; void delay(unsigned int n)//1ms {unsigned char a,…

虹科Pico汽车示波器 | 免拆诊断案例 | 2016款保时捷911 GT3 RS车发动机异响

一、故障现象 一辆2016款保时捷911 GT3 RS车&#xff0c;搭载4.0 L水平对置发动机&#xff08;型号为MA176&#xff09;&#xff0c;累计行驶里程约为4.2万km。车主反映&#xff0c;1星期前上过赛道&#xff0c;现在发动机有“哒哒”异响。 二、故障诊断 接车后试车&#xff…

51.基于SpringBoot + Vue实现的前后端分离-校园志愿者管理系统(项目 + 论文)

项目介绍 本站是一个B/S模式系统&#xff0c;采用SpringBoot Vue框架&#xff0c;MYSQL数据库设计开发&#xff0c;充分保证系统的稳定性。系统具有界面清晰、操作简单&#xff0c;功能齐全的特点&#xff0c;使得基于SpringBoot Vue技术的校园志愿者管理系统设计与实现管理工…

正则表达式中 “$” 并不是表示 “字符串结束”

△△请给“Python猫”加星标 &#xff0c;以免错过文章推送 作者&#xff1a;Seth Larson 译者&#xff1a;豌豆花下猫Python猫 英文&#xff1a;Regex character “$” doesnt mean “end-of-string” 转载请保留作者及译者信息&#xff01; 这篇文章写一写我最近在用 Python …

C++ 内存分区管理

一、栈区&#xff08;Stack&#xff09; 栈区用来存储函数的参数值、局部变量的值等数据。栈区是自动分配和释放的&#xff0c;函数执行时会在栈区分配空间&#xff0c;函数执行结束时会自动释放这些空间。栈区的数据是连续分配的&#xff0c;由系统自动管理。 注意事项&…

普通人赚钱途径大盘点:从搬砖到玩转智慧,财富之路任你探索

在生活的大舞台上&#xff0c;每个人都在以自己的方式演绎着赚钱的故事。作为普通人&#xff0c;我们或许没有显赫的财富背景&#xff0c;但赚钱的途径却是多种多样&#xff0c;等待我们去发掘。今天&#xff0c;就让我来为大家盘点一下普通人赚钱的常见途径&#xff0c;看看哪…