文章目录
- 为什么需要位置编码?
- 预备知识
- 三角函数求和公式
- 旋转矩阵
- 逆时针旋转
- 顺时针旋转
- 旋转矩阵的性质
- 原始Transformer中的位置编码
- 论文中的介绍
- 具体计算过程
- 为什么是线性变换?
- 大模型常用的旋转位置编码RoPE
- 基本原理
- 最简实现形式
- Llama3中的代码实现
- 两种位置编码的区别
- 编码方式
- 实现方式
- 参考资料
为什么需要位置编码?
众所周知,老生常谈,Transformer模型的核心是自注意力机制(Self-Attention),这一机制的特点是输入序列中的所有元素都是同时被处理的,而不是像RNN那样按顺序处理。这种并行处理的方式虽然具有很高的效率,但也导致了模型无法自然地获取输入序列中元素的位置信息。
比如,自注意力机制在处理 AI 好 难 学
和 难 学 好 AI
这两个元素相同,但是位置不同的序列时,得到的每个元素对应的attention值是相同的,也没办法区分。
因此,Positional Encoding 的作用,就是在把 Word Embedding 送入 Attention 之前,把位置信息给带上,使得模型能够在进行自注意力计算时感知到输入元素的相对和绝对位置。
网络社区中对 Positional Encoding 分类的方法很多,按照不同的分类方法划分,大致可以分为:
- 绝对位置编码和相对位置编码:
- 绝对位置编码,为输入序列中的每个位置提供一个唯一的表示,通常是通过预定义的方法生成,并直接添加到输入表示中
- 相对位置编码,是对两个单词之间的相对位置进行建模,并且将相对位置信息加入到Self-Attention结构中,形如Transformer-XL,DeBERTa等采用的就是相对位置编码。Self-Attention的本质是两个单词信息的内积操作,相对位置编码的思想是对内积的计算方式进行改进,在内积中注入两个单词的相对位置因素
- 固定式位置编码和可学习式位置编码:
- 这种分类方式,说的是 绝对位置编码 的不同实现方式
- 固定位置编码,主要是 Transformer论文中提出的正弦和余弦位置编码(Sinusoidal Positional Encoding)方法,使用正弦和余弦函数生成不同频率的编码
- 可学习式位置编码,没有固定的位置编码公式,通过初始化位置向量让模型根据上下文数据自适应地学习出来,Bert和GPT采用的就是可学习式
- 绝对位置编码添加的位置不同:
- 绝对位置编码加在 Transformer 的输入端,典型代表是绝对位置编码( Sinusoidal 位置编码和可学习位置编码 )
- 绝对位置编码乘在 q , k , v q, k, v q,k,v,典型代表是 RoPE 位置编码
- 相对位置编码加在注意力权重 q T k q^{T}k qTk,典型代表是 ALiBi 位置编码
根据本人面试经历,只要是和Positional Encoding相关的问题,基本都是 Transformer论文中提出的正弦和余弦位置编码,以及目前大模型常用的RoPE,这两个方法。因此,本文主要以这两个方法为例来深入讨论。
预备知识
三角函数求和公式
s i n ( α + β ) = s i n α ∗ C o s β + c o s α ∗ s i n β \rm{sin}(\alpha+\beta) = sin\alpha*Cos\beta + cos\alpha * sin\beta sin(α+β)=sinα∗Cosβ+cosα∗sinβ
s i n ( α − β ) = s i n α ∗ C o s β − c o s α ∗ s i n β \rm{sin}(\alpha-\beta) = sin\alpha*Cos\beta - cos\alpha * sin\beta sin(α−β)=sinα∗Cosβ−cosα∗sinβ
c o s ( α + β ) = c o s α ∗ c o s β − s i n α ∗ s i n β \rm{cos}(\alpha+\beta) = cos\alpha*cos\beta - sin\alpha * sin\beta cos(α+β)=cosα∗cosβ−sinα∗sinβ
c o s ( α − β ) = c o s α ∗ c o s β + s i n α ∗ s i n β \rm{cos}(\alpha-\beta) = cos\alpha*cos\beta + sin\alpha * sin\beta cos(α−β)=cosα∗cosβ+sinα∗sinβ
旋转矩阵
逆时针旋转
假设向量 a , b \bold{a}, \bold{b} a,b的长度均为1,将 a \bold{a} a逆时针旋转 θ \theta θ角度,变成 b \bold{b} b的过程如下:
a = [ c o s μ , s i n μ \bold{a} = [\rm{cos}\mu, sin\mu a=[cosμ,sinμ]
b = [ c o s ( μ + θ ) , s i n ( μ + θ ) \bold{b} = [\rm{cos}(\mu+\theta), sin(\mu+\theta) b=[cos(μ+θ),sin(μ+θ)]
根据上面的三角函数求和公式可得:
b = [ c o s θ c o s μ − s i n θ s i n μ , s i n μ c o s θ + c o s μ s i n θ \bold{b} = [\rm{cos}\theta cos\mu - sin\theta sin\mu,sin\mu cos\theta + cos\mu sin\theta b=[cosθcosμ−sinθsinμ,sinμcosθ+cosμsinθ]
这里我们用矩阵乘来简化计算:
b T = [ c o s ( μ + θ ) s i n ( μ + θ ) ] = [ c o s θ − s i n θ s i n θ c o s θ ] [ c o s μ s i n μ ] = M n a \bold{b}^{\rm{T}} =\begin{bmatrix} \rm{cos}(\mu+\theta) \\ \rm{sin}(\mu+\theta) \end{bmatrix}= \begin{bmatrix} \rm{cos}\theta & -\rm{sin}\theta \\ \rm{sin}\theta & \rm{cos}\theta \end{bmatrix} \begin{bmatrix} \rm{cos}\mu \\ \rm{sin}\mu \end{bmatrix} =\bold{M_{n}}\bold{a} bT=[cos(μ+θ)sin(μ+θ)]=[cosθsinθ−sinθcosθ][cosμsinμ]=Mna
因此,逆时针的旋转矩阵为: M n = [ c o s θ − s i n θ s i n θ c o s θ ] \bold{M_{n}}=\begin{bmatrix} \rm{cos}\theta & -\rm{sin}\theta \\ \rm{sin}\theta & \rm{cos}\theta \end{bmatrix} Mn=[cosθsinθ−sinθcosθ]
顺时针旋转
假设向量 a , b \bold{a}, \bold{b} a,b的长度均为1,将 a \bold{a} a顺时针旋转 θ \theta θ角度,变成 b \bold{b} b的过程如下:
a = [ c o s μ , s i n μ \bold{a} = [\rm{cos}\mu, sin\mu a=[cosμ,sinμ]
b = [ c o s ( μ − θ ) , s i n ( μ − θ ) \bold{b} = [\rm{cos}(\mu-\theta), sin(\mu-\theta) b=[cos(μ−θ),sin(μ−θ)]
根据上面的三角函数求和公式可得:
b = [ c o s θ c o s μ + s i n θ s i n μ , s i n μ c o s θ − c o s μ s i n θ \bold{b} = [\rm{cos}\theta cos\mu + sin\theta sin\mu,sin\mu cos\theta - cos\mu sin\theta b=[cosθcosμ+sinθsinμ,sinμcosθ−cosμsinθ]
这里我们用矩阵乘来简化计算:
b T = [ c o s ( μ − θ ) s i n ( μ − θ ) ] = [ c o s θ s i n θ − s i n θ c o s θ ] [ c o s μ s i n μ ] = M s a \bold{b}^{\rm{T}} =\begin{bmatrix} \rm{cos}(\mu-\theta) \\ \rm{sin}(\mu-\theta) \end{bmatrix} = \begin{bmatrix} \rm{cos}\theta & \rm{sin}\theta \\ -\rm{sin}\theta & \rm{cos}\theta \end{bmatrix} \begin{bmatrix} \rm{cos}\mu \\ \rm{sin}\mu \end{bmatrix} = \bold{M_{s}}\bold{a} bT=[cos(μ−θ)sin(μ−θ)]=[cosθ−sinθsinθcosθ][cosμsinμ]=Msa
因此,顺时针的旋转矩阵为: M s = [ c o s θ s i n θ − s i n θ c o s θ ] \bold{M_{s}}=\begin{bmatrix} \rm{cos}\theta & \rm{sin}\theta \\ -\rm{sin}\theta & \rm{cos}\theta \end{bmatrix} Ms=[cosθ−sinθsinθcosθ]
旋转矩阵的性质
R ( α ) R ( β ) = R ( α + β ) R(\alpha) R(\beta) =R(\alpha+\beta) R(α)R(β)=R(α+β)
R ( θ ) T = R ( − θ ) R(\theta)^{\rm{T}} =R(-\theta) R(θ)T=R(−θ)
原始Transformer中的位置编码
论文中的介绍
首先贴上Transformer论文中,对于Positional Encoding部分的全部介绍:
我真的服了,这么重要的位置编码,论文里就写了这么一点??现在看来,内容虽然少,但是句句都是关键,每一句都是面试官想要考你的点,蚌埠住了!
回到正题,论文里面对Positional Encoding的描述主要有以下几个点:
- 位置编码的维度和token的embedding的维度一致,所以可以直接add
- 位置编码的具体实现方式是:
sine and cosine functions of different frequencies
,也就是同时使用正弦函数和余弦函数来表示每个token的绝对位置 - 在
sine and cosine functions of different frequencies
中,包括两个关键变量,一个是pos
,表示 是哪个token,另一个是i
,表示token中不同embedding的位置 - 使用这种正余弦位置编码的方式,可以在计算attention时,很轻松的学习
relative positions
,也就是相对位置,理由是, P E p o s + k PE_{pos+k} PEpos+k可以表示为 P E p o s PE_{pos} PEpos的线性变换!!(其实就是旋转矩阵) - 选择正余弦位置编码方式,也是因为它可以允许模型外推到,比训练期间遇到的序列长度更长的序列长度,这个特性对于扩大模型推理时的长度非常友好!!
具体计算过程
下面,让我们通过一个具体的示例,来理解Transformer论文的正余弦位置编码,到底是怎么计算的?(参考这篇blog)
假设我们的输入如下,第一行是输入文本,第二行tokenization后的tokens,最后是每个token对应的embedding(维度是5):
首先,对于pos=0
的token5
(对应text为When
)来说,计算它的 位置编码 方式如下:
可以看到,token的每个维度,都会计算一个位置编码,维度索引的奇偶性,决定了使用sin
还是cos
函数来计算。
这里的计算方式,和原论文的公式有出入,原论文应该是维度索引
i
为偶数时,使用sin
函数来计算,为奇数时,则使用cos
函数来计算。
但是这里的计算方式却是反过来,奇数时,使用sin
函数来计算,偶数时,则使用cos
函数来计算
所以大家知道这一点就可以,不影响对Positional Encoding计算过程的理解
同理,对于所有输入tokens,分别计算他们的位置编码:
这里可以感觉出来,越靠前的token计算的位置编码,他们使用的正余弦函数的频率越大,振荡的越快,相反,越往后的tokne,在embedding维度上振荡越慢,不同的频率也就是论文中说的sine and cosine functions of different frequencies
,大概如下图所示:
为什么是线性变换?
到这里,相信大家对Transformer论文的正余弦位置编码的计算过程,有了一个清晰的理解。现在来思考论文中的一个关键点: P E p o s + k PE_{pos+k} PEpos+k如何表示为 P E p o s PE_{pos} PEpos的线性变换?
我们假设,用 t t t来表示不同token得pos
(其实这里是把位置,类比为时间,第i
个位置和第i
个时刻是一致的),那么论文中的PE
计算公式就变成了:
P E ( t , 2 i ) = s i n ( t 1000 0 2 i / d m o d e l ) P E ( t , 2 i + 1 ) = c o s ( t 1000 0 2 i / d m o d e l ) \begin{aligned} PE_{(t, 2i)} &= \rm{sin}(\frac{\it{t}}{10000^{2\it{i}/\it{d}_{model}}}) \\ PE_{(t, 2i+1)} &= \rm{cos}(\frac{\it{t}}{10000^{2\it{i}/\it{d}_{model}}}) \end{aligned} PE(t,2i)PE(t,2i+1)=sin(100002i/dmodelt)=cos(100002i/dmodelt)
可以看到,位置编码的过程,其实是对每个token按照维度方向,两两分组,维度索引为偶数时使用sin
函数,奇数时使用cos
函数/
所以一共有 d m o d e l / 2 d_{model}/2 dmodel/2个分组,这里用 j j j来表示分组情况,那么,第 t t t个token的第 j j j个分组可以表示为:
P E ( t , j ) = { sin ( θ j ⋅ t ) , if j = 2 i / 2 cos ( θ j ⋅ t ) , if j = ( 2 i + 1 ) / 2 P E_{(t, j)}=\left\{\begin{array}{ll} \sin \left(\theta_{j} \cdot t\right), & \text { if } j=2 i / 2 \\ \cos \left(\theta_{j} \cdot t\right), & \text { if } j=(2 i+1) / 2 \end{array}\right. PE(t,j)={sin(θj⋅t),cos(θj⋅t), if j=2i/2 if j=(2i+1)/2
其中, θ j = 1 1000 0 j / d m o d e l \theta_j=\frac{1}{10000^{j/d_{model}}} θj=10000j/dmodel1
那么,对于第 j j j个分组来说,如果 P E ( t + k , j ) PE_{(t+k, j)} PE(t+k,j)是 P E ( t , j ) PE_{(t, j)} PE(t,j)的线性变换,则存在一个矩阵 M ∈ R 2 × 2 \bold{M}\in \mathbb R^{2 \times 2} M∈R2×2,使得 P E ( t + k , j ) = M ∗ P E ( t , j ) PE_{(t+k, j)}=\bold{M} * PE_{(t, j)} PE(t+k,j)=M∗PE(t,j)成立,也就是:
[ s i n ( θ j ⋅ ( t + k ) ) c o s ( θ j ⋅ ( t + k ) ) ] = M [ s i n ( θ j ⋅ t ) c o s ( θ j ⋅ t ) ] \begin{bmatrix} \rm{sin}(\theta_{\it{j}} \cdot (\it{t} + \it{k})) \\ \rm{cos}(\theta_{\it{j}} \cdot (\it{t} + \it{k})) \end{bmatrix} = \bold{M} \begin{bmatrix} \rm{sin}(\theta_{\it{j}} \cdot \it{t}) \\ \rm{cos}(\theta_{\it{j}} \cdot \it{t}) \end{bmatrix} [sin(θj⋅(t+k))cos(θj⋅(t+k))]=M[sin(θj⋅t)cos(θj⋅t)]
聪明的小伙伴肯定可以看出来,这不就是预备知识中讲的逆时针旋转公式嘛!!!其实不然,大家注意看,这里sin和cos的顺序,和我们之前推导旋转公式的时候,是相反的,所以这里重新计算就可以得到:
M = [ c o s ( θ j ⋅ k ) s i n ( θ j ⋅ k ) − s i n ( θ j ⋅ k ) c o s ( θ j ⋅ k ) ] \bold{M}= \begin{bmatrix} \rm{cos}(\theta_{\it{j}} \cdot \it{k}) & \rm{sin}(\theta_{\it{j}} \cdot \it{k}) \\ -\rm{sin}(\theta_{\it{j}} \cdot \it{k}) & \rm{cos}(\theta_{\it{j}} \cdot \it{k}) \end{bmatrix} M=[cos(θj⋅k)−sin(θj⋅k)sin(θj⋅k)cos(θj⋅k)]
因此, P E ( t + k , j ) PE_{(t+k, j)} PE(t+k,j)就是 P E ( t , j ) PE_{(t, j)} PE(t,j)顺时针旋转 ( θ j ⋅ k ) (\theta_j \cdot k) (θj⋅k) 角度得到的,旋转角度是相对位置 k k k的线性关系,所以,论文中才说:for any fixed offset k k k, P E p o s + k PE_{pos+k} PEpos+k can be represented as a linear function of P E p o s PE_{pos} PEpos
对于所有 d m o d e l / 2 d_{model}/2 dmodel/2个分组来说, P E t + k PE_{t+k} PEt+k是 P E t PE_{t} PEt顺时针旋转得到的,可以表示为:
大模型常用的旋转位置编码RoPE
RoPE是 Rotary Position Embedding 的缩写,即旋转位置编码,源自 RoFormer 这篇论文(RoFormer: Enhanced Transformer with Rotary Position Embedding),目前已在 Llama 等各类大模型中被广泛用作默认的位置编码方法。
基本原理
核心就一句话:使用旋转矩阵对绝对位置进行编码,同时在自注意中结合了显式的相对位置依赖性
也就是说,RoPE是一种固定式的绝对位置编码策略,但是它的绝对位置编码配合 Transformer 的Self-Attention 内积注意力机制能达到相对位置编码的效果。
RoPE的本质是对两个token形成的Query和Key向量做一个变换,使得变换后的Query和Key带有位置信息,进一步使得Attention的内积操作不需要做任何更改就能自动感知到相对位置信息。换句话说,RoPE的出发点和策略用的相对位置编码思想,但是实现方式用的是绝对位置编码。
下面根据论文中的介绍,尽量直白地讲清楚RoPE的基本原理:
首先,论文在最开始讲说清楚了,RoPE方法的最终目标,就是找到一种等价的位置编码方法,使得query和key的内积结果,只和输入word embedding也就是 x m , x n \bold{x}_m, \bold{x}_n xm,xn以及他们之间的相对位置 m − n m-n m−n有关。
那么,RoPE是怎么做的呢?
根据论文中的公式14和16,我们可以知道:
- 首先对query向量逆时针旋转 m θ i m\theta_{i} mθi角度,对key向量逆时针旋转 n θ i n\theta_{i} nθi角度
- 然后,计算attention,也就是 q m T ⋅ k n = ( R m q ) T ⋅ ( R n k ) = q T R m T ⋅ R n k = q T R − m ⋅ R n k = q T R n − m k \bold{q}^{\rm{T}}_{m} \cdot \bold{k}_{n}=(R_m \bold{q})^{\rm{T}} \cdot (R_n \bold{k})=\bold{q}^{\rm{T}} R_m^{\rm{T}} \cdot R_n \bold{k}=\bold{q}^{\rm{T}} R_{-m} \cdot R_n \bold{k}=\bold{q}^{\rm{T}} R_{n-m} \bold{k} qmT⋅kn=(Rmq)T⋅(Rnk)=qTRmT⋅Rnk=qTR−m⋅Rnk=qTRn−mk
因此,可以看到,RoPE的实现方式,是通过对query和key的embedding分别进行逆时针旋转(乘以逆时针旋转矩阵),然后作用在attention内积计算中,自然的实现了相对位置信息的嵌入。
以上是非常简略的讲解,论文中以及llama3代码中,都是在复数域计算频率因子,从而进行旋转变换(相乘),最后再变换到实数域得到结果。
最简实现形式
https://zhuanlan.zhihu.com/p/684666015
# 1. 参数初始化
bs = 1
seq_len = 10
dim = 128
theta = 10000.0# 2. 初始化输入,并转换为复数
xq = torch.randn(bs, seq_len, dim) # [1, 10, 128]
xq_div_2 = xq.view(*xq.shape[:-1], -1, 2) # [1, 10, 64, 2]
xq_complex = torch.view_as_complex(xq_div_2) # [1, 10, 64]# 3. 计算旋转位置编码的复数因子
theta = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim)) # [64]
m = torch.arange(seq_len) # [10]
freqs = torch.outer(m, theta) # 外积函数,两个矩阵中的元素两两相乘,[10, 64]
freqs = torch.polar(torch.one_like(freqs), freqs) # [10, 64]# 4. 在复数空间,对原输入进行旋转位置编码转换
freqs_cis = freqs[:seq_len].view(xq_complex.shape) # 统一维度 # [1, 10, 64]
xq_rope_complex = xq_complex * freqs_cis # [1, 10, 64]# 5. 复数转换为实数
xq_rope_real = torch.view_as_real(xq_rope_complex) # [1, 10, 64, 2]
xq_rope_real = xq_rope_real.flatten(-2) # 展平最后两个维度 [1, 10, 128]
Llama3中的代码实现
https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
这里预先计算每个序列中位置旋转的角度,以复数形式表达。
dim:为attention输入tensor的维度,简单来讲可以理解成embedding的维度
end:为序列的最大长度,即只计算0-end的旋转角度
theta:为超参数,原论文推荐值为10000
return:输出的是一个复数向量,向量shape为(end,dim//2),第一维是是序列长度,第二维为输入tensor维度的一半,因为旋转是以2维矩阵为单位的,即每两个数有一个旋转角度。
"""freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))t = torch.arange(end, device=freqs.device) # type: ignorefreqs = torch.outer(t, freqs).float() # type: ignorefreqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64return freqs_cisdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
将预先计算好的旋转角度freqs_cis的shape与输入shape统一。
freqs_cis的维数会变为和x一致,其中第二维是freqs_cis.shape[0],最后一维是freqs_cis.shape[1],其余维度为1
"""ndim = x.ndimassert 0 <= 1 < ndimassert freqs_cis.shape == (x.shape[1], x.shape[-1])shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]return freqs_cis.view(*shape)def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
用预先计算的角度来旋转输入的xq和xk,做法是利用复数乘法的性质来完成旋转
"""xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_)xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)
两种位置编码的区别
简单来说,绝对位置编码是一个顺时针旋转的钟表系统,这个信息会与输入相加。但是旋转位置编码是一个逆时针旋转的时钟系统,并没有采用相加的方式,而是直接将输入(query和key)进行了旋转。
编码方式
- 正余弦位置编码:
- 使用固定的正弦和余弦函数来为每个token的每个维度生成位置编码
- 这种编码的每个维度使用不同的频率,使得编码具有周期性和可区分性,从而允许模型推断序列中元素的相对位置
- RoPE:
- 使用旋转变换的方式来对位置进行编码。
- RoPE通过在自注意力机制的点积计算过程中引入旋转变换来实现位置编码,而不是直接添加到输入嵌入上。
- 提供了一种位置编码与内容编码更紧密结合的方法,使得模型能够在不丢失相对位置关系的情况下处理长序列。
实现方式
- 正余弦位置编码:
- 直接将位置编码加到输入的词嵌入上,影响模型的输入表示。
- RoPE:
- 在注意力计算过程中,使用旋转操作影响注意力得分。具体而言,通过将query和key中的位置进行旋转变换,使得位置编码在自注意力的计算中以更自然的方式呈现。
- 这种方法通常不直接影响输入嵌入,而是调整注意力机制本身。
参考资料
- [1] https://note.mowen.cn/note/detail?noteUuid=Q2_oDhFEqD2pD8Iv4uSzn
- [2] https://note.mowen.cn/note/detail?noteUuid=waAeRtCgZXLO62f9RhUWa
- [3] https://www.bilibili.com/video/BV1F1421B7iv/?share_source=copy_web&vd_source=79b1ab42a5b1cccc2807bc14de489fa7
- [4] https://www.jianshu.com/p/e8be3dbfb4c5
- [5] https://blog.csdn.net/BIT_Legend/article/details/137042032
- [6] https://medium.com/@fareedkhandev/understanding-transformers-a-step-by-step-math-example-part-1-a7809015150a
- [7] https://zhuanlan.zhihu.com/p/684666015
- [8] https://kexue.fm/archives/8265