【卷积神经网络】MNIST 手写体识别

LeNet-5 是经典卷积神经网络之一,1998 年由 Yann LeCun 等人在论文 《Gradient-Based Learning Applied to Document Recognition》中提出。LeNet-5 网络使用了卷积层、池化层和全连接层,实现可以应用于手写体识别的卷积神经网络。TensorFlow 内置了 MNIST 手写体数据集,可以很方便地读取数据集,并应用于后续的模型训练过程中。本文主要记录了如何使用 TensorFlow 2.0 实现 MNIST 手写体识别模型。

目录

1 数据集准备

2 模型建立

3 模型训练


1 数据集准备

        TensorFlow 内置了 MNIST 手写体数据集,安装 TensorFlow 之后,使用如下代码就可以加载 MNIST 数据集:

import tensorflow as tfmnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()

        使用 Matplotlib 查看前 25 张图片,并打印对应的标签。

from matplotlib import pyplot as plt# 查看训练集
plt.figure(figsize=(3,3))
for i in range(25):plt.subplot(5,5,i+1)plt.imshow(train_x[i], cmap=plt.cm.binary)plt.xticks([])plt.yticks([])
plt.show()

        接着,使用 tf.one_hot() 函数,对图像的标签进行独热码编码。

# 预处理
train_y = tf.one_hot(train_y, depth=10)
test_y = tf.one_hot(test_y, depth=10)

2 模型建立

        MNIST 手写体数据集中,每张图像的大小是 28 × 28 × 1,按照 LeNet-5 模型的思路,构建卷积神经网络模型。选择 5 × 5 的卷积核,卷积层之后是 2 × 2 的平均池化,激活函数选择 sigmoid(除了最后一层)。

# the first layer can receive an 'input_shape' argument
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(filters=6,kernel_size=5,padding='valid',activation='sigmoid',input_shape=(28,28,1)),tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),tf.keras.layers.Conv2D(filters=16,kernel_size=5,padding='valid',activation='sigmoid'),tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),tf.keras.layers.Flatten(),tf.keras.layers.Dense(120,activation='sigmoid'),tf.keras.layers.Dense(84,activation='sigmoid'),tf.keras.layers.Dense(10,activation='softmax')
])

        使用 model.summary() 查看模型信息。

model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 24, 24, 6)         156       
                                                                 
 average_pooling2d (AverageP  (None, 12, 12, 6)        0         
 ooling2D)                                                       
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 16)          2416      
                                                                 
 average_pooling2d_1 (Averag  (None, 4, 4, 16)         0         
 ePooling2D)                                                     
                                                                 
 flatten (Flatten)           (None, 256)               0         
                                                                 
 dense (Dense)               (None, 120)               30840     
                                                                 
 dense_1 (Dense)             (None, 84)                10164     
                                                                 
 dense_2 (Dense)             (None, 10)                850       
                                                                 
=================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
_________________________________________________________________

3 模型训练

        使用 compile() 函数配置模型,优化算法为 Adam 算法,学习率为 0.001,损失函数为交叉熵损失函数。

# 模型配置
model.compile(optimizer=tf.keras.optimizer.Adam(learning_rate=1e-3),loss=tf.keras.losses.CategoricalCrossentropy(),metrics=['accuracy']
)# 模型训练
model.fit(x=train_x,y=train_y,validation_split=0.0,epochs=10
)

Epoch 1/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.5806 - accuracy: 0.8206
Epoch 2/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.1254 - accuracy: 0.9620
Epoch 3/10
1875/1875 [==============================] - 75s 40ms/step - loss: 0.0870 - accuracy: 0.9735
Epoch 4/10
1875/1875 [==============================] - 82s 43ms/step - loss: 0.0699 - accuracy: 0.9785
Epoch 5/10
1875/1875 [==============================] - 69s 37ms/step - loss: 0.0604 - accuracy: 0.9809
Epoch 6/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0530 - accuracy: 0.9833
Epoch 7/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.0477 - accuracy: 0.9854
Epoch 8/10
1875/1875 [==============================] - 70s 38ms/step - loss: 0.0436 - accuracy: 0.9863
Epoch 9/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.0399 - accuracy: 0.9873
Epoch 10/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0357 - accuracy: 0.9883
<keras.callbacks.History at 0x20a56b65660>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/58966.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023年智慧政务一网通办云平台顶层设计与建设方案PPT

导读:原文《2023年智慧政务一网通办云平台顶层设计与建设方案PPT》(获取来源见文尾),本文精选其中精华及架构部分,逻辑清晰、内容完整,为快速形成售前方案提供参考。 部分内容:

汽车3D HMI图形引擎选型指南【2023】

推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 2002年&#xff0c;电影《少数派报告》让观众深入了解未来。 除了情节的核心道德困境之外&#xff0c;大多数人都对它的技术着迷。 我们看到了自动驾驶汽车、个性化广告和用户可以无缝交互的 3D 计算机界面。 令人惊讶的是…

基于PID优化和矢量控制装置的四旋翼无人机(MatlabSimulink实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

使用GoLand进行远程调试

对部署进行配置 在此配置远程服务器地址&#xff0c;映射&#xff0c;是否自动上传(更新)等 选择SFTP类型 选择上传 另外给自动上传选项打钩 此时在本地修改某个文件&#xff0c;远程机器相应目录的文件&#xff0c;也会被同步修改 对远程调试进行配置 远程机器需要安装delve 而…

Linux命令(74)之unzip

linux命令之unzip 1.unzip介绍 linux命令是用来解压缩名称后缀为".zip"的文件 2.unzip用法 unzip [参数] filenname.zip unzip常用参数 参数说明-l显示压缩文件内所包含的文件-t检查备份文件是否正确无误-v显示命令执行详细过程-q不显示命令执行过程-P<密码&g…

时间范围选择时选中日期所使用的当日内具体时刻 如00:00:00= 23:59:59

<el-form-item label"审核时间&#xff1a;"><el-date-pickerv-model"auditTime"type"datetimerange"range-separator"至"value-format"yyyy-MM-dd HH:mm:ss"start-placeholder"开始日期"end-placeholde…

无意间发现这款可以免费制作3D翻页电子画册的网站

在博主努力的搜寻下&#xff0c;无意间发现这个网站&#xff0c;可以免费制作3D翻页电子画册。使用这个网站非常简单&#xff0c;只需上传你想要展示的图片和添加相应的文字&#xff0c;然后选择合适的模板和风格。接下来&#xff0c;就会自动转化成漂亮的3D翻页画册 工具嘛&am…

用idea查看sqlite数据库idea sqlite

1、安装Database Navigator插件 2、导入数据库并查看 3、删除数据库连接 在此做个笔记

深度掌握Python lxml库:高级篇

在Python的世界中&#xff0c;lxml是处理XML和HTML的一款强大且易用的库。在前面的初级和中级篇章中&#xff0c;我们介绍了如何解析、创建、修改XML文档&#xff0c;如何使用XPath查询&#xff0c;以及如何解析大型XML文档。在这篇高级篇章中&#xff0c;我们将继续深入研究lx…

STM32 物联网 4G CAT1 SIMCOM A7680C 源码

基于状态机编写4G模块驱动函数 #include "bsp.h" char LTE_TX[512],LTE_RX[512]; int LTE_TX_length,LTE_RX_length; char U1_TX_data[512],U1_RX_data[512]; char LTE_DATA_buf[512]; char LTE_COM_buf[512]; char LTE_SEND_buf[512];unsigned char U1_TX_flag,U1…

一文总结Redis知识点

目录 为什么基于MySQL又出现Redis&#xff1f;Redis的优点&#xff1f;Redis支持的基本命令Redis支持的数据结构1 String2 List3 Set4 Sorted Set5 Hash6 Stream 消息队列7 Geospatial 地理空间8 Bitmap 位图9 Bitfield 位域10 HyperLogLog Redis是单线程还是多线程&#xff1f…

SpringBoot Thymeleaf iText7 生成 PDF(2023/08/29)

SpringBoot Thymeleaf iText7 生成 PDF&#xff08;2023/08/29&#xff09; 文章目录 SpringBoot Thymeleaf iText7 生成 PDF&#xff08;2023/08/29&#xff09;1. 前言2. 技术思路3. 实现过程4. 测试 1. 前言 近期在项目种遇到了实时生成复杂 PDF 的需求&#xff0c;经过一番…

ceph中PGLog处理流程

ceph的PGLog是由PG来维护&#xff0c;记录了该PG的所有操作&#xff0c;其作用类似于数据库里的undo log。PGLog通常只保存近千条的操作记录(默认是3000条&#xff0c; 由osd_min_pg_log_entries指定)&#xff0c;但是当PG处于降级状态时&#xff0c;就会保存更多的日志&#x…

地铁+铁路系统防雷接地应用解决方案

地铁作为城市轨道交通的一种&#xff0c;是一种高效、安全、环保的公共交通方式。然而&#xff0c;地铁也面临着雷电灾害的威胁&#xff0c;尤其是在雷暴多发的地区。 雷电对地铁系统的影响主要有以下几个方面&#xff1a; 直接雷击&#xff1a;雷电直接击中地铁系统的设备或…

RISC-V公测平台发布 · 在SG2042上配置Jupiter+Octave科学计算环境

简介 JupyterHub是一个开源的共享计算平台&#xff0c;它为每个用户管理一个单独的 Jupyter 环境&#xff0c; 可以用于学生班级、企业数据科学小组或科学研究小组。它是一个多用户中心&#xff0c;可以生成、管理和代理多个单用户Jupyter笔记本服务器的实例。 GNU Octave是一…

WebSocket- 前端篇

官网代码 // 为了浏览器兼容websocketconst WebSocket window.WebSocket || window.MozWebSocket// 创建连接 this.socket new WebSocket(ws://xxx)// 连接成功this.socket.onopen (res)>{console.log(websocket 连接成功)this.socket.send(入参字段) // 传递的参数字段}…

zookeeper集群 数据一致性测试

搭建zk集群 准备 3台ubuntu 20机器每台机器提前安装配置好jdk-8apache-zookeeper-3.8.2-bin.tar.gz 开始 # 上传 scp -P 22 -r C:\Users\xcrj\Downloads\apache-zookeeper-3.8.2-bin.tar.gz root192.168.1.102:/root/zk/ # 解压 tar -zxvf apache-zookeeper-3.8.2-bin.tar.…

【前端】 Layui点击图片实现放大、关闭效果

实现效果&#xff1a;点击图片实现放大&#xff0c;点击空白处关闭效果。下图。 实现逻辑&#xff1a;二维码是使用JQ插件生成的&#xff0c;点击二维码&#xff0c;获取图片路径&#xff0c;通过Layui的弹窗显示放大后的图片。 Html <div id"qrcode" class&quo…

Java学习笔记31——字符流

字符流 字符流为什么出现字符流编码表字符串中的编码解码问题字符流写数据的5中方式字符流读数据的两种方式字符流复制Java文件 字符流 为什么出现字符流 汉字的存储如果是GBK编码占用2个字节&#xff0c;如果是UTF-8占用三个字节 用字节流复制文本文件时&#xff0c;文本文…

驱动开发错误汇编

本博文将会不定期更新。以便记录我的驱动开发生涯中的一些点点滴滴的技术细节和琐事。 生成时link找不到导出函数&#xff0c;比如"LNK2019 无法解析的外部符号 _FltCreateCommunicationPort32"。出现这种情况的原因是&#xff0c;驱动的编译环境忽略了所有的默认库&…