卷积神经网络(CNN)mnist手写数字分类识别的实现

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3.归一化
    • 4.可视化
    • 5.调整图片格式
  • 二、构建CNN网络模型
  • 三、编译模型
  • 四、训练模型
  • 五、预测
  • 六、知识点详解
    • 1. MNIST手写数字数据集介绍
    • 2. 神经网络程序说明
    • 3. 网络结构说明

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

3.归一化

# 将像素的值标准化至0到1的区间内。
train_images, test_images = train_images / 255.0, test_images / 255.0train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

4.可视化

plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])
plt.show()

在这里插入图片描述

5.调整图片格式

#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))# train_images, test_images = train_images / 255.0, test_images / 255.0train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

二、构建CNN网络模型

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10)
])model.summary()
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 26, 26, 32)        320       max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         )                                                               conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         2D)                                                             flatten (Flatten)           (None, 1600)              0         dense (Dense)               (None, 64)                102464    dense_1 (Dense)             (None, 10)                650       =================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
_________________________________________________________________

三、编译模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.1429 - accuracy: 0.9562 - val_loss: 0.0550 - val_accuracy: 0.9803
Epoch 2/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0460 - accuracy: 0.9856 - val_loss: 0.0352 - val_accuracy: 0.9883
Epoch 3/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0312 - accuracy: 0.9904 - val_loss: 0.0371 - val_accuracy: 0.9880
Epoch 4/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0234 - accuracy: 0.9925 - val_loss: 0.0330 - val_accuracy: 0.9900
Epoch 5/10
1875/1875 [==============================] - 14s 8ms/step - loss: 0.0176 - accuracy: 0.9944 - val_loss: 0.0311 - val_accuracy: 0.9904
Epoch 6/10
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0136 - accuracy: 0.9954 - val_loss: 0.0300 - val_accuracy: 0.9911
Epoch 7/10
1875/1875 [==============================] - 14s 8ms/step - loss: 0.0109 - accuracy: 0.9964 - val_loss: 0.0328 - val_accuracy: 0.9909
Epoch 8/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0097 - accuracy: 0.9969 - val_loss: 0.0340 - val_accuracy: 0.9903
Epoch 9/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0078 - accuracy: 0.9974 - val_loss: 0.0499 - val_accuracy: 0.9879
Epoch 10/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0078 - accuracy: 0.9976 - val_loss: 0.0350 - val_accuracy: 0.9902

五、预测

通过下面的网络结构我们可以简单理解为,输入一张图片,将会得到一组数,这组代表这张图片上的数字为0~9中每一个数字的几率,out数字越大可能性越大。
在这里插入图片描述

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

pre = model.predict(test_images)
pre[1]
313/313 [==============================] - 1s 2ms/step
array([  3.3290668 ,   0.29532072,  21.943724  ,  -7.09336   ,-15.3133955 , -28.765621  ,  -1.8459738 ,  -5.761892  ,-2.966585  , -19.222878  ], dtype=float32)

六、知识点详解

本文使用的是最简单的CNN模型- -LeNet-5,如果是第一次接触深度学习的话,可以先试着把代码跑通,然后再尝试去理解其中的代码。

1. MNIST手写数字数据集介绍

MNIST手写数字数据集来源于是美国国家标准与技术研究所,是著名的公开数据集之一。数据集中的数字图片是由250个不同职业的人纯手写绘制,数据集获取的网址为:http://yann.lecun.com/exdb/mnist/ (下载后需解压)。我们一般会采用(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()这行代码直接调用,这样就比较简单

MNIST手写数字数据集中包含了70000张图片,其中60000张为训练数据,10000为测试数据,70000张图片均是28*28,数据集样本如下:

在这里插入图片描述

如果我们把每一张图片中的像素转换为向量,则得到长度为28*28=784的向量。因此我们可以把训练集看成是一个[60000,784]的张量,第一个维度表示图片的索引,第二个维度表示每张图片中的像素点。而图片里的每个像素点的值介于0-1之间。

在这里插入图片描述

2. 神经网络程序说明

神经网络程序可以简单概括如下:
在这里插入图片描述

3. 网络结构说明

在这里插入图片描述

各层的作用

  • 输入层:用于将数据输入到训练网络
  • 卷积层:使用卷积核提取图片特征
  • 池化层:进行下采样,用更高层的抽象表示图像特征
  • Flatten层:将多维的输入一维化,常用在卷积层到全连接层的过渡
  • 全连接层:起到“特征提取器”的作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/145667.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

servlet页面以及控制台输出中文乱码

如图: servlet首页面: servlet映射页面: 以及控制台输出打印信息: 以上页面均出现中文乱码 下面依次解决: 1、首页面中文乱码 检查你的html或者jsp页面中meta字符集 如图设置成utf-8 然后重启一下tomcat 2、servl…

企业数字化过程中数据仓库与商业智能的目标

当前环境下,各领域企业通过数字化相关的一切技术,以数据为基础、以用户为核心,创建一种新的,或对现有商业模式进行重塑就是数字化转型。这种数字化转型给企业带来的效果就像是一次重构,会对企业的业务流程、思维文化、…

【数据结构与算法】JavaScript实现树结构(一)

文章目录 一、树结构简介1.1.简单了解树结构1.2.树结构的表示方式 二、二叉树2.1.二叉树简介2.2.特殊的二叉树2.3.二叉树的数据存储 三、二叉搜索树3.1.认识二叉搜索树3.2.二叉搜索树应用举例 一、树结构简介 1.1.简单了解树结构 什么是树? 真实的树:…

VS2017 IDE 编译时的 X86、x64位 是干什么的

指定编译出的程序是x86架构下的32位程序还是64位程序 VS2017项目配置X86改配置x64位_winform:把项目由x86改为x64-CSDN博客 vs平台选项:Any CPU,x86,x64_vs anycpu-CSDN博客

Jenkins中强制停止停不下来的job

# Script console 执行脚本 Jenkins 的提供了 script console 的功能,允许你写一些脚本,来调度 Jenkins 执行一些任务。 我们就可以利用 script console 来强制停止 job 执行。 首先进入 Jenkins 的 script console 页面: script console 路…

如何利用TSINGSEE青犀智能分析网关算法从人员、设备、行为三大角度进行监狱智能化升级改造

监狱作为关押犯人的重要场所,十分需要全天候全方位无死角的监控,但由于狱警人力有限,无法达到目前的监控需求。并且在监狱中,犯人众多也极易发生口角冲突,如若没有及时处理,就会发生难以挽回的意外。如何更…

MySQL分页查询的工作原理

前言 MySQL 的分页查询在我们的开发过程中还是很常见的,比如一些后台管理系统,我们一般会有查询订单列表页、商品列表页等。 示例: SELECT * FROM goods order by create_time limit 0,10; 在了解order by和limit的工作原理之前&#xff0c…

Python | 机器学习之逻辑回归

​🌈个人主页:Sarapines Programmer🔥 系列专栏:《人工智能奇遇记》🔖少年有梦不应止于心动,更要付诸行动。 目录结构 1. 机器学习之逻辑回归概念 1.1 机器学习 1.2 逻辑回归 2. 逻辑回归 2.1 实验目的…

开发一款小程序游戏需要多少钱?

小程序游戏的开发成本因多种因素而异,无法提供具体的固定数字。以下是影响小程序游戏开发成本的一些关键因素: 游戏规模和复杂度: 小程序游戏可以是简单的休闲游戏,也可以是更复杂的策略游戏。规模和复杂度会影响开发所需的时间和…

字节跳动小程序开发:探索创新的数字化世界

在数字化时代,字节跳动小程序开发成为企业数字化转型的关键一环。通过这一平台,企业能够借助先进的技术和丰富的功能,实现创新、引领市场潮流。本文将通过一些简单的技术代码示例,带你深入了解字节跳动小程序开发的魅力。 1. 小…

科研学习|研究方法——python T检验

一、单样本T检验 目的:检验单样本的均值是否和已知总体的均值相等前提条件: (1)总体方差未知,否则就可以利用 Z ZZ 检验(也叫 U UU 检验,就是正态检验); (2&a…

Android SmartTable根据int状态格式化文字及颜色

private void initData() {List<UserInfo> list new ArrayList<>();list.add(new UserInfo("一年级", "李同学", 6, 1, 120, 1100, 450, 0));list.add(new UserInfo("一年级", "张同学", 6, 2, 120, 1100, 450, 1));list…

工业数据采集分析 数据跨层次、跨环节、跨系统大整合

在“工业4.0”、“智能制造”、“工业互联网”的大背景下&#xff0c;数据采集早就成为一个广泛关注的热点话题&#xff0c;不论智能制造发展到何种程度&#xff0c;数据采集都是生产过程中应用非常频繁的需求&#xff0c;也是工业物联网不可或缺的一环。 利用工业数据采集系统…

处理机器学习数据集中字符串列(pandas.get_dummies)

如图&#xff0c;在数据集中week列的数据不是数值型&#xff0c;会导致我们在训练过程中难以处理。 而pandas库中有一个非常好用的函数&#xff0c;独热编码pandas.get_dummies(df) 使用此函数之后&#xff0c;会在原数据中新建各列代表Fri-Sun&#xff0c;值为0或1&#xff…

【大话Presto 】- 核心概念

文章目录 前言Operator Model And Iterator Model系统组成Connector数据模型查询执行模型StatementStageTaskSplitDriverOperatorExchangePipeLine 总结 前言 Presto&#xff08;PrestoDB&#xff09;是一个FaceBook开源的分布式MPP SQL引擎&#xff0c;旨在处理大规模数据的查…

全新的FL studio21.2版支持原生中文FL studio2024官方破解版

FL studio2024官方破解版是一款非常专业的音频编辑制作软件。目前它的版本来到了全新的FL studio21.2版&#xff0c;支持原生中文&#xff0c;全面升级的EQ、母带压线器等功能&#xff0c;让你操作起来更加方便&#xff0c;该版本经过破解处理&#xff0c;用户可永久免费使用&a…

「校园 Pie」 系列活动正式启航,首站走进南方科技大学!

PieCloudDB 社区校园行系列活动「校园 Pie」已正式启动。「校园 Pie」旨在促进数据库领域的学术交流&#xff0c;提供一个平台让学生们了解最新的数据库发展趋势和相关技术应用。 在「校园 Pie」系列活动中&#xff0c;PieCloudDB 社区将携拓数派技术专家&#xff0c;社区大咖…

可以免费使用的设计素材网站分享

UI设计师最怕什么&#xff1f; 没有创意&#xff0c;没有灵感&#xff0c;没有思路&#xff01; 在哪里可以得到idea&#xff1f;别担心&#xff0c;往下看&#xff01; 你知道网络有多大&#xff0c;你想要什么吗&#xff1f;今天&#xff0c;我想和大家分享一些宝藏网页设…

用户运营:如何搭建用户分析体系

在运营的工作范畴中&#xff0c;用户运营是很重要的一个环节&#xff0c;甚至有公司会设置专门的“用户运营”岗位。 用户运营的价值体现在多个方面&#xff0c;不仅可以帮助引流、吸引更多用户使用产品&#xff0c;在用户正式使用产品之后的运营则更为重要。通过日常用户运营&…

SpringMVC调用流程

SpringMVC的调用流程 SpringMVC涉及组件理解&#xff1a; DispatcherServlet : SpringMVC提供&#xff0c;我们需要使用web.xml配置使其生效&#xff0c;它是整个流程处理的核心&#xff0c;所有请求都经过它的处理和分发&#xff01;[ CEO ] HandlerMapping : SpringMVC提供&…