使用 SwanLab 进行可视化 MNIST 手写体识别训练

使用 SwanLab 进行可视化 MNIST 手写体识别训练

在线演示demo

本案例主要:

  • 使用pytorch进行CNN(卷积神经网络)的构建、模型训练与评估
  • 使用swanlab跟踪超参数、记录指标和可视化监控整个训练周期

一、相关简介

SwanLab

SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。

SwanLab提供了一套云端AI实验跟踪方案,面向训练过程,提供了训练可视化、实验跟踪、超参数记录、日志记录、多人协同等功能,研究者能轻松通过直观的可视化图表找到迭代灵感,并且通过在线链接的分享与基于组织的多人协同训练,打破团队沟通的壁垒。

可视化界面截图:

在这里插入图片描述

MNIST

MNIST手写体识别是深度学习最经典的入门任务之一,由 LeCun 等人提出。
该任务基于MNIST数据集,研究者通过构建机器学习模型,来识别10个手写数字(0~9)。

二、环境配置

本案例基于Python>=3.8,请在您的计算机上安装好Python。
环境依赖:

torch
torchvision
swanlab

快速安装命令:

pip install torch torchvision swanlab

MNIST 数据集已经被 torch 自动集成了,所以不需要额外下载,很方便。

三、训练代码

复制以下代码,创建 app.py 并粘贴代码,保存后直接使用 python 或 IDE 运行:python app.py

import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import ResNet18_Weights
import swanlab# CNN网络构建
class ConvNet(nn.Module):def __init__(self):super().__init__()# 1,28x28self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10self.fc1 = nn.Linear(20 * 10 * 10, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):in_size = x.size(0)out = self.conv1(x)  # 24out = F.relu(out)out = F.max_pool2d(out, 2, 2)  # 12out = self.conv2(out)  # 10out = F.relu(out)out = out.view(in_size, -1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.log_softmax(out, dim=1)return out# 捕获并可视化前20张图像
def log_images(loader, num_images=16):images_logged = 0logged_images = []for images, labels in loader:# images: batch of images, labels: batch of labelsfor i in range(images.shape[0]):if images_logged < num_images:# 使用swanlab.Image将图像转换为wandb可视化格式logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))images_logged += 1else:breakif images_logged >= num_images:breakswanlab.log({"MNIST-Preview": logged_images})def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):model.train()# 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签for iter, (inputs, labels) in enumerate(train_dataloader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 将结果和标签传入损失函数中计算交叉熵损失loss = criterion(outputs, labels)# 4. 根据损失计算反向传播loss.backward()# 5. 优化器执行模型参数更新optimizer.step()print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),loss.item()))# 6. 每20次迭代,用SwanLab记录一下loss的变化if iter % 20 == 0:swanlab.log({"train/loss": loss.item()})def test(model, device, val_dataloader, epoch):model.eval()correct = 0total = 0with torch.no_grad():# 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签for inputs, labels in val_dataloader:inputs, labels = inputs.to(device), labels.to(device)# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 获得预测的数字_, predicted = torch.max(outputs, 1)total += labels.size(0)# 4. 计算与标签一致的预测结果的数量correct += (predicted == labels).sum().item()# 5. 得到最终的测试准确率accuracy = correct / total# 6. 用SwanLab记录一下准确率的变化swanlab.log({"val/accuracy": accuracy}, step=epoch)if __name__ == "__main__":#检测是否支持mpstry:use_mps = torch.backends.mps.is_available()except AttributeError:use_mps = False#检测是否支持cudaif torch.cuda.is_available():device = "cuda"elif use_mps:device = "mps"else:device = "cpu"# 初始化swanlabrun = swanlab.init(project="MNIST-example",experiment_name="PlainCNN",config={"model": "ResNet18","optim": "Adam","lr": 1e-4,"batch_size": 256,"num_epochs": 10,"device": device,},)# 设置MNIST训练集和验证集dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)# (可选)看一下数据集的前16张图像log_images(train_dataloader, 16)# 初始化模型model = ConvNet()model.to(torch.device(device))# 打印模型print(model)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=run.config.lr)# 开始训练和测试循环for epoch in range(1, run.config.num_epochs+1):swanlab.log({"train/epoch": epoch}, step=epoch)train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)if epoch % 2 == 0: test(model, device, val_dataloader, epoch)# 保存模型# 如果不存在checkpoint文件夹,则自动创建一个if not os.path.exists("checkpoint"):os.makedirs("checkpoint")torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')

四、注意事项

在这里插入图片描述
在运行代码的时候,可能会出现如上提示,需要输入一个凭证,这个时候我们只需要去 SwanLab 云端版登录并获取,复制后粘贴到终端,回车后继续运行即可:

在这里插入图片描述

当然,有云端版肯定也有本地版。

上面的训练会将训练数据上传到云端,让我们可以直接通过在线链接的方式访问自己的实验数据和实验进度 。但是还可以选择不上传,而通过本地命令在本机开启一个面板服务,其前端界面与云端版基本一致,同样能查看实验数据和详细信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/16999.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gRPC - Protocol Buffer 编译器安装

文章目录 Protocol Buffer 编译器安装如何安装 Protocol Buffer 编译器使用包管理器安装Linux 上&#xff0c;使用 apt 或 apt-get&#xff0c;例如&#xff1a;macOS 上&#xff0c;使用 Homebrew&#xff1a; 安装预编译的二进制文件&#xff08;任何操作系统&#xff09;其他…

凭借这点高效备考HCIA,一般人我不告诉他

在这个信息爆炸的时代&#xff0c;备考HCIA认证时&#xff0c;利用好在线资源可以大大提升学习效率。 以下是一些只有内行人才知道的高效备考策略&#xff0c;通过这些方法&#xff0c;你可以事半功倍&#xff0c;轻松拿下HCIA认证。 01 找到最适合你的学习方式 每个人的学习方…

Linux中ftp配置

一、ftp协议 1、端口 ftp默认使用20、21端口 20端口用于建立数据连接 21端口用于建立控制连接 2、ftp数据连接模式 主动模式&#xff1a;服务器主动发起数据连接 被动模式&#xff1a;服务器被动等待数据连接 二、ftp安装 yum install -y vsftpd #---下…

使用httpx异步获取高校招生信息:一步到位的代理配置教程

概述 随着2024年中国高考的临近&#xff0c;考生和家长对高校招生信息的需求日益增加。了解各高校的专业、课程设置和录取标准对于高考志愿填报至关重要。通过爬虫技术&#xff0c;可以高效地从各高校官网获取这些关键信息。然而&#xff0c;面对大量的请求和反爬机制的挑战&a…

蓝桥杯物联网竞赛_STM32L071KBU6_字符串处理

前言&#xff1a; 个人感觉国赛相较于省赛难度上升的点在于对于接收的字符串的处理&#xff0c;例如串口发送的字符串一般包含字母字符串 数字字符串&#xff0c;亦或者更复杂&#xff0c;对于LORA也是如此&#xff0c;传递的字符串如#9#1亦或者#1a#90,#1#12&#xff0c;如何…

4.每日LeetCode-数组类,斐波那契数(Go,Java,Python)

题目 题号&#xff1a;509斐波那契数 &#xff08;通常用 F(n) 表示&#xff09;形成的序列称为 斐波那契数列 。该数列由 0 和 1 开始&#xff0c;后面的每一项数字都是前面两项数字的和。也就是&#xff1a; F(0) 0&#xff0c;F(1) 1 F(n) F(n - 1) F(n - 2)&#xff0…

剖析【C++】——类与对象(上)超详解——小白篇

目录 1.面向过程和面向对象的初步认识 1.面向过程&#xff08;Procedural Programming&#xff09; 2.面向对象&#xff08;Object-Oriented Programming&#xff09; 概念&#xff1a; 特点&#xff1a; 总结 2.C 类的引入 1.从 C 语言的结构体到 C 的类 2.C 中的结构…

调用萨姆索诺夫函数:深入探索函数的参数与返回值

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、萨姆索诺夫函数的引入与调用 二、如何获取函数的返回值 三、无参数与无返回值的函数调…

帝国CMS验证码不显示怎么回事呢?

帝国CMS验证码有时候会不显示或打叉&#xff0c;总结自己的解决方法。 1、检查服务器是否开启GD库 测试GD库是否开启的方法&#xff1a;浏览器访问&#xff1a;/e/showkey/index.php&#xff0c;如果出现一堆乱码或报错&#xff0c;证明GD库没有开启&#xff0c;开启即可。 2…

聊聊2024上半年软考高项

今年和去年机考区别&#xff1a; 去年高项分了三个批次&#xff0c;其中有一部分人不是在周末考的试&#xff1b;今年分了两个批次&#xff0c;时间是25、26号&#xff1b;仍然是联考的形式。去年是综合知识在上午&#xff0c;案例分析和论文在下午&#xff1b;今年是综合知识…

[随笔] 在CSDN的6周年纪念日随笔

纪念 转眼已过6年&#xff0c;大一的时候学习编程&#xff0c;潜水 CSDN 学习各类博文&#xff0c;才学浅薄就没有主动写博文记录自己的学习历程。 过了段时间刚刚到了大二&#xff0c;很喜欢 Todolist&#xff0c;意气风发的写下《一份清爽的编程计划》&#xff0c;哈哈。 …

一文读懂npm i的命令以及作用

目录 1. 基本知识2. 常见用法 1. 基本知识 npm i 是 Node Package Manager (npm) 的一个命令&#xff0c;用于安装 Node.js 项目依赖的包 是 npm install 的简写形式&#xff0c;功能完全相同 详细解析 npm&#xff1a; npm 是 Node.js 的包管理工具&#xff0c;用于安装、共…

数据结构-队列(带图详解)

目录 队列的概念 画图理解队列 代码图理解 代码展示(注意这个队列是单链表的结构实现) Queue.h(队列结构) Queue.c(函数/API实现) main.c(测试文件) 队列的概念 队列&#xff08;Queue&#xff09;是一种基础的数据结构&#xff0c;它遵循先进先出&#xff08;First In …

二十八、openlayers官网示例Data Tiles解析——自定义绘制DataTile源数据

官网demo地址&#xff1a; https://openlayers.org/en/latest/examples/data-tiles.html 这篇示例讲解的是自定义加载DataTile源格式的数据。 先来看一下什么是DataTile&#xff0c;这个源是一个数组&#xff0c;与我们之前XYZ切片源有所不同。DataTile主要适用于需要动态生成…

经典面试题:MySQL如何调优?

目录 前言1. SQL查询优化2. 索引优化3. 表结构设计4. 硬件与配置优化5. 日常维护6. 性能测试与基准测试 前言 MySQL如何进行调优&#xff1f;这是面试中容易被问到的高频问题。 1. SQL查询优化 避免使用select* &#xff1a;只选取需要的列&#xff0c;减少数据传输量。使用…

Host头攻击-使用安全的Web服务器配置

Nginx配置示例 在Nginx中&#xff0c;你可以通过修改配置文件来验证HTTP Host头&#xff0c;确保它符合预期的值。以下是一个简单的配置示例&#xff1a; 1.添加HTTP Host头验证规则&#xff1a; 在Nginx的配置文件中&#xff0c;找到针对目标URL的相关配置块&#xff0c;并…

算法简单笔记2

5月26号&#xff0c;之前学了两天算法烦了&#xff0c;去学了几天鸿蒙&#xff0c;今天又回来看一下算法&#xff0c;距离6月1日国赛还有6天&#xff0c;哈哈真是等死咯...... 一、蓝桥杯第13届国赛第1题填空题&#xff1a;重合次数 &#xff08;半难不难&#xff0c;写编程难…

通过JavaScript本地存储数据

文章目录 本地存储本地存储分类 - localStorage本地存储分类 - sessionStorage存储复杂数据类型解决方法 本地存储 数据存储在用户浏览器中设置、读取方便、甚至页面刷新都不丢失数据容量较大&#xff0c;sessionStorage和localStorage约5M左右 本地存储分类 - localStorage …

探索演进:了解IPv4和IPv6之间的区别

探索演进&#xff1a;了解IPv4和IPv6之间的区别 在广阔的互联网领域中&#xff0c;设备之间的通信依赖于一组独特的协议来促进连接。前景协议中&#xff0c;IPv4&#xff08;Internet 协议版本 4&#xff09;和 IPv6&#xff08;Internet 协议版本 6&#xff09;是数字基础设施…