pytorch DistributedDataParallel 分布式训练踩坑记录

目录

    • 一、几个比较常见的概念:
    • 二、踩坑记录
      • 2.1 dist.init_process_group初始化
      • 2.2 spawn启动(rank怎么来的)
      • 2.3 loss backward
      • 2.4 model cuda设置
      • 2.5 数据加载

一、几个比较常见的概念:

  • rank: 多机多卡时代表某一台机器,单机多卡时代表某一块GPU
  • world_size: 多机多卡时代表有几台机器,单机多卡时代表有几块GPU
    world_size = torch.cuda.device_count()
    
  • local_rank: 多机多卡时代表某一块GPU, 单机多卡时代表某一块GPU
    单机多卡的情况要比多机多卡的情况常见的多。
  • DP:适用于单机多卡(=多进程)训练。算是旧版本的DDP
  • DDP:适用于单机多卡训练、多机多卡。

二、踩坑记录

2.1 dist.init_process_group初始化

这一步就是设定一个组,这个组里面设定你有几个进程(world_size),现在是卡几(rank)。让pycharm知道你要跑几个进程,包装在组内,进行通讯这样模型参数会自己同步,不需要额外操作了。

import os
import torch.distributed as distdef ddp_setup(rank,world_size):os.environ['MASTER_ADDR'] = 'localhost' #rank0 对应的地址os.environ['MASTER_PORT'] = '6666' #任何空闲的端口dist.init_process_group(backend='nccl',  #nccl Gloo #nvidia显卡的选择ncclworld_size=world_size, init_method='env://',rank=rank) #初始化默认的分布进程组dist.barrier() #等到每块GPU运行到这再继续往下走

2.2 spawn启动(rank怎么来的)

rank是自动分配的。怎么分配呢?这里用的是spawn也就一行代码。

import torch.multiprocessing as mp
def main (rank:int,world_size:int,args):pass#训练代码 主函数mp.spawn(main,args=(args.world_size,args), nprocs=args.world_size)

注意,调用spawn的时候,没有输入main的其中一个参数rank,rank由代码自动分配。将代码复制两份在两张卡上同时跑,你可以print(rank),会发现输出 0 1。两份代码并行跑。

另外,nprocs=args.world_size。如果不这么写,代码会卡死在这,既不报错,也不停止。

2.3 loss backward

one of the variables needed for gradient computation has been modified by an inplace operationRuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

经过调试发现,当使用nn.DataParallel并行训练或者单卡训练均可正常运行;另外如果将两次模型调用集成到model中,即通过out1, out2 = model(input0, input1) 的方式在分布式训练下也不会报错。

在分布式训练中,如果对同一模型进行多次调用则会触发以上报错,即nn.parallel.DistributedDataParallel方法封装的模型,forword()函数和backward()函数必须交替执行,如果执行多个(次)forward()然后执行一次backward()则会报错。

解决此问题可以聚焦到nn.parallel.DistributedDataParallel接口上,通过查询PyTorch官方文档发现此接口下的两个参数:

  • find_unused_parameters: 如果模型的输出有不需要进行反向传播的,此参数需要设置为True;若你的代码运行后卡住不动,基本上就是该参数的问题。
  • broadcast_buffers: 该参数默认为True,设置为True时,在模型执行forward之前,gpu0会把buffer中的参数值全部覆盖到别的gpu上。
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False, find_unused_parameters=True)

2.4 model cuda设置

RuntimeError: NCCL error in: ../torch/lib/c10d/ProcessGroupNCCL.cpp:859, invalid usage, NCCL version 21.1.1
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).

*这是因为model和local_rank所指定device不一致引起的错误。

model.cuda(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],broadcast_buffers=False,find_unused_parameters=True)

2.5 数据加载

使用distributed加载数据集,需要使用DistributedSampler自动为每个gpu分配数据,但需要注意的是sampler和shuffle=True不能并存。

train_sampler = DistributedSampler(trainset)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.train_batch_size,num_workers=args.train_workers,sampler=train_sampler)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/138633.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

浅谈Elasticsearch监控和日志分析

Elasticsearch 监控和日志分析 Elasticsearch 是一个分布式搜索引擎,它提供了全文搜索、结构化搜索、分析等功能。在实际应用中,监控和日志分析是确保 Elasticsearch 集群稳定、高效运行的关键。本文将详细讲解 Elasticsearch 的监控和日志分析功能&…

rabbitmq 交换机相关实例代码

1.扇形交换机 定义扇形交换机和队列 package com.macro.mall.portal.config;import org.springframework.amqp.core.Binding; import org.springframework.amqp.core.BindingBuilder; import org.springframework.amqp.core.FanoutExchange; import org.springframework.amqp.…

macOS Sonoma 14.2beta2(23C5041e)发布(附黑白苹果镜像地址)

系统介绍 黑果魏叔11 月 10 日消息,今日向 Mac 电脑用户推送了 macOS 14.2 开发者预览版 Beta 2 更新(内部版本号:23C5041e),本次更新距离上次发布隔了 14 天。 macOS Sonoma 14.2 添加了 Music 收藏夹播放列表&…

2000-2022年上市公司数字化转型同群效应数据

2000-2022年上市公司数字化转型同群效应数据 1、时间:2000-2022年 2、指标:股票代码、年份、行业代码、行政区划代码、数字化转型程度-A、数字化转型程度-B、同行业同群-数字化转型程度-A_均值、同行业同群-数字化转型程度-A_中位数、同省份同群-数字化…

专业128分总分390+上岸中山大学884信号与系统电通院考研经验分享

专业课884 信号系统 过年期间开始收集报考信息,找到了好几个上岸学姐和学长,都非常热情,把考研的准备,复习过程中得与失,都一一和我分享,非常感谢。得知这两年专业课难度提高很多,果断参加了学长…

transformers安装避坑

1.4 下载rust编辑器 看到这里你肯定会疑惑了,我们不是要用python的吗? 这个我也不知道,你下了就对了,不然后面的transformers无法安装 因为是windows到官网选择推荐的下载方式https://www.rust-lang.org/tools/install。 执行文…

写一下关于部署项目到服务器的心得(以及遇到的难处)

首先要买个服务器(本人的是以下这个) 这里我买的是宝塔面板的,没有宝塔面板的也可以自行安装 点击登录会去到以下页面 在这个界面依次执行下面命令会看到账号和密码和宝塔面板内外网地址 sudo -s bt 14点击地址就可以跳转宝塔对应的内外网页面 然后使用上述命令提供的账号密…

RK3568平台 查看内存的基本命令

一.free命令 free命令显示系统使用和空闲的内存情况,包括物理内存、交互区内存(swap)和内核缓冲区内存。共享内存将被忽略。 Mem 行(第二行)是内存的使用情况。 Swap 行(第三行)是交换空间的使用情况。 total 列显示系统总的可用物理内存和交换空间大小。 used 列显…

SpringBoot中的桥接模式

桥接模式是一种结构型设计模式,它的主要目的是通过将抽象部分与实现部分分离,提高系统的灵活性和可扩展性。在桥接模式中,有四个主要参与者:抽象类、具体抽象类、桥接类和具体类。 抽象类是定义了抽象方法的基类,这些…

内核移植笔记 Cortex-M移植

常用寄存器 PRIMASK寄存器 为1位宽的中断屏蔽寄存器。在置位时,它会阻止不可屏蔽中断(NMI)和HardFault异常之外的所有异常(包括中断)。 实际上,它是将当前异常优先级提升为0,这也是可编程异常/…

49.批处理命令(1/2)

目录 一批处理。 (1)批处理定义。 (2)常见命令。 (2.1)rem和:: (2.2)echo和。 (2.3)pause。 (2.4)errorlevel。 (…

Halcon WPF 开发学习笔记(2):Halcon导出c#脚本

文章目录 前言HalconC#教学简单说明如何二开机器视觉 前言 我目前搜了一下我了解的机器视觉软件,有如下特点 优点缺点兼容性教学视频(B站前三播放量)OpenCV开源,免费,因为有源码所以适合二次开发学习成本极高,卡学历。研究生博士…

【Java】Netty创建网络服务端客户端(TCP/UDP)

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍Netty创建网络服务端客户端示例。 学其所用,用其所学。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下,下次更…

再聊canal的FlatMessage和事务之间的关系

背景 近期得益于项目上的技改推动,我对canal又有了进一步的认识。 最开始我以为flatMassage的id代表的是binlog的唯一id,随着对canal的使用深度不断加深,我逐渐认识到原先认为的是错误的。 你是否有过这样的疑问: 事务与FlatMessage的id是什么关系?不同事务的FlatMessa…

Java线程池——Executor框架

文章目录 一、Executor接口二、ExecutorService接口三、ThreadPoolExecutor类1、状态2、Worker3、扩展 四、ForkJoinPool类1、工作窃取算法2、Fork/Join的设计3、执行原理 五、ScheduledThreadPool类1、ScheduledExecutorService2、比较Timer 六、Executors类 Executor 框架是 …

Windows系统安装2个版本得的MySQL

一、MySQL官网下载对应版本的zip文件 最新版本8.0.34下载链接:https://dev.mysql.com/downloads/mysql/ MySQL 5.7下载链接:https://downloads.mysql.com/archives/community/ 二、将下载到的压缩包解压到指定目录 使用解压工具将下载到的压缩包解…

智安网络|探索人机交互的未来:自然语言处理的前沿技术

自然语言处理是人工智能领域中研究人类语言和计算机之间交互的一门学科。它涉及了语言的理解、生成、翻译、分类和摘要等多个方面。随着人们对自然语言处理的重视和需求不断增长,成为了热门的研究方向。 首先,我们需要了解自然语言处理的基本概念。自然…

Learning an Animatable Detailed 3D Face Model from In-The-Wild Images论文笔记

Learning an Animatable Detailed 3D Face Model from In-The-Wild Images论文笔记 论文目标:提出一个端到端的框架,可以从非受控的图片中学习高质量、可动画的3D人脸模型。论文方法:论文结果:论文意义: 论文目标:提出一个端到端的框架,可以从非受控的图片中学习高质量、可动画…

js实现定时刷新,并设置定时器上限

定时器 在js中,有两种定时器: 倒计时定时器 倒计时定时器,也叫延时定时器或一次性定时器 功能:倒计时多长时间后执行某个动作 语法:setTimeout(function, timeout); 返回值:int类型,当前定时器…

Transforme原理--全局解读

文章目录 作用全局解读 作用 Transformer最初设计用于处理序列数据,特别在NLP(自然语言处理)领域取得了巨大成功 全局解读 Transformer来源于谷歌的一篇经典论文Attention is All you Need 在此使用Transformer在机器翻译中的运用来讲解Transformer。 其中Tran…