基于MNIST的手写数字识别

上次我们基于CIFAR-10训练一个图像分类器,梳理了一下训练模型的全过程,并且对卷积神经网络有了一定的理解,我们再在GPU上搭建一个手写的数字识别cnn网络,加深巩固一下

步骤

  1. 加载数据集
  2. 定义神经网络
  3. 定义损失函数
  4. 训练网络
  5. 测试网络

MNIST数据集简介

MINIST是一个手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),它有6w张训练样本和1w张测试样本,每张图的像素尺寸为28*28,如下图一共4个图片,这些图片文件均被保存为二进制格式

训练全过程

1.加载数据集

import torch
import torchvision
from torchvision import transforms
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

展示一些训练图片

import numpy as np
import matplotlib.pyplot as plt
def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
# 得到batch中的数据
dataiter = iter(train_loader)
images, labels = dataiter.next()imshow(torchvision.utils.make_grid(images))

2.定义卷积神经网络

import torch
import torch.nn as nn
import torch.nn.functional as F#可以调用一些常见的函数,例如非线性以及池化等
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# input image channel, 6 output channels, 5x5 square convolutionself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# 全连接 从16 * 4 * 4的维度转成120self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)#(2,2)也可以直接写成数字2x = x.view(-1, self.num_flat_features(x))#将维度转成以batch为第一维 剩余维数相乘为第二维x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # 第一个维度batch不考虑num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

3.定义损失和优化器

criterion = nn.CrossEntropyLoss()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

这里设置了 momentum=0.9 ,训练一轮的准确率由90%提到了98%

4.训练网络

def train(epochs):net.train()for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader):# 得到输入 和 标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 消除梯度optimizer.zero_grad()# 前向传播 计算损失 后向传播 更新参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印日志running_loss += loss.item()if i % 100 == 0:    # 每100个batch打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100))running_loss = 0.0
torch.save(net, 'mnist.pth')

net.train():调用方法时,模型将进入训练模式。在训练模式下,一些特定的模块,例如Dropout和Batch Normalization,将被启用。这是因为在训练过程中,我们需要使用Dropout来防止过拟合,并使用Batch Normalization来加速收敛

net.eval():调用方法时,模型将进入评估模式。在评估模式下,一些特定的模块,例如Dropout和Batch Normalization,将被禁用。这是因为在评估过程中,我们不需要使用Dropout来防止过拟合,并且Batch Normalization的统计信息应该是固定的。

5.测试网络

在其它地方导入模型测试时需要将类的定义添加到加载模型的这个py文件中

from mnist.py import Net  # 导入会运行mnist.py
net = torch.load('mnist.pth')testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)correct = 0
total = 0
net.to('cpu') 
print(net)with torch.no_grad():  # 或者model.eval()for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

训练一轮速度

GPU:10s

CPU:10s

训练三轮速度

GPU:24.5s

CPU:28.6s

得出结论:训练数据计算量少的时候,无论在CPU上还是GPU,性能几乎都是接近的,而当训练数据计算量达到一定多的时候,GPU的优势就比较显著直观了

小小实验:

(1)加载并测试一张图片,正确则输出True

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import cv2
import numpy as npclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)  x = x.view(-1, self.num_flat_features(x))  x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  num_features = 1for s in size:num_features *= sreturn num_featurescorrect = 0
total = 0
net = torch.load('mnist.pth')
net.to('cpu')
# print(net)with torch.no_grad(): imgdir = '3.jpeg'img = cv2.imread(imgdir, 0)img = cv2.resize(img, (28, 28))trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])image = trans(img)image = image.unsqueeze(0)label = torch.tensor([int(imgdir.split('.')[0])])outputs = net(image)_, predicted = torch.max(outputs.data, 1)print(predicted)print((predicted == label).item())

拿刚刚训练的模型试了6张数字图片,只有一张2是预测对的....

unsuqeeze:通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1,参数:【0, 1, 2】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1752.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小扎万字深度访谈:最强开源大模型Llama 3发布,Meta的AGI路径和开源哲学

今天Meta发布了史上最强开源大模型Llama 3,一口气发布了 8B 和 70B 2个预训练和指令微调模型,对比同级别的参数模型,性能上均达到了最佳。 此外,Meta还发布了基于Llama 3的AI助手Meta AI,可以在Facebook、Instagram、W…

一举颠覆Transformer!最新Mamba结合方案刷新多个SOTA,单张GPU即可处理140k

还记得前段时间爆火的Jamba吗? Jamba是世界上第一个生产级的Mamba大模型,它将基于结构化状态空间模型 (SSM) 的 Mamba 模型与 transformer 架构相结合,取两种架构之长,达到模型质量和效率兼得的效果。 在吞吐量和效率等关键衡量指…

基于函数计算FC3.0 部署AI数字绘画stable-diffusion自定义模型

基于函数计算FC3.0 部署AI数字绘画stable-diffusion自定义模型 部署AI数字绘画stable-diffusion曲线救国授权github账号 部署ffmpeg-app-v3总结 在讲述了函数计算FC3.0和函数计算FC2.0的操作界面UI改版以及在函数管理、函数执行引擎、自定义域名、函数授权及弹性伸缩规则方面进…

【管理咨询宝藏82】麦肯锡某化工企业战略咨询报告

本报告首发于公号“管理咨询宝藏”,如需阅读完整版报告内容,请查阅公号“管理咨询宝藏”。 【管理咨询宝藏82】麦肯锡某化工企业战略咨询报告 【格式】PPT版本,可以编辑 【关键词】战略咨询、MBB、业务规划 【核心观点】 - 打造面向客户的…

ROS2 仿真学习02 Gazebo导入官方示例模型

1.下载模型 git clone https://gitee.com/bingda-robot/gazebo_models.git将gazebo_models拖到到.gazebo当中(如果没看到.gazebo文件请按住CTRLh) 2.添加模型到gazebo的Insert 这就将官方示例的模型都导入到Gazebo 了 随便试试一个模型

SLS 查询新范式:使用 SPL 对日志进行交互式探索

作者:无哲 引言 在构建现代数据和业务系统的过程中,可观测性已经变得至关重要,日志服务(SLS)为 Log/Trace/Metric 数据提供了大规模、低成本、高性能的一站式平台服务,并提供数据采集、加工、投递、分析、…

海外平台运营为什么需要静态住宅IP?

在世界经济高度全球化的今天,许多企业家和电子商务卖家纷纷转向海外平台进行业务扩展。像亚马逊、eBay这样的跨国电商平台为卖家提供了巨大的机会,来接触到世界各地的顾客。然而,在这些平台上成功运营,尤其是维持账号的健康和安全…

脚本开发与自动化运维

shell脚本开发 grep搜索工具 参数&#xff1a; -A<显示行数>&#xff1a;-A NUM, --after-context NUM&#xff0c;除了显示符合范本样式的那一行之 外&#xff0c;并显示该行之后的内容。 -B<显示行数>&#xff1a;--before-context NUM&#xff0c;除了显示…

使用51单片机控制T0和T1分别间隔1秒2秒亮灭逻辑

#include <reg51.h>sbit LED1 P1^0; // 设置LED1灯的接口 sbit LED2 P1^1; // 设置LED2灯的接口unsigned int cnt1 0; // 设置LED1灯的定时器溢出次数 unsigned int cnt2 0; // 设置LED2灯的定时器溢出次数// 定时器T0 void Init_Timer0() {TMOD | 0x01;; // 定时器…

数据分析师平均薪资18322,这11个行业需求量最大!

2024年&#xff0c;是一个被数据深刻影响的时代。数据&#xff0c;如同无形的燃料&#xff0c;驱动着现代社会的运转。从全球互联网用户每天产生的2.5亿TB数据&#xff0c;到制造业的传感器、金融交易、医疗病历等各个领域的海量信息&#xff0c;数据的量级每年都在呈指数级增长…

Linux 内核设备树 ranges属性

今天有人问了我一下ranges属性&#xff0c;找了相关资料确认后&#xff0c;记录一下&#xff1a; 参考资料链接&#xff1a;让你完全理解linux内核设备树ranges属性地址转换 - vkang - 博客园 (cnblogs.com) ranges属性定义如下&#xff1a; ranges < local_address pa…

SAP专家级实施商解读:SAP S/4HANA Cloud(PCE私有云) 的五大误解

五个关于SAP S/4HANA Cloud&#xff08;PCE私有云&#xff09;的重要疑问&#xff1a; ■ SAP太贵了&#xff1f; ■ SAP S/4HANA Cloud 只适用于大型企业&#xff1f; ■ ERP项目&#xff0c;尤其是 SAP 解决方案&#xff0c;太耗时了&#xff1f; ■ ERP项目/云项目没有优势&…

JAVA学习笔记29(集合)

1.集合 ​ *集合分为&#xff1a;单列集合、双列集合 ​ *Collection 接口有两个重要子接口 List Set&#xff0c;实现子类为单列集合 ​ *Map接口实现子类为双列集合&#xff0c;存放的King–Value ​ *集合体系图 1.1 Collection接口 1.接口实现类特点 1.collection实现…

PL_to_PS中断传输数据

PL_to_PS中断传输数据 实验功能&#xff1a;将PL端的数据存入BRAM&#xff0c;然后在PS端读出数据&#xff0c;用串口打印。通过中断来触发 参考文章&#xff1a; https://www.cnblogs.com/fhyfhy/p/11760986.html [ZYNQ_PS与PL通过BRAM交互&#xff08;三&#xff1a;PSPL读…

MyBatis 框架学习(II)

MyBatis 框架学习(II) 文章目录 MyBatis 框架学习(II)1. 介绍2. 准备&测试2.1 配置数据库连接字符串和MyBatis2.2 编写持久层代码 3. MyBatis XML基础操作3.1 Insert 操作3.2 Delete 操作3.3 Update 操作3.4 Select 操作 4. #{} 与 ${}的使用5. 动态SQL操作5.1 < if >…

去除图像周围的0像素,调整大小

在做分割任务时&#xff0c;经常需要处理图像&#xff0c;如果图像周围有一圈0像素&#xff0c;需要去除掉&#xff0c;重新调整大小 数组的处理 如果图像的最外一圈为0&#xff0c;我们将图像最外圈的图像0去除掉。 import numpy as npdef remove_outer_zeros(arr):# 获取数…

纠正对CAN的错误认识

STM32CUBEMX系列——CAN通讯的配置_stm32cubemx 配置103 can-CSDN博客 STM32之CAN通信_stm32 can通信-CSDN博客 在回环模式下&#xff0c;发送的数据帧会在控制器内部被立即接收&#xff0c;而不会通过总线传播到其他节点。这种模式可以确保在没有其他节点干扰的情况下&#…

AI边缘计算盒子+ThingSense管理平台,推动明厨亮灶智慧监管新篇章

背景随着“互联网”时代的浪潮汹涌而至&#xff0c;国家及各地政府纷纷在“十四五”规划中明确指出&#xff0c;强化食品安全管理&#xff0c;利用技术手段实现智慧监管是刻不容缓的任务。为此&#xff0c;各地正加速推进“互联网明厨亮灶”的建设步伐&#xff0c;实现系统对接…

C# 使用 ThoughtWorks.QRCode 生成二维码

目录 关于 ThoughtWorks.QRCode 开发运行环境 方法设计 代码实现 调用示例 Logo图标透明化 小结 关于 ThoughtWorks.QRCode 二维码是用某种特定的几何图形按一定规律在平面分布的、黑白相间的、记录数据符号信息的图形&#xff0c;在应用程序开发中也被广泛使用&#x…

vue+node使用RSA非对称加密,实现登录接口加密密码

背景 登录接口&#xff0c;密码这种重要信息不可以用明文传输&#xff0c;必须加密处理。 这里就可以使用RSA非对称加密&#xff0c;后端生成公钥和私钥。 公钥&#xff1a;给前端&#xff0c;公钥可以暴露出来&#xff0c;没有影响&#xff0c;因为公钥加密的数据只有私钥才…