LSTM简介
一个目前很火的特殊的RNN结构, 有效解决了RNN的梯度爆炸和长序列记忆问题
优势
LSTM 通过引入遗忘门、输入门、输出门, 来实现对特殊特征的记忆和遗忘,来达到更好的对序列数据的处理和记忆效果。
原理图:
总结公式:
大概就是这样的一个公式
简单来说就是,LSTM一共有三个门,输入门,遗忘门,输出门,
分别为三个门的程度参数,
g 是对输入的常规RNN操作。
公式里可以看到LSTM的输出有两个,细胞状态C’
和隐状态 h’
c’是经输入、遗忘门的产物,也就是当前cell本身的内容,经过输出门得到h’,就是想输出什么内容给下一单元
那么实际应用时,我们并不关心细胞本身的状态,而是要拿到它呈现出的状态
h’作为最终输出.
实现
利用pytorch 手动实现lstm
构建公式
class myLstm(nn.Module):def __init__(self,input_sz,hidden_sz):super().__init__()self.input_size=input_szself.hidden_size=hidden_szself.U_i=nn.Parameter(torch.Tensor(input_sz,hidden_sz))self.V_i = nn.Parameter(torch.Tensor(hidden_sz,hidden_sz))self.b_i = nn.parameter(torch.Tensor(hidden_sz))#f_tself.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))self.b_f = nn.Parameter(torch.Tensor(hidden_sz))#c_tself.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))self.b_c = nn.Parameter(torch.Tensor(hidden_sz))#o_tself.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))self.b_o = nn.Parameter(torch.Tensor(hidden_sz))self.init_weights()def forward(self,x,init_states=None):bs,seq_sz,_=x.size()hidden_seq=[]if init_states is None:h_t,c_t=(torch.zeros(bs,self.hidden_size).to(x.device),torch.zeros(bs,self.hidden_size).to(x.device))else:h_t, c_t = init_statesfor t in range(seq_sz):x_t = x[:, t, :]i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)c_t = f_t * c_t + i_t * g_th_t = o_t * torch.tanh(c_t)hidden_seq.append(h_t.unsqueeze(0))hidden_seq = torch.cat(hidden_seq, dim=0)hidden_seq = hidden_seq.transpose(0, 1).contiguous()return hidden_seq, (h_t, c_t)