Tensorflow 神经网络作业手写数字识别 训练、回测准确率

大白话讲解卷积神经网络工作原理,推荐一个bilibili的讲卷积神经网络的视频,up主从youtube搬运过来,用中文讲了一遍。

这篇文章是 TensorFlow 2.0 Tutorial 入门教程的第五篇文章,介绍如何使用卷积神经网络(Convolutional Neural Network, CNN)来提高mnist手写数字识别的准确性。之前使用了最简单的784x10的神经网络,达到了 0.91 的正确性,而这篇文章在使用了卷积神经网络后,正确性达到了0.99

卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。

卷积神经网络由一个或多个卷积层和顶端的全连通层(对应经典的神经网络)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网络能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网络在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网络,卷积神经网络需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。

——维基百科

1. 安装TensorFlow 2.0

Google与2019年3月发布了TensorFlow 2.0,TensorFlow 2.0 清理了废弃的API,通过减少重复来简化API,并且通过Keras能够轻松地构建模型,从这篇文章开始,教程示例采用TensorFlow 2.0版本。

1

pip install tensorflow==2.0.0-beta0

或者在这里下载whl包安装:Links for tensorflow

2. 代码目录结构

1
2
3
4
5
6
7
8
9
10
11
12
13

data_set_tf2/  # TensorFlow 2.0的mnist数据集
    |--mnist.npz  
test_images/   # 预测所用的图片
    |--0.png
    |--1.png
    |--4.png
v4_cnn/
    |--ckpt/   # 模型保存的位置
        |--checkpoint
        |--cp-0005.ckpt.data-00000-of-00001
        |--cp-0005.ckpt.index
    |--predict.py  # 预测代码
    |--train.py    # 训练代码

3. CNN模型代码(train.py)

模型定义的前半部分主要使用Keras.layers提供的Conv2D(卷积)与MaxPooling2D(池化)函数。

CNN的输入是维度为 (image_height, image_width, color_channels)的张量,mnist数据集是黑白的,因此只有一个color_channel(颜色通道),一般的彩色图片有3个(R,G,B),熟悉Web前端的同学可能知道,有些图片有4个通道(R,G,B,A),A代表透明度。对于mnist数据集,输入的张量维度就是(28,28,1),通过参数input_shape传给网络的第一层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models


class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))

        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))

        model.summary()

        self.model = model

model.summary()用来打印我们定义的模型的结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________

我们可以看到,每一个Conv2DMaxPooling2D层的输出都是一个三维的张量(height, width, channels)。height和width会逐渐地变小。输出的channel的个数,是由第一个参数(例如,32或64)控制的,随着height和width的变小,channel可以变大(从算力的角度)。

模型的后半部分,是定义输出张量的。layers.Flatten会将三维的张量转为一维的向量。展开前张量的维度是(3, 3, 64) ,转为一维(576)的向量后,紧接着使用layers.Dense层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。

最大值的下标即可代表对应的数字,使用numpy很容易计算出来:

1
2
3
4
5
6

import numpy as np

y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

4. mnist数据集预处理(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13

class DataSource(object):
    def __init__(self):
        # mnist数据集存储的位置,如何不存在将自动下载
        data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
        (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
        # 6万张训练图片,1万张测试图片
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
        # 像素值映射到 0 - 1 之间
        train_images, test_images = train_images / 255.0, test_images / 255.0

        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels

5. 开始训练并保存训练结果(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

class Train:
    def __init__(self):
        self.cnn = CNN()
        self.data = DataSource()

    def train(self):
        check_path = './ckpt/cp-{epoch:04d}.ckpt'
        # period 每隔5epoch保存一次
        save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)

        self.cnn.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])

        test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
        print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))


if __name__ == "__main__":
    app = Train()
    app.train()

在执行python train.py后,会得到以下的结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 45s 749us/sample - loss: 0.1477 - accuracy: 0.9536
Epoch 2/5
60000/60000 [==============================] - 45s 746us/sample - loss: 0.0461 - accuracy: 0.9860
Epoch 3/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0336 - accuracy: 0.9893
Epoch 4/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0257 - accuracy: 0.9919
Epoch 5/5
59968/60000 [============================>.] - ETA: 0s - loss: 0.0210 - accuracy: 0.9930
Epoch 00005: saving model to ./ckpt/cp-0005.ckpt
60000/60000 [==============================] - 51s 848us/sample - loss: 0.0210 - accuracy: 0.9930
10000/10000 [==============================] - 3s 290us/sample - loss: 0.0331 - accuracy: 0.9901
准确率: 0.9901,共测试了10000张图片

可以看到,在第一轮训练后,识别准确率达到了0.9536,5轮之后,使用测试集验证,准确率达到了0.9901

在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt。接下来我们就可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测了。

6. 图片预测(predict.py)

为了将模型的训练和加载分开,预测的代码写在了predict.py中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import tensorflow as tf
from PIL import Image
import numpy as np

from train import CNN

'''
python 3.7
tensorflow 2.0.0b0
pillow(PIL) 4.3.0
'''

class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('./ckpt')
        self.cnn = CNN()
        # 恢复网络权重
        self.cnn.model.load_weights(latest)

    def predict(self, image_path):
        # 以黑白方式读取图片
        img = Image.open(image_path).convert('L')
        img = np.reshape(img, (28, 28, 1)) / 255.
        x = np.array([1 - img])

        # API refer: https://keras.io/models/model/
        y = self.cnn.model.predict(x)

        # 因为x只传入了一张图片,取y[0]即可
        # np.argmax()取得最大值的下标,即代表的数字
        print(image_path)
        print(y[0])
        print('        -> Predict digit', np.argmax(y[0]))


if __name__ == "__main__":
    app = Predict()
    app.predict('../test_images/0.png')
    app.predict('../test_images/1.png')
    app.predict('../test_images/4.png')

最终,执行predict.py,可以看到:

1
2
3
4
5
6
7
8
9
10

$ python predict.py
../test_images/0.png
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
        -> Predict digit 0
../test_images/1.png
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
        -> Predict digit 1
../test_images/4.png
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
        -> Predict digit 4

任何程序错误,以及技术疑问或需要解答的,请添加

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547074.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React Native绑定微信分享/登录/支付(演示+实现步骤+注意事项)

React Native(以下简称RN)绑定微信分享/微信登录/微信支付的实现演示源码注意事项!微信的调用大同小异,本文实现了微信的分享功能,其他功能可以在链接文档里面找到具体的方法。 本文分文三个部分:一、效果…

open×××+Mysql+PAM构建强大的***系统

openMysqlPAM构建强大的***系统本次为新的生产环境部署系统而采用了这个方案,陆续会将实际的生产架构整理出来.由于涉及到公司的各种敏感信息,已经将IP做了替换中途可能有出入 敬请谅解。等我找时间画图出来一并奉上。如果有根本上的问题,请大…

Linux 下 Qt 5.12无法切换中文输入法

无法切换中文输入的原因是当前下载的QtCreator中没有适配当前输入法框架(ibus、fcitx)的动态库 解决方法: 一、安装对应的输入法插件 1、如果是fcitx: ubuntu18.04:sudo apt-get install libfcitx-qt5-dev 拷贝系统路…

微信中通过页面(H5)直接打开本地app的解决方案

简述 微信中通过页面直接打开app分为安卓版和IOS版,两个的实现方式是完全不同的。 安卓版实现:使用腾讯的应用宝,只要配置了“微下载”之后,打开链接腾讯会帮你判断本地是否已经安装了app,如果本地安装就直接打开&am…

GIMP 基本教程

本文主要记录笔者使用GIMP的心得,有些具体操作内容会省略,读者可以酌情阅读,内容较多,建议通过右边目录查看。 GIMP 是高级图片编辑器。 您可以使用它来编辑,增强和修饰照片 和扫描,创建工程图以及制作自己的图像。 它具有大量的专…

iOS通用链接(Universal Links)突然点击无效的解决方案

接上文《微信中通过页面(H5)直接打开本地app的解决方案》已经把iOS搞定并且已经正常能跑了,突然就再也用不了了... 问题描述 测试告诉我,如果从微信打开App之后,点击App右上角的应用网址之后,iOS通用链接就费了,在也…

如何利用shell脚本和client-go实现自己的k8s调度器

调度器介绍 scheduler 是k8s master的一部分,作为插件存在于k8s生态体系。 自定义调度器方式 添加功能重新编译实现自己的调度器(multi-scheduler)scheduler调用扩展程序实现最终调度(Kubernetes scheduler extender&#xff09…

Linux ubuntu安装搜狗输入法

1.下载搜狗输入法的安装包 下载地址为:http://pinyin.sogou.com/linux/,如下图,要选择与自己系统位数一致的安装包, 我的系统是 64 位,所以我下载 64 位的安装包 sogoupinyin_2.2.0.0108_amd64.deb 安装方法: 1.打开命令终端,输入: sudo apt-get install xxx.deb 路径 2.重启电…

React Native顶|底部导航使用小技巧

导航一直是App开发中比较重要的一个组件,ReactNative提供了两种导航组件供我们使用,分别是:NavigatorIOS和Navigator,但是前者只能用于iOS平台,后者在ReactNative0.44版本以后已经被移除了。 好在有人提供了更好的导航…

Linux QT5.12 一种整体界面字体设置的方法及设置PlainTextEdit组件的字体大小方法

1.在Linux QT5.12开发界面时,经常会涉及到界面字体大小的设置,默认字体一般比较小,解决方法如下: 在main函数中添加代码: // // 一种整体界面字体设置的方法: QFont font a.font(); font.setPointSize(14); a.setFont(font); // 2.在L…

Win7电脑,无法把文件保存到桌面上?

今天有用户反映重装了Win7后&#xff0c;文件无法另存到桌面上&#xff0c;解决方法如下&#xff1a;1、在任何地方打开资源管理器&#xff0c;按<Alt><F>键打开资源管理器的菜单&#xff1b;2、选择“工具”的“文件夹选项”&#xff0c;在“导航窗格”里选上“显…

ReactNative常用组件汇总

导航组件react-navigation: https://github.com/react-community/react-navigation 网络请求asios: https://github.com/mzabriskie/axios 设备信息react-native-device-info: https://github.com/rebeccahughes/react-native-device-info 缓存使用react-native-storage: https…

Yolov5训练自己的数据集之制作数据集

在VOC 2018文件夹下有五个文件夹&#xff0c;搜集好的图片放在JPEGImages文件夹下&#xff1b;标注后数据保存在Annotations文件夹下&#xff1b;labels文件夹在数据集的训练时用到&#xff1b;在ImageSets文件夹下有下面三个文件夹&#xff0c;在Main文件夹中有一个train.txt文…

ReactNative布局样式总结

flex number 用于设置或检索弹性盒模型对象的子元素如何分配空间 flexDirection enum(row, row-reverse ,column,column-reverse) flexDirection属性决定主轴的方向&#xff0c;默认是“column”&#xff1a; row&#xff1a;主轴为水平方向&#xff0c;起点在左端row-rev…

Android 线程池对象-ThreadPoolExecutor浅析

本人最近在已经在91&#xff0c;百度应用等渠道上线的个人应用——铃声酷的代码里用到了ThreadPoolExecutor这一线程池对象去处理并发&#xff0c;个人感觉相当的给力啊&#xff01;它是并发实用程序开放源码库 util.concurrent&#xff0c;它包括互斥、信号量、诸如在并发访问…

Linux kubuntu x64系统下解决QT5.12编辑菜单和工具栏不显示图标问题

Linux kubuntu x64系统下发现QT5.12在设计视图下编辑菜单和工具栏显示图标,但是编译运行后发现菜单和工具栏不显示图标,如下图: 我的解决办法是: 1.在QT项目中,菜单和工具栏图标一定要添加到项目资源文件中(在资源编辑器中Add Prefix后,再添加文件,关闭资源编辑器后自动将图标…

PyTorch系列 (二): pytorch数据读取自制数据集并

PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 参考&#xff1a; PyTorch documentationPyTorch 码源 本文首先介绍了有关预处理包的源码&#xff0c;接着介绍了在数据处理中的具体应用&#xff1b; 1 P…

nodejs+nginx获取真实ip

nodejs nginx获取真实ip分为两部分&#xff1a; 第一、配置nginx&#xff1b;第二、通过nodejs代码获取&#xff1b; 其他语言也是一样的&#xff0c;都是配置nginx之后&#xff0c;在http头里面获取“x-forwarded-for”. 第一、配置nginx location / {proxy_set_header Ho…

【OSChina-MoPaaS应用开发大赛】豪美创新后台业务管理系统

2019独角兽企业重金招聘Python工程师标准>>> 应用名称&#xff1a;豪美创新后台业务管理系统 应用URL地址&#xff1a;http://tyz.sturgeon.mopaas.com/admin/index.html 登录&#xff1a;admin/admin 投票地址&#xff1a;http://www.oschina.net/mopaas-app-co…

QT5更改应用程序图标

1.准备好.ico的图片放在工程目录下&#xff0c;并添加到项目的资源文件中 2.在项目配置.pro文件中添加一下内容 RC_ICONS AppIcon.icoAppIcon为你的ico图片名字 3.在可视化设计文件.ui中选择主窗口&#xff0c;将其属性中的windowIcon一项右侧下三角单击&#xff0c;从“选择…