【问题解决】DDP | 如何使用 DDP 模式来训练模型

在训练 pytorch 模型时,多卡并行训练能够很大程度上提升模型的训练效率

一般来说,有两种多卡并行的训练方式:DP 和 DDP

一、DP 和 DDP 的区别

DP(Data Parallelism)和DDP(Distributed Data Parallelism)是深度学习中用于训练模型的两种并行计算策略。

1、Data Parallelism (DP):

  • DP通常指在单台机器的多个GPU上实施的数据并行处理。在这种设置中,模型的每个副本都放在不同的GPU上,每个GPU都会处理输入数据的不同批次。
  • 所有GPU完成自己的前向和反向传播后,它们会将梯度发送到主GPU,主GPU会对这些梯度进行平均,然后更新模型的权重。
  • 然后这些更新的权重会被复制到其他所有GPU上,以保持模型的一致性。
  • DP的缺点是它受限于单个机器的资源,因此当模型或数据集非常大时,可能会受到内存或带宽的限制。

2、Distributed Data Parallelism (DDP):

  • DDP是一种更加高级的并行计算策略,它允许跨多台机器的多个GPU进行模型训练。
  • 在DDP中,每个节点(机器)都可能有一个或多个GPU,每个节点都在其GPU上运行模型的一个副本,并处理数据的不同部分。
  • DDP的关键在于它使用了一种更加高效的梯度聚合策略。每个节点都独立地完成前向和反向传播,并计算出梯度。梯度不是首先发送到主节点,而是在所有节点之间直接同步,通常是通过一种称为“All-Reduce”的操作。
  • 这种方法减少了通信瓶颈,因为它不需要所有数据都通过单个主节点,并且可以更有效地扩展到大规模的计算资源。

总结来说,DP是单机多GPU的并行策略,而DDP是跨多台机器的多GPU并行策略。DDP通常在大规模分布式训练场景中更为有效,因为它可以更好地利用分布式系统的计算和存储资源。

二、如何将模型包装进 DDP

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import datetime
import os# 初始化进程组
dist.init_process_group(backend='nccl', init_method='env://')# 设置本地设备
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device("cuda")# 创建模型
model = ...  # 替换为你的模型
device = torch.device("cuda:{}".format(rank))
model.to(device)# 重点一:包装数据集
# load training data
train_data = Dataset(...)
train_sampler = DistributedSampler(train_data)
training_data_loader = DataLoader(dataset=train_data, batch_size=batch_size, sampler=train_sampler, drop_last=True, num_workers=opt.num_workers)# 重点二:包装模型
ddp_model = DDP(model, device_ids=[rank])# 定义损失函数和优化器
criterion = ...
optimizer = ...# 训练循环
for epoch in range(num_epochs):# 重点三:设置 sampler 的 epoch,DistributedSampler 需要这个来维持各个进程之间的随机种子,也就是保证所有进程在数据洗牌时使用的随机种子是一致的,这样每个进程就会得到不同的数据子集,但整个训练集上的采样是一致的,train_sampler.sampler.set_epoch(epoch)for iteration, data in enumerate(training_data_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = ddp_model(inputs)loss = criterion(outputs, labels)loss.backward()  # DDP将在这里同步梯度optimizer.step()# 日志和其他操作if rank == 0:# 打印日志或写入日志文件...# 清理
dist.destroy_process_group()

如果你在使用DDP训练时没有调用 train_sampler.set_epoch(epoch),那么在每个epoch中所有的进程都将以相同的方式从数据集中采样数据。这意味着每个进程将获得相同的数据子集,导致以下几个问题:

  • 数据重复:所有的模型副本将在每个epoch中学习相同的数据,这样就减少了模型训练的总体数据多样性。

  • 并行效率降低:数据的重复使用降低了并行训练的效率,因为模型的不同副本不再是在不同的数据子集上训练,而是在复制的数据上训练。

  • 收敛问题:由于数据的多样性降低,模型可能更难收敛到一个好的解,或者可能导致过拟合,因为模型只看到了数据的一个子集。

  • 泛化能力下降:模型的泛化能力可能会受到影响,因为它没有在整个数据集的不同采样上进行训练。

  • 因此,为了确保有效的分布式训练和数据的多样性,在每个epoch开始时调用 train_sampler.set_epoch(epoch) 是很重要的。这将为每个进程提供一个不同的数据视图,从而使整个模型能够从整个数据集中学习,并提高训练的效率和最终模型的性能。

  • 如果不使用这个 train_sampler.set_epoch(epoch),也可能会导致 8 卡训练结果没有单卡训练结果好,因为不使用 train_sampler.set_epoch(epoch) 的话,即使有多个进程,每个进程在每个 epoch 中采样的数据将会是一样的,这意味着所有 gpu 卡在每个 epoch 都在训练相同的数据

三、如何训练

# 方法一:老方法
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env --node_rank=0 --master_port=12348 train.py --aug xxx
  • python -m torch.distributed.launch:这部分是告诉Python运行torch.distributed.launch模块。这个模块是PyTorch的一部分,用于帮助用户启动多个工作进程进行分布式训练。

  • --nproc_per_node=8:这个参数指定每个节点(node)上要启动的进程数。在这个例子中,你将在单个节点上启动8个进程。

  • --nnodes=1:这个参数指定了参与分布式训练的节点总数。在这个例子中,你只使用了一个节点。

  • --node_rank=0:这个参数指定了当前节点的排名。在分布式训练中,每个节点都有一个唯一的排名,用于在节点之间通信。在只有一个节点的情况下,这个排名通常是0。

  • --master_port=12348:这个参数指定了主节点上用于分布式训练的通信端口。所有的节点都会连接到这个端口来进行通信。

# 方法二:新方法
OMP_NUM_THREADS=8 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_port=11347 train.py --aug xxx
  • OMP_NUM_THREADS=8: 这是一个环境变量,用于设置OpenMP使用的线程数。OpenMP是一个支持多平台共享内存并行编程的API。在这里,设置OMP_NUM_THREADS=8意味着每个进程将尝试使用8个线程进行计算,这有助于优化CPU上的并行计算性能。

  • torchrun: 这是PyTorch 1.9版本引入的一个新工具,用来替代python -m torch.distributed.launch。torchrun是一个简化的命令行工具,用于启动分布式训练。

  • --nproc_per_node=8: 这个参数指定了每个节点(node)上要启动的进程数。在这里,它设置为8,意味着在当前节点上会启动8个训练进程。

  • --nnodes=1: 这个参数指定了参与分布式训练的节点数。这里设置为1,表示只有一个节点参与训练。

  • --node_rank=0: 这个参数指定了当前节点的排名。因为只有一个节点,所以排名是0。在多节点训练中,每个节点会有一个唯一的排名。

  • --master_port=11347: 这个参数指定了主节点用于通信的端口。在分布式训练中,各个节点需要通过网络进行通信,这个端口就是用于这种通信的。

  • train.py: 这是你的训练脚本,torchrun会运行这个脚本作为分布式训练的一部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3683.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jira搭建过程

看到很多小伙伴对jira有兴趣,我们今天就来分享一下jira的搭建吧 首先要明白jira是什么? 看来搭建jira也是我们测试人员需要具备的技能之一了.下面是详细的大家步骤: 1.系统环境准备 Centos 7.5 Mysql 5.6 Java1.8 2.软件安装包 atlassian-jira-software-7.13.0-x64.bin …

【Android 实现AES-CMAC加密】

1. 概述 CMAC(Cipher Block Chaining-Message Authentication Code),也简称为CBC_MAC,它是一种基于对称秘钥分组加密算法的消息认证码。由于其是基于“对称秘钥分组算法”的,故可以将其当做是对称算法的一种操作模式。…

Linux_环境变量

目录 1、查询所有环境变量 2、常见的环境变量 2.1 PATH 2.2 HOME 2.3 PWD 3、增加新的环境变量 4、删除环境变量 5、main函数的三个形参 5.1 argv字符串数组 5.2 env字符串数组 6、系统调用接口 6.1 getenv 6.2 putenv 7、全局变量environ 结语 前言&…

SpringBoot + kotlin 协程小记

前言: Kotlin 协程是基于 Coroutine 实现的,其设计目的是简化异步编程。协程提供了一种方式,可以在一个线程上写起来像是在多个线程中执行。 协程的基本概念: 协程是轻量级的,不会创建新的线程。 协程会挂起当前的协…

中颖51芯片学习9. PWM(12bit脉冲宽度调制)

中颖51芯片学习9. PWM(12bit脉冲宽度调制) 一、资源简介二、PWM工作流程三、寄存器介绍1. PWMx控制寄存器PWMxCON2. PWM0周期寄存器PWM0PH/L3. PWM1周期寄存器PWM1PH/L4. PWM0占空比控制寄存器PWM0DH/L5. PWM1占空比控制寄存器 PWM1DH/L6. 占空比寄存器与…

MacOS 12安装V8Js

一、环境 V8引擎(https://github.com/v8/v8)是Google的开源JavaScript引擎,性能很高,NodeJs就是采用了V8引擎。V8的作用就解析、运行JavaScript脚本,可以简单理解为JavaScript的解析器。 V8Js(https://git…

跨语言指令调优深度探索

目录 I. 介绍II. 方法与数据III. 结果与讨论1. 跨语言迁移能力2. 问题的识别3. 提高跨语言表现的可能方向 IV. 结论V. 参考文献 I. 介绍 在大型语言模型的领域,英文数据由于其广泛的可用性和普遍性,经常被用作训练模型的主要语料。尽管这些模型可能在英…

CDN引入Vue3

选择CDN版本 vue.global.prod.js > 在head中使用 引入后&#xff0c;在后续根组件和子组件中可以通过全局的Vue,来引入对应ref、createApp等方法&#xff0c;如下&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF…

ESLlint重大更新后,使用旧版ESLint搭配Prettier的配置方式

概要 就在前几天&#xff0c;ESLint迎来了一次重大更新&#xff0c;9.0.0版本&#xff0c;根据官方文档介绍&#xff0c;使用新版的先决条件是Node.js版本必须是18.18.0、20.9.0&#xff0c;或者是>21.1.0的版本&#xff0c;新版ESLint将不再直接支持以下旧版配置(非扁平化…

二、OSPF协议基础

基于SPF算法&#xff08;Dijkstra算法&#xff09;的链路状态路由协议OSPF&#xff08;Open Shortest Path First&#xff0c;开放式最短路径优先&#xff09; 目录 1.RIP在大型网络中部署所面临的问题 2.Router ID 3.OSPF的报文 4.OSPF邻居建立过程 5.OSPF报文的确认机制…

SAP的生成式AI

这是一篇openSAP中关于SAP生成式AI课程的笔记,原地址https://open.sap.com/courses/genai1/ 文章目录 Unit 1: Approaches to artificial intelligence概念三种范式监督学习非监督学习强化学习Unit 2: Introduction to generative AI生成式AI基础模型关系基础模型有哪些能力呢…

怎么通过isinstance(Obj,Class)验证?【isinstance】

最近有这样一个项目&#xff0c;这个项目可以用一个成熟的项目的构造树&#xff0c;读取树&#xff0c;再检索的过程&#xff0c;现在有新的需求&#xff0c;另一个逻辑构造同样节点结构的树&#xff0c;pickle序列化保存&#xff0c;再使用原来项目的读取、检索函数&#xff0…

线程、线程组、线程池、锁、事务、分布式

1.线程 Thread类 &#xff0c;可以继承他&#xff0c;复写run方法&#xff0c;然后new一个对象&#xff0c;调用start方法启动。 2.runnable接口&#xff0c;他单独把run方法定义出来了&#xff0c;可以自己实现一个runnable接口&#xff0c;然后new一个runnable对象给到threa…

一年期免费SSL证书申请方法

免费SSL证书的申请已经成为当今互联网安全实践中的重要环节&#xff0c;它不仅有助于保护网站数据传输的隐私性和完整性&#xff0c;还能提升用户信任度&#xff0c;因为现代浏览器会明确标识出未使用HTTPS&#xff08;即未部署SSL证书&#xff09;的网站为“不安全”。以下是一…

vue项目的Husky、env、editorconfig、eslintrc、tsconfig.json配置文件小聊

一、Git配置文件 1、Husky Husky 是一款管理 git hooks 的工具&#xff0c;可以帮助我们触发git提交的各个阶段&#xff1a;pre-commit、commit-msg、pre-push&#xff0c; 有助于我们在项目开发中的git规范和团队协作。 .husky文件通常包含以下内容&#xff1a; pre-commi…

互联网安全面临的全新挑战

前言 当前移动互联网安全形势严峻&#xff0c;移动智能终端漏洞居高不下、修复缓慢&#xff0c;移动互联网恶意程序持续增长&#xff0c;同时影响个人和企业安全。与此同时&#xff0c;根据政策形势移动互联网安全监管重心从事前向事中事后转移&#xff0c;需加强网络安全态势感…

玩转必应bing国内广告投放,正确的攻略方式!

搜索引擎广告作为精准触达潜在客户的重要渠道&#xff0c;一直是众多企业营销策略中的关键一环&#xff0c;在国内市场&#xff0c;虽然百度占据主导地位&#xff0c;但必应Bing凭借其独特的用户群体、高质量的搜索体验以及与微软生态的紧密集成&#xff0c;为广告主提供了不可…

相关运算及实现

本文介绍相关运算及实现。 相关运算在相关检测及数字锁相放大中经常用到&#xff0c;其与卷积运算又有一定的联系&#xff0c;本文简要介绍其基本运算及与卷积运算的联系&#xff0c;并给出实现。 1.定义 这里以长度为N的离散时间序列x(n),y(n)为例&#xff0c;相关运算定义如…

nvm管理多个node版本,快速来回切换node版本

前言 文章基于 windows环境 使用nvm安装多版本nodejs。 最近公司有的项目比较老需要降低node版本才能运行&#xff0c;由于来回进行卸载不同版本的node比较麻烦&#xff1b;所以需要使用node工程多版本管理&#xff0c;后面自己就简单捯饬了一下nvm来管理node&#xff0c;顺便…

VTK----VTK数据结构详解2(计算机篇)

在VTK中&#xff0c;属性数据和点都用数据数组&#xff08;data arrays&#xff09;表示。某些属性数据&#xff08;例如法线、张量&#xff09;需要具有与其定义一致的元组&#xff08;在计算机编程中&#xff0c;元组&#xff08;tuple&#xff09;用来表示存储多种数据类型的…