自注意力简介

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师的课程讲得最好。

自注意力就是输入一堆向量,假设称为a1,a2,a3,a4,那么这四个向量都会参与自注意力机制的运算,得到的结果仍然是四个输出,这四个输出再去做全连接运算。而每一个自注意力机制的输出都用到了a1~a4四个向量来进行运算,也就是说每个输出都是观察了所有的输入之后才得到的。

首先,输入a1需要和a2,a3,a4分别计算相关性,这个相关性可以由缩放点积方式来计算,也就称作缩放点积注意力,也可以由两个两个输入向量相加后再做非线性处理得到,称为加性注意力。

缩放点积的计算方法如下:

输入一个向量v1和一个向量v2,v1去乘上一个可训练矩阵Wq得到q,v2去乘上一个可训练矩阵Wk得到k,再把这个q和k做一个点积运算,得到的就是α,类似于相似度。

回到前面的例子中,a1这里既作为q,又作为k,又作为v。其中,Wq*a1就是q1,Wk*a1就是k1,q1和k1的点积就是α11,相当于a1自己和自己的相似度,同样的,a1和a2,a3,a4分别计算得到α12,α13,α14,然后将α11,α12,α13,α14经过softmax得到最终的四个输出,如下图所示:

然后再用一个可训练矩阵Wv去乘以a1得到v1,用计算得到的相似度α'11去乘以v1,得到一个值temp11;同样的,用可训练矩阵Wv去乘以a2得到v2,用计算得到的相似度α‘12去乘以v2,得到temp12;类似的得到temp13,temp14,然后把temp11+temp12+temp13+temp14得到b1,这个b1就是自注意力机制的第一个输出。

我们刚刚是以a1的视角做的运算,得到b1,同样可以以a2,a3,a4的视角做运算,得到b2,b3,b4。这次就得到了自注意力机制的输出。光看最后这个结构图,有点类似全连接,只是里面的运算过程比全连接要复杂。

下面,我们来看一下如何用代码实现自注意力的计算。

import torch  
import torch.nn as nn  
import torch.nn.functional as F  class SelfAttention(nn.Module):# embed_size代表输入的向量维度,heads代表多头注意力机制中的头数量def __init__(self, embed_size, heads): super(SelfAttention,self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads # 每个头的维度# 用assert断言机制判断assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads" # 没有偏置项,其实这个线性层本质上就是为了计算值Wv*a = Vself.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)# 最后的全连接操作,输出仍是输入的向量维度,也就是说大小是不变的self.fc_out = nn.Linear(heads*self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):# 这个mask也很关键,它用于控制模型在处理序列数据时应该关注哪些部分,以及忽略哪些部分N = query.shape[0] # 获取输入的批量个数print("N:",N)value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 获取输入序列的长度# Split the embedding into self.heads different pieces  # 把k,q,v都切分为多个组values = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)# 计算k,q,vvalues = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk",[queries, keys]) # 格式转化print("queries.shape:", queries.shape)print("keys.shape:", keys.shape)print("energy.shape:", energy.shape)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20"))attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3) # softmax内部是缩放点积print("attention.shape:", attention.shape)print("values.shape:", values.shape)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)out = self.fc_out(out)return out
embed_size = 512
heads = 8
attention = SelfAttention(embed_size, heads)# batch size 1, seq length 60
values = torch.rand(1,60,embed_size)
keys = torch.rand(1,60,embed_size)
queries = torch.rand(1,60,embed_size)
mask = None # 假设没有maskout = attention(values, keys, queries, mask)
print(out.shape)# 输出
N: 1
queries.shape: torch.Size([1, 60, 8, 64])
keys.shape: torch.Size([1, 60, 8, 64])
energy.shape: torch.Size([1, 8, 60, 60])
attention.shape: torch.Size([1, 8, 60, 60])
values.shape: torch.Size([1, 60, 8, 64])
torch.Size([1, 60, 512])

通过这个程序,我们可以看到,自注意力机制是不改变输入和输出的形状的,输入的Q,K,V格式是[1,60,512],输出的结果的仍然是[1,60,512]。

下面是几点说明:

1. 这里的embed_size代表的是输入到自注意力层中的每个元素的向量维度。在Transformer模型中,输入数据首先会被转换成一个固定长度的向量,这个向量的长度就称为embed_size。

2. mask表示的是模型在处理序列数据时,应该忽略掉哪部分,我这里设置为None,也就是全部参与计算。

3. einsum,称为爱因斯坦求和,起源是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展。einsum 可以计算向量、矩阵、张量运算,如果利用得当,sinsum可完全代替其他的矩阵计算方法。

例如,C = einsum('ij,jk->ik', A, B),就相当于两个矩阵求内积:cik = Σj AijBjk。

通过输出可以看到,在计算前queries的形状是[1,60,8,64],keys的形状是[1,60,8,64],在表达式"nqhd,nkhd->nhqk"中,n=1,q=60,h=8,d=64,k=60,两个矩阵进行内积,因此得到的结果是nhqk,也就是[1,8,60,60]。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44895.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新兴商业模式如何破局?市场策划专家的实战指南

在这个融合了传统市场、互联网和信息技术的大潮中,想要在市场营销策划上玩得转,咱们得有超凡的全局思维和跨界的协作精神。 下面,我就来和大家聊聊如何在这样一个复杂环境下搞定市场营销策划,让你在竞争激烈的市场中脱颖而出。 …

MySQL慢查询日志详解与性能优化指南

1. 什么是慢查询日志 慢查询日志是MySQL提供的一种日志记录功能,它能够记录执行时间超过预设阈值的SQL查询语句,并将这些信息写入到日志文件中。 2. 查看慢查询日志的设置和状态 2.1 慢查询日志的开启状态和日志文件位置 通过以下命令可以查看慢查询…

Nodejs 第八十二章(Nacos动态配置)

什么是动态配置? 在Nacos中,动态配置是指可以在运行时动态修改应用程序的配置信息,而不需要重新启动或重新部署应用程序。通过Nacos的动态配置功能,开发人员可以将应用程序的配置信息存储在Nacos服务器中,并在需要时进…

PTK是如何加密WLAN单播数据帧的?

1. References WLAN 4-Way Handshake如何生成PTK?-CSDN博客 2. 概述 在Wi-Fi网络中,单播、组播和广播帧的加密算法是由AP决定的。其中单播帧的加密使用PTK密钥,其PTK的密钥结构如下图所示: PTK的组成如上图所示,由K…

做一个专业的声音分析系统,需要对声音那些评判标准进行计算

为了构建一个专业的声音分析系统,需要对以下评判标准进行计算。每个标准需要相应的算法和技术指标来实现。下面是一些关键的评判标准和如何计算这些标准的具体方法: 1. 音质 清晰度 信噪比(SNR):计算音频信号中的信…

win11下部署Jenkins,build c#项目

一个c#的项目,由于项目经理总要新版本测试,以前每次都是手动出包,现在改成jenkins自动生成,节省时间。 一、下载Jenkins, 可以通过清华镜像下载Index of /jenkins/windows-stable/ | 清华大学开源软件镜像站 | Tsingh…

模切厂如何选择合适的ERP系统?听说模切行业都是选择点晴模切ERP

选择适合模切行业的ERP系统时,应考虑系统的功能需求、供应商的选择、实施案例、用户评价和技术支持。点晴模切ERP系统因其全面的功能、成熟的架构、可扩展性、业财一体化管理、简便的开发平台和精确的刀模管理,被广泛认为是模切行业的优选。 一、功能需…

前端使用pinia中存入的值

导入pinia,创建pinia实例 使用pinia中的值

数字身份管理发展趋势:访问控制智能化

人工智能和机器学习技术正在大量应用于安全访问控制领域。这些技术可以分析用户行为,并实时监测访问中出现的异常情况,有助于主动识别潜在的安全风险。人工智能和机器学习可以显著提高业务数据和系统被访问过程中的安全性,它们还可以为用户提…

mysql8多值索引

MySQL8新出了一个多值索引,我还没体验过呢,今天试一试。 建表 我先建个表试一试多值索引的效果。我粗略地看了下多值索引的介绍,发现是只适用于数组类型的。所以我建一个含有数组字段的表试一试。语法还是挺麻烦的: create tabl…

2.电容(常见元器件及电路基础知识)

一.电容种类 1.固态电容 这种一般价格贵一些,ESR,ESL比较低,之前项目400W电源用的就是这个,温升能够很好的控制 2.铝电解电容 这种一般很便宜,ESR,ESL相对大一些,一般发热量比较大,烫手。 这种一般比上一个贵一点&am…

PS设计新手如何学习?沈阳PS设计线下培训

对于PS设计新手来说,学习之路可能既充满期待又伴有挑战。为了帮助你高效、系统地掌握Photoshop技能,以下是一些建议: 一、了解基础知识 界面熟悉:打开Photoshop,花时间熟悉工作区域,包括菜单栏、工具箱、面…

[AI Fabric] 解锁AI的未来:深入探索Fabric开源框架

今天看到一个项目,Fabric,我们一起来看下 介绍 fabric 是一个使用人工智能增强人类能力的开源框架。 为什么需要Fabric 因为作者认为,人工智能很强大,不存在能力问题,存在的是集成问题。 Fabric 的创建就是为了解…

原来没分库分表,后期如何分库分表?

MySQL 后期进行分库分表是一项复杂的任务,需要仔细规划和逐步实施。以下是一个详细的步骤指南,帮助你在现有系统上实施分库分表: 1. 分析现有系统 评估当前数据库的表和数据量:确定哪些表的数据量和访问量最大,哪些表…

开源公司网站源码系统,降低成本,提升效率 附带完整的安装代码包以及搭建教程

系统概述 开源公司网站源码系统是一个基于开源技术的网站建设解决方案。它提供了完整的网站框架和功能模块,允许企业快速搭建起一个功能齐全、设计美观的企业网站。该系统不仅降低了网站开发的成本,还大大提高了建设效率,使企业能够更快地将…

深入Scikit-learn:掌握Python最强大的机器学习库

Scikit-learn是一个基于Python的开源机器学习库,广泛用于数据挖掘和数据分析。以下是一些Scikit-learn中常用知识点的代码演示: 1. 导入库和准备数据 # 导入所需的库 from sklearn import datasets from sklearn.model_selection import train_test_sp…

ActiViz中的点放置器vtkPointPlacer

文章目录 1. vtkPointPlacer2. vtkFocalPlanePointPlacer3. vtkPolygonalSurfacePointPlacer4. vtkImageActorPointPlacer5. vtkBoundedPlanePointPlacer6. vtkTerrainDataPointPlacer1. vtkPointPlacer 概述: vtkPointPlacer是一个基类,用于确定在三维空间中放置点的最佳位…

泛微开发修炼之旅--37通过js实现监听下拉框,并触发后端接口,改变其他控件内容的实现方法与源码(含pc端和移动端实现)

文章链接:37通过js实现监听下拉框,并触发后端接口,改变其他控件内容的实现方法与源码(含pc端和移动端实现)

Java Spring 事物处理

一、定义 事务(Transaction)是指作为单个逻辑工作单元执行的一系列操作。操作要么全部成功执行,要么全部失败回滚,以确保数据的一致性和完整性。 二、特性 原子性(Atomicity):事务被视为不可分…

flutter Navigator跳转报错

Navigator operation requested with a context that does not include a Navigator. The context used to push or pop routes from the Navigator must be that of a widget that is a descendant of a Navigator widget. 这个报错是:因为你尝试使用 Navigator 操…