PyTorch Lightning:通过分布式训练扩展深度学习工作流

 

一、介绍

        欢迎来到我们关于 PyTorch Lightning 系列的第二篇文章!在上一篇文章中,我们向您介绍了 PyTorch Lightning,并探讨了它在简化深度学习模型开发方面的主要功能和优势。我们了解了 PyTorch Lightning 如何为组织和构建 PyTorch 代码提供高级抽象,使研究人员和从业者能够更多地关注模型设计和实验,而不是样板代码。

        在本文中,我们将深入研究 PyTorch Lightning,并探索它如何通过分布式训练实现深度学习工作流的扩展。分布式训练对于在海量数据集上训练大型模型至关重要,因为它允许我们利用多个 GPU 或机器的强大功能来加速训练过程。然而,分布式训练往往伴随着一系列挑战和复杂性。

二、安装 Pytorch Lightning & Torchvision

pip install torch torchvision pytorch-lightning 

三、实现

        首先,我们需要从 PyTorch 和 PyTorch Lightning 导入必要的模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transformsimport pytorch_lightning as pl

        接下来,我们使用 PyTorch 的类定义我们的神经网络架构。在这个例子中,我们使用一个简单的卷积神经网络,其中包含两个卷积层和三个全连接层:nn.Module

class Net(pl.LightningModule):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(nn.functional.relu(self.conv1(x)))x = self.pool(nn.functional.relu(self.conv2(x)))x = torch.flatten(x, 1)x = nn.functional.relu(self.fc1(x))x = nn.functional.relu(self.fc2(x))x = self.fc3(x)return x

        然后,我们为 .在该方法中,我们接收一批输入和标签,将它们通过我们的神经网络来获取 logits,计算交叉熵损失,并使用该方法记录训练损失。在该方法中,我们执行与 相同的操作,但不记录损失:LightningModuletraining_stepxyself.logvalidation_steptraining_step

    def training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = nn.functional.cross_entropy(logits, y)self.log("train_loss", loss)return lossdef validation_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = nn.functional.cross_entropy(logits, y)self.log("val_loss", loss)return loss

        我们还在方法中定义了优化器和学习率调度器:configure_optimizers

    def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)return [optimizer], [scheduler]

        接下来,我们使用 PyTorch 和 定义数据加载和预处理步骤:DataLoadertransforms

    def prepare_data(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])CIFAR10(root='./data', train=True, download=True, transform=transform)CIFAR10(root='./data', train=False, download=True, transform=transform)def train_dataloader(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = CIFAR10(root='./data', train=True, download=False, transform=transform)return DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)def val_dataloader(self):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])val_dataset = CIFAR10(root='./data', train=False, download=False, transform=transform)return DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)
  1. prepare_data(self):此函数负责在训练模型之前准备数据。它首先使用该类定义一系列转换。转换包括将数据转换为张量并对其进行规范化。定义转换后,该函数将下载用于训练和测试拆分的 CIFAR10 数据集。数据集将下载到目录,并将指定的转换应用于数据。transforms.Compose'./data'
  2. train_dataloader(self):此函数为训练数据集创建数据加载器。它首先定义与函数中相同的转换。接下来,它为训练拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用训练数据集创建一个对象。数据加载程序配置为 64 的批大小,对数据进行随机排序,并使用 8 个工作线程进行数据加载。它返回数据加载器。prepare_data'./data'DataLoader
  3. val_dataloader(self):此函数为验证数据集创建数据加载器。它遵循与函数类似的结构。它首先使用 定义转换,这些转换与前面的函数相同。然后,为验证拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用验证数据集创建一个对象。数据加载器配置为 64 的批大小,无需随机处理数据,并使用 8 个工作线程进行数据加载。它返回数据加载器。train_dataloadertransforms.Compose'./data'DataLoader

        该函数将模型作为输入,并对测试数据集执行评估。它首先对测试数据应用转换,将其转换为张量并规范化。然后,它为测试数据集创建数据加载程序。模型将移动到相应的设备(GPU,如果可用)。评估标准设置为交叉熵损失。evaluate_model

def evaluate_model(model):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)criterion = nn.CrossEntropyLoss()model.eval()test_loss = 0.0correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100.0 * correct / totalaverage_loss = test_loss / len(test_loader)print(f"Test Loss: {average_loss:.4f}")print(f"Test Accuracy: {accuracy:.2f}%")

        将模型置于评估模式,并初始化测试损失、正确预测和总数据点的变量。在无梯度上下文中,该函数遍历测试数据加载器,通过模型转发成批的输入,计算损失并累积测试损失。它还计算正确预测的数量和数据点的总数。最后,它计算并打印平均测试损失和测试精度。

        最后,我们实例化我们的模型和来自 PyTorch Lightning,指定用于分布式训练的所需数量的 GPU 或机器:NetTrainer

net = Net()trainer = pl.Trainer(num_nodes=1,  # Change to the number of machines in your distributed setupaccelerator="auto",  # Distributed Data Parallel, Available names are: auto, cpu, cuda, hpu, ipu, mps, tpu.max_epochs=5, devices=1 # Change to the desired number of GPUs or use `None` for CPU training
)trainer.fit(net)evaluate_model(net)
  • num_nodes:它指定分布式设置中的计算机数量。在这种情况下,它设置为 ,表示单台计算机设置。1
  • accelerator:它确定训练的加速器类型。该值允许 PyTorch Lightning 根据硬件和软件环境自动选择适当的加速器。其他可能的值包括 、 和 ,它们对应于特定的硬件加速器。"auto""cpu""cuda""hpu""ipu""mps""tpu"
  • max_epochs:它设置用于训练模型的最大周期数(通过训练数据集的完整遍历)。在本例中,它设置为 。5
  • devices:它指定用于训练的 GPU 数量。将其设置为 表示使用单个 GPU 进行训练。如果要在 CPU 上进行训练,可以将其设置为 。1None

        这些选项允许您控制训练过程的各个方面,例如分布式训练、加速器选择以及用于训练的周期数和设备数。

        设置好所有内容后,我们只需调用对象的方法,传入我们的模型、训练数据加载器和验证数据加载器。fitTrainerNet

四、输出

 

五、结论

        PyTorch Lightning 通过分布式训练简化了扩展深度学习工作流的过程。通过抽象化分布式训练的复杂性,PyTorch Lightning 使我们能够专注于设计和实现我们的深度学习模型,而不必担心低级细节。在本文中,我们演练了一个使用 PyTorch Lightning 进行分布式训练的示例代码实现。通过利用多个GPU或机器的强大功能,我们可以显著减少大型深度学习模型的训练时间。

六、引用

  • PyTorch Lightning: Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.1.0.rc0 documentation
  • PyTorch: PyTorch
  • torchvision.datasets.CIFAR10: Datasets — Torchvision 0.15 documentation
  • torch.utils.data.DataLoader: torch.utils.data — PyTorch 2.0 documentation
  • 火炬亚当:Adam — PyTorch 2.0 documentation
  • torch.optim.lr_scheduler。步长:StepLR — PyTorch 2.0 documentation
  • Torch.nn.CrossEntropyLoss: CrossEntropyLoss — PyTorch 2.0 documentation
  • torch.cuda.is_available:torch.cuda — PyTorch 2.0 documentation

阿奈·东格雷

皮托奇

分布式系统

深度学习
皮托奇闪电
计算机视觉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/48783.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT基础教程之二 第一个Qt小程序

QT基础教程之二 第一个Qt小程序 按钮的创建 在Qt程序中&#xff0c;最常用的控件之一就是按钮了&#xff0c;首先我们来看下如何创建一个按钮 QPushButton * btn new QPushButton; 头文件 #include <QPushButton>//设置父亲btn->setParent(this);//设置文字btn-&g…

SQL两张表数据对比

表1&#xff1a; 表2&#xff1a; 1、查询两表的数据差异&#xff1a; # 查询表1中有但表2没有的数据 SELECT DATA FROM data1 WHERE ( DATA ) NOT IN ( SELECT DATA FROM data2 );# 查询表2中有但表…

xml对象与字符串互换

很多老系统&#xff0c;特别是C的系统&#xff0c;可能数据结构采用的xml。xml对java来说没有什么&#xff0c;但是C来说&#xff0c;可能还有个顺序问题&#xff0c;毕竟c没有那么多通用类库。 2 xstream 先说依赖&#xff0c;我本来不想升级&#xff0c;但是有个问题卡者就给…

drools8尝试

drools7升级到drools8有很大很大的变更.几乎不能说是一个项目了. 或者说就是名字相同的不同项目, 初看下来变化是这样 两个最关键的东西都retired了 https://docs.drools.org/8.42.0.Final/drools-docs/drools/migration-guide/index.html business central变成了一个VS code…

C语言学习系列-->看淡指针(3)

文章目录 一、字符指针变量二、数组指针变量2.1 概述2.2 数组指针初始化 三、二维数组传参本质四、函数指针五、typedef关键字六、函数指针数组 一、字符指针变量 在指针的类型中我们知道有⼀种指针类型为字符指针 char* 一般使用&#xff1a; #include<stdio.h>int main…

钛合金为何成为iPhone 15 Pro材料首选?

多年来&#xff0c;iPhone Pro一直采用厚重的钢框架&#xff0c;但不会持续太久。 有了iPhone 15 Pro&#xff0c;苹果可能会从钢框架转向钛框架&#xff0c;这不仅仅是因为它听起来更酷。钛比钢有很多优点&#xff0c;尤其是它更轻&#xff0c;这将解决iPhone Pro与普通iPhon…

Kubernetes 安全机制 认证 授权 准入控制

客户端应用若想发送请求到 apiserver 操作管理K8S资源对象&#xff0c;需要先通过三关安全验证 认证&#xff08;Authentication&#xff09;鉴权&#xff08;Authorization&#xff09;准入控制&#xff08;Admission Control&#xff09; Kubernetes 作为一个分布式集群的管理…

ethers.js2:provider提供商

1、Provider类 Provider类是对以太坊网络连接的抽象&#xff0c;为标准以太坊节点功能提供简洁、一致的接口。在ethers中&#xff0c;Provider不接触用户私钥&#xff0c;只能读取链上信息&#xff0c;不能写入&#xff0c;这一点比web3.js要安全。 除了之前介绍的默认提供者d…

JAVA免杀学习与实验

1 认识Webshell 创建一个JSP文件&#xff1a; <% page import"java.io.InputStream" %> <% page import"java.io.BufferedReader" %> <% page import"java.io.InputStreamReader" %> <% page language"java" p…

jmeter进行业务接口并发测试,但登录接口只执行一次

业务接口性能测试&#xff0c;往往都是需要登录&#xff0c;才能请求成功&#xff0c;通常只需要登录一次&#xff0c;再对业务接口多次并发测试。 在测试计划中&#xff0c;添加setUp线程组 把登录请求放入到该线程组中&#xff0c;设置HTTP信息头&#xff0c;JSON提取(提取登…

前端基础(ES6 模块化)

目录 前言 复习 ES6 模块化导出导入 解构赋值 导入js文件 export default 全局注册 局部注册 前言 前面学习了js&#xff0c;引入方式使用的是<script s"XXX.js">&#xff0c;今天来学习引入文件的其他方式&#xff0c;使用ES6 模块化编程&#xff0c;…

电路学习+硬件每日学习十个知识点(40)23.8.20 (希腊字母读音,阶跃信号和冲激信号的关系式,信号的波形变换,信号的基本运算,卷积积分,卷积和)

文章目录 1.信号具有时间特性和频率特性。2.模拟转数字&#xff0c;抽样、量化、编码3.阶跃信号和冲激信号4.信号的波形变换&#xff08;时移、折叠、尺度变换&#xff09;5.信号的基本运算&#xff08;加减、相乘、微分与积分、差分与累加&#xff09;5.1 相加减5.2 相乘5.3 微…

基础论文学习(1)——ViT

Vision Transformer&#xff08;ViT&#xff09; 模型架构是在 ICLR 2021 上作为会议论文发表的一篇研究论文中介绍的&#xff0c;题为“An Image is Worth 16*16 Words: Transformers for Image Recognition at Scale”。它由Neil Houlsby&#xff0c;Alexey Dosovitskiy和Goo…

springMVC之视图

文章目录 前言一、ThymeleafView二、转发视图三、重定向视图四、视图控制器view-controller五、补充总结 前言 SpringMVC中的视图是View接口&#xff0c;视图的作用渲染数据&#xff0c;将模型Model中的数据展示给用户。 SpringMVC视图的种类很多&#xff0c;默认有转发视图和…

vscode远程调试

安装ssh 在vscode扩展插件搜索remote-ssh安装 如果连接失败&#xff0c;出现 Resolver error: Error: XHR failedscode 报错&#xff0c;可以看这篇帖子vscode ssh: Resolver error: Error: XHR failedscode错误_阿伟跑呀的博客-CSDN博客 添加好后点击左上角的加号&#xff0…

【Python机器学习】实验16 卷积、下采样、经典卷积网络

文章目录 卷积、下采样、经典卷积网络1. 对图像进行卷积处理2. 池化3. VGGNET4. 采用预训练的Resnet实现猫狗识别 TensorFlow2.2基本应用5. 使用深度学习进行手写数字识别 卷积、下采样、经典卷积网络 1. 对图像进行卷积处理 import cv2 path data\instance\p67.jpg input_…

Linux 线程同步——条件变量

一、条件变量的概念 如果说互斥锁是用于同步线程对共享数据的访问的话&#xff0c;那么条件变量则是用于在线程之间同步共享数据的值。条件变量提供了一种线程间的通知机制&#xff1a;当某个共享数据达到某个值的时候&#xff0c;唤醒等待这个共享数据的线程。如下图所示&…

vue3中使用第三方插件mitt实现任意组件通讯

vue3中使用第三方插件mitt实现任意组件通讯 组件通讯是vue3组合式开发的核心之一&#xff0c;现在我在写代码时&#xff0c;一个组件的代码超过了200行&#xff0c;基本都会拆分组件。组件拆分后&#xff0c;组件之间的通讯就很重要&#xff0c;总结了一下&#xff0c;目前有这…

【SQL应知应会】索引(三)• MySQL版:聚簇索引与非聚簇索引;查看索引与删除索引;索引方法

欢迎来到爱书不爱输的程序猿的博客, 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 本文收录于SQL应知应会专栏,本专栏主要用于记录对于数据库的一些学习&#xff0c;有基础也有进阶&#xff0c;有MySQL也有Oracle 索引 • MySQL版 前言一、索引1.简介2.索引类型之逻…

【李沐】3.2线性回归从0开始实现

%matplotlib inline import random import torch from d2l import torch as d2l1、生成数据集&#xff1a; 看最后的效果&#xff0c;用正态分布弄了一些噪音 上面这个具体实现可以看书&#xff0c;又想了想还是上代码把&#xff1a; 按照上面生成噪声&#xff0c;其中最后那…