PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。

相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html

训练集、验证集、测试集的使用

1.添加依赖,获取训练集和测试集

添加相应的依赖,同时使用MNIST数据集,获取训练和测试集

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

2.实现并调用test_step

在定义LightningModule中,实现test_step方法;在外部,调用test方法

class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss)# 初始化Trainer
trainer = Trainer()# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))

3.实现并调用验证集

通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法

class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def validation_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss)def test_step(self, batch, batch_idx):...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)

checkpoint

checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:

  • 16-bit scaling factor (若使用16精度训练)
  • Current epoch
  • Global step
  • LightningModule’s state_dict
  • State of all optimizers
  • State of all learning rate schedulers
  • State of all callbacks (for stateful callbacks)
  • State of datamodule (for stateful datamodules)
  • The hyperparameters (init arguments) with which the model was created
  • The hyperparameters (init arguments) with which the datamodule was created
  • State of Loops

保存与调用方法

# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval()	# disable randomness, dropout, etc...
y_hat = model(x)

调用,还可以使用torch的方法

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)

# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Lightning和PyTorch完全兼容

checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]

设置checkpoint不可见

trainer = Trainer(enable_checkpointing=False)

如果想全部重新恢复

model = LitModel()
trainer = Trainer()

自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

早停策略

EarlyStopping Callback

在Lightning中,早停回调步骤如下:

  • Import EarlyStopping callback. 载入EarlyStopping回调方法
  • Log the metric you want to monitor using log() method. 加载日志方法
  • Init the callback, and set monitor to the logged metric of your choice. 设置monitor
  • Set the mode based on the metric needs to be monitored. 设置mode
  • Pass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingclass LitModel(LightningModule):def validation_step(self, batch, batch_idx):loss = ...self.log("val_loss", loss)model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping

注意

  • EarlyStopping默认在一次Validation后调用,但是Validation可以自定义多少次epoch后进行一次验证,例如check_val_every_n_epoch and val_check_interval

完整代码

# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):def __init__(self):super().__init__()self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))def training_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)self.log("train_loss", loss)return lossdef test_step(self, batch, batch_idx):  # 测试,该方法与training_step相似x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss)def validation_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss)def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizerdef forward(self, x):# forward 定义了一次 预测/推理 行为embedding = self.encoder(x)return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/")	# 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/12247.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Visual Studio】无法打开包括文件: “dirent.h”: No such file or directory

VS2017/2019 无法打开包括文件: “dirent.h”: No such file or directory 1 “dirent.h”: No such file or directory 在windows下的VS2017/2019编译器中,发现无法打开“dirent.h”,主要是MSVC并没有实现这个头文件,但是在Linux这个头文件…

如何以管理员权限安装某个msi

介绍 如何以管理员权限安装某个msi 方法 要以管理员权限在控制台中安装一个 MSI 文件,你可以按照以下步骤操作: 打开命令提示符(或 PowerShell):按下 Win R 键,在运行窗口中输入 “cmd”(或 …

Chapter 8: Files | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介FilesPersistenceOpening filesText files and linesReading filesSearching through a fileLetting the user choose the file nameUsing try, except, and openWriting filesDebuggingGlossary Python for Everybody Exploring Data Us…

安全初级—正则表达式、This关键字、闭包

文章目录 正则表达式字面量字符元字符转义符特殊字符字符类预定义模式重复类量词符贪婪模式修饰符 This关键字使用场合使用注意点避免多层 this避免数组处理方法中的 this避免回调函数中的 this 绑定 this 的方法Function.prototype.call()Function.prototype.apply()Function.…

LLaMA 2: Open Foundation and Fine-Tuned Chat Models

LLaMA 2: Open Foundation and Fine-Tuned Chat Models Pre-trainingFine-tuningReward modelRLHF参考 Pre-training 数据层面: 预训练语料比LLaMA1多了40%,一共2T tokens,更关注了高质量数据的清洗。 其中数据不包含Meta产品与服务&#xf…

PHP8的注释-PHP8知识详解

欢迎你来到PHP服务网,学习《PHP8知识详解》系列教程,本文学习的是《PHP8的注释》。 什么是注释? 注释是在程序代码中添加的文本,用于解释和说明代码的功能、逻辑或其他相关信息。注释通常不会被编译器或解释器处理,而…

论文笔记:Fine-Grained Urban Flow Prediction

2021 WWW 1 intro 细粒度城市流量预测 两个挑战 细粒度数据中观察到的网格间的转移动态使得预测变得更加复杂 需要在全局范围内捕获网格单元之间的空间依赖性单独学习外部因素(例如天气、POI、路段信息等)对大量网格单元的影响非常具有挑战性——>论…

draw up a plan

爱情是美好的,却不是唯一的。爱情只是属于个人化的感情。 推荐一篇关于爱情的美文: 在一个小镇上,有一家以制作精美巧克力而闻名的手工巧克力店,名叫“甜蜜之爱”。这家巧克力店是由一位名叫艾玛的年轻女性经营的,她对…

iOS - Apple开发者账户添加新测试设备

获取UUID 首先将设备连接XCode,打开Window -> Devices and Simulators,通过下方位置查看 之后登录(苹果开发者网站)[https://developer.apple.com/account/] ,点击设备 点击加号添加新设备 填写信息之后点击Continue,并一路继续…

靶机精讲之Brainpan1

nmap扫描 主机发现 端口扫描 服务扫描 -sT 说明用tcp协议(三次握手)扫描 -sV扫描版本 O扫描系统 NULL是图片 10000端口是个python服务 UDP扫描 脚本扫描 web渗透 目录爆破 显示/bin/目录有东西 gobuster dir -w /usr/share/dirbuster/wordlists/di…

蓝桥杯单片机第十届国赛 真题+代码

iic.c /* # I2C代码片段说明1. 本文件夹中提供的驱动代码供参赛选手完成程序设计参考。2. 参赛选手可以自行编写相关代码或以该代码为基础&#xff0c;根据所选单片机类型、运行速度和试题中对单片机时钟频率的要求&#xff0c;进行代码调试和修改。 */ #include <STC1…

【【51单片机DA转换模块】】

爆改直流电机&#xff0c;DA转换器 main.c #include <REGX52.H> #include "Delay.h" #include "Timer0.h"sbit DAP2^1;unsigned char Counter,Compare; //计数值和比较值&#xff0c;用于输出PWM unsigned char i;void main() {Timer0_Init();whil…

后台管理系统中刷新业务功能的实现

实现 下载vueuse npm i vueuse/core在header组件中引入并给全屏按钮绑定点击事件 <el-button type"default" click"toggle" icon"FullScreen" circle></el-button>import { useFullscreen } from vueuse/coreconst { toggle } u…

【Java】分支结构习题

【Java】分支结构 文章目录 【Java】分支结构题1 &#xff1a;数字9 出现的次数题2 &#xff1a;计算1/1-1/21/3-1/41/5 …… 1/99 - 1/100 的值。题3 &#xff1a;猜数字题4 &#xff1a;牛客BC110 X图案题5 &#xff1a;输出一个整数的每一位题6 &#xff1a; 模拟三次密码输…

[SQL挖掘机] - 内连接: inner join

介绍: 内连接是一种多表连接方式&#xff0c;用于将两个或多个表中的数据通过共同的列值进行匹配&#xff0c;并返回满足连接条件的匹配行。简单来说&#xff0c;内连接能够将相关联的数据组合在一起&#xff0c;以便进行更复杂和全面的数据分析。 内连接的工作原理如下&…

bash: 睡觉的冒号;是不是两个点?

文章目录 简介躺着的冒号是两个点正常冒号总结简介 在bash里冒号和躺着的冒号的用法不一样一定要注意别用错。 躺着的冒号是两个点 难道正常的不是两个点)的作用: A sequence expression takes the form {x…y[…incr]}, where x and y are either integers or single cha…

排序算法、

描述 由小到大输出成一行&#xff0c;每个数字后面跟一个空格。 输入 三个整数 输出 输入三个整数&#xff0c;按由小到大的顺序输出。 输入样例 1 2 3 1 输出样例 1 1 2 3 输入样例 2 4 5 2 输出样例 2 2 4 5 代码一&#xff08;如下&#xff09;&#xff1…

pytest+allure运行出现乱码的解决方法

pytestallure运行出现乱码的解决方法 报错截图&#xff1a; 这里的截图摘自 悟翠人生 小伙伴的https://blog.csdn.net/weixin_45435918/article/details/107601721一文。 这是因为没有安装allure运行环境或者没有配置allure的环境变量导致&#xff0c;解决方案&#xff1a; 1…

各系统的目录信息路径

Windows系统: 查看系统版本——C:\boot.ini IIS配置文件——C:\windows\system32\inetsrv\MetaBase.xml 存储Windows系统初次安装的密码——C:\windows\repair\sam Mysql配置——C:\ProgramFiles\mysql\my.ini MySQL root密码——C:\P…

LaTex使用技巧20:LaTex修改公式的编号和最后一行对齐

写论文发现公式编号的格式不对&#xff0c;要求是如果是多行的公式&#xff0c;公式编号和公式的最后一行对齐。 我原来使用的是{equation}环境。 \begin{equation} \begin{aligned} a&bc\\ &cd \end{aligned} \end{equation}公式的编号没有和最后一行对齐。 查了一…