Pytorch-Lighting使用教程(MNIST为例)

一、pytorch-lighting简介

1.1 pytorch-lighting是什么

pytorch-lighting(简称pl),基于 PyTorch 的框架。它的核心思想是,将学术代码模型定义、前向 / 反向、优化器、验证等)与工程代码for-loop,保存、tensorboard 日志、训练策略等)解耦开来,使得代码更为简洁清晰。

工程代码经常会出现在深度学习代码中,PyTorch Lightning 对这部分逻辑进行了封装,只需要在 Trainer 类中简单设置即可调用,无需重复造轮子。

1.2 pytorch-lighting优势

  • 通过抽象出样板工程代码,可以更容易地识别和理解ML代码;
  • Lightning的统一结构使得在现有项目的基础上进行构建和理解变得非常容易;
  • Lightning 自动化的代码是用经过全面测试、定期维护并遵循ML最佳实践的高质量代码构建的;

pytorch-lighting最大的好处:

(1)是摆脱了硬件依赖,不需要在程序中显式设置.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换

(2)支持分布式训练,自动分配资源,能够很好的进行大规模的DL训练

(3)代码量较少,只需要关心关键的逻辑代码,而框架性的东西,pytorch-lighting已经帮你解决(如自动训练,自动debug)


二、基于Pytorch-Lighting框架训练MNIST模型

1、仅仅训练

下载的所有的数据集都用于训练(没有评估和测试过程,不清楚模型的好与坏)。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)return loss# 3.3 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 开始训练
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

class LitAutoEncoder(pl.LightningModule):

  • 将模型定义代码写在__init__
  • 定义前向传播逻辑
  • 将优化器代码写在 configure_optimizers 钩子中
  • 训练代码写在 training_step 钩子中,可使用 self.log 随时记录变量的值,会保存在 tensorboard 中
  • 验证代码写在 validation_step 钩子中
  • 移除硬件调用.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换
  • 根据需要,重写其他钩子函数,例如 validation_epoch_end,对 validation_step 的结果进行汇总;train_dataloader,定义训练数据的加载逻辑
  • 实例化 Lightning Module 和 Trainer 对象,传入数据集
  • 定义训练参数和回调函数,例如训练设备、数量、保存策略,Early Stop、半精度等

运行结果:

2、添加验证和测试模块

在训练之后,加入了测试和评估功能,能更好的指导模型的性能。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as plimport torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transformsfrom torch.utils.data import DataLoader# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)return loss# 3.3 测试过程设置def test_step(self, batch, batch_idx):# this is the test loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss)# 3.4 验证过程设置def validation_step(self, batch, batch_idx):# this is the validation loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss)# 3.5 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 4.3 设置种子
seed = torch.Generator().manual_seed(42)# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=10)# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

3、权重 & 超参的保存和加载

当模型正在训练时,性能会随着它继续看到更多数据而发生变化。

1)训练完成后,使用在训练过程中发现的最佳性能相对应的权重;

2)权重可以让训练在训练过程中断的情况下从原来的位置恢复。

保存权重:Lightning 会自动为你在当前工作目录下保存一个权重,其中包含上一次训练的状态。这能确保在训练中断的情况下恢复训练。

3.1 自动在当前目录下保存checkpoint

# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()

3.2 指定checkpoint保存的目录

# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")

3.3 加载checkpoint

# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")

4、可视化

在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了了解汽车行驶方向的窗口。借助 Lightning,可以可视化任何您能想到的东西:数字、文本、图像、音频。

要跟踪指标,只需使用 LightningModule 内可用的 self.log 方法。

class LitModel(pl.LightningModule):def training_step(self, batch, batch_idx):value = ...self.log("some_value", value)

同时记录多个指标:

values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

4.1 命令行查看

要在命令行进度栏中查看指标,请将 prog_bar 参数设置为 True。

self.log(..., prog_bar=True)

4.2 浏览器查看

默认情况下,Lightning 使用 Tensorboard(如果可用)和一个简单的 CSV 记录器

在命令行中输入(注意:一定是lightning_logs所在的目录):

tensorboard --logdir=lightning_logs/

Tensorboard界面:

Tensorboard输出分析:

完整的代码:

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as plimport torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom pytorch_lightning.loggers import TensorBoardLogger# 设置浮点矩阵乘法精度为 'medium'
torch.set_float32_matmul_precision('medium')# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)batch_idx_value = batch_idx + 1print(" ")values = {"loss": loss, "batch_idx_value": batch_idx_value}  # add more items if neededself.log_dict(values)# 在命令行界面显示log'''sync_dist=True:分布式计算,数据同步标志prog_bar=True:在控制台上显示'''self.log("train_loss", loss, sync_dist=True, prog_bar=True)return loss# 3.3 测试过程设置def test_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss, sync_dist=True, prog_bar=True)# 3.4 验证过程设置def validation_step(self, batch, batch_idx):# this is the validation loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss, sync_dist=True, prog_bar=True)# 3.5 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 4.3 设置种子
seed = torch.Generator().manual_seed(42)# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set, batch_size=256, num_workers=5)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=5)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=1000)# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)
# 7. 从checkpoint恢复状态
# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

参考:

https://zhuanlan.zhihu.com/p/659631467

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/845725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++学习之避免使用野指针

现有结构体struct_a和struct_b,其中struct_b中包含struct_a指针作为成员变量。先基于struct_a定义一个变量a_ptr,之后定义一个struct_b指针变量b_ptr,并将a_ptr赋值给b_ptr中的struct_a类型变量。之后释放b_ptr, 那么a_ptr指向的内…

红队内网攻防渗透:内网渗透之windows内网权限提升技术:工具篇

红队内网攻防渗透 1. 内网权限提升技术1.1 windows内网权限提升技术--工具篇1.1.1 Web到Win系统提权-平台&语言&用户1.1.1.1 Web搭建平台差异1.1.1.2 Web语言权限差异1.1.1.3 系统用户权限差异1.1.2 Web到Win系统提权-Windows 2012宝塔面板-哥斯拉1.1.2.1 环境条件:1.…

Anthropic公司CEO谈AI发展:Cluade安全超过商业利益

Anthropic公司今年3月发布的超越GPT-4模型Claude3 opus,成功吸引了大量GPT-4用户“叛变”。 作为OpenAI的头号劲敌,Claude3发布方Anthropic公司的联合创始人兼CEO,达里奥阿莫迪(DarioAmodei)承诺:在能够制…

生信分析进阶4 - 比对结果的FLAG和CIGAR信息含义与BAM文件指定区域提取

BAM文件时存储比对数据的常用格式,可用于短reads和长reads数据。BAM是二进制压缩格式,SAM文件为其纯文本格式,CRAM为BAM的高压缩格式,IO效率相比于BAM略差,但是占用存储空间更小。 1. BAM文件的比对信息 BAM的核心信…

用c语言实现通讯录

目录 静态简易通讯录 代码: 功能模块展示: 设计思路: 动态简易通讯录(本质顺序表) 代码: 扩容模块展示: 设计思路: 文件版本通讯录 代码: 文件模块展示&#x…

pas编程语言:深度剖析与实用技巧

pas编程语言:深度剖析与实用技巧 在编程的浩瀚海洋中,pas编程语言以其独特的魅力和深厚的内涵吸引着众多编程爱好者。然而,其复杂性和深度也常常让人望而生畏。本文将从四个方面、五个方面、六个方面和七个方面对pas编程语言进行深入剖析&am…

SJ705C安全帽高温预处理箱

一、仪器用途 安全帽高温预处理箱是我公司根据安全帽新国家标准检测试验要求而自主设计研发制造。是安全帽检测前做高温预处理的专用设备。 二、仪器特征 1、有PID自整定温度控制仪,控制准确。 2、数显计时、计温器。 3、石英灯管加热系统;。 …

【数据结构】单链表-->详细讲解,后赋源码

欢迎来到我的Blog,点击关注哦💕 前面已经介绍顺序表,顺序表存在一定的局限性,空间管理上存在一定的缺陷,今天介绍新的存储结构单链表。 前言: 单链表是一种基本的数据结构,它由一系列节点组成&a…

HTML网页滚动条使用整理_网页滚动条使用详解

一、HTML 网页滚动条 HTML Document 滚动条,自动出现; 当网页内容超出浏览器可视宽度或者高度,滚动条自动出现; 不同浏览器滚动条样式效果不同。 二、Css 修改滚动条样式 Css 伪元素控制进度条_Css控制滚动条_Css ::-webkit-scrollbar整理 三、Js监听滚动条,触底加载事…

Android 调试桥_ADB命令

Android 调试桥 ADB全称 【Android Debug Bridge】 是Android SDK中的一个命令行工具,adb命令可以直接操作管理Android模拟器或真实的Android设备(手机) ADB的工作原理 启动一个 adb 客户端时,此客户端首先检查是否有已运行的 …

python zip()函数(将多个可迭代对象的元素配对,创建一个元组的迭代器)zip_longest()

文章目录 Python zip() 函数深入解析基本用法函数原型基础示例 处理不同长度的迭代器高级用法多个迭代器使用 zip() 与 dict()解压序列 注意事项内存效率:zip() 返回的是一个迭代器,这意味着直到迭代发生前,元素不会被消耗。这使得 zip() 特别…

自然语言处理基础知识入门(六) GPT模型详解

GPT 前言一、GPT模型1.1 为什么采用Decoder模块?1.2 为什么不使用Encoder模块? 二、 模型训练2.1 预训练阶段2.2 半监督微调 总结 前言 在之前的章节中,深入探究了预训练ELMo模型的架构与实现原理。通过采用双向LSTM架构在大规模文本数据上进…

[数据集][目标检测][数据集][目标检测]智能手机检测数据集VOC格式5447张

数据集格式:Pascal VOC格式(不包含分割的txt文件,仅仅包含jpg图片和对应的xml) 图片数量(jpg文件个数):5447 标注数量(xml文件个数):5447 标注类别数:1 标注类别名称:["phone"] 每个类别标注的框数&#xff…

2024年华为OD机试真题-执行时长-Python-OD统一考试(C卷D卷)

2024年OD统一考试(D卷)完整题库:华为OD机试2024年最新题库(Python、JAVA、C++合集) 题目描述: 为了充分发挥GPU算力,需要尽可能多的将任务交给GPU执行,现在有一个任务数组,数组元素表示在这1秒内新增的任务个数且每秒都有新增任务,假设GPU最多一次执行n个任务,一次执…

Qt程序错误“QObject::connect: Cannot queue arguments of type ‘QTextCursor’”的解决方法

背景: 在Qt的线程中调用QTexiEdit控件的append(QString)或insertPlainText(QString),线程首次执行会报错 “QObject::connect: Cannot queue arguments of type ‘QTextCursor”,销毁该线程&a…

Pytorch中Tensor的类型对应表

Data typedtypeCPU tensorGPU tensor32位浮点数torch.float32 or torch.floattorch.FloatTensortorch.cuda.FloatTensor64位浮点数torch.float64 or torch.doubletorch.DoubleTensortorch.cuda.DoubleTensor16位浮点数torch.float16 or torch.halftorch.HalfTensortorch.cuda.H…

Flutter 中的 SliverWithKeepAliveWidget 小部件:全面指南

Flutter 中的 SliverWithKeepAliveWidget 小部件:全面指南 Flutter 是一个由 Google 开发的跨平台 UI 框架,它允许开发者使用 Dart 语言构建高性能、美观的移动、Web 和桌面应用。在 Flutter 的丰富组件库中,SliverWithKeepAliveWidget 是一…

宝塔Linux面板-Docker管理(2024详解)

上一篇文章《宝塔Linux可视化运维面板-详细教程2024》,详细介绍了宝塔Linux面板的详细安装和配置方法。本文详细介绍使用Linux面板管理服务器Docker环境。 目录 1、安装Docker 1.1 在线安装 ​编辑 1.2 手动安装 1.3 运行状态 1.4 镜像加速 2 应用商店 3 总览 4 容器 …

高德地图 JS API用于绘画船舶轨迹

文章目录 引言I 2.0升级指南1.1 修改 JSAPI 引用中的版本号到 2.01.2 相应修改II 1.4.15 文档引言 地图 JS API 2.0 是高德开放平台免费提供的第四代 Web 地图渲染引擎, 以 WebGL 为主要绘图手段,本着“更轻、更快、更易用”的服务原则,广泛采用了各种前沿技术,交互体验、…

从CSV到数据库(简易)

需求:客户上传CSV文档,要求CSV文档内容查重/插入/更新相关数据。 框架:jdbcTemplate、commons-io、 DB:oracle 相关依赖: 这里本来打算用的2.11.0,无奈正式项目那边用老版本1.3.1,新版本对类型…