TensorFlow中RNN实现的正确打开方式

上周写的文章《完全图解RNN、RNN变体、Seq2Seq、Attention机制》介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为:

  • 一个完整的、循序渐进的学习TensorFlow中RNN实现的方法。这个学习路径的曲线较为平缓,应该可以减少不少学习精力,帮助大家少走弯路。

  • 一些可能会踩的坑

  • TensorFlow源码分析

  • 一个Char RNN实现示例,可以用来写诗,生成歌词,甚至可以用来写网络小说!(项目地址:https://github.com/hzy46/Char-RNN-TensorFlow

一、学习单步的RNN:RNNCell

如果要学习TensorFlow中的RNN,第一站应该就是去了解“RNNCell”,它是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。

借助图片来说可能更容易理解。假设我们有一个初始状态h0,还有输入x1,调用call(x1, h0)后就可以得到(output1, h1):


TensorFlow中RNN实现的正确打开方式

再调用一次call(x2, h1)就可以得到(output2, h2):

TensorFlow中RNN实现的正确打开方式

也就是说,每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。

在代码实现上,RNNCell只是一个抽象类,我们用的时候都是用的它的两个子类BasicRNNCell和BasicLSTMCell。顾名思义,前者是RNN的基础类,后者是LSTM的基础类。这里推荐大家阅读其源码实现(地址:http://t.cn/RNJrfMl),一开始并不需要全部看一遍,只需要看下RNNCell、BasicRNNCell、BasicLSTMCell这三个类的注释部分,应该就可以理解它们的功能了。

除了call方法外,对于RNNCell,还有两个类属性比较重要:

  • state_size

  • output_size

前者是隐层的大小,后者是输出的大小。比如我们通常是将一个batch送入模型计算,设输入数据的形状为(batch_size, input_size),那么计算时得到的隐层状态就是(batch_size, state_size),输出就是(batch_size, output_size)。

可以用下面的代码验证一下(注意,以下代码都基于TensorFlow最新的1.2版本):

import tensorflow as tf

import numpy as np


cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128) # state_size = 128

print(cell.state_size) # 128


inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态,形状为(batch_size, state_size)

output, h1 = cell.call(inputs, h0) #调用call函数


print(h1.shape) # (32, 128)

对于BasicLSTMCell,情况有些许不同,因为LSTM可以看做有两个隐状态h和c,对应的隐层就是一个Tuple,每个都是(batch_size, state_size)的形状:

import tensorflow as tf

import numpy as np

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = lstm_cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态

output, h1 = lstm_cell.call(inputs, h0)


print(h1.h)  # shape=(32, 128)

print(h1.c)  # shape=(32, 128)

二、学习如何一次执行多步:tf.nn.dynamic_rnn

基础的RNNCell有一个很明显的问题:对于单个的RNNCell,我们使用它的call函数进行运算时,只是在序列时间上前进了一步。比如使用x1、h0得到h1,通过x2、h1得到h2等。这样的h话,如果我们的序列长度为10,就要调用10次call函数,比较麻烦。对此,TensorFlow提供了一个tf.nn.dynamic_rnn函数,使用该函数就相当于调用了n次call函数。即通过{h0,x1, x2, …., xn}直接得{h1,h2…,hn}。

具体来说,设我们输入数据的格式为(batch_size, time_steps, input_size),其中time_steps表示序列本身的长度,如在Char RNN中,长度为10的句子对应的time_steps就等于10。最后的input_size就表示输入数据单个序列单个时间维度上固有的长度。另外我们已经定义好了一个RNNCell,调用该RNNCell的call函数time_steps次,对应的代码就是:

# inputs: shape = (batch_size, time_steps, input_size)

# cell: RNNCell

# initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵

outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

此时,得到的outputs就是time_steps步里所有的输出。它的形状为(batch_size, time_steps, cell.output_size)。state是最后一步的隐状态,它的形状为(batch_size, cell.state_size)。

此处建议大家阅读tf.nn.dynamic_rnn的文档(地址:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)做进一步了解。

三、学习如何堆叠RNNCell:MultiRNNCell

很多时候,单层RNN的能力有限,我们需要多层的RNN。将x输入第一层RNN的后得到隐层状态h,这个隐层状态就相当于第二层RNN的输入,第二层RNN的隐层状态又相当于第三层RNN的输入,以此类推。在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函数对RNNCell进行堆叠,相应的示例程序如下:

import tensorflow as tf

import numpy as np


# 每调用一次这个函数就返回一个BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell创建3层RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3层RNN

# 得到的cell实际也是RNNCell的子类

# 它的state_size是(128, 128, 128)

# (128, 128, 128)并不是128x128x128的意思

# 而是表示共有3个隐层状态,每个隐层状态的大小为128

print(cell.state_size) # (128, 128, 128)

# 使用对应的call函数

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态

output, h1 = cell.call(inputs, h0)

print(h1) # tuple中含有3个32x128的向量

通过MultiRNNCell得到的cell并不是什么新鲜事物,它实际也是RNNCell的子类,因此也有call方法、state_size和output_size属性。同样可以通过tf.nn.dynamic_rnn来一次运行多步。

此处建议阅读MutiRNNCell源码(地址:http://t.cn/RNJrfMl)中的注释进一步了解其功能。

四、可能遇到的坑1:Output说明

在经典RNN结构中有这样的图:

TensorFlow中RNN实现的正确打开方式

在上面的代码中,我们好像有意忽略了调用call或dynamic_rnn函数后得到的output的介绍。将上图与TensorFlow的BasicRNNCell对照来看。h就对应了BasicRNNCell的state_size。那么,y是不是就对应了BasicRNNCell的output_size呢?答案是否定的。

找到源码中BasicRNNCell的call函数实现:

def call(self, inputs, state):
   """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
   output = self._activation(_linear([inputs, state], self._num_units, True))
   return output, output

这句“return output, output”说明在BasicRNNCell中,output其实和隐状态的值是一样的。因此,我们还需要额外对输出定义新的变换,才能得到图中真正的输出y。由于output和隐状态是一回事,所以在BasicRNNCell中,state_size永远等于output_size。TensorFlow是出于尽量精简的目的来定义BasicRNNCell的,所以省略了输出参数,我们这里一定要弄清楚它和图中原始RNN定义的联系与区别。

再来看一下BasicLSTMCell的call函数定义(函数的最后几行):

new_c = (
   c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))

new_h = self._activation(new_c) * sigmoid(o)


if self._state_is_tuple:
 new_state = LSTMStateTuple(new_c, new_h)

else:
 new_state = array_ops.concat([new_c, new_h], 1)

return new_h, new_state

我们只需要关注self._state_is_tuple == True的情况,因为self._state_is_tuple == False的情况将在未来被弃用。返回的隐状态是new_c和new_h的组合,而output就是单独的new_h。如果我们处理的是分类问题,那么我们还需要对new_h添加单独的Softmax层才能得到最后的分类概率输出。

还是建议大家亲自看一下源码实现(地址:http://t.cn/RNJsJoH)来搞明白其中的细节。

五、可能遇到的坑2:因版本原因引起的错误

在前面我们讲到堆叠RNN时,使用的代码是:

# 每调用一次这个函数就返回一个BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell创建3层RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3层RNN

这个代码在TensorFlow 1.2中是可以正确使用的。但在之前的版本中(以及网上很多相关教程),实现方式是这样的:

one_cell =  tf.nn.rnn_cell.BasicRNNCell(num_units=128)

cell = tf.nn.rnn_cell.MultiRNNCell([one_cell] * 3) # 3层RNN

如果在TensorFlow 1.2中还按照原来的方式定义,就会引起错误!

六、一个练手项目:Char RNN

上面的内容实际上就是TensorFlow中实现RNN的基本知识了。这个时候,建议大家用一个项目来练习巩固一下。此处特别推荐Char RNN项目,这个项目对应的是经典的RNN结构,实现它使用的TensorFlow函数就是上面说到的几个,项目本身又比较有趣,可以用来做文本生成,平常大家看到的用深度学习来写诗写歌词的基本用的就是它了。

Char RNN的实现已经有很多了,可以自己去Github上面找,我这里也做了一个实现,供大家参考。项目地址为:hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow)。代码的部分实现来自于《安娜卡列尼娜文本生成——利用TensorFlow构建LSTM模型》

这篇专栏,在此感谢 @天雨粟 。

我主要向代码中添加了embedding层,以支持中文,另外重新整理了代码结构,将API改成了最新的TensorFlow 1.2版本。

可以用这个项目来写诗(以下诗句都是自动生成的):

何人无不见,此地自何如。
一夜山边去,江山一夜归。
山风春草色,秋水夜声深。
何事同相见,应知旧子人。
何当不相见,何处见江边。
一叶生云里,春风出竹堂。
何时有相访,不得在君心。

还可以生成代码:

static int page_cpus(struct flags *str)
{
       int rc;
       struct rq *do_init;
};

/*
* Core_trace_periods the time in is is that supsed,
*/
#endif

/*
* Intendifint to state anded.
*/
int print_init(struct priority *rt)
{       /* Comment sighind if see task so and the sections */
       console(string, &can);
}

此外生成英文更不是问题(使用莎士比亚的文本训练):

LAUNCE:
The formity so mistalied on his, thou hast she was
to her hears, what we shall be that say a soun man
Would the lord and all a fouls and too, the say,
That we destent and here with my peace.

PALINA:
Why, are the must thou art breath or thy saming,
I have sate it him with too to have me of
I the camples.

最后,如果你脑洞够大,还可以来做一些更有意思的事情,比如我用了著名的网络小说《斗破苍穹》训练了一个RNN模型,可以生成下面的文本:

闻言,萧炎一怔,旋即目光转向一旁的那名灰袍青年,然后目光在那位老者身上扫过,那里,一个巨大的石台上,有着一个巨大的巨坑,一些黑色光柱,正在从中,一道巨大的黑色巨蟒,一股极度恐怖的气息,从天空上暴射而出 ,然后在其中一些一道道目光中,闪电般的出现在了那些人影,在那种灵魂之中,却是有着许些强者的感觉,在他们面前,那一道道身影,却是如同一道黑影一般,在那一道道目光中,在这片天地间,在那巨大的空间中,弥漫而开……

“这是一位斗尊阶别,不过不管你,也不可能会出手,那些家伙,可以为了这里,这里也是能够有着一些异常,而且他,也是不能将其他人给你的灵魂,所以,这些事,我也是不可能将这一个人的强者给吞天蟒,这般一次,我们的实力,便是能够将之击杀……”

“这里的人,也是能够与魂殿强者抗衡。”

萧炎眼眸中也是掠过一抹惊骇,旋即一笑,旋即一声冷喝,身后那些魂殿殿主便是对于萧炎,一道冷喝的身体,在天空之上暴射而出,一股恐怖的劲气,便是从天空倾洒而下。

“嗤!”

还是挺好玩的吧,另外还尝试了生成日文等等。

七、学习完整版的LSTMCell

上面只说了基础版的BasicRNNCell和BasicLSTMCell。TensorFlow中还有一个“完全体”的LSTM:LSTMCell。这个完整版的LSTM可以定义peephole,添加输出的投影层,以及给LSTM的遗忘单元设置bias等,可以参考其源码(地址:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L417)了解使用方法。

八、学习最新的Seq2Seq API

Google在TensorFlow的1.2版本(1.3.0的rc版已经出了,貌似正式版也要出了,更新真是快)中更新了Seq2Seq API,使用这个API我们可以不用手动地去定义Seq2Seq模型中的Encoder和Decoder。此外它还和1.2版本中的新数据读入方式Datasets兼容。可以阅读此处的文档(地址:http://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq)学习它的使用方法。

九、总结

最后简单地总结一下,这篇文章提供了一个学习TensorFlow RNN实现的详细路径,其中包括了学习顺序、可能会踩的坑、源码分析以及一个示例项目hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow),希望能对大家有所帮助。




本文作者:Non
本文转自雷锋网禁止二次转载,原文链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/287426.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【遥感物候】Hants NDVI时间序列谐波分析法数据重构,植被生长季曲线效果可佳(附Hants软件下载)

NDVI时间序列谐波分析法(Harmonic Analysis of NDVI Time-Series)(简称Hants )对时间序列数据进行平滑。该方法是一种新的物候分析方法,可用于定量化的监测植被动态变化。其核心算法是傅里叶变换和最小二乘法拟合, 即把时间波谱数据分解成许多不同频率的正弦曲线和余弦曲线,…

Android之在Java socket作为服务器里面返回数据头部怎么写入浏览器需要下载文件的文件名

1 问题 Android app里面写了一个Java socket的简单服务器,在浏览器里面输入相应的IP和端口访问服务器下载文件,Java socket怎么写返回数据的头部信息,浏览器才知道需要下载文件的名字呢? 2 关于Content-Disposition 在常规的HTTP应答中,Content-Disposition 响应头指示回…

java中hasnext的作用_java中Scanner的hasNext()的疑问

第一个问题,两段代码的区别在于阻塞的位置不同,加上一行输出代码就可以很明显地看到差别。Test.javaimport java.util.Scanner;public class Test {public static void main(String[] args) {Scanner s new Scanner(System.in);while(s.hasNext()){Syst…

《看聊天记录都学不会C语言?太菜了吧》(2)我说编程很容易你们不服?

若是大一学子或者是真心想学习刚入门的小伙伴可以私聊我,若你是真心学习可以送你书籍,指导你学习,给予你目标方向的学习路线,无套路,博客为证。 本系列文章将会以通俗易懂的对话方式进行教学,对话中将涵盖…

ABAP的自学之路 ,初步认识ABAP 一

由于工作的关系,最近需要对SAP系统进行二次开发,于是开始学习ABAP。鉴于网上对于ABAP的资料少之又少,所以自己整理一些资料。 第一章 ABAP 开发环境和总体介绍1.1 ABAP 开发环境ABAP 开发的三种环境:(1)SAP…

LCD1602,4位数据总线液晶屏时钟,STC12C5A60S2的10位ADC功能程序

/* 程序名:    LCD1602,4位数据总线液晶屏时钟,STC12C5A60S2的10位ADC功能程序 编写时间:  2015年10月4日 硬件支持:  LCD1602液晶屏 STC12C5A60S2 外部12MHZ晶振 接线定义: DB7 --> P1^7DB6…

WPF|黑暗模式的钱包支付仪表盘界面设计

收集下大家的意见,是否需要在文中贴上源码(文末会给出源码链接),请大家踊跃留言。阅读目录效果展示准备简单说明 源码结尾(视频及源码仓库)1. 效果展示欣赏效果:2. 准备创建一个WPF工程&#x…

量子计算机的现状和趋势

量子计算机概述 计算机是一种新型的运算 它具有具有强大的并行处理数据的能力,可解决现有计算机难以运算的数学问题。因此,它成为世界各国战略竞争的焦点。 量子计算机的优势 量子计算机与现有的电子计算机以及正在研究的光计算机,生物计算机…

【空间数据库】Windows操作系统PostgreSQL+PostGIS环境搭建图文安装教程

PostgreSQL是一种特性非常齐全的自由软件的对象-关系型数据库管理系统(ORDBMS),PostgreSQL支持大部分的SQL标准并且提供了很多其他现代特性,如复杂查询、外键、触发器、视图、事务完整性、多版本并发控制等。同样,PostgreSQL也可以用许多方法扩展,例如通过增加新的数据类…

Android之gravity=“center_vertical“和layout_gravity=“center“的效果

1、两控件分别加上2个下面的属性 gravity="center_vertical" android:layout_gravity="center" 代码如下 <LinearLayoutandroid:id="@+id/ll_no_love"android:layout_width="match_parent"android:layout_height="match…

《看聊天记录都学不会C语言?太菜了吧》(3)人艰不拆,代码都在谈恋爱?!

若是大一学子或者是真心想学习刚入门的小伙伴可以私聊我&#xff0c;若你是真心学习可以送你书籍&#xff0c;指导你学习&#xff0c;给予你目标方向的学习路线&#xff0c;无套路&#xff0c;博客为证。 本系列文章将会以通俗易懂的对话方式进行教学&#xff0c;对话中将涵盖…

spark java 计数_spark程序——统计包含字符a或者b的行数

本篇分析一个spark例子程序。程序实现的功能是&#xff1a;分别统计包含字符a、b的行数。java源码如下&#xff1a;package sparkTest;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import…

golang reflect

reflect包实现了运行时反射&#xff0c;允许程序操作任意类型的对象。典型用法是用静态类型interface{}保存一个值&#xff0c;通过调用TypeOf获取其动态类型信息&#xff0c;该函数返回一个Type类型值。调用ValueOf函数返回一个Value类型值&#xff0c;该值代表运行时的数据。…

DB2常用命令

查看DB2License信息 DB2基础命令 转载于:https://www.cnblogs.com/arcer/p/5573317.html

.NET7 Preview4之MapGroup

这篇是“闻(看)香(码)识(学)女(技)人(术)”。这也是一个有意思的功能&#xff0c;路由分组&#xff0c;啥也不说了&#xff0c;看代码看结果&#xff1a;using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.OpenApi;var builder WebApplication.Create…

【空间数据库】ArcGIS 10.6 Database_Server_Desktop安装、连接数据库服务、创建企业级数据库(附server10.6.ecp)

由于作者一直使用SQL Server 2008 R2开发版,之前在ArcGIS中创建企业级数据库都是基于单独安装的SQL Server 2008 R2开发版,今天我们演示安装ArcGIS10.6自带的数据库服务(SQL Server 2014 Express版本)、连接数据库服务和创建数据库。 首先,我们来看一下完整的ArcGIS10.6安…

(一)easyUI之树形网络

树形网格&#xff08;TreeGrid&#xff09;可以展示有限空间上带有多列和复杂数据电子表 一、案例一&#xff1a;按tree的数据结构来生成 前台<% page language"java" contentType"text/html; charsetUTF-8"pageEncoding"UTF-8"%> <!DO…

《看聊天记录都学不会C语言?太菜了吧》(4)零基础的我原来早就学会编程了?

若是大一学子或者是真心想学习刚入门的小伙伴可以私聊我&#xff0c;若你是真心学习可以送你书籍&#xff0c;指导你学习&#xff0c;给予你目标方向的学习路线&#xff0c;无套路&#xff0c;博客为证。 本系列文章将会以通俗易懂的对话方式进行教学&#xff0c;对话中将涵盖…

Android之华为平板打日志提示Permission denied

1 问题 $ adb logcat | grep ssfsafaf int logctl_get(): open /dev/hwlog_switch fail -1, 13. Permission deniedNote: log switch off, only log_main and log_events will have logs!2 解决办法 1&#xff09;、如果是华为手机&#xff0c;打开手机的拨号界面&#xff0c…

二叉树结构 codevs 1029 遍历问题

codevs 1029 遍历问题 时间限制: 1 s空间限制: 128000 KB题目等级 : 钻石 Diamond题目描述 Description我们都很熟悉二叉树的前序、中序、后序遍历&#xff0c;在数据结构中常提出这样的问题&#xff1a;已知一棵二叉树的前序和中序遍历&#xff0c;求它的后序遍历&#xff0c;…