机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property    #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层  (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播   (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], []  #创建空列表,用于存储数据for x in X:  #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围,  range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score

 4.预测

import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/54452.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

XML/HTML:深入解析与比较

XML/HTML:深入解析与比较 引言 XML(可扩展标记语言)和HTML(超文本标记语言)是两种广泛应用于网络和文档处理的标记语言。它们在结构和用途上有着显著的区别,但同时也存在着紧密的联系。本文将深入探讨XML和HTML的定义、特点、用途及其相互关系,旨在帮助读者更好地理解…

元学习的简单示例

代码功能 模型结构:SimpleModel是一个简单的两层全连接神经网络。 元学习过程:在maml_train函数中,每个任务由支持集和查询集组成。模型先在支持集上进行训练,然后在查询集上进行评估,更新元模型参数。 任务生成&…

3DMAX乐高积木插件LegoBlocks使用方法

3DMAX乐高积木插件LegoBlocks,用户可以通过控件调整和自定义每个乐高积木的外观和大小。 【适用版本】 3dMax2009或更高版本(不仅限于此范围) 【安装方法】 3DMAX乐高积木插件无需安装,使用时直接拖动插件脚本文件到3dMax视口中…

NLP 主要语言模型分类

文章目录 ngram自回归语言模型TransformerGPTBERT(2018年提出)基于 Transformer 架构的预训练模型特点应用基于 transformer(2017年提出,attention is all you need)堆叠层数与原transformer 的差异bert transformer 层…

Packet Tracer - 配置编号的标准 IPv4 ACL(两篇)

Packet Tracer - 配置编号的标准 IPv4 ACL(第一篇) 目标 第 1 部分:计划 ACL 实施 第 2 部分:配置、应用和验证标准 ACL 背景/场景 标准访问控制列表 (ACL) 为路由器 配置脚本,基于源地址控制路由器 是允许还是拒绝数据包。本练习的主要内…

leetcode练习 二叉树的最大深度

给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3提示: 树中节点的数量在 [0, 104] 区间内。-100 …

[模板]树的最长路径

[模板]树的最长路径 题目描述 给定一棵树,树中包含 n 个结点(编号1~n)和 n-1 条无向边,每条边都有一个权值。 现在请你找到树中的一条最长路径。 换句话说,要找到一条路径,使得使得路径两端的点的距离最远…

python学习第十节:爬虫基于requests库的方法

python学习第十节:爬虫基于requests库的方法 requests模块的作用: 发送http请求,获取响应数据,requests 库是一个原生的 HTTP 库,比 urllib 库更为容易使用。requests 库发送原生的 HTTP 1.1 请求,无需手动…

Linux:login shell和non-login shell以及其配置文件

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 shell是Linux与外界交互的程序,登录shell有两种方式,login shell与non-login shell,它们的区别是读取的配置文件不同,本…

NPM如何切换淘宝镜像进行加速

什么是淘宝镜像NPM? 淘宝镜像NPM和官方NPM的主要区别在于服务器的地理位置和网络访问速度。淘宝镜像NPM是由淘宝团队维护的一个npm镜像源,主要服务于中国大陆用户,提供了一个国内的npm镜像源,地址为 https://registry.npmmirror.…

解决Tez报错问题

在启动hive的时候,发现该报错 1、检测HADOOP_PATH环境变量 echo $HADOOP_CLASSPATH 如果没有输出,说明我们的配置文件没有生效,这时候需要重写source一下 2、刷新配置文件生效 source /etc/profile 有输出,环境生效 3、再次运…

【数据结构初阶】链式二叉树接口实现超详解

文章目录 1. 节点定义2. 前中后序遍历2. 1 遍历规则2. 2 遍历实现2. 3 结点个数2. 3. 1 二叉树节点个数2. 3. 2 二叉树叶子节点个数2. 3. 3 二叉树第k层节点个数 2. 4 二叉树查找值为x的节点2. 5 二叉树层序遍历2. 6 判断二叉树是否是完全二叉树 3. 二叉树性质 1. 节点定义 用…

laravel public 目录获取

在Laravel框架中,public目录是用来存放公共资源的,如CSS、JS、图片等。你可以通过多种方式获取public目录的路径。 方法一:使用helper函数public_path() $path public_path(); 方法二:使用Request类 $path Request::root().…

SpringCloud从零开始简单搭建 - JDK17

文章目录 SpringCloud Nacos从零开始简单搭建 - JDK17一、创建父项目二、创建子项目三、集成Nacos四、集成nacos配置中心 SpringCloud Nacos从零开始简单搭建 - JDK17 环境要求:JDK17、Spring Boot3、maven。 那么,如何从零开始搭建一个 SpringCloud …

Qt构建JSON及解析JSON

目录 一.JSON简介 JSON对象 JSON数组 二.Qt中JSON介绍 QJsonvalue Qt中JSON对象 Qt中JSON数组 QJsonDocument 三.Qt构建JSON数组 四.解析JSON数组 一.JSON简介 一般来讲C类和对象在java中是无法直接直接使用的,因为压根就不是一个规则。但是他们在内存中…

Git清除某文件所有历史提交记录

一、软件要求 1.1 软件版本要求 git > 2.22.0python3 > 3.5 1.2 辅助插件 git filter-repo Linux/macOS # Debian/Ubuntu 系统 # 或使用 pip 安装pip install git-filter-repo sudo apt install git-filter-repo Windows pip install git-filter-repo二、操作步骤…

【flex-shrink】计算 flex弹性盒子的子元素的宽度大小

计算以下两个子div的宽度大小&#xff1a; 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">…

使用 Internet 共享 (ICS) 方式分配ip

设备A使用dhcp的情况下&#xff0c;通过设备B分配ip并共享网络的方法。 启用网络共享&#xff08;ICS&#xff09;并配置 NAT Windows 自带的 Internet Connection Sharing (ICS) 功能可以简化 NAT 设置&#xff0c;允许共享一个网络连接给其他设备。 打开网络设置&#xff1…

灵当CRM系统index.php存在SQL注入漏洞

文章目录 免责申明漏洞描述搜索语法漏洞复现nuclei修复建议 免责申明 本文章仅供学习与交流&#xff0c;请勿用于非法用途&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任 漏洞描述 灵当CRM系统是一款功能全面、易于使用的客户关系管理&#xff08;C…

jacoco生成单元测试覆盖率报告

前言 单元测试是日常编写代码中常用的&#xff0c;用于测试业务逻辑的一种方式&#xff0c;单元测试的覆盖率可以用来衡量我们的业务代码经过测试覆盖的比例。 目前市场上开源的单元测试覆盖率的java插件&#xff0c;主要有Emma&#xff0c;Cobertura&#xff0c;Jacoco。具体…