深度学习案例之基于 CNN 的 MNIST 手写数字识别

一、模型结构

本文只涉及利用Tensorflow实现CNN的手写数字识别,CNN的内容请参考:卷积神经网络(CNN)

MNIST数据集的格式与数据预处理代码input_data.py的讲解请参考 :Tutorial (2)

二、实验代码

# -*- coding:utf-8 -*-
"""@Time  : @Author: Feng Lepeng@File  : mnist_cnn_tf_demo.py@Desc  : 手写数字识别的CNN网络 LeNet注意:一般情况下,我们都是直接将网络结构翻译成为这个代码,最多稍微的修改一下网络中的参数(超参数、窗口大小、步长等信息)https://deeplearnjs.org/demos/model-builder/https://js.tensorflow.org/#getting-started
"""
import math
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data# 数据加载
mnist = input_data.read_data_sets('data/mnist', one_hot=True)# 手写数字识别的数据集主要包含三个部分:训练集(5.5w, mnist.train)、测试集(1w, mnist.test)、验证集(0.5w, mnist.validation)
# 手写数字图片大小是28*28*1像素的图片(黑白),也就是每个图片由784维的特征描述
train_img = mnist.train.images
train_label = mnist.train.labels
test_img = mnist.test.images
test_label = mnist.test.labels
train_sample_number = mnist.train.num_examples# 相关的参数、超参数的设置
# 学习率,一般学习率设置的比较小
learn_rate_base = 1.0
# 每次迭代的训练样本数量
batch_size = 64
# 展示信息的间隔大小
display_step = 1# 输入的样本维度大小信息
input_dim = train_img.shape[1]
# 输出的维度大小信息
n_classes = train_label.shape[1]# 模型构建
# 1. 设置数据输入的占位符
x = tf.placeholder(tf.float32, shape=[None, input_dim], name='x')
y = tf.placeholder(tf.float32, shape=[None, n_classes], name='y')
learn_rate = tf.placeholder(tf.float32, name='learn_rate')def learn_rate_func(epoch):"""根据给定的迭代批次,更新产生一个学习率的值:param epoch::return:"""return learn_rate_base * (0.9 ** int(epoch / 10))def get_variable(name, shape=None, dtype=tf.float32, initializer=tf.random_normal_initializer(mean=0, stddev=0.1)):"""返回一个对应的变量:param name::param shape::param dtype::param initializer::return:"""return tf.get_variable(name, shape, dtype, initializer)# 2. 构建网络
def le_net(x, y):# 1. 输入层with tf.variable_scope('input1'):# 将输入的x的格式转换为规定的格式# [None, input_dim] -> [None, height, weight, channels]net = tf.reshape(x, shape=[-1, 28, 28, 1])# 2. 卷积层with tf.variable_scope('conv2'):# 卷积# conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None) => 卷积的API# data_format: 表示的是输入的数据格式,两种:NHWC和NCHW,N=>样本数目,H=>Height, W=>Weight, C=>Channels# input:输入数据,必须是一个4维格式的图像数据,具体格式和data_format有关,如果data_format是NHWC的时候,input的格式为: [batch_size, height, weight, channels] => [批次中的图片数目,图片的高度,图片的宽度,图片的通道数];如果data_format是NCHW的时候,input的格式为: [batch_size, channels, height, weight] => [批次中的图片数目,图片的通道数,图片的高度,图片的宽度]# filter: 卷积核,是一个4维格式的数据,shape: [height, weight, in_channels, out_channels] => [窗口的高度,窗口的宽度,输入的channel通道数(上一层图片的深度),输出的通道数(卷积核数目)]# strides:步长,是一个4维的数据,每一维数据必须和data_format格式匹配,表示的是在data_format每一维上的移动步长,当格式为NHWC的时候,strides的格式为: [batch, in_height, in_weight, in_channels] => [样本上的移动大小,高度的移动大小,宽度的移动大小,深度的移动大小],要求在样本上和在深度通道上的移动必须是1;当格式为NCHW的时候,strides的格式为: [batch,in_channels, in_height, in_weight]# padding: 只支持两个参数"SAME", "VALID",当取值为SAME的时候,表示进行填充,"在TensorFlow中,如果步长为1,并且padding为SAME的时候,经过卷积之后的图像大小是不变的";当VALID的时候,表示多余的特征会丢弃;net = tf.nn.conv2d(input=net, filter=get_variable('w', [5, 5, 1, 20]), strides=[1, 1, 1, 1], padding='SAME')net = tf.nn.bias_add(net, get_variable('b', [20]))# 激励 ReLu# tf.nn.relu => max(fetures, 0)# tf.nn.relu6 => min(max(fetures,0), 6)net = tf.nn.relu(net)# 3. 池化with tf.variable_scope('pool3'):# 和conv2一样,需要给定窗口大小和步长# max_pool(value, ksize, strides, padding, data_format="NHWC", name=None)# avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None)# 默认格式下:NHWC,value:输入的数据,必须是[batch_size, height, weight, channels]格式# 默认格式下:NHWC,ksize:指定窗口大小,必须是[batch, in_height, in_weight, in_channels], 其中batch和in_channels必须为1# 默认格式下:NHWC,strides:指定步长大小,必须是[batch, in_height, in_weight, in_channels],其中batch和in_channels必须为1# padding: 只支持两个参数"SAME", "VALID",当取值为SAME的时候,表示进行填充,;当VALID的时候,表示多余的特征会丢弃;net = tf.nn.max_pool(value=net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')# 4. 卷积with tf.variable_scope('conv4'):net = tf.nn.conv2d(input=net, filter=get_variable('w', [5, 5, 20, 50]), strides=[1, 1, 1, 1], padding='SAME')net = tf.nn.bias_add(net, get_variable('b', [50]))net = tf.nn.relu(net)# 5. 池化with tf.variable_scope('pool5'):net = tf.nn.max_pool(value=net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')# 6. 全连接with tf.variable_scope('fc6'):# 28 -> 14 -> 7(因为此时的卷积不改变图片的大小)net = tf.reshape(net, shape=[-1, 7 * 7 * 50])net = tf.add(tf.matmul(net, get_variable('w', [7 * 7 * 50, 500])), get_variable('b', [500]))net = tf.nn.relu(net)# 7. 全连接with tf.variable_scope('fc7'):net = tf.add(tf.matmul(net, get_variable('w', [500, n_classes])), get_variable('b', [n_classes]))act = tf.nn.softmax(net)return act# 构建网络
act = le_net(x, y)# 构建模型的损失函数
# softmax_cross_entropy_with_logits: 计算softmax中的每个样本的交叉熵,logits指定预测值,labels指定实际值
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=act, labels=y))# 使用Adam优化方式比较多
# learning_rate: 要注意,不要过大,过大可能不收敛,也不要过小,过小收敛速度比较慢
train = tf.train.AdamOptimizer(learning_rate=learn_rate).minimize(cost)# 得到预测的类别是那一个
# tf.argmax:对矩阵按行或列计算最大值对应的下标,和numpy中的一样
# tf.equal:是对比这两个矩阵或者向量的相等的元素,如果是相等的那就返回True,反正返回False,返回的值的矩阵维度和A是一样的
pred = tf.equal(tf.argmax(act, axis=1), tf.argmax(y, axis=1))
# 正确率(True转换为1,False转换为0)
acc = tf.reduce_mean(tf.cast(pred, tf.float32))# 初始化
init = tf.global_variables_initializer()with tf.Session() as sess:# 进行数据初始化sess.run(init)# 模型保存、持久化saver = tf.train.Saver()epoch = 0while True:avg_cost = 0# 计算出总的批次total_batch = int(train_sample_number / batch_size)# 迭代更新for i in range(total_batch):# 获取x和ybatch_xs, batch_ys = mnist.train.next_batch(batch_size)feeds = {x: batch_xs, y: batch_ys, learn_rate: learn_rate_func(epoch)}# 模型训练sess.run(train, feed_dict=feeds)# 获取损失函数值avg_cost += sess.run(cost, feed_dict=feeds)# 重新计算平均损失(相当于计算每个样本的损失值)avg_cost = avg_cost / total_batch# DISPLAY  显示误差率和训练集的正确率以此测试集的正确率if (epoch + 1) % display_step == 0:print("批次: %03d 损失函数值: %.9f" % (epoch, avg_cost))# 这里之所以使用batch_xs和batch_ys,是因为我使用train_img会出现内存不够的情况,直接就会退出feeds = {x: batch_xs, y: batch_ys, learn_rate: learn_rate_func(epoch)}train_acc = sess.run(acc, feed_dict=feeds)print("训练集准确率: %.3f" % train_acc)feeds = {x: test_img, y: test_label, learn_rate: learn_rate_func(epoch)}test_acc = sess.run(acc, feed_dict=feeds)print("测试准确率: %.3f" % test_acc)if train_acc > 0.9 and test_acc > 0.9:saver.save(sess, './mnist/model')breakepoch += 1# 模型可视化输出writer = tf.summary.FileWriter('./mnist/graph', tf.get_default_graph())writer.close()

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/453970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎样获取linux命令帮助?

获得命令使用帮助:内部命令:help COMMAND外部命令:COMMAND --help (大多数命令有help选项)命令手册:manualman [章节号] COMMAND其中man数据库是分章节的,相同的COMMAND出现在不同的章节表示…

编译安装 zbar 时两次 make 带来的惊喜

为了装 php 的条形码扩展模块 php-zbarcode,先装了一天的 ImageMagick 和 zbar。也许和我装的 Ubuntu 17.10 的有版本兼容问题吧,总之什么毛病都有,apt 不行,PPA 源也不行,编译安装还有几处源代码出错,装不…

python数组的乘法_在Python中乘法非常大的2D数组

我必须在Python中将非常大的2D数组乘以大约100次.每个矩阵由3200032000元素组成.我正在使用np.dot(X,Y),但是每次乘法都需要很长时间…在我的代码实例下面:import numpy as npX Nonefor i in range(100)multiplying Trueif X None:X generate_large_2darray()mu…

0阶指数哥伦布编码

指数哥伦布编码 规定语法元素的编解码模式的描述符如下: 比特串: b(8):任意形式的8比特字节(就是为了说明语法元素是为8个比特,没有语法上的含义) f(n):n位固定模式比特串(其值固定,如forbidde…

TensorFolw 报错

1、报错1&#xff1a;ValueError: Only call softmax_cross_entropy_with_logits with named arguments (labels..., logits..., ...) 提示出错如下&#xff1a; Traceback (most recent call last):File "/MNIST/softmax.py", line 12, in <module>cross_en…

CentOS7种搭建FTP服务器

安装vsftpd 首先要查看你是否安装vsftp [rootlocalhost /]# rpm -q vsftpd vsftpd-3.0.2-10.el7.x86_64 #显示也就安装成功了&#xff01; 如果没有则安装vsftpd [rootlocalhost/]# yum install -y vsftpd 完成后再检查一遍 [rootlocalhost /]# whereis vsftpd vsf…

js循环

顺序——要加分号结束 分支&#xff1a;让程序根据条件不同执行不同的代码 if else语句用来做分支的 if&#xff08;条件&#xff09;{代码} if&#xff08;条件&#xff09;{代码}else{代码} else if&#xff08;条件&#xff09;{代码} if是嵌套。 switch...case&#xff1…

x264函数调用关系图

1 encoder 2 slice write 3 analyse FFMPEG中MPEG-2编解码函数调用关系图 1 Encoder &#xff08;函数调用从左到右&#xff0c;下同&#xff1b;图片显示不全时&#xff0c;请下载显示&#xff09; 2 P帧运动估计流程图 3 B帧运动估计流程图 4 decoder ffmpeg的mpeg2编码I帧代…

Tensorflow 加载预训练模型和保存模型

使用tensorflow过程中&#xff0c;训练结束后我们需要用到模型文件。有时候&#xff0c;我们可能也需要用到别人训练好的模型&#xff0c;并在这个基础上再次训练。这时候我们需要掌握如何操作这些模型数据。看完本文&#xff0c;相信你一定会有收获&#xff01; 一、Tensorfl…

在 ActiveReports 中嵌入 Spread 控件

Spread 是一款很出色的表格控件&#xff0c;Spread 可以使开发人员把具有兼容 Microsoft Excel 的电子表格添加到程序中。ActiveReports 提供了一个非常灵活的、简单的报表环境。下面将展示怎样在 ActiveReports 中使用 Spread for WinForm。和其他三方控件一样&#xff0c;Spr…

sort()函数、C++

Sort&#xff08;&#xff09;函数是c一种排序方法之一&#xff0c;它使用的排序方法是类似于快排的方法&#xff0c;时间复杂度为n*log2(n) &#xff08;1&#xff09;Sort函数包含在头文件为#include<algorithm>的c标准库中。 II&#xff09;Sort函数有三个参数&#x…

python waitkey_python中VideoCapture(),read(),waitKey()的使用

有以下程序import cv2cap cv2.VideoCapture(0)while cap.isOpened():ret,frame cap.read()cv2.imshow(frame,frame)c cv2.waitKey(1)if c 27:breakcap.release()cv2.destroyAllWindows()说明&#xff1a;程序段里&#xff0c;1、cv2.VideoCapture()函数&#xff1a;cap cv…

深度学习案例之 验证码识别

本项目介绍利用深度学习技术&#xff08;tensorflow&#xff09;&#xff0c;来识别验证码&#xff08;4位验证码&#xff0c;具体的验证码的长度可以自己生成&#xff0c;可以在自己进行训练&#xff09; 程序分为四个部分 1、生成验证码的程序&#xff0c;可生成数字字母大…

windows下使用pthread库

最近在看《C多核高级编程》这本书&#xff0c;收集了些有用的东西&#xff0c;方便在windows下使用POSIX标准进行Pthread开发&#xff0c;有利于跨平台。 -------------------------------------------------- windows下使用pthread库时间:2010-01-27 07:41来源:罗索工作室 作…

day 05 多行输出与多行注释、字符串的格式化输出、预设创建者和日期

msg"hello1 hello2 hello3 " print(msg) 显示结果为&#xff1a; # " "只能进行单行的字符串 多行字符串用 ,前面设置变量&#xff0c;可以用 表示多行 msghello1 hello2 hello3print(msg) 显示结果为&#xff1a; 当然如果没有设置变量&#xff0c;…

python数值计算guess_【python】猜数字game,旨在提高初学者对Python循环结构的使用...

import random #引入生成随机数的模块需求&#xff1a;程序设定生成 1-20 之间的一个随机数&#xff0c;让用户猜日期&#xff1a;2019-10-21作者&#xff1a;xiaoxiaohui目的&#xff1a;猜数字game&#xff0c;旨在提高初学者对Python 变量类型以及循环结构的使用。secretNu…

调试九法-总体规则

调试规则规则1 理解系统规则2 制造失败规则3 不要想&#xff0c;而要看规则4 分而治之规则5 一次只改一个地方规则6 保持审计跟踪规则7 检查插头规则8 获得全新观点规则9 如果你不修复bug&#xff0c;它将依然存在转载于:https://www.cnblogs.com/uetucci/p/7987805.html

深度学习之循环神经网络(Recurrent Neural Network,RNN)

递归神经网络和循环神经网络 循环神经网络&#xff08;recurrent neural network&#xff09;&#xff1a;时间上的展开&#xff0c;处理的是序列结构的信息&#xff0c;是有环图递归神经网络&#xff08;recursive neural network&#xff09;&#xff1a;空间上的展开&#…

从北京回来的年轻人,我该告诉你点什么?

前言 就在上周末&#xff0c;我与公众号里的一个当地粉丝见面了&#xff0c;一起吃了顿饭&#xff0c;顺便聊了聊。先来简单交代下我们这位粉丝&#xff08;以下简称小L&#xff09;的经历以及诉求。 小L之前在北京八维研修学院培训的PHP&#xff0c;因为家庭原因&#xff0c;没…

Linphone编译【转载】

Linphone依赖太多的库&#xff0c;以致于稍有疏失&#xff0c;就会在编译&#xff0c;运行出错&#xff0c;都是由于依赖库安装的问题。 1 基础知识 1.1 动态库的连接 很多人安装完库后&#xff0c;configure依然报告这个库没有。这是对linux动态库知识匮乏造成&#xff0c;也就…