Pytorch单机多卡分布式训练

Pytorch单机多卡分布式训练

数据并行:

DP和DDP

这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。

  1. DataParallel(DP)

    • 只支持单进程多线程,单一机器上进行训练。
    • 模型训练开始的时候,先把模型复制到四个GPU上面,然后把数据分配给四个GPU进行前向传播,前向传播之后再汇总到卡0上面,然后在卡0上进行反向传播,参数更新,再将更新好的模型复制到其他几张卡上。

    在这里插入图片描述

  2. DistributedDataParallel(DDP)

    • 支持多线程多进程,单一或者多个机器上进行训练。通常DDP比DP要快。

    • 先把模型载入到四张卡上,每个GPU上都分配一些小批量的数据,再进行前向传播,反向传播,计算完梯度之后再把所有卡上的梯度汇聚到卡0上面,卡0算完梯度的平均值之后广播给所有的卡,所有的卡更新自己的模型,这样传输的数据量会少很多。

      在这里插入图片描述

DDP代码写法

  1. 初始化

    import torch.distributed as dist
    import torch.utils.data.distributed# 进行初始化,backend表示通信方式,可选择的有nccl(英伟达的GPU2GPU的通信库,适用于具有英伟达GPU的分布式训练)、gloo(基于tcp/ip的后端,可在不同机器之间进行通信,通常适用于不具备英伟达GPU的环境)、mpi(适用于支持mpi集群的环境)
    # init_method: 告知每个进程如何发现彼此,默认使用env://
    dist.init_process_group(backend='nccl', init_method="env://")
    
  2. 设置device

    device = torch.device(f'cuda:{args.local_rank}')	# 设置device,local_rank表示当前机器的进程号,该方式为每个显卡一个进程
    torch.cuda.set_device(device)	# 设定device
    
  3. 创建dataloader之前要加一个sampler

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)	# 加一个sampler
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)
    
  4. torch.nn.parallel.DistributedDataParallel包裹模型(先to(device)再包裹模型)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])	# 包裹模型
    
  5. 真正训练之前要set_epoch(),否则将不会shuffer数据

    for epoch in range(10):train_sampler.set_epoch(epoch)		# set_epochfor step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))
    
  6. 模型保存

    if args.local_rank == 0:		# local_rank为0表示master进程torch.save(net, "my_net.pth")
    
  7. 运行

    if __name__ == "__main__":parser = argparse.ArgumentParser()# local_rank参数是必须的,运行的时候不必自己指定,DDP会自行提供parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)
    
  8. 运行命令

    python -m torch.distributed.launch --nproc_per_node=2 多卡训练.py	# --nproc_per_node=2表示当前机器上有两个GPU可以使用
    

完整代码

import os
import argparse
import torch
import torchvision
import torch.distributed as dist
import torch.utils.data.distributedfrom torchvision import transforms
from torch.multiprocessing import Processdef main(args):# nccl: 后端基于NVIDIA的GPU-to-GPU通信库,适用于具有NVIDIA GPU的分布式训练# gloo: 后端是一个基于TCP/IP的后端,可在不同机器之间进行通信,通常适用于不具备NVIDIA GPU的环境。# mpi: 后端使用MPI实现,适用于具备MPI支持的集群环境。# init_method: 告知每个进程如何发现彼此,如何使用通信后端初始化和验证进程组。 默认情况下,如果未指定 init_method,PyTorch 将使用环境变量初始化方法 (env://)。dist.init_process_group(backend='nccl', init_method="env://") # nccl比较推荐device = torch.device(f'cuda:{args.local_rank}')torch.cuda.set_device(device)trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)net = torchvision.models.resnet101(num_classes=10)net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)net = net.to(device)net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])criterion = torch.nn.CrossEntropyLoss()opt = torch.optim.Adam(params=net.parameters(), lr=0.001)for epoch in range(10):train_sampler.set_epoch(epoch)for step, data in enumerate(data_loader_train):images, labels = dataimages, labels = images.to(device), labels.to(device)opt.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()opt.step()if step % 10 == 0:print("loss: {}".format(loss.item()))if args.local_rank == 0:torch.save(net, "my_net.pth")if __name__ == "__main__":parser = argparse.ArgumentParser()# must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDPparser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()main(args)

参考:

https://zhuanlan.zhihu.com/p/594046884
https://zhuanlan.zhihu.com/p/358974461

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/91417.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

openGauss学习笔记-83 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT使用内存和存储规划

文章目录 openGauss学习笔记-83 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT使用内存和存储规划83.1 MOT内存规划83.2 存储IO83.3 容量需求 openGauss学习笔记-83 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT使用内存和存储规划 本节描述了为满足特定…

完整的 pixel 6a 刷入 AOSP 源码过程记录

基础环境 虚拟机:VMware Workstation 16 Pro 16.0.0 build-16894299 Linux版本:ubuntu-16.04.7-desktop-amd64 设备:pixel 6a;代号:bluejay; 基础软件安装 安装 Git 命令:sudo apt install git …

金融生产存储亚健康治理:升级亚健康 3.0 ,应对万盘规模的挑战

随着集群规模的不断扩大,硬盘数量指数级上升,信创 CPU 和操作系统、硬盘多年老化、物理搬迁等多种复杂因素叠加,为企业的存储亚健康管理增加了新的挑战。 在亚健康 2.0 的基础上,星辰天合在 XSKY SDS V6.2 实现了亚健康 3.0&#…

【C++入门到精通】C++入门 —— set multiset (STL)

阅读导航 前言一、set简介二、std::set1. std::set简介2. std::set的使用- 基本使用- std::set的模板参数列表- std::set的构造函数- std::set的迭代器- std::set容量与元素访问函数 3. set的所有函数(表) 三、std::multiset1. std::multiset简介 四、st…

嵌入式学习笔记(35)外部中断

6.9.1什么是外部中断 (1)内部中断就是指中断源来自于SoC内部(一般是内部外设),譬如串口、定时器等部件产生的中断;外部中断是SoC外部的设备,通过外部中断对应的GPIO引脚产生的中断。 (2)按键在SoC中就使用了外部中断…

【每日一题】1498. 满足条件的子序列数目

1498. 满足条件的子序列数目 - 力扣(LeetCode) 给你一个整数数组 nums 和一个整数 target 。 请你统计并返回 nums 中能满足其最小元素与最大元素的 和 小于或等于 target 的 非空 子序列的数目。 由于答案可能很大,请将结果对 109 7 取余后…

stm32无人机-飞行力学原理

惯性导航,是一种无源导航,不需要向外部辐射或接收信号源,就能自主进行确定自己在什么地方的一种导航方法。 惯性导航主要由惯性器件计算实现,惯性器件包括陀螺仪和加速度计。一般来说,惯性器件与导航物体固连&#xf…

CTFSHOW SSTI

目录 web361 【无过滤】 subprocess.Popen os._wrap_close url_for lipsum cycler web362 【过滤数字】 第一个通过 计算长度来实现 第二个使用脚本输出另一个数字来绕过 使用没有数字的payload web363 【过滤引号】 使用getitem 自定义变量 web364 【过…

数据集笔记:OpenCelliD(手机基站开放数据库)

下载数据的方式可见:【数据获取】全球最大手机基站开源数据库 1 读取数据 import pandas as pdpd.read_csv(C:/Users/16000/Downloads/454.csv/454.csv,headerNone,names[Radio,MCC,MNC,LAC/TAC/NID,CID,Longitude,Latitude,Range,Samples,Changeable1,Changeable…

设计模式6、适配器模式 Adapter

解释说明:将一个类的接口转换成客户希望的另一个接口。适配器模式让那些接口不兼容的类可以一起工作 目标接口(Target):当前系统所期待的接口,它可以是抽象类或接口 适配者(Adaptee)&#xff1a…

AIGC(生成式AI)试用 7 -- 桌面小程序

生成式AI,别人用来写作,我先用来写个桌面小程序。 桌面小程序:计算器 需求 Python开发图形界面,标题:计算器 - * / 基本运算计算范围:-999999999 ~ 999999999** 乘方计算(例,2*…

LLM之Colossal-LLaMA-2:Colossal-LLaMA-2的简介、安装、使用方法之详细攻略

LLM之Colossal-LLaMA-2:Colossal-LLaMA-2的简介、安装、使用方法之详细攻略 导读:2023年9月25日,Colossal-AI团队推出了开源模型Colossal-LLaMA-2-7B-base。Colossal-LLaMA-2项目的技术细节,主要核心要点总结如下: >> 数据处…

分布式事务-TCC异常-幂等性

1、幂等性问题: 二阶段提交时,如果二阶段执行成功通知TC时出现网路或其他问题中断,那么TC没有收到执行成功的通知,TC内部有定时器不断的重试二阶段方法,导致接口出现幂等性问题。 2、解决方法 和空回滚问题一样也是…

Kotlin只截取Float小数点后数值DecimalFormat

Kotlin只截取Float小数点后数值DecimalFormat import java.text.DecimalFormatfun main(args: Array<String>) {val pi 3.141516Fvar p pi - pi.toInt()println(p)val decimalFormat DecimalFormat("00.0000")val format decimalFormat.format(p)println(…

UE5屏幕适配

一、本程序设计发布在手机上&#xff0c;首先确定屏幕的设计分辨率&#xff0c;这里我们选择iphone6s&#xff0c;750x1334。 二、设置DPI Scale为1.0的比例&#xff0c;点击齿轮标志 因为我们这个程序是手机竖屏使用的&#xff0c;所以DPI Scale Rule选择Shortest Side&#…

c语言常用语法,长时间不用容易忘。

关键字 auto 声明自动变量const 定义常量&#xff0c;如果一个变量被 const 修饰&#xff0c;那么它的值就不能再被改变extern 声明变量或函数是在其它文件或本文件的其他位置定义register 声明寄存器变量signed 声明有符号类型变量或函数static 声明静态变量&#xff0c;修饰…

APA技术架构与说明

1.自动泊车的硬件架构 2.APA自动泊车辅助系统 1&#xff09;APA主要包括以下典型功能 &#xff08;1&#xff09;泊车入库&#xff1a;利用超声波雷达或环视摄像头实现车位识别&#xff0c;并计算出合适行驶轨迹&#xff0c;对车辆进行横向/纵向控制使车辆驶入车位&#xff1…

在MyBatisPlus中添加分页插件

开发过程中&#xff0c;数据量大的时候&#xff0c;查询效率会有所下降&#xff0c;这时&#xff0c;我们往往会使用分页。 具体操作入下&#xff1a; 1、添加分页插件&#xff1a; package com.zhang.config;import com.baomidou.mybatisplus.extension.plugins.Pagination…

基于微信小程序的停车场预约收费小程序设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言系统主要功能&#xff1a;具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09;有保障的售后福利 代码参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计…

CCF CSP认证历年题目自练 Day15

CCF CSP认证历年题目自练 Day15 题目一 试题编号&#xff1a; 201709-1 试题名称&#xff1a; 打酱油 时间限制&#xff1a; 1.0s 内存限制&#xff1a; 256.0MB 问题描述&#xff1a; 问题描述   小明带着N元钱去买酱油。酱油10块钱一瓶&#xff0c;商家进行促销&#xf…