《动手学深度学习 Pytorch版》 10.4 Bahdanau注意力

10.4.1 模型

Bahdanau 等人提出了一个没有严格单向对齐限制的可微注意力模型。在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。

新的基于注意力的模型与 9.7 节中的模型相同,只不过 9.7 节中的上下文变量 c \boldsymbol{c} c 在任何解码时间步 t ′ \boldsymbol{t'} t 都会被 c t ′ \boldsymbol{c}_{t'} ct 替换。假设输入序列中有 T \boldsymbol{T} T 个词元,解码时间步 t ′ \boldsymbol{t'} t 的上下文变量是注意力集中的输出:

c t ′ = ∑ t = 1 T α ( s t ′ − 1 , h t ) h t \boldsymbol{c}_{t'}=\sum^T_{t=1}{\alpha{(\boldsymbol{s}_{t'-1},\boldsymbol{h}_t)\boldsymbol{h}_t}} ct=t=1Tα(st1,ht)ht

参数字典:

  • 遵循与 9.7 节中的相同符号表达

  • 时间步 t ′ − 1 \boldsymbol{t'-1} t1 时的解码器隐状态 s t ′ − 1 \boldsymbol{s}_{t'-1} st1 是查询

  • 编码器隐状态 h t \boldsymbol{h}_t ht 既是键,也是值

  • 注意力权重 α \alpha α 是使用上节所定义的加性注意力打分函数计算的

043

从图中可以看到,加入注意力机制后:

  • 将编码器对每次词的输出作为 key 和 value

  • 将解码器对上一个词的输出作为 querry

  • 将注意力的输出和下一个词的词嵌入合并作为解码器输入

import torch
from torch import nn
from d2l import torch as d2l

10.4.2 定义注意力解码器

AttentionDecoder 类定义了带有注意力机制解码器的基本接口

#@save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

在 Seq2SeqAttentionDecoder 类中实现带有 Bahdanau 注意力的循环神经网络解码器。初始化解码器的状态,需要下面的输入:

  • 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;

  • 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;

  • 编码器有效长度(排除在注意力池中填充词元)。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):  # 新增 enc_valid_lens 表示有效长度# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 输出X的形状为(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形状为(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)  # 解码器最终隐藏层的上一个输出添加querry个数的维度后作为querry# context的形状为(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)  # 编码器的输出作为key和value# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)  # 并起来当解码器输入# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)  # 存一下注意力权重# 全连接层变换后,outputs的形状为 (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

10.4.3 训练

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
loss 0.020, 7252.9 tokens/sec on cuda:0

在这里插入图片描述

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ',f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => il est mouillé .,  bleu 0.658
i'm home . => je suis chez moi .,  bleu 1.000

训练结束后,下面通过可视化注意力权重会发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')


在这里插入图片描述

练习

(1)在实验中用LSTM替换GRU。

class Seq2SeqEncoder_LSTM(d2l.Encoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqEncoder_LSTM, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.lstm = nn.LSTM(embed_size, num_hiddens, num_layers,  # 更换为 LSTMdropout=dropout)def forward(self, X, *args):X = self.embedding(X)X = X.permute(1, 0, 2)output, state = self.lstm(X)return output, stateclass Seq2SeqAttentionDecoder_LSTM(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder_LSTM, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.LSTM(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):enc_outputs, hidden_state, enc_valid_lens = stateX = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1][0], dim=1)  # 解码器最终隐藏层的上一个输出添加querry个数的维度后作为querrycontext = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights
embed_size_LSTM, num_hiddens_LSTM, num_layers_LSTM, dropout_LSTM = 32, 32, 2, 0.1
batch_size_LSTM, num_steps_LSTM = 64, 10
lr_LSTM, num_epochs_LSTM, device_LSTM = 0.005, 250, d2l.try_gpu()train_iter_LSTM, src_vocab_LSTM, tgt_vocab_LSTM = d2l.load_data_nmt(batch_size_LSTM, num_steps_LSTM)
encoder_LSTM = Seq2SeqEncoder_LSTM(len(src_vocab_LSTM), embed_size_LSTM, num_hiddens_LSTM, num_layers_LSTM, dropout_LSTM)
decoder_LSTM = Seq2SeqAttentionDecoder_LSTM(len(tgt_vocab_LSTM), embed_size_LSTM, num_hiddens_LSTM, num_layers_LSTM, dropout_LSTM)
net_LSTM = d2l.EncoderDecoder(encoder_LSTM, decoder_LSTM)
d2l.train_seq2seq(net_LSTM, train_iter_LSTM, lr_LSTM, num_epochs_LSTM, tgt_vocab_LSTM, device_LSTM)
loss 0.021, 7280.8 tokens/sec on cuda:0

在这里插入图片描述

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq_LSTM = d2l.predict_seq2seq(net_LSTM, eng, src_vocab_LSTM, tgt_vocab_LSTM, num_steps_LSTM, device_LSTM, True)print(f'{eng} => {translation}, ',f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => puis-je <unk> <unk> .,  bleu 0.000
i'm home . => je suis chez moi .,  bleu 1.000
attention_weights_LSTM = torch.cat([step[0][0][0] for step in dec_attention_weight_seq_LSTM], 0).reshape((1, 1, -1, num_steps_LSTM))# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights_LSTM[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')


在这里插入图片描述


(2)修改实验以将加性注意力打分函数替换为缩放点积注意力,它如何影响训练效率?

class Seq2SeqAttentionDecoder_Dot(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.DotProductAttention(  # 替换为缩放点积注意力num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):enc_outputs, hidden_state, enc_valid_lens = stateX = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights
embed_size_Dot, num_hiddens_Dot, num_layers_Dot, dropout_Dot = 32, 32, 2, 0.1
batch_size_Dot, num_steps_Dot = 64, 10
lr_Dot, num_epochs_Dot, device_Dot = 0.005, 250, d2l.try_gpu()train_iter_Dot, src_vocab_Dot, tgt_vocab_Dot = d2l.load_data_nmt(batch_size_Dot, num_steps_Dot)
encoder_Dot = Seq2SeqEncoder_LSTM(len(src_vocab_Dot), embed_size_LSTM, num_hiddens_Dot, num_layers_Dot, dropout_Dot)
decoder_Dot = Seq2SeqAttentionDecoder_LSTM(len(tgt_vocab_Dot), embed_size_Dot, num_hiddens_Dot, num_layers_Dot, dropout_Dot)
net_Dot = d2l.EncoderDecoder(encoder_Dot, decoder_Dot)
d2l.train_seq2seq(net_Dot, train_iter_Dot, lr_Dot, num_epochs_Dot, tgt_vocab_Dot, device_Dot)
loss 0.021, 7038.8 tokens/sec on cuda:0

在这里插入图片描述

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq_Dot = d2l.predict_seq2seq(net_Dot, eng, src_vocab_Dot, tgt_vocab_Dot, num_steps_Dot, device_Dot, True)print(f'{eng} => {translation}, ',f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
go . => va !,  bleu 1.000
i lost . => j'ai perdu .,  bleu 1.000
he's calm . => il est riche .,  bleu 0.658
i'm home . => je suis chez moi .,  bleu 1.000
attention_weights_Dot = torch.cat([step[0][0][0] for step in dec_attention_weight_seq_Dot], 0).reshape((1, 1, -1, num_steps_Dot))# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights_Dot[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')


在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/118171.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Elasticsearch】es脚本编程使用详解

目录 一、es脚本语言介绍 1.1 什么是es脚本 1.2 es脚本支持的语言 1.3 es脚本语言特点 1.4 es脚本使用场景 二、环境准备 2.1 docker搭建es过程 2.1.1 拉取es镜像 2.1.2 启动容器 2.1.3 配置es参数 2.1.4 重启es容器并访问 2.2 docker搭建kibana过程 2.2.1 拉取ki…

LSKA(大可分离核注意力):重新思考CNN大核注意力设计

文章目录 摘要1、简介2、相关工作3、方法4、实验5、消融研究6、与最先进方法的比较7、ViTs和CNNs的鲁棒性评估基准比较8、结论 摘要 https://arxiv.org/pdf/2309.01439.pdf 大型可分离核注意力&#xff08;LSKA&#xff09;模块的视觉注意力网络&#xff08;VAN&#xff09;已…

Linux CentOS 8(firewalld的配置与管理)

Linux CentOS 8&#xff08;firewalld的配置与管理&#xff09; 目录 一、firewalld 简介二、firewalld 工作概念1、预定义区域&#xff08;管理员可以自定义修改&#xff09;2、预定义服务 三、firewalld 配置方法1、通过firewall-cmd配置2、通过firewall图形界面配置 四、配置…

利用Jpom在线构建Spring Boot项目

1 简介 前面介绍了运用Jpom构建部署Vue项目&#xff0c;最近研究了怎么部署Spring Boot项目&#xff0c;至此&#xff0c;一套简单的前后端项目就搞定了。 2 基本步骤 因为就是一个简单的自研测试项目&#xff0c;所以构建没有使用docker容器&#xff0c;直接用java -jar命令…

Java程序设计进阶

Java异常处理机制 异常 异常的最高父类是 Throwable&#xff0c;在 java.lang 包下。 Throwable 类的方法主要有&#xff1a; 方法说明public String getMessage()返回对象的错误信息public void printStackTrace()输出对象的跟踪信息到标准错误输出流public void printSta…

【项目设计】网络对战五子棋(下)

我不再装模作样地拥有很多朋友&#xff0c;而是回到了孤单之中&#xff0c;以真正的我开始了独自的生活。有时我也会因为寂寞而难以忍受空虚的折磨&#xff0c;但我宁愿以这样的方式来维护自己的自尊&#xff0c;也不愿以耻辱为代价去换取那种表面的朋友。 文章目录 一、项目设…

Postman笔记

文章目录 1.安装2.简介和使用流程3 postman使用3.1 测试集与HTTP请求发送HTTP请求和分析响应数据 3.2 发送HTTP请求和分析响应数据3.3 Postman中请求体提交方式3.4 Postman使用之接口测试3.5 使用Postman新建一个mock服务3.6 请求数据的参数化3.7 断言与脚本导出 1.安装 官网地…

DP读书:《openEuler操作系统》(五)进程与线程

进程与线程 进程的概念程序&#xff1a;从源码到执行1. 编译阶段:2. 加载阶段:3. 执行阶段: 程序的并发执行与进程抽象 进程的描述进程控制块1. 描述信息2. 控制信息3. CPU上下文4. 资源管理信息 进程状态1.就绪状态2.运行状态3.阻塞状态4.终止状态 进程的控制进程控制源语1.创…

CrossOver23.6软件激活码怎么获取 CrossOver软件2023怎么激活

CrossOver一款类虚拟机&#xff0c;它的主要功能是在mac系统中安装windows应用程序。其工作原理是将exe格式的windows应用程序安装包安装至CrossOver容器中&#xff0c;并将运行该exe文件所需的配置文件下载至容器中&#xff0c;便能在mac正常运行windows应用程序了。下面就让我…

如何构建一个外卖微信小程序

随着外卖行业的不断发展&#xff0c;越来越多的商家开始关注外卖微信小程序的开发。微信小程序具有使用方便、快速上线、用户覆盖广等优势&#xff0c;成为了商家们的首选。 那么&#xff0c;如何快速开发一个外卖微信小程序呢&#xff1f;下面就让我们来看看吧&#xff01; 首…

【C++入门:C++世界的奇幻之旅】

1. 什么是C 2. C发展史 3. C的重要性 4. C关键字 5. 命名空间 6. C输入&输出 7. 缺省参数 8. 函数重载 9. 引用 10. 内联函数 11. auto关键字(C11) 12. 基于范围的for循环(C11) 13. 指针空值---nullptr(C11)05. 1. 什么是C C语言是结构化和模块化的语言&…

什么是web3.0?

Web 3.0&#xff0c;也常被称为下一代互联网&#xff0c;代表着互联网的下一个重大演变。尽管关于Web 3.0的确切定义尚无共识&#xff0c;但它通常被认为是一种更分散、更开放且更智能的互联网。 以下是Web 3.0的一些主要特征和概念&#xff1a; 1. 去中心化 Web 3.0旨在减少…

人工智能:CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的知识梳理

卷积神经网络&#xff08;CNN&#xff09; 卷积神经网络&#xff08;CNN&#xff09;&#xff0c;也被称为ConvNets或Convolutional Neural Networks&#xff0c;是一种深度学习神经网络架构&#xff0c;主要用于处理和分析具有网格状结构的数据&#xff0c;特别是图像和视频数…

CSS中 通过自定义属性(变量)动态修改元素样式(以 el-input 为例)

传送门&#xff1a;CSS中 自定义属性&#xff08;变量&#xff09;详解 1. 需求及解决方案 需求&#xff1a;通常我们动态修改 div 元素的样式&#xff0c;使用 :style 和 :class 即可&#xff1b;但想要动态修改 如&#xff1a;Element-ui 中输入框&#xff08;input&#x…

Windows与Linux服务器互传文件

使用winscp实现图形化拖动的方式互传文件. 1.下载winscp软件并安装&#xff0c;官方地址&#xff1a; https://winscp.net/eng/index.php 2.打开软件&#xff1a; 文件协议选择scp&#xff0c;输入linux服务器的IP和端口号&#xff0c;然后输入你的用户名和密码就可以登陆了。…

postman打开后,以前的接口记录不在,问题解决

要不这些文件保存在C:\Users\{用户名}\AppData\Roaming\Postman 比如&#xff0c;你目前使用的window登录用户是abc&#xff0c;那么地址便是C:\Users\abc\AppData\Roaming\Postman 打开后&#xff0c;这个目录下会有一些命名为backup-yyyy-MM-ddThh-mm-ss.SSSZ.json类似的文…

渗透攻击漏洞——原型链污染

背景 2019年初&#xff0c;Snyk的安全研究人员披露了流行的JavaScript库Lodash中一个严重漏洞的详细信息&#xff0c;该漏洞使黑客能够攻击多个Web应用程序&#xff0c;这个安全漏洞就是一个“原型污染漏洞”&#xff08;JavaScript Prototype Pollution&#xff09;&#xff…

【分布式】大模型分布式训练入门与实践 - 04

大模型分布式训练 数据并行-Distributed Data Parallel1.1 背景1.2 PyTorch DDP1&#xff09; DDP训练流程2&#xff09;DistributedSampler3&#xff09;DataLoader: Parallelizing data loading4&#xff09;Data-parallel&#xff08;DP&#xff09;5&#xff09;DDP原理解析…

图论06-【无权无向】-图的遍历并查集Union Find-力扣695为例

文章目录 1. 代码仓库2. 思路2.1 UF变量设计2.2 UF合并两个集合2.3 查找当前顶点的父节点 find(element) 3. 完整代码 1. 代码仓库 https://github.com/Chufeng-Jiang/Graph-Theory 2. 思路 2.1 UF变量设计 parent数组保存着每个节点所指向的父节点的索引&#xff0c;初始值为…

Java IDEA controller导出CSV,excel

Java IDEA controller导出CSV&#xff0c;excel 导出excel/csv&#xff0c;亲测可共用一个方法&#xff0c;代码逻辑里判断设置不同的表头及contentType&#xff1b;导出excel导出csv 优化&#xff1a;有数据时才可以导出参考 导出excel/csv&#xff0c;亲测可共用一个方法&…