【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)


文章目录

  • 【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)
  • 1. 加性注意力的起源与提出
  • 2. 加性注意力的原理
  • 3. 发展
  • 4. 代码实现
  • 5. 代码逐句解释


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 加性注意力的起源与提出

加性注意力(Additive Attention)是由Bahdanau et al. 在其2015年关于机器翻译的论文中提出的。这一注意力机制被应用于神经机器翻译(NMT)模型中,旨在提高翻译任务中序列对序列(Seq2Seq)模型的性能,尤其是解决长距离依赖问题。传统的Seq2Seq模型仅依赖于编码器的最终隐藏状态来生成翻译,这在处理长文本时容易丢失输入的细节信息。加性注意力通过在解码过程中对编码器隐藏状态进行加权求和,显著提升了模型性能

加性注意力是一种较早提出的注意力机制,与随后流行的点积注意力不同,加性注意力通过一个可学习的网络计算注意力分数,而不是直接计算向量之间的点积。加性注意力的提出标志着注意力机制在深度学习领域中的广泛应用,尤其是在处理长序列数据时的应用。

2. 加性注意力的原理

加性注意力的核心思想是通过学习一个函数来计算查询(Query)和键(Key)之间的相似性,然后根据相似性对值(Value)进行加权。

具体步骤如下:

1) 输入:

  • Query:解码器中的当前隐藏状态。
  • Key 和 Value:编码器中的隐藏状态(通常是一系列时间步的隐藏状态序列)。

2) 计算注意力分数: 通过将Query和Key进行非线性变换,再经过加性函数求得注意力分数。这个过程使用了一个可学习的权重矩阵,将查询和键分别映射到一个共同的表示空间,计算它们的相似性。

3) softmax归一化: 将上述得到的注意力分数通过softmax函数进行归一化,得到注意力权重。

4) 加权求和: 使用得到的注意力权重对值(Value)进行加权求和,生成最终的加权上下文向量。

公式如下:
在这里插入图片描述
这里, W q W_q Wq W k W_k Wk是可学习的权重矩阵, e i j e_{ij} eij 是注意力分数, v j v_j vj是Value。

3. 发展

加性注意力是最早被提出的注意力机制之一,并在神经机器翻译中取得了显著的成果。后来,随着注意力机制的发展,点积注意力(如Transformer中的缩放点积注意力)因其更高效的计算方式而逐渐取代了加性注意力。然而,加性注意力仍然在某些场景中被使用,尤其是在需要更细致的相似性计算的任务中。

在性能方面,加性注意力与点积注意力的主要区别在于计算复杂度。加性注意力通过一个可学习的神经网络计算注意力分数,计算复杂度为 O ( d ) O(d) O(d),而点积注意力直接计算点积,复杂度为 O ( d 2 ) O(d^2) O(d2),这使得加性注意力在某些场景下具有优势。

4. 代码实现

下面是一个使用加性注意力机制的简化实现,基于PyTorch框架。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键映射到同一空间self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义一个线性层,用于计算注意力分数self.energy_layer = nn.Linear(hidden_dim, 1)def forward(self, query, keys, values):# query: [batch_size, query_dim]# keys: [batch_size, seq_len, key_dim]# values: [batch_size, seq_len, value_dim]# 计算查询和键的投影query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 将查询扩展到和键的时间步相同的维度query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算 e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 通过softmax得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights# 测试加性注意力
batch_size = 2
query_dim = 5
key_dim = 5
value_dim = 6
seq_len = 10
hidden_dim = 20# 随机生成查询、键和值
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

5. 代码逐句解释

1. 导入库:

import torch
import torch.nn as nn
import torch.nn.functional as F

导入PyTorch库,其中torch用于张量操作,nn包含神经网络模块,F提供常用函数如softmax。

2. 定义加性注意力类:

class AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键投影到同一维度self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义计算注意力能量的线性层self.energy_layer = nn.Linear(hidden_dim, 1)

这里定义了AdditiveAttention类,继承自nn.Modulequery_layerkey_layer分别是将查询和键投影到同一维度的线性层,energy_layer用于计算注意力能量分数。

3. 前向传播函数:

def forward(self, query, keys, values):query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 扩展查询的维度,使其与键对齐query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)# 计算注意力能量:e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 通过线性层计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 使用softmax归一化得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 计算上下文向量,通过加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights
  • forward函数负责计算加性注意力的前向传播过程。首先,将查询和键分别通过线性层映射到相同的维度。
  • 然后,计算注意力能量,并使用softmax进行归一化,得到注意力权重。
  • 最后,使用这些注意力权重对值进行加权求和,生成上下文向量。
    4. 测试模型:
# 测试加性注意力
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

在这里,使用随机生成的张量querykeysvalues来测试加性注意力的输出。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/57293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C#】调用本机AI大模型流式返回

【python】AI Navigator的使用及搭建本机大模型_anaconda ai navigator-CSDN博客 【Python】AI Navigator对话流式输出_python ai流式返回-CSDN博客 前两章节我们讲解了使用AI Navigator软件搭建本机大模型,并使用python对大模型api进行调用,使其流式返…

“智能科研写作:结合AI与ChatGPT提升SCI论文和基金申请质量“

基于AI辅助下的高效高质量SCI论文撰写及投稿实践 科学研究的核心在于将复杂的思想和实验成果通过严谨的写作有效地传递给学术界和工业界。对于研究生、青年学者及科研人员,如何高效撰写和发表SCI论文,成为提升学术水平和科研成果的重要环节。系统掌握从…

SAP_FICO模块-资产减值功能对折旧和残值的影响

一、业务背景 由于财务同事没注意,用总账给资产多做了一笔凭证,导致该资产金额虚增,每个月的折旧金额也虚增;现在财务的需求是怎么操作可以进行资产减值,并且减少每个月计提的折旧; 二、实现方式 通过事务码…

qt EventFilter用途详解

一、概述 EventFilter是QObject类的一个事件过滤器,当使用installEventFilter方法为某个对象安装事件过滤器时,该对象的eventFilter函数就会被调用。通过重写eventFilter方法,开发者可以在事件处理过程中进行拦截和处理,实现对事…

go 语言 Gin Web 框架的实现原理探究

Gin 是一个用 Go (Golang) 编写的 Web 框架,性能极优,具有快速、支持中间件、crash处理、json验证、路由组、错误管理、内存渲染、可扩展性等特点。 官网地址:https://gin-gonic.com/ 源码地址:https://github.com/gin-gonic/gi…

Shell重定向输入输出

我的后端学习大纲 我的Linux学习大纲 重定向介绍 标准输入介绍 从键盘读取用户输入的数据,然后再把数据拿到Shell程序中使用; 标准输出介绍 Shell程序产生的数据,这些数据一般都是呈现到显示器上供用户浏览查看; 默认输入输出文件 每个…

前OpenAI首席技术官为新AI初创公司筹资;我国发布首个应用临床眼科大模型 “伏羲慧眼”|AI日报

文章推荐 2024人工智能报告.zip |一文迅速了解今年的AI界都发生了什么? 今日热点 据报道,前OpenAI首席技术官Mira Murati正在为一家新的AI初创公司筹集资金 据路透社报道,上个月宣布离职的OpenAI首席技术官Mira Murati正在为一…

栈和队列(一)

栈和队列的定义和特点 栈和队列是一种特殊的线性表,只能在表的端点进行操作 栈的定义和特点 这就是栈的结构,是一个特殊的线性表,只能在栈顶(或者说是表尾)进行操作。其中top为栈顶,base为栈底 栈s的存储…

华为:高级ACL 特定ip访问特定ip命令

网络拓扑图: 网络环境: 全网互通即可 1.创建一个名为test的高级ACL acl name test advance 2.添加规则 ##拒绝所有ip访问 rule 10 deny ip source any destination 192.168.1.10 0.0.0.0 只允许特定ip访问特定ip rule 5 permit ip source 192.168.2.10…

【Vulnhub靶场】Kioptrix Level 5

目标 本地IP:192.168.118.128 目标IP:192.168.118.0/24 信息收集 nmap探测存活主机,扫全端口,扫服务 首先探测到目标ip为:192.168.118.136 nmap -sP 192.168.118.0/24nmap -p- 192.168.118.136nmap -sV -A 192.168.…

BurpSuite渗透工具的简单使用

BurpSuite渗透工具 用Burp Suite修改请求 step1: 安装Burp Suite。官网链接:Burp Suite官网 step2: 设置代理 step3: 如果要拦截https请求,还需要在客户端安装证书 step4: 拦截到请求可以在Proxy ->…

【嵌入式实时操作系统开发】智能家居入门4(FreeRTOS、MQTT服务器、MQTT协议、STM32、微信小程序)

前面已经发了智能家居入门的1、2、3了,在实际开发中一般都会使用到实时操作系统,这里就以FreeRTOS为例子,使用标准库。记录由裸机转到实时操作系统所遇到的问题以及总体流程。相较于裸机,系统实时性强了很多,小程序下发…

opencv环境配置-适配b站阿童木的opencv教程

首先,opencv作为一个库文件,目的是为了让更多人不需要学习底层像素操作就能上手视觉技术,所以他适配很多环境,目前电脑端我知道的就可以适配C语言 C Python MCU端就是openmv跟他最类似,还有个k210 canmv 阿童木教的…

Unity 山水树木

本章节内容 1. Unity对3D游戏物体的简单操作; 2. 构建山水树木的场景 1. Unity 简易操作 1.1 新建3D游戏场景 1. 打开Unity Hub,点击 New Project (新建项目)按键,选择第二项 3D(Built-In Render Pipeline)&#xf…

Linux之实战命令41:lshw应用实例(七十五)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…

虚拟滚动是怎么做性能优化的?

前言 一个简单的情景模拟(千万别被带入): A: 假设现在有 10 万条数据,你作为前端该怎么优化这种大数据的列表? B: 针对大数据列表一般不会依次性加载,会采用上拉加载、分页加载等…

如何用数据字典提升数据质量和决策效率?

在前面的文章中我们谈到过数据字典的概念,本文将继续探讨如何用数据字典提升数据质量和决策效率。 一、数据字典 数据字典:一种对数据的定义和描述的集合,它包含了数据的名称、类型、长度、取值范围、业务含义、数据来源等详细信息。 数据字…

Java中的一些名词概念

**函数式接口:** 概念&#xff1a;一个接口中的抽象方法只有一个&#xff0c;那么这个接口就是一个函数式接口。形参: 形参变量是**功能函数里的变量**&#xff0c;只有<u>在被调用的时候才分配内存单元</u>&#xff0c;<u>调用结束后立即释放</u>。…

AUTOSAR_EXP_ARAComAPI的6章笔记(2)

☞返回总目录 相关总结&#xff1a;AutoSar AP CM实例说明符的使用方法总结 6.2 实例说明符的使用方法 一、InstanceSpecifier 的概念 InstanceSpecifier 是在 [3] 中定义的一个核心概念&#xff0c;它由符合特定模型元素绝对路径的模型元素 shortName 组成&#xff0c;表现…

【10月最新】植物大战僵尸杂交版即将新增【植物】内容介绍预告(附最新版本下载链接)

新增植物 玉米旋转机 玉米旋转机是一种支持性植物&#xff0c;每4秒可散射6颗油炸玉米或黄油&#xff08;概率20%&#xff09;&#xff0c;油炸玉米经过火炬可变为爆米花&#xff0c;造成范围爆炸伤害。其价格为325&#xff0c;并在每种植一颗后&#xff0c;价格增加50。玉米旋…