【单机多卡】torch改造代码为DDP单机多卡分布式并行

torch分布式数据并行DDPtorch.nn.parallel.DistributedDataParallel代码修改记录。(要求pytorch_version>1.0)

目录

1.🍄🍄要修改的地方概览

2.✏️✏️初始化

3.✏️✏️设置当前进程GPU

4.✏️✏️设置sampler

5.✏️✏️sampler传入dataloader

6.✏️✏️数据放GPU

7.✏️✏️模型放GPU

8.✏️✏️load模型

9.✏️✏️save模型

10.✏️✏️执行命令

整理不易,欢迎一键三连!!!



1.🍄🍄要修改的地方概览

2.✏️✏️初始化

在代码最开始的地方设置初始化参数,即训练和数据组织之前。

n_gpus = args.n_gpus   #自行传入
#local_rank = args.local_rank   #自行传入
local_rank = int(os.environ['LOCAL_RANK'])   #代码计算torch.distributed.init_process_group('nccl', world_size=n_gpus, rank=local_rank)#初始化进程组
  • 指定GPU之间的通信方式'nccl'
  • world_size:当前这个节点上要用多少GPU卡;(当前节点就是当前机器)
  • rank: 当前进程在哪个GPU卡上,通过args.local_rank来获取,local_rank变量是通过外部指令传入的;(也可以通过环境变量来接收)

注意:自行传入的变量需要通过argparse第三方库写入,示例如下:

import argparseparser = argparse.ArgumentParser()
parser.add_argument("--n_gpus", help="num of gpus")
parser.add_argument("-p", "--project", help="project name")
parser.add_argument('-s', '--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('-d', '--seed', type=int, default=72, help='Random seed.')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='Number of epochs to train.')args = parser.parse_args()
print(args.n_gpus)
print(args.sparse)
print(args.seed)
print(args.epochs)

3.✏️✏️设置当前进程GPU

在初始化之后紧接着设置当前进程的GPU

torch.cuda.set_device(local_rank)

上述指令作用相当于设置CUDA_VISBLE_DEVICES环境变量,设置当前进程要用第几张卡;

4.✏️✏️设置sampler

from torch.utils.data.distributed import DistributedSampler
train_sampler = DistributedSampler(dataset_train)
...
for epoch in range(start_epoch, total_epochs):train_sampler.set_epoch(epoch)  #为了让每张卡在每个周期中得到的数据是随机的...

此处的train_dataset为load数据的Dataset类,根据数据地址return出每个image和队形的mask,DistributedSampler返回一堆数据的索引train_sampler,根据索引去dataloader中拿数据,并且在每次epoch训练之前,加上train_sampler.set_epoch(epoch)这句,达到shuffle=True的目的。

5.✏️✏️sampler传入dataloader

from torch.utils.data import DataLoader
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,sampler = train_sampler)
dataloader_val = DataLoader(dataset_val,batch_size=1,)

通过将train_sampler传入dataloader达到数据传入模型的数据格式。

6.✏️✏️数据放GPU

在每次训练过程中,设置数据放GPU里。

for img,label in dataloader_train:inputs = img.cuda(local_rank)  #数据放GPUlabels = label.cuda(local_rank)  #数据放GPU...

7.✏️✏️模型放GPU

在定义模型的地方,设置将模型放入GPU

model = XXNet()
net = torch.nn.parallel.DistributedDataParallel(model.cuda(local_rank),device_ids=[local_rank])  #模型拷贝,放入DistributedDataParallel

8.✏️✏️load模型

torch.load(model_file_path, map_location = local_rank)

设置 map_location指定将模型传入哪个GPU上

9.✏️✏️save模型

torch.save(net.module.state_dict(), os.path.join(ckp_savepath, ckp_name))

注意,此处保存的net是net.module.state_dict

10.✏️✏️执行命令

python -m torch.distributed.launch --nproc_per_node=n_gpus --master_port 29502 train.py
  • nproc_per_node:等于GPU数量
  • master_port:默认为29501,如果出现address already in use,可以将其修改为其他值,比如29502

参考:视频讲解

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--


🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据开发面试必问:Hive调优技巧系列一

Hive必问调优 Hive 调优拆解:Hive SQL 几乎是每一位互联网分析师的必备技能,相信很多小伙伴都有被面试官问到 Hive 优化问题的经历。所以掌握扎实的 HQL 基础尤为重要,hive优化也是小伙伴应该掌握的一项技能,本篇文章具体从hive建表优化、HQ…

数据结构-链表结构-单向链表

链表结构 说到链表结构就不得不提起数据结构,什么是数据结构?就是用来组织和存储数据的某种结构。那么到底是某种结构呢? 数据结构分为: 线性结构 数组,链表,栈,队列 树形结构 二叉树&#x…

QWidget窗口类

QWidget窗口类 设置父对象窗口位置窗口尺寸窗口标题和图标信号槽函数例子1例子3例子3 设置父对象 // 构造函数 QWidget::QWidget(QWidget *parent nullptr, Qt::WindowFlags f Qt::WindowFlags());// 公共成员函数 // 给当前窗口设置父对象 void QWidget::setParent(QWidget…

Linux系统下MySQL读写分离

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 一、基于Amoeba读写分离 1.基于程序代码内部实现 2.基于中间代理层实现 三、操作步骤 1.在主机Amoeba上安装java环境 2.安装并配置Amoeba 3.配置Amoeba读写分离…

【大数据趋势】7月30日 汇率,恒指期货的大数据趋势概率分析。

1. 数据源头之一 : 汇率变化 从程序模拟趋势来看,美元在持续弱势状态,周线上正在构建一个新的下跌趋势,而且正在反抽过程中,即将完成,如果没有外部干预,会顺势往下。从月线来看,高点逐步降低&a…

线性代数的学习和整理2:线性代数的基础知识(整理ing)

目录 0 写在前面的话 网上推荐的线性代数的课程 1 线性代数和矩阵的各种概念 1.1 各种逻辑图 2 关于线性代数入门的各种灵魂发问 2.1 什么是线性,什么是线性相关 ? 为什么叫线性变换? 为什么叫线性代数? 2.2 线性代数是人造…

Spark性能调优指南来了!

1、什么是Spark Spark 是一种基于内存的快速、通用、可扩展的大数据分析计算引擎。 Spark Core:实现了Spark的基本功能,包含任务调度、内存管理、错误恢复、与存储系统交互等模块。Spark Core中还包含了对弹性分布式数据集(Resilient Distributed Dat…

安科瑞智慧空开微型断路器在银行的应用-安科瑞黄安南

应用场景 智能微型断路器与智能网关组合应用于末端回路 功能 1.计量功能:实时上报电压、电流、功率、电能、漏电、温度、频率等电参量; 2.报警功能:过压报警、欠压报警、过流报警、过载报警、漏电报警、超温报警、三相电缺相报警&#xff…

论文笔记:Adjusting for Autocorrelated Errors in Neural Networks for Time Series

2021 NIPS 原来的时间序列预测任务是根据预测论文提出用一阶自回归误差预测 一阶差分,类似于ResNet的残差思路?记为pred,最终的预测结果

【蓝桥杯备考资料】如何进入国赛?

目录 写在前面注意事项数组、字符串处理BigInteger日期问题DFS 2013年真题Java B组世纪末的星期马虎的算式振兴中华黄金连分数有理数类(填空题)三部排序(填空题)错误票据幸运数字带分数连号区间数 2014年真题蓝桥杯Java B组03猜字…

维护电脑,让“战友”保持长寿命

目录 维护电脑,让“战友”保持长寿命介绍你的电脑介绍一下你的日常维护措施给出一些你觉得有用的维护技巧不推荐做些什么其他补充总结 无论是学习还是工作,电脑都是IT人必不可少的重要武器,一台好电脑除了自身配置要经得起考验,后…

Linux内核的I2C驱动框架详解------这应该是我目前600多篇博客中耗时最长的一篇博客

目录 1 I2C驱动整体框架图 2 I2C控制器 2.1 I2C控制器设备--I2C控制器在内核中也被看做一个设备 2.2 i2c控制器驱动程序 2.3 platform_driver结构体中的probe函数做了什么 2.3.1 疑问: i2cdev_notifier_call函数哪里来的 2.3.2 疑问:为什么有两…

2023 ISSE观察:智能遮阳窗帘行业蓬勃发展,AI设计引热议

7月31日,上海国际智能遮阳与建筑节能展览会落下帷幕。作为智能遮阳的行业展会,展会三天,现场热闹非凡,参展商和观展者络绎不绝。 作为一大行业盛事,2023 ISSE展会方打造了五大展区,除了提供系统门窗装修方案…

二、SQL-6.DCL-1).用户管理

一、DCL介绍 Data Control Language 数据控制语言 用来管理数据库 用户、控制数据库的 访问权限。 二、语法 1、管理用户 管理用户在系统数据库mysql中的user表中创建、删除一个用户,需要Host(主机名)和User(用户名&#xff0…

openGauss学习笔记-26 openGauss 高级数据管理-约束

文章目录 openGauss学习笔记-26 openGauss 高级数据管理-约束26.1 NOT NULL约束26.2 UNIQUE约束26.3 PRIMARY KEY26.4 FOREIGN KEY26.5 CHECK约束 openGauss学习笔记-26 openGauss 高级数据管理-约束 约束子句用于声明约束,新行或者更新的行必须满足这些约束才能成…

基于SHARC+®单核的ADSP-21567KBCZ6、ADSP-21566BBCZ4、ADSP-21566KBCZ4高性能DSP处理器产品

ADSP-2156x 处理器的速度高达 1 GHz,属于 SHARC 系列产品。ADSP-2156x 处理器基于 SHARC 单核。ADSP-2156x SHARC 处理器是 SIMD SHARC 系列数字信号处理器 (DSP) 中的一款产品,采用 ADI 的超级哈佛架构。这些 32 位/40 位/64 位浮点处理器已针对高性能音…

Rust vs Go:常用语法对比(九)

题图来自 Golang vs Rust - The Race to Better and Ultimate Programming Language 161. Multiply all the elements of a list Multiply all the elements of the list elements by a constant c 将list中的每个元素都乘以一个数 package mainimport ( "fmt")func …

Android Unit Test

一、测试基础知识 1.1 测试级别 测试金字塔(如图 2 所示)说明了应用应如何包含三类测试(即小型、中型和大型测试): 小型测试是指单元测试,用于验证应用的行为,一次验证一个类。 中型测试是指…

创造自己的宠物医院预约服务小程序,步骤详解

在现代社会,越来越多的人开始养宠物,而宠物的健康管理也成为了一个重要的话题。为了方便宠物主人随时随地进行宠物医院的管理和服务,开发一个宠物医院管理小程序是很有必要的。今天我们将分享一些制作宠物医院管理小程序的技巧,帮…

Vue没有node_modules怎么办

npm install 一下 然后再npm run serve 就可以运行了