【深度学习实验】注意力机制(四):点积注意力与缩放点积注意力之比较

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
    • 3. 打分函数——加性注意力模型
    • 3. 打分函数——点积注意力与缩放点积注意力
      • a. 缩放点积注意力模型
      • b. 点积注意力模型
      • c. 模拟实验
      • d. 模型比较与选择
      • e. 代码整合

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

在这里插入图片描述

2. 掩码Softmax 操作

【深度学习实验】注意力机制(二):掩码Softmax 操作
在这里插入图片描述

3. 打分函数——加性注意力模型

  • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

【深度学习实验】注意力机制(三):打分函数——加性注意力模型

3. 打分函数——点积注意力与缩放点积注意力

  • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

  • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

a. 缩放点积注意力模型

class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
  1. 初始化方法 (__init__):

    • 参数:
      • dropout: Dropout 正则化的概率。
    • 说明: 初始化方法定义了模型的组件,仅包含了一个 Dropout 正则化层。
  2. 前向传播方法 (forward):

    • 参数:
      • queries: 查询张量,形状为 (batch_size, num_queries, d)
      • keys: 键张量,形状为 (batch_size, num_kv_pairs, d)
      • values: 值张量,形状为 (batch_size, num_kv_pairs, value_size)
      • valid_lens: 有效长度张量,形状为 (batch_size,)(batch_size, num_queries)
    • 返回值: 加权平均后的值张量,形状为 (batch_size, num_queries, value_size)
  3. 实现细节:

    • 计算缩放点积得分:通过张量乘法计算 querieskeys 的点积,然后除以 d \sqrt{d} d 进行缩放,其中 d d d 是查询或键的维度。
    • 使用 masked_softmax 函数计算注意力权重,根据有效长度对注意力进行掩码。
    • 将注意力权重应用到值上,得到最终的加权平均结果。
    • 使用 Dropout 对注意力权重进行正则化。

b. 点积注意力模型

class DotProductAttention2(nn.Module):"""点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention2, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):# P195:(8.3),(8.4)# 在计算得分时不进行缩放操作(即不再除以sqrt(d))self.scores = torch.bmm(queries, keys.transpose(1, 2))self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
  • 区别:
    • 计算点积得分:通过张量乘法计算 querieskeys 的点积。

c. 模拟实验

  • 模型数据
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
  • 模型应用
# 创建缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
# 使用模型进行前向传播
attention(queries, keys, values, valid_lens)# 创建点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
# 使用模型进行前向传播
attention2(queries, keys, values, valid_lens)

在这里插入图片描述

  • 权重可视化
      为了直观地展示模型在不同输入下的注意力分布,使用前文 show_heatmaps 函数,通过热图的形式展示注意力权重。
# 可视化缩放点积注意力权重
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')# 可视化点积注意力权重
show_heatmaps(attention2.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')

在这里插入图片描述

d. 模型比较与选择

  • 缩放点积注意力模型

    • 适用于处理高维度的查询和键。
    • 通过缩放操作有助于防止点积得分的方差过大。
  • 点积注意力模型

    • 适用于处理相对较低维度的查询和键。
    • 更方便地利用矩阵乘积提高计算效率。
      在这里插入图片描述

e. 代码整合

# 导入必要的库
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import datadef masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作"""# X:3D张量,valid_lens:1D或2D张量if valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):"""显示矩阵热图"""d2l.use_svg_display()num_rows, num_cols = matrices.shape[0], matrices.shape[1]fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,sharex=True, sharey=True, squeeze=False)for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)if i == num_rows - 1:ax.set_xlabel(xlabel)if j == 0:ax.set_ylabel(ylabel)if titles:ax.set_title(titles[j])fig.colorbar(pcm, ax=axes, shrink=0.6)class DotProductAttention(nn.Module):"""缩放点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):print(queries)d = queries.shape[-1]print(d)self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)print(self.scores)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)class DotProductAttention2(nn.Module):"""点积注意力"""def __init__(self, dropout, **kwargs):super(DotProductAttention2, self).__init__(**kwargs)# 使用暂退法进行模型正则化self.dropout = nn.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def forward(self, queries, keys, values, valid_lens=None):# P195:(8.3),(8.4)# 在计算得分时不进行缩放操作(即不再除以sqrt(d))self.scores = torch.bmm(queries, keys.transpose(1, 2))print(self.scores)self.attention_weights = masked_softmax(self.scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])# 缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
attention(queries, keys, values, valid_lens)# 点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
attention2(queries, keys, values, valid_lens)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155498.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用户增长常见分析模型

一、用户增长是什么 用户增长基本上会涉及生意场上的各行各业,你开个店面希望有更多的客户光顾,你做了个APP希望有更多的用户经常使用,你搭建了个电商平台希望有更多的人下单买东西。 用户增长,即以提升用户LTV为目的&#xff08…

Self-Supervised Exploration via Disagreement论文笔记

通过分歧进行自我监督探索 0、问题 使用可微的ri直接去更新动作策略的参数的,那是不是就不需要去计算价值函数或者critic网络了? 1、Motivation 高效的探索是RL中长期存在的问题。以前的大多数方式要么陷入具有随机动力学的环境,要么效率…

应用在金银精炼控制系统中的Modbus转Profinet网关案例

应用在金银精炼控制系统中的Modbus转Profinet网关案例 Modbus转Profinet网关(XD-MDPN100)能够支持多种通信协议和接口,满足不同设备和系统的需求。在金银精炼控制系统中使用,通过控制PID阀门的大小,将1200plc与PID控制…

Git 远程仓库(Github)

目录 添加远程库 查看当前的远程库 提取远程仓库 推送到远程仓库 删除远程仓库 Git 并不像 SVN 那样有个中心服务器。 目前我们使用到的 Git 命令都是在本地执行,如果你想通过 Git 分享你的代码或者与其他开发人员合作。 你就需要将数据放到一台其他开发人员…

【paddlepaddle】

安装paddlepaddle 报错 ImportError: /home/ubuntu/miniconda3/envs/paddle_gan/bin/../lib/libstdc.so.6: version GLIBCXX_3.4.30 not found (required by /home/ubuntu/miniconda3/envs/paddle_gan/lib/python3.8/site-packages/paddle/fluid/libpaddle.so) 替换 /home/ubu…

【python】Python生成GIF动图,多张图片转动态图,pillow

pip install pillow 示例代码: from PIL import Image, ImageSequence# 图片文件名列表 image_files [car.png, detected_map.png, base64_image_out.png]# 打开图片 images [Image.open(filename) for filename in image_files]# 设置输出 GIF 文件名 output_g…

GAMES101—Lec 05~06:光栅化

目录 概念回顾(个人理解)光栅化1.采样2.采样出现的问题:走样 反走样 概念回顾(个人理解) 屏幕:在图形学中,我们认为屏幕是一个二维数组,数组里的每一个元素为一个二维像素。 光栅化…

【Operating Systems:Three Easy Pieces 操作系统导论 】第28章 插叙:线程 API

【Operating Systems:Three Easy Pieces 操作系统导论 】 第28章 插叙&#xff1a;线程 API pthread 库介绍 线程创建 #include <pthread.h> // 头文件 int pthread_create(pthread_t * thread,const pthread_attr_t * attr,void * (*start_routine)(void*),void *…

【数据结构(四)】栈(1)

文章目录 1. 关于栈的一个实际应用2. 栈的介绍3. 栈的应用场景4. 栈的简单应用4.1. 思路分析4.2. 代码实现 5. 栈的进阶应用(实现综合计算器)5.1. 栈实现一位数计算(中缀表达式)5.1.1. 思路分析5.1.2. 代码实现 5.2. 栈实现多位数计算(中缀表达式)5.2.1. 解决思路5.2.2. 代码实…

强化学习笔记

这里写自定义目录标题 参考资料基础知识16.3 有模型学习16.3.1 策略评估16.3.2 策略改进16.3.3 策略迭代16.3.3 值迭代 16.4 免模型学习16.4.1 蒙特卡罗强化学习16.4.2 时序差分学习Sarsa算法&#xff1a;同策略算法&#xff08;on-policy&#xff09;&#xff1a;行为策略是目…

Android Studio 安装及使用

&#x1f353; 简介&#xff1a;java系列技术分享(&#x1f449;持续更新中…&#x1f525;) &#x1f353; 初衷:一起学习、一起进步、坚持不懈 &#x1f353; 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正&#x1f64f; &#x1f353; 希望这篇文章对你有所帮助,欢…

电脑游戏录屏软件,记录游戏高光时刻

电脑游戏录制是游戏爱好者分享游戏乐趣、技巧和成就的绝佳方式&#xff0c;此时&#xff0c;一款好用的录屏软件就显得尤为重要。本文将为大家介绍三款电脑游戏录屏软件&#xff0c;通过对这三款软件的分步骤详细介绍&#xff0c;让大家更加了解它们的特点及使用方法。 电脑游戏…

视频剪辑技巧:如何高效地将多个视频合并成一个新视频

在视频制作过程中&#xff0c;将多个视频合并成一个新视频是一个常见的操作。这涉及到将多个片段组合在一起&#xff0c;或者将不同的视频素材进行混剪。无论是制作一部完整的影片&#xff0c;还是为社交媒体提供短视频&#xff0c;都要掌握如何高效地将多个视频合并。现在一起…

ky10 server arm 在线编译安装openssl3.1.4

在线编译脚本 #!/bin/shOPENSSLVER3.1.4OPENSSL_Vopenssl versionecho "当前OpenSSL 版本 ${OPENSSL_V}" #------------------------------------------------ #wget https://www.openssl.org/source/openssl-3.1.4.tar.gzecho "安装OpenSSL${OPENSSLVER}...&q…

Joern安装与使用

环境准备 Joern需要在Linux环境中运行&#xff0c;所以在Windows系统中需要借助WSL或虚拟机安装。 JDK安装 Joern的运行需要JAVA环境的支持&#xff0c;本次采用的是JDK17&#xff0c;其他版本建议看一下Joern官方文档。 apt install openjdk-17-jre-headless 配置JAVA环境变…

Win11+Modelsim SE-64 10.6d搭建UVM环境

1、添加源文件及tb文件 在目录下建立文件夹&#xff0c;将DUT和Testbench添加进去&#xff0c;文件夹内容如下所示&#xff1a; 2、以《UVM实战》中的例子做简单的示例&#xff1a; 2.1 设计文件 &#xff1a;dut.sv 功能很简单&#xff0c;即将接受到的数据原封不动发送出去…

UE4基础篇十六:自定义 EQS 生成器

UE4 中的 EQS 带有一组很好的查询项生成器,但在某些情况下,您可能更喜欢根据需要创建生成器。我决定编写自己的生成器,因为我必须编写一个查询来找到查询器周围的最佳位置,但又不能太靠近它。我知道我可以添加一个距离测试来随着距离增加分数,但我什至不想考虑距查询器一定…

vulnhub5

靶机下载地址&#xff1a; https://download.vulnhub.com/boredhackerblog/hard_socnet2.ova 信息收集 第一步信息收集&#xff0c;还是老方法我习惯 fscan 和 nmap 一起用 Fscan 简单探测全局信息 ┌──(kali㉿kali)-[~/Desktop/Tools/fscan] └─$ ./fscan_amd64 -h 19…

5.4 Windows驱动开发:内核通过PEB取进程参数

PEB结构(Process Envirorment Block Structure)其中文名是进程环境块信息&#xff0c;进程环境块内部包含了进程运行的详细参数信息&#xff0c;每一个进程在运行后都会存在一个特有的PEB结构&#xff0c;通过附加进程并遍历这段结构即可得到非常多的有用信息。 在应用层下&am…

轿车5+1汽车变速器变速箱同步器操纵机构机械结构设计CAD汽车工程

wx供重浩&#xff1a;创享日记 对话框发送&#xff1a;汽车变速器 获取完整论文报告说明书工程源文件 变速器工程图 操纵机构3D图 一、机械式变速器的概述及其方案的确定 1.1 变速器的功用和要求 变速器的功用是根据汽车在不同的行驶条件下提出的要求&#xff0c;改变发动机…