Llama改进之——RoPE旋转位置编码

引言

旋转位置编码(Rotary Position Embedding, RoPE)将绝对相对位置依赖纳入自注意力机制中,以增强Transformer架构的性能。目前很火的大模型LLaMA、QWen等都应用了旋转位置编码。

之前在[论文笔记]ROFORMER中对旋转位置编码的原始论文进行了解析,重点推导了旋转位置编码的公式,本文侧重实现,同时尽量简化数学上的推理,详细内容可见最后的参考文章。

复数与极坐标

复数由两个部分组成:实部(real part)和虚部(imaginary part)。实部就是一个普通的数字,可以是零、正数或负数。虚部是另一个实数与 i i i相乘。比如 2 + 3 i 2+3i 2+3i是一个复数,其中 2 2 2是实部; 3 i 3i 3i是虚部。下面这些数字都是复数:
2 , 2 + 2 i , 1 − 3 i , − 4 i , 17 i 2, \quad 2+2i,\quad 1-3i,\quad -4i,\quad 17i 2,2+2i,13i,4i,17i
可以看到复数是实数的扩展,包含了实数,比如 2 2 2可以看成是虚部为 0 0 0

通常实数放前面,然后是 i i i。但当 i i i与三角函数( sin ⁡ , cos ⁡ \sin,\cos sin,cos)在一起通常把 i i i放在前面: i sin ⁡ θ , i cos ⁡ θ i \sin \theta, i\cos \theta isinθ,icosθ​​。

i i i我们可以理解为就是一个简单的数学对象,满足 i 2 = − 1 i^2=-1 i2=1

image-20240406094033599

极坐标系是一个二维坐标系统。该坐标系统中任意位置可由一个夹角和一段相对原点——极点的距离来表示。如上图(来自百度百科)所示。

给定极坐标系内的任意一个复数 x + y i x+yi x+yi(对应二维向量 [ x , y ] [x,y] [x,y]),要将其(逆时针)旋转 θ \theta θ度,只需要乘上旋转子:
R θ = cos ⁡ θ + i sin ⁡ θ ( sin ⁡ 2 θ + cos ⁡ 2 θ = 1 ) (1) \pmb R_\theta = \cos \theta + i \sin \theta \qquad(\sin^2 \theta + \cos^2 \theta = 1) \tag 1 RRRθ=cosθ+isinθ(sin2θ+cos2θ=1)(1)
可以相乘再展开,然后利用 i 2 = − 1 i^2=-1 i2=1可得:
x ′ + y ′ i = ( cos ⁡ θ + i sin ⁡ θ ) ( x + y i ) = ( x cos ⁡ θ − y sin ⁡ θ ) + ( x sin ⁡ θ + y cos ⁡ θ ) i \begin{aligned} x^\prime + y^\prime i &= (\cos \theta + i\sin \theta)(x + yi) \\ &= (x \cos \theta - y \sin \theta)+(x \sin \theta + y \cos \theta)i \end{aligned} x+yi=(cosθ+isinθ)(x+yi)=(xcosθysinθ)+(xsinθ+ycosθ)i
对应二维平面中点 [ x , y ] [x,y] [x,y]关于原点的逆时针旋转:
[ x ′ y ′ ] = [ cos ⁡ θ − sin ⁡ θ sin ⁡ θ cos ⁡ θ ] [ x y ] \begin{bmatrix} x^\prime \\ y^\prime \end{bmatrix} = \begin{bmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{bmatrix} \begin{bmatrix} x \\ y \end{bmatrix} [xy]=[cosθsinθsinθcosθ][xy]
其中包含 θ \theta θ的矩阵是一个旋转矩阵。

旋转位置编码

x i ∈ R d \pmb x_i \in \Bbb R^d xxxiRd是无位置信息的标记 w i w_i wi d d d维词嵌入向量。自注意力首先将位置信息与单词嵌入相结合,并将其转化为query、key和value的表示形式。
q m = f q ( x m , m ) k n = f k ( x n , n ) v n = f v ( x n , n ) (2) \begin{aligned} \pmb q_m &= f_q(\pmb x_m, m) \\ \pmb k_n &= f_k(\pmb x_n, n) \\ \pmb v_n &= f_v(\pmb x_n, n) \\ \end{aligned} \tag 2 qqqmkkknvvvn=fq(xxxm,m)=fk(xxxn,n)=fv(xxxn,n)(2)
其中 q m , k n \pmb q_m,\pmb k_n qqqm,kkkn v n \pmb v_n vvvn分别通过 f q , f k f_q,f_k fq,fk f v f_v fv整合了第m和第n个位置信息。query和key然后用于计算注意力权重,而输出为value的加权和。
$$
\begin{aligned}
a_{m,n} &= \frac{\exp(\frac{\pmb q^T_m \pmb k_n}{\sqrt d})}{\sum_{j=1}^N \exp \frac{\pmb q^T_m \pmb k_j}{\sqrt d}} \
\pmb o_m &= \sum_{n=1}^N a_{m,n}\pmb v_n \

\end{aligned} \tag 3
$$

Transformer通过自注意机制利用各个标记的位置信息,如等式(3)中所见, q m T k n \pmb q_m^T \pmb k_n qqqmTkkkn通常可以在不同位置的标记之间传递知识。为了融入相对位置信息,我们需要将查询 q m \pmb q_m qqqm和键 k n \pmb k_n kkkn的内积公式转化为一个函数 g g g,该函数只接受词嵌入 x m , x n \pmb x_m,\pmb x_n xxxm,xxxn以及它们的相对位置 m − n m-n mn​作为输入变量。换句话说,我们希望内积只以相对形式编码位置信息:

⟨ f q ( x m , m ) , f k ( x n , n ) ⟩ = g ( x m , x n , m − n ) (4) \langle f_q(\pmb x_m,m) , f_k(\pmb x_n,n) \rangle = g(\pmb x_m,\pmb x_n, m-n) \tag 4 fq(xxxm,m),fk(xxxn,n)=g(xxxm,xxxn,mn)(4)
最终目标是找到一个等价的编码方式来求解函数 f q ( x m , m ) f_q(\pmb x_m, m) fq(xxxm,m) f k ( x n , n ) f_k(\pmb x_n, n) fk(xxxn,n)​,以符合上等式。

从简单的维度 d = 2 d=2 d=2的情况开始,这样可以利用二维平面上向量的几何特性及其复数形式来证明公式(4)的一个解是:
f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = Re [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ] (5) \begin{aligned} f_q(\pmb x_m,m) &= (\pmb W_q\pmb x_m) e^{im\theta} \\ f_k(\pmb x_n,n) &= (\pmb W_k\pmb x_n) e^{in\theta} \\ g(\pmb x_m,\pmb x_n,m-n) &= \text{Re}[(\pmb W_q\pmb x_m)(\pmb W_k\pmb x_n)^*e^{i(m-n)\theta}] \end{aligned} \tag {5} fq(xxxm,m)fk(xxxn,n)g(xxxm,xxxn,mn)=(WWWqxxxm)eimθ=(WWWkxxxn)einθ=Re[(WWWqxxxm)(WWWkxxxn)ei(mn)θ](5)
这里 Re [ ⋅ ] \text{Re}[\cdot] Re[]表示复数的实部; ( W k x n ) ∗ (\pmb W_k\pmb x_n)^* (WWWkxxxn)表示 ( W k x n ) (\pmb W_k\pmb x_n) (WWWkxxxn)的共轭复数; θ ∈ R \theta \in \Bbb R θR表示一个非零常数。

可以进一步将 f { q , k } f_{\{q,k\}} f{q,k}写成矩阵乘法形式:
f { q , k } ( x m , m ) = ( cos ⁡ m θ − sin ⁡ m θ sin ⁡ m θ cos ⁡ m θ ) ( W { q , k } ( 11 ) W { q , k } ( 12 ) W { q , k } ( 21 ) W { q , k } ( 22 ) ) ( x m ( 1 ) x m ( 2 ) ) (6) f_{\{q,k\}} (\pmb x_m,m) =\begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix}\begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} \tag{6} f{q,k}(xxxm,m)=(cosmθsinmθsinmθcosmθ)(W{q,k}(11)W{q,k}(21)W{q,k}(12)W{q,k}(22))(xm(1)xm(2))(6)
这里的 { q , k } \{q,k\} {q,k}表示 q q q k k k的集合,比如上式对 f q f_q fq f k f_k fk​都成立;包含 sin ⁡ m θ \sin m\theta sinmθ cos ⁡ m θ \cos m\theta cosmθ的矩阵是上面介绍的旋转矩阵。

其中$ (x^{(1)}_m, x^{(2)}_m) 为 为 x_m$ 在二维坐标中的表示。类似地, g g g 可以被视为一个矩阵,从而能够在二维情况下求解等式 ( 4 ) (4) (4)。具体来说,结合相对位置嵌入是很直接的:只需将仿射变换后的词嵌入向量旋转一定角度乘位置索引(旋转 m θ m\theta mθ​),从而解释了旋转位置嵌入背后的直觉。

我们进行直观理解,假设两个向量 q \pmb q qqq k \pmb k kkk它们的夹角为 θ \theta θ,根据向量夹角的余弦我们知道 q ⋅ k = ∣ q ∣ ∣ k ∣ cos ⁡ θ \pmb q \cdot \pmb k = |\pmb q||\pmb k| \cos \theta qqqkkk=qqqkkkcosθ​。

image-20240408173339571

q \pmb q qqq(逆时针)旋转 α \alpha α角度后,与 k \pmb k kkk的夹角变成了 θ + α \theta + \alpha θ+α

image-20240408173856558

k \pmb k kkk旋转 β \beta β角度后,与 q \pmb q qqq的夹角变成了 θ − β \theta - \beta θβ

image-20240408174209956

当两个向量同时旋转后,它们的夹角变成了 θ + α − β \theta + \alpha -\beta θ+αβ。内积表达式为:
q ⋅ k = ∣ q ∣ ∣ k ∣ cos ⁡ ( θ + α − β ) \pmb q \cdot \pmb k = |\pmb q||\pmb k| \cos (\theta + \alpha - \beta) qqqkkk=qqqkkkcos(θ+αβ)
特殊地,当 α − β = 0 \alpha - \beta =0 αβ=0​​时,即两个向量旋转的角度相同,它们的内积不变。通过这两个向量的夹角来影响内积的值。通过这种直觉,公式(4)是成立的。

为了将我们在二维空间中的结果推广到任意 x i ∈ R d \pmb x_i ∈ \R^d xxxiRd,其中 d d d 是偶数。我们可以将 d d d 维空间划分为 $d/2 $个子空间(分块矩阵),并结合内积的线性特性进行组合,将 f { q , k } f_{\{q,k\}} f{q,k}​ 转化为:
f { q , k } = ( x m , m ) = R Θ , m d W { q , k } x m (7) f_{\{q,k\}} = (\pmb x_m,m) = \pmb R_{\Theta,m}^d \pmb W_{\{q,k\}} \pmb x_m \tag{7} f{q,k}=(xxxm,m)=RRRΘ,mdWWW{q,k}xxxm(7)

这里说的特性是指线性叠加性:

  1. 定义:内积的定义是两个向量对应分量相乘后再相加。假设有两个向量 v ⃗ = ( v 1 , v 2 , . . . , v n ) \vec{v} = (v_1, v_2, ..., v_n) v =(v1,v2,...,vn) w ⃗ = ( w 1 , w 2 , . . . , w n ) \vec{w} = (w_1, w_2, ..., w_n) w =(w1,w2,...,wn),它们的内积可以表示为 v ⃗ ⋅ w ⃗ = v 1 w 1 + v 2 w 2 + . . . + v n w n \vec{v} \cdot \vec{w} = v_1w_1 + v_2w_2 + ... + v_nw_n v w =v1w1+v2w2+...+vnwn

  2. 线性性质:内积满足线性叠加性,即对于任意标量 a a a 和向量 v ⃗ , w ⃗ , u ⃗ \vec{v}, \vec{w}, \vec{u} v ,w ,u ,有以下性质:

    • 可加性: v ⃗ ⋅ ( w ⃗ + u ⃗ ) = v ⃗ ⋅ w ⃗ + v ⃗ ⋅ u ⃗ \vec{v} \cdot (\vec{w} + \vec{u}) = \vec{v} \cdot \vec{w} + \vec{v} \cdot \vec{u} v (w +u )=v w +v u
    • 齐次性: ( a v ⃗ ) ⋅ w ⃗ = a ( v ⃗ ⋅ w ⃗ ) (a\vec{v}) \cdot \vec{w} = a(\vec{v} \cdot \vec{w}) (av )w =a(v w )

其中
R Θ , m d = ( cos ⁡ m θ 1 − sin ⁡ m θ 1 0 0 ⋯ 0 0 sin ⁡ m θ 1 cos ⁡ m θ 1 0 0 ⋯ 0 0 0 0 cos ⁡ m θ 2 − sin ⁡ m θ 2 ⋯ 0 0 0 0 sin ⁡ m θ 2 cos ⁡ m θ 2 ⋯ 0 0 ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ 0 0 0 0 ⋯ cos ⁡ m θ d / 2 − sin ⁡ m θ d / 2 0 0 0 0 ⋯ sin ⁡ m θ d / 2 cos ⁡ m θ d / 2 ) (8) \pmb R_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \\ \end{pmatrix} \tag{8} RRRΘ,md=cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2(8)
是一个带有预定义参数 Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } Θ = \{θ_i = 10000^{−2(i−1)/d}, i ∈ [1, 2, ..., d/2]\} Θ={θi=100002(i1)/d,i[1,2,...,d/2]}​ 的旋转矩阵。RoPE的图示如原论文中的图(1)所示。将RoPE应用于等式(3)中的自注意力机制,我们可以得到:
q m ⊤ k n = ( R Θ , m d W q x m ) ⊤ ( R Θ , n d W k x n ) = x m ⊤ W q R Θ , n − m d W k x n (9) \pmb q_m^\top \pmb k_n = (\pmb R_{\Theta,m}^d \pmb W_{q}\pmb x_m)^\top (\pmb R_{\Theta,n}^d \pmb W_{k}\pmb x_n) = \pmb x_m^\top \pmb W_q \pmb R_{\Theta,n-m}^d \pmb W_k \pmb x_n \tag{9} qqqmkkkn=(RRRΘ,mdWWWqxxxm)(RRRΘ,ndWWWkxxxn)=xxxmWWWqRRRΘ,nmdWWWkxxxn(9)
其中 R Θ , n − m d = ( R Θ , m d ) ⊤ R Θ , n d \pmb R_{\Theta,n-m}^d=(\pmb R_{\Theta,m}^d)^\top \pmb R_{\Theta,n}^d RRRΘ,nmd=(RRRΘ,md)RRRΘ,nd。值得指出的是, R Θ \pmb R_{\Theta} RRRΘ​是一个正交矩阵,它不会改变向量的模长,因此通常来说它不会改变原模型的稳定性。

我们可以增大 θ \theta θ的base以支持更长的上下文,这里是10000。

image-20240413084948720

上图所说的是一个长度为6的序列,在进行自注意力计算时,Query和Key向量经过旋转位置编码变换的过程。首先对于位置1来说,记为 m m m。然后仅考虑第一个二维子空间,即 ( x 1 , x 2 ) (x_1,x_2) (x1,x2)向量,旋转 m θ 1 m\theta_1 mθ1后得到的增强表示。

由于公式(8)中 R Θ , m d \pmb R^d_{\Theta,m} RRRΘ,md的稀疏性,可以通过下述等价方式来实现 R Θ , m d \pmb R^d_{\Theta,m} RRRΘ,md x ∈ R d \pmb x \in \R^d xxxRd的乘法:
KaTeX parse error: No such environment: equation at position 37: …\pmb x = \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲\begin{pmatrix}…
其中 ⊗ \otimes ​是逐位对应相乘。

为什么可以简化成这样子,把乘 x \pmb x xxx带入公式(8)得到:
R Θ , m d x = ( cos ⁡ m θ 1 − sin ⁡ m θ 1 0 0 ⋯ 0 0 sin ⁡ m θ 1 cos ⁡ m θ 1 0 0 ⋯ 0 0 0 0 cos ⁡ m θ 2 − sin ⁡ m θ 2 ⋯ 0 0 0 0 sin ⁡ m θ 2 cos ⁡ m θ 2 ⋯ 0 0 ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ 0 0 0 0 ⋯ cos ⁡ m θ d / 2 − sin ⁡ m θ d / 2 0 0 0 0 ⋯ sin ⁡ m θ d / 2 cos ⁡ m θ d / 2 ) ( x 1 x 2 x 3 x 4 ⋮ x d − 1 x d ) \pmb R_{\Theta,m}^d \pmb x= \begin{pmatrix}\begin{array}{cc:cc:cc:cc} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \hdashline 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \hdashline \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ \hdashline 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \\ \end{array}\end{pmatrix} \begin{pmatrix}x_1 \\ x_2 \\ \hdashline x_3 \\ x_4 \\ \hdashline\vdots \\ \hdashline x_{d-1} \\ x_{d}\end{pmatrix} RRRΘ,mdxxx=cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2x1x2x3x4xd1xd
根据分块矩阵的乘法,我们仅考虑左右两边矩阵的第一块,其得到(10)中向量的第1和第2个元素:
( cos ⁡ m θ 1 − sin ⁡ m θ 1 sin ⁡ m θ 1 cos ⁡ m θ 1 ) ( x 1 x 2 ) = ( x 1 cos ⁡ m θ 1 − x 2 sin ⁡ m θ 1 x 1 sin ⁡ m θ 1 + x 2 cos ⁡ m θ 1 ) \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1\\ \sin m\theta_1 & \cos m\theta_1 \end{pmatrix} \begin{pmatrix} x_1\\ x_2 \end{pmatrix} = \begin{pmatrix}x_1 \cos m\theta_1 - x_2 \sin m\theta_1 \\ x_1 \sin m\theta_1+x_2 \cos m\theta_1 \end{pmatrix} (cosmθ1sinmθ1sinmθ1cosmθ1)(x1x2)=(x1cosmθ1x2sinmθ1x1sinmθ1+x2cosmθ1)
因此这是成立的。

代码实现

本节参考LLaMA源码来实现旋转位置编码,同时底层实现逻辑进行一个解释。

首先定义一个函数生成旋转矩阵:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):"""给定维度预计算频率(\theta) Tensor的复指数(complex exponentials,cis)Args:dim (int): dimension of the frequency tensorend (int): end index for precomputing frequenciestheta (float, optional): scaling factor for frequency computation. Defaults to 10000.0.Returns:torch.Tensor: Precomputed frequency tensor with complex exponentials."""# freqs (dim/2, )# theta_i = 10000 ** (-2(i-1)/dim) for i = [1,2,...,dim / 2]# theta_i# we start from 0 dont need to do i-1freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))# generate token sequence m = [0, 1, ..., seq_len - 1]# m (end, )m = torch.arange(end, device=freqs.device)# compute m * \theta# freqs (end, dim / 2)freqs = torch.outer(m, freqs).float()# freqs_cis (end, dim / 2)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cis

这个函数用于生成公式(8)中的旋转矩阵。

首先计算预定义参数 Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } Θ = \{θ_i = 10000^{−2(i−1)/d}, i ∈ [1, 2, ..., d/2]\} Θ={θi=100002(i1)/d,i[1,2,...,d/2]} ,我们的 i i i 0 0 0开始因此不需要 i − 1 i-1 i1,对应上面的Line 17。

然后考虑所有的位置,生成一个m = (seq_len, )形状的向量,Line 20。

计算m和Line 17计算出来的freqs的外积,即m中的每个位置 m i m_i mi都会乘上 Θ Θ Θ的每个元素,得到一个(seq_len, dim / 2)形状的矩阵。假设序列的长度

假设 m = [ m 1 , m 2 , ⋯ , m T ] = [ 1 , 2 , ⋯ , N ] m=[m_1,m_2,\cdots,m_T] =[1,2,\cdots, N] m=[m1,m2,,mT]=[1,2,,N]​,这里 N N N表示序列长度。

它们的乘积是一个矩阵:
( m 1 θ 1 m 1 θ 2 ⋯ m 1 θ d / 2 m 2 θ 1 m 2 θ 2 ⋯ m 2 θ d / 2 ⋮ ⋮ ⋱ ⋮ m N θ 1 m N θ 2 ⋯ m N θ d / 2 ) \begin{pmatrix} m_1 \theta_1 & m_1 \theta_2 & \cdots & m_1 \theta_{d/2} \\ m_2 \theta_1 & m_2 \theta_2 & \cdots & m_2 \theta_{d/2} \\ \vdots & \vdots &\ddots &\vdots \\ m_N \theta_1 & m_N \theta_2 & \cdots & m_N \theta_{d/2} \end{pmatrix} m1θ1m2θ1mNθ1m1θ2m2θ2mNθ2m1θd/2m2θd/2mNθd/2
最后在Line 25通过torch.polar将它们转换为复数形式:
( cos ⁡ ( m 1 θ 1 ) + i ⋅ sin ⁡ ( m 1 θ 1 ) cos ⁡ ( m 1 θ 2 ) + i ⋅ sin ⁡ ( m 1 θ 2 ) ⋯ cos ⁡ ( m 1 θ d / 2 ) + i ⋅ sin ⁡ ( m 1 θ d / 2 ) cos ⁡ ( m 2 θ 1 ) + i ⋅ sin ⁡ ( m 2 θ 1 ) cos ⁡ ( m 2 θ 2 ) + i ⋅ sin ⁡ ( m 2 θ 2 ) ⋯ cos ⁡ ( m 2 θ d / 2 ) + i ⋅ sin ⁡ ( m 2 θ d / 2 ) ⋮ ⋮ ⋱ ⋮ cos ⁡ ( m N θ 1 ) + i ⋅ sin ⁡ ( m N θ 1 ) cos ⁡ ( m N θ 2 ) + i ⋅ sin ⁡ ( m N θ 2 ) ⋯ cos ⁡ ( m N θ d / 2 ) + i ⋅ sin ⁡ ( m N θ d / 2 ) ) \begin{pmatrix} \cos(m_1 \theta_1) + i\cdot \sin(m_1 \theta_1) & \cos(m_1 \theta_2) + i\cdot \sin(m_1 \theta_2) & \cdots & \cos(m_1 \theta_{d/2}) + i\cdot \sin(m_1 \theta_{d/2}) \\ \cos(m_2 \theta_1) + i\cdot \sin(m_2 \theta_1) & \cos(m_2 \theta_2) + i\cdot \sin(m_2 \theta_2) & \cdots & \cos(m_2 \theta_{d/2}) + i\cdot \sin(m_2 \theta_{d/2}) \\ \vdots & \vdots &\ddots &\vdots \\ \cos(m_N \theta_1) + i\cdot \sin(m_N \theta_1) & \cos(m_N \theta_2) + i\cdot \sin(m_N \theta_2) & \cdots & \cos(m_N \theta_{d/2}) + i\cdot \sin(m_N \theta_{d/2}) \\ \end{pmatrix} cos(m1θ1)+isin(m1θ1)cos(m2θ1)+isin(m2θ1)cos(mNθ1)+isin(mNθ1)cos(m1θ2)+isin(m1θ2)cos(m2θ2)+isin(m2θ2)cos(mNθ2)+isin(mNθ2)cos(m1θd/2)+isin(m1θd/2)cos(m2θd/2)+isin(m2θd/2)cos(mNθd/2)+isin(mNθd/2)
torch.polar(abs, angle)基于absangle计算出一个极坐标系中的复数表示:

image-20240524170711764

那如何达到公式(10)的结果呢,为了简单,这里只展示 d = 4 d=4 d=4​的情况,考虑某个Token x \pmb x xxx
x = [ x 1 x 2 x 3 x 4 ] \pmb x=\begin{bmatrix} x_1 & x_2 & x_3 & x_4 \end{bmatrix} xxx=[x1x2x3x4]
第一步把 x \pmb x xxx的元素两两分组:
x = [ [ x 1 , x 2 ] [ x 3 , x 4 ] ] \pmb x=\begin{bmatrix} [x_1 ,x_2 ] & [x_3 ,x_4] \end{bmatrix} xxx=[[x1,x2][x3,x4]]
也不考虑批次维度,形状由(1,4)变成(1,2,2)。然后把新的 x \pmb x xxx转换成复数的形式,形状变成了(1, 2)
x = [ x 1 + i ⋅ x 2 x 3 + i ⋅ x 4 ] \pmb x=\begin{bmatrix} x_1 + i\cdot x_2 & x_3 + i \cdot x_4 \end{bmatrix} xxx=[x1+ix2x3+ix4]
即每个二维向量变成了一个复数。然后我们把这个向量矩阵和freqs_cis对应的向量对应位置相乘(分别旋转 m θ 1 , m θ 2 m\theta_1,m\theta_2 mθ1,mθ2角度: d / 2 = 4 / 2 = 2 d/2=4/2=2 d/2=4/2=2),这里假设当前位置为 m m m​,然后有:
x = [ x 1 + i ⋅ x 2 x 3 + i ⋅ x 4 ] ⊗ [ cos ⁡ ( m θ 1 ) + i ⋅ sin ⁡ ( m θ 1 ) cos ⁡ ( m θ 2 ) + i ⋅ sin ⁡ ( m θ 2 ) ] = [ ( x 1 + i ⋅ x 2 ) [ cos ⁡ ( m θ 1 ) + i ⋅ sin ⁡ ( m θ 1 ) ] ( x 3 + i ⋅ x 4 ) [ cos ⁡ ( m θ 2 ) + i ⋅ sin ⁡ ( m θ 2 ) ] ] = [ x 1 cos ⁡ m θ 1 + i ⋅ x 1 sin ⁡ m θ 1 + i ⋅ x 2 cos ⁡ m θ 1 − x 2 sin ⁡ m θ 1 x 3 cos ⁡ m θ 2 + i ⋅ x 3 sin ⁡ m θ 2 + i ⋅ x 4 cos ⁡ m θ 2 − x 4 sin ⁡ m θ 2 ] = [ x 1 cos ⁡ m θ 1 − x 2 sin ⁡ m θ 1 + i ( x 1 sin ⁡ m θ 1 + x 2 cos ⁡ m θ 1 ) x 3 cos ⁡ m θ 2 − x 4 sin ⁡ m θ 2 + i ( x 3 sin ⁡ m θ 2 + x 4 cos ⁡ m θ 2 ) ] \begin{aligned} \pmb x &=\begin{bmatrix} x_1 + i\cdot x_2 & x_3 + i \cdot x_4 \end{bmatrix} \otimes \begin{bmatrix} \cos(m \theta_1) + i\cdot \sin(m \theta_1) & \cos(m \theta_2) + i\cdot \sin(m \theta_2)\end{bmatrix} \\ &= \begin{bmatrix} (x_1 + i\cdot x_2) [\cos(m \theta_1) + i\cdot \sin(m \theta_1)] & (x_3 + i \cdot x_4) [\cos(m \theta_2) + i\cdot \sin(m \theta_2)] \end{bmatrix} \\ &= \begin{bmatrix} x_1 \cos m \theta_1 +i\cdot x_1 \sin m \theta_1 + i \cdot x_2 \cos m \theta_1 - x_2 \sin m \theta_1 & x_3 \cos m \theta_2 +i\cdot x_3 \sin m \theta_2 + i \cdot x_4 \cos m \theta_2 - x_4 \sin m \theta_2 \end{bmatrix} \\ &= \begin{bmatrix} x_1 \cos m \theta_1 - x_2 \sin m \theta_1+ i(x_1 \sin m \theta_1 + x_2 \cos m \theta_1) & x_3 \cos m \theta_2 -x_4 \sin m \theta_2 +i(x_3 \sin m \theta_2 +x_4 \cos m \theta_2) \end{bmatrix} \\ \end{aligned} xxx=[x1+ix2x3+ix4][cos(mθ1)+isin(mθ1)cos(mθ2)+isin(mθ2)]=[(x1+ix2)[cos(mθ1)+isin(mθ1)](x3+ix4)[cos(mθ2)+isin(mθ2)]]=[x1cosmθ1+ix1sinmθ1+ix2cosmθ1x2sinmθ1x3cosmθ2+ix3sinmθ2+ix4cosmθ2x4sinmθ2]=[x1cosmθ1x2sinmθ1+i(x1sinmθ1+x2cosmθ1)x3cosmθ2x4sinmθ2+i(x3sinmθ2+x4cosmθ2)]

得到一个形状为(1,2)的复数项链。

然后我们把里面的复数变为二维向量:
x = [ [ x 1 cos ⁡ m 1 θ 1 − x 2 sin ⁡ m 1 θ 1 x 1 sin ⁡ m 1 θ 1 + x 2 cos ⁡ m 1 θ 1 ] [ x 3 cos ⁡ m 1 θ 2 − x 4 sin ⁡ m 1 θ 2 x 3 sin ⁡ m 1 θ 2 + x 4 cos ⁡ m 1 θ 2 ] ] \pmb x= \begin{bmatrix} \begin{bmatrix} x_1 \cos m_1 \theta_1 - x_2 \sin m_1 \theta_1 \\ x_1 \sin m_1 \theta_1 + x_2 \cos m_1 \theta_1 \end{bmatrix} & \begin{bmatrix} x_3 \cos m_1 \theta_2 -x_4 \sin m_1 \theta_2 \\ x_3 \sin m_1 \theta_2 +x_4 \cos m_1 \theta_2 \end{bmatrix} \end{bmatrix} xxx=[[x1cosm1θ1x2sinm1θ1x1sinm1θ1+x2cosm1θ1][x3cosm1θ2x4sinm1θ2x3sinm1θ2+x4cosm1θ2]]
最后拉平其中的二维向量:
x = [ x 1 cos ⁡ m θ 1 − x 2 sin ⁡ m θ 1 x 1 sin ⁡ m θ 1 + x 2 cos ⁡ m θ 1 x 3 cos ⁡ m θ 2 − x 4 sin ⁡ m θ 2 x 3 sin ⁡ m θ 2 + x 4 cos ⁡ m 1 θ 2 ] \pmb x= \begin{bmatrix} x_1 \cos m \theta_1 - x_2 \sin m \theta_1 & x_1 \sin m \theta_1 + x_2 \cos m \theta_1 & x_3 \cos m \theta_2 -x_4 \sin m \theta_2 & x_3 \sin m \theta_2 +x_4 \cos m_1 \theta_2 \end{bmatrix} xxx=[x1cosmθ1x2sinmθ1x1sinmθ1+x2cosmθ1x3cosmθ2x4sinmθ2x3sinmθ2+x4cosm1θ2]
比较公式(10)中前4行的结果,可以发现是一样的,只不过列向量变成了行向量。

基于上面的过程我们就不难理解下面的代码:

def apply_rotary_emb(xq: Tensor, xk: Tensor, freq_cis: Tensor):"""使用给定的频率Tensor将旋转嵌入应用到输入张量中。该函数使用提供的频率使用给定的频率Tensor将旋转嵌入应用到输入张量中。freqs_cis将旋转嵌入应用到给定的查询xq和键xk张量上。输入张量被重塑为复数,并且频率张量被重塑以匹配广播兼容性。生成的张量包含旋转嵌入,并作为实张量返回。Args:xq (torch.Tensor): Query tensor to apply rotary embeddings.xk (torch.Tensor): Key tensor to apply rotary embeddings.freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.Returns:Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings."""# xq (batch_size, seq_len, n_head, head_dim)# xq_ (batch_size, seq_len, n_head, head_dim // 2, 2)xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)# turn to complex# xq_ (batch_size, seq_len, n_head, head_dim // 2)xq_ = torch.view_as_complex(xq_)xk_ = torch.view_as_complex(xk_)# 应用旋转操作,然后将结果转回实数# xq_out (batch_size, seq_len, n_head, head_dim)xq_out = torch.view_as_real(xq_ * freq_cis).flatten(2)xk_out = torch.view_as_real(xk_ * freq_cis).flatten(2)return xq_out.type_as(xq), xk_out.type_as(xk)

下篇文章我们会探讨如何应用旋转位置编码到自注意力上。

参考

  1. [论文笔记]ROFORMER
  2. 复数与二维空间旋转

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/19118.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据与结构——红黑树

目录 红黑树的概念 性质 结点的定义 插入 验证 查找 删除 红黑树与AVL树的比较 红黑树的概念 红黑树是一种自平衡二叉搜索树(Binary Search Tree, BST),其每个节点带有颜色属性,可以是红色或黑色。红黑树通过约束节点颜色…

未来已来:Facebook的数字革命与社交转型

在当今数字化时代,Facebook作为全球最大的社交网络之一,不仅扮演着连接人们的桥梁,更是引领着社交行业的数字革命与转型。本文将深入探讨Facebook如何通过创新技术、改变用户体验以及应对挑战,塑造了未来社交的面貌,以…

ozon卖家精灵,ozon卖家怎么使用

在跨境电商的浪潮中,OZON作为俄罗斯领先的电商平台,吸引了众多卖家争相入驻。然而,面对日益激烈的市场竞争,如何提升店铺的运营效果,成为卖家们迫切需要解决的问题。而OZON卖家精灵作为一款专为OZON卖家打造的辅助工具…

java高级——Collection集合之List探索(包含ArrayList、LinkedList、Vector底层实现及区别,非常详细哦)

java高级——Collection集合之List探索 前情提要文章介绍提前了解的知识点1. 数组2. 单向链表3. 双向链表4. 为什么单向链表使用的较多5. 线程安全和线程不安全的概念 ArrayList介绍1. 继承结构解析1.1 三个标志性接口1.2 AbstractList和AbstractCollection 2. ArrayList底层代…

民国漫画杂志《时代漫画》第32期.PDF

时代漫画32.PDF: https://url03.ctfile.com/f/1779803-1248635561-0ae98a?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps: 资源来源网络!

去除字符串中的空格和特殊字符

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 用户在输入数据时,可能会无意中输入多余的空格,或在一些情况下,字符串前后不允许出现空格和特殊字符,…

Beego 使用教程 7:Web 文件上传下载和错误处理

beego 是一个用于Go编程语言的开源、高性能的 web 框架 beego 被用于在Go语言中企业应用程序的快速开发,包括RESTful API、web应用程序和后端服务。它的灵感来源于Tornado, Sinatra 和 Flask beego 官网:http://beego.gocn.vip/ 上面的 bee…

「清新题精讲」Skiers

更好的阅读体验 Skiers Description 给定 n n n 个点的有向无环平面图,求最少多少条从 1 1 1 到 n n n 的路径能覆盖原图的所有边? 1 ≤ n ≤ 5 1 0 3 1\le n\le 5\times10^3 1≤n≤5103 Solution 考虑从 1 1 1 到 n n n 的路径其实是边的链覆…

如何让你的网站能通过域名访问

背景 当我们租一台云服务器,并在上面运行了一个Web服务,我们可以使用云服务器的公网IP地址进行访问,如下: 本文主要记录如何 实现让自己的网站可以通过域名访问。 买域名 可以登录腾讯云等主流公有云平台的,购买域名…

设计模式21——命令模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用,主要是下面的UML图可以起到大作用,在你学习过一遍以后可能会遗忘,忘记了不要紧,只要看一眼UML图就能想起来了。同时也请大家多多指教。 命令模式(Command&…

失落的方舟 命运方舟台服怎么下载游戏客户端 游戏账号怎么注册

《失落的方舟》(Lost Ark)是韩国Smilegate公司精心打造的一款大型多人在线角色扮演游戏(MMORPG),以其精美的画面、沉浸式的剧情、类似动作游戏的战斗体验和广阔的开放世界设定,自面世以来便深受全球玩家喜爱…

计算机毕业设计 | SpringBoot+vue仓库管理系统(附源码)

1,绪论 1.1 项目背景 随着电子计算机技术和信息网络技术的发明和应用,使着人类社会从工业经济时代向知识经济时代发展。在这个知识经济时代里,仓库管理系统将会成为企业生产以及运作不可缺少的管理工具。这个仓库管理系统是由:一…

一款高级管理控制面板主题!【送源码】

AdminLTE是一个完全响应的管理模板。基于Bootstrap5框架和JavaScript插件。高度可定制,易于使用。适用于从小型移动设备到大型桌面的多种屏幕分辨率。AdminLTE 是一个基于Bootstrap 3.x的免费高级管理控制面板主题。 https://github.com/almasaeed2010/AdminLTE —…

操作系统真象还原:完善MBR

第3章-完善MBR 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 编译器给程序中各符号(变量名或函数名等)分配的地址,就是各符号相对于文件开头的偏移量 。 section 称为节,在有的编译器中,同…

STM32的时钟介绍

目录 前言1. 简介1.1 时钟是用来做什么的1.2 时钟产生的方式 2. 时钟树的组成2.1 时钟源2.1.1 内部时钟2.1.2 外部时钟 2.2 PLL锁相环2.3 SYSCLK2.4 AHB和HCLK2.5 APB和PCLK2.6 总结 3. STM32时钟的如何进行工作4.我的疑问4.1 使用MSI和HSI有什么区别吗?4.2 MSI的频…

Linux系统编程(五)多线程

目录 一、基本知识点二、线程的编译三、 线程相关函数1. 线程的创建2. 线程的退出3. 线程的等待补充 四、综合举例 一、基本知识点 线程(Thread)是操作系统能够进行运算调度的最小单位。它被包含在进程之中,是进程中的实际运作单位。一个标准…

【4.vi编辑器使用(下)】

一、vi编辑器的光标移动 二、vi编辑器查找命令 1、命令::/string 查找字符串 n:继续查找 N:反向继续查找 /^the 查找以the开头的行 /end 查找以 查找以 查找以结尾的行 三、vi编辑器替换命令 1、语法: : s[范围,范围]str1/str2[g] g表示全…

可视化大屏:随意堆数据,错!要主次分明、重点突出,动静结合。

可视化大屏是一种展示数据的方式,它的设计应该遵循一些原则,以确保信息的传递和理解效果最佳。以下是一些关键点,可以帮助设计出主次分明、重点突出、动静结合的可视化大屏: 定义目标和重点: 在开始设计可视化大屏之前…

C语言数据结构堆排序、向上调整和向下调整的时间复杂度的计算、TopK问题等的介绍

文章目录 前言一、堆排序1. 排升序(1). 建堆(2). 排序 2. 拍降序(1). 建堆(2). 排序 二、建堆时间复杂度的计算1. 向上调整时间复杂度2. 向下调整时间复杂度 三、TopK问题总结 前言 …

Java事务入门:从基础概念到初步实践

Java事务入门:从基础概念到初步实践 引言1. Java事务基础概念1.1 什么是事务?1.2 为什么需要事务? 2. Java事务管理2.1 JDBC 的事务管理2.2 Spring 事务管理2.2.1 Spring JDBC2.2.1.1 添加 Spring 配置2.2.1.2 添加业务代码并测试验证 2.2.2…