自注意力简介

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师的课程讲得最好。

自注意力就是输入一堆向量,假设称为a1,a2,a3,a4,那么这四个向量都会参与自注意力机制的运算,得到的结果仍然是四个输出,这四个输出再去做全连接运算。而每一个自注意力机制的输出都用到了a1~a4四个向量来进行运算,也就是说每个输出都是观察了所有的输入之后才得到的。

首先,输入a1需要和a2,a3,a4分别计算相关性,这个相关性可以由缩放点积方式来计算,也就称作缩放点积注意力,也可以由两个两个输入向量相加后再做非线性处理得到,称为加性注意力。

缩放点积的计算方法如下:

输入一个向量v1和一个向量v2,v1去乘上一个可训练矩阵Wq得到q,v2去乘上一个可训练矩阵Wk得到k,再把这个q和k做一个点积运算,得到的就是α,类似于相似度。

回到前面的例子中,a1这里既作为q,又作为k,又作为v。其中,Wq*a1就是q1,Wk*a1就是k1,q1和k1的点积就是α11,相当于a1自己和自己的相似度,同样的,a1和a2,a3,a4分别计算得到α12,α13,α14,然后将α11,α12,α13,α14经过softmax得到最终的四个输出,如下图所示:

然后再用一个可训练矩阵Wv去乘以a1得到v1,用计算得到的相似度α'11去乘以v1,得到一个值temp11;同样的,用可训练矩阵Wv去乘以a2得到v2,用计算得到的相似度α‘12去乘以v2,得到temp12;类似的得到temp13,temp14,然后把temp11+temp12+temp13+temp14得到b1,这个b1就是自注意力机制的第一个输出。

我们刚刚是以a1的视角做的运算,得到b1,同样可以以a2,a3,a4的视角做运算,得到b2,b3,b4。这次就得到了自注意力机制的输出。光看最后这个结构图,有点类似全连接,只是里面的运算过程比全连接要复杂。

下面,我们来看一下如何用代码实现自注意力的计算。

import torch  
import torch.nn as nn  
import torch.nn.functional as F  class SelfAttention(nn.Module):# embed_size代表输入的向量维度,heads代表多头注意力机制中的头数量def __init__(self, embed_size, heads): super(SelfAttention,self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads # 每个头的维度# 用assert断言机制判断assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads" # 没有偏置项,其实这个线性层本质上就是为了计算值Wv*a = Vself.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)# 最后的全连接操作,输出仍是输入的向量维度,也就是说大小是不变的self.fc_out = nn.Linear(heads*self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):# 这个mask也很关键,它用于控制模型在处理序列数据时应该关注哪些部分,以及忽略哪些部分N = query.shape[0] # 获取输入的批量个数print("N:",N)value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 获取输入序列的长度# Split the embedding into self.heads different pieces  # 把k,q,v都切分为多个组values = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)# 计算k,q,vvalues = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk",[queries, keys]) # 格式转化print("queries.shape:", queries.shape)print("keys.shape:", keys.shape)print("energy.shape:", energy.shape)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20"))attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3) # softmax内部是缩放点积print("attention.shape:", attention.shape)print("values.shape:", values.shape)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)out = self.fc_out(out)return out
embed_size = 512
heads = 8
attention = SelfAttention(embed_size, heads)# batch size 1, seq length 60
values = torch.rand(1,60,embed_size)
keys = torch.rand(1,60,embed_size)
queries = torch.rand(1,60,embed_size)
mask = None # 假设没有maskout = attention(values, keys, queries, mask)
print(out.shape)# 输出
N: 1
queries.shape: torch.Size([1, 60, 8, 64])
keys.shape: torch.Size([1, 60, 8, 64])
energy.shape: torch.Size([1, 8, 60, 60])
attention.shape: torch.Size([1, 8, 60, 60])
values.shape: torch.Size([1, 60, 8, 64])
torch.Size([1, 60, 512])

通过这个程序,我们可以看到,自注意力机制是不改变输入和输出的形状的,输入的Q,K,V格式是[1,60,512],输出的结果的仍然是[1,60,512]。

下面是几点说明:

1. 这里的embed_size代表的是输入到自注意力层中的每个元素的向量维度。在Transformer模型中,输入数据首先会被转换成一个固定长度的向量,这个向量的长度就称为embed_size。

2. mask表示的是模型在处理序列数据时,应该忽略掉哪部分,我这里设置为None,也就是全部参与计算。

3. einsum,称为爱因斯坦求和,起源是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展。einsum 可以计算向量、矩阵、张量运算,如果利用得当,sinsum可完全代替其他的矩阵计算方法。

例如,C = einsum('ij,jk->ik', A, B),就相当于两个矩阵求内积:cik = Σj AijBjk。

通过输出可以看到,在计算前queries的形状是[1,60,8,64],keys的形状是[1,60,8,64],在表达式"nqhd,nkhd->nhqk"中,n=1,q=60,h=8,d=64,k=60,两个矩阵进行内积,因此得到的结果是nhqk,也就是[1,8,60,60]。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44895.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新兴商业模式如何破局?市场策划专家的实战指南

在这个融合了传统市场、互联网和信息技术的大潮中,想要在市场营销策划上玩得转,咱们得有超凡的全局思维和跨界的协作精神。 下面,我就来和大家聊聊如何在这样一个复杂环境下搞定市场营销策划,让你在竞争激烈的市场中脱颖而出。 …

Nodejs 第八十二章(Nacos动态配置)

什么是动态配置? 在Nacos中,动态配置是指可以在运行时动态修改应用程序的配置信息,而不需要重新启动或重新部署应用程序。通过Nacos的动态配置功能,开发人员可以将应用程序的配置信息存储在Nacos服务器中,并在需要时进…

PTK是如何加密WLAN单播数据帧的?

1. References WLAN 4-Way Handshake如何生成PTK?-CSDN博客 2. 概述 在Wi-Fi网络中,单播、组播和广播帧的加密算法是由AP决定的。其中单播帧的加密使用PTK密钥,其PTK的密钥结构如下图所示: PTK的组成如上图所示,由K…

做一个专业的声音分析系统,需要对声音那些评判标准进行计算

为了构建一个专业的声音分析系统,需要对以下评判标准进行计算。每个标准需要相应的算法和技术指标来实现。下面是一些关键的评判标准和如何计算这些标准的具体方法: 1. 音质 清晰度 信噪比(SNR):计算音频信号中的信…

win11下部署Jenkins,build c#项目

一个c#的项目,由于项目经理总要新版本测试,以前每次都是手动出包,现在改成jenkins自动生成,节省时间。 一、下载Jenkins, 可以通过清华镜像下载Index of /jenkins/windows-stable/ | 清华大学开源软件镜像站 | Tsingh…

前端使用pinia中存入的值

导入pinia,创建pinia实例 使用pinia中的值

mysql8多值索引

MySQL8新出了一个多值索引,我还没体验过呢,今天试一试。 建表 我先建个表试一试多值索引的效果。我粗略地看了下多值索引的介绍,发现是只适用于数组类型的。所以我建一个含有数组字段的表试一试。语法还是挺麻烦的: create tabl…

2.电容(常见元器件及电路基础知识)

一.电容种类 1.固态电容 这种一般价格贵一些,ESR,ESL比较低,之前项目400W电源用的就是这个,温升能够很好的控制 2.铝电解电容 这种一般很便宜,ESR,ESL相对大一些,一般发热量比较大,烫手。 这种一般比上一个贵一点&am…

开源公司网站源码系统,降低成本,提升效率 附带完整的安装代码包以及搭建教程

系统概述 开源公司网站源码系统是一个基于开源技术的网站建设解决方案。它提供了完整的网站框架和功能模块,允许企业快速搭建起一个功能齐全、设计美观的企业网站。该系统不仅降低了网站开发的成本,还大大提高了建设效率,使企业能够更快地将…

泛微开发修炼之旅--37通过js实现监听下拉框,并触发后端接口,改变其他控件内容的实现方法与源码(含pc端和移动端实现)

文章链接:37通过js实现监听下拉框,并触发后端接口,改变其他控件内容的实现方法与源码(含pc端和移动端实现)

游戏AI的创造思路-技术基础-决策树(2)

上一篇写了决策树的基础概念和一些简单例子,本篇将着重在实际案例上进行说明 目录 8. 决策树应用的实际例子 8.1. 方法和过程 8.1.1. 定义行为 8.1.2. 确定属性 8.1.3. 构建决策树 8.1.4. 实施行为 8.1.5. 实时更新 8.2. Python代码 8. 决策树应用的实际例子…

滑动窗口,最长子序列最好的选择 -> O(N)

最近在学校上短学期课程,做程序设计题,一下子回忆起了大一学数据结构与算法的日子! 这十天我会记录一些做题的心得,今天带来的是对于最长子序列长度题型的解题框架:滑动窗口 本质就是双指针算法: 通过le…

模拟生成高斯随机数序列

模拟和生成高斯随机数序列(服从标准正态分布的随机变量) Box-Muller 法 & Marsaglia 极坐标法 Box-Muller:使两个独立的均匀分布生成一个高斯分布。 Box-Muller方法的基本思想是利用两个独立的均匀分布随机变量的关系来生成高斯分布的…

Ubuntu与Windows通过WIFI与以太网口共享网络,Ubuntu与Windows相互ping通,但ping百度失败

Linux开发板(正点原子阿尔法_IMX6U)与Ubuntu的文件传输SCP 报错 SSH: no matching host key type found. Their offer: ssh-rsa-CSDN博客 前面的文章提到了如何将Ubuntu与Windows通过WIFI共享网络给以太网,从而实现Linux开发板、Ubuntu、Win…

香港优才计划续签难吗?一次性说清楚优才续签要求,不在香港居住也能续签成功!

香港优才计划续签难吗?这个问题对考虑申请优才的人来说,还是挺重要的。我们申请优才,最关注的2个问题,一个是获批,还有一个就是续签了。 毕竟我们费那么大功夫申请优才,可不只是为了一个为期3年的香港临时…

如何分析软件测试中发现的Bug!

假如你是一名软件测试工程师,每天面对的就是那些“刁钻”的Bug,它们像是隐藏在黑暗中的敌人,时不时跳出来给你一个“惊喜”。那么,如何才能有效地分析和处理这些Bug,让你的测试工作变得高效且有趣呢?今天我…

MongoDB - 集合和文档的增删改查操作

文章目录 1. MongoDB 运行命令2. MongoDB CRUD操作1. 新增文档1. 新增单个文档 insertOne2. 批量新增文档 insertMany 2. 查询文档1. 查询所有文档2. 指定相等条件3. 使用查询操作符指定条件4. 指定逻辑操作符 (AND / OR) 3. 更新文档1. 更新操作符语法2. 更新单个文档 updateO…

web安全及内网安全知识

本文来源无问社区(wwlib.cn)更多详细内容可前往观看http://www.wwlib.cn/index.php/artread/artid/7506.html Web安全 1、sql注入 Web程序中对于用户提交的参数未做过滤直接拼接到SQL语句中执行,导致参数中的特殊字符破坏了SQL语句原有逻…

简述多云互联原理,客户公私云池互通的产品是什么?

多云互联原理基于企业或组织使用多个不同的云服务提供商的基础设施和服务,以实现最佳的运营效率、弹性和成本效益。这种策略允许用户避免供应商锁定,分散风险,并利用不同云服务商的特定优势,例如价格、地理位置、功能或性能。 多云…

jvm 07 GC算法,内存池,对象内存分配

01 垃圾判断算法 1.1引用计数算法 最简单的垃圾判断算法。在对象中添加一个属性用于标记对象被引用的次数,每多一个其他对象引用,计数1, 当引用失效时,计数-1,如果计数0,表示没有其他对象引用,…