【单机多卡】torch改造代码为DDP单机多卡分布式并行

torch分布式数据并行DDPtorch.nn.parallel.DistributedDataParallel代码修改记录。(要求pytorch_version>1.0)

目录

1.🍄🍄要修改的地方概览

2.✏️✏️初始化

3.✏️✏️设置当前进程GPU

4.✏️✏️设置sampler

5.✏️✏️sampler传入dataloader

6.✏️✏️数据放GPU

7.✏️✏️模型放GPU

8.✏️✏️load模型

9.✏️✏️save模型

10.✏️✏️执行命令

整理不易,欢迎一键三连!!!



1.🍄🍄要修改的地方概览

2.✏️✏️初始化

在代码最开始的地方设置初始化参数,即训练和数据组织之前。

n_gpus = args.n_gpus   #自行传入
#local_rank = args.local_rank   #自行传入
local_rank = int(os.environ['LOCAL_RANK'])   #代码计算torch.distributed.init_process_group('nccl', world_size=n_gpus, rank=local_rank)#初始化进程组
  • 指定GPU之间的通信方式'nccl'
  • world_size:当前这个节点上要用多少GPU卡;(当前节点就是当前机器)
  • rank: 当前进程在哪个GPU卡上,通过args.local_rank来获取,local_rank变量是通过外部指令传入的;(也可以通过环境变量来接收)

注意:自行传入的变量需要通过argparse第三方库写入,示例如下:

import argparseparser = argparse.ArgumentParser()
parser.add_argument("--n_gpus", help="num of gpus")
parser.add_argument("-p", "--project", help="project name")
parser.add_argument('-s', '--sparse', action='store_true', default=False, help='GAT with sparse version or not.')
parser.add_argument('-d', '--seed', type=int, default=72, help='Random seed.')
parser.add_argument('-e', '--epochs', type=int, default=10000, help='Number of epochs to train.')args = parser.parse_args()
print(args.n_gpus)
print(args.sparse)
print(args.seed)
print(args.epochs)

3.✏️✏️设置当前进程GPU

在初始化之后紧接着设置当前进程的GPU

torch.cuda.set_device(local_rank)

上述指令作用相当于设置CUDA_VISBLE_DEVICES环境变量,设置当前进程要用第几张卡;

4.✏️✏️设置sampler

from torch.utils.data.distributed import DistributedSampler
train_sampler = DistributedSampler(dataset_train)
...
for epoch in range(start_epoch, total_epochs):train_sampler.set_epoch(epoch)  #为了让每张卡在每个周期中得到的数据是随机的...

此处的train_dataset为load数据的Dataset类,根据数据地址return出每个image和队形的mask,DistributedSampler返回一堆数据的索引train_sampler,根据索引去dataloader中拿数据,并且在每次epoch训练之前,加上train_sampler.set_epoch(epoch)这句,达到shuffle=True的目的。

5.✏️✏️sampler传入dataloader

from torch.utils.data import DataLoader
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,sampler = train_sampler)
dataloader_val = DataLoader(dataset_val,batch_size=1,)

通过将train_sampler传入dataloader达到数据传入模型的数据格式。

6.✏️✏️数据放GPU

在每次训练过程中,设置数据放GPU里。

for img,label in dataloader_train:inputs = img.cuda(local_rank)  #数据放GPUlabels = label.cuda(local_rank)  #数据放GPU...

7.✏️✏️模型放GPU

在定义模型的地方,设置将模型放入GPU

model = XXNet()
net = torch.nn.parallel.DistributedDataParallel(model.cuda(local_rank),device_ids=[local_rank])  #模型拷贝,放入DistributedDataParallel

8.✏️✏️load模型

torch.load(model_file_path, map_location = local_rank)

设置 map_location指定将模型传入哪个GPU上

9.✏️✏️save模型

torch.save(net.module.state_dict(), os.path.join(ckp_savepath, ckp_name))

注意,此处保存的net是net.module.state_dict

10.✏️✏️执行命令

python -m torch.distributed.launch --nproc_per_node=n_gpus --master_port 29502 train.py
  • nproc_per_node:等于GPU数量
  • master_port:默认为29501,如果出现address already in use,可以将其修改为其他值,比如29502

参考:视频讲解

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--


🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【设计模式】简单工厂模式

C语言实现简单的工厂模式 #include <stdio.h> #include <stdlib.h>// 图形类型枚举 typedef enum {CIRCLE,SQUARE,RECTANGLE } ShapeType;// 图形结构体 typedef struct {ShapeType type;float area; } Shape;// 创建圆形 Shape* createCircle() {Shape* circle …

大数据开发面试必问:Hive调优技巧系列一

Hive必问调优 Hive 调优拆解:Hive SQL 几乎是每一位互联网分析师的必备技能&#xff0c;相信很多小伙伴都有被面试官问到 Hive 优化问题的经历。所以掌握扎实的 HQL 基础尤为重要&#xff0c;hive优化也是小伙伴应该掌握的一项技能&#xff0c;本篇文章具体从hive建表优化、HQ…

数据结构-链表结构-单向链表

链表结构 说到链表结构就不得不提起数据结构&#xff0c;什么是数据结构&#xff1f;就是用来组织和存储数据的某种结构。那么到底是某种结构呢&#xff1f; 数据结构分为&#xff1a; 线性结构 数组&#xff0c;链表&#xff0c;栈&#xff0c;队列 树形结构 二叉树&#x…

QWidget窗口类

QWidget窗口类 设置父对象窗口位置窗口尺寸窗口标题和图标信号槽函数例子1例子3例子3 设置父对象 // 构造函数 QWidget::QWidget(QWidget *parent nullptr, Qt::WindowFlags f Qt::WindowFlags());// 公共成员函数 // 给当前窗口设置父对象 void QWidget::setParent(QWidget…

Linux系统下MySQL读写分离

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 一、基于Amoeba读写分离 1.基于程序代码内部实现 2.基于中间代理层实现 三、操作步骤 1.在主机Amoeba上安装java环境 2.安装并配置Amoeba 3.配置Amoeba读写分离…

【大数据趋势】7月30日 汇率,恒指期货的大数据趋势概率分析。

1. 数据源头之一 : 汇率变化 从程序模拟趋势来看&#xff0c;美元在持续弱势状态&#xff0c;周线上正在构建一个新的下跌趋势&#xff0c;而且正在反抽过程中&#xff0c;即将完成&#xff0c;如果没有外部干预&#xff0c;会顺势往下。从月线来看&#xff0c;高点逐步降低&a…

没有 telnet 不能测试端口?容器化部署最佳的端口测试方式

写在前面 生产中遇到&#xff0c;整理笔记在容器中没有 telnet &#xff0c;如何测试远程端口理解不足小伙伴帮忙指正 他的一生告诉我们&#xff0c;不能自爱就不能爱人&#xff0c;憎恨自己也必憎恨他人&#xff0c;最后也会像可恶的自私一样&#xff0c;使人变得极度孤独和悲…

线性代数的学习和整理2:线性代数的基础知识(整理ing)

目录 0 写在前面的话 网上推荐的线性代数的课程 1 线性代数和矩阵的各种概念 1.1 各种逻辑图 2 关于线性代数入门的各种灵魂发问 2.1 什么是线性&#xff0c;什么是线性相关 &#xff1f; 为什么叫线性变换&#xff1f; 为什么叫线性代数&#xff1f; 2.2 线性代数是人造…

【Linux】进程篇(补):简易 shell 的实现(进程深刻理解、内建命令的使用)

文章目录 makefilemybash.c 代码逻辑框架&#xff08;重要的是&#xff0c;边写边查&#xff01;&#xff09; 命令行提示符&#xff0c;fflush 刷新显示获取 输入的 有效字符串&#xff0c;定义一个字符数组&#xff0c;用 fgets 从键盘上获取&#xff08;注意处理命令行输入…

【疑难杂症专辑】【jlink 关闭了调试接口/进入休眠/停机不可下载】

做开发调试器的正常使用是基础&#xff0c;但有时候会人为造成一些问题。如下场景&#xff1a; 使用四线SW接口&#xff0c;进入低功耗停机模式后不能下载 先点下载&#xff0c;等keil在查找设备时短接复位然后断开把单片机唤醒&#xff0c;看自己的唤醒条件是什么&#xff0c…

Spark性能调优指南来了!

1、什么是Spark Spark 是一种基于内存的快速、通用、可扩展的大数据分析计算引擎。 Spark Core&#xff1a;实现了Spark的基本功能&#xff0c;包含任务调度、内存管理、错误恢复、与存储系统交互等模块。Spark Core中还包含了对弹性分布式数据集(Resilient Distributed Dat…

安科瑞智慧空开微型断路器在银行的应用-安科瑞黄安南

应用场景 智能微型断路器与智能网关组合应用于末端回路 功能 1.计量功能&#xff1a;实时上报电压、电流、功率、电能、漏电、温度、频率等电参量&#xff1b; 2.报警功能&#xff1a;过压报警、欠压报警、过流报警、过载报警、漏电报警、超温报警、三相电缺相报警&#xff…

论文笔记:Adjusting for Autocorrelated Errors in Neural Networks for Time Series

2021 NIPS 原来的时间序列预测任务是根据预测论文提出用一阶自回归误差预测 一阶差分&#xff0c;类似于ResNet的残差思路&#xff1f;记为pred&#xff0c;最终的预测结果

【蓝桥杯备考资料】如何进入国赛?

目录 写在前面注意事项数组、字符串处理BigInteger日期问题DFS 2013年真题Java B组世纪末的星期马虎的算式振兴中华黄金连分数有理数类&#xff08;填空题&#xff09;三部排序&#xff08;填空题&#xff09;错误票据幸运数字带分数连号区间数 2014年真题蓝桥杯Java B组03猜字…

维护电脑,让“战友”保持长寿命

目录 维护电脑&#xff0c;让“战友”保持长寿命介绍你的电脑介绍一下你的日常维护措施给出一些你觉得有用的维护技巧不推荐做些什么其他补充总结 无论是学习还是工作&#xff0c;电脑都是IT人必不可少的重要武器&#xff0c;一台好电脑除了自身配置要经得起考验&#xff0c;后…

1.0 python环境安装

1 python环境安装 python安装教程原文 2 PyCharm安装教程 PyCharm安装教程

异常的使用

异常 异常的概念 指的是程序在执行的过程中&#xff0c;出现的非正常的情况&#xff0c;最后会导致JVM的非正常停止。在java等面向对象的语言当中&#xff0c;异常本身是一个类&#xff0c;产生异常就是创建异常对象并且抛出一个异常对象。java处理异常的方式就是中断处理。异常…

Linux内核的I2C驱动框架详解------这应该是我目前600多篇博客中耗时最长的一篇博客

目录 1 I2C驱动整体框架图 2 I2C控制器 2.1 I2C控制器设备--I2C控制器在内核中也被看做一个设备 2.2 i2c控制器驱动程序 2.3 platform_driver结构体中的probe函数做了什么 2.3.1 疑问&#xff1a; i2cdev_notifier_call函数哪里来的 2.3.2 疑问&#xff1a;为什么有两…

Python爬虫-快手photoId

前言 本文是该专栏的第49篇,后面会持续分享python爬虫干货知识,记得关注。 笔者在本专栏的上一篇,有详细介绍平台视频播放量的爬取方法。与该平台相关联的文章,笔者已整理在下方,感兴趣的同学可查看翻阅。 1. Python如何解决“快手滑块验证码”(4) 2. 快手pcursor 3. …

2023 ISSE观察:智能遮阳窗帘行业蓬勃发展,AI设计引热议

7月31日&#xff0c;上海国际智能遮阳与建筑节能展览会落下帷幕。作为智能遮阳的行业展会&#xff0c;展会三天&#xff0c;现场热闹非凡&#xff0c;参展商和观展者络绎不绝。 作为一大行业盛事&#xff0c;2023 ISSE展会方打造了五大展区&#xff0c;除了提供系统门窗装修方案…