【AI基础】pytorch lightning 基础学习

传统pytorch工作流是首先定义模型框架,然后写训练和验证,测试循环代码。训练,验证,测试代码写起来比较繁琐。这里介绍使用pytorch lightning 部署模型,加速模型训练和验证,记录。

准备工作

1 安装pytorch lightning 检查版本

$ conda create -n lightning python=3.9 -y
$ conda activate lightning
import lightning as L
import torchprint("Lightning version:", L.__version__)
print("Torch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

2 加载基本库函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

3 设置随机种子(可复现性)

L.seed_everything(1121218)

4 数据集下载和增强变换

这里以CIFAR10数据集为例子,该数据集包含 10 个类的 6 万张 32x32 彩色图像,每个类 6000 张图像。

from torchvision import datasets, transforms# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test
)
# Data augmentation and normalization for training
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),],
)
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

上面的增强变换包括以下四种基本变换: 

  • 裁剪(需要指定图像大小,在本例中为 32x32)。
  • 水平翻转。
  • 转换为张量数据类型,这是 PyTorch 所必需的。
  • 对图像的每个颜色通道进行归一化处理。

传统pytorch模型训练流

定义一个CNN模型

class CIFAR10CNN(nn.Module):def __init__(self):super(CIFAR10CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = torch.relu(self.fc1(x))x = self.fc2(x)return x

编写训练、验证循环代码

  • 需要初始化模型,损失函数和优化器
  • 管理模型和数据在机器上的运行(CPU 与 GPU)
  • 训练步骤:前向传播、损失计算、反向传播和优化
  • 验证步骤:计算准确性和损失
  • tensorboard日志记录,训练损失,准确率,其他相关指标记录等
  • 模型保存
  • # Initialize the model, loss function, and optimizer
    model = CIFAR10CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)# TensorBoard setup
    writer = SummaryWriter('runs/cifar10_cnn_experiment')# Training loop
    total_step = len(train_loader)
    for epoch in range(num_epochs):model.train()train_loss = 0.0for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')# Calculate average training loss for the epochavg_train_loss = train_loss / len(train_loader)writer.add_scalar('training loss', avg_train_loss, epoch)# Validationmodel.eval()with torch.no_grad():correct = 0total = 0val_loss = 0.0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalavg_val_loss = val_loss / len(test_loader)print(f'Validation Accuracy: {accuracy:.2f}%')writer.add_scalar('validation loss', avg_val_loss, epoch)writer.add_scalar('validation accuracy', accuracy, epoch)# Learning rate schedulingscheduler.step(avg_val_loss)# Final test
    model.eval()
    with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')writer.close()# Save the model
    torch.save(model.state_dict(), 'cifar10_cnn.pth')

     在上面的代码示例,有一些需要特别注意繁琐的细节:

    训练和验证模式之间可以手动切换。
    有梯度计算的手动规范。
    使用较差的 SummaryWriter 类进行日志记录。
    有一个学习率调度程序。

Pytorch lightning 工作流

1 使用LightningModule 类定义模型结构

class CIFAR10CNN(L.LightningModule):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 64 * 4 * 4)x = F.relu(self.fc1(x))x = self.fc2(x)return x

唯一的区别是,我们是从LightningModule类继承,而不是从继承nn.Module。是类LightningModule的扩展nn.Module。它将 PyTorch 工作流的训练、验证、测试、预测和优化步骤组合到一个没有循环的单一界面中。 当你开始使用时LightningModule,它被组织成六个部分:

  • 初始化(__init__和setup()方法)
  • 训练循环(training_step()方法)
  • 验证循环(validation_step()方法)
  • 测试循环(test_step()方法)
  • 预测循环(prediction_step()方法)
  • 优化器和 LR 调度程序(configure_optimizers())

我们已经看到了初始化部分。让我们继续进行训练步骤。

2 编写训练过程代码

在模型类中,复写training_step()方法

# Add the method inside the class
def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log('train_loss', loss)return loss

此方法将整个训练循环压缩为几行代码。首先,从数据batch中读取模型输入和模型输出。然后,我们运行前向传递self(x)并计算损失。然后,我们只需使用内置的 Lightning 记录器函数记录训练损失即可self.log()。

还可以在此方法中记录其他指标,例如训练准确性:

def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log("train_loss", loss)self.log("train_acc", acc)return loss

log()方法可以自动计算每个epoch的模型的各个指标,比如准确性,F1-score等等。该方法里面有一些参数是可以额外设置的,比如记录每个batch和epoch下的模型指标,模型训练和验证时创建进度条,还有将模型的各个指标输出到本地文件中。

# Log the loss at each training step and epoch, create a progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

3 编写验证和测试步骤代码

def validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('val_loss', loss)self.log('val_acc', acc)
def test_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)acc = (y_hat.argmax(1) == y).float().mean()self.log('test_loss', loss)self.log('test_acc', acc)

唯一的区别是不需要返回计算出的指标。Lightning模块会自动将正确的数据加载器分配给验证和测试步骤,并在后台创建循环。

尽管validation_step()和test_step()看起来相同,但它们有一个关键的区别:

  • validation_step()在训练期间,直接参与模型验证。
  • test_step()在测试期间,需要调用训练器对象的.test()方法,才能执行此操作。

4 配置优化器和优化器scheduler程序

为了定义优化器和学习率调度器,需要重写configure_optimizers()类的方法。

def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5)return {"optimizer": optimizer,"lr_scheduler": {"scheduler": scheduler,"monitor": "val_loss",},}

上面,创建了一个Adam优化器,传入超参数和学习率。还定义了一个ReduceLROnPlateau调度函数,用于在验证损失稳定时降低学习率。返回对象字典是最灵活的选项,因为它允许定义需要额外参数的scheduler。

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers

5 定义callbacks和记录器

模型类和附带的训练,验证,优化器,学习率调度器和指标计算都已经完成,模型可以实现前向和反向传播,模型更新,验证,记录模型的各个指标。此时,还需要定义一系列的callbacks和记录器类型。这里定义一个checkpoint callback和记录器。

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints",monitor="val_loss",filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",save_top_k=3,mode="min",
)

ModelCheckpoint是一个强大的回调,用于在监控给定指标的同时定期保存模型。每个模型检查点都记录到dirpath中。

定义一个tensorboardlogger() 记录方法

logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn")

定义一个early_stopping callback

early_stopping = EarlyStopping(monitor="val_loss", patience=5, mode="min", verbose=False)

6 创建一个trainer类

在将模型LightningModule类和callback, 记录器全部定义完以后,就可以定义一个Trainer 类来实现模型的数据读取,自动训练,验证,模型自动保存,比较简洁。可以定义最大epoch数,使用gpu训练和gpu个数,记录器,callback,训练精度,训练数据比例(默认100%),验证数据比例(默认100%),多少个epoch 模型做一次验证,多少个epoch后记录一次模型指标,记录和模型地址,单gpu训练还是分布式训练。

# Initialize the Trainer
trainer = L.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stopping],logger=logger,accelerator="gpu" if torch.cuda.is_available() else "cpu",devices="auto",
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

7 训练和测试模型

# Train and test the modeltrainer.fit(model, train_loader, test_loader)trainer.test(model, test_loader)

8 pytorch lightning 训练模型的基本流程总结

  •   创建应用转换的训练、验证和测试数据加载器。
  • 将代码组织到一个LightningModule类中:
  • 定义初始化。
  • 定义训练、验证和(可选)测试步骤。
  • 定义优化器和学习率调度器。
  • 定义回调和记录器。
  • 创建一个训练类trainer
  • 初始化模型类。
  • 拟合并测试模型。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/55080.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA红娘婚恋相亲交友系统源码全面解析

在数字化时代,红娘婚恋相亲交友系统成为了连接单身男女的重要桥梁。JAVA作为一种流行的编程语言,为开发这样的系统提供了强大的支持。编辑h17711347205以下是对JAVA红娘婚恋相亲交友系统源码的全面解析,以及三段示例代码的展示。 系统概述 …

[产品管理-33]:实验室技术与商业化产品的距离,实验室技术在商业化过程中要越过多少道“坎”?

目录 一、实验室技术 1.1 实验室研究性技术 1.2 技术发展的S曲线 技术发展S曲线的主要阶段和特点 技术发展S曲线的意义和应用 二、实验室技术商业化的路径 2.1 实验室技术与商业化产品的距离 1、技术成熟度与稳定性 - 技术自身 2、市场需求与适应性 - 技术是满足需求 …

ArcEngine C#二次开发图层处理:根据属性分割图层(Split)

需求:仅根据某一属性,分割图层,并以属性值命名图层名保存。 众所周知,ArcGIS ArcToolbox中通过Split可以实现图形分割一个图层,以属性值命名图层,如下图所示。 本功能仅依据属性值,将一个shp图…

转行大模型的必要性与未来前景:迎接智能时代的浪潮

随着人工智能(AI)技术的迅猛发展,特别是大型语言模型(LLM, Large Language Models)的崛起,各行各业正迎来一场前所未有的技术革命。对于普通程序员而言,转行进入大模型领域不仅是对个人职业发展…

【第十五章:Sentosa_DSML社区版-机器学习之关联规则】

目录 15.1 频繁模式增长 15.2 PrefixSpan 【第十五章:Sentosa_DSML社区版-机器学习之关联规则】 机器学习关联规则是一种用于发现数据集中项之间有趣关系的方法。它基于统计和概率理论,通过分析大量数据来识别项之间的频繁共现模式。 15.1 频繁模式增…

Python 爬虫 根据ID获得UP视频信息

思路: 用selenium库对网页进行获取,然后用bs4进行分析,拿到bv号,标题,封面,时长,播放量,发布时间 先启动webdriver.,进入网页之后,先等几秒,等加…

Unity 的Event的Use()方法

对于Event的Use方法,其在调用后将不会再判断同类型的事件 这种情况下,第二个MosueDown不会进入,因为已经Use 如果把Use注释掉 依旧能进入第二个MosueDown 也就是说当使用了Use方法,相同的事件类型不会进第二遍

【反素数】

题目 思路 首先分析 的性质 一定是 中约数最大的一定是约数同是最大的数字中值中最小的进一步挖掘性质,紧贴枚举的做法 约数最大值最小(也决定了层数、其它约束),是枚举的比较条件实现上述目的,枚举的质数种类在大小…

Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1 问题描述踩坑解决办法 问题描述 如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题 import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Fla…

VS Code breadcrumbs view 是什么

VS Code breadcrumbs view 是什么 正文 正文 breadcrumbs view:中文翻译,面包屑视图,乍听起来感觉十分抽象。这里我们来解释一下这个视图的含义? 如下图所示,红色框标记的部分就是 这个视图可以显示出当前打卡文件所…

新手答疑 | 零基础该怎么学习嵌入式?嵌入式Linux学习路线是什么?嵌入式开发板推荐?

很多初学者想要涉足嵌入式Linux开发领域,但往往在刚入门阶段,会因为初次接触到大量复杂的概念术语和深奥的技术文档感到压力重重,面对这些内容不知从何下手,感到十分迷茫,网上的内容也纷繁复杂,没有清晰的学…

从 Kafka 到 WarpStream: 用 MinIO 简化数据流

虽然 Apache Kafka 长期以来一直是流数据的行业标准,但新的创新替代方案正在重塑生态系统。其中之一是 WarpStream,它最近在 Confluent 的所有权下进入了新的篇章。此次收购进一步增强了 WarpStream 提供高性能、云原生数据流的能力,巩固了其…

SAP Message - self-explanatory 自身说明

SAP Message 解释、创建和应用可见如下文章:SAP Abap】SE91 - SAP MESSAGE 消息类创建与应用-CSDN博客 SE91 SAP消息类型 - tongxiaohu - 博客园 这里主要想聊一下常用的SE91 中不常用的功能 - 自身说明 选项的作用。 以 VF - 004 为例: 我们都知道自…

2024双十一买啥最划算?2024双十一五款值得入手的好物入手

2024双十一购物狂欢节将至,还在为买什么而纠结吗?这里为你入手五款值得入手的好物。从生活必备到时尚单品,涵盖多个领域,让你在双十一以划算的价格买到心仪之物,开启品质生活新旅程。 一、西圣find可视挖耳勺 入手理…

毕业设计选题:基于ssm+vue+uniapp的校园订餐小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

【补充】倒易点阵定义

晶体点阵:晶体内部结构在三维空间周期平移的客观存在的数学抽象,反映晶体实际原子排列。 倒易点阵:通过对晶体的正点阵进行傅里叶变换得到的,其中正点阵中每个阵点的位置矢量方向代表晶面族的法向,位置矢量的长度是晶…

CSS04-Chrome调试工具

Chrome 浏览器提供了一个非常好用的调试工具,可以用来调试我们的 HTML结构和 CSS 样式。

我们是向量数据库的领军企业,我们只招TOP人才

我们是全球领先的向量数据库企业,业务正在快速发展,现开放大量岗位: 前端、产品经理、数据库开发工程师、C、数据库运维、数据库测试…… 我们招聘的唯一目标,寻找 TOP人才! 如果你已经有丰富的经验,那么加…

jmeter-请求参数加密-MD5加密

方法1 :使用jmeter自带的函数助手digest Tool(工具)---Function Helper Dialog(函数助手对话框) 第一个参数是要md5加密的值,第二个参数是保存加密后值的变量 ( 此处变量是从txt文件导入的,所以使用的是${wd} ) …

overlayscrollbars使用

官网 https://github.com/KingSora/OverlayScrollbars 使用 <link href"https://cdn.bootcdn.net/ajax/libs/overlayscrollbars/2.10.0/styles/overlayscrollbars.css" rel"stylesheet"> <script src"https://cdn.bootcdn.net/ajax/libs/…