基于MNIST的手写数字识别

上次我们基于CIFAR-10训练一个图像分类器,梳理了一下训练模型的全过程,并且对卷积神经网络有了一定的理解,我们再在GPU上搭建一个手写的数字识别cnn网络,加深巩固一下

步骤

  1. 加载数据集
  2. 定义神经网络
  3. 定义损失函数
  4. 训练网络
  5. 测试网络

MNIST数据集简介

MINIST是一个手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),它有6w张训练样本和1w张测试样本,每张图的像素尺寸为28*28,如下图一共4个图片,这些图片文件均被保存为二进制格式

训练全过程

1.加载数据集

import torch
import torchvision
from torchvision import transforms
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

展示一些训练图片

import numpy as np
import matplotlib.pyplot as plt
def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
# 得到batch中的数据
dataiter = iter(train_loader)
images, labels = dataiter.next()imshow(torchvision.utils.make_grid(images))

2.定义卷积神经网络

import torch
import torch.nn as nn
import torch.nn.functional as F#可以调用一些常见的函数,例如非线性以及池化等
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# input image channel, 6 output channels, 5x5 square convolutionself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# 全连接 从16 * 4 * 4的维度转成120self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)#(2,2)也可以直接写成数字2x = x.view(-1, self.num_flat_features(x))#将维度转成以batch为第一维 剩余维数相乘为第二维x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # 第一个维度batch不考虑num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

3.定义损失和优化器

criterion = nn.CrossEntropyLoss()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

这里设置了 momentum=0.9 ,训练一轮的准确率由90%提到了98%

4.训练网络

def train(epochs):net.train()for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader):# 得到输入 和 标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 消除梯度optimizer.zero_grad()# 前向传播 计算损失 后向传播 更新参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印日志running_loss += loss.item()if i % 100 == 0:    # 每100个batch打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100))running_loss = 0.0
torch.save(net, 'mnist.pth')

net.train():调用方法时,模型将进入训练模式。在训练模式下,一些特定的模块,例如Dropout和Batch Normalization,将被启用。这是因为在训练过程中,我们需要使用Dropout来防止过拟合,并使用Batch Normalization来加速收敛

net.eval():调用方法时,模型将进入评估模式。在评估模式下,一些特定的模块,例如Dropout和Batch Normalization,将被禁用。这是因为在评估过程中,我们不需要使用Dropout来防止过拟合,并且Batch Normalization的统计信息应该是固定的。

5.测试网络

在其它地方导入模型测试时需要将类的定义添加到加载模型的这个py文件中

from mnist.py import Net  # 导入会运行mnist.py
net = torch.load('mnist.pth')testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)correct = 0
total = 0
net.to('cpu') 
print(net)with torch.no_grad():  # 或者model.eval()for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

训练一轮速度

GPU:10s

CPU:10s

训练三轮速度

GPU:24.5s

CPU:28.6s

得出结论:训练数据计算量少的时候,无论在CPU上还是GPU,性能几乎都是接近的,而当训练数据计算量达到一定多的时候,GPU的优势就比较显著直观了

小小实验:

(1)加载并测试一张图片,正确则输出True

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import cv2
import numpy as npclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)  x = x.view(-1, self.num_flat_features(x))  x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  num_features = 1for s in size:num_features *= sreturn num_featurescorrect = 0
total = 0
net = torch.load('mnist.pth')
net.to('cpu')
# print(net)with torch.no_grad(): imgdir = '3.jpeg'img = cv2.imread(imgdir, 0)img = cv2.resize(img, (28, 28))trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])image = trans(img)image = image.unsqueeze(0)label = torch.tensor([int(imgdir.split('.')[0])])outputs = net(image)_, predicted = torch.max(outputs.data, 1)print(predicted)print((predicted == label).item())

拿刚刚训练的模型试了6张数字图片,只有一张2是预测对的....

unsuqeeze:通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1,参数:【0, 1, 2】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/1752.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode-寻找重复数

287-寻找重复数 https://leetcode.cn/problems/find-the-duplicate-number/description/?envTypestudy-plan-v2&envIdtop-100-liked给定一个包含 n 1 个整数的数组 nums ,其数字都在 [1, n] 范围内(包括 1 和 n),可知至少存…

小扎万字深度访谈:最强开源大模型Llama 3发布,Meta的AGI路径和开源哲学

今天Meta发布了史上最强开源大模型Llama 3,一口气发布了 8B 和 70B 2个预训练和指令微调模型,对比同级别的参数模型,性能上均达到了最佳。 此外,Meta还发布了基于Llama 3的AI助手Meta AI,可以在Facebook、Instagram、W…

Java使用腾讯翻译api开发app

//这是使用腾讯翻译接口的代码 package com.example.simpleocr; import com.tencentcloudapi.common.Credential; import com.tencentcloudapi.common.profile.ClientProfile; import com.tencentcloudapi.common.profile.HttpProfile; import com.tencentcloudapi.common.exce…

一举颠覆Transformer!最新Mamba结合方案刷新多个SOTA,单张GPU即可处理140k

还记得前段时间爆火的Jamba吗? Jamba是世界上第一个生产级的Mamba大模型,它将基于结构化状态空间模型 (SSM) 的 Mamba 模型与 transformer 架构相结合,取两种架构之长,达到模型质量和效率兼得的效果。 在吞吐量和效率等关键衡量指…

基于函数计算FC3.0 部署AI数字绘画stable-diffusion自定义模型

基于函数计算FC3.0 部署AI数字绘画stable-diffusion自定义模型 部署AI数字绘画stable-diffusion曲线救国授权github账号 部署ffmpeg-app-v3总结 在讲述了函数计算FC3.0和函数计算FC2.0的操作界面UI改版以及在函数管理、函数执行引擎、自定义域名、函数授权及弹性伸缩规则方面进…

【管理咨询宝藏82】麦肯锡某化工企业战略咨询报告

本报告首发于公号“管理咨询宝藏”,如需阅读完整版报告内容,请查阅公号“管理咨询宝藏”。 【管理咨询宝藏82】麦肯锡某化工企业战略咨询报告 【格式】PPT版本,可以编辑 【关键词】战略咨询、MBB、业务规划 【核心观点】 - 打造面向客户的…

【格式化日期】在Vue3中如何格式化日期

使用第三方库date-fns格式化处理日期 使用步骤&#xff1a; ① 安装 date-fns&#xff1a; npm install date-fns② 在 Vue 组件中使用 date-fns 来格式化日期&#xff1a; <script setup> import { ref } from vue; // 引入date-fns import { format } from date-fn…

opencv的高斯滤波函数

//1、高斯滤波器 GaussianBlur(NormalX, res1, Size(Ksize, Ksize), Sigma); //2、高斯分离卷积 Mat v getGaussianKernel(Ksize, Sigma); sepFilter2D(NormalX, res2, -1, v.t(), v); //3、普通卷积 filter2D(NormalX, res3, -1, v*v.t()); …

spring注解整理

spring注解整理 Configuration 使用Configuration注解来标注的类为配置类&#xff0c;配置类就相当于applicationContext.xml配置文件&#xff0c;可以在配置类中来配置bean Configurationpublic class MainConfig { /** * bean的类型是返回类型&#xff0c;bean的id默认…

2023-2024年人形机器人行业报告合集(精选397份)

人形机器人行业报告&#xff08;精选397份&#xff09; 2023-2024年 【以下是资料目录】 报告来源&#xff1a;下载教程&#xff08;海选智库&行业资源智库&#xff09; 2024流程工业智能制造机器人业务开启新增长曲线 2024电子皮肤行业深度研究报告&#xff1a;赋予机…

vue用法示例(一)

1、v-html html 插入&#xff0c;可以插入文本&#xff0c;也可以插入元素&#xff0c;如 message:"<h1>xxx</h1>" <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>Vue 测试实例 - 菜鸟教程(runo…

mysql基础18——权限管理

权限管理 根据不同的用户进行横向和纵向的分组 横向的分组 用户可以接触到的数据的范围 纵向的分组 用户对接触到的数据能访问到什么程度 把具有相同数据访问范围和程度的用户分为不同的类别 这种类别叫做角色 通过角色对相同权限的用户进行分组管理 可以使权限管理更加简单…

ROS2 仿真学习02 Gazebo导入官方示例模型

1.下载模型 git clone https://gitee.com/bingda-robot/gazebo_models.git将gazebo_models拖到到.gazebo当中&#xff08;如果没看到.gazebo文件请按住CTRLh&#xff09; 2.添加模型到gazebo的Insert 这就将官方示例的模型都导入到Gazebo 了 随便试试一个模型

SLS 查询新范式:使用 SPL 对日志进行交互式探索

作者&#xff1a;无哲 引言 在构建现代数据和业务系统的过程中&#xff0c;可观测性已经变得至关重要&#xff0c;日志服务&#xff08;SLS&#xff09;为 Log/Trace/Metric 数据提供了大规模、低成本、高性能的一站式平台服务&#xff0c;并提供数据采集、加工、投递、分析、…

海外平台运营为什么需要静态住宅IP?

在世界经济高度全球化的今天&#xff0c;许多企业家和电子商务卖家纷纷转向海外平台进行业务扩展。像亚马逊、eBay这样的跨国电商平台为卖家提供了巨大的机会&#xff0c;来接触到世界各地的顾客。然而&#xff0c;在这些平台上成功运营&#xff0c;尤其是维持账号的健康和安全…

算法刷题记录 Day51

算法刷题记录 Day51 Date: 2024.04.19 lc 42. 接雨水 // 单调栈 class Solution { public:int trap(vector<int>& height) {// 思路2&#xff1a;单调栈。当有个元素要入栈时。若该元素小于等于栈顶&#xff0c;则直接入栈&#xff1b;// 若该元素大于栈顶&#x…

脚本开发与自动化运维

shell脚本开发 grep搜索工具 参数&#xff1a; -A<显示行数>&#xff1a;-A NUM, --after-context NUM&#xff0c;除了显示符合范本样式的那一行之 外&#xff0c;并显示该行之后的内容。 -B<显示行数>&#xff1a;--before-context NUM&#xff0c;除了显示…

使用51单片机控制T0和T1分别间隔1秒2秒亮灭逻辑

#include <reg51.h>sbit LED1 P1^0; // 设置LED1灯的接口 sbit LED2 P1^1; // 设置LED2灯的接口unsigned int cnt1 0; // 设置LED1灯的定时器溢出次数 unsigned int cnt2 0; // 设置LED2灯的定时器溢出次数// 定时器T0 void Init_Timer0() {TMOD | 0x01;; // 定时器…

Leetcode 1047:删除字符串中的所有相邻重复项

给出由小写字母组成的字符串 S&#xff0c;重复项删除操作会选择两个相邻且相同的字母&#xff0c;并删除它们。 在 S 上反复执行重复项删除操作&#xff0c;直到无法继续删除。 在完成所有重复项删除操作后返回最终的字符串。答案保证唯一。 import java.util.Stack;public…

数据分析师平均薪资18322,这11个行业需求量最大!

2024年&#xff0c;是一个被数据深刻影响的时代。数据&#xff0c;如同无形的燃料&#xff0c;驱动着现代社会的运转。从全球互联网用户每天产生的2.5亿TB数据&#xff0c;到制造业的传感器、金融交易、医疗病历等各个领域的海量信息&#xff0c;数据的量级每年都在呈指数级增长…