tags:
- AI
- 开发/云原生/Kubernetes
- 开发/Python/Notebook
- 开发/Python/PyTorch
- 开发/Python/PyTorchLightning
- AI/PyTorch
- AI/训练
- AI/Tensorflow
- AI/TensorBoard
- 开发/Python/TensorBoard
- 开发/Python/HuggingFace
- AI/Trainer
- AI/HuggingFace
- AI/数据集/MNIST
用 PyTorch Lightning 监控和串流 PyTorch 的训练进度
这里以 PyTorch Lightning 为例子,跑 pytorch/examples 里面的 MNIST 这样的通用训练为例。
PyTorch Lightning 是一个非常有用的库,它提供了一种更简洁、更模块化的方式来构建和训练 PyTorch 模型。此外,它还提供了许多有用的工具和特性,用于监控和记录训练进度。
以下是如何使用 PyTorch Lightning 来监控和串流 PyTorch 的训练进度的一些建议:
使用 Trainer 类的 callbacks
PyTorch Lightning 的 Trainer 类有许多内置的回调函数(callbacks),可以在训练的不同阶段触发。例如,ModelCheckpoint 可以在模型性能改善时保存模型,EarlyStopping 可以在性能不再改善时停止训练。
你也可以创建自定义的回调函数,以在训练过程中执行特定的操作。例如,你可以创建一个回调来在每个 epoch 结束时记录并打印训练损失和验证损失。
from pytorch_lightning import Trainer, LightningModule, Callback class MyCustomCallback(Callback): def on_validation_end(self, trainer, pl_module): print(f'Validation loss: {trainer.callback_metrics["val_loss"]}') model = MyLightningModel()
trainer = Trainer(callbacks=[MyCustomCallback()])
trainer.fit(model)
使用 TensorBoard
PyTorch Lightning 提供了与 TensorBoard 的无缝集成,你可以使用它来可视化训练进度。首先,你需要在创建 Trainer 对象时启用 TensorBoard 日志:
trainer = Trainer(logger=TensorBoardLogger('tb_logs/'))然后,你可以在训练过程中使用 PyTorch Lightning 的 self.log 方法来记录指标。这些指标将在 TensorBoard 中显示。```python
class MyLightningModel(LightningModule): def training_step(self, batch, batch_idx): # your training step code here loss = ... self.log('train_loss', loss) return loss def validation_step(self, batch, batch_idx): # your validation step code here loss = ... self.log('val_loss', loss)
使用 WandB
Weights & Biases (WandB) 是一个强大的实验跟踪工具,它可以让你记录和比较不同的模型训练运行。PyTorch Lightning 支持 WandB,你可以在创建 Trainer 对象时启用它:
import wandb
from pytorch_lightning.loggers import WandbLogger wandb.init(project="my-project")
logger = WandbLogger(name="my-model", project="my-project") trainer = Trainer(logger=logger)
然后,你可以在模型中使用 self.log 方法来记录指标,这些指标将在 WandB 的 UI 中显示。
自定义进度条
PyTorch Lightning 的 Trainer 还提供了一个自定义进度条的功能,你可以通过 progress_bar_refresh_rate 参数来设置进度条的更新频率。如果你需要更详细的进度信息,你可以考虑使用 tqdm 或其他进度条库来手动实现。
请注意,这些只是监控和串流 PyTorch 训练进度的一些基本方法。根据你的具体需求,你可能需要使用更复杂的工具或策略。
环境准备
克隆 [pytorch/examples](https://github.com
/pytorch/examples)
git clone https://github.com/pytorch/examples
构建环境
conda create -n demo-1 python=3.10
安装依赖
pip install lightning torch torchvision
::: details 当前的 Python 依赖
absl-py==2.0.0
aiohttp==3.9.1
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.12.2
google-auth==2.26.1
google-auth-oauthlib==1.2.0
grpcio==1.60.0
idna==3.6
Jinja2==3.1.2
lightning==2.1.3
lightning-utilities==0.10.0
Markdown==3.5.1
MarkupSafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
networkx==3.2.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
packaging==23.2
pillow==10.2.0
protobuf==4.23.4
pyasn1==0.5.1
pyasn1-modules==0.3.0
pytorch-lightning==2.1.3
PyYAML==6.0.1
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
six==1.16.0
sympy==1.12
tensorboard==2.15.1
tensorboard-data-server==0.7.2
torch==2.1.2
torchmetrics==1.2.1
torchvision==0.16.2
tqdm==4.66.1
triton==2.1.0
typing_extensions==4.9.0
urllib3==2.1.0
Werkzeug==3.0.1
yarl==1.9.4
:::
修改代码
把 examples/mnist/main.py
里面的代码修改成使用 PyTorch Lightning 的 Trainer 的代码来跑。
::: code-group
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import pytorch_lightning as pl # [!code ++]
from torch.utils.data import DataLoader # [!code ++]# [!code --]
class Net(nn.Module): # [!code --]def __init__(self): # [!code --]super(Net, self).__init__() # [!code --]
class LitNet(pl.LightningModule): # [!code ++]def __init__(self, lr): # [!code ++]super(LitNet, self).__init__() # [!code ++]self.save_hyperparameters() # [!code ++]self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout(0.25)output = F.log_softmax(x, dim=1)return output# [!code --]
def train(args, model, device, train_loader, optimizer, epoch):# [!code --]model.train()# [!code --]for batch_idx, (data, target) in enumerate(train_loader):# [!code --]data, target = data.to(device), target.to(device)# [!code --]optimizer.zero_grad()# [!code --]output = model(data)# [!code --]def training_step(self, batch, batch_idx): # [!code ++]data, target = batch # [!code ++]output = self(data) # [!code ++]loss = F.nll_loss(output, target)loss.backward() # [!code --]optimizer