RNN入门

雷锋网 AI科技评论按:本文作者何之源,原文载于知乎专栏AI Insight,雷锋网(公众号:雷锋网) AI科技评论获其授权发布。

上周写的文章《完全图解RNN、RNN变体、Seq2Seq、Attention机制》介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为:

  • 一个完整的、循序渐进的学习TensorFlow中RNN实现的方法。这个学习路径的曲线较为平缓,应该可以减少不少学习精力,帮助大家少走弯路。

  • 一些可能会踩的坑

  • TensorFlow源码分析

  • 一个Char RNN实现示例,可以用来写诗,生成歌词,甚至可以用来写网络小说!(项目地址:https://github.com/hzy46/Char-RNN-TensorFlow)

一、学习单步的RNN:RNNCell

如果要学习TensorFlow中的RNN,第一站应该就是去了解“RNNCell”,它是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。

借助图片来说可能更容易理解。假设我们有一个初始状态h0,还有输入x1,调用call(x1, h0)后就可以得到(output1, h1):

 

TensorFlow中RNN实现的正确打开方式

再调用一次call(x2, h1)就可以得到(output2, h2):

TensorFlow中RNN实现的正确打开方式

 

也就是说,每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。

在代码实现上,RNNCell只是一个抽象类,我们用的时候都是用的它的两个子类BasicRNNCell和BasicLSTMCell。顾名思义,前者是RNN的基础类,后者是LSTM的基础类。这里推荐大家阅读其源码实现(地址:http://t.cn/RNJrfMl),一开始并不需要全部看一遍,只需要看下RNNCell、BasicRNNCell、BasicLSTMCell这三个类的注释部分,应该就可以理解它们的功能了。

除了call方法外,对于RNNCell,还有两个类属性比较重要:

  • state_size

  • output_size

前者是隐层的大小,后者是输出的大小。比如我们通常是将一个batch送入模型计算,设输入数据的形状为(batch_size, input_size),那么计算时得到的隐层状态就是(batch_size, state_size),输出就是(batch_size, output_size)。

可以用下面的代码验证一下(注意,以下代码都基于TensorFlow最新的1.2版本):

import tensorflow as tf

import numpy as np

 

cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128) # state_size = 128

print(cell.state_size) # 128

 

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态,形状为(batch_size, state_size)

output, h1 = cell.call(inputs, h0) #调用call函数

 

print(h1.shape) # (32, 128)

对于BasicLSTMCell,情况有些许不同,因为LSTM可以看做有两个隐状态h和c,对应的隐层就是一个Tuple,每个都是(batch_size, state_size)的形状:

import tensorflow as tf

import numpy as np

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = lstm_cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态

output, h1 = lstm_cell.call(inputs, h0)

 

print(h1.h)  # shape=(32, 128)

print(h1.c)  # shape=(32, 128)

二、学习如何一次执行多步:tf.nn.dynamic_rnn

基础的RNNCell有一个很明显的问题:对于单个的RNNCell,我们使用它的call函数进行运算时,只是在序列时间上前进了一步。比如使用x1、h0得到h1,通过x2、h1得到h2等。这样的h话,如果我们的序列长度为10,就要调用10次call函数,比较麻烦。对此,TensorFlow提供了一个tf.nn.dynamic_rnn函数,使用该函数就相当于调用了n次call函数。即通过{h0,x1, x2, …., xn}直接得{h1,h2…,hn}。

具体来说,设我们输入数据的格式为(batch_size, time_steps, input_size),其中time_steps表示序列本身的长度,如在Char RNN中,长度为10的句子对应的time_steps就等于10。最后的input_size就表示输入数据单个序列单个时间维度上固有的长度。另外我们已经定义好了一个RNNCell,调用该RNNCell的call函数time_steps次,对应的代码就是:

# inputs: shape = (batch_size, time_steps, input_size)

# cell: RNNCell

# initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵

outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

此时,得到的outputs就是time_steps步里所有的输出。它的形状为(batch_size, time_steps, cell.output_size)。state是最后一步的隐状态,它的形状为(batch_size, cell.state_size)。

此处建议大家阅读tf.nn.dynamic_rnn的文档(地址:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)做进一步了解。

三、学习如何堆叠RNNCell:MultiRNNCell

很多时候,单层RNN的能力有限,我们需要多层的RNN。将x输入第一层RNN的后得到隐层状态h,这个隐层状态就相当于第二层RNN的输入,第二层RNN的隐层状态又相当于第三层RNN的输入,以此类推。在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函数对RNNCell进行堆叠,相应的示例程序如下:

import tensorflow as tf

import numpy as np

 

# 每调用一次这个函数就返回一个BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell创建3层RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3层RNN

# 得到的cell实际也是RNNCell的子类

# 它的state_size是(128, 128, 128)

# (128, 128, 128)并不是128x128x128的意思

# 而是表示共有3个隐层状态,每个隐层状态的大小为128

print(cell.state_size) # (128, 128, 128)

# 使用对应的call函数

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态

output, h1 = cell.call(inputs, h0)

print(h1) # tuple中含有3个32x128的向量

通过MultiRNNCell得到的cell并不是什么新鲜事物,它实际也是RNNCell的子类,因此也有call方法、state_size和output_size属性。同样可以通过tf.nn.dynamic_rnn来一次运行多步。

此处建议阅读MutiRNNCell源码(地址:http://t.cn/RNJrfMl)中的注释进一步了解其功能。

四、可能遇到的坑1:Output说明

在经典RNN结构中有这样的图:

TensorFlow中RNN实现的正确打开方式

在上面的代码中,我们好像有意忽略了调用call或dynamic_rnn函数后得到的output的介绍。将上图与TensorFlow的BasicRNNCell对照来看。h就对应了BasicRNNCell的state_size。那么,y是不是就对应了BasicRNNCell的output_size呢?答案是否定的。

找到源码中BasicRNNCell的call函数实现:

def call(self, inputs, state):
   """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
   output = self._activation(_linear([inputs, state], self._num_units, True))
   return output, output

这句“return output, output”说明在BasicRNNCell中,output其实和隐状态的值是一样的。因此,我们还需要额外对输出定义新的变换,才能得到图中真正的输出y。由于output和隐状态是一回事,所以在BasicRNNCell中,state_size永远等于output_size。TensorFlow是出于尽量精简的目的来定义BasicRNNCell的,所以省略了输出参数,我们这里一定要弄清楚它和图中原始RNN定义的联系与区别。

再来看一下BasicLSTMCell的call函数定义(函数的最后几行):

new_c = (
   c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))

new_h = self._activation(new_c) * sigmoid(o)

 

if self._state_is_tuple:
 new_state = LSTMStateTuple(new_c, new_h)

else:
 new_state = array_ops.concat([new_c, new_h], 1)

return new_h, new_state

我们只需要关注self._state_is_tuple == True的情况,因为self._state_is_tuple == False的情况将在未来被弃用。返回的隐状态是new_c和new_h的组合,而output就是单独的new_h。如果我们处理的是分类问题,那么我们还需要对new_h添加单独的Softmax层才能得到最后的分类概率输出。

还是建议大家亲自看一下源码实现(地址:http://t.cn/RNJsJoH)来搞明白其中的细节。

五、可能遇到的坑2:因版本原因引起的错误

在前面我们讲到堆叠RNN时,使用的代码是:

# 每调用一次这个函数就返回一个BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用tf.nn.rnn_cell MultiRNNCell创建3层RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3层RNN

这个代码在TensorFlow 1.2中是可以正确使用的。但在之前的版本中(以及网上很多相关教程),实现方式是这样的:

one_cell =  tf.nn.rnn_cell.BasicRNNCell(num_units=128)

cell = tf.nn.rnn_cell.MultiRNNCell([one_cell] * 3) # 3层RNN

如果在TensorFlow 1.2中还按照原来的方式定义,就会引起错误!

参考自https://www.leiphone.com/news/201709/QJAIUzp0LAgkF45J.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/567929.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java二进制小数表示_《Java编程的逻辑》笔记9--小数的二进制表示

小数计算为什么会出错?简要答案实际上,不是运算本身会出错,而是计算机根本就不能精确的表示很多数,比如0.1这个数。计算机是用一种二进制格式存储小数的,这个二进制格式不能精确表示0.1,它只能表示一个非常…

『TensorFlow』模型保存和载入方法汇总

一、TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称功能说明默认值var_listSaver中存储变量集合全局变量集合reshape加载时是否恢复变量形状Truesharded是否将变量轮循放在所有设备上Truemax_to_keep保留最近检…

STL13-list容器(链表)

链表是由一系列的结点组成,结点包括两个域:一个数据域,一个指针域 1、链表内存是非连续的,添加删除元素效率较高,时间复杂度都是常数项,不需要移动元素 2、链表只有在需要的时候才会分配内存 3、链表只要…

php 前往页面,PHP实现网页截图?

如何使用PHP实现网页截图PHP实现网页截图是一个在日常开发中不常见的需求,但是如果实现还是非常有意思的。目前业界有很多成熟的方案,下面我推荐使用一个很稳定的第三方服务来直接实现,该服务有如下特点:支持多线路支持登录截图支…

STL14-set/multiset容器

set只有一个方法就是insert #include<iostream> #include<set> //set和multiset是一个头文件 //set内部实现机制 红黑色&#xff08;平衡二叉树的一种&#xff09; //关联式容器 //set不允许有重复元素 //multiset运行有重复元素 //容器查找效率高 //容器根据元素的…

普通的java类型是指,String是一个很普通的类 - Java那些事儿

上一篇我们讲了Java中的数组&#xff0c;其实是为本章的内容做准备的&#xff0c;String这个类是我们在写Java代码中用得最多的一个类&#xff0c;没有之一&#xff0c;今天我们就讲讲它&#xff0c;我们打开String这个类的源码&#xff1a;声明了一个char[]数组&#xff0c;变…

STL15-map/multimap容器

map的key值不可以重复 multimap的key值可以重复 #if 1 #include<iostream> #include<map> using namespace std; //初始化 void test01() {//map容器参数 第一个参数key的类型 第二个参数value类型map<int, int> mymap;//插入元素 pair.first key值 pair.se…

php nginx日志分析,如何通过NGINX的log日志来分析网站的访问情况,试试这些命令...

想知道你的网站每天的访问情况吗&#xff1f;有多少人访问了&#xff1f;访问最多的页面是哪个&#xff1f;哪个时段访问的人最多&#xff1f;哪个地方访问的最多&#xff1f;每秒有多少请求&#xff1f;很好奇吧&#xff0c;只要你是使用了nginx进行请求抓发&#xff0c;那么就…

php带来互联网的影响,网络对我们的影响有哪些?

影响有&#xff1a;1、丰富了我们的业余生活&#xff1b;2、降低了获取知识的成本&#xff0c;降低了提升工作的能力的成本&#xff0c;提高了工作的效率&#xff0c;可以快速建立良好的人脉关系&#xff1b;3、让购物变得更加简单便捷&#xff1b;4、朋友间深度沟通与交流越来…

STL17-函数对象

仿函数&#xff1a; #include<iostream> #include<vector> #include<algorithm> using namespace std; //仿函数&#xff08;函数对象&#xff09;重载“&#xff08;&#xff09;”操作符 使类对象可以像函数那样调用 //仿函数是一个类&#xff0c;不是一个…

STL18常用算法

#include<iostream> #include<algorithm> #include<vector> using namespace std; //transform 将一个容器中的元素搬运在另一个容器中 #if 0 //错误 struct PrintVector {void operator()(int v) {cout << v << " ";} }; void test0…

php中页面平滑回到顶部代码,原生JS实现平滑回到顶部组件

返回顶部组件是一种极其常见的网页功能&#xff0c;需求简单&#xff1a;页面滚动一定距离后&#xff0c;显示返回顶部的按钮&#xff0c;点击该按钮可以将滚动条滚回至页面开始的位置。实现思路也很容易&#xff0c;只要改变document.documentElement.scrollTop或document.bod…

C++基础01-C++对c的增强

所谓namespace&#xff0c;是指标识符的各种可见范围。C标准程序库中的所 有标识符都被定义于一个名为std的namespace中。 一 &#xff1a;<iostream>和<iostream.h>格式不一样&#xff0c;前者没有后缀&#xff0c;实际上&#xff0c; 在你的编译器include文件夹…

C++基础02-C++对c的拓展

变量名实质上是一段连续存储空间的别名&#xff0c;是一个标号(门牌号) 通过变量来申请并命名内存空间. 通过变量的名字可以使用存储空间. 变量名&#xff0c;本身是一段内存的引用&#xff0c;即别名(alias). 引用可以看作一个已定义变量的别名。 引用的语法&#xff…

php小程序onload,微信小程序 loading 组件实例详解

这篇文章主要介绍了微信小程序 loading 组件实例详解的相关资料,需要的朋友可以参考下loading通常使用在请求网络数据时的一种方式&#xff0c;通过hidden属性设置显示与否主要属性&#xff1a;wxml显示loading正在加载jsPage({data:{// text:"这是一个页面"hiddenLo…

C++基础04-类基础

一、类和对象 面向对象三大特点&#xff1a;封装、继承、多态。 struct 中所有行为和属性都是 public 的(默认)。C中的 class 可以指定行为和属性的访问方式。 封装,可以达到,对内开放数据,对外屏蔽数据,对外提供接口。达到了信息隐蔽的功能。 class 封装的本质,在于将数…

C++基础05-类构造函数与析构函数

总结&#xff1a; 1、类对象的作用域为两个{}之间。在遇到}后开始执行析构函数 2、当没有任何显式的构造函数&#xff08;无参&#xff0c;有参&#xff0c;拷贝构造&#xff09;时&#xff0c;默认构造函数才会发挥作用 一旦提供显式的构造函数&#xff0c;默认构造函数不复…

PHP网站配置项,Thinkphp5通用网站后台配置项的动态添加及更新

一、引入无论平时我们自己制作&#xff0c;还是浏览别人的网站&#xff0c;它都具有其相应的一些共用的、通用的属性&#xff0c;比如&#xff1a;网站的名字&#xff0c;关键字、备案号、分页数量、是否开启缓存等信息。一些网站可能将配置项写死在后台&#xff0c;无法动态更…

oracle 查询cpu 100%,Oracle 11g中查询CPU占有率高的SQL

oracle版本&#xff1a;oracle11g背景&#xff1a;今天在Linux中的oracle服务上&#xff0c;运用top命令发现许多进程的CPU占有率是100%。操作步骤&#xff1a;以进程PID:7851为例执行以下语句&#xff1a;方法一&#xff1a;(1)通过PID&#xff0c;查得相对应的系统进程对应的…

C++基础08-this指针-const修饰成员函数-函数返回引用/值

一、this指针 1、C类对象中的成员变量和成员函数是分开存储的。C语言中的内存四区模型仍然有效&#xff01; 2、C中类的普通成员函数都隐式包含一个指向当前对象的this指针。 3、静态成员函数、成员变量属于类 4、静态成员函数与普通成员函数的区别 静态成员函数不包含指…