双向RNN和双向LSTM

双向RNN和双向LSTM

一、双向循环神经网络BiRNN

1、为什么要用BiRNN

双向RNN,即可以从过去的时间点获取记忆,又可以从未来的时间点获取信息,也就是说具有以下两个特点:

捕捉前后文信息:传统的单向 RNN 只能利用先前的上下文信息,而 BiRNN 同时利用了输入序列的前后文信息。在很多任务中,如自然语言处理中的命名实体识别、机器翻译等,理解一个词的前后文语境至关重要。

例如:img

判断句子中Teddy是否是人名,如果只从前面两个词是无法得知Teddy是否是人名,如果能有后面的信息就很好判断了,这就需要用的双向循环神经网络。

提高精度:在处理某些序列数据时,单向 RNN 可能无法充分捕捉整个序列中的重要信息,导致性能欠佳。BiRNN 能够通过双向处理,提高模型的表达能力和准确度。

2、BiRNN的架构

双向循环神经网络(BRNN)的基本思想是提出每一个训练序列向前和向后分别是两个循环神经网络(RNN),而且这两个都连接着一个输出层。这个结构提供给输出层输入序列中每一个点的完整的过去和未来的上下文信息。下图展示的是一个沿着时间展开的双向循环神经网络。六个独特的权值在每一个时步被重复的利用,六个权值分别对应:输入到向前和向后隐含层(w1, w3),隐含层到隐含层自己(w2, w5),向前和向后隐含层到输出层(w4, w6)。值得注意的是:向前和向后隐含层之间没有信息流,这保证了展开图是非循环的。每一个输出都是综合考虑两个方向获得的结果再输出,如下图所示:

在这里插入图片描述

H → t = ϕ ( X t W x h ( f ) + H → t − 1 W h h ( f ) + b h ( f ) ) , H ← t = ϕ ( X t W x h ( b ) + H ← t + 1 W h h ( b ) + b h ( b ) ) , \begin{array}{l} \overrightarrow{\mathbf{H}}_{t}=\phi\left(\mathbf{X}_{t} \mathbf{W}_{x h}^{(f)}+\overrightarrow{\mathbf{H}}_{t-1} \mathbf{W}_{h h}^{(f)}+\mathbf{b}_{h}^{(f)}\right), \\ \overleftarrow{\mathbf{H}}_{t}=\phi\left(\mathbf{X}_{t} \mathbf{W}_{x h}^{(b)}+\overleftarrow{\mathbf{H}}_{t+1} \mathbf{W}_{h h}^{(b)}+\mathbf{b}_{h}^{(b)}\right), \end{array} H t=ϕ(XtWxh(f)+H t1Whh(f)+bh(f)),H t=ϕ(XtWxh(b)+H t+1Whh(b)+bh(b)),
拼接得到结果:
H t = [ H → t H ← t ] \mathbf{H}_{t}=\left[\overrightarrow{\mathbf{H}}_{t} \overleftarrow{\mathbf{H}}_{t}\right] Ht=[H tH t]
至于网络单元到底是标准的RNN还是GRU或者是LSTM是没有关系的

(GRU:把遗忘门和输入门合并成一个更新门(Update Gate),并且把Cell State和Hidden State也合并成一个Hidden State,它的计算如下图所示)

在这里插入图片描述

对于整个双向循环神经网络(BRNN)的计算过程如下:

向前推算(Forward pass):

对于双向循环神经网络(BRNN)的隐含层,向前推算跟单向的循环神经网络(RNN)一样,除了输入序列对于两个隐含层是相反方向的,输出层直到两个隐含层处理完所有的全部输入序列才更新:

img

向后推算(Backward pass):

双向循环神经网络(BRNN)的向后推算与标准的循环神经网络(RNN)通过时间反向传播相似,除了所有的输出层δ项首先被计算,然后返回给两个不同方向的隐含层:

img

3、代码(用IMDB进行情感分析)

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import spacy
from torchtext.vocab import GloVe# 加载数据集
base_csv = r'E:\kaggle\情感分析\archive (1)\IMDB Dataset.csv'
df = pd.read_csv(base_csv)# 分割数据集
train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)# 保存为CSV文件
train_data.to_csv('train.csv', index=False)
test_data.to_csv('test.csv', index=False)# 加载Spacy分词器
spacy_en = spacy.load('en_core_web_sm')
tokenizer = spacy_en.tokenizerclass TextDataset(Dataset):def __init__(self, dataframe, text_field, label_field, vocab):self.dataframe = dataframeself.text_field = text_fieldself.label_field = label_fieldself.vocab = vocabdef __len__(self):return len(self.dataframe)def __getitem__(self, idx):text = self.dataframe.iloc[idx][self.text_field]label = self.dataframe.iloc[idx][self.label_field]tokens = tokenizer(text)token_ids = [self.vocab.get(token.text, self.vocab['<unk>']) for token in tokens]label = 1 if label == "positive" else 0return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)# 加载预训练词向量
glove_embeddings = GloVe(name='6B', dim=100)
vocab = glove_embeddings.stoi
vocab['<unk>'] = len(vocab)  # 添加 <unk> 标记
glove_embeddings.vectors = torch.cat((glove_embeddings.vectors, torch.zeros(1, glove_embeddings.dim)), 0)# 创建数据集
train_dataset = TextDataset(train_data, 'review', 'sentiment', vocab)
test_dataset = TextDataset(test_data, 'review', 'sentiment', vocab)# 创建数据加载器
batch_size = 32def collate_fn(batch):texts, labels = zip(*batch)text_lengths = [len(text) for text in texts]padded_texts = nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)return padded_texts, torch.tensor(labels, dtype=torch.long)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)# 定义模型
class BiRNN(nn.Module):def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_classes, pad_idx):super(BiRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.num_directions = 2self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)self.rnn = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True, bidirectional=True)self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)def forward(self, x):h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)x = self.embedding(x)out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :])return out# 设置参数
vocab_size = len(vocab)
embed_size = 100
hidden_size = 256
num_layers = 2
num_classes = 2
pad_idx = 0model = BiRNN(vocab_size, embed_size, hidden_size, num_layers, num_classes, pad_idx)
model.embedding.weight.data.copy_(glove_embeddings.vectors)
model.embedding.weight.data[pad_idx] = torch.zeros(embed_size)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)# 训练模型
num_epochs = 5for epoch in range(num_epochs):model.train()epoch_loss = 0for texts, labels in train_loader:texts = texts.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(texts)loss = criterion(outputs, labels)loss.backward()optimizer.step()epoch_loss += loss.item()print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(train_loader):.4f}')# 评估模型
model.eval()
with torch.no_grad():correct = 0total = 0for texts, labels in test_loader:texts = texts.to(device)labels = labels.to(device)outputs = model(texts)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f'Accuracy: {accuracy * 100:.2f}%')# 保存模型
torch.save(model.state_dict(), 'model.ckpt')# 预测函数
def predict(text, model, vocab, device):model.eval()tokens = tokenizer(text)token_ids = [vocab.get(token.text, vocab['<unk>']) for token in tokens]token_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)with torch.no_grad():output = model(token_tensor)_, predicted = torch.max(output.data, 1)return 'positive' if predicted.item() == 1 else 'negative'# 示例使用
sample_text = "This movie was fantastic! I really enjoyed it."
prediction = predict(sample_text, model, vocab, device)
print(f'Prediction: {prediction}')

4、多层RNN

将许多RNN层堆叠,构成一个多层RNN网络。RNN每读取一个新的 x t \mathrm{x}_{\mathrm{t}} xt输就会生成状态向量 h t \mathrm{h}_{\mathrm{t}} ht作为当前时刻的输出和下一时刻的输入状态。 将T 个输入 x 0 ∼ x T \mathrm{x}_{0} \sim \mathrm{x}_{\mathrm{T}} x0xT 依次输入RNN,相应地会产生个输出。第一层RNN输出的T TT个状态向量可以作为第二层RNN的输入,第二层RNN拥有独立的参数,依次读取T个来自第一层RNN的状态向量,产生T个新的输出。第二层RNN输出的T个状态向量可以作为第三层RNN的输入,依此类推,构成一个多层RNN网络。
在这里插入图片描述

RNN自带层数和是否双向,只要把层数设置成你想要的层数就可以

class RNN(nn.Module):def __init__(self,vocab_size, embedding_dim, hidden_size, num_classes, num_layers,bidirectional):super(RNN, self).__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.hidden_size = hidden_sizeself.num_classes = num_classesself.num_layers = num_layersself.embedding = nn.Embedding(self.vocab_size, embedding_dim, padding_idx=word2idx['<PAD>'])# 这里用了torch的embeding函数self.rnn = nn.RNN(input_size=self.embedding_dim, hidden_size=self.hidden_size,batch_first=True,num_layers=self.num_layers)# input_size:表示输入 xt 的特征维度,从embeding层的size# hidden_size:表示输出的特征维度# num_layers:表示网络的层数# nonlinearity:表示选用的非线性激活函数,默认是 ‘tanh’# bias:表示是否使用偏置,默认使用# batch_first:表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位# dropout:表示是否在输出层应用(随机丢掉一些特征,重要的调参)# bidirectional:表示是否使用双向的 rnn,默认是 False。def forward(self, x):batch_size, seq_len = x.shapeh0 = torch.randn(self.num_layers, batch_size, self.hidden_size).to(device)#初始化一个h0,也即c0,在RNN中一个Cell输出的ht和Ct是相同的,而LSTM的一个cell输出的ht和Ct是不同的#维度[layers, batch, hidden_len]x = self.embedding(x)   # 这里用的是nn的embedingout,_ = self.rnn(x, h0)  # 输入x,h0是初始化的特征,output = self.fc(out[:,-1,:]).squeeze(0) #因为有max_seq_len个时态,所以取最后一个时态即-1层return output   

二、双向LSTM----Bi-LSTM

BiLSTM无非就是把cell变成了LSTM的基本单元,其他同理如上所示

需要修改的内容:

  1. 替换RNN层为LSTM层:将nn.RNN替换为nn.LSTM
  2. 设置双向LSTM:将LSTM层的bidirectional参数设置为True
  3. 修改隐藏状态的初始化:LSTM的隐藏状态包含两个张量(隐藏状态和细胞状态),因此需要初始化这两个张量。
  4. 处理双向LSTM的输出:双向LSTM的输出维度会加倍(因为有两个方向)

在这里插入图片描述

class RNN(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_size, num_classes, num_layers, bidirectional):super(RNN, self).__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.hidden_size = hidden_sizeself.num_classes = num_classesself.num_layers = num_layersself.bidirectional = bidirectionalself.num_directions = 2 if bidirectional else 1self.embedding = nn.Embedding(self.vocab_size, embedding_dim, padding_idx=word2idx['<PAD>'])self.lstm = nn.LSTM(input_size=self.embedding_dim, hidden_size=self.hidden_size, num_layers=self.num_layers, bidirectional=self.bidirectional, batch_first=True)#替换了原来的nn.RNN层,使其变为LSTM,并添加了双向参数# 全连接层的输入大小需要考虑双向LSTM的输出维度self.fc = nn.Linear(self.hidden_size * self.num_directions, self.num_classes)def forward(self, x):batch_size, seq_len = x.shape# 初始化隐藏状态和细胞状态,双向LSTM需要初始化2*num_layers的hidden statesh0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device)c0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(device)x = self.embedding(x)  # Embedding layer# LSTM layerout, (hn, cn) = self.lstm(x, (h0, c0))# 取最后一个时间步的输出,双向LSTM会拼接两个方向的输出output = self.fc(out[:, -1, :])return output

三、两者异同

双向RNN(Bidirectional Recurrent Neural Network)和双向LSTM(Bidirectional Long Short-Term Memory)都是用于处理序列数据的神经网络模型。它们在处理时间序列、自然语言处理等任务中非常有用,因为它们能够考虑到序列中的前后信息。尽管它们有一些相似之处,但也有一些关键的不同点。

1、相同点

  1. 双向结构:无论是双向RNN还是双向LSTM,它们都有双向结构,这意味着它们有两个隐藏层,一个从前向后处理序列,另一个从后向前处理序列。这种结构允许模型在每个时间步上同时考虑前面和后面的信息,从而提高预测性能。

  2. 序列处理:两者都是为序列数据设计的,适用于自然语言处理、时间序列预测等任务。

  3. 隐层状态的结合:在每个时间步,它们都会结合来自两个方向的隐藏状态(通常是通过连接或求和)来产生输出。

2、不同点

  1. 基本单元

    • 双向RNN:基本单元是RNN,即简单的循环神经网络。它们依赖于隐藏状态将信息从一个时间步传播到下一个时间步。然而,RNN在处理长序列时容易遇到梯度消失和梯度爆炸问题,这限制了它们捕捉长距离依赖的能力。
    • 双向LSTM:基本单元是LSTM,即长短期记忆网络。LSTM通过引入遗忘门、输入门和输出门来控制信息的流动,从而更好地解决了梯度消失和梯度爆炸问题。因此,LSTM在处理长序列时比普通RNN更有效。
  2. 记忆能力

    • 双向RNN:由于其结构的限制,普通RNN在处理长期依赖关系时表现不佳。它们更适合短期依赖任务。
    • 双向LSTM:LSTM单元通过其门控机制能够有效地捕捉和保持长期依赖信息,因此在需要长时间记忆的任务中表现更优。
  3. 复杂性和计算成本

    • 双向RNN:结构相对简单,计算成本较低。但由于其局限性,可能需要更多层次的堆叠或者更复杂的架构来提升性能。
    • 双向LSTM:由于引入了门机制,LSTM单元比RNN单元复杂,计算成本也更高。但其增强的记忆能力通常能带来更好的性能。

参考https://zhuanlan.zhihu.com/p/519965073

https://fancyerii.github.io/books/rnn-intro/

https://www.cnblogs.com/Lee-yl/p/10066531.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13205.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3+ts(<script setup lang=“ts“>)刷新页面后保持下拉框选中效果

效果图&#xff1a; 代码&#xff1a; <template><div class"app-layout"><div class"app-box"><div class"header"><div class"header-left"></div><div class"title">室外智…

scratch列表排序 2024年3月中国电子学会图形化编程 少儿编程 scratch编程等级考试四级真题和答案解析

目录 scratch列表排序 一、题目要求 1、准备工作 2、功能实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、p…

Python函数之旅专栏(导航)

Python内置函数(参考版本:3.11.8)AELRabs( )enumerate( )len( )range( )aiter( )eval( )list( )repr( )all( )exec( )locals( )reversed( )anext( )round( )any( ) ascii( )FM  filter( )map( )S float( )max( )set( )Bformat( )memoryview( )setattr( )bin( )frozenset( )…

ArcGI基本技巧-科研常用OLS, GWR, GTWR模型实现

ArcGI基本技巧-科研常用OLS, GWR, GTWR模型实现 OLS,GWR,GTWR回归模型均可以揭示解释变量对被解释变量的影响且可以进行预测。Ordinary Least Squares (OLS)是最小二乘法&#xff0c;Geographically Weighted Regression (GWR)是地理加权回归&#xff0c;Geographically and T…

Unity射击游戏开发教程:(18)添加弹药计数+补充弹药

添加简单的弹药计数 我将讨论如何向游戏中添加简单的弹药计数。这将包括在 HUD 中添加弹药计数器,当弹药达到 0 时,文本会将颜色更改为红色以提醒玩家。另外,当弹药数为0时,玩家将无法再射击。让我们深入了解吧! 在播放器脚本中我们需要添加一些变量。我们将创建两个公共整…

详细分析Python中的win32com(附Demo)

目录 前言1. 基本知识2. Excel3. Word 前言 对于自动化RPA比较火热&#xff0c;相应的库也比较多&#xff0c;此文分析win32com这个库&#xff0c;用于操作office 1. 基本知识 Win32com 是一个 Python 模块&#xff0c;是 pywin32 扩展的一部分&#xff0c;允许 Python 代码…

C语言如何删除表中指定位置的结点?

一、问题 如何删除链表中指定位置的结点&#xff1f; 二、解答 删除链表中指定的结点&#xff0c;就像是排好队的⼩朋友⼿牵着⼿&#xff0c;将其中⼀个⼩朋友从队伍中分出来&#xff0c;只需将这个⼩朋友的双⼿从两边松开。 删除结点有两种情况&#xff1a; &#xff08;1&am…

HTML静态网页成品作业(HTML+CSS)——我的家乡江永网页设计制作(1个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有1个页面。 二、作品演示 三、代…

Pencils Protocol Season 2 收官在即,Season 3 携系列重磅权益来袭

此前Scroll生态LaunchPad &聚合收益平台Pencils Protocol&#xff08;原Penpad&#xff09;&#xff0c;推出了首个资产即其生态代币PDD的Launch&#xff0c;Season 2活动主要是用户通过质押ETH代币、组件战队等方式&#xff0c;来获得Point奖励&#xff0c;并以该Point为依…

2024 Google I/O大会:全方位解读最新AI技术和产品

引言&#xff1a; 2024年的Google I/O大会如期举行&#xff0c;作为技术圈的年度盛事之一&#xff0c;谷歌展示了其在人工智能领域的最新进展。本次大会尤其引人注目&#xff0c;因为它紧随着OpenAI昨天发布GPT-4o的脚步。让我们详细解析Google此次公布的各项新技术和产品&…

svn如何远程访问?

svn&#xff08;Subversion&#xff09;是一种版本控制系统&#xff0c;广泛应用于软件开发领域。它能够追踪文件和目录的变化&#xff0c;记录每个版本的修改内容&#xff0c;并允许多人协同开发。svn的远程访问功能允许开发人员可以在不同的地点访问和管理代码&#xff0c;提…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-15.7讲 GPIO中断实验-编写按键中断驱动

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

EasyCVR智慧校园建设中的关键技术:视频汇聚智能管理系统应用

一、引言 随着信息技术的迅猛发展&#xff0c;智慧校园作为教育信息化建设的重要组成部分&#xff0c;对于提升校园安全、教学效率和管理水平具有重要意义。本文旨在介绍智慧校园视频管理系统的架构设计&#xff0c;为构建高效、智能的校园视频监控系统提供参考。 二、系统整…

Vitis HLS 学习笔记--资源绑定-使用URAM(1)

目录 1. 简介 2. 代码分析 2.1 存储器代码 2.2 Implementation报告 2.3 存储器类型指定 2.4 存储器初始化 3. 总结 1. 简介 在博文《Vitis HLS 学习笔记--资源绑定-使用URAM-CSDN博客》中&#xff0c;介绍了如何在Vitis HLS环境下设计一个简易的存储器模型。 通过以下…

gin自定义验证器+中文翻译

gin自定义验证器中文翻译 1、说明2、global.go3、validator.go4、eg&#xff1a;main.go5、调用接口测试 1、说明 gin官网自定义验证器给的例子相对比较简单&#xff0c;主要是语法级别&#xff0c;便于入门学习&#xff0c;并且没有给出翻译相关的处理&#xff0c;因此在这里记…

红黑树底层封装map、set C++

目录 一、框架思考 三个问题 问题1的解决 问题2的解决&#xff1a; 问题3的解决&#xff1a; 二、泛型编程 1、仿函数的泛型编程 2、迭代器的泛型编程 3、typename&#xff1a; 4、/--重载 三、原码 红黑树 map set 一、框架思考 map和set都是使用红黑树底层&…

新人学习笔记值(初始JavaScript)

一、Java Script是什么 1.Java Script是世界上最流行的语言之一&#xff0c;是一种运行在客户端的脚本语言&#xff08;script是脚本的意思&#xff09; 2.脚本语言&#xff1a;不需要编译&#xff0c;运行过程中由js解释器&#xff08;js引擎&#xff09;进行解释并运行 3.现在…

Vue原理学习:vdom 和 diff算法(基于snabbdom)

vdom 和 diff 背景 基于组件化&#xff0c;数据驱动视图。只需关心数据&#xff0c;无需关系 DOM &#xff0c;好事儿。 但是&#xff0c;JS 运行非常快&#xff0c;DOM 操作却非常慢&#xff0c;如何让“数据驱动视图”能快速响应&#xff1f; 引入 vdom 用 vnode 表示真实…

联合新能源汽车有限公司出席2024年7月8日杭州快递物流展

参展企业介绍 青岛联合新能源汽车有限公司&#xff08;简称&#xff1a;联合汽车&#xff09;&#xff0c;是一家专注于纯电动汽车领域创新的科技公司&#xff0c;在国内率先提出车电分离&#xff0c;电池标准化并共享的方案&#xff0c;研发了包含标准电池、电池仓、可换电纯电…

Bootstrap Studio for Mac:打造专业级网页设计软件

对于追求高效与品质的设计师和开发者来说&#xff0c;Bootstrap Studio for Mac无疑是最佳选择。它建立在广受欢迎的Bootstrap框架之上&#xff0c;输出干净、语义化的HTML代码。同时&#xff0c;强大的CSS和SASS编辑器&#xff0c;支持自动建议和规则验证&#xff0c;让您的设…