练习利用LSTM实现手写数字分类任务

练习利用LSTM实现手写数字分类任务

MNIST数据集中图片大小为28*28.

按照行进行展开成28维的特征向量。

考虑到这28个的向量之间存在着顺序依赖关系,我们可以将他们看成是一个长为28的输入序列,将其输入到LSTM中,LSTM可以从中提取到序列特征,再将此序列特征用一层全联接作为分类器,分类器输出10种分类类别。

综合代码

import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connectedimport input_data
mnist = input_data.read_data_sets('MNIST_data/',one_hot = True)
#one_hot = True 独热编码,类似[0,0,0,1,0,0,0,0,0,0]这种形式,等价于class=3n_inputs  = 28  #表示输入神经元的个数
n_steps   = 28  #表示序列长度
n_neurons = 150 #表示LSTM中隐藏层和输出层神经元呢个数
n_outputs = 10  #是最终分类器输出的类别数,mnist数据集是10分类任务learning_rate = 0.01 #优化方法的学习率X = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
Y_labels = tf.placeholder(tf.int32,[None,n_outputs])basic_cell = tf.contrib.rnn.BasicLSTMCell(n_neurons,forget_bias = 1.0, state_is_tuple = True) 
#获取一层LSTM网络,参数1是每个cell的输出神经元个数,参数2是遗忘的偏置,参数3表示双状态outneurons, states = tf.nn.dynamic_rnn(basic_cell,X,dtype = tf.float32) 
#outneurons得到了输出序列logits = fully_connected(tf.transpose(outneurons,perm = [1,0,2])[-1], n_outputs,activation_fn = None)
#在这里由于outneurons的维度为[batch_size,n_steps,n_inputs]的形式,而我们只需要最后一个cell对于所有batch的输出,因此把前两个维度调换一下,再取用[-1]取到最后一个cell对于所有batch的输出。shape为[batch_size,n_inputs]
#将其接到一层全连接网络作为分类器得到logitscross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = Y_labels,logits = logits)
loss = tf.reduce_mean(cross_entropy)
#对logits用softmax做归一化,计算其对于样本labels的交叉熵的和,取均值作为损失函数lossoptimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
trainop = optimizer.minimize(loss)
#申请一个优化器,用来最后小化损失函数losscorrect = tf.equal(tf.argmax(logits,1),tf.argmax(Y_labels,1))
#分析正确率accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))batch_size = 64 
init = tf.global_variables_initializer()
with tf.Session() as sess:init.run()for i in range(10000):x_batch, y_batch = mnist.train.next_batch(batch_size)x_batch = x_batch.reshape([-1,n_steps,n_inputs])sess.run(trainop,feed_dict = {X : x_batch,Y_labels : y_batch})if i % 200 == 0:print('train accuracy =',sess.run(accuracy,feed_dict = {X : x_batch,Y_labels : y_batch}))X_test = mnist.test.images.reshape((-1,n_steps,n_inputs))Y_test = mnist.test.labelsprint('test accuracy =',sess.run(accuracy,feed_dict = {X : X_test,Y_labels : Y_test}))

评估

实验表明求得得准确率可达到99%。

疑问

我将BasicLSTMCell换成BasicRNNCell就无法训练,这是为什么呢?难道跟LSTM有遗忘们相关吗?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/320935.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【图论】最短路上的统计(ssl 1500)

最短路上的统计 ssl 1500 题目大意: 求一个图中,从a到b的所有最短路所经过的点数之和 原题: 题目描述 一个无向图上,没有自环,所有边的权值均为1,对于一个点对(a,b),我们要把所…

Service Mesh新成员:Consul 1.2

本文译自 HashiCorp 官网关于 Consul 1.2 支持 Service Mesh 发布的博客文章。原文链接:https://www.hashicorp.com/blog/consul-1-2-service-mesh作者:Mitchell Hashimoto 翻译:董干 转载自:https://blog.idevfun.io/consul-1-2-…

P4564-[CTSC2018]假面【期望dp】

正题 题目大意:https://www.luogu.com.cn/problem/P4564 题目大意 nnn个人第iii个有mim_imi​点血,每次有操作 有ppp的概率对一个人造成111点伤害(如果死了就不算,ppp每次都不同)给出若干个人,对里面存活的人随机选择…

VAE(变分自编码器)学习笔记

VAE学习笔记 普通的编码器可以将图像这类信息编码成为特征向量. 但通常这些特征向量不具有空间上的连续性. VAE(变分自编码器)可以将图像信息编码成为具有空间连续性的特征向量. 方法是向编码器和解码器中加入统计信息,即特征向量代表的的是一个高斯分布,强迫特征向量服从高…

小 X 的 AK 计划

小 X 的 AK 计划 题目大意: 有n个点,到一个点(时间为距离)并花一些时间可以A掉此点,问最多可以A多少个点 原题: 解题思路: 先按位置从小到大排序,然后到每一个点并A掉的时间加在…

.NET Core 2.1的重大缺陷延长了.NET Core 2.0的寿命

微软近日宣布,.NET Core 2.0 即将 "寿终正寝",对它的支持将在2018年10月1日结束。.NET Core 2.0 是一个非长期支持(LTS)的版本,因此微软的承诺是在下一个版本发布的三个月之后结束对它的支持。由于 .NET Cor…

P4782-[模板]2-SAT问题【tarjan】

正题 题目链接:https://www.luogu.com.cn/problem/P4782 题目大意 给若干个条件限定为xi为a或xj为bx_i为a或x_j为bxi​为a或xj​为b。求构造一个序列xxx满足所有条件 解题思路 我们对于每个xix_ixi​构造两个点分别表示xix_ixi​为0/10/10/1。然后就开始对能够确定的条件关系…

区间dp专题

区间dp专题 基本思想 区间dp一类的问题往往子问题具有很明显的区间性质,也就是说我们可以通过将子问题定义为整个区间的一个子区间.因为一个大区间可以切分成两段相邻的子区间.从这点出发,我们便可以找到递推关系. 1.纸牌游戏 蜘蛛牌游戏规则是这样的:只能将牌拖…

.Net Core开发日志——Global Tools

.Net Core 2.1引入了一个新的功能,Global Tools,其本质是包含控制台应用程序的nuget包,目前而言,还没有特别有用的工具,不过相信随着时间的推移,各种有创意或者实用性强的Global Tools会出现在大家的视野里…

【DP】回文词 (ssl 1813)

回文词 ssl 1813 题目大意: 给出一个式子,最少要加多少个字符才能让这个式子是一个“回文词” 原题: 题目描述 回文词是一种对称的字符串,也就是说:一个回文词,从左向右读和从右向左读的结果都是一样的.任意给定一个字符串,通过插入若干…

POJ3678-Katu Puzzle【2-SAT】

正题 题目链接:http://poj.org/problem?id3678 题目大意 nnn个xix_ixi​为0/10/10/1。有mmm个条件表示xiandxjax_i\ and\ x_jaxi​ and xj​a或xiorxjax_i\ or\ x_jaxi​ or xj​a或xixorxjax_i\ xor\ x_jaxi​ xor xj​a。 求构造一组合法的xix_ixi​。 解题思路 讨论一下 …

Simple-Faster-RCNN源码学习笔记

Simple-Faster-RCNN 源码学习 项目github地址: https://github.com/chenyuntc/simple-faster-rcnn-pytorch 源码 源文件: model/utils/bbox_tools.py 方法: loc2bbox(src_bbox, loc) 参数含义: src_bbox描述的是bbox的坐标.loc表示的偏移(offsets)和缩放尺度(scales). 给…

API网关模式

什么是网关网关一词来源于计算机网络中的定义,网关(Gateway)又称网间连接器、协议转换器。网关的准确定义是: 两个计算机程序或系统之间的连接,网关作为两个程序之间的门户,允许它们通过不同计算机之间的协议通信来共享信息。顾名…

楼层

楼层 题目大意: 有两个数m和t,问1~m之间去掉有数字t的数之后还有多少个数 原题: 题目描述 mxy 感觉新世界的大门打开了。 ta 决定要在新世界的旅馆中找间房住。已知新世界每天都有一个高能的数字 t,这个数字在楼层中是不会出…

P3825-[NOI2017]游戏【2-SAT】

正题 题目链接:https://www.luogu.com.cn/problem/P3825 题目大意 nnn场比赛,对于场地aaa不能用赛车AAA(b,cb,cb,c以此类推),对于场地xxx可以用任何赛车。然后给定mmm条条件形如iIjJi\ I\ j\ Ji I j J表示在第iii场比赛使用赛车I…

CVPR19 基于图卷积网络的多标签图像识别模型 论文笔记

笔记 旷视研究院的研究员提出了如下模型,用于图像的多标签分类. 该模型与一般模型不一样的一点是,它的分类器是生成的,因此它有一个专门生成分类器的子网络. 网络主要由两部分构成 特征表示子网络,该网络由ResNet-101构成,即蓝色框圈出的部分.分类器生成子网络,该网络由3个…

日行千里,全凭“车”况,为什么我们要升级平台

历经一个半月的时间,不管是叫工业互联网平台还是叫工业大数据平台,从1.0版本升级到2.0版本,升级部分包括:客户端(网关)、服务端(数据接收、数据处理、计算服务)、底层数据库结构、WE…

朋友

朋友 题目大意: 有两堆数,只有第一堆数会和第二堆数中比自己小的数交“朋友”,问有多少对朋友 原题: 题目描述 mxy 即将前往新世界。 在前往新世界的过程中,ta 遇见了两种人。一种是只和 lowb 做朋友,…

好代码是管出来的——.Net Core中的单元测试与代码覆盖率

测试对于软件来说,是保证其质量的一个重要过程,而测试又分为很多种,单元测试、集成测试、系统测试、压力测试等等,不同的测试的测试粒度和测试目标也不同,如单元测试关注每一行代码,集成测试关注的是多个模…

P3694-邦邦的大合唱站队【状压dp】

正题 题目链接:https://www.luogu.com.cn/problem/P3694 题目大意 nnn个人,有mmm个队伍,每个人都属于一个队伍。要求叫出一些人来,然后任意插入出来的空隙中使得同一队的人在一起。求最少出列人数。 解题思路 如果知道最终的队列就可以十分…