常见Transformer位置编码

文章目录

    • 概述
    • 绝对位置编码
    • 相对位置编码
      • T5 Bias
      • ALiBi
      • RoPE
    • 参考资料

概述

相对于RNN这样的序列模型来说,Transformer可并行是一个很大的优势,但可并行性带来一个问题,由于不是从前到后,所以模型对于位置信息是不敏感的。于是在Transformer最早提出时就定义了位置编码(Positional Encodings)的概念,本文章旨在介绍常见位置编码方式。

绝对位置编码

绝对位置编码(Absolute Positional Encodings)也在早期的Transformer架构中被广泛使用。
Transformer中的位置编码如下图所示:

最初Transformer中的位置编码采用三角式绝对位置编码,计算方式如下:
P E ( k , 2 i ) = s i n ( k 1000 0 2 i / d ) PE_{(k, 2i)}=sin(\frac{k}{10000^{2i/d}}) PE(k,2i)=sin(100002i/dk)
P E ( k , 2 i + 1 ) = c o s ( k 1000 0 2 i / d ) PE_{(k, 2i+1)}=cos(\frac{k}{10000^{2i/d}}) PE(k,2i+1)=cos(100002i/dk)
其中, 2 i 2i 2i 2 i + 1 2i+1 2i+1代表向量的第几维, k k k代表token在序列中的位置, d d d是向量的维度。
实现代码:

import torch
import math
seq_len = 128 # 序列长度
d_model = 8   # 向量维度
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1) # shape:[128, 1, 8]

相对位置编码

绝对位置编码的缺点很明显:

  1. 可扩展性不足,只能处理训练数据中看过的长度,如基础BERT模型只能够处理512个token,这显然是不合理的。所以越来越多的模型使用相对位置编码的方法来增加模型对于长度的泛化性。
  2. 每一个位置编码都是独立的。换而言之,位置编码和距离没有关系,不管距离多远都是一样的。

在正式介绍相对位置编码之前回顾一下原始论文中 self-attention的计算方式:
q i = ( x i + p i ) W q q_i=(x_i+p_i)W_q qi=(xi+pi)Wq
k j = ( x j + p j ) W k k_j=(x_j+p_j)W_k kj=(xj+pj)Wk
v k = ( x k + p k ) W v v_k=(x_k+p_k)W_v vk=(xk+pk)Wv
α i , j = s o f t m a x ( q i k j T ) \alpha_{i,j}=softmax(q_ik_j^T) αi,j=softmax(qikjT)
o i = ∑ j α i , j v j o_{i}=\sum_j\alpha_{i,j}v_j oi=jαi,jvj

我们把 q i q_i qi k j k_j kj中与位置编码的项去掉,再引入一个相对位置编码项。 对 q i k j T q_ik_j^T qikjT进行展开,得到以下等式:
q i k j T = ( x i W q ) ( x j W k ) T ⊕ R i , j q_ik_j^T=(x_iW_q)(x_jW_k)^T\oplus R_{i,j} qikjT=(xiWq)(xjWk)TRi,j

其中 R i , j R_{i,j} Ri,j代表相对位置的编码方式, ⊕ \oplus 代表融入相对位置编码的方式(如:相乘或相加)。

T5 Bias

T5模型直接使用一个可学习矩阵来进行相对位置编码的学习:
q i k j T = ( x i W q ) ( x j W k ) T + β i , j q_ik_j^T=(x_iW_q)(x_jW_k)^T+ \beta_{i,j} qikjT=(xiWq)(xjWk)T+βi,j
其中 β i , j \beta_{i,j} βi,j为一个可学习的、共享的bias值,这种方法的缺点很明显,需要训练矩阵的参数,训练会变慢。

ALiBi

基于T5 bias的缺点,Press[3]提出了ALiBi算法。该方法通过预定义相对位置编码,通过在 q i k j T q_ik_j^T qikjT后增加一个偏置项来达到相对位置编码的目的:

其中 m m m是和注意力头有关的斜率,对于 n n n个注意力头,其斜率集合从 2 − 8 n 2^{\frac{-8}{n}} 2n8开始,计算示例如下图所示:

从原始论文可以看到,对比T5 Bias,输入长度越长优势明显。

RoPE

RoPE[5]提出的出发点是“通过绝对位置编码的方式实现相对位置编码”。如前所述,self-attention核心是计算是内积,所以问题就变成了内积前带有绝对位置编码信息,内积之后带有相对位置编码信息。假设存在这样的位置函数记为 f ( x , m ) f(\boldsymbol {x},m) f(x,m),则有:
⟨ f ( q , m ) , f ( k , m ) ⟩ = g ( q , k , m − n ) \langle{f(\boldsymbol{q},m), f(\boldsymbol {k},m)}\rangle=g(\boldsymbol {q},\boldsymbol {k}, m-n) f(q,m),f(k,m)=g(q,k,mn)

所以现在的目标变成了找到这样一个函数使得上述等式成立。论文就提出了一个可以满足上述等式的函数。为了简化推理,我们假设向量的维度为2,对于第 m m m个位置的token,对应的旋转矩阵为:
R θ , m = ( cos ⁡ m θ − sin ⁡ m θ sin ⁡ m θ cos ⁡ m θ ) R_{\theta,m}=\begin{pmatrix} \cos{m\theta} & -\sin{m\theta} \\ \sin{m\theta} & \cos{m\theta} \end {pmatrix} Rθ,m=(cosmθsinmθsinmθcosmθ)

其中 θ \theta θ是一个预设值,经过旋转之后的矩阵我们可以写成:
q ′ = R θ , m q m = ( cos ⁡ m θ − sin ⁡ m θ sin ⁡ m θ cos ⁡ m θ ) ( q m ( 1 ) q m ( 2 ) ) q' = R_{\theta,m}q_m=\begin{pmatrix} \cos{m\theta} & -\sin{m\theta} \\ \sin{m\theta} & \cos{m\theta} \end {pmatrix} \begin{pmatrix} q_m^{(1)} \\ q_m^{(2)} \end {pmatrix} q=Rθ,mqm=(cosmθsinmθsinmθcosmθ)(qm(1)qm(2))

对于二维向量,原始论文给出的RoPE实现过程如下图所示:

RoPE一般性实现如下:

q m T k n = ( R θ , m q m ) T ( R θ , n k n ) = q T R θ , m T R θ , n k n ( R θ , m T R θ , n = R θ , n − m ) = q T R ( θ , n − m ) k n \begin{align} q_m^Tk_n &= (R_{\theta,m}q_m)^T(R_{\theta,n}k_n) \\ &=q^TR_{\theta,m}^TR_{\theta,n}k_n \quad\quad\quad(R_{\theta,m}^TR_{\theta,n}=R_{\theta, n-m})\\ &=q^TR_{(\theta,n-m)}k_n \end{align} qmTkn=(Rθ,mqm)T(Rθ,nkn)=qTRθ,mTRθ,nkn(Rθ,mTRθ,n=Rθ,nm)=qTR(θ,nm)kn
可以看出RoPE巧妙地结合了绝对位置编码和相对位置编码:
θ \theta θ只与token当前的位置有关,且对于所有token都是一样的
n − m n-m nm和两个token之间的距离有关,token距离越小,旋转的角度越小,反之,旋转角度越大。
RoPE实现代码[8]如下:

import torchfrom torch import nndef get_rotary_matrix(context_len: int, embedding_dim: int) -> torch.Tensor:"""Generate the Rotary Matrix for ROPEArgs:context_len (int): context lenembedding_dim (int): embedding dimReturns:torch.Tensor: the rotary matrix of dimension context_len x embedding_dim x embedding_dim"""R = torch.zeros((context_len, embedding_dim, embedding_dim), requires_grad=False)positions = torch.arange(1, context_len+1).unsqueeze(1)# Create matrix theta (shape: context_len  x embedding_dim // 2)slice_i = torch.arange(0, embedding_dim // 2)# 原始论文旋转角度计算方式theta = 10000. ** (-2.0 * (slice_i.float()) / embedding_dim) m_theta = positions * theta# Create sin and cos valuescos_values = torch.cos(m_theta)sin_values = torch.sin(m_theta)# 按照公式进行赋值R[:, 2*slice_i, 2*slice_i] = cos_valuesR[:, 2*slice_i, 2*slice_i+1] = -sin_valuesR[:, 2*slice_i+1, 2*slice_i] = sin_valuesR[:, 2*slice_i+1, 2*slice_i+1] = cos_valuesreturn Rbatch_size = 1
context_len = 8
embedding_dim = 2
ff_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
ff_k = nn.Linear(embedding_dim, embedding_dim, bias=False)
x = torch.randn((batch_size, context_len, embedding_dim))# 初始化矩阵:Q、K
queries = ff_q(x)
keys = ff_k(x)# 获取旋转矩阵
R_matrix = get_rotary_matrix(context_len, embedding_dim)queries_rot = (queries.transpose(0,1) @ R_matrix).transpose(0,1)
keys_rot = (keys.transpose(0,1) @ R_matrix).transpose(0,1)# ... Compute the score in the attention mechanism using the rotated queries and keys

目前RoPE技术已经在LLAMA系列、GLM系列、Baichuan系列、Qwen系列等广泛使用。

参考资料

  1. Zhao et al. Length Extrapolation of Transformers: A Survey from the Perspective of Positional Encoding.
  2. Su etal. ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING.
  3. 科学空间《让研究人员绞尽脑汁的Transformer位置编码》
  4. Press et al. Train short, test long: Attention with linear biases enables input length extrapolation.
  5. Transformer升级之路:2、博采众长的旋转式位置编码
  6. Su et al. RoFormer: Enhanced Transformer with Rotary Position Embedding
  7. https://github.com/meta-llama/llama3/blob/main/llama/model.py
  8. https://afterhoursresearch.hashnode.dev/rope-rotary-positional-embedding

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/58269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【IEEE出版 | EI稳定检索】2024智能机器人与自动控制国际学术会议 (IRAC 2024,11月29-12月1日)

2024智能机器人与自动控制国际学术会议 (IRAC 2024) 2024 International Conference on Intelligent Robotics and Automatic Control 官方信息 会议官网:www.icirac.org 2024 International Conference on Intelligent Robotics and Autom…

Golang | Leetcode Golang题解之第535题TinyURL的加密与解密

题目: 题解: import "math/rand"type Codec map[int]stringfunc Constructor() Codec {return Codec{} }func (c Codec) encode(longUrl string) string {for {key : rand.Int()if c[key] "" {c[key] longUrlreturn "http:/…

影响神经网络速度的因素- FLOPs、MAC、并行度以及计算平台

影响神经网络速度的四个主要因素分别是 FLOPs(浮点操作数)、MAC(内存访问成本)、并行度以及计算平台。这些因素共同作用,直接影响到神经网络的计算速度和资源需求。 1. FLOPs(Floating Point Operations&a…

【北京迅为】《STM32MP157开发板嵌入式开发指南》-第七十六章 C++入门

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

小菜家教平台:基于SpringBoot+Vue打造一站式学习管理系统

前言 现在已经学习了很多与Java相关的知识,但是迟迟没有进行一个完整的实践(之前这个项目开发到一半,很多东西没学搁置了,同时原先的项目中也有很多的问题),所以现在准备从零开始做一个基于SpringBootVue的…

基于Matlab的语音识别

一、引言 语音识别技术是让计算机识别一些语音信号,并把语音信号转换成相应的文本或者命令的一种高科技技术。语音识别技术所涉及的领域非常广泛,包括信号处理、模式识别、人工智能等技术。近年来已经从实验室开始走向市场,渗透到家电、通信…

如何在 IntelliJ IDEA 中调整 `Ctrl+/` 快捷键生成注释的位置

前言 在使用 IntelliJ IDEA 编写代码时,注释是代码可读性和维护性的重要组成部分。IDEA 提供了快捷键 Ctrl/ 用于快速生成单行注释。然而,默认情况下,使用此快捷键生成的注释会出现在行首,导致注释与代码之间存在较大的空格&…

源鲁杯 2024 web(部分)

[Round 1] Disal F12查看: f1ag_is_here.php 又F12可以发现图片提到了robots 访问robots.txt 得到flag.php<?php show_source(__FILE__); include("flag_is_so_beautiful.php"); $a$_POST[a]; $keypreg_match(/[a-zA-Z]{6}/,$a); $b$_REQUEST[b];if($a>99999…

使用 ADB 在某个特定时间点点击 Android 设备上的某个按钮

前提条件 安装 ADB&#xff1a;确保你已经在计算机上安装了 Android SDK&#xff08;或单独的 ADB&#xff09;。并将其添加到系统环境变量中&#xff0c;以便你可以在命令行中运行 adb。 USB调试&#xff1a;确保 Android 设备已启用 USB 调试模式。这可以在设备的“设置” -…

数据库的使用02:SQLServer的连接字符串、备份、还原、SQL监视相关设置

目录 一、连接字符串 【本地连接字符串】 【远程连接字符串】 二、备份 三、还原 &#xff08;1&#xff09;还原数据库-bak、btn文件 &#xff08;2&#xff09;附加数据库mdf文件 四、SQL监视器的使用 一、连接字符串 【本地连接字符串】 server DESKTOP-FTH2P3S; Da…

Oracle视频基础1.3.6练习

1.3.6 以下是您的需求清单&#xff08;不含解决方案&#xff09;&#xff1a; 检查数据库启动情况等待会话结束&#xff0c;进行正常关机等待事务全部提交后再关机查看 alert 日志文件查看后台跟踪文件查看用户跟踪文件 检查数据库启动情况 ps -ef | grep oracle ipcs clear…

【大数据学习 | HBASE】hbase的原理与组成结构

1. hbase的简述 hbase作为google的大数据三篇比较重要的论文之一&#xff0c;它的起源叫做bigtable&#xff0c;意思非常简单就是大表的意思&#xff0c;是一个分布式存储很多数据的大型表格系统&#xff0c;它是对于hdfs中的数据不能直观查询和随机读写的病痛的一个补充和完善…

苍穹外卖Bug集合

初始化后端项目运行出现以下问题 以上报错是因为maven和jdk版本不符合&#xff0c;需要将jdk改成17&#xff0c;mavne改成3.9.9

【C++篇】在秩序与混沌的交响乐中: STL之map容器的哲学探寻

文章目录 C map 容器详解&#xff1a;高效存储与快速查找前言第一章&#xff1a;C map 的概念1.1 map 的定义1.2 map 的特点 第二章&#xff1a;map 的构造方法2.1 常见构造函数2.1.1 示例&#xff1a;不同构造方法 2.2 相关文档 第三章&#xff1a;map 的常用操作3.1 插入操作…

太空旅游:科技能否让星辰大海变为现实?

内容概要 在这个快速变化的时代&#xff0c;太空旅游成为了一个让人热血沸腾的话题。想象一下&#xff0c;坐在一颗漂浮的太空舱里&#xff0c;手中端着饮料&#xff0c;眺望着无尽的星辰大海&#xff0c;简直就像科幻电影中的情节一样。不过&#xff0c;这不仅仅是一个空洞的…

程序中怎样用最简单方法实现写excel文档

很多开发语言都能找到excel文档读写的库&#xff0c;但是在资源极其受限的环境下开发&#xff0c;引入这些库会带来兼容性问题。因为一个小功能引入一堆库&#xff0c;我始终觉得划不来。看到有项目引用的jar包有一百多个&#xff0c;看着头麻&#xff0c;根本搞不清谁依赖谁。…

【春秋云镜】CVE-2023-23752

目录 CVE-2023-23752漏洞细节漏洞利用示例修复建议 春秋云镜&#xff1a;解法一&#xff1a;解法二&#xff1a; CVE-2023-23752 是一个影响 Joomla CMS 的未授权路径遍历漏洞。该漏洞出现在 Joomla 4.0.0 至 4.2.7 版本中&#xff0c;允许未经认证的远程攻击者通过特定 API 端…

解决虚拟机启动报:此主机支持AMD-V,但AMD-V处于禁用状态

首先要知道你自己使用的主板型号&#xff0c;如果是京东购买的&#xff0c;可以直接上京东去问客服。如果没有订单号&#xff0c;如果能提供正确的主板型号&#xff0c;他们应该也是会帮忙解答的。 您好&#xff0c;AMD 平台与 Intel 平台以及部分新老主板开启虚拟化的步骤和细…

【EI会议推荐】抢先掌握学术前沿!快来参加EI学术会议投稿,展示你的研究成果,开启科研新高度!

【EI会议推荐】抢先掌握学术前沿&#xff01;快来参加EI学术会议投稿&#xff0c;展示你的研究成果&#xff0c;开启科研新高度&#xff01; 【EI会议推荐】抢先掌握学术前沿&#xff01;快来参加EI学术会议投稿&#xff0c;展示你的研究成果&#xff0c;开启科研新高度&#…

2.若依vue表格数据根据不同状态显示不同颜色style

例如国标显示蓝色&#xff0c;超标是红色 使用是蓝色&#xff0c;未使用是绿色 <el-table-column label"外卖配送是否完成评价" align"center" prop"isOverFlag"> <template slot-scope"scope"> …