卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

f44bc0601d8c56a8616bba8d154b88bb.png

LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧。

环境 Pycharm2019+Python3.7.6+tensorflow 2.0

话不多说,先放完整源码

from tensorflow.keras import layers, datasets, Sequential, losses, optimizers
import tensorflow as tf
import matplotlib.pyplot as pltdef get_data():(train_images, train_labels), (val_images, val_labels) = datasets.mnist.load_data()return train_images, train_labels, val_images, val_labelsdef model_build():network = Sequential([layers.Conv2D(6, kernel_size=3, strides=1, input_shape=(28, 28, 1)),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Conv2D(16, kernel_size=3, strides=1),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Conv2D(24, kernel_size=3, strides=1),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Flatten(),layers.Dense(120, activation='relu'),layers.Dropout(0.5),layers.Dense(84, activation='relu'),layers.Dense(10)  # 因为输出的是独热编码设置为10])network.summary()return networktrain_images, train_labels, val_images, val_labels = get_data()
plt.figure()
plt.imshow(train_images[0])  # 打印第一张图片检查数据
plt.colorbar()  # 色度条显示
plt.grid(False)  # 不显示网格
plt.show()
print(train_images.shape, train_labels.shape)
'''检查数据标签是否正确'''lables = ['0', '1', '2', '3', '4','5', '6', '7', '8', '9']
plt.figure(figsize=(10, 10))
for i in range(25):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(lables[train_labels[i]])
plt.show()
train_images = tf.expand_dims(train_images, axis=3)
val_images = tf.expand_dims(val_images, axis=3)
train_labels = tf.cast(train_labels, tf.int32)
val_labels = tf.cast(val_labels, tf.int32)
train_labels = tf.one_hot(train_labels, depth=10)
val_labels = tf.one_hot(val_labels, depth=10)
train_images = tf.convert_to_tensor(train_images)
print(train_images.dtype, train_labels.dtype)
if train_images.dtype != tf.float32:train_images = tf.cast(train_images, tf.float32)
print(train_images, train_labels)model = model_build()
earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_acc', min_delta=0.001, patience=112)
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=["acc"])
hist = model.fit(train_images, train_labels, epochs=20, batch_size=28, validation_data=[val_images, val_labels],callbacks=[earlystop_callback])print(hist.history.keys())
print(hist.history['acc'])
from tensorflow.keras import layers, datasets, Sequential, losses, optimizers
import tensorflow as tf
import matplotlib.pyplot as plt

先导入我们所需要的库,为了方便,我把 layers, datasets, Sequential, losses, optimizers做了特别导入。

def get_data():(train_images, train_labels), (val_images, val_labels) = datasets.mnist.load_data()return train_images, train_labels, val_images, val_labels

定义数据获取函数,从tensorflow的mnist中使用load_data()获取手写数字数据集,我们在这里会得到四个张量,(train_images, train_labels), (val_images, val_labels),分别为训练图像,训练标签,验证图像,验证标签。其中images的张量形状为(60000, 28, 28) labels张量形状为(60000, )

labels和images图像和标签索引相互对应。

def model_build():network = Sequential([layers.Conv2D(6, kernel_size=3, strides=1, input_shape=(28, 28, 1)),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Conv2D(16, kernel_size=3, strides=1),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Conv2D(24, kernel_size=3, strides=1),layers.MaxPool2D(pool_size=2, strides=2),layers.ReLU(),layers.Flatten(),layers.Dense(120, activation='relu'),layers.Dropout(0.5),layers.Dense(84, activation='relu'),layers.Dense(10)  # 因为输出的是独热编码设置为10])network.summary()return network

定义模型搭建函数

使用了Sequential封装网络

layers.Conv2D(6, kernel_size=3, strides=1, input_shape=(28, 28, 1)),

加入第一层,为Conv2D卷积层,卷积核个数为6个, 感受野为3*3,卷积步长为1,网格输入张量形状为(28, 28,1)现在我们主要讨论卷积在tf里的实现方式,卷积算法我会在未来另一篇文章中介绍,这里不再赘述。

 layers.MaxPool2D(pool_size=2, strides=2),

池化层,选用了MaxPool2D最大池化层,池化域2*2, 步长为2,pool_size=2, strides=2是一种常见的参数设置,可以使数据宽高缩小到原来的一半,算法到时候和卷积一并介绍。

layers.ReLU(),

激活函数层, 选用‘relu’函数

layers.Flatten(),
layers.Dense(120, activation='relu'),
layers.Dropout(0.5),
layers.Dense(84, activation='relu'),
layers.Dense(10)

这一部分为全连接层,在将数据输入全连接层前,要先使用flatten层对数据进行铺平处理。

可以设置dropout层“退火”防止过拟合

因为我们最后实现分类时,我们使用了独热编码来代替原来的label ,所以在最后的输出层设置了10。

train_images, train_labels, val_images, val_labels = get_data()
plt.figure()
plt.imshow(train_images[0])  # 打印第一张图片检查数据
plt.colorbar()  # 色度条显示
plt.grid(False)  # 不显示网格
plt.show()
print(train_images.shape, train_labels.shape)
'''检查数据标签是否正确'''
lables = ['0', '1', '2', '3', '4','5', '6', '7', '8', '9']
plt.figure(figsize=(10, 10))
for i in range(25):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(lables[train_labels[i]])

从get_data()函数中加载我们所需要的数据集

打印第一张图片,检查数据集标签是否正确,正确的标签使我们训练的关键因素。

120f9ed4a6fa62fa22f6a1fe28f3a33c.png

0e12658d9ca15a90dd165120afda3911.png

我们应该能得到这样的图片,标签表,发现标签对应是正确的,我们便可以继续。

图像展示函数使用的matplotlib库,具体不再介绍。

train_images = tf.expand_dims(train_images, axis=3)
val_images = tf.expand_dims(val_images, axis=3)
train_labels = tf.cast(train_labels, tf.int32)
val_labels = tf.cast(val_labels, tf.int32)
train_labels = tf.one_hot(train_labels, depth=10)
val_labels = tf.one_hot(val_labels, depth=10)
network = model_build()
train_images = tf.convert_to_tensor(train_images)
print(train_images.dtype, train_labels.dtype)
if train_images.dtype != tf.float32:train_images = tf.cast(train_images, tf.float32)
print(train_images, train_labels)

对数据的预处理,我觉得这部分很重要,在编写这个卷积网络时,我在数据的准备上犯了很多错误,导致程序无法运行或训练效果很差等。

train_images = tf.expand_dims(train_images, axis=3)
val_images = tf.expand_dims(val_images, axis=3)

我们在上面说过,我们image的shape为(60000, 28, 28),但2d卷积层的输入要求为4个维度,我们便将所有的image数据扩充了一个维度,变为(60000, 28, 28, 1) 灰白图像。

train_labels = tf.cast(train_labels, tf.int32)
val_labels = tf.cast(val_labels, tf.int32)
train_labels = tf.one_hot(train_labels, depth=10)
val_labels = tf.one_hot(val_labels, depth=10)

我们要将标签形式转化为独热编码,比如 [1]-->[0, 1, 0, 0, 0, 0, 0 ,0 ,0, 0] 2-->[0, 0, 1, 0, 0, 0, 0 ,0 ,0, 0],这样的好处是规避了标签本身可能存在的数据比较,比如 ‘1’标签大于‘2’标签,但在分类时标签‘1’和‘2’并没有大小关系,独热编码就很好的规避了这种可能存在的比较。

但在将编码独热化前,需要将label的数据转化成int32,否则会报错

network = model_build()
train_images = tf.convert_to_tensor(train_images)
print(train_images.dtype, train_labels.dtype)
if train_images.dtype != tf.float32:train_images = tf.cast(train_images, tf.float32)
print(train_images, train_labels)

Conv2D的输入类型为 tf.float32

model = model_build()
earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_acc', min_delta=0.001, patience=112)
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=["acc"])
hist = model.fit(train_images, train_labels, epochs=20, batch_size=28, validation_data=[val_images, val_labels],callbacks=[earlystop_callback])

编译训练的设置,十分重要,这里的参数对准确率有很大的影响。

我们先搭建网络

earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_acc', min_delta=0.001, patience=112)

使用了一个早停callback ,监控对象为验证集准确率,监控的分度值为0.001,步数112,实现了如果模型在连续在112次内val_acc始终没有在0.001的分度上有所提升,就认为已经收敛了,训练结束。这个callback内的参数是随意设置的,根据自己的目的参数可以调整,合理范围内一般不影响准确率。

model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=["acc"])
hist = model.fit(train_images, train_labels, epochs=20, batch_size=28, validation_data=[val_images, val_labels],callbacks=[earlystop_callback])

模型编译中,我们在分类的损失函数中一般选择CategoricalCrossentropy(from_logits=True)

交叉熵损失函数,其中在最后参数选定为True,使用softmax修订结果,softmax可以将值转化为概率值,这个选项会对训练的准确度产生巨大影响。softmax()

训练模型

print(hist.history.keys())
print(hist.history['acc'])

在最后我用hist变量接受我们训练的数据

2ad2e296b900c00ff97965d0cad9dc47.png

我们可以看到在epoch 16中我们的准确率已经达到了较高的水平。

可以通过hist.history返回的字典得到我们训练的误差等信息进行误差图像等可视化的设置,我在上一篇的全连接网络实例中介绍了一种简单的可视化方法。

在tensorflow的官网中可以学习他们对数据结果可视化的样例。

比如,我大概修改后,我们可以得到这样的效果

c6434e55b3b062345626fd9c7b6708b7.png

226580c81d1a89321ca1ab51c1341a51.png

蓝色条为模型对他分类的信任度,比如第一个模型认为他有100%的概率认为这个数字是7。

这篇CNN LeNet-5的实例笔记结束了,顺便提一句,这个手写数字数据集在上一篇提到的全连接神经网络中的准确率高达99.5%。LeNet-5在简单的灰色手写数字数据集的识别效果很好,但在复杂彩色图像下,性能就会急剧下降,下一篇介绍预计为VGG13卷积神经网络,可以对更复杂图像进行识别。最近时间不太充裕,晚一些我会继续发布在tensorflow与神经网络专栏中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/349382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

glassfish_多种监视和管理GlassFish 3的方法

glassfishGlassFish 3支持多种监视和管理方法。 在本文中,我将简要介绍GlassFish提供的管理,监视和管理方法。 GlassFish管理控制台 GlassFish基于Web的管理控制台GUI可能是GlassFish管理最著名的界面。 默认情况下,运行GlassFish后&#xf…

flask-mail异步发送邮件_SpringBoot 2.0 集成 JavaMail ,实现异步发送邮件

一、JavaMail的核心API1、API功能图解2、API说明(1)、Message 类:javax.mail.Message 类是创建和解析邮件的一个抽象类子类javax.mail.internet.MimeMessage :表示一份电子邮件。 发送邮件时,首先创建出封装了邮件数据的 Message 对象, 然后把…

Java 9中什么是私有的?

在进行面试时,我发现大多数应聘者都不知道Java中的private修饰符真正意味着什么。 他们知道一些足以进行日常编码的事情,但还远远不够。 这不成问题。 足够了解就足够了。 但是,了解Java的一些内部工作仍然很有趣。 在极少数情况下&#xff0…

java switch case怎么判断范围_【转】Java期末复习攻略!

期末19年就这样要过去了,终于到了小时候作文里的未来呢!然而,期末考试也随之来临了。不知大家“预习”的怎么样呢? 期末复习资料的放送快接近尾声了下面康康学长学姐们怎么教你们打java这个boss(下面是java大佬给大家的复习建议以…

spring aop示例_Spring JpaRepository示例(内存中)

spring aop示例这篇文章描述了一个使用内存中HSQL数据库的简单Spring JpaRepository示例。 该代码示例可从GitHub的Spring-JpaRepository目录中获得。 它基于带有注释的Spring-MVC-示例和此处提供的信息 。 JPA资料库 在此示例中,我们实现了一个虚拟bean&#xff1…

python人工智能入门优达视频_机器学习:优达教你搭建Python 环境的正确姿势

原标题:机器学习:优达教你搭建Python 环境的正确姿势为机器学习搭建好 Python 环境听起来简单,但有时候坑还不少。如果此前没有配置过类似的环境,很可能会苦苦折腾各种命令好几个小时。可是我明明只是想马上搞起来我的机器学习! 在…

java ee cdi_Java EE CDI ConversationScoped示例

java ee cdi在本教程中,我们将向您展示如何在Web应用程序中创建和使用ConversationScoped Bean。 在CDI中,bean是定义应用程序状态和/或逻辑的上下文对象的源。 如果容器可以根据CDI规范中定义的生命周期上下文模型来管理其实例的生命周期,则…

js input 自动换行_深入Slate.js - 拯救 ContentEditble

我们是钉钉的文档协同团队,我们在做一些很有意义的事情,其中之一就是自研的文字编辑器。为了把自研文字编辑器做好,我们调研了开源社区各种优秀编辑器,Slate.js 是其中之一(实际上,自研文字编辑器前&#x…

printf 地址_C程序显示主机名和IP地址

查找本地计算机的主机名和IP地址的方法有很多。这是使用C程序查找主机名和IP地址的简单方法。我们将使用以下功能:gethostname() :gethostname函数检索本地计算机的标准主机名。gethostbyname() :gethostbyname函数从主机数据库中检索与主机名…

java 定义变量时 赋值与不赋值_探究Java中基本类型和部分包装类在声明变量时不赋值的情况下java给他们的默认赋值...

探究Java中基本类型和部分包装类在声明变量时不赋值的情况下java给他们的默认赋值当基本数据类型作为普通变量(八大基本类型: byte,char,boolean,short,int,long,float,double)只有开发人员对其进行初始化,java不会对其进行初始化,如果不初始…

java 字符串 移位_使用位运算、值交换等方式反转java字符串-共四种方法

在本文中,我们将向您展示几种在Java中将String类型的字符串字母倒序的几种方法。StringBuilder(str).reverse()char[]循环与值交换byte循环与值交换apache-commons-lang3如果是为了进行开发,请选择StringBuilder(str).reverse()API。出于学习的目的&…

xstream xml模板_XStream – XStreamely使用Java中的XML数据的简便方法

xstream xml模板有时候,我们不得不处理XML数据。 而且大多数时候,这不是我们一生中最快乐的一天。 甚至有一个术语“ XML地狱”描述了程序员必须处理许多难以理解的XML配置文件时的情况。 但是,不管喜欢与否,有时我们别无选择&…

python知识点智能问答_基于知识图谱的智能问答机器人

研究背景及意义 智能问答是计算机与人类以自然语言的形式进行交流的一种方式,是人工智能研究的一个分支。 知识图谱本质上是一种语义网络,其结点代表实体(entity)或者概念(concept),边代表实体/…

java会了还学什么_java都学哪些内容?学完之后可以做哪些工作?

展开全部阶段一:揭开企业开发神秘面纱 (4周32313133353236313431303231363533e78988e69d8331333431336163)1) Web开发基础:HTML语言、JavaScript、CSS、DOM等2) Oracle数据库基础:安装、配置Oracle数据库,熟练掌握SQL语句3) 操作系…

Java中的RAII

资源获取即初始化( RAII )是Bjarne Stroustrup用C 引入的一种用于异常安全资源管理的设计思想。 感谢垃圾回收,Java 没有此功能,但是我们可以使用try-with-resources实现类似的功能。 约翰哈德斯(John Huddles&#x…

eclipse juno_Eclipse Juno上带有GlassFish的JavaEE 7

eclipse junoJava EE 7很热。 前四个JSR最近通过了最终批准选票,与此同时GlassFish 4达到了升级版83。 如果您关注我的博客,那么您将了解NetBeans的大部分工作。 但是我确实认识到,那里还有其他IDE用户,他们也有权试用最新和最出色…

java 生成校验验证码_java 验证码生成与校验

java绘图相关类验证码工具类package dt2008.util;import javax.imageio.ImageIO;import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpServletResponse;import java.awt.*;import java.awt.image.BufferedImage;import java.io.IOException;import ja…

红黑树中nil结点_什么是红黑树?程序员面试必问!

点击上方java小组,选择“置顶公众号”优质文章,第一时间送达当在10亿数据中只需要进行10几次比较就能查找到目标时,不禁感叹编程之魅力!人类之伟大呀! —— 学红黑树有感。终于,在学习了几天的红黑树相关的…

杰克逊JSON解析错误-UnrecognizedPropertyException:无法识别的字段,未标记为可忽略[已解决]...

在解析从我们的一个RESTful Web服务接收到的JSON字符串时,我收到此错误“线程“ main”中的异常com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException:无法识别的字段“人”(类Hello $ Person),不是标记…

mysql2008数据库配置_SQL Server 2008 R2 超详细安装图文教程

这篇文章主要介绍了SQL Server 2008 R2 超详细安装图文教程,需要的朋友可以参考下一、下载SQL Server 2008 R2安装文件二、将安装文件刻录成光盘或者用虚拟光驱加载,或者直接解压,打开安装文件,出现下面的界面安装SQL Server 2008 R2需要.NET…