【AI基础】pytorch lightning 基础学习

传统pytorch工作流是首先定义模型框架,然后写训练和验证,测试循环代码。训练,验证,测试代码写起来比较繁琐。这里介绍使用pytorch lightning 部署模型,加速模型训练和验证,记录。

准备工作

1 安装pytorch lightning 检查版本

$ conda create -n lightning python=3.9 -y
$ conda activate lightning
import lightning as L
import torchprint("Lightning version:", L.__version__)
print("Torch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

2 加载基本库函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

3 设置随机种子(可复现性)

L.seed_everything(1121218)

4 数据集下载和增强变换

这里以CIFAR10数据集为例子,该数据集包含 10 个类的 6 万张 32x32 彩色图像,每个类 6000 张图像。

from torchvision import datasets, transforms# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test
)
# Data augmentation and normalization for training
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),],
)
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

上面的增强变换包括以下四种基本变换: 

  • 裁剪(需要指定图像大小,在本例中为 32x32)。
  • 水平翻转。
  • 转换为张量数据类型,这是 PyTorch 所必需的。
  • 对图像的每个颜色通道进行归一化处理。

传统pytorch模型训练流

定义一个CNN模型

class CIFAR10CNN(nn.Module):def __init__(self):super(CIFAR10CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = torch.relu(self.fc1(x))x = self.fc2(x)return x

编写训练、验证循环代码

  • 需要初始化模型,损失函数和优化器
  • 管理模型和数据在机器上的运行(CPU 与 GPU)
  • 训练步骤:前向传播、损失计算、反向传播和优化
  • 验证步骤:计算准确性和损失
  • tensorboard日志记录,训练损失,准确率,其他相关指标记录等
  • 模型保存
  • # Initialize the model, loss function, and optimizer
    model = CIFAR10CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)# TensorBoard setup
    writer = SummaryWriter('runs/cifar10_cnn_experiment')# Training loop
    total_step = len(train_loader)
    for epoch in range(num_epochs):model.train()train_loss = 0.0for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')# Calculate average training loss for the epochavg_train_loss = train_loss / len(train_loader)writer.add_scalar('training loss', avg_train_loss, epoch)# Validationmodel.eval()with torch.no_grad():correct = 0total = 0val_loss = 0.0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalavg_val_loss = val_loss / len(test_loader)print(f'Validation Accuracy: {accuracy:.2f}%')writer.add_scalar('validation loss', avg_val_loss, epoch)writer.add_scalar('validation accuracy', accuracy, epoch)# Learning rate schedulingscheduler.step(avg_val_loss)# Final test
    model.eval()
    with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')writer.close()# Save the model
    torch.save(model.state_dict(), 'cifar10_cnn.pth')

     在上面的代码示例,有一些需要特别注意繁琐的细节:

    训练和验证模式之间可以手动切换。
    有梯度计算的手动规范。
    使用较差的 SummaryWriter 类进行日志记录。
    有一个学习率调度程序。

Pytorch lightning 工作流

1 使用LightningModule 类定义模型结构

class CIFAR10CNN(L.LightningModule):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = F.relu(self.fc1(x))x = self.fc2(x)return x

唯一的区别是,我们是从LightningModule类继承,而不是从继承nn.Module。是类LightningModule的扩展nn.Module。它将 PyTorch 工作流的训练、验证、测试、预测和优化步骤组合到一个没有循环的单一界面中。 当你开始使用时LightningModule,它被组织成六个部分:

  • 初始化(__init__和setup()方法)
  • 训练循环(training_step()方法)
  • 验证循环(validation_step()方法)
  • 测试循环(test_step()方法)
  • 预测循环(prediction_step()方法)
  • 优化器和 LR 调度程序(configure_optimizers())

我们已经看到了初始化部分。让我们继续进行训练步骤。

2 编写训练过程代码

在模型类中,复写training_step()方法

# Add the method inside the class
def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss)return loss

此方法将整个训练循环压缩为几行代码。首先,从数据batch中读取模型输入和模型输出。然后,我们运行前向传递self(x)并计算损失。然后,我们只需使用内置的 Lightning 记录器函数记录训练损失即可self.log()。

还可以在此方法中记录其他指标,例如训练准确性:

def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log("train_loss", loss)self.log("train_acc", acc)return loss

log()方法可以自动计算每个epoch的模型的各个指标,比如准确性,F1-score等等。该方法里面有一些参数是可以额外设置的,比如记录每个batch和epoch下的模型指标,模型训练和验证时创建进度条,还有将模型的各个指标输出到本地文件中。

# Log the loss at each training step and epoch, create a progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

3 编写验证和测试步骤代码

def validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('val_loss', loss)self.log('val_acc', acc)
def test_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('test_loss', loss)self.log('test_acc', acc)

唯一的区别是不需要返回计算出的指标。Lightning模块会自动将正确的数据加载器分配给验证和测试步骤,并在后台创建循环。

尽管validation_step()和test_step()看起来相同,但它们有一个关键的区别:

  • validation_step()在训练期间,直接参与模型验证。
  • test_step()在测试期间,需要调用训练器对象的.test()方法,才能执行此操作。

4 配置优化器和优化器scheduler程序

为了定义优化器和学习率调度器,需要重写configure_optimizers()类的方法。

def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5)return {"optimizer": optimizer,"lr_scheduler": {"scheduler": scheduler,"monitor": "val_loss",},}

上面,创建了一个Adam优化器,传入超参数和学习率。还定义了一个ReduceLROnPlateau调度函数,用于在验证损失稳定时降低学习率。返回对象字典是最灵活的选项,因为它允许定义需要额外参数的scheduler。

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers

5 定义callbacks和记录器

模型类和附带的训练,验证,优化器,学习率调度器和指标计算都已经完成,模型可以实现前向和反向传播,模型更新,验证,记录模型的各个指标。此时,还需要定义一系列的callbacks和记录器类型。这里定义一个checkpoint callback和记录器。

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints",monitor="val_loss",filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",save_top_k=3,mode="min",
)

ModelCheckpoint是一个强大的回调,用于在监控给定指标的同时定期保存模型。每个模型检查点都记录到dirpath中。

定义一个tensorboardlogger() 记录方法

logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn")

定义一个early_stopping callback

early_stopping = EarlyStopping(monitor="val_loss", patience=5, mode="min", verbose=False)

6 创建一个trainer类

在将模型LightningModule类和callback, 记录器全部定义完以后,就可以定义一个Trainer 类来实现模型的数据读取,自动训练,验证,模型自动保存,比较简洁。可以定义最大epoch数,使用gpu训练和gpu个数,记录器,callback,训练精度,训练数据比例(默认100%),验证数据比例(默认100%),多少个epoch 模型做一次验证,多少个epoch后记录一次模型指标,记录和模型地址,单gpu训练还是分布式训练。

# Initialize the Trainer
trainer = L.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stopping],logger=logger,accelerator="gpu" if torch.cuda.is_available() else "cpu",devices="auto",
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

7 训练和测试模型

# Train and test the modeltrainer.fit(model, train_loader, test_loader)trainer.test(model, test_loader)

8 pytorch lightning 训练模型的基本流程总结

  •   创建应用转换的训练、验证和测试数据加载器。
  • 将代码组织到一个LightningModule类中:
  • 定义初始化。
  • 定义训练、验证和(可选)测试步骤。
  • 定义优化器和学习率调度器。
  • 定义回调和记录器。
  • 创建一个训练类trainer
  • 初始化模型类。
  • 拟合并测试模型。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/55080.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【分布式微服务云原生】使用Docker体验不同Linux发行版

Docker 允许用户在同一个宿主机上运行多种不同的Linux发行版,而共享同一个宿主机内核。这种方式不仅节省资源,还非常方便进行环境测试和开发。 1. Docker与Linux发行版 Linux内核 发行版 Linux内核与各种发行版软件包组合,构成了一个完整的…

JAVA红娘婚恋相亲交友系统源码全面解析

在数字化时代,红娘婚恋相亲交友系统成为了连接单身男女的重要桥梁。JAVA作为一种流行的编程语言,为开发这样的系统提供了强大的支持。编辑h17711347205以下是对JAVA红娘婚恋相亲交友系统源码的全面解析,以及三段示例代码的展示。 系统概述 …

[产品管理-33]:实验室技术与商业化产品的距离,实验室技术在商业化过程中要越过多少道“坎”?

目录 一、实验室技术 1.1 实验室研究性技术 1.2 技术发展的S曲线 技术发展S曲线的主要阶段和特点 技术发展S曲线的意义和应用 二、实验室技术商业化的路径 2.1 实验室技术与商业化产品的距离 1、技术成熟度与稳定性 - 技术自身 2、市场需求与适应性 - 技术是满足需求 …

Visual Studio 2022

VS(Visual Studio)是一款由微软开发的集成开发环境(IDE),用于开发应用程序、网站以及移动应用等。VS的历史可以追溯到1997年,当时发布了第一个版本的VS。以下是VS的一些重要历史里程碑: Visual …

ArcEngine C#二次开发图层处理:根据属性分割图层(Split)

需求:仅根据某一属性,分割图层,并以属性值命名图层名保存。 众所周知,ArcGIS ArcToolbox中通过Split可以实现图形分割一个图层,以属性值命名图层,如下图所示。 本功能仅依据属性值,将一个shp图…

MATLAB中的模型预测控制(MPC)实现详解

模型预测控制(MPC)是一种基于模型的优化控制策略,广泛应用于工业过程控制、无人驾驶、机器人等领域。MPC通过预测未来的系统行为,优化控制输入以达到预期的控制目标。本文将详细介绍如何在MATLAB中实现MPC,包括基本原理…

Socket【C#】Demo

字段: Socket RJ45;//以太网 属性: public Socket socket { get > RJ45; set > RJ45 value; } 构造: //实例化Socket RJ45 new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); 连接: //封装…

后台监控中的云边下控耗时、边缘采集耗时 、云边下控量

云边下控耗时:指云端控制边缘设备的时间,从云端下发指令到边缘设备响应完成的时间。该指标反映了云端控制边缘设备的效率和响应速度。 边缘采集耗时:指边缘设备采集数据到云端处理完成的时间,包括数据采集、传输、处理等环节。该…

NETTY 是什么

Netty netty 很多都已经封装好了,比如客户端和服务端的连接后只要实现channelActive 方法 客户端给服务端发送send数据的时候,只要实现channgelRead方法 netty 就是一个nio程序,对nio做了很多封装,优化,用netty的时…

转行大模型的必要性与未来前景:迎接智能时代的浪潮

随着人工智能(AI)技术的迅猛发展,特别是大型语言模型(LLM, Large Language Models)的崛起,各行各业正迎来一场前所未有的技术革命。对于普通程序员而言,转行进入大模型领域不仅是对个人职业发展…

【Web】Electron:第一个桌面程序

Electron 是一个开源框架,使开发者能够使用 HTML、CSS 和 JavaScript 构建跨平台的桌面应用程序。通过 Electron,开发者可以将网页技术应用于桌面软件开发,从而利用现有的网页技术栈构建功能强大的桌面应用。 下载 Electron 虽然 Electron …

【第十五章:Sentosa_DSML社区版-机器学习之关联规则】

目录 15.1 频繁模式增长 15.2 PrefixSpan 【第十五章:Sentosa_DSML社区版-机器学习之关联规则】 机器学习关联规则是一种用于发现数据集中项之间有趣关系的方法。它基于统计和概率理论,通过分析大量数据来识别项之间的频繁共现模式。 15.1 频繁模式增…

Python 爬虫 根据ID获得UP视频信息

思路: 用selenium库对网页进行获取,然后用bs4进行分析,拿到bv号,标题,封面,时长,播放量,发布时间 先启动webdriver.,进入网页之后,先等几秒,等加…

CMake 中 add_definitions() 使用的注意事项及替代方案

CMake 中 add_definitions() 使用的注意事项及替代方案 在 CMake 中使用 add_definitions() 函数时,虽然其作用范围是全局的,但在实际应用中可能会遇到一些问题,导致其对子目录的影响不如预期。理解和避免这些问题可以帮助更高效地使用 CMak…

python中序列化和反序列化

在 Python 编程中,序列化 是指将一个 Python 对象转换为一种可以存储或传输的格式的过程。通过序列化,可以将对象的数据结构转化为诸如 JSON、XML、YAML 等格式,以便将其存储到文件、数据库,或者通过网络进行传输。与之对应的过程…

lvm管理磁盘过程记录

lvm管理磁盘过程记录.md 0.参考文章一、使用lvm在Linux系统上进行磁盘管理1.安装 LVM 工具2.创建物理卷(PV)3.创建卷组(VG)4.创建逻辑卷(LV)5.格式化逻辑卷6.挂载使用7.开机自动挂载(可选&#…

Unity 的Event的Use()方法

对于Event的Use方法,其在调用后将不会再判断同类型的事件 这种情况下,第二个MosueDown不会进入,因为已经Use 如果把Use注释掉 依旧能进入第二个MosueDown 也就是说当使用了Use方法,相同的事件类型不会进第二遍

【反素数】

题目 思路 首先分析 的性质 一定是 中约数最大的一定是约数同是最大的数字中值中最小的进一步挖掘性质,紧贴枚举的做法 约数最大值最小(也决定了层数、其它约束),是枚举的比较条件实现上述目的,枚举的质数种类在大小…

Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1 问题描述踩坑解决办法 问题描述 如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题 import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Fla…

VS Code breadcrumbs view 是什么

VS Code breadcrumbs view 是什么 正文 正文 breadcrumbs view:中文翻译,面包屑视图,乍听起来感觉十分抽象。这里我们来解释一下这个视图的含义? 如下图所示,红色框标记的部分就是 这个视图可以显示出当前打卡文件所…