Tensorflow实现LSTM详解

关于什么是 LSTM 我就不详细阐述了,吴恩达老师视频课里面讲的很好,我大概记录了课上的内容在吴恩达《序列模型》笔记一,网上也有很多写的好的解释,比如:LSTM入门、理解LSTM网络

然而,理解挺简单,上手写的时候还是遇到了很多的问题,网上大部分的博客都没有讲清楚 cell 参数的设置,在我看了N多篇文章后终于搞明白了,写出来让大家少走一些弯路吧!
在这里插入图片描述
如上图是一个LSTM的单元,可以应用到多种RNN结构中,常用的应该是 one-to-manymany-to-many
在这里插入图片描述
在这里插入图片描述
下面介绍 many-to-many 这种结构:

  1. batch_size:批度训练大小,即让 batch_size 个句子同时训练。
  2. time_steps:时间长度,即句子的长度
  3. embedding_size:组成句子的单词的向量长度(embedding size)
  4. hidden_size:隐藏单元数,一个LSTM结构是一个神经网络(如上图就是一个LSTM单元),每个小黄框是一个神经网络,小黄框的隐藏单元数就是hidden_size,那么这个LSTM单元就有 4*hidden_size 个隐藏单元。
  5. 每个LSTM单元的输出 C、h,都是向量,他们的长度都是当前 LSTM 单元的 hidden_size。
  6. n_words:语料库中单词个数。

实现方式一:

import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnndef add_layer(inputs, in_size, out_size, activation_function=None):  # 单层神经网络weights = tf.Variable(tf.random_normal([in_size, out_size]))baises = tf.Variable(tf.zeros([1, out_size]) + 0.1)wx_b = tf.matmul(inputs, weights) + baisesif activation_function is None:outputs = wx_belse:outputs = activation_function(wx_b)return outputsn_words = 15
embedding_size = 8
hidden_size = 8  # 一般hidden_size和embedding_size是相同的
batch_size = 3
time_steps = 5w = tf.Variable(tf.random_normal([n_words, embedding_size], stddev=0.01))  # 模拟参数 W
sentence = tf.Variable(np.arange(15).reshape(batch_size, time_step, 1))    # 模拟训练的句子:3条句子,每个句子5个单词  shape(3,5,1)
input_s = tf.nn.embedding_lookup(w, sentence)  # 将单词映射到向量:每个单词变成了size为8的向量  shape=(3,5,1,8)
input_s = tf.reshape(input_s, [-1, 5, 8])        # shape(3,5,8)with tf.name_scope("LSTM"):  # trustlstm_cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True, name='lstm_layer') h_0 = tf.zeros([batch_size, embedding_size])  # shape=(3,8)c_0 = tf.zeros([batch_size, embedding_size])  # shape=(3,8)state = rnn.LSTMStateTuple(c=c_0, h=h_0)      # 设置初始状态outputs = []for i in range(time_steps):  # 句子长度if i > 0: tf.get_variable_scope().reuse_variables()  # 名字相同cell使用的参数w就一样,为了避免重名引起别的的问题,设置一下变量重用output, state = lstm_cell(input_s[:, i, :], state)     # output:[batch_size,embedding_size]  shape=(3,8)outputs.append(output)     # outputs:[TIME_STEP,batch_size,embedding_size]  shape=(5,3,8)path = tf.concat(outputs, 1)   # path:[batch_size,embedding_size*TIME_STEP]   shape=(3, 40)path_embedding = add_layer(path, time_step * embedding_size, embedding_size)  # path_embedding:[batch_size, embedding_size]with tf.Session() as s:s.run(tf.global_variables_initializer())# 因为使用的参数数量都还比较小,打印一些变量看看就能明白是怎么操作的print(s.run(outputs))print(s.run(path_embedding))

比如一批训练64句话,每句话20个单词,每个词向量长度为200,隐藏层单元个数为128
那么训练一批句子,输入的张量维度是[64,20,200],ht,ct​ 的维度是[128],那么LSTM单元参数矩阵的维度是[128+200,4x128],
在时刻1,把64句话的第一个单词作为输入,即输入一个[64,200]的矩阵,由于会和 ht 进行concat,输入矩阵变成了[64,200+128],输入矩阵会和参数矩阵[200+128,4x128]相乘,输出为[64,4x128],也就是每个黄框的输出为[64,128],黄框之间会进行一些操作,但不改变维度,输出依旧是[64,128],即每个句子经过LSTM单元后,输出的维度是128,所以每个LSTM输出的都是向量,包括Ct,ht,所以它们的长度都是当前LSTM单元的hidden_size 。那么我们就知道cell_output的维度为[64,128]
之后的时刻重复刚才同样的操作,那么outputs的维度是[20,64,128].
softmax相当于全连接层,将outputs映射到vocab_size个单词上,进行交叉熵误差计算。
然后根据误差更新LSTM参数矩阵和全连接层的参数。

实现方式二:

测试数据链接:https://pan.baidu.com/s/1j9sgPmWUHM5boM5ekj3Q2w 提取码:go3f

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tfdata = pd.read_excel("seq_data.xlsx")  # 读取序列数据
data = data.values[1:800]   # 取前800个
normalize_data = (data - np.mean(data)) / np.std(data)  # 标准化数据
s = np.std(data)
m = np.mean(data)
time_step = 96   # 序列段长度
rnn_unit = 8     # 隐藏层节点数目
lstm_layers = 2  # cell层数
batch_size = 7   # 序列段批处理数目
input_size = 1   # 输入维度
output_size = 1  # 输出维度
lr = 0.006       # 学习率train_x, train_y = [], []
for i in range(len(data) - time_step - 1):x = normalize_data[i:i + time_step]y = normalize_data[i + 1:i + time_step + 1]train_x.append(x.tolist())train_y.append(y.tolist())
X = tf.placeholder(tf.float32, [None, time_step, input_size])  # shape(?,time_step, input_size)
Y = tf.placeholder(tf.float32, [None, time_step, output_size])  # shape(?,time_step, out_size)
weights = {'in': tf.Variable(tf.random_normal([input_size, rnn_unit])),'out': tf.Variable(tf.random_normal([rnn_unit, 1]))}
biases = {'in': tf.Variable(tf.constant(0.1, shape=[rnn_unit, ])),'out': tf.Variable(tf.constant(0.1, shape=[1, ]))}
def lstm(batch):w_in = weights['in']b_in = biases['in']input = tf.reshape(X, [-1, input_size])input_rnn = tf.matmul(input, w_in) + b_ininput_rnn = tf.reshape(input_rnn, [-1, time_step, rnn_unit])cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(rnn_unit) for i in range(lstm_layers)])init_state = cell.zero_state(batch, dtype=tf.float32)output_rnn, final_states = tf.nn.dynamic_rnn(cell, input_rnn, initial_state=init_state, dtype=tf.float32)output = tf.reshape(output_rnn, [-1, rnn_unit])w_out = weights['out']b_out = biases['out']pred = tf.matmul(output, w_out) + b_outreturn pred, final_statesdef train_lstm():global batch_sizewith tf.variable_scope("sec_lstm"):pred, _ = lstm(batch_size)loss = tf.reduce_mean(tf.square(tf.reshape(pred, [-1]) - tf.reshape(Y, [-1])))train_op = tf.train.AdamOptimizer(lr).minimize(loss)saver = tf.train.Saver(tf.global_variables())loss_list = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(100):  # We can increase the number of iterations to gain better result.start = 0end = start + batch_sizewhile (end < len(train_x)):_, loss_ = sess.run([train_op, loss], feed_dict={X: train_x[start:end], Y: train_y[start:end]})start += batch_sizeend = end + batch_sizeloss_list.append(loss_)if i % 10 == 0:print("Number of iterations:", i, " loss:", loss_list[-1])if i > 0 and loss_list[-2] > loss_list[-1]:saver.save(sess, 'model_save1\\modle.ckpt')# I run the code in windows 10,so use  'model_save1\\modle.ckpt'# if you run it in Linux,please use  'model_save1/modle.ckpt'print("The train has finished")train_lstm()def prediction():with tf.variable_scope("sec_lstm", reuse=tf.AUTO_REUSE):pred, _ = lstm(1)saver = tf.train.Saver(tf.global_variables())with tf.Session() as sess:saver.restore(sess, 'model_save1\\modle.ckpt')# I run the code in windows 10,so use  'model_save1\\modle.ckpt'# if you run it in Linux,please use  'model_save1/modle.ckpt'predict = []for i in range(0, np.shape(train_x)[0]):next_seq = sess.run(pred, feed_dict={X: [train_x[i]]})predict.append(next_seq[-1])plt.figure()plt.plot(list(range(len(data))), data, color='b')plt.plot(list(range(time_step + 1, np.shape(train_x)[0] + 1 + time_step)), [value * s + m for value in predict],color='r')plt.show()prediction()

参考文章:

基于TensorFlow构建LSTM
TensorFlow实战:LSTM的结构与cell中的参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文浅尝 | 一个模型解决所有问题:实体和事件的神经联合模型

笔记整理&#xff1a;康矫健&#xff0c;浙江大学计算机科学与技术系&#xff0c;硕士研究生。论文链接&#xff1a;https://arxiv.org/pdf/1812.00195.pdf发表会议&#xff1a;AAAI 2019摘要 近来&#xff0c;针对事件抽取的工作大都集中在预测事件的triggers和arguments r…

AutoPep8-----Pycharm自动排版工具

查找pycharm中的external tool的步骤&#xff1a; https://jingyan.baidu.com/article/84b4f565bd39a060f6da3211.html 今天从 PyCharm 入手&#xff0c;写一些可以明显改善开发效率的使用技巧&#xff0c;一旦学会&#xff0c;受用一生。以下代码演示是在 Mac 环境下&#xf…

阿里P8架构师谈:java架构师面试技能24全点

1,JAVA基础扎实&#xff0c;理解io、多线程、集合等基础框架&#xff0c;对JVM原理有一定的了解&#xff0c;熟悉常见类库,常见java api不仅会用更能知其所以然&#xff1b; 2,对Spring,MyBatis/Hibernate&#xff0c;Struts2,SpringMVC等开源框架熟悉并且了解到它的基本原理和…

百度自然语言处理部招人啦!正式、实习都要!研究、落地都有!

星标/置顶小屋&#xff0c;带你解锁最萌最前沿的NLP、搜索与推荐技术2010年&#xff0c;百度自然语言处理部正式成立。十年来&#xff0c;百度NLP聚集了一大批兼具扎实技术实力和实践经验的AI人才&#xff0c;获得数十项国内外权威奖项&#xff0c;申请专利上千件&#xff0c;发…

我与导师的聊天记录

虽然导师远在马来西亚&#xff0c;但是每次都是很耐心的回答我的问题&#xff0c;真的是非常感激啦&#xff01; 我就想记录下来&#xff0c;自己提出的问题&#xff0c;老师给我的解答&#xff0c;算是我研究生生涯的很大一部分生活了吧&#xff01; 噢~ 还有就是&#xff0c;…

论文浅尝 | 面向知识图谱补全的共享嵌入神经网络模型

论文笔记整理&#xff1a;谭亦鸣&#xff0c;东南大学博士生&#xff0c;研究方向为跨语言知识图谱问答。来源&#xff1a;CIKM’2018链接&#xff1a;http://delivery.acm.org/10.1145/3280000/3271704/p247-guan.pdf?ip121.249.15.96&id3271704&accACTIVE%20SERVICE…

PyTorch常用代码段合集

文 | Jack Stark知乎编 | 极市平台来源 | https://zhuanlan.zhihu.com/p/104019160导读本文是PyTorch常用代码段合集&#xff0c;涵盖基本配置、张量处理、模型定义与操作、数据处理、模型训练与测试等5个方面&#xff0c;还给出了多个值得注意的Tips&#xff0c;内容非常全面。…

想成长为一名实战型架构师?7大实战技能经验分享

很多同学想成为一名架构师,但是对于其中的技能掌握程度&#xff0c;以及编程功底的要求&#xff1f;设计能力的要求有哪些&#xff1f; 我简要从以下7点经验来谈&#xff0c;从技能的角度抛砖引玉。 编程基本功&#xff1a;数据结构和算法 1.数据结构相关的哈希表、链表、二叉…

LeetCode 70. 爬楼梯(动态规划)

题目链接&#xff1a;https://leetcode-cn.com/problems/climbing-stairs/ 之前在递归中讲过这个问题&#xff0c;现在用动态规划求解。 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 注意&…

技术动态 | 藏经阁计划发布一年,阿里知识引擎有哪些技术突破?

本文转载自公众号&#xff1a;阿里技术。导读&#xff1a;2018年4月阿里巴巴业务平台事业部——知识图谱团队联合清华大学、浙江大学、中科院自动化所、中科院软件所、苏州大学等五家机构&#xff0c;联合发布藏经阁&#xff08;知识引擎&#xff09;研究计划。藏经阁计划依赖阿…

python中模块、函数与各个模块之间的调用

1 针对一个模块的函数调用 a &#xff1a; import 模块名 模块名.函数名 b&#xff1a; from 模块名 import 函数名 &#xff08;as 别名&#xff09; python调用另一个.py文件中的类和函数 同一文件夹下的调用 1.调用函数 A.py文件如下&#xff1a; def add(x,y): print(‘和…

模拟退火算法求解TSP问题

前言&#xff1a;模拟退火&#xff08;simulated annealing&#xff09;技术&#xff0c;在每一步都以一定的概率接受比当前结果更差的结果&#xff0c;从而有助于“跳出”局部极小。在每次迭代过程中&#xff0c;接受“次优解”的概率要随着时间的推移而逐渐降低&#xff0c;从…

一篇文章彻底搞懂“分布式事务”

在如今的分布式盛行的时代&#xff0c;分布式事务永远都是绕不开的一个话题&#xff0c;今天就谈谈分布式事务相关的一致性与实战解决方案。 01 为什么需要分布式事务 由于近十年互联网的发展非常迅速&#xff0c;很多网站的访问越来越大&#xff0c;集中式环境已经不能满足业…

C++很难学?这个ACM金牌大佬可不这么认为!

C作为一门底层可操作性很强的语言&#xff0c;广泛应用于游戏开发、工业和追求性能、速度的应用。比如腾讯&#xff0c;无论游戏&#xff0c;还是微信&#xff0c;整个鹅厂后台几乎都是 C 开发&#xff0c;对 C 开发者的需求非常大。但问题是C入门和精通都比较困难&#xff0c;…

数据结构--位图 BitMap

文章目录1. 位图2. 位图代码3. 布隆过滤器 Bloom Filter4. 总结1. 位图 我们有1千万个整数&#xff0c;整数的范围在1到1亿之间。如何快速查找某个整数是否在这1千万个整数中呢&#xff1f; 当然&#xff0c;这个问题可以用散列表来解决。可以使用一种特殊的散列表&#xff0…

领域应用 | 企业效益最大化的秘密:知识图谱

本文转载自公众号&#xff1a;TigerGraph。凡是有关系的地方都可以用知识图谱。知识图谱知识图谱是用节点和关系所组成的图谱&#xff0c;为真实世界的各个场景直观地建模&#xff0c;运用“图”这种基础性、通用性的“语言”&#xff0c;“高保真”地表达这个多姿多彩世界的各…

国家一级职业资格证书 计算机类有哪些

当前bai&#xff0c;计算机证书考试多种du多样&#xff0c;水平参差不齐。比较正规且得到社会zhi认可的dao计算机证书考试有以下几种&#xff1a;全国计算机应用软件人员水平考试、计算机等级考试、计算机及信息高新技术考试、计算机应用水平测试和各种国外著名大计算机公司组织…

阿里P8架构师谈:分布式系统全局唯一ID简介、特点、5种生成方式

什么是分布式系统唯一ID 在复杂分布式系统中&#xff0c;往往需要对大量的数据和消息进行唯一标识。 如在金融、电商、支付、等产品的系统中&#xff0c;数据日渐增长&#xff0c;对数据分库分表后需要有一个唯一ID来标识一条数据或消息&#xff0c;数据库的自增ID显然不能满足…

朴素贝叶斯算法--过滤垃圾短信

文章目录1. 基于黑名单过滤2. 基于规则过滤3. 基于概率统计过滤4. 总结上一节我们讲到&#xff0c;如何用位图、布隆过滤器&#xff0c;来 过滤重复数据。今天&#xff0c;我们再讲一个跟过滤相关的问题&#xff0c;如何过滤垃圾短信&#xff1f;1. 基于黑名单过滤 可以维护一…

2020深度文本匹配最新进展:精度、速度我都要!

文 | QvQ编 | 兔子酱在过去的几年里&#xff0c;信息检索(IR)领域见证了一系列神经排序模型的引入&#xff0c;这些模型多是基于表示或基于交互的&#xff0c;亦或二者的融合。然鹅&#xff0c;模型虽非常有效&#xff0c;尤其是基于 PLMs 的排序模型更是增加了几个数量级的计算…