Attention机制解析

Attention

Attention机制解析

1. 引言

Attention机制在自然语言处理(NLP)和计算机视觉(CV)等领域取得了广泛的应用。其核心思想是通过对输入数据的不同部分赋予不同的权重,使模型能够更加关注重要的信息。本文将详细介绍Attention的原理,包括Self-Attention和Cross-Attention的机制、公式解析以及代码实现,并探讨其在实际中的应用。

2. Attention机制原理

2.1 基本概念

Attention机制的基本思想是通过计算输入序列中每个元素的重要性(即注意力权重),然后对这些权重进行加权求和,从而得到输出。其公式表示如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • Q Q Q(Query):查询矩阵
  • K K K(Key):键矩阵
  • V V V(Value):值矩阵
  • d k d_k dk:键矩阵的维度

2.2 Self-Attention

Self-Attention是Attention机制的一种特殊形式,其查询、键和值都来自同一个输入序列。Self-Attention的计算步骤如下:

  1. 将输入序列映射为查询、键和值矩阵。
  2. 计算查询和键的点积,并进行缩放。
  3. 对结果应用softmax函数,得到注意力权重。
  4. 使用这些权重对值进行加权求和,得到输出。
公式

设输入序列为 X X X,则:

Q = X W Q , K = X W K , V = X W V Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV

Self-Attention ( X ) = softmax ( Q K T d k ) V \text{Self-Attention}(X) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Self-Attention(X)=softmax(dk QKT)V

其中 W Q W_Q WQ W K W_K WK W V W_V WV 是可学习的权重矩阵。

2.3 Cross-Attention

Cross-Attention与Self-Attention类似,但其查询、键和值来自不同的输入序列。通常用于结合来自不同来源的信息。

公式

设查询序列为 X X X,键和值序列为 Y Y Y,则:

Q = X W Q , K = Y W K , V = Y W V Q = XW_Q, \quad K = YW_K, \quad V = YW_V Q=XWQ,K=YWK,V=YWV

Cross-Attention ( X , Y ) = softmax ( Q K T d k ) V \text{Cross-Attention}(X, Y) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Cross-Attention(X,Y)=softmax(dk QKT)V

3. 代码实现

3.1 Self-Attention的实现

import torch
import torch.nn.functional as Fclass SelfAttention(torch.nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)# Scaled dot-product attentionenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out

3.2 Cross-Attention的实现

import torch
import torch.nn.functional as Fclass CrossAttention(torch.nn.Module):def __init__(self, embed_size, heads):super(CrossAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)# Scaled dot-product attentionenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out

3.3 使用示例

没错,你看的没错,这两段代码一模一样,代码完全可以复用,只是 Q Q Q K K K V V V来源的序列不同而已。

# 假设我们有以下输入张量
embed_size = 256
heads = 8
seq_length = 10
N = 32  # batch sizevalues = torch.randn(N, seq_length, embed_size)
keys = torch.randn(N, seq_length, embed_size)
queries = torch.randn(N, seq_length, embed_size)
mask = None  # 这里我们没有使用mask# SelfAttention 示例
self_attention = SelfAttention(embed_size, heads)
self_attention_output = self_attention(values, keys, queries, mask)
print("SelfAttention Output Shape:", self_attention_output.shape)# CrossAttention 示例
cross_attention = CrossAttention(embed_size, heads)
cross_attention_output = cross_attention(values, keys, queries, mask)
print("CrossAttention Output Shape:", cross_attention_output.shape)

4. 应用

4.1 自然语言处理

在NLP中,Attention机制被广泛应用于各种任务,如机器翻译、文本生成和问答系统等。例如,Transformer模型通过使用多头自注意力机制实现了高效的序列到序列转换,极大地提高了翻译质量。

4.2 计算机视觉

在CV中,Attention机制用于图像识别、目标检测和图像生成等任务。自注意力机制可以帮助模型关注图像中的重要区域,从而提高识别精度。

4.3 多模态任务

在多模态任务中,Cross-Attention用于结合不同模态的数据,例如图像和文本的匹配、视频字幕生成等。

5. 结论

Attention机制通过动态地调整输入数据的权重,使得模型能够更有效地关注重要信息。Self-Attention和Cross-Attention分别在单一序列和多模态任务中发挥重要作用。随着研究的不断深入,Attention机制在各种领域的应用前景广阔。

希望本文能帮助读者理解Attention机制的原理和实现,并能在实际应用中加以利用。如果有任何问题或建议,欢迎在评论区留言交流。

6. 参考文献

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  2. Lin, Z., Feng, M., Santos, C. N. D., Yu, M., Xiang, B., Zhou, B., & Bengio, Y. (2017). A structured self-attentive sentence embedding. arXiv preprint arXiv:1703.03130.

7. 附录

7.1 多头自注意力机制

多头自注意力机制(Multi-Head Self-Attention)通过在不同的子空间中并行执行多个自注意力操作,使得模型能够捕获不同方面的信息。其公式为:

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO

其中:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

7.2 代码实现示例

class MultiHeadSelfAttention(torch.nn.Module):def __init__(self, embed_size, heads):super(MultiHeadSelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)# Scaled dot-product attentionenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return out

以上,Enjoy your learning!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46904.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

最优控制公式推导(代数里卡提方程,李雅普诺夫方程,HJB方程)

本文探讨了线性时不变系统(LTI系统)的最优控制问题,特别是线性二次调节器(LQR)问题。通过Hamilton-Jacobi-Bellman (HJB) 方程的推导,求得了系统的最优控制律,并进一步推导了代数里卡提方程&…

Python新手必学:如何解决Python安装包下载缓慢/无法下载的问题

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 文章内容 📒📝 临时使用镜像源📝 永久修改镜像源Windows系统macOS/Linux系统📝 推荐镜像源⚓️ 相关链接 ⚓️📖 介绍 📖 你是否曾在使用Python进行项目开发时,遇到过安装包下载速度如蜗牛爬行般的窘境?尤其是在急…

焊死,这38条命令还不会?难怪你的Windows那么费劲

号主:老杨丨11年资深网络工程师,更多网工提升干货,请关注公众号:网络工程师俱乐部 下午好,我的网工朋友。 我们每天都在和各种设备打交道,而命令提示符(CMD)无疑是我们这些技术宅的得…

玩转HarmonyOS NEXT之IM应用首页布局

本文从目前流行的垂类市场中,选择即时通讯应用作为典型案例详细介绍HarmonyOS NEXT的各类布局在实际开发中的综合应用。即时通讯应用的核心功能为用户交互,主要包含对话聊天、通讯录,社交圈等交互功能。 应用首页 创建一个包含一列的栅格布…

Java进阶之路66问 | 谈谈对熔断,限流,降级的理解

熔断(Circuit Breaker) 熔断机制类似于电路中的保险丝,用于在服务或系统出现异常或超负荷时暂时关闭,防止问题进一步扩大,待问题解决后再逐步恢复。这可以有效保护系统免受过载的影响。 想象你在使用电器时&#xff0…

JAVA 异步编程(异步,线程,线程池)一

目录 1.概念 1.1 线程和进程的区别 1.2 线程的五种状态 1.3 单线程,多线程,线程池 1.4 异步与多线程的概念 2. 实现异步的方式 2.1 方式1 裸线程(Thread) 2.1 方式2 线程池(Executor) 2.1.1 源码分析 2.1.2 线程池创建…

南京邮电大学计算机考研考情分析!专业课均分127分!复试录取比例偏高近2:1!计算机类共录取543人!

南京邮电大学(Nanjing University of Posts and Telecommunications),位于南京市,简称南邮(NJUPT),是教育部、工业和信息化部、国家邮政局与江苏省共建高校,国家“双一流”建设高校&…

软考中级科目包含哪些?应该考哪个?

软考中级包含5个专业方向,分别是:计算机软件、计算机网络、计算机应用技术、信息系统、信息服务。这5个方向又对应15个软考中级科目。 信息系统包括:系统集成项目管理工程师、信息系统监理师、信息安全工程师、数据库系统工程师、信息系统管…

C# 中IEnumerable与IQuerable的区别

目的 详细理清IEnumerator、IEnumerable、IQuerable三个接口之间的联系与区别 继承关系:IEnumerator->IEnumerable->IQuerable IEnumerator:枚举器 包含了枚举器含有的方法,谁实现了IEnuemerator接口中的方法,就可以自定…

力扣Hot100之两数之和

解法一: 双层循环暴力求解,先在数组的一个位置定住然后在这个位置的后续位置进行判断,如果两个数加起来等于目标和那么就返回 class Solution:def twoSum(self, nums: List[int], target: int) -> List[int]:for i,num in enumerate(num…

Windows 系统利用 SSH 和 WSL2 子系统当服务器

由于最近组内需要将一台 Windows 系统的电脑 W A W_A WA​ 转成能通过 SSH 访问,并且能用 Linux 命令当服务器运行。忙活了一天,终于是把全部东西弄通了。 安装 SSH 首先就是 W A W_A WA​ 先要安装 OpenSSH 服务,直接按照下面的教程安装…

HCIE是什么等级的证书?

HCIE(华为认证互联网专家,Huawei Certified Internetwork Expert)是华为认证体系中的最高等级证书。它要求考生具备在复杂网络环境中规划、设计、部署、运维和优化网络的能力。HCIE认证是华为认证体系中最具挑战性和含金量的认证之一&#xf…

RocketMQ实现分布式事务

RocketMQ的分布式事务消息功能,在普通消息基础上,支持二阶段的提交。将二阶段提交和本地事务绑定,实现全局提交结果的一致性。 1、生产者将消息发送至RocketMQ服务端。 2、RocketMQ服务端将消息持久化成功之后,向生产者返回Ack确…

NDK R25b 交叉编译FFMpeg4,项目集成,附库下载地址

1.准备工作 文件下载: NDK R25b下载地址:Android NDK历史版本下载网址 - 君*邪 - 博客园 (cnblogs.com) FFmpeg4.4.4 下载地址:https://ffmpeg.org/releases/ffmpeg-4.4.4.tar.xz 环境配置: 本次编译环境是在PC虚拟机中使用U…

普通人还有必要学习 Python 之类的编程语言吗?

在开始前分享一些编程的资料需要的同学评论888即可拿走 是我根据网友给的问题精心整理的对于编程的重要性,这里就不详谈了。 未来,我们和机器的交流会越来越多,编程可以简单看作是和机器对话并分发给机器任务。机器不仅越来越强大&#xff0…

C# —— CRC16 算法

CRC16:即循环冗余校验码。数据通信当中一种常用的查错校验码 其特征信息字段和校验字段的长度可以是任意选定的,对数据进行指定多项式计算 并且将得到的结果附加在帧的后面,接受的设备也执行类似的算法,以保证数据传输的正确性和完整性 crc…

鸿蒙语言基础类库:【@system.configuration (应用配置)】

应用配置 说明: 从API Version 7 开始,该接口不再维护,推荐使用新接口[ohos.i18n]和[ohos.intl]。本模块首批接口从API version 3开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import configurati…

云服务器实际内存与购买不足量问题

君衍 一、本篇缘由二、问题研究1、dmidecode2、dmesg | grep -i memory 三、kdump四、解决方案1、卸载kdump-tools2、清理依赖包3、修改配置文件4、重新生成配置文件5、重启服务器6、再次查看 一、本篇缘由 本篇由于最近买了云服务器,之前基本在本地使用VMware进行虚…

web自动化测试selenium的基本使用

目录 初始化浏览器并打开网页 定位网页元素 定位的方法 模拟键盘操作 模拟鼠标操作 xpath方法 xpath结点 路径表达式 轴 selenium是一个很流行的自动化测试的库,主要用于模拟浏览器的运行,是web应用测试的工具。 在使用selenium时,…