Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]# 2)  进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/677166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VUEX项目场景

VUEX项目场景 一、登录状态存储 登录页面代码 <template><div><input v-model"username" type"text" placeholder"Username"><input v-model"password" type"password" placeholder"Password&…

【OrangePi Zero2 智能家居】需求及项目准备

一、需求及项目准备 二、系统框图 三、硬件接线 四、语音模块配置 五、模块测试 一、需求及项目准备 语音接入控制各类家电&#xff0c;如客厅灯、卧室灯、风扇Socket网络编程&#xff0c;实现Sockect发送指令远程控制各类家电烟雾警报监测&#xff0c; 实时检查是否存在煤气…

高级FPGA开发之基础协议PCIe(二)

高级FPGA开发之基础协议之PCIe&#xff08;二&#xff09; 一、TLP报文类型 在PCIe总线中&#xff0c;存储器读写、I/O读写和配置读写请求TLP主要由以下几类报文组成&#xff1a; 1.1 存储器读请求TLP和读完成TLP 当PCIe主设备&#xff08;RC或者EP&#xff09;访问目标设备…

Java图形化界面编程——菜单组件 笔记

2.7 菜单组件 ​ 前面讲解了如果构建GUI界面&#xff0c;其实就是把一些GUI的组件&#xff0c;按照一定的布局放入到容器中展示就可以了。在实际开发中&#xff0c;除了主界面&#xff0c;还有一类比较重要的内容就是菜单相关组件&#xff0c;可以通过菜单相关组件很方便的使用…

机器学习系列——(十二)线性回归

导言 在机器学习领域&#xff0c;线性回归是最基础且重要的算法之一。它用于建立输入特征与输出目标之间的线性关系模型&#xff0c;为我们解决回归问题提供了有效的工具。本文将详细介绍线性回归的原理、应用和实现方法&#xff0c;帮助读者快速了解和上手这一强大的机器学习…

回调地狱 与 Promise(JavaScript)

目录捏 前言一、异步编程二、回调函数三、回调地狱四、Promise1. Promise 简介2. Promise 语法3. Promise 链式 五、总结 前言 想要学习Promise&#xff0c;我们首先要了解异步编程、回调函数、回调地狱三方面知识&#xff1a; 一、异步编程 异步编程技术使你的程序可以在执行一…

除夕日的每日一题(字符个数统计,多数元素)

字符个数统计_牛客题霸_牛客网 (nowcoder.com) #include <stdio.h> #include <string.h> #include <stdlib.h>int num0,len,i,j,k,asc; int tmp[128]{0}; char str[400]; int main() {gets(str);//gets(str) 用于从标准输入读取一个字符串并存储在 str 中len…

梯度提升树系列7——深入理解GBDT的参数调优

目录 写在开头1. GBDT的关键参数解析1.1 学习率(learning rate)1.2 树的数量(n_estimators)1.3 树的最大深度(max_depth)1.4 叶子节点的最小样本数(min_samples_leaf)1.5 特征选择的比例(max_features)1.6 最小分裂所需的样本数(min_samples_split)1.7 子采样比例(…

算法笔记P67

#include <cstdio>void swap(int*a,int*b){int temp *a;*a *b;*b temp; }int main() {int a1,b2;int *p1 &a,*p2 &b;swap(p1,p2);printf("a %d,b %d",a,b);return 0; }PS&#xff1a;对*和&的理解 1. * &#xff1a;例如int* p的理解&#…

AWD-Test2

1.已知账号密码&#xff0c;可SSH连接进行代码审计。2.登录可万能密码进入&#xff0c;也可注册后登录。3.修改url参数&#xff0c;发现报错。确定为Linux系统4.写入一句话&#xff0c;并提交。&#xff08;也可以文件上传&#xff0c;这里采用简洁的方法&#xff09; <?p…

修改SpringBoot中默认依赖版本

例如SpringBoot2.7.2中ElasticSearch版本是7.17.4 我希望把它变成7.6.1

电商商城系统网站

文章目录 电商商城系统网站一、项目演示二、项目介绍三、系统部分功能截图四、部分代码展示五、底部获取项目&#xff08;9.9&#xffe5;带走&#xff09; 电商商城系统网站 一、项目演示 商城系统 二、项目介绍 基于SpringBootVue的前后端分离商城系统网站 运行环境:idea或…

【开源计算机视觉库OpencV详解——超详细】

开源计算机视觉库OpencV详解 1. 介绍2. 核心功能3. 安装OpenCV4. 示例&#xff1a;使用Python读取和显示图像5. 示例&#xff1a;使用Python捕捉视频6. 获取帮助和文档 1. 介绍 OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个开放源码的计算机视觉…

M1 Mac使用SquareLine-Studio进行LVGL开发

背景 使用Gui-Guider开发遇到一些问题,比如组件不全。使用LVGL官方的设计软件开发 延续上一篇使用的基本环境。 LVGL项目 新建项目 选择Arduino的项目,设定好分辨率及颜色。 设计UI 导出代码 Export -> Create Template Project 导出文件如图 将libraries下的ui文…

【原理图PCB专题】Cadence17.4 PCB位号重排与反标

在文章:【原理图专题】Cadence 16.6如何把PCB元件位号重排并反标到原理图 中我们讲到了Cadence16.6版本对原理图进行反标的操作。 对于反标之前我们是通过如下所示的绘制流程来讲的,一般在首板或是大改板操作器件里有很多不同的很大的位号,这时我们可以通过Backannotate功能…

C++初阶篇----新手进村

目录 一、什么是C二、C关键字三、命名空间3.1命名空间的定义3.2命名空间的使用 四、C输入和输出五、缺省参数5.1缺省参数的概念5.2缺省参数的分类 六、函数重载6.1函数重载的概念6.2函数重载的原理----名字修饰 七、引用7.1引用概念7.2引用特性7.3常引用7.4引用的使用7.5传值、…

如何使用python网络爬虫批量获取公共资源数据实践技术应用

要使用Python网络爬虫批量获取公共资源数据&#xff0c;你需要遵循以下步骤&#xff1a; 确定目标网站和数据结构&#xff1a;首先&#xff0c;你需要明确你要爬取的网站以及该网站的数据结构。了解目标网站的数据结构和API&#xff08;如果有的话&#xff09;是关键。选择合适…

考研高数(导数的定义)

总结&#xff1a; 导数的本质就是极限。 函数在某点可导就必连续&#xff0c;连续就有极限且等于该点的函数值。 例题1&#xff1a;&#xff08;归结原则的条件是函数可导&#xff09; 例题2&#xff1a; 例题3&#xff1a;

使用 Elasticsearch 和 OpenAI 构建生成式 AI 应用程序

本笔记本演示了如何&#xff1a; 将 OpenAI Wikipedia 向量数据集索引到 Elasticsearch 中使用 Streamlit 构建一个简单的 Gen AI 应用程序&#xff0c;该应用程序使用 Elasticsearch 检索上下文并使用 OpenAI 制定答案 安装 安装 Elasticsearch 及 Kibana 如果你还没有安装好…

Python爬虫Xpath库详解#4

爬虫专栏&#xff1a;http://t.csdnimg.cn/WfCSx 前言 前面&#xff0c;我们实现了一个最基本的爬虫&#xff0c;但提取页面信息时使用的是正则表达式&#xff0c;这还是比较烦琐&#xff0c;而且万一有地方写错了&#xff0c;可能导致匹配失败&#xff0c;所以使用正则表达式…