注意力机制中三种掩码技术详解和Pytorch实现

注意力机制是许多最先进神经网络架构的基本组成部分,比如Transformer模型。注意力机制中的一个关键方面是掩码,它有助于控制信息流,并确保模型适当地处理序列。

在这篇文章中,我们将探索在注意力机制中使用的各种类型的掩码,并在PyTorch中实现它们。

在神经网络中,掩码是一种用于阻止模型使用输入数据中的某些部分的技术。这在序列模型中尤其重要,因为序列的长度可能会有所不同,且输入的某些部分可能无关紧要(例如,填充符)或需要被隐藏(例如,语言建模中的未来内容)。

掩码的类型

填充掩码 Padding Mask

在深度学习中,特别是在处理序列数据时,“填充掩码”(Padding Mask)是一个重要概念。当序列数据的长度不一致时,通常需要对短的序列进行填充(padding),以确保所有序列的长度相同,这样才能进行批处理。这些填充的部分实际上是没有任何意义的,不应该对模型的学习产生影响。

序列掩码 Sequence Mask

序列掩码用于隐藏输入序列的某些部分。比如在双向模型中,想要根据特定标准忽略序列的某些部分。

前瞻掩码 Look-ahead Mask

前瞻掩码,也称为因果掩码或未来掩码,用于自回归模型中,以防止模型在生成序列时窥视未来的符号。这确保了给定位置的预测仅依赖于该位置之前的符号。

填充掩码

填充掩码就是用来指示哪些数据是真实的,哪些是填充的。在模型处理这些数据时,掩码会用来避免在计算损失或者梯度时考虑填充的部分,确保模型的学习只关注于有效的数据。在使用诸如Transformer这样的模型时,填充掩码特别重要,因为它们可以帮助模型在进行自注意力计算时忽略掉填充的位置。

 importtorchdefcreate_padding_mask(seq, pad_token=0):mask= (seq==pad_token).unsqueeze(1).unsqueeze(2)returnmask  # (batch_size, 1, 1, seq_len)# Example usageseq=torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])padding_mask=create_padding_mask(seq)print(padding_mask)

序列掩码

在使用如Transformer模型时,序列掩码用于避免在计算注意力分数时考虑到填充位置的影响。这确保了模型的注意力是集中在实际有意义的数据上,而不是无关的填充数据。

RNNs本身可以处理不同长度的序列,但在批处理和某些架构中,仍然需要固定长度的输入。序列掩码在这里可以帮助RNN忽略掉序列中的填充部分,特别是在计算最终序列输出或状态时。

在训练模型时,序列掩码也可以用来确保在计算损失函数时,不会将填充部分的预测误差纳入总损失中,从而提高模型训练的准确性和效率。

序列掩码通常表示为一个与序列数据维度相同的二进制矩阵或向量,其中1表示实际数据,0表示填充数据

 def create_sequence_mask(seq):seq_len = seq.size(1)mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)return mask  # (seq_len, seq_len)# Example usageseq_len = 4sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))print(sequence_mask)

前瞻掩码 Look-ahead Mask

前瞻掩码通过在自注意力机制中屏蔽(即设置为一个非常小的负值,如负无穷大)未来时间步的信息来工作。这确保了在计算每个元素的输出时,模型只能使用到当前和之前的信息,而不能使用后面的信息。这种机制对于保持自回归属性(即一次生成一个输出,且依赖于前面的输出)是必要的。

在实现时,前瞻掩码通常表示为一个上三角矩阵,其中对角线及对角线以下的元素为0(表示这些位置的信息是可见的),对角线以上的元素为1(表示这些位置的信息是不可见的)。在计算注意力时,这些为1的位置会被设置为一个非常小的负数(通常是负无穷),这样经过softmax函数后,这些位置的权重接近于0,从而不会对输出产生影响。

 def create_look_ahead_mask(size):mask = torch.triu(torch.ones(size, size), diagonal=1)return mask  # (seq_len, seq_len)# Example usagelook_ahead_mask = create_look_ahead_mask(4)print(look_ahead_mask)

掩码之间的关系

填充掩码(Padding Mask)和序列掩码(Sequence Mask)都是在处理序列数据时使用的技术,它们的目的是帮助模型正确处理变长的输入序列,但它们的应用场景和功能有些区别。这两种掩码经常在深度学习模型中被一起使用,尤其是在需要处理不同长度序列的场景下。

填充掩码专门用于指示哪些数据是填充的,这主要应用在输入数据预处理和模型的输入层。其核心目的是确保模型在处理或学习过程中不会将填充部分的数据当作有效数据来处理,从而影响模型的性能。在诸如Transformer模型的自注意力机制中,填充掩码用于阻止模型将注意力放在填充的序列上。

序列掩码通常用于更广泛的上下文中,它不仅可以指示填充位置,还可以用于其他类型的掩蔽,如在序列到序列的任务中掩蔽未来的信息(如解码器的自回归预测)。序列掩码可以用于确保模型在处理过程中只关注于当前及之前的信息,而不是未来的信息,这对于保持信息的时序依赖性非常重要。

充掩码多用于模型的输入阶段或在注意力机制中排除无效数据的影响,序列掩码则可能在模型的多个阶段使用,特别是在需要控制信息流的场景中。

与填充掩码和序列掩码不同,前瞻掩码专门用于控制时间序列的信息流,确保在生成序列的每个步骤中模型只能利用到当前和之前的信息。这是生成任务中保持模型正确性和效率的关键技术。

在注意机制中应用不同的掩码

在注意力机制中,掩码被用来修改注意力得分。

 importtorch.nn.functionalasFdefscaled_dot_product_attention(q, k, v, mask=None):matmul_qk=torch.matmul(q, k.transpose(-2, -1))dk=q.size()[-1]scaled_attention_logits=matmul_qk/torch.sqrt(torch.tensor(dk, dtype=torch.float32))ifmaskisnotNone:scaled_attention_logits+= (mask*-1e9)attention_weights=F.softmax(scaled_attention_logits, dim=-1)output=torch.matmul(attention_weights, v)returnoutput, attention_weights# Example usaged_model=512batch_size=2seq_len=4q=torch.rand((batch_size, seq_len, d_model))k=torch.rand((batch_size, seq_len, d_model))v=torch.rand((batch_size, seq_len, d_model))mask=create_look_ahead_mask(seq_len)attention_output, attention_weights=scaled_dot_product_attention(q, k, v, mask)print(attention_output)

我们创建一个简单的Transformer 层来验证一下三个掩码的不同之处:

 import torchimport torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelassert d_model % num_heads == 0self.depth = d_model // num_headsself.wq = nn.Linear(d_model, d_model)self.wk = nn.Linear(d_model, d_model)self.wv = nn.Linear(d_model, d_model)self.dense = nn.Linear(d_model, d_model)def split_heads(self, x, batch_size):x = x.view(batch_size, -1, self.num_heads, self.depth)return x.permute(0, 2, 1, 3)def forward(self, v, k, q, mask):batch_size = q.size(0)q = self.split_heads(self.wq(q), batch_size)k = self.split_heads(self.wk(k), batch_size)v = self.split_heads(self.wv(v), batch_size)scaled_attention, _ = scaled_dot_product_attention(q, k, v, mask)scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()original_size_attention = scaled_attention.view(batch_size, -1, self.d_model)output = self.dense(original_size_attention)return outputclass TransformerLayer(nn.Module):def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):super(TransformerLayer, self).__init__()self.mha = MultiHeadAttention(d_model, num_heads)self.ffn = nn.Sequential(nn.Linear(d_model, dff),nn.ReLU(),nn.Linear(dff, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout_rate)self.dropout2 = nn.Dropout(dropout_rate)def forward(self, x, mask):attn_output = self.mha(x, x, x, mask)attn_output = self.dropout1(attn_output)out1 = self.layernorm1(x + attn_output)ffn_output = self.ffn(out1)ffn_output = self.dropout2(ffn_output)out2 = self.layernorm2(out1 + ffn_output)return out2

创建一个简单的模型:

 d_model = 512num_heads = 8dff = 2048dropout_rate = 0.1batch_size = 2seq_len = 4x = torch.rand((batch_size, seq_len, d_model))mask = create_padding_mask(torch.tensor([[1, 2, 0, 0], [3, 4, 5, 0]]))transformer_layer = TransformerLayer(d_model, num_heads, dff, dropout_rate)output = transformer_layer(x, mask)

然后在Transformer层上运行我们上面介绍的三个掩码。

 def test_padding_mask():seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])expected_mask = torch.tensor([[[[0, 0, 1, 1]]], [[[0, 0, 0, 1]]]])assert torch.equal(create_padding_mask(seq), expected_mask)print("Padding mask test passed!")def test_sequence_mask():seq_len = 4expected_mask = torch.tensor([[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]])assert torch.equal(create_sequence_mask(torch.zeros(seq_len, seq_len)), expected_mask)print("Sequence mask test passed!")def test_look_ahead_mask():size = 4expected_mask = torch.tensor([[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]])assert torch.equal(create_look_ahead_mask(size), expected_mask)print("Look-ahead mask test passed!")def test_transformer_layer():d_model = 512num_heads = 8dff = 2048dropout_rate = 0.1batch_size = 2seq_len = 4x = torch.rand((batch_size, seq_len, d_model))mask = create_padding_mask(torch.tensor([[1, 2, 0, 0], [3, 4, 5, 0]]))transformer_layer = TransformerLayer(d_model, num_heads, dff, dropout_rate)output = transformer_layer(x, mask)assert output.size() == (batch_size, seq_len, d_model)print("Transformer layer test passed!")test_padding_mask()test_sequence_mask()test_look_ahead_mask()test_transformer_layer()

结果和上面我们单独执行是一样的,所以得到如下结果

总结

最后我们来做个总结,在自然语言处理和其他序列处理任务中,使用不同类型的掩码来管理和优化模型处理信息的方式是非常关键的。这些掩码主要包括填充掩码、序列掩码和前瞻掩码,每种掩码都有其特定的使用场景和目的。

  1. 填充掩码(Padding Mask):- 目的:确保模型在处理填充的输入数据时不会将这些无关的数据当作有效信息处理。- 应用:主要用于处理因数据长度不一致而进行的填充操作,在模型的输入层或注意力机制中忽略这些填充数据。- 功能:帮助模型集中于实际的、有效的输入数据,避免因为处理无意义的填充数据而导致的性能下降。
  2. 序列掩码(Sequence Mask):- 目的:更广泛地控制模型应该关注的数据部分,包括但不限于填充数据。- 应用:用于各种需要精确控制信息流的场景,例如在递归神经网络和Transformer模型中管理有效数据和填充数据。- 功能:通过指示哪些数据是有效的,哪些是填充的,帮助模型更有效地学习和生成预测。
  3. 前瞻掩码(Look-ahead Mask):- 目的:防止模型在生成序列的过程中“看到”未来的信息。- 应用:主要用在自回归模型如Transformer的解码器中,确保生成的每个元素只能依赖于之前的元素。- 功能:保证模型生成信息的时序正确性,防止在生成任务中出现信息泄露,从而维持生成过程的自然和准确性。

这些掩码在处理变长序列、保持模型效率和正确性方面扮演着重要角色,是现代深度学习模型不可或缺的一部分。在设计和实现模型时,合理地使用这些掩码可以显著提高模型的性能和输出质量。

https://avoid.overfit.cn/post/2371a9ec5eca46af81dbe23d3442a383

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【瑞吉外卖 | day07】移动端菜品展示、购物车、下单

文章目录 瑞吉外卖 — day71. 导入用户地址簿相关功能代码1.1 需求分析1.2 数据模型1.3 代码开发 2. 菜品展示2.1 需求分析2.2 代码开发 3. 购物车3.1 需求分析3.2 数据模型3.3 代码开发 4. 下单4.1 需求分析4.2 数据模型4.3 代码开发 瑞吉外卖 — day7 移动端相关业务功能 —…

MySQL 一行记录是怎么存储的

文章目录 1. 文件存放目录 && 组织2. 表空间文件的结构3. InnoDB 行格式4. Compact 行格式记录的额外信息1. 变长字段长度列表2. NULL 值列表3. 记录头信息 记录的真实数据1. 定义的表字段2. 三个隐藏字段 5. varchar(n) 中 n 最大取值为多少?6. 行溢出后&a…

pnpm install安装失败

ERR_PNPM_META_FETCH_FAIL GET https://registry.npmjs.org/commitlint%2Fcli: request to https://registry.npmjs.org/commitlint%2Fcli failed, reason: connect ETIMEDOUT 2606:4700::6810:123:443 1. 检查网络连接 确保你的网络连接正常并且没有被防火墙或代理服务器阻止…

高翔【自动驾驶与机器人中的SLAM技术】学习笔记(二)——带着问题的学习;一刷感受;环境搭建

按照作者在读者寄语中的说法:我们得榨干这本书的知识。 带着问题 为了更好的学习,我们最好带着问题去探索。 第一:核心问题与基础知识 如上图:这本书介绍了SLAM相关的核心问题和基础知识。王谷博士给我们做了梳理:…

数据结构(4.1)——树的性质

结点数总度数1 结点的度——结点有几个孩子(分支) 度为m的树、m叉树的区别 度为m的树第i层至多有 个结点(i>1) 高度为h的m叉树至多有 个结点 高度为h的m叉树至少有h个结点 、高度为h,度为m叉树至多有hm-1个结点 具有n个结点的m叉树的最小高度为 总结

数据采集监控平台:挖掘数据价值 高效高速生产!

在当今数字化的时代,数据已成为企业非常宝贵的资产之一。然而,要充分发挥数据的潜力,离不开一个强大的数据采集监控平台,尤其是生产制造行业。它不仅是数据的收集者,更是洞察生产的智慧之眼,高效高速处理产…

EXCEL VBA工程密码破解 工作表保护破解

这里写目录标题 破解Excel宏工程加密方法一 新建破解宏文件方法二 修改二进制文件 破解工作表保护引用 破解Excel宏工程加密 如图所示 白料数据处理已工程被加密。 方法一 新建破解宏文件 1 创建一个XLSM文件,查看代码 ALTF11 2 新建一个模块,“插…

云计算数据中心(二)

目录 三、绿色节能技术(一)配电系统节能技术(二)空调系统节能技术(三)集装箱数据中心节能技术(四)数据中心节能策略和算法研究(五)新能源的应用(六…

新版本 idea 创建不了 spring boot 2 【没有jkd8选项】

创建新项目 将地址换成如下 https://start.aliyun.com/

Calibration相机内参数标定

1.环境依赖 本算法采用张正友相机标定法进行实现,内部对其进行了封装。 环境依赖为 ubuntu20.04 opencv4.2.0 yaml-cpp yaml-cpp安装方式: (1)git clone https://github.com/jbeder/yaml-cpp.git #将yaml-cpp下载至本地 &a…

深度解析:disableHostCheck: true引发的安全迷局与解决之道

在Web开发的浩瀚星空中,开发者们时常会遇到各种配置与调优的挑战,其中disableHostCheck: true这一选项,在提升开发效率的同时,也悄然埋下了安全隐患的伏笔。本文将深入探讨这一配置背后的原理、为何会引发报错,以及如何…

深度学习落地实战:基于GAN(生成对抗网络)生成图片

前言 大家好,我是机长 本专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代…

Qt会议室项目

在Qt中编写会议室应用程序通常涉及到用户界面设计、网络通信、音频/视频处理等方面。以下是创建一个基本会议室应用程序的步骤概述: 项目设置: 使用Qt Creator创建一个新的Qt Widgets Application或Qt Quick Application项目。 用户界面设计&#xff1…

牛客TOP101:合并k个已排序的链表

文章目录 1. 题目描述2. 解题思路3. 代码实现 1. 题目描述 2. 解题思路 多个链表的合并本质上可以看成两个链表的合并,只不过需要进行多次。最简单的方法就是一个一个链表,按照合并两个有序链表的思路,循环多次就可以了。   另外一个思路&a…

(c++)virtual关键字的作用,多态的原理(详细)

1.viirtual修饰的两种函数 virtual 修饰的函数有两种,一个是虚函数,一个是纯虚函数。 2.虚函数与纯虚函数的异同之处 1.虚函数与纯虚函数的相同之处 虚函数和纯虚函数都重写的一种,什么是重写呢?重写是指在子类中写和父类中返…

《0基础》学习Python——第十四讲__封装、继承、多态

<封装、继承、多态> 一、类和实例解析 1、面向对象最重要的概念就是类&#xff08;Class&#xff09;和实例&#xff08;Instance&#xff09;&#xff0c;必须牢记类是抽象的模板 &#xff0c;比如Student类&#xff0c;而 实例是根据类创建出来的一个个具体的“对象”…

《昇思25天学习打卡营第23天|onereal》

第23天学习内容简介&#xff1a; ----------------------------------------------------------------------------- 本案例基于MindNLP和ChatGLM-6B实现一个聊天应用。 1 环境配置 配置网络线路 2 代码开发 下载权重大约需要10分钟 ------------------------------- 运…

大模型技术对学校有什么作用?

大模型技术对学校有多方面的作用&#xff0c;可以在教学、管理、决策等多个领域带来显著的改进。以下是大模型技术对学校的主要作用&#xff1a; 1. 个性化教学&#xff1a;大模型技术可以帮助教师分析学生的学习行为和历史成绩&#xff0c;从而定制个性化的教学计划和资源。这…

Linux桌面环境手动编译安装librime、librime-lua以及ibus-rime,提升中文输入法体验

Linux上的输入法有很多&#xff0c;大体都使用了Fcitx或者iBus作为输入法的引擎。相当于有了一个很不错的“地基”&#xff0c;你可以在这个“地基”上盖上自己的“小别墅”。而rime输入法&#xff0c;就是一个“毛坯别墅”&#xff0c;你可以在rime的基础上&#xff0c;再装修…

HCNA ICMP:因特网控制消息协议

ICMP&#xff1a;因特网控制消息协议 前言 Internet控制报文协议ICMP是网络层的一个重要协议。ICMP协议用来在网络设备间传递各种差错和控制信息&#xff0c;他对于手机各种网络信息、诊断和排除各种网络故障有至关重要的作用。使用基于ICMP的应用时&#xff0c;需要对ICMP的工…