用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑。一个简单的RNN如下图所示:

将这个循环展开得到下图:

上一时刻的状态会传递到下一时刻。这种链式特性决定了RNN能够很好的处理序列化的数据,RNN 在语音识别,语言建模,翻译,图片描述等问题上已经取得了很到的结果。
根据输入、输出的不同和是否有延迟等一些情况,RNN在应用中有如下一些形态:

RNN存在的问题

RNN能够把状态传递到下一时刻,好像对一部分信息有记忆能力一样,如下图:

h3

的值可能会由x1,x2的值来决定。
但是,对于一些复杂场景

由于距离太远,中间间隔了太多状态,x1,x2对ht+1

的值几乎起不到任何作用。(梯度消失和梯度爆炸)

LSTM(Long Short Term Memory)

由于RNN不能很好地处理这种问题,于是出现了LSTM(Long Short Term Memory)一种加强版的RNN(LSTM可以改善梯度消失问题)。简单来说就是原始RNN没有长期的记忆能力,于是就给RNN加上了一些记忆控制器,实现对某些信息能够较长期的记忆,而对某些信息只有短期记忆能力。
如上图所示,LSTM中存在Forget Gate,Input Gate,Output Gate来控制信息的流动程度。
RNN:

LSTN:

加号圆圈表示线性相加,乘号圆圈表示用gate来过滤信息。

Understanding LSTM中对LSTM有非常详细的介绍。(对应的中文翻译)

LSTM MNIST手写数字辨识

实际上,图片文字识别这类任务用CNN来做效果更好,但是这里想要强行用LSTM来做一波。
MNIST_data中每一个image的大小是28*28,以行顺序作为序列输入,即第一行的28个像素作为$x_{0}
,第二行为

x_1,...,第28行的28个像素作为

x_28$输入,一个网络结构总共的输入是28个维度为28的向量,输出值是10维的向量,表示的是0-9个数字的概率值。这是一个many to one的RNN结构。
下面直接上代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)# 参数设置
BATCH_SIZE = 100        # BATCH的大小,相当于一次处理50个image
TIME_STEP = 28          # 一个LSTM中,输入序列的长度,image有28行
INPUT_SIZE = 28         # x_i 的向量长度,image有28列
LR = 0.01               # 学习率
NUM_UNITS = 100         # 多少个LTSM单元
ITERATIONS=8000         # 迭代次数
N_CLASSES=10            # 输出大小,0-9十个数字的概率# 定义 placeholders 以便接收x,y
train_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])       # 维度是[BATCH_SIZE,TIME_STEP * INPUT_SIZE]
image = tf.reshape(train_x, [-1, TIME_STEP, INPUT_SIZE])                   # 输入的是二维数据,将其还原为三维,维度是[BATCH_SIZE, TIME_STEP, INPUT_SIZE]
train_y = tf.placeholder(tf.int32, [None, N_CLASSES])                     # 定义RNN(LSTM)结构
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=NUM_UNITS) 
outputs,final_state = tf.nn.dynamic_rnn(cell=rnn_cell,              # 选择传入的cellinputs=image,               # 传入的数据initial_state=None,         # 初始状态dtype=tf.float32,           # 数据类型time_major=False,           # False: (batch, time step, input); True: (time step, batch, input),这里根据image结构选择False
)
output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES)      

这里outputs,final_state = tf.nn.dynamic_rnn(...).
final_state包含两个量,第一个为c保存了每个LSTM任务最后一个cell中每个神经元的状态值,第二个量h保存了每个LSTM任务最后一个cell中每个神经元的输出值,所以c和h的维度都是[BATCH_SIZE,NUM_UNITS]。
outputs的维度是[BATCH_SIZE,TIME_STEP,NUM_UNITS],保存了每个step中cell的输出值h。
由于这里是一个many to one的任务,只需要最后一个step的输出outputs[:, -1, :],output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES) 通过一个全连接层将输出限制为N_CLASSES。

loss = tf.losses.softmax_cross_entropy(onehot_labels=train_y, logits=output)      # 计算loss
train_op = tf.train.AdamOptimizer(LR).minimize(loss)      #选择优化方法correct_prediction = tf.equal(tf.argmax(train_y, axis=1),tf.argmax(output, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))  #计算正确率sess = tf.Session()
sess.run(tf.global_variables_initializer())     # 初始化计算图中的变量for step in range(ITERATIONS):    # 开始训练x, y = mnist.train.next_batch(BATCH_SIZE)   test_x, test_y = mnist.test.next_batch(5000)_, loss_ = sess.run([train_op, loss], {train_x: x, train_y: y})if step % 500 == 0:      # test(validation)accuracy_ = sess.run(accuracy, {train_x: test_x, train_y: test_y})print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

训练过程输出:

train loss: 2.2990 | test accuracy: 0.13
train loss: 0.1347 | test accuracy: 0.96
train loss: 0.0620 | test accuracy: 0.97
train loss: 0.0788 | test accuracy: 0.98
train loss: 0.0160 | test accuracy: 0.98
train loss: 0.0084 | test accuracy: 0.99
train loss: 0.0436 | test accuracy: 0.99
train loss: 0.0104 | test accuracy: 0.98
train loss: 0.0736 | test accuracy: 0.99
train loss: 0.0154 | test accuracy: 0.98
train loss: 0.0407 | test accuracy: 0.98
train loss: 0.0109 | test accuracy: 0.98
train loss: 0.0722 | test accuracy: 0.98
train loss: 0.1133 | test accuracy: 0.98
train loss: 0.0072 | test accuracy: 0.99
train loss: 0.0352 | test accuracy: 0.98

可以看到,虽然RNN是擅长处理序列类的任务,在MNIST手写数字图片辨识这个任务上,RNN同样可以取得很高的正确率。

参考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://yjango.gitbooks.io/superorganism/content/lstmgru.html
参考代码

https://www.cnblogs.com/sandy-t/p/6930608.html

有些人,一辈子都没有得到过自己想要的,因为他们总是半途而废

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/567931.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RNN入门

雷锋网 AI科技评论按:本文作者何之源,原文载于知乎专栏AI Insight,雷锋网(公众号:雷锋网) AI科技评论获其授权发布。 上周写的文章《完全图解RNN、RNN变体、Seq2Seq、Attention机制》介绍了一下RNN的几种结构,今天就来…

Java二进制小数表示_《Java编程的逻辑》笔记9--小数的二进制表示

小数计算为什么会出错?简要答案实际上,不是运算本身会出错,而是计算机根本就不能精确的表示很多数,比如0.1这个数。计算机是用一种二进制格式存储小数的,这个二进制格式不能精确表示0.1,它只能表示一个非常…

『TensorFlow』模型保存和载入方法汇总

一、TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称功能说明默认值var_listSaver中存储变量集合全局变量集合reshape加载时是否恢复变量形状Truesharded是否将变量轮循放在所有设备上Truemax_to_keep保留最近检…

STL13-list容器(链表)

链表是由一系列的结点组成,结点包括两个域:一个数据域,一个指针域 1、链表内存是非连续的,添加删除元素效率较高,时间复杂度都是常数项,不需要移动元素 2、链表只有在需要的时候才会分配内存 3、链表只要…

php 前往页面,PHP实现网页截图?

如何使用PHP实现网页截图PHP实现网页截图是一个在日常开发中不常见的需求,但是如果实现还是非常有意思的。目前业界有很多成熟的方案,下面我推荐使用一个很稳定的第三方服务来直接实现,该服务有如下特点:支持多线路支持登录截图支…

STL14-set/multiset容器

set只有一个方法就是insert #include<iostream> #include<set> //set和multiset是一个头文件 //set内部实现机制 红黑色&#xff08;平衡二叉树的一种&#xff09; //关联式容器 //set不允许有重复元素 //multiset运行有重复元素 //容器查找效率高 //容器根据元素的…

普通的java类型是指,String是一个很普通的类 - Java那些事儿

上一篇我们讲了Java中的数组&#xff0c;其实是为本章的内容做准备的&#xff0c;String这个类是我们在写Java代码中用得最多的一个类&#xff0c;没有之一&#xff0c;今天我们就讲讲它&#xff0c;我们打开String这个类的源码&#xff1a;声明了一个char[]数组&#xff0c;变…

STL15-map/multimap容器

map的key值不可以重复 multimap的key值可以重复 #if 1 #include<iostream> #include<map> using namespace std; //初始化 void test01() {//map容器参数 第一个参数key的类型 第二个参数value类型map<int, int> mymap;//插入元素 pair.first key值 pair.se…

php nginx日志分析,如何通过NGINX的log日志来分析网站的访问情况,试试这些命令...

想知道你的网站每天的访问情况吗&#xff1f;有多少人访问了&#xff1f;访问最多的页面是哪个&#xff1f;哪个时段访问的人最多&#xff1f;哪个地方访问的最多&#xff1f;每秒有多少请求&#xff1f;很好奇吧&#xff0c;只要你是使用了nginx进行请求抓发&#xff0c;那么就…

php带来互联网的影响,网络对我们的影响有哪些?

影响有&#xff1a;1、丰富了我们的业余生活&#xff1b;2、降低了获取知识的成本&#xff0c;降低了提升工作的能力的成本&#xff0c;提高了工作的效率&#xff0c;可以快速建立良好的人脉关系&#xff1b;3、让购物变得更加简单便捷&#xff1b;4、朋友间深度沟通与交流越来…

STL17-函数对象

仿函数&#xff1a; #include<iostream> #include<vector> #include<algorithm> using namespace std; //仿函数&#xff08;函数对象&#xff09;重载“&#xff08;&#xff09;”操作符 使类对象可以像函数那样调用 //仿函数是一个类&#xff0c;不是一个…

STL18常用算法

#include<iostream> #include<algorithm> #include<vector> using namespace std; //transform 将一个容器中的元素搬运在另一个容器中 #if 0 //错误 struct PrintVector {void operator()(int v) {cout << v << " ";} }; void test0…

php中页面平滑回到顶部代码,原生JS实现平滑回到顶部组件

返回顶部组件是一种极其常见的网页功能&#xff0c;需求简单&#xff1a;页面滚动一定距离后&#xff0c;显示返回顶部的按钮&#xff0c;点击该按钮可以将滚动条滚回至页面开始的位置。实现思路也很容易&#xff0c;只要改变document.documentElement.scrollTop或document.bod…

C++基础01-C++对c的增强

所谓namespace&#xff0c;是指标识符的各种可见范围。C标准程序库中的所 有标识符都被定义于一个名为std的namespace中。 一 &#xff1a;<iostream>和<iostream.h>格式不一样&#xff0c;前者没有后缀&#xff0c;实际上&#xff0c; 在你的编译器include文件夹…

C++基础02-C++对c的拓展

变量名实质上是一段连续存储空间的别名&#xff0c;是一个标号(门牌号) 通过变量来申请并命名内存空间. 通过变量的名字可以使用存储空间. 变量名&#xff0c;本身是一段内存的引用&#xff0c;即别名(alias). 引用可以看作一个已定义变量的别名。 引用的语法&#xff…

php小程序onload,微信小程序 loading 组件实例详解

这篇文章主要介绍了微信小程序 loading 组件实例详解的相关资料,需要的朋友可以参考下loading通常使用在请求网络数据时的一种方式&#xff0c;通过hidden属性设置显示与否主要属性&#xff1a;wxml显示loading正在加载jsPage({data:{// text:"这是一个页面"hiddenLo…

C++基础04-类基础

一、类和对象 面向对象三大特点&#xff1a;封装、继承、多态。 struct 中所有行为和属性都是 public 的(默认)。C中的 class 可以指定行为和属性的访问方式。 封装,可以达到,对内开放数据,对外屏蔽数据,对外提供接口。达到了信息隐蔽的功能。 class 封装的本质,在于将数…

C++基础05-类构造函数与析构函数

总结&#xff1a; 1、类对象的作用域为两个{}之间。在遇到}后开始执行析构函数 2、当没有任何显式的构造函数&#xff08;无参&#xff0c;有参&#xff0c;拷贝构造&#xff09;时&#xff0c;默认构造函数才会发挥作用 一旦提供显式的构造函数&#xff0c;默认构造函数不复…

PHP网站配置项,Thinkphp5通用网站后台配置项的动态添加及更新

一、引入无论平时我们自己制作&#xff0c;还是浏览别人的网站&#xff0c;它都具有其相应的一些共用的、通用的属性&#xff0c;比如&#xff1a;网站的名字&#xff0c;关键字、备案号、分页数量、是否开启缓存等信息。一些网站可能将配置项写死在后台&#xff0c;无法动态更…

oracle 查询cpu 100%,Oracle 11g中查询CPU占有率高的SQL

oracle版本&#xff1a;oracle11g背景&#xff1a;今天在Linux中的oracle服务上&#xff0c;运用top命令发现许多进程的CPU占有率是100%。操作步骤&#xff1a;以进程PID:7851为例执行以下语句&#xff1a;方法一&#xff1a;(1)通过PID&#xff0c;查得相对应的系统进程对应的…