【tensorflow】static_rnn与dynamic_rnn的区别

static_rnn和dynamic_rnn的区别主要在于实现不同。

  • static_rnn会把RNN展平,用空间换时间。 gpu会吃不消(个人测试结果)

  • dynamic_rnn则是使用for或者while循环。

调用static_rnn实际上是生成了rnn按时间序列展开之后的图。打开tensorboard你会看到sequence_length个rnn_cell
stack在一起,只不过这些cell是share
weight的。因此,sequence_length就和图的拓扑结构绑定在了一起,因此也就限制了每个batch的sequence_length必须是一致。

调用dynamic_rnn不会将rnn展开,而是利用tf.while_loop这个api,通过Enter, Switch, Merge,
LoopCondition, NextIteration等这些control
flow的节点,生成一个可以执行循环的图(这个图应该还是静态图,因为图的拓扑结构在执行时是不会变化的)。在tensorboard上,你只会看到一个rnn_cell,
外面被一群control
flow节点包围着。对于dynamic_rnn来说,sequence_length仅仅代表着循环的次数,而和图本身的拓扑没有关系,所以每个batch可以有不同sequence_length。

static_rnn

导包、加载数据、定义变量
import tensorflow as tf
tf.reset_default_graph() #流式计算图形graph  循环神经网络 将名字相同重置了图
import datetime #打印时间
import os   #保存文件
from tensorflow.examples.tutorials.mnist import input_data# minst测试集
mnist = input_data.read_data_sets('../', one_hot=True)# 每次使用100条数据进行训练
batch_size = 100
# 图像向量
width = 28
height = 28
# LSTM隐藏神经元数量
rnn_size = 256
# 输出层one-hot向量长度的
out_size = 10

声明变量

def weight_variable(shape, w_alpha=0.01):initial = w_alpha * tf.random_normal(shape)return tf.Variable(initial)def bias_variable(shape, b_alpha=0.1):initial = b_alpha * tf.random_normal(shape)return tf.Variable(initial)# 权重及偏置
w = weight_variable([rnn_size, out_size])
b = bias_variable([out_size])

将数据转化成RNN所要求的数据

# 按照图片大小申请占位符
X = tf.placeholder(tf.float32, [None, height, width])
# 原排列[0,1,2]transpose为[1,0,2]代表前两维装置,如shape=(1,2,3)转为shape=(2,1,3)
# 这里的实际意义是把所有图像向量的相同行号向量转到一起,如x1的第一行与x2的第一行
x = tf.transpose(X, [1, 0, 2])
# reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度
x = tf.reshape(x, [-1, width])
# split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组
x = tf.split(x, height)

构建静态的循环神经网络

# LSTM
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出
outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)#取最后一个进行矩阵乘法
y_conv = tf.add(tf.matmul(outputs[-1], w), b)
# 最小化损失优化
Y = tf.placeholder(dtype=tf.float32,shape = [None,10])
#损失使用的交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_conv, labels=Y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
# 计算准确率
correct = tf.equal(tf.argmax(y_conv, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

模型的训练

# 启动会话.开始训练
saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())
step = 0
acc_rate = 0.90
while 1:batch_x, batch_y = mnist.train.next_batch(batch_size)batch_x = batch_x.reshape((batch_size, height, width))session.run(optimizer, feed_dict={X:batch_x,Y:batch_y})# 每训练10次测试一次if step % 10 == 0:batch_x_test = mnist.test.imagesbatch_y_test = mnist.test.labelsbatch_x_test = batch_x_test.reshape([-1, height, width])acc = session.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test})print(datetime.datetime.now().strftime('%c'), ' step:', step, ' accuracy:', acc)# 偏差满足要求,保存模型if acc >= acc_rate:
#             os.sep = ‘/’model_path = os.getcwd() + os.sep + str(acc_rate) + "mnist.model"saver.save(session, model_path, global_step=step)breakstep += 1
session.close()

Wed Dec 18 10:08:45 2019 step: 0 accuracy: 0.1006
Wed Dec 18 10:08:46 2019 step: 10 accuracy: 0.1009
Wed Dec 18 10:08:46 2019 step: 20 accuracy: 0.1028

Wed Dec 18 10:08:57 2019 step: 190 accuracy: 0.9164

dynamic_rnn

加载数据,声明变量
import tensorflow as tf
tf.reset_default_graph()
from tensorflow.examples.tutorials.mnist import input_data# 载入数据
mnist = input_data.read_data_sets("../", one_hot=True)# 输入图片是28
n_input = 28
max_time = 28
lstm_size = 100  # 隐藏单元 可调
n_class = 10  # 10个分类
batch_size = 100   # 每次50个样本 可调
n_batch_size = mnist.train.num_examples // batch_size    # 计算一共有多少批次

Extracting …/train-images-idx3-ubyte.gz
Extracting …/train-labels-idx1-ubyte.gz
Extracting …/t10k-images-idx3-ubyte.gz
Extracting …/t10k-labels-idx1-ubyte.gz

占位符、权重

# 这里None表示第一个维度可以是任意长度
# 创建占位符
x = tf.placeholder(tf.float32,[None, 28*28])
# 正确的标签
y = tf.placeholder(tf.float32,[None, 10])# 初始化权重 ,stddev为标准差
weight = tf.Variable(tf.truncated_normal([lstm_size, n_class], stddev=0.1))
# 初始化偏置层
biases = tf.Variable(tf.constant(0.1, shape=[n_class]))

构建动态RNN、损失函数、准确率

# 定义RNN网络
def RNN(X, weights, biases):#  原始数据为[batch_size,28*28]# input = [batch_size, max_time, n_input]input_ = tf.reshape(X,[-1, max_time, n_input])# 定义LSTM的基本单元
#     lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)# final_state[0] 是cell state# final_state[1] 是hidden statoutputs, final_state = tf.nn.dynamic_rnn(lstm_cell, input_, dtype=tf.float32)display(final_state)results = tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)return results
# 计算RNN的返回结果
prediction = RNN(x, weight, biases)
# 损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
# 使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
# 将结果存下来
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 计算正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

LSTMStateTuple(c=<tf.Tensor ‘rnn/while/Exit_3:0’ shape=(?, 100) dtype=float32>, h=<tf.Tensor ‘rnn/while/Exit_4:0’ shape=(?, 100) dtype=float32>)

训练数据

saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(6):for batch in range(n_batch_size):# 取出下一批次数据batch_xs,batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs,y: batch_ys})if(batch%100==0):print(str(batch)+"/" + str(n_batch_size))acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print("Iter" + str(epoch) + " ,Testing Accuracy = " + str(acc))if acc >0.9:saver.save(sess,'./rnn_dynamic')break

0/550
100/550
200/550
300/550
400/550
500/550
Iter0 ,Testing Accuracy = 0.5903

Iter5 ,Testing Accuracy = 0.9103

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/456006.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pcie1 4 速度_太阳系行星们谁转得最快?八大行星自转速度排行榜,地球排第五...

不知道大家有没有玩儿过陀螺呢&#xff1f;玩儿陀螺的技术如果很好的话&#xff0c;它可以在地上飞快地旋转并且能够旋转很长的时间。有趣的是&#xff0c;宇宙中的很多星球就像陀螺一样绕着一个中心轴旋转着。这就是星球的自转。在太阳系中有八颗大行星&#xff0c;它们都在自…

python中时间模块

时间日期相关的模块 calendar 日历模块time   时间模块datetime 日期时间模块timeit   时间检测模块 日历模块 calendar() 功能&#xff1a;获取指定年份的日历字符串 格式&#xff1a;calendar.calendar&#xff08;年份,w2,l1&#xff0c;c6,m3&#xff09; 返回值&…

硬盘接口详细解释

硬盘是电脑主要的存储媒介之一&#xff0c;由一个或者多个铝制或者玻璃制的碟片组成。碟片外覆盖有铁磁性材料。硬盘有固态硬盘&#xff08;SSD 盘&#xff0c;新式硬盘&#xff09;、机械硬盘&#xff08;HDD 传统硬盘&#xff09;、混合硬盘&#xff08;HHD 一块基于传统机械…

【Keras】30 秒上手 Keras+实例对mnist手写数字进行识别准确率达99%以上

本文我们将学习使用Keras一步一步搭建一个卷积神经网络。具体来说&#xff0c;我们将使用卷积神经网络对手写数字(MNIST数据集)进行识别&#xff0c;并达到99%以上的正确率。 为什么选择Keras呢&#xff1f; 主要是因为简单方便。更多细节请看&#xff1a;https://keras.io/ …

分布式资本沈波:未来区块链杀手级应用将出现在“+区块链”

雷锋网5月22日报道&#xff0c;日前“区块链技术和应用峰会”在杭州国际博览中心举行。会上&#xff0c;分布式资本创始管理人沈波作了《区块链的投资现状与发展趋势》演讲。 沈波表示&#xff0c;由于区块链的共识机制和无法篡改两大特点&#xff0c;它在各行各业皆有应用潜力…

帧间预测小记

帧间预测后&#xff0c;在比特流中会有相应的信息&#xff1a;残差信息&#xff0c;运动矢量信息&#xff0c;所选的模式。 宏块的色度分量分辨率是亮度分辨率的一半&#xff08;Cr和Cb&#xff09;&#xff0c;水平和垂直均一半。色度块采用和亮度块一致的分割模式&#xff0…

ImageJ Nikon_科研论文作图之ImageJ

各位读者朋友们又见面了&#xff0c;今天给大家介绍一款图片处理软件——ImageJ&#xff0c;这是一款免费的科学图像分析工具&#xff0c;广泛应用于生物学研究领域。ImageJ软件能够对图像进行缩放、旋转、扭曲、模糊等处理&#xff0c;也可计算选定区域内分析对象的一系列几何…

python中面向对象

面向对象 Object Oriented 面向对象的学习&#xff1a; 面向对象的语法&#xff08;简单&#xff0c;记忆就可以搞定&#xff09;面向对象的思想&#xff08;稍难&#xff0c;需要一定的理解&#xff09; 面向过程和面向对象的区别 面向过程开发&#xff0c;以函数作为基本结构…

【urllib】url编码问题简述

对url编解码总结 需要用到urllib库中的parse模块 import urllib.parse # Python3 url编码 print(urllib.parse.quote("天天")) # Python3 url解码 print(urllib.parse.unquote("%E5%A4%E5%A4%")) urlparse() # urllib.parse.urlparse(urlstring,scheme,…

冷知识 —— 地理

西安1980坐标系&#xff1a; 1978 年 4 月在西安召开全国天文大地网平差会议&#xff0c;确定重新定位&#xff0c;建立我国新的坐标系。为此有了 1980 国家大地坐标系。1980 国家大地坐标系采用地球椭球基本参数为 1975 年国际大地测量与地球物理联合会第十六届大会推荐的数据…

独家| ChinaLedger白硕:区块链中的隐私保护

隐私问题一直是区块链应用落地的障碍问题之一&#xff0c;如何既能满足监管&#xff0c;又能不侵害数据隐私&#xff0c;是行业都在攻克的问题。那么&#xff0c;到底隐私问题为何难&#xff1f;有什么解决思路&#xff0c;以及实践创新呢&#xff1f;零知识证明、同态加密等技…

手机处理器排行榜2019_手机处理器AI性能排行榜出炉,高通骁龙第一,华为排在第十名...

↑↑↑击上方"蓝字"关注&#xff0c;每天推送最新科技新闻安兔兔在近日公布了今年四月份Android手机处理器AI性能排行榜&#xff0c;榜单显示高通骁龙865处理器的AI性能在Android阵营中排在第一名——该处理器的AI性能得分接近46万分&#xff0c;今年的小米10、三星G…

芯片支持的且会被用到的H.264特性 预测编码基本原理

视频压缩&#xff1a; 1.H.264基本档次和主要档次&#xff1b;2.CAVLC熵编码&#xff0c;即基于上下文的自适应变长编码&#xff1b;&#xff08;不支持CABAC&#xff0c;即基于上下文的自适应算术编码&#xff09;分辨率&#xff1a;仅用到1080p60&#xff0c;即分辨率为1920*…

MongoDB 数据库 【总结笔记】

一、MongoDB 概念解析 什么是MongoDB&#xff1f; ​ 1、MongoDB是有C语言编写的&#xff0c;是一个基于分布式文件存储的开源数据库系统&#xff0c;在高负载的情况下&#xff0c;添加更多节点&#xff0c;可以保证服务器的性能 ​ 2、MongoDB为web应用提供了高性能的数据存储…

PHP 函数截图 哈哈哈

转载于:https://www.cnblogs.com/bootoo/p/6714676.html

python中的魔术方法

魔术方法 魔术方法就是一个类/对象中的方法&#xff0c;和普通方法唯一的不同时&#xff0c;普通方法需要调用&#xff01;而魔术方法是在特定时刻自动触发。 1.__init__ 初始化魔术方法 触发时机&#xff1a;初始化对象时触发&#xff08;不是实例化触发&#xff0c;但是和实…

2016年光伏电站交易和融资的十大猜想

1领跑者计划备受关注&#xff0c;竞价上网或从试点开始 领跑者计划规模大&#xff0c;上网条件好&#xff0c;又有政府背书&#xff0c;虽说价格也不便宜&#xff0c;但省去很多隐性成本&#xff0c;对于致力于规模化发展的大型企业来说仍是首要选择。同时&#xff0c;从能源管…

loading gif 透明_搞笑GIF:有这样的女朋友下班哪里都不想去

原标题&#xff1a;搞笑GIF&#xff1a;有这样的女朋友下班哪里都不想去这样的广场舞看着不凉快吗&#xff1f;大哥慢点&#xff0c;机器经受不住你这样的速度求孩子的心里阴影面积生孩子就是用来玩的。有这样的媳妇做饭&#xff0c;下班哪里也不想去1.领导在门外用门夹核桃&am…

Redis数据库 【总结笔记】

一、NoSql&#xff08;非关系型数据库&#xff09; NoSQL&#xff1a;NoSQL Not Only SQL 非关系型数据库 ​ NoSQL&#xff0c;泛指非关系型的数据库。随着互联网web2.0网站的兴起&#xff0c;传统的关系数据库在应付web2.0网站&#xff0c;特别是超大规模和高并发的SNS类型…

基于IP的H.264关键技术

一、 引言 H.264是ITU-T最新的视频编码标准&#xff0c;被称作ISO/IEC14496-10或MPEG-4 AVC&#xff0c;是由运动图像专家组(MPEG)和ITU的视频编码专家组共同开发的新产品。H.264分两层结构&#xff0c;包括视频编码层和网络适配层。视频编码层处理的是块、宏块和片的数据&…