lstm需要优化的参数_使用PyTorch手写代码从头构建LSTM,更深入的理解其工作原理...

这是一个造轮子的过程,但是从头构建LSTM能够使我们对体系结构进行更加了解,并将我们的研究带入下一个层次。

LSTM单元是递归神经网络深度学习研究领域中最有趣的结构之一:它不仅使模型能够从长序列中学习,而且还为长、短期记忆创建了一个数值抽象,可以在需要时相互替换。

751f844ccf62f3619125d59300241223.png

在这篇文章中,我们不仅将介绍LSTM单元的体系结构,还将通过PyTorch手工实现它。

最后但最不重要的是,我们将展示如何对我们的实现做一些小的调整,以实现一些新的想法,这些想法确实出现在LSTM研究领域,如peephole。

LSTM体系结构

LSTM被称为门结构:一些数学运算的组合,这些运算使信息流动或从计算图的那里保留下来。因此,它能够“决定”其长期和短期记忆,并输出对序列数据的可靠预测:

1b6a8415463c8ddebb44ba41672bdee1.png

LSTM单元中的预测序列。注意,它不仅会传递预测值,而且还会传递一个c,c是长期记忆的代表

遗忘门

遗忘门(forget gate)是输入信息与候选者一起操作的门,作为长期记忆。请注意,在输入、隐藏状态和偏差的第一个线性组合上,应用一个sigmoid函数:

344929c78260868950e1737331a9dde8.png

sigmoid将遗忘门的输出“缩放”到0-1之间,然后,通过将其与候选者相乘,我们可以将其设置为0,表示长期记忆中的“遗忘”,或者将其设置为更大的数字,表示我们从长期记忆中记住的“多少”。

新型长时记忆的输入门及其解决方案

输入门是将包含在输入和隐藏状态中的信息组合起来,然后与候选和部分候选c''u t一起操作的地方:

d0dfb1555009820c429fca38a56a8402.png

在这些操作中,决定了多少新信息将被引入到内存中,如何改变——这就是为什么我们使用tanh函数(从-1到1)。我们将短期记忆和长期记忆中的部分候选组合起来,并将其设置为候选。

单元的输出门和隐藏状态(输出)

之后,我们可以收集ot作为LSTM单元的输出门,然后将其乘以候选单元(长期存储器)的tanh,后者已经用正确的操作进行了更新。网络输出为ht。

2eda49c81569ac2a102ddce6a557e602.png

LSTM单元方程

95673b00cbe1aeebef348057b550a9c7.png

在PyTorch上实现

import math
import torch
import torch.nn as nn

我们现在将通过继承nn.Module,然后还将引用其参数和权重初始化,如下所示(请注意,其形状由网络的输入大小和输出大小决定):

class NaiveCustomLSTM(nn.Module):
def __init__(self, input_sz: int, hidden_sz: int):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
#i_t
self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
#f_t
self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
#c_t
self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
#o_t
self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()

要了解每个操作的形状,请看:

矩阵的输入形状是(批量大小、序列长度、特征长度),因此将序列的每个元素相乘的权重矩阵必须具有该形状(特征长度、输出长度)。

序列上每个元素的隐藏状态(也称为输出)都具有形状(批大小、输出大小),这将在序列处理结束时产生输出形状(批大小、序列长度、输出大小)。-因此,将其相乘的权重矩阵必须具有与单元格的参数hiddensz相对应的形状(outputsize,output_size)。

这里是权重初始化,我们将其用作PyTorch默认值中的权重初始化nn.Module:

def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)

前馈操作

前馈操作接收initstates参数,该参数是上面方程的(ht,ct)参数的元组,如果不引入,则设置为零。然后,我们对每个保留(ht,c_t)的序列元素执行LSTM方程的前馈,并将其作为序列下一个元素的状态引入。

最后,我们返回预测和最后一个状态元组。让我们看看它是如何发生的:

def forward(self,x,init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
#reshape hidden_seq p/ retornar
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

优化版本

这个LSTM在运算上是正确的,但在计算时间上没有进行优化:我们分别执行8个矩阵乘法,这比矢量化的方式慢得多。我们现在将演示如何通过将其减少到2个矩阵乘法来完成,这将使它更快。

为此,我们设置了两个矩阵U和V,它们的权重包含在4个矩阵乘法上。然后,我们对已经通过线性组合+偏置操作的矩阵执行选通操作。

通过矢量化操作,LSTM单元的方程式为:

afc365239768ca48a6fb365d64482711.png

class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

最后但并非最不重要的是,我们可以展示如何优化,以使用LSTM peephole connections。

LSTM peephole

LSTM peephole对其前馈操作进行了细微调整,从而将其更改为优化的情况:

1f1bf717535b3a64fa898c7c1ba5c12b.png

如果LSTM实现得很好并经过优化,我们可以添加peephole选项,并对其进行一些小的调整:

class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, peephole=False):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.peephole = peephole
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
if self.peephole:
gates = x_t @ U + c_t @ V + bias
else:
gates = x_t @ U + h_t @ V + bias
g_t = torch.tanh(gates[:, HS*2:HS*3])
i_t, f_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.sigmoid(gates[:, HS*3:]), # output
)
if self.peephole:
c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
h_t = torch.tanh(o_t * c_t)
else:
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

我们的LSTM就这样结束了。如果有兴趣大家可以将他与torch LSTM内置层进行比较。

代码:https://github.com/piEsposito/pytorch-lstm-by-hand

作者:Piero Esposito

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/276897.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

有哪些漂亮的中国风 LOGO 设计?

提到中国风的logo,我觉得首先登场的应该是北京故宫博物院的logo,铛!故宫博物院的logo,从颜色,到外形,到元素,无一例外,充满了中国风的味道,可谓是中国风中的典型。同一风…

大家放松下,仿《大腕》经典对白

仿《大腕》经典对白: 一定要找那最流行的框架, 用功能最强大编辑器, 做就要做最复杂的系统, 轻量级的绝对不行, 框架最简单也得是SPRING&…

MySQL-8.0.12源码安装实例

1、通过官网下载对应的版本后,通过FTP上传至云服务器的/usr/local/src 目录 2、解压缩文件 [rootJSH-01 src]# ls mysql-boost-8.0.12.tar.gz [rootJSH-01 src]# tar zxvf mysql-boost-8.0.12.tar.gz [rootJSH-01 src]# ls mysql-8.0.12 mysql-boost-8.0.12.tar.gz…

python3常用模块_Python3 常用模块

一、time与datetime模块 在Python中,通常有这几种方式来表示时间: 时间戳(timestamp):通常来说,时间戳表示的是从1970年1月1日00:00:00开始按秒计算的偏移量。我们运行“type(time.time())”,返回的是float类型。 格式…

Windows下的HEAP溢出及其利用

Windows下的HEAP溢出及其利用 作者: isno 一、概述 前一段时间ASP的溢出闹的沸沸扬扬,这个漏洞并不是普通的堆栈溢出,而是发生在HEAP中的溢出,这使大家重新认识到了Windows下的HEAP溢出的可利用性。其实WIN下的HEAP溢出比Linux和SOLARIS下面的…

地方政府不愿房价下跌 救市或化解房地产调控

地方政府不愿房价下跌 "救市"或化解房地产调控 2008年05月09日 07:29:38  来源:上海证券报 漫画 刘道伟 由于房地产业与地方政府利益攸关,地方政府最不愿意看到房价下跌。中央房地产调控政策刚刚导致部分城市的房价步入调整,一些…

App移动端性能工具调研

使用GT的差异化场景平台描述release版本development版本Android在Android平台上,如果希望使用GT的高级功能,如“插桩”等,就必须将GT的SDK嵌入到被调测的应用的工程里,再配合安装好的GT使用。支持AndroidiOS在iOS平台上&#xff0…

UITabBar Contoller

。UITabBar中的UIViewController获得控制权:在TabBar文件中添加:IBOutlet UITabBar *myTabBar; //在xib中连接tabBar;(void)tabBarController:(UITabBarController *)tabBarController didSelectViewController:      (UIViewControlle…

python3.5安装pip_win10上python3.5.2第三方库安装(运用pip)

1 首先在python官网下载并安装python。我这儿用的是python3.5.2,其自带了pip。如果你选择的版本没有自带pip,那么请查找其他的安装教程。 2 python安装好以后,我在其自带的命令提示符窗口中输入了pip,结果尴尬了,提示我…

C语言程序设计 练习题参考答案 第八章 文件(2)

/* 8.8从文件ex88_1.txt中取出成绩,排序后,按降序存放EX88_2.TXT中 */ #include "stdio.h" #define N 10 struct student { int num; char name[20]; int score[3]; /*不能使用float*/ float average; }; void sort(struc…

语法上的小trick

语法上的小trick 构造函数 虽然不写构造函数也是可以的,但是可能会开翻车,所以还是写上吧。: 提供三种写法: ​ 使用的时候只用: 注意,这里的A[i]gg(3,3,3)的“gg”不能打括号,否则就是强制转换…

Ubuntu18.04如何让桌面软件默认root权限运行?

什么是gksu? 什么是gksu:Linxu中的gksu是系统中的su/sudo工具,如果安装了gksu,在终端中键入gksu会弹出一个对话框. 安装gksu: 在Ubuntu之前的版本中是继承gksu工具的,但是在Ubutu18.04中并没有集成, 在Elementary OS中连gksu的APT源都没有. Ubuntu18.04 安装和使用gksu: seven…

win10诊断启动后联网_小技巧:win10网络共享文件夹出现错误无法访问如何解决?...

win10系统共享文件夹时在资源管理器中的网络里能够看到所共享的文件夹,但在打开文件夹时却出现 Windows无法访问 Desktop-r8ceh55新建文件夹 请检查名称的拼写。否则,网络可能有问题。要尝试识别并解决网络问题,请单击“诊断”的错误提示&…

两段关于统计日期的sql语句

统计月份:selectleft(convert(char(10),[Article_TimeDate],102),7) as月份, count(*) as数量from[hdsource].[dbo].[article]groupbyleft(convert(char(10),[Article_TimeDate],102),7)orderby1统计年份: selectleft(convert(char(10),[Article_TimeDat…

apache配置文件详解与优化

apache配置文件详解与优化 一、总结 一句话总结&#xff1a;结合apache配置文件中的英文说明和配置详解一起看 1、apache模块配置用的什么标签&#xff1f; IfModule 例如&#xff1a; <IfModule dir_module>DirectoryIndex index.html 索引文件 首页文件&#xff08;首页…

帆软报表(finereport)单元格函数,OP参数

单元格模型&#xff1a;单元格数据和引用&#xff1a;数据类型、实际值与显示值、单元格支持的操作单元格样式&#xff1a;行高列宽、隐藏行列、自动换行、上下标、文字竖排、大文本字段分页时断开、标识说明、格式刷单元格Web属性&#xff1a;web显示、web编辑风格、控件实际值…

sklearn 安装_sklearn-classification_report

原型sklearn.metrics.classification_report(y_true, y_pred, labelsNone, target_namesNone, sample_weightNone, digits2)参数y_true&#xff1a;1维数组或标签指示数组/离散矩阵&#xff0c;样本实际类别值列表y_pred&#xff1a;1维数组或标签指示数组/离散矩阵&#xff0c…

effective c++条款11扩展——关于拷贝构造函数和赋值运算符

effective c条款11扩展——关于拷贝构造函数和赋值运算符 作者&#xff1a;冯明德重点:包含动态分配成员的类 应提供拷贝构造函数,并重载""赋值操作符。 以下讨论中将用到的例子: class CExample { public: CExample(){pBufferNULL; nSize0;} ~CExample(){delete pB…

SparkSQL 之 Shuffle Join 内核原理及应用深度剖析-Spark商业源码实战

本套技术专栏是作者&#xff08;秦凯新&#xff09;平时工作的总结和升华&#xff0c;通过从真实商业环境抽取案例进行总结和分享&#xff0c;并给出商业应用的调优建议和集群环境容量规划等内容&#xff0c;请持续关注本套博客。版权声明&#xff1a;禁止转载&#xff0c;欢迎…

Python标准库之csv(1)

1.Python处理csv文件之csv.writer() import csvdef csv_write(path,data):with open(path,w,encodingutf-8,newline) as f:writer csv.writer(f,dialectexcel)for row in data:writer.writerow(row)return True 调用上面的函数 data [[Name,Height],[Keys,176cm],[HongPing,1…