TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络)

源代码/数据集已上传到 Github - tensorflow-tutorial-samples

卷积神经网络gif动图

大白话讲解卷积神经网络工作原理,推荐一个bilibili的讲卷积神经网络的视频,up主从youtube搬运过来,用中文讲了一遍。

这篇文章是 TensorFlow 2.0 Tutorial 入门教程的第五篇文章,介绍如何使用卷积神经网络(Convolutional Neural Network, CNN)来提高mnist手写数字识别的准确性。之前使用了最简单的784x10的神经网络,达到了 0.91 的正确性,而这篇文章在使用了卷积神经网络后,正确性达到了0.99

卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。

卷积神经网络由一个或多个卷积层和顶端的全连通层(对应经典的神经网络)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网络能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网络在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网络,卷积神经网络需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。

——维基百科

1. 安装TensorFlow 2.0

Google与2019年3月发布了TensorFlow 2.0,TensorFlow 2.0 清理了废弃的API,通过减少重复来简化API,并且通过Keras能够轻松地构建模型,从这篇文章开始,教程示例采用TensorFlow 2.0版本。

1
pip install tensorflow==2.0.0-beta0

或者在这里下载whl包安装:https://pypi.tuna.tsinghua.edu.cn/simple/tensorflow/

2. 代码目录结构

1
2
3
4
5
6
7
8
9
10
11
12
13
data_set_tf2/  # TensorFlow 2.0的mnist数据集|--mnist.npz  
test_images/   # 预测所用的图片|--0.png|--1.png|--4.png
v4_cnn/|--ckpt/   # 模型保存的位置|--checkpoint|--cp-0005.ckpt.data-00000-of-00001|--cp-0005.ckpt.index|--predict.py  # 预测代码|--train.py    # 训练代码

3. CNN模型代码(train.py)

模型定义的前半部分主要使用Keras.layers提供的Conv2D(卷积)与MaxPooling2D(池化)函数。

CNN的输入是维度为 (image_height, image_width, color_channels)的张量,mnist数据集是黑白的,因此只有一个color_channel(颜色通道),一般的彩色图片有3个(R,G,B),熟悉Web前端的同学可能知道,有些图片有4个通道(R,G,B,A),A代表透明度。对于mnist数据集,输入的张量维度就是(28,28,1),通过参数input_shape传给网络的第一层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, modelsclass CNN(object):def __init__(self):model = models.Sequential()# 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))model.add(layers.MaxPooling2D((2, 2)))# 第2层卷积,卷积核大小为3*3,64个model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.MaxPooling2D((2, 2)))# 第3层卷积,卷积核大小为3*3,64个model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.Flatten())model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))model.summary()self.model = model

model.summary()用来打印我们定义的模型的结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________

我们可以看到,每一个Conv2DMaxPooling2D层的输出都是一个三维的张量(height, width, channels)。height和width会逐渐地变小。输出的channel的个数,是由第一个参数(例如,32或64)控制的,随着height和width的变小,channel可以变大(从算力的角度)。

模型的后半部分,是定义输出张量的。layers.Flatten会将三维的张量转为一维的向量。展开前张量的维度是(3, 3, 64) ,转为一维(576)的向量后,紧接着使用layers.Dense层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。

最大值的下标即可代表对应的数字,使用numpy很容易计算出来:

1
2
3
4
5
6
import numpy as npy1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

4. mnist数据集预处理(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
class DataSource(object):def __init__(self):# mnist数据集存储的位置,如何不存在将自动下载data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)# 6万张训练图片,1万张测试图片train_images = train_images.reshape((60000, 28, 28, 1))test_images = test_images.reshape((10000, 28, 28, 1))# 像素值映射到 0 - 1 之间train_images, test_images = train_images / 255.0, test_images / 255.0self.train_images, self.train_labels = train_images, train_labelsself.test_images, self.test_labels = test_images, test_labels

因为mnist数据集国内下载不稳定,因此数据集也同步到了Github仓库。

对mnist数据集的介绍,大家可以参考这个系列的第一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)。

5. 开始训练并保存训练结果(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Train:def __init__(self):self.cnn = CNN()self.data = DataSource()def train(self):check_path = './ckpt/cp-{epoch:04d}.ckpt'# period 每隔5epoch保存一次save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)self.cnn.model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))if __name__ == "__main__":app = Train()app.train()

在执行python train.py后,会得到以下的结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 45s 749us/sample - loss: 0.1477 - accuracy: 0.9536
Epoch 2/5
60000/60000 [==============================] - 45s 746us/sample - loss: 0.0461 - accuracy: 0.9860
Epoch 3/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0336 - accuracy: 0.9893
Epoch 4/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0257 - accuracy: 0.9919
Epoch 5/5
59968/60000 [============================>.] - ETA: 0s - loss: 0.0210 - accuracy: 0.9930
Epoch 00005: saving model to ./ckpt/cp-0005.ckpt
60000/60000 [==============================] - 51s 848us/sample - loss: 0.0210 - accuracy: 0.9930
10000/10000 [==============================] - 3s 290us/sample - loss: 0.0331 - accuracy: 0.9901
准确率: 0.9901,共测试了10000张图片

可以看到,在第一轮训练后,识别准确率达到了0.9536,5轮之后,使用测试集验证,准确率达到了0.9901

在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt。接下来我们就可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测了。

6. 图片预测(predict.py)

为了将模型的训练和加载分开,预测的代码写在了predict.py中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
from PIL import Image
import numpy as npfrom train import CNN'''
python 3.7
tensorflow 2.0.0b0
pillow(PIL) 4.3.0
'''class Predict(object):def __init__(self):latest = tf.train.latest_checkpoint('./ckpt')self.cnn = CNN()# 恢复网络权重self.cnn.model.load_weights(latest)def predict(self, image_path):# 以黑白方式读取图片img = Image.open(image_path).convert('L')img = np.reshape(img, (28, 28, 1)) / 255.x = np.array([1 - img])# API refer: https://keras.io/models/model/y = self.cnn.model.predict(x)# 因为x只传入了一张图片,取y[0]即可# np.argmax()取得最大值的下标,即代表的数字print(image_path)print(y[0])print('        -> Predict digit', np.argmax(y[0]))if __name__ == "__main__":app = Predict()app.predict('../test_images/0.png')app.predict('../test_images/1.png')app.predict('../test_images/4.png')

最终,执行predict.py,可以看到:

1
2
3
4
5
6
7
8
9
10
$ python predict.py
../test_images/0.png
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]-> Predict digit 0
../test_images/1.png
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]-> Predict digit 1
../test_images/4.png
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]-> Predict digit 4

与TensorFlow1.0的区别总结

  1. 数据集从tensorflow.examples.tutorials.mnist切换到了tensorflow.keras.datasets
  2. Keras的接口成为了主力,datasets, layers, models都是从Keras引入的,而且在网络的搭建上,代码更少,更为简洁。

附: 推荐

  • 一篇文章入门 Python

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

职业梦想是计算机的英语作文,理想职业英语作文2篇

篇一:大学英语作文之我理想的工作my ideal jobMy Ideal JobAs college students, we will step into the society, and now we need to prepare for our future and arrange for our future career life, we need to take into consideration what to do in the fut…

C语言二维数组中的指针问题

#include "stdio.h" void main() {int a[5][5];int i,j;for (i0;i<5;i){for (j0;j<5;j){a[i][j] i;}} for (i0;i<5;i){for (j0;j<5;j){printf("%d ",a[i][j]);}printf("\n");} }转载于:https://blog.51cto.com/shamrock/12…

爬取微信小程序源码

爬取微信小程序源码 想知道爬取微信小程序有多简单吗&#xff1f;一张图、三个步骤&#xff0c;拿到你想要的任何微信小程序源码。

马老师 生产环境mysql主从复制、架构优化方案

Binlog日志(主服务器) > 中继日志(从服务器 运行一遍,保持一致)。从服务器是否要二进制日志取决于架构设计。如果二进制保存足够稳定&#xff0c;从性能上来说&#xff0c;从服务器不需要二进制日志。默认情况下&#xff0c;mysql主从复制是异步的。 异步&#xff1a;命令写…

10分钟带你学会微信小程序的反编译

以xxxxx小程序为例10分钟带你学会微信小程序的反编译 2019-11-28 12:59:26 以一个简单的例子介绍下小程序反编译操作流程 实验环境前置准备模拟器内软件安装获取小程序包开始解包导入开发者工具补充注意事项技术交流群有偿解包uniapp 逆向服务逆向教程小程序分包教程#实验环境…

反编译Android APK详细操作指南

早在4年前我曾发表过一篇关于《Android开发之反编译与防止反编译》的文章&#xff0c;在该文章中我对如何在Windows平台反编译APK做了讲解&#xff0c;如今用Mac系统的同学越来越多&#xff0c;也有很多朋友问我能否出一篇关于如何在Mac平台上反编译APK的文章&#xff0c;今天呢…

用idea新建springboot项目遇到的@Restcontroller不能导入的问题

我个人的解决方法如下&#xff1a; 1.springboot默认有 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><dependency><groupId>o…

Ext1.X的CheckboxSelectionModel默认全选之后不允许编辑的BUG解决方案

Ext1.X的CheckboxSelectionModel默认全选之后不允许编辑的BUG解决方案&#xff0c;ext 的CheckboxSelectionModel在后台默认选中之后&#xff0c;前台就不允许编辑的bug是存在的&#xff0c;因为CheckboxSelectionModel没有Disabled"true"的设置&#xff0c;只能想办…

广州海珠区计算机学校,2019广州海珠区电脑派位和对口直升表

点击即可领取期末各科试卷预约课程还可获赠免费的学习复习诊断— — 学而思爱智康课程优势 — —12年本地化教研沉淀个性化学习方式专属教学服务优质的教学系统2019广州海珠区电脑派位和对口直升表&#xff0c;各位的爸爸妈妈们看过来&#xff01;&#xff01;看看目标学校都招…

android的消息队列机制

android下的线程&#xff0c;Looper线程&#xff0c;MessageQueue&#xff0c;Handler&#xff0c;Message等之间的关系&#xff0c;以及Message的send/post及Message dispatch的过程。 Looper线程 我们知道&#xff0c;线程是进程中某个单一顺序的控制流&#xff0c;它是内核…

KNN算法检测手势动作

KNN算法原理&#xff1a; KNN&#xff08;k-nearest neighbor&#xff09;是一个简单而经典的机器学习分类算法&#xff0c;通过度量”待分类数据”和”类别已知的样本”的距离&#xff08;通常是欧氏距离&#xff09;对样本进行分类。 这话说得有些绕口&#xff0c;且来分解…

IIS负载均衡(转)

在大型Web应用系统中&#xff0c;由于请求的数据量过大以及并发的因素&#xff0c;导致Web系统会出现宕机的现象&#xff0c;解决这一类问题的方法我个人觉得主要在以下几个方面&#xff1a; 1.IIS 负载均衡。 2.数据库 负载均衡。 3.系统架构优化&#xff0c;比如报表服务器…

maven报错Non-resolvable parent POM for com.wpbxin:springboot2-first-example:0.0.1-SNAPSHOT: Could not

文章目录 一、maven报错二、一些说明三、出现问题的原因和几种解决方法 忽略SSL证书检查生成证书并导入到 JRE security 中使用默认的 maven 中央仓库使用 http 的镜像库四、参考链接 记录使用 maven 时遇到的问题。第一种方法最方便&#xff0c;亲测能用。 一、maven报错 mav…

计算机科技与技术对应岗位,计算机技术与软件专业技术资格名称及岗位基本任职条件...

超越梦想&#xff1a;计算机技术与软件专业技术资格名称及岗位基本任职条件专业资格名称对应专业技术职务等级学位或学历要求资历要求信息系统项目管理师、系统分析师、系统架构设计师、网络规划设计师、系统规划与管理师高级博士学位聘任工程师满2年硕士学位聘任工程师满4年本…

Intellij idea 出现错误 error:java: 无效的源发行版: 11解决方法

Select the project, then File > ProjectStructure > ProjectSettings > Modules -> sources You probably have the Language Level set at 9: Just change it to 8 借用下别人的图片&#xff0c; 我的默认是11报的错&#xff0c; 改成8后就好了

计算机房机柜标准尺寸,有哪些参数和尺寸符合机房机柜的安装要求

现如今服务器机房越来越多&#xff0c;不管是自己托管服务器还是租用服务器&#xff0c;机房机柜的选择是很重要的。机房机柜也会跟其高度厚度尺寸等相关。对于需要运行环境和要求高的机房&#xff0c;还需要选择有智能系统的机柜才行。还需要根据一些机房机柜的参数来考虑是否…

9月第1周国内IT技术类网站:CSDN覆盖数持续走低

根据国际统计机构Alexa公布的最新数据显示&#xff0c;9月第1周&#xff08;2013-09-02至2013-09-08&#xff09;&#xff0c;国内IT技术类网站排行榜中&#xff0c;CSDN以1710居于榜首&#xff0c;第二位是1170的博客园&#xff0c;第三位是670的51CTO。下面是具体情况&#x…

YOLOv3改进方法增加特征尺度和训练层数

YOLOv3改进方法 YOLOv3的改进方法有很多&#xff0c;本文讲述的是增加一个特征尺度。 以YOLOv3-darknet53&#xff08;ALexeyAB版本&#xff09;为基础&#xff0c;增加了第4个特征尺度&#xff1a;104*104。原版YOLOv3网络结构&#xff1a; YOLOv3-4l网络结构&#xff1a; 即…

uva 610(tarjan的应用)

题目链接&#xff1a;http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id23727 思路&#xff1a;首先是Tarjan找桥&#xff0c;对于桥&#xff0c;只能是双向边&#xff0c;而对于同一个连通分量而言&#xff0c;只要重新定向为同一个方向即可。 1 #include<ios…

Win7搭建NodeJs开发环境以及HelloWorld展示—图解

Windows 7系统下搭建NodeJs开发环境&#xff08;NodeJsWebStrom&#xff09;以及Hello World&#xff01;展示&#xff0c;大体思路如下&#xff1a;第一步&#xff1a;安装NodeJs运行环境。第二步&#xff1a;安装WebStrom开发工具。第三步&#xff1a;创建并运行NodeJs项目展…