使用 SwanLab 进行可视化 MNIST 手写体识别训练

使用 SwanLab 进行可视化 MNIST 手写体识别训练

在线演示demo

本案例主要:

  • 使用pytorch进行CNN(卷积神经网络)的构建、模型训练与评估
  • 使用swanlab跟踪超参数、记录指标和可视化监控整个训练周期

一、相关简介

SwanLab

SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。

SwanLab提供了一套云端AI实验跟踪方案,面向训练过程,提供了训练可视化、实验跟踪、超参数记录、日志记录、多人协同等功能,研究者能轻松通过直观的可视化图表找到迭代灵感,并且通过在线链接的分享与基于组织的多人协同训练,打破团队沟通的壁垒。

可视化界面截图:

在这里插入图片描述

MNIST

MNIST手写体识别是深度学习最经典的入门任务之一,由 LeCun 等人提出。
该任务基于MNIST数据集,研究者通过构建机器学习模型,来识别10个手写数字(0~9)。

二、环境配置

本案例基于Python>=3.8,请在您的计算机上安装好Python。
环境依赖:

torch
torchvision
swanlab

快速安装命令:

pip install torch torchvision swanlab

MNIST 数据集已经被 torch 自动集成了,所以不需要额外下载,很方便。

三、训练代码

复制以下代码,创建 app.py 并粘贴代码,保存后直接使用 python 或 IDE 运行:python app.py

import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import ResNet18_Weights
import swanlab# CNN网络构建
class ConvNet(nn.Module):def __init__(self):super().__init__()# 1,28x28self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10self.fc1 = nn.Linear(20 * 10 * 10, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):in_size = x.size(0)out = self.conv1(x)  # 24out = F.relu(out)out = F.max_pool2d(out, 2, 2)  # 12out = self.conv2(out)  # 10out = F.relu(out)out = out.view(in_size, -1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.log_softmax(out, dim=1)return out# 捕获并可视化前20张图像
def log_images(loader, num_images=16):images_logged = 0logged_images = []for images, labels in loader:# images: batch of images, labels: batch of labelsfor i in range(images.shape[0]):if images_logged < num_images:# 使用swanlab.Image将图像转换为wandb可视化格式logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))images_logged += 1else:breakif images_logged >= num_images:breakswanlab.log({"MNIST-Preview": logged_images})def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):model.train()# 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签for iter, (inputs, labels) in enumerate(train_dataloader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 将结果和标签传入损失函数中计算交叉熵损失loss = criterion(outputs, labels)# 4. 根据损失计算反向传播loss.backward()# 5. 优化器执行模型参数更新optimizer.step()print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),loss.item()))# 6. 每20次迭代,用SwanLab记录一下loss的变化if iter % 20 == 0:swanlab.log({"train/loss": loss.item()})def test(model, device, val_dataloader, epoch):model.eval()correct = 0total = 0with torch.no_grad():# 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签for inputs, labels in val_dataloader:inputs, labels = inputs.to(device), labels.to(device)# 2. 传入到resnet18模型中得到预测结果outputs = model(inputs)# 3. 获得预测的数字_, predicted = torch.max(outputs, 1)total += labels.size(0)# 4. 计算与标签一致的预测结果的数量correct += (predicted == labels).sum().item()# 5. 得到最终的测试准确率accuracy = correct / total# 6. 用SwanLab记录一下准确率的变化swanlab.log({"val/accuracy": accuracy}, step=epoch)if __name__ == "__main__":#检测是否支持mpstry:use_mps = torch.backends.mps.is_available()except AttributeError:use_mps = False#检测是否支持cudaif torch.cuda.is_available():device = "cuda"elif use_mps:device = "mps"else:device = "cpu"# 初始化swanlabrun = swanlab.init(project="MNIST-example",experiment_name="PlainCNN",config={"model": "ResNet18","optim": "Adam","lr": 1e-4,"batch_size": 256,"num_epochs": 10,"device": device,},)# 设置MNIST训练集和验证集dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)# (可选)看一下数据集的前16张图像log_images(train_dataloader, 16)# 初始化模型model = ConvNet()model.to(torch.device(device))# 打印模型print(model)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=run.config.lr)# 开始训练和测试循环for epoch in range(1, run.config.num_epochs+1):swanlab.log({"train/epoch": epoch}, step=epoch)train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)if epoch % 2 == 0: test(model, device, val_dataloader, epoch)# 保存模型# 如果不存在checkpoint文件夹,则自动创建一个if not os.path.exists("checkpoint"):os.makedirs("checkpoint")torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')

四、注意事项

在这里插入图片描述
在运行代码的时候,可能会出现如上提示,需要输入一个凭证,这个时候我们只需要去 SwanLab 云端版登录并获取,复制后粘贴到终端,回车后继续运行即可:

在这里插入图片描述

当然,有云端版肯定也有本地版。

上面的训练会将训练数据上传到云端,让我们可以直接通过在线链接的方式访问自己的实验数据和实验进度 。但是还可以选择不上传,而通过本地命令在本机开启一个面板服务,其前端界面与云端版基本一致,同样能查看实验数据和详细信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/16999.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux中ftp配置

一、ftp协议 1、端口 ftp默认使用20、21端口 20端口用于建立数据连接 21端口用于建立控制连接 2、ftp数据连接模式 主动模式&#xff1a;服务器主动发起数据连接 被动模式&#xff1a;服务器被动等待数据连接 二、ftp安装 yum install -y vsftpd #---下…

使用httpx异步获取高校招生信息:一步到位的代理配置教程

概述 随着2024年中国高考的临近&#xff0c;考生和家长对高校招生信息的需求日益增加。了解各高校的专业、课程设置和录取标准对于高考志愿填报至关重要。通过爬虫技术&#xff0c;可以高效地从各高校官网获取这些关键信息。然而&#xff0c;面对大量的请求和反爬机制的挑战&a…

蓝桥杯物联网竞赛_STM32L071KBU6_字符串处理

前言&#xff1a; 个人感觉国赛相较于省赛难度上升的点在于对于接收的字符串的处理&#xff0c;例如串口发送的字符串一般包含字母字符串 数字字符串&#xff0c;亦或者更复杂&#xff0c;对于LORA也是如此&#xff0c;传递的字符串如#9#1亦或者#1a#90,#1#12&#xff0c;如何…

剖析【C++】——类与对象(上)超详解——小白篇

目录 1.面向过程和面向对象的初步认识 1.面向过程&#xff08;Procedural Programming&#xff09; 2.面向对象&#xff08;Object-Oriented Programming&#xff09; 概念&#xff1a; 特点&#xff1a; 总结 2.C 类的引入 1.从 C 语言的结构体到 C 的类 2.C 中的结构…

调用萨姆索诺夫函数:深入探索函数的参数与返回值

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、萨姆索诺夫函数的引入与调用 二、如何获取函数的返回值 三、无参数与无返回值的函数调…

帝国CMS验证码不显示怎么回事呢?

帝国CMS验证码有时候会不显示或打叉&#xff0c;总结自己的解决方法。 1、检查服务器是否开启GD库 测试GD库是否开启的方法&#xff1a;浏览器访问&#xff1a;/e/showkey/index.php&#xff0c;如果出现一堆乱码或报错&#xff0c;证明GD库没有开启&#xff0c;开启即可。 2…

[随笔] 在CSDN的6周年纪念日随笔

纪念 转眼已过6年&#xff0c;大一的时候学习编程&#xff0c;潜水 CSDN 学习各类博文&#xff0c;才学浅薄就没有主动写博文记录自己的学习历程。 过了段时间刚刚到了大二&#xff0c;很喜欢 Todolist&#xff0c;意气风发的写下《一份清爽的编程计划》&#xff0c;哈哈。 …

数据结构-队列(带图详解)

目录 队列的概念 画图理解队列 代码图理解 代码展示(注意这个队列是单链表的结构实现) Queue.h(队列结构) Queue.c(函数/API实现) main.c(测试文件) 队列的概念 队列&#xff08;Queue&#xff09;是一种基础的数据结构&#xff0c;它遵循先进先出&#xff08;First In …

二十八、openlayers官网示例Data Tiles解析——自定义绘制DataTile源数据

官网demo地址&#xff1a; https://openlayers.org/en/latest/examples/data-tiles.html 这篇示例讲解的是自定义加载DataTile源格式的数据。 先来看一下什么是DataTile&#xff0c;这个源是一个数组&#xff0c;与我们之前XYZ切片源有所不同。DataTile主要适用于需要动态生成…

算法简单笔记2

5月26号&#xff0c;之前学了两天算法烦了&#xff0c;去学了几天鸿蒙&#xff0c;今天又回来看一下算法&#xff0c;距离6月1日国赛还有6天&#xff0c;哈哈真是等死咯...... 一、蓝桥杯第13届国赛第1题填空题&#xff1a;重合次数 &#xff08;半难不难&#xff0c;写编程难…

探索演进:了解IPv4和IPv6之间的区别

探索演进&#xff1a;了解IPv4和IPv6之间的区别 在广阔的互联网领域中&#xff0c;设备之间的通信依赖于一组独特的协议来促进连接。前景协议中&#xff0c;IPv4&#xff08;Internet 协议版本 4&#xff09;和 IPv6&#xff08;Internet 协议版本 6&#xff09;是数字基础设施…

内存泄漏案例分享3-view的内存泄漏

案例3——view内存泄漏 前文提到&#xff0c;profile#Leaks视图无法展示非Activity、非Fragment的内存泄漏&#xff0c;换言之&#xff0c;除了Activity、Fragment的内存泄漏外&#xff0c;其他类的内存问题我们只能自己检索hprof文件查询了。 下面有一个极佳的view内存泄漏例子…

OrangePi AIpro开箱测评

OrangePi AIpro(8T) 香橙派联合华为精心打造&#xff0c;建设人工智能新生态 章节一&#xff1a;引言 1.1 背景 香橙派&#xff08;OrangePi&#xff09;是深圳市迅龙软件有限公司旗下开源产品品牌&#xff0c;迅龙软件成立于2005年&#xff0c;是全球领先的开源硬件和开源软…

初识C语言——第二十九天

数组 本章重点 1.一维数组的创建和初始化 数组的创建 注意事项&#xff1a; 1.一维由低数组在内存中是连续存放的&#xff01; 2.随着数组下标的增长&#xff0c;地址是由低到高变化的 2.二维数组的创建和初始化 注意事项&#xff1a; 1.二维数组在内存中也是连续存放的&am…

YOLOv8+PyQt5面部表情检测系统完整资源集合(yolov8模型,从图像、视频和摄像头三种路径识别检测,包含登陆页面、注册页面和检测页面)

1.资源包含可视化的面部表情检测系统&#xff0c;基于最新的YOLOv8训练的面部表情检测模型&#xff0c;和基于PyQt5制作的可视化面部表情检测系统&#xff0c;包含登陆页面、注册页面和检测页面&#xff0c;该系统可自动检测和识别图片或视频当中出现的八类面部表情&#xff1a…

211大学计算机专业不考408,新增的交叉专业却考408!南京农业大学计算机考研考情分析!

南京农业大学信息科技学院可追溯至1981年成立的计算中心和1985年筹建的农业图书情报专业。1987年设立了农业图书情报系&#xff0c;1993 年农业图书情报系更名为信息管理系&#xff0c;本科专业名称也于1999年更名为信息管理与信息系统专业。1994年计算中心开始招收计算机应用专…

开源网页视频会议,WebRTC音视频功能比较

1. 概述 OpenAI 发布了新一代旗舰生成模型 GPT-4o,这是一款真正的多模态大模型,可以「实时对音频、视觉和文本进行推理」。 支持与 AI 实时语音对话,且响应时间达到毫秒级;交互中可识别人类情绪并以相应的情感做出回应;多语言能力的提升,WebRTC 成为大模型关键能力。 视频会议…

theharvester一键收集域名信息(KALI工具系列十)

目录 1、KALI LINUX简介 2、theharvester工具简介 3、在KALI中使用theharvester 3.1 用搜索引擎扫描 3.2 扫描并输出结果 3.3 扫描某域名下的所有账号 3.4 使用所有的搜索引擎扫描 4、总结 1、KALI LINUX简介 Kali Linux 是一个功能强大、多才多艺的 Linux 发行版&…

【Docker学习】详细讲解docker ps

docker ps是我们操作容器次数最多的命令之一&#xff0c;但我们往往使用docker ps或是docker ps -a&#xff0c;对于该命令的其它选项&#xff0c;我们关注比较少。那么这一讲&#xff0c;我给大家详细讲讲该命令的全部方法。 命令&#xff1a; docker container ls 描述&am…