pytorch分布式数据训练结合学习率周期及混合精度

文章目录

  • 1、SPAWN方式
  • 2、torchrun 方式

正如标题所写,我们正常的普通训练都是单机单卡或单机多卡。而往往一个高精度的模型需要训练时间很长,所以DDP分布式数据并行和混合精度可以加速模型训练。混精可以增大batch size.

如下提供示例代码,经过官网查阅验证的。原始代码由百度文心一言提供。
问题:pytorch 分布式数据并行与学习率周期和混合精度训练的示例代码
该问题分别问了(当前2023年10月16日)百度文心一言,google bard,科大讯飞星火,chatgpt4,bing。目前看来只有文心给出最接近正确答案。当前版本 pythorch 2.0.1。关于混精 torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...),混精可以参考:

  • https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html?highlight=torch%20cuda%20amp%20autocast
  • https://pytorch.org/docs/stable/amp.html#gradient-scaling

1、SPAWN方式

在使用这些助手(都是免费的)时,发现对于不同问题,各自的优劣不同,所以可以都试一下。
在PyTorch中,分布式数据并行(Distributed Data Parallel,DDP)可以在多个GPU上并行训练模型,提高训练速度。学习率周期(Learning Rate Schedules)用于调整训练过程中学习率的变化,而混合精度训练则可以提高训练速度并减少内存使用。
以下是一个使用PyTorch的分布式数据并行、学习率周期和混合精度训练的示例代码:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocastdef setup(rank, world_size):"""Args:rank: Unique identifier of each processworld_size: Total number of processes"""os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'# initialize the process groupdist.init_process_group(backend="nccl", rank=rank, world_size=world_size)torch.cuda.set_device(rank)def cleanup():dist.destroy_process_group()
# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)# 训练函数
def train(rank, world_size):print(f"Running basic DDP example on rank {rank}.")gpu=ranksetup(rank=rank,world_size=world_size,)torch.manual_seed(0)model = SimpleModel().cuda(gpu)ddp_model = DDP(model, device_ids=[gpu])# 使用混合精度训练scaler = GradScaler()optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)# 使用学习率周期scheduler1 = ExponentialLR(optimizer, gamma=0.9)scheduler2 = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)# 模拟数据data = torch.randn(32, 10).to(gpu)target = torch.randn(32, 10).to(gpu)for epoch in range(100):optimizer.zero_grad()with autocast():  # 启用混合精度训练output = ddp_model(data)loss = nn.MSELoss()(output, target)scaler.scale(loss).backward()  # 使用scaler进行梯度缩放scaler.step(optimizer)  # 更新参数scaler.update()  # 更新GradScaler状态scheduler1.step()  # 使用ExponentialLR调整学习率scheduler2.step()  # 使用CosineAnnealingLR调整学习率print(f"Rank {rank}, Epoch {epoch}, Loss {loss.item()}")cleanup()
if __name__ == "__main__":world_size = torch.cuda.device_count()  # 获取可用GPU的数量mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)  # 在每个GPU上运行train函数

这个示例代码在每个GPU上并行训练一个简单的线性模型。每个进程(即每个GPU)都有自己的模型副本,并且每个进程都独立计算梯度。然后,所有进程都会聚集他们的梯度并平均,然后用于一次总体参数更新。这个过程会根据学习率周期来调整每个epoch后的学习率

本部分参考官方的:https://pytorch.org/tutorials/beginner/ddp_series_multigpu.html?highlight=torch%20multiprocessing 是写单GPU和多GPU的区别。

2、torchrun 方式

首先是写一个ddp.py,内容如下:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast# 定义模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)# 训练函数
def train():dist.init_process_group("nccl")rank = dist.get_rank()print(f"Start running basic DDP example on rank {rank}.")gpu = rank %  torch.cuda.device_count()torch.manual_seed(0)model = SimpleModel().to(gpu)ddp_model = DDP(model, device_ids=[gpu])# 使用混合精度训练scaler = GradScaler()optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)# 使用学习率周期scheduler1 = ExponentialLR(optimizer, gamma=0.9)scheduler2 = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)# 模拟数据data = torch.randn(32, 10).to(gpu)target = torch.randn(32, 10).to(gpu)for epoch in range(100):optimizer.zero_grad()with autocast():  # 启用混合精度训练output = ddp_model(data)loss = nn.MSELoss()(output, target)scaler.scale(loss).backward()  # 使用scaler进行梯度缩放scaler.step(optimizer)  # 更新参数scaler.update()  # 更新GradScaler状态scheduler1.step()  # 使用ExponentialLR调整学习率scheduler2.step()  # 使用CosineAnnealingLR调整学习率print(f"Rank {rank}, Epoch {epoch}, Loss {loss.item()}")dist.destroy_process_group()
if __name__ == "__main__":train()

单机多卡,执行:

torchrun --nproc_per_node=4 --standalone ddp.py

如果是多机多卡:

torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py

本部分参考:
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/110584.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring MVC(一)【什么是Spring MVC】

重点 Spring:IOC 和 AOP 。 Spring MVC :Spring MVC 的执行流程。 SSM 框架的整合! Spring 和 Mybatis 我们不建议使用太多注解,Spring MVC 建议全部使用注解开发! 1、MVC 回顾 1.1、什么是MVC MVC是模型(Model)…

chrome历史版本下载

chrome历史版本下载 windows Google Chrome all versions on Windows linux版本 Google Chrome 64bit Linux版_chrome浏览器,chrome插件,谷歌浏览器下载,谈笑有鸿儒

软件架构风格总结以及场景应用

目录 前言1. 数据流风格2. 调用/返回风格3. 独立构件风格4. 解释器风格5. 仓库风格 前言 软件架构风格可以几个大类: 数据流风格:有批处理风格、管道-过滤器调用/返回风格:有主程序/子程序、面向对象、层次结构等独立构件风格:有…

藏在超级应用背后的逻辑和哲学

众所周知,Elon Musk 想将 Twitter 重新设计定位成一款“超级应用 - X”的野心已经不再是秘密。伴随着应用商店中 Twitter 标志性的蓝鸟 Logo 被 X 取代后,赛博世界充满了对这件事情各种角度的探讨与分析。 Musk 曾经无数次通过微信这一样本来推广他的“超…

Linux:命令行参数和环境变量

文章目录 命令行参数环境变量环境变量的概念常见的环境变量PATH 环境变量表本地变量和环境变量命令分类 本篇主要解决以下问题: 什么是命令行参数命令行参数有什么用环境变量是什么环境变量存在的意义 命令行参数 在学习C语言中,对于main函数当初的写…

iOS开发UITableView的使用,区别Plain模式和Grouped模式

简单赘述一下 的创建步骤 // 创建UITableView self.tableView [[UITableView alloc] initWithFrame:self.view.bounds style:UITableViewStylePlain]; // 设置数据源和代理 self.tableView.dataSource self; self.tableView.delegate self; // 注册自定义UITableViewCe…

【数据结构】830+848真题易错题汇总(自用)

【数据结构】830848易错题汇总(10-23) 文章目录 【数据结构】830848易错题汇总(10-23)选择题填空题判断题简答题:应用题:算法填空题:算法设计题:(待补) 选择题 1、顺序栈 S 的 Pop(S, e)操作弹出元素 e,则下列(C )是正…

数组中指针不同加1的区别

#include <iostream> int main(){int a[5] {1,2,3,4,5};int* p (int*)(&a1);printf("%d",*(p-1)); // 这段代码会输出4 }原因&#xff1a; array和&array的值是一样的&#xff0c;但是他们代表的意义完全不一样&#xff0c;array是数组首元素的首地址…

Android中使用Glide加载圆形图像或给图片设置指定圆角

一、Glide加载圆形头像 效果 R.mipmap.head_icon是默认圆形头像 ImageView mImage findViewById(R.id.image);RequestOptions options new RequestOptions().placeholder(R.mipmap.head_icon).circleCropTransform(); Glide.with(this).load("图像Uri").apply(o…

最新Tuxera NTFS2024破解版mac读写NTFS磁盘工具

Tuxera NTFS for Mac是一款Mac系统NTFS磁盘读写软件。在系统默认状态下&#xff0c;MacOSX只能实现对NTFS的读取功能&#xff0c;Tuxera NTFS可以帮助MacOS 系统的电脑顺利实现对NTFS分区的读/写功能。Tuxera NTFS 2024完美兼容最新版本的MacOS 11 Big Sur&#xff0c;在M1芯片…

php文本转语音功能插件

当前插件集成了百度文本转语音功能,支持laravel 9版本以上. 下载方式: 1、通过composer下载:composer require yreborn/laravel-speech 2、在composer.json 新增 “yreborn/laravel-speech”: “dev-main”&#xff0c;在命令行使用composer install进行安装 1、创建config/…

Oracle数据库 ORA-28001: the password has expired解决方法

今天在用dbvisualizer登录数据库的时候&#xff0c;报了the password has expired的错误&#xff0c;于是上网查了一下原因&#xff0c;是因为数据库密码过期了&#xff0c;因为默认的是180天。 解决方法&#xff1a; 1&#xff09;用系统用户登录 #在cmd终端输入&#xff1…

基于晶体结构优化的BP神经网络(分类应用) - 附代码

基于晶体结构优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于晶体结构优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.晶体结构优化BP神经网络3.1 BP神经网络参数设置3.2 晶体结构算法应用 4.测试结果…

【单片机基础】使用51单片机制作函数信号发生器(DAC0832使用仿真)

文章目录 &#xff08;1&#xff09;DA转换&#xff08;2&#xff09;DAC0832简介&#xff08;3&#xff09;电路设计&#xff08;4&#xff09;参考例程&#xff08;5&#xff09;参考文献 &#xff08;1&#xff09;DA转换 单片机作为一个数字电路系统&#xff0c;当需要采集…

UE5 运行时生成距离场数据

1.背景 最近有在运行时加载模型的需求&#xff0c;使用DatasmithRuntimeActor可以实现&#xff0c;但是跟在编辑器里加载的模型对比起来&#xff0c;室内没有Lumen的光照效果。 图1 编辑器下加载模型的效果 图2 运行时下加载模型的效果 然后查看了距离场的数据&#xff0c;发现…

华为智选SF5,AITO问界的车怎么样

#华为智选 #赛力斯SF5 #aito问界m5 #aito问界m7 #华为汽车 华为的车&#xff0c;后杠焊两点&#xff0c;拉车的时候&#xff0c;拖车钩断了&#xff0c;后杠拉出来了&#xff0c;这质量可以吗&#xff1f;是否应该全部召回&#xff1f;M5&#xff0c;M7是不是也这样&#xff1f…

蓝桥杯(跳跃 C++)

思路&#xff1a; 1、根据题目很容易知道可以用深度搜索、广度搜索、动态规划的思想解题。 2、这里利用深度搜素&#xff0c;由题目可知&#xff0c;可以往九个方向走。 3、这里的判断边界就是走到终点。 #include<iostream> using namespace std; int max1 0; int …

增加并行度后,发现Flink窗口不会计算的问题。

文章目录 前言一、现象二、结论三、解决 前言 窗口没有关闭计算的问题&#xff0c;一直困扰了很久&#xff0c;经过多次验证&#xff0c;确定了问题的根源。 一、现象 Flink使用了window&#xff0c;同时使用了watermark &#xff0c;并且还设置了较高的并行度。生产是设置了…

chromium 54 chrome 各个版本发布功能列表(109-119)

chromium Features 109-119 From https://chromestatus.com/features chromium109 Features:12 Auto range support for font descriptors inside font-face rule Auto range support for variable fonts in ‘font-weight’, ‘font-style’ and ‘font-stretch’ descrip…

微服务负载均衡实践

概述 本文介绍微服务的服务调用和负载均衡&#xff0c;使用spring cloud的loadbalancer及openfeign两种技术来实现。 本文的操作是在微服务的初步使用的基础上进行。 环境说明 jdk1.8 maven3.6.3 mysql8 spring cloud2021.0.8 spring boot2.7.12 idea2022 步骤 改造Eu…