Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]# 2)  进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/677166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OrangePi Zero2 智能家居】需求及项目准备

一、需求及项目准备 二、系统框图 三、硬件接线 四、语音模块配置 五、模块测试 一、需求及项目准备 语音接入控制各类家电,如客厅灯、卧室灯、风扇Socket网络编程,实现Sockect发送指令远程控制各类家电烟雾警报监测, 实时检查是否存在煤气…

高级FPGA开发之基础协议PCIe(二)

高级FPGA开发之基础协议之PCIe(二) 一、TLP报文类型 在PCIe总线中,存储器读写、I/O读写和配置读写请求TLP主要由以下几类报文组成: 1.1 存储器读请求TLP和读完成TLP 当PCIe主设备(RC或者EP)访问目标设备…

Java图形化界面编程——菜单组件 笔记

2.7 菜单组件 ​ 前面讲解了如果构建GUI界面,其实就是把一些GUI的组件,按照一定的布局放入到容器中展示就可以了。在实际开发中,除了主界面,还有一类比较重要的内容就是菜单相关组件,可以通过菜单相关组件很方便的使用…

机器学习系列——(十二)线性回归

导言 在机器学习领域,线性回归是最基础且重要的算法之一。它用于建立输入特征与输出目标之间的线性关系模型,为我们解决回归问题提供了有效的工具。本文将详细介绍线性回归的原理、应用和实现方法,帮助读者快速了解和上手这一强大的机器学习…

回调地狱 与 Promise(JavaScript)

目录捏 前言一、异步编程二、回调函数三、回调地狱四、Promise1. Promise 简介2. Promise 语法3. Promise 链式 五、总结 前言 想要学习Promise,我们首先要了解异步编程、回调函数、回调地狱三方面知识: 一、异步编程 异步编程技术使你的程序可以在执行一…

除夕日的每日一题(字符个数统计,多数元素)

字符个数统计_牛客题霸_牛客网 (nowcoder.com) #include <stdio.h> #include <string.h> #include <stdlib.h>int num0,len,i,j,k,asc; int tmp[128]{0}; char str[400]; int main() {gets(str);//gets(str) 用于从标准输入读取一个字符串并存储在 str 中len…

AWD-Test2

1.已知账号密码&#xff0c;可SSH连接进行代码审计。2.登录可万能密码进入&#xff0c;也可注册后登录。3.修改url参数&#xff0c;发现报错。确定为Linux系统4.写入一句话&#xff0c;并提交。&#xff08;也可以文件上传&#xff0c;这里采用简洁的方法&#xff09; <?p…

修改SpringBoot中默认依赖版本

例如SpringBoot2.7.2中ElasticSearch版本是7.17.4 我希望把它变成7.6.1

电商商城系统网站

文章目录 电商商城系统网站一、项目演示二、项目介绍三、系统部分功能截图四、部分代码展示五、底部获取项目&#xff08;9.9&#xffe5;带走&#xff09; 电商商城系统网站 一、项目演示 商城系统 二、项目介绍 基于SpringBootVue的前后端分离商城系统网站 运行环境:idea或…

M1 Mac使用SquareLine-Studio进行LVGL开发

背景 使用Gui-Guider开发遇到一些问题,比如组件不全。使用LVGL官方的设计软件开发 延续上一篇使用的基本环境。 LVGL项目 新建项目 选择Arduino的项目,设定好分辨率及颜色。 设计UI 导出代码 Export -> Create Template Project 导出文件如图 将libraries下的ui文…

【原理图PCB专题】Cadence17.4 PCB位号重排与反标

在文章:【原理图专题】Cadence 16.6如何把PCB元件位号重排并反标到原理图 中我们讲到了Cadence16.6版本对原理图进行反标的操作。 对于反标之前我们是通过如下所示的绘制流程来讲的,一般在首板或是大改板操作器件里有很多不同的很大的位号,这时我们可以通过Backannotate功能…

C++初阶篇----新手进村

目录 一、什么是C二、C关键字三、命名空间3.1命名空间的定义3.2命名空间的使用 四、C输入和输出五、缺省参数5.1缺省参数的概念5.2缺省参数的分类 六、函数重载6.1函数重载的概念6.2函数重载的原理----名字修饰 七、引用7.1引用概念7.2引用特性7.3常引用7.4引用的使用7.5传值、…

考研高数(导数的定义)

总结&#xff1a; 导数的本质就是极限。 函数在某点可导就必连续&#xff0c;连续就有极限且等于该点的函数值。 例题1&#xff1a;&#xff08;归结原则的条件是函数可导&#xff09; 例题2&#xff1a; 例题3&#xff1a;

使用 Elasticsearch 和 OpenAI 构建生成式 AI 应用程序

本笔记本演示了如何&#xff1a; 将 OpenAI Wikipedia 向量数据集索引到 Elasticsearch 中使用 Streamlit 构建一个简单的 Gen AI 应用程序&#xff0c;该应用程序使用 Elasticsearch 检索上下文并使用 OpenAI 制定答案 安装 安装 Elasticsearch 及 Kibana 如果你还没有安装好…

FPGA_简单工程_状态机

一 理论 fpga是并行执行的&#xff0c;当处理需要顺序解决的事时&#xff0c;就要引入状态机。 状态机&#xff1a; 简写FSM&#xff0c;也称同步有限状态机。 分为&#xff1a;more型状态机&#xff0c;mealy型状态机。 功能&#xff1a;执行该事件&#xff0c;然后跳转到下…

相机图像质量研究(11)常见问题总结:光学结构对成像的影响--像差

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结&#xff1a;光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结&#xff1a;光学结构对成…

Linux 36.2@Jetson Orin Nano之Hello AI World!

Linux 36.2Jetson Orin Nano之Hello AI World&#xff01; 1. 源由2. Hello AI World&#xff01;3. 步骤3.1 准备阶段3.2 获取代码3.3 Python环境3.4 重点环节3.5 软件配置3.6 PyTorch安装3.7 编译链接3.8 安装更新 4. 测试4.1 video-viewer4.2 detectnet4.3 演示命令 5. 参考…

【OrangePi Zero2 智能家居】阿里云人脸识别方案

一、接入阿里云 二、C语言调用阿里云人脸识别接口 三、System V消息队列和POSIX 消息队列 一、接入阿里云 在之前树莓派的人脸识别方案采用了翔云平台的方案去1V1上传比对两张人脸比对&#xff0c;这种方案是可行&#xff0c;可 以继续采用。但为了接触更多了云平台方案&…

跟着pink老师前端入门教程-day23

苏宁网首页案例制作 设置视口标签以及引入初始化样式 <meta name"viewport" content"widthdevice-width, user-scalableno, initial-scale1.0, maximum-scale1.0, minimum-scale1.0"> <link rel"stylesheet" href"css/normaliz…

Mybatis是否支持延迟加载?

前言 随着互联网应用的不断发展&#xff0c;数据库访问成为了应用开发中的一个重要环节。在这个背景下&#xff0c;MyBatis作为一种优秀的持久层框架&#xff0c;提供了灵活的SQL映射配置和强大的功能&#xff0c;为开发者提供了便捷的数据库访问解决方案。本文将深入探讨MyBat…