使用 SwanLab 进行可视化 MNIST 手写体识别训练
在线演示demo
本案例主要:
- 使用pytorch进行CNN(卷积神经网络)的构建、模型训练与评估
- 使用swanlab跟踪超参数、记录指标和可视化监控整个训练周期
一、相关简介
SwanLab
SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。
SwanLab提供了一套云端AI实验跟踪方案,面向训练过程,提供了训练可视化、实验跟踪、超参数记录、日志记录、多人协同等功能,研究者能轻松通过直观的可视化图表找到迭代灵感,并且通过在线链接的分享与基于组织的多人协同训练,打破团队沟通的壁垒。
可视化界面截图:
MNIST
MNIST手写体识别是深度学习最经典的入门任务之一,由 LeCun 等人提出。
该任务基于MNIST数据集,研究者通过构建机器学习模型,来识别10个手写数字(0~9)。
二、环境配置
本案例基于Python>=3.8,请在您的计算机上安装好Python。
环境依赖:
torch
torchvision
swanlab
快速安装命令:
pip install torch torchvision swanlab
MNIST 数据集已经被 torch 自动集成了,所以不需要额外下载,很方便。
三、训练代码
复制以下代码,创建 app.py
并粘贴代码,保存后直接使用 python 或 IDE 运行:python app.py
import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import ResNet18_Weights
import swanlab# CNN网络构建
class ConvNet(nn.Module):def __init__(self):super().__init__()# 1,28x28self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10self.fc1 = nn.Linear(20 * 10 * 10, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):in_size = x.size(0)out = self.conv1(x) # 24out = F.relu(out)out = F.max_pool2d(out, 2, 2) # 12out = self.conv2(out) # 10out = F.relu(out)out = out.view(in_size, -1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.log_softmax(out, dim=1)return out# 捕获并可视化前20张图像
def log_images(loader, num_images=16):images_logged = 0logged_images = []for images, labels in loader:# images: batch of images, labels: batch of labelsfor i in range(images.shape[0]):if images_logged < num_images:# 使用swanlab.Image将图像转换为wandb可视化格式logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))images_logged += 1else:breakif images_logged >= num_images:breakswanlab.log({"MNIST-Preview": logged_images})def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):model.train()# 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签for iter, (inputs, labels) in enumerate(train_dataloader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 将结果和标签传入损失函数中计算交叉熵损失loss = criterion(outputs, labels)# 4. 根据损失计算反向传播loss.backward()# 5. 优化器执行模型参数更新optimizer.step()print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),loss.item()))# 6. 每20次迭代,用SwanLab记录一下loss的变化if iter % 20 == 0:swanlab.log({"train/loss": loss.item()})def test(model, device, val_dataloader, epoch):model.eval()correct = 0total = 0with torch.no_grad():# 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签for inputs, labels in val_dataloader:inputs, labels = inputs.to(device), labels.to(device)# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 获得预测的数字_, predicted = torch.max(outputs, 1)total += labels.size(0)# 4. 计算与标签一致的预测结果的数量correct += (predicted == labels).sum().item()# 5. 得到最终的测试准确率accuracy = correct / total# 6. 用SwanLab记录一下准确率的变化swanlab.log({"val/accuracy": accuracy}, step=epoch)if __name__ == "__main__":#检测是否支持mpstry:use_mps = torch.backends.mps.is_available()except AttributeError:use_mps = False#检测是否支持cudaif torch.cuda.is_available():device = "cuda"elif use_mps:device = "mps"else:device = "cpu"# 初始化swanlabrun = swanlab.init(project="MNIST-example",experiment_name="PlainCNN",config={"model": "ResNet18","optim": "Adam","lr": 1e-4,"batch_size": 256,"num_epochs": 10,"device": device,},)# 设置MNIST训练集和验证集dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)# (可选)看一下数据集的前16张图像log_images(train_dataloader, 16)# 初始化模型model = ConvNet()model.to(torch.device(device))# 打印模型print(model)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=run.config.lr)# 开始训练和测试循环for epoch in range(1, run.config.num_epochs+1):swanlab.log({"train/epoch": epoch}, step=epoch)train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)if epoch % 2 == 0: test(model, device, val_dataloader, epoch)# 保存模型# 如果不存在checkpoint文件夹,则自动创建一个if not os.path.exists("checkpoint"):os.makedirs("checkpoint")torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
四、注意事项
在运行代码的时候,可能会出现如上提示,需要输入一个凭证,这个时候我们只需要去 SwanLab 云端版登录并获取,复制后粘贴到终端,回车后继续运行即可:
当然,有云端版肯定也有本地版。
上面的训练会将训练数据上传到云端,让我们可以直接通过在线链接的方式访问自己的实验数据和实验进度 。但是还可以选择不上传,而通过本地命令在本机开启一个面板服务,其前端界面与云端版基本一致,同样能查看实验数据和详细信息。