《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码

在注意力机制中,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention),也被称为内部注意力(intra-attention)。本节将使用自注意力进行序列编码,以及使用序列的顺序作为补充信息。

import math
import torch
from torch import nn
from d2l import torch as d2l

10.6.1 自注意力

在这里插入图片描述

给定一个由词元组成的输入序列 x 1 , … , x n \boldsymbol{x}_1,\dots,\boldsymbol{x}_n x1,,xn,其中任意 x i ∈ R d ( 1 ≤ i ≤ n ) \boldsymbol{x}_i\in\R^d\quad(1\le i\le n) xiRd(1in) 。该序列的自注意力输出为一个长度相同的序列 y 1 , … , y n \boldsymbol{y}_1,\dots,\boldsymbol{y}_n y1,,yn,其中:

y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \boldsymbol{y}_i=f(\boldsymbol{x}_i,(\boldsymbol{x}_1,\boldsymbol{x}_1),\dots,(\boldsymbol{x}_n,\boldsymbol{x}_n))\in\R^d yi=f(xi,(x1,x1),,(xn,xn))Rd

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,  # 基于多头注意力对一个张量完成自注意力的计算num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))  # 张量的形状为(批量大小,时间步的数目或词元序列的长度,d)。
attention(X, X, X, valid_lens).shape  # 输出与输入的张量形状相同
torch.Size([2, 4, 100])

10.6.2 比较卷积神经网络、循环神经网络和自注意力

在这里插入图片描述

  • 卷积神经网络

    • 计算复杂度为 O ( k n d 2 ) O(knd^2) O(knd2)

      • k k k 为卷积核大小

      • n n n 为序列长度是

      • d d d 为输入和输出的通道数量

    • 并行度为 O ( n ) O(n) O(n)

    • 最大路径长度为 O ( n / k ) O(n/k) O(n/k)

  • 循环神经网络

    • 计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

      d × d d\times d d×d 权重矩阵和 d d d 维隐状态的乘法计算复杂度为 O ( d 2 ) O(d^2) O(d2),由于序列长度为 n n n,因此循环神经网络层的计算复杂度为 O ( n d 2 ) O(nd^2) O(nd2)

    • 并行度为 O ( 1 ) O(1) O(1)

      O ( n ) O(n) O(n) 个顺序操作无法并行化。

    • 最大路径长度也是 O ( n ) O(n) O(n)

  • 自注意力

    • 计算复杂度为 O ( n 2 d ) O(n^2d) O(n2d)

      查询、键和值都是 n × d n\times d n×d 矩阵

    • 并行度为 O ( n ) O(n) O(n)

      每个词元都通过自注意力直接连接到任何其他词元。因此有 O ( 1 ) O(1) O(1) 个顺序操作可以并行计算

    • 最大路径长度也是 O ( 1 ) O(1) O(1)

总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

10.6.3 位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加 位置编码(positional encoding) 来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。

基于正弦函数和余弦函数的固定位置编码的矩阵第 i i i 行、第 2 j 2j 2j 列和 2 j + 1 2j+1 2j+1 列上的元素为:

p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) \begin{align} p_{i,2j}&=\sin{\left(\frac{i}{10000^{2j/d}}\right)}\\ p_{i,2j+1}&=\cos{\left(\frac{i}{10000^{2j/d}}\right)} \end{align} pi,2jpi,2j+1=sin(100002j/di)=cos(100002j/di)

#@save
class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, num_hiddens, dropout, max_len=1000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(dropout)# 创建一个足够长的Pself.P = torch.zeros((1, max_len, num_hiddens))X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)self.P[:, :, 0::2] = torch.sin(X)self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):X = X + self.P[:, :X.shape[1], :].to(X.device)return self.dropout(X)

在位置嵌入矩阵 P \boldsymbol{P} P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。从下面的例子中可以看到位置嵌入矩阵的第 6 列和第 7 列的频率高于第 8 列和第 9 列。第 6 列和第 7 列之间的偏移量(第 8 列和第 9 列相同)是由于正弦函数和余弦函数的交替。

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])


在这里插入图片描述

10.6.3.1 绝对位置信息

打印出 0 , 1 , … , 7 0,1,\dots,7 0,1,,7 的二进制表示形式即可明白沿着编码维度单调降低的频率与绝对位置信息的关系。

每个数字、每两个数字和每四个数字上的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替。

for i in range(8):print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111

在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')


在这里插入图片描述

10.6.3.2 相对位置信息

除了捕获绝对位置信息之外,上述的位置编码还允许模型学习得到输入序列中相对位置信息。这是因为对于任何确定的位置偏移 δ \delta δ,位置 i + δ i+\delta i+δ 处的位置编码可以线性投影位置 i i i 处的位置编码来表示。

这种投影的数学解释是,令 ω j = 1 / 1000 0 2 j / d \omega_j=1/10000^{2j/d} ωj=1/100002j/d,对于任何确定的位置偏移 δ \delta δ,上个式子中的任何一对 ( p i , 2 j , p i , 2 j + 1 ) (p_{i,2j},p_{i,2j+1}) (pi,2j,pi,2j+1) 都可以线性投影到 ( p i + δ , 2 j , p i + δ , 2 j + 1 ) (p_{i+\delta,2j},p_{i+\delta,2j+1}) (pi+δ,2j,pi+δ,2j+1)

[ cos ⁡ ( δ ω j ) sin ⁡ ( δ ω j ) − sin ⁡ ( δ ω j ) cos ⁡ ( δ ω j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ cos ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + sin ⁡ ( δ ω j ) cos ⁡ ( i ω j ) − sin ⁡ ( δ ω j ) sin ⁡ ( i ω j ) + cos ⁡ ( δ ω j ) cos ⁡ ( i ω j ) ] = [ sin ⁡ ( ( i + δ ) ω j ) cos ⁡ ( ( i + δ ) ω j ) ] = [ p i , 2 j p i , 2 j + 1 ] \begin{align} &\begin{bmatrix} \cos{(\delta\omega_j)} & \sin{(\delta\omega_j)}\\ -\sin{(\delta\omega_j)} & \cos{(\delta\omega_j)} \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =&\begin{bmatrix} \cos{(\delta\omega_j)}\sin{(i\omega_j)}+\sin{(\delta\omega_j)}\cos{(i\omega_j)}\\ -\sin{(\delta\omega_j)}\sin{(i\omega_j)}+\cos{(\delta\omega_j)}\cos{(i\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} \sin{((i+\delta)\omega_j)}\\ \cos{((i+\delta)\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix} \end{align} ===[cos(δωj)sin(δωj)sin(δωj)cos(δωj)][pi,2jpi,2j+1][cos(δωj)sin(iωj)+sin(δωj)cos(iωj)sin(δωj)sin(iωj)+cos(δωj)cos(iωj)][sin((i+δ)ωj)cos((i+δ)ωj)][pi,2jpi,2j+1]

2 × 2 2\times 2 2×2 投影矩阵不依赖于任何位置的索引 i i i

练习

(1)假设设计一个深度架构,通过堆叠基于位置编码的自注意力层来表示序列。可能会存在什么问题?


(2)请设计一种可学习的位置编码方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/120439.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

竞赛 深度学习人体跌倒检测 -yolo 机器视觉 opencv python

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 **基于深度学习的人体跌倒检测算法研究与实现 ** 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🥇学长这里给一个题目综合评分(每项满…

npm改变npm缓存路径和改变环境变量

在安装nodejs时,系统会自动安装在系统盘C, 时间久了经常会遇到C盘爆满,有时候出现红色,此时才发现很多时候是因为npm 缓存保存在C盘导致的,下面就介绍下如何改变npm缓存路径。 1、首先找到安装nodejs的路径&#xff0c…

JVM(Java Virtual Machine)G1收集器篇

前言 本文参考《深入理解Java虚拟机》,本文主要介绍G1收集器的收集思想和具体过程(填上一篇文章留下的坑) 本系列其他文章链接: JVM(Java Virtual Machine)内存模型篇 JVM(Java Virtual Machi…

SQL sever中函数(2)

目录 一、函数分类及应用 1.1标量函数(Scalar Functions): 1.1.1格式 1.1.2示例 1.1.3作用 1.2表值函数(Table-Valued Functions): 1.2.1内联表值函数(Inline Table-Valued Functions&am…

Linux shell编程学习笔记15:定义数组、获取数组元素值和长度

一、 Linux shell 脚本编程中的数组概述 数组是一种常见的数据结构。跟大多数编程语言一样,大多数Linux shell脚本支持数组,但对数组的支持程度各不相同,比如数组的维度,是支持一维数组还是多维数组?再如,…

Redis为什么变慢了

一、Redis为什么变慢了 1.Redis真的变慢了吗? 对 Redis 进行基准性能测试 例如,我的机器配置比较低,当延迟为 2ms 时,我就认为 Redis 变慢了,但是如果你的硬件配置比较高,那么在你的运行环境下,可能延迟是 0.5ms 时就可以认为 Redis 变慢了。 所以,你只有了解了你的…

蓝桥杯每日一题2023.10.27

题目描述 快速排序 - 蓝桥云课 (lanqiao.cn) #include <stdio.h>int quick_select(int a[], int l, int r, int k) {int p rand() % (r - l 1) l;int x a[p];{int t a[p]; a[p] a[r]; a[r] t;}int i l, j r;while(i < j) {while(i < j && a[i] &…

Python轮廓追踪【OpenCV形态学操作】

文章目录 概要代码运行结果 概要 一些理论知识 OpenCV形态学操作理论1 OpenCV形态学操作理论2 OpenCV轮廓操作|轮廓类似详解 代码 代码如下&#xff0c;可以直接运行 import cv2 as cv# 定义结构元素 kernel cv.getStructuringElement(cv.MORPH_RECT, (3, 3)) # print kern…

【Linux】rpm和yum的使用

不知道是不是有和我一样的宝子们&#xff0c;在rpm上卡了老久老久&#xff0c;但其实搞通了&#xff0c;理解了原理之后&#xff0c;不难的&#xff0c;所以不管你现在遇到的困难是什么&#xff0c;都不要放弃&#xff0c;一定要坚持&#xff0c;加油。 一、rpm 1.rpm rpm的…

On Moving Object Segmentation from Monocular Video with Transformers 论文阅读

论文信息 标题&#xff1a;On Moving Object Segmentation from Monocular Video with Transformers 作者&#xff1a; 来源&#xff1a;ICCV 时间&#xff1a;2023 代码地址&#xff1a;暂无 Abstract 通过单个移动摄像机进行移动对象检测和分割是一项具有挑战性的任务&am…

使用vue-cli搭建spa项目,vue项目结构说明,开发示例,如何修改端口号

目录 1. vue-cli安装 1.1 安装前提 1.2 什么是vue-cli 1.3 安装vue-cli 2. 使用vue-cli构建项目 2.1 使用脚手架创建项目骨架 2.2 到新建项目目录&#xff0c;安装需要的模块 2.3 如何修改端口号 2.4 添加element-ui模块 2.5 package.json详解 3. install命令中的-g…

目标检测技术概述

什么是目标检测&#xff1f; 在计算机视觉众多的技术领域中&#xff0c;目标检测&#xff08;Object Detection&#xff09;也是一项非常基础的任务&#xff0c;图像分割、物体追踪、关键点检测等通常都要依赖于目标检测。在目标检测时&#xff0c;由于每张图像中物体的数量、…

云游数智农业世界,体验北斗时空智能

今日&#xff0c;2023年中国国际农业机械展览会在武汉正式拉开帷幕&#xff0c;众多与会者云集&#xff0c;各类农机产品纷呈&#xff0c;盛况空前。 千寻位置作为国家北斗地基增强系统的建设与运营方&#xff0c;在中国国际农业机械展览会上亮相&#xff0c;以「北斗时空智能 …

代码审计及示例

简介&#xff1a; 代码安全测试是从安全的角度对代码进行的安全测试评估。 结合丰富的安全知识、编程经验、测试技术&#xff0c;利用静态分析和人工审核的方法寻找代码在架构和编码上的安全缺陷&#xff0c;在代码形成软件产品前将业务软件的安全风险降到最低。 方法&#x…

合成数据的好处和用途

在不断变化的数据科学和人工智能环境中&#xff0c;合成数据集的概念成为具有多种用途的强大工具。 假设您是一名数据科学家&#xff0c;并分配了为电子商务网站创建尖端推荐系统的任务。为此&#xff0c;您需要大量的用户交互数据。但是&#xff0c;您面临着保护用户隐私和处…

Lua入门使用与基础语法

文章目录 目的基础说明开发环境基础语法注释数据类型变量流程控制函数 总结 目的 Lua是一种非常小巧的脚本语言&#xff0c;基于C构建并且完全开源&#xff0c;可以方便的嵌入到各种项目中&#xff0c;当然也可以单独使用。Lua经常被用在很多非脚本语言的项目中&#xff0c;用…

设计模式—创建型模式之单例模式

设计模式—创建型模式之单例模式 介绍 单例模式说明&#xff1a;一个单一的类&#xff0c;负责创建自己的对象&#xff0c;同时确保系统中只有单个对象被创建。 单例模式特点&#xff1a; 某个类只能有一个实例&#xff1b;&#xff08;构造器私有&#xff09;它必须自行创…

Redis 主从

目录 ​编辑一、构建主从架构 1、集群结构 2、准备实例和配置 &#xff08;1&#xff09;创建目录 &#xff08;2&#xff09;修改原始配置 &#xff08;3&#xff09;拷贝配置文件到每个实例目录 &#xff08;4&#xff09;修改每个实例的端口&#xff0c;工作目录 &a…

虹科分享 | 买车无忧?AR带来全新体验!

文章来源&#xff1a;虹科数字化与AR 阅读原文&#xff1a;https://mp.weixin.qq.com/s/XsUFCTsiI4bkEMBHcGUT7w 新能源汽车的蓬勃发展&#xff0c;推动着汽车行业加速进行数字化变革。据数据显示&#xff0c;全球新能源汽车销售额持续上升&#xff0c;预计到2025年&#xff0…