用 PyTorch Lightning 监控和串流 PyTorch 的训练进度 TensorBoard MNISTDataModule 训练 查看训练进度


tags:

  • AI
  • 开发/云原生/Kubernetes
  • 开发/Python/Notebook
  • 开发/Python/PyTorch
  • 开发/Python/PyTorchLightning
  • AI/PyTorch
  • AI/训练
  • AI/Tensorflow
  • AI/TensorBoard
  • 开发/Python/TensorBoard
  • 开发/Python/HuggingFace
  • AI/Trainer
  • AI/HuggingFace
  • AI/数据集/MNIST

用 PyTorch Lightning 监控和串流 PyTorch 的训练进度

这里以 PyTorch Lightning 为例子,跑 pytorch/examples 里面的 MNIST 这样的通用训练为例。

PyTorch Lightning 是一个非常有用的库,它提供了一种更简洁、更模块化的方式来构建和训练 PyTorch 模型。此外,它还提供了许多有用的工具和特性,用于监控和记录训练进度。

以下是如何使用 PyTorch Lightning 来监控和串流 PyTorch 的训练进度的一些建议:

使用 Trainer 类的 callbacks
PyTorch Lightning 的 Trainer 类有许多内置的回调函数(callbacks),可以在训练的不同阶段触发。例如,ModelCheckpoint 可以在模型性能改善时保存模型,EarlyStopping 可以在性能不再改善时停止训练。

你也可以创建自定义的回调函数,以在训练过程中执行特定的操作。例如,你可以创建一个回调来在每个 epoch 结束时记录并打印训练损失和验证损失。

from pytorch_lightning import Trainer, LightningModule, Callback  class MyCustomCallback(Callback):  def on_validation_end(self, trainer, pl_module):  print(f'Validation loss: {trainer.callback_metrics["val_loss"]}')  model = MyLightningModel()  
trainer = Trainer(callbacks=[MyCustomCallback()])  
trainer.fit(model)

使用 TensorBoard
PyTorch Lightning 提供了与 TensorBoard 的无缝集成,你可以使用它来可视化训练进度。首先,你需要在创建 Trainer 对象时启用 TensorBoard 日志:


trainer = Trainer(logger=TensorBoardLogger('tb_logs/'))然后,你可以在训练过程中使用 PyTorch Lightning 的 self.log 方法来记录指标。这些指标将在 TensorBoard 中显示。```python
class MyLightningModel(LightningModule):  def training_step(self, batch, batch_idx):  # your training step code here  loss = ...  self.log('train_loss', loss)  return loss  def validation_step(self, batch, batch_idx):  # your validation step code here  loss = ...  self.log('val_loss', loss)

使用 WandB
Weights & Biases (WandB) 是一个强大的实验跟踪工具,它可以让你记录和比较不同的模型训练运行。PyTorch Lightning 支持 WandB,你可以在创建 Trainer 对象时启用它:

import wandb  
from pytorch_lightning.loggers import WandbLogger  wandb.init(project="my-project")  
logger = WandbLogger(name="my-model", project="my-project")  trainer = Trainer(logger=logger)

然后,你可以在模型中使用 self.log 方法来记录指标,这些指标将在 WandB 的 UI 中显示。

自定义进度条
PyTorch Lightning 的 Trainer 还提供了一个自定义进度条的功能,你可以通过 progress_bar_refresh_rate 参数来设置进度条的更新频率。如果你需要更详细的进度信息,你可以考虑使用 tqdm 或其他进度条库来手动实现。

请注意,这些只是监控和串流 PyTorch 训练进度的一些基本方法。根据你的具体需求,你可能需要使用更复杂的工具或策略。

环境准备

克隆 [pytorch/examples](https://github.com

/pytorch/examples)

git clone https://github.com/pytorch/examples

构建环境

conda create -n demo-1 python=3.10

安装依赖

pip install lightning torch torchvision

::: details 当前的 Python 依赖

absl-py==2.0.0
aiohttp==3.9.1
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.12.2
google-auth==2.26.1
google-auth-oauthlib==1.2.0
grpcio==1.60.0
idna==3.6
Jinja2==3.1.2
lightning==2.1.3
lightning-utilities==0.10.0
Markdown==3.5.1
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
packaging==23.2
pillow==10.2.0
protobuf==4.23.4
pyasn1==0.5.1
pyasn1-modules==0.3.0
pytorch-lightning==2.1.3
PyYAML==6.0.1
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
six==1.16.0
sympy==1.12
tensorboard==2.15.1
tensorboard-data-server==0.7.2
torch==2.1.2
torchmetrics==1.2.1
torchvision==0.16.2
tqdm==4.66.1
triton==2.1.0
typing_extensions==4.9.0
urllib3==2.1.0
Werkzeug==3.0.1
yarl==1.9.4

:::

修改代码

examples/mnist/main.py 里面的代码修改成使用 PyTorch Lightning 的 Trainer 的代码来跑。

::: code-group

import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import pytorch_lightning as pl # [!code ++]
from torch.utils.data import DataLoader # [!code ++]# [!code --]
class Net(nn.Module): # [!code --]def __init__(self): # [!code --]super(Net, self).__init__() # [!code --]
class LitNet(pl.LightningModule): # [!code ++]def __init__(self, lr): # [!code ++]super(LitNet, self).__init__() # [!code ++]self.save_hyperparameters() # [!code ++]self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout(0.25)output = F.log_softmax(x, dim=1)return output# [!code --]
def train(args, model, device, train_loader, optimizer, epoch):# [!code --]model.train()# [!code --]for batch_idx, (data, target) in enumerate(train_loader):# [!code --]data, target = data.to(device), target.to(device)# [!code --]optimizer.zero_grad()# [!code --]output = model(data)# [!code --]def training_step(self, batch, batch_idx): # [!code ++]data, target = batch # [!code ++]output = self(data) # [!code ++]loss = F.nll_loss(output, target)loss.backward() # [!code --]optimizer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822521.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

灯光4-利用光照探头模拟局部实时光影效果

在Unity中,可以使用光照探头来模拟局部实时光影效果。光照探头是一种用于捕捉场景中光照信息的特殊组件。通过将光照探头放置在场景中的某个位置,它会记录下该位置的光照信息,并将其应用于周围的物体上,从而实现局部实时光影效果。…

1.5MHz,1.2A COT 架构同步降压变换器只要0.16元,型号:LN3435

推荐原因 1.5MHZ的开关频率,可以使用小电感,1.2A满足多数应用,价格感人,只要0.16元 产品概述 LN3435是一款电流模COT架构同步降压开关稳压器。 输入范围为 2.7V-6.0V,可提供 1.2A 的连续输出电流。 内部集成了低内阻…

学习Rust的第4天:常见编程概念

基于Steve Klabnik的《The Rust Programming Language》一书。昨天我们做了一个猜谜游戏 ,今天我们将探讨常见的编程概念,例如: Variables 变量Constants 常数Shadowing 阴影Data Types 数据类型Functions 功能 Variables 变量 In layman ter…

C语言入门第四天(数组)

一、C语言数组的基本语法 1.数组的定义 数组是 C 语言中的一种数据结构,用于存储一组具有相同数据类型的数据。数组中的每个元素可以通过一个索引(下标)来访问,索引从 0 开始,最大值为数组长度减 1。 2.定义语法格式 …

【鸿蒙开发】动画

1. 属性动画 animation放在其他属性的后面才有过渡效果 组件的某些通用属性变化时,可以通过属性动画实现渐变过渡效果,提升用户体验。支持的属性包括width、height、backgroundColor、opacity、scale、rotate、translate等。 接口: animati…

4个步骤:如何使用 SwiftSoup 和爬虫代理获取网站视频

摘要/导言 在本文中,我们将探讨如何使用 SwiftSoup 库和爬虫代理技术来获取网站上的视频资源。我们将介绍一种简洁、可靠的方法,以及实现这一目标所需的步骤。 背景/引言 随着互联网的迅速发展,爬虫技术在今天的数字世界中扮演着越来越重要…

Python也可以合并和拆分PDF,批量高效!

PDF是最方便的文档格式,可以在任何设备原样且无损的打开,但因为PDF不可编辑,所以很难去拆分合并。 知乎上也有人问,如何对PDF进行合并和拆分? 看很多回答推荐了各种PDF编辑器或者网站,确实方法比较多。 …

支持向量机模型pytorch

通过5个条件判定一件事情是否会发生,5个条件对这件事情是否发生的影响力不同,计算每个条件对这件事情发生的影响力多大,写一个支持向量机模型pytorch程序,最后打印5个条件分别的影响力。 示例一 支持向量机(SVM)是一种…

【原创】springboot+mysql理发会员管理系统设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

c++中虚函数、纯虚函数以及虚函数的实现原理

c中虚函数、纯虚函数以及虚函数的实现原理 什么是虚函数和纯虚函数 虚函数(Virtual Functions)和纯虚函数(Pure Virtual Functions)是 C 中用于实现多态性的重要概念。 虚函数(Virtual Functions) 虚函…

算法课程笔记——常用库函数

memset初始化 设置成0是可以每个设置为0 而1时会特别大 -1的补码是11111111 要先排序 unique得到的是地址 地址减去得到下标 结果会放到后面 如果这样非相邻 会出错 要先用sort排序 O(n)被O(nlogn)覆盖

服务器数据恢复—xfs文件系统节点、目录项丢失的数据恢复案例

服务器数据恢复环境: EMC某型号存储,该存储内有一组由12块磁盘组建的raid5阵列,划分了两个lun。 服务器故障: 管理员为服务器重装操作系统后,发现服务器的磁盘分区发生改变,原来的sdc3分区丢失。由于该分区…

photoshop基础学习笔记

学习 Photoshop 的基础知识是掌握图像处理和设计的关键。以下是一份基础学习笔记,帮助你开始学习 Photoshop: 1. Photoshop 界面导览 工具栏(Tool Bar):包含了各种工具,如选择工具、画笔工具、橡皮擦工具…

Linux命令学习—DHCP 服务器

1.1、DHCP 服务器 ①、DHCP(dynamic host configure protocol)动态主机配置协议 最大的功能就是向客户端提供 TCP/IP 信息,使用的是 UDP:67 端口 ②、手动设定适合:适用小型网络 ③、手动输入 IP 地址和自动获取比较优缺点 ④…

攻防演练,作为红方的步骤应该是那些

在执行合法的攻防演练中,对目标服务器如 http://XXXXX/ 进行漏洞扫描和评估需要遵循严格的步骤来确保所有活动都是安全、合法且有效的。以下是一些基本步骤和技术指南,以及使用 nmap 进行初始扫描的示例。 1. 获取授权 确保你有明确的书面授权来进行漏…

问,由于java存在性能上,以及部分功能上的缺点,请问如何正确使用C,C++,Go,这三个语言,提升Java Web项目的性能?

拓展阅读:版本任你发,我用java8 我明白Java虽然在许多方面表现出色,但在某些特定场景下可能会遇到性能瓶颈或功能限制。为了提升Java Web项目的性能,可以考虑将C、C和Go这三种语言用于特定的组件或服务。以下是如何正确使用这些语…

葡萄书--深度学习基础

卷积神经网络 卷积神经网络具有的特性: 平移不变性(translation invariance):不管检测对象出现在图像中的哪个位置,神经网络的前面几层应该对相同的图像区域具有相似的反应,即为“平移不变性”。图像的平移…

设置Linux命令行tab补全不区分大小写

root权限编辑文件 sudo vim /etc/inputrc加入新配置 [按下i键开始输入] 文件末尾加入新配置 set completion-ignore-case on保存 [按下esc键,再输入:wq确定保存] 重启 reboot

web自动化系列-selenium 的鼠标操作(十)

对于鼠标操作 ,我们可以通过click()方法进行点击操作 ,但是有些特殊场景下的操作 ,click()是无法完成的 ,比如 :我想进行鼠标悬停 、想进行鼠标拖拽 ,怎么办 ? 这个时候你用click()是无法完成的…

渲染技术如何改变影视制作的面貌

随着科技的飞速发展,影视制作领域也迎来了翻天覆地的变化。其中,渲染技术的不断革新,更是对影视制作产生了深远的影响。渲染作为影视制作中的关键环节,渲染技术的提升,不仅提升了画面的质量,还为创作者提供…