Activation Beacon出自智源与人大在2024年1月放在arxiv上的论文《Long Context Compression with Activation Beacon》(v1版的题目:Soaring from 4K to 400K: Extending LLM’s Context with Activation Beacon)。它引入了Beacon token将上下文信息蒸馏到其激活(activations);在压缩时将文本切分成固定大小的块(chunk),并根据压缩比 α \alpha α进一步将chunk分成更小的单元,beacon token插入在每个单元后面;LLM每次编码一个chunk,在自注意力机制执行过程中将chunk的信息蒸馏到beacon token的激活信息(activation)中,逐步地对整个长文本完成压缩过程,论文实验结果表明此方法可以有效加速推理过程并节省KV cache内存占用。
实现思路
如论文图1所示意,对输入文本 X = [ x 1 , … , x n ] X = [x_1, \ldots, x_n] X=[x1,…,xn],将其划分为相同尺寸w(如1024)的chunk:
[ x 1 , … , x n ] → Partition [ X 1 , … X ⌈ n / w ⌉ ] , X i = [ x ( i − 1 ) w + 1 , … , x i w ] = [ x 1 i , … , x w i ] [x_1, \ldots, x_n] \xrightarrow{\text{Partition}} [X_1, \ldots X_{\lceil n/w \rceil}], X_i=[x_{(i-1)w+1}, \ldots,x_{iw}] = [x^i_1, \ldots, x^i_w] [x1,…,xn]Partition[X1,…X⌈n/w⌉],Xi=[x(i−1)w+1,…,xiw]=[x1i,…,xwi]
对每一个chunk X i X_i Xi,使用一个压缩比 α i \alpha_i αi(w可由 α i \alpha_i αi整除),即将chunk划分到大小为 α \alpha α的更细粒度单元,一组共 k i = w / α i k_i=w/\alpha_i ki=w/αi个beacon token: B i = [ ⟨ b ⟩ 1 i , … , ⟨ b ⟩ k i i ] B_i=[\langle \mathbf{b} \rangle^i_1, \ldots, \langle \mathbf{b} \rangle^i_{k_i}] Bi=[⟨b⟩1i,…,⟨b⟩kii]被交替地插入到这些单元后。
X i → Interleave B i X i ′ = [ x 1 i , … , x α i i , ⟨ b ⟩ 1 i , … , x w − α i + 1 i , … , x w i , ⟨ b ⟩ k i i ] X_i \xrightarrow{\text{Interleave} \ B_i} X^{\prime}_i = [x^i_1, \ldots, x^i_{\alpha_i}, \langle \mathbf{b} \rangle^i_1, \ldots, x^i_{w-\alpha_i +1}, \ldots, x^i_w, \langle \mathbf{b} \rangle^i_{k_i}] XiInterleave BiXi′=[x1i,…,xαii,⟨b⟩1i,…,xw−αi+1i,…,xwi,⟨b⟩kii]
LLM逐一地编码这些chunk,在自注意力机制过程中将每个chunk的信息压缩到beacon token的激活(activations)中,在编码了 X i ′ X^{\prime}_i Xi′后,将 X i X_i Xi的所有原始token(raw tokens)的激活信息给丢弃,但一直保留并累积beacon token B i B_i Bi的激活信息;在编码下一个chunk X i + 1 ′ X^{\prime}_{i+1} Xi+1′时,LLM将累积的beacon激活作为原始上下文 X ≤ i X_{\le i} X≤i的代理。
如论文图2所示,Activation Beacon与一般的LLM相比只做少许修改。对于第i个chunk X i ′ X^{\prime}_i Xi′,编码过程可以写作:
LLM ( ⟨ b ⟩ 1 i , … , ⟨ b ⟩ k i − 1 i − 1 ⏟ beacon activations accumulated from X < i ′ , x 1 i , … , x α i i , ⟨ b ⟩ 1 i , … , x w − α i + 1 i , … , x w i , ⟨ b ⟩ k i i ⏟ the current chunk X i ′ ) , \operatorname{LLM}(\underbrace{\langle\mathbf{b}\rangle_1^i, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}}_{\text {beacon activations accumulated from } X_{<i}^{\prime}}, \underbrace{x_1^i, \ldots, x_{\alpha_i}^i,\langle\mathbf{b}\rangle_1^i, \ldots, x_{w-\alpha_i+1}^i, \ldots, x_w^i,\langle\mathbf{b}\rangle_{k_i}^i}_{\text {the current chunk } X_i^{\prime}}), LLM(beacon activations accumulated from X<i′ ⟨b⟩1i,…,⟨b⟩ki−1i−1,the current chunk Xi′ x1i,…,xαii,⟨b⟩1i,…,xw−αi+1i,…,xwi,⟨b⟩kii),
也就是LLM的输入是前面chunk的激活累积和当前chunk需要被编码的token的混合物。设D表示LLM的隐藏层尺寸, H ∈ R ( w + k i ) × D \boldsymbol{H} \in \mathbb{R}^{(w+k_i) \times D} H∈R(w+ki)×D表示LLM任意层的self attention的输入隐藏状态。我们会区分raw token和beacon token:
I r = { j ∣ x j i ≠ ⟨ b ⟩ } , I b = { j ∣ x j i = ⟨ b ⟩ } ; H r = H [ I r ] , H b = H [ I b ] . \mathbb{I}^r=\left\{j \mid x_j^i \neq\langle\mathbf{b}\rangle\right\}, \quad \mathbb{I}^b=\left\{j \mid x_j^i=\langle\mathbf{b}\rangle\right\} ; \quad \boldsymbol{H}^r=\boldsymbol{H}\left[\mathbb{I}^r\right], \quad \boldsymbol{H}^b=\boldsymbol{H}\left[\mathbb{I}^b\right] . Ir={j∣xji=⟨b⟩},Ib={j∣xji=⟨b⟩};Hr=H[Ir],Hb=H[Ib].
将隐状态变成query, key, value:
Q r = W Q r H r , K r = W K r H r , V r = W V r H r , Q b = W Q b H b , K b = W K b H b , V b = W V b H b , \begin{array}{lll} \boldsymbol{Q}^r=\boldsymbol{W}_Q^r \boldsymbol{H}^r, & \boldsymbol{K}^r=\boldsymbol{W}_K^r \boldsymbol{H}^r, & \boldsymbol{V}^r=\boldsymbol{W}_V^r \boldsymbol{H}^r, \\ \boldsymbol{Q}^b=\boldsymbol{W}_Q^b \boldsymbol{H}^b, & \boldsymbol{K}^b=\boldsymbol{W}_K^b \boldsymbol{H}^b, & \boldsymbol{V}^b=\boldsymbol{W}_V^b \boldsymbol{H}^b, \end{array} Qr=WQrHr,Qb=WQbHb,Kr=WKrHr,Kb=WKbHb,Vr=WVrHr,Vb=WVbHb,
上式中 W ∗ r \boldsymbol{W}^r_* W∗r是LLM原来的投影矩阵, W ∗ b \boldsymbol{W}^b_* W∗b是新引入的只处理beacon token的投影矩阵。再将raw token和beacon token的query/key/value状态来得到 Q , K , V ∈ R ( w + k i ) × D \boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V} \in \mathbb{R}^{(w+k_i) \times D} Q,K,V∈R(w+ki)×D
Q [ I r ] = Q r , Q [ I b ] = Q b , K [ I r ] = K r , K [ I b ] = K b , V [ I r ] = V r , V [ I b ] = V b \boldsymbol{Q}\left[\mathbb{I}^r\right]= \boldsymbol{Q}^r,\boldsymbol{Q}\left[\mathbb{I}^b\right]= \boldsymbol{Q}^b, \quad \boldsymbol{K}\left[\mathbb{I}^r\right]= \boldsymbol{K}^r,\boldsymbol{K}\left[\mathbb{I}^b\right]= \boldsymbol{K}^b, \quad \boldsymbol{V}\left[\mathbb{I}^r\right]= \boldsymbol{V}^r,\boldsymbol{V}\left[\mathbb{I}^b\right]= \boldsymbol{V}^b Q[Ir]=Qr,Q[Ib]=Qb,K[Ir]=Kr,K[Ib]=Kb,V[Ir]=Vr,V[Ib]=Vb
最后,用标准方法计算self attention:
A = softmax ( mask ( Q { K a c ; K } T D ) ) , V = A { V a c ; V } \boldsymbol{A} = \text{softmax}\left(\text{mask} \left( \frac{\boldsymbol{Q}\{\boldsymbol{K}^{ac}; \boldsymbol{K} \}^T }{\sqrt{D}} \right)\right), \quad \boldsymbol{V} = \boldsymbol{A}\{\boldsymbol{V}^{ac};\boldsymbol{V} \} A=softmax(mask(DQ{Kac;K}T)),V=A{Vac;V}
上式中的 { ⋅ ; ⋅ } \{ \cdot ; \cdot\} {⋅;⋅}表示矩阵连接, K a c , V a c ∈ R m i − 1 × D \boldsymbol{K}^{ac}, \boldsymbol{V}^{ac} \in \mathbb{R}^{m_{i-1} \times D} Kac,Vac∈Rmi−1×D是从之前的chunk累积得到的beacon token的激活参数, m i − 1 = ∑ j = 1 i − 1 k j m_{i-1} = \sum^{i-1}_{j=1} k_j mi−1=∑j=1i−1kj, mask就是causal attention mask。在self attention过程中,所有的token与其他token进行交互,使得beacon tokens的key和value( K b , V b \boldsymbol{K}^{b}, \boldsymbol{V}^{b} Kb,Vb)蒸馏了 X i X_i Xi的上下文信息,它们会增量累积:
K a c = { K a c ; K b } , V a c = { V a c ; V b } \boldsymbol{K}^{ac} = \{\boldsymbol{K}^{ac}; \boldsymbol{K}^{b}\}, \boldsymbol{V}^{ac} = \{\boldsymbol{V}^{ac};\boldsymbol{V}^{b} \} Kac={Kac;Kb},Vac={Vac;Vb}
### 下面代码是activation beacon在实现时,interleave插入beacon token的代码,位于model_beacon.py的Memory类的_step函数input_len = input_ids.shape[1]if beacon_size > 0:# insert beacon tokens in between raw tokens,对应论文中的式(2)input_ids_with_beacons = input_ids.new_full((input_ids.shape[0], input_len + beacon_size), self.beacon_token.item())raw_token_indices = torch.arange(input_ids_with_beacons.shape[1], device=input_ids.device)interleave_start_idx = compression_ratio - self._interleave_remainderraw_token_indices = raw_token_indices[raw_token_indices % (compression_ratio + 1) != interleave_start_idx].unsqueeze(0).expand_as(input_ids)input_ids_with_beacons = input_ids_with_beacons.scatter(dim=1, index=raw_token_indices, src=input_ids)input_ids = input_ids_with_beacons# attention mask## beacon token是参与attention的,所以默认值为1attention_mask_with_beacons = attention_mask.new_full((attention_mask.shape[0], attention_mask.shape[1] + beacon_size), 1)attention_mask_with_beacons = attention_mask_with_beacons.scatter(dim=1, index=raw_token_indices, src=attention_mask)attention_mask = attention_mask_with_beacons# labelsif labels is not None:## beacon token不参与loss的计算,所以标签为-100labels_with_beacons = labels.new_full((labels.shape[0], labels.shape[1] + beacon_size), -100)labels_with_beacons = labels_with_beacons.scatter(dim=1, index=raw_token_indices, src=labels)labels = labels_with_beacons
训练过程
Activation Beacon的学习目标是在当前chunk上下文和之前压缩信息的条件下提高生成质量,损失函数如下:
min Θ b . ∑ i = 2 ⌈ N / w ⌉ ∑ j = 1 w Pr ( x j i ∣ ⟨ b ⟩ 1 1 , … , ⟨ b ⟩ k i − 1 i − 1 , x 1 i , … x j − 1 i ; Θ , Θ b ) . \min _{\boldsymbol{\Theta}^b} . \sum_{i=2}^{\lceil N / w\rceil} \sum_{j=1}^w \operatorname{Pr}\left(x_j^i \mid\langle\mathbf{b}\rangle_1^1, \ldots,\langle\mathbf{b}\rangle_{k_{i-1}}^{i-1}, x_1^i, \ldots x_{j-1}^i ; \mathbf{\Theta}, \boldsymbol{\Theta}^b\right) . Θbmin.i=2∑⌈N/w⌉j=1∑wPr(xji∣⟨b⟩11,…,⟨b⟩ki−1i−1,x1i,…xj−1i;Θ,Θb).
上式中 Θ \mathbf{\Theta} Θ是LLM的参数,在训练过程中被冻结, Θ b \mathbf{\Theta^b} Θb是每一层中beacon token对应的投影矩阵 W ∗ b \boldsymbol{W}^b_* W∗b和beacon token 的embedding e ⟨ b ⟩ \mathbf{e}_{\langle b \rangle} e⟨b⟩(所有beacon token使用共享embedding ),训练时beacon token不参与损失计算(标签被设置为-100)因为它们仅用作压缩。
训练时第i个chunk的压缩比 α i \alpha_i αi是随机地从{2, 4, 8, 16, 32}中选取的,意在让模型灵活地支持不同的压缩粒度。而在推理时可以根据下游任务选择一个压缩比并应用到所有chunk。
训练过程分为预训练和微调,消融实验表明两个阶段对模型效果都有提升。
注意,activation beacon默认的方式是将其交替地插入在原始上下文中(代码中的interleave),论文做消融实验时尝试将beacon token全部放在chunk的最后时效果是会下降的(代码中的append)。
注:Activation Beacon与MemoRAG是同一个团队出的,理解这篇思路之后,就能更好地理解MemoRAG的记忆模型了。(对比这两篇论文对应的记忆模型的代码,几乎是一样的,有点奇怪为什么memorag没有引用这篇文章,也没有对代码做说明。因为不理解memorag的记忆模型的代码,通过搜索关键字beacon搜到了这篇论文)。