目录
- 5.4 定位编辑法:ROME
- 5.4.1 知识存储位置
- 1)因果跟踪实验
- 2)阻断实验
- 5.4.2 知识存储机制
- 5.4.3 精准知识编辑
- 1)确定键向量
- 2)优化值向量
- 3)插入知识
5.4 定位编辑法:ROME
定位编辑:
-
首先定位知识存储在神经网络中的哪些参数中,
-
然后再针对这些定位到的参数进行精确的编辑。
ROME(Rank-One Model Editing)是其中的代表性方法。
.
5.4.1 知识存储位置
通过对知识进行定位,可以揭示模型内部的运作机制,这是理解和编辑模型的关键步骤。
ROME 通过因果跟踪实验和阻断实验发现知识存储于模型中间层的全连接前馈层。
.
1)因果跟踪实验
ROME通过因果跟踪实验探究模型中不同结构与知识在推理过程中的相关性。实验包含三个步骤:
-
正常推理:保存模型未受干扰时的内部状态,用于后续恢复。
-
干扰推理:干扰模型的所有内部状态,作为基准线。
-
恢复推理:逐步恢复内部状态,对比恢复前后的输出差异,评估每个模块与知识回忆的相关性。
最终目标是确定知识在模型中的具体位置。
实验中,每个知识被表示为知 识元组 t = (s, r, o),其中 s 为 主体, r 为 关系,o 为客体。输入问题为 q = (s, r),q^(i) 表示 q 的第 i 个 Token。我们期望模型在处理问题 q 时能 够输出对应的客体 o 作为答案。具体地,因果跟踪实验的步骤如下:
-
正常推理:输入问题 q=(s,r),让其预测出 o。此过程,保存模型内部所有模块的正常输出。
-
干扰推理:向 s 的嵌入层添加噪声,破坏输入向量,形成干扰状态。
-
恢复推理:在干扰状态下,逐个恢复输入问题中每个 Token q^(i) 的输出向量至“干净”状态,并记录恢复前后模型对答案预测概率的增量(称为模块的因果效应),用于评估各模块对知识回忆的贡献。
以问题“斑马的肤色是”为例,其因果跟踪过程如下:
图 5.12: 正常推理
图 5.13: 干扰推理
图 5.14: 恢复推理
ROME在1000个知识陈述上对三种模块进行因果跟踪实验,发现:
-
中间层Transformer在处理主体s的最后一个Token s⁻¹时,因果效应显著。
-
末尾层Transformer在处理输入问题q的最后一个Token q⁻¹时,因果效应也很强,但这在意料之中。
-
中间层Transformer在处理s⁻¹时的因果效应主要来自全连接前馈层。
-
注意力层主要对末尾层Transformer处理q⁻¹产生贡献。
基于这些发现,ROME认为模型中间层的全连接前馈层可能是模型中存储知识的关键位置。
.
2)阻断实验
为区分全连接前馈层和注意力层在 s^(−1) 处的因果效应中所起到的作用,并且验证全连接前馈层的主导性,ROME 对两种模型结构进行了阻断实验。
阻断实验原理
在恢复某一层Transformer处理s^(-1)的输出后,将后续的全连接前馈层(或注意力层)冻结为干扰状态,即隔离其计算,观察模型性能下降程度,从而明确各层的关键作用。
图 5.15: 阻断实验
实验分析
比较阻断前后的因果效应,ROME 发现:
-
如果没后续全连接前馈层的计算,中间层在处理 s^(−1) 时就会失去因果效应,而末尾层的因果效应几乎不受全连接前馈层缺失的影响。
-
而在阻断注意力层时,模型各层处理 s^(−1) 时的因果效应只有较小的下降。
.
基于上述,ROME 认为在大语言模型中:
-
知识存储于模型的中间层,其关键参数位于全连接前馈层,
-
而且特定中间层的全连接前馈层在处理主体的末尾 Token 时发生作用。
.
5.4.2 知识存储机制
基于在此之前研究成果,ROME 结合知识定位实验中的结论,推测知识以键值映射的形式等价地存储在任何一个中间层的全连接前馈层中,并对知识存储机制做出假设:
-
首先,起始的 Transformer 层中的注意力层收集主体 s 的信息,将其汇入至主体的最后一个 Token 的向量表示中。
-
接着,位于中间层的全连接前馈层对这个编码主体的向量表示进行查询,将查询到的相关信息融入残差流(Residual Stream)中。
-
最后,末尾的注意力层捕获并整理隐藏状态中的信息,以生成最终的输出。
.
5.4.3 精准知识编辑
与 T-Patcher 相似,ROME 同样将全连接前馈层视为键值存储体。不同的是:
-
T-patcher 将上投影矩阵的参数向量看作键向量,将下投影矩阵的参数向量看作值向量,
-
而 ROME 则是将下投影矩阵的输入向量看作键向量,将其输出向量看作值向量。
具体地,ROME 认为上投影矩阵 W f c W_{fc} Wfc 和激活函数 σ 能够计算出键向量 k∗,而下投影矩阵 W p r o j W_{proj} Wproj 会与键向量运算并输出值向量 v∗,类似信息的查询。
为了实现模型编辑,ROME 通过因果跟踪实验定位编辑位置,然后确定键向量,优化值向量,并通过插入新的键值对完成知识更新。其核心步骤包括:1. 确定键向量;2. 优化值向量;3. 插入知识。
图 5.16: ROME 模型编辑方法
.
1)确定键向量
首先,需要确定 s (−1) 在被编辑的全连接前馈层中的向量表示 k*。
键向量 k* **是通过将 s 输入模型并读取其在全连接前馈层激活函数后的向量表示来确定的。
为了提高泛化性,会在 s 前拼接随机的不同前缀文本,多次推理后计算平均向量作为 k*。
键向量的计算公式如下:
k ∗ = 1 N ∑ j = 1 N k ( x j + s ) k^* = \frac{1}{N} \sum_{j=1}^N k(x_j + s) k∗=N1j=1∑Nk(xj+s)
其中:
-
N 为样本数量,
-
j 为前缀文本索引,
-
x_j 为随机前缀文本,
-
k(x_j + s) 代表在拼接前缀文本 x_j 时,s 的末尾 Token 在被编辑的全连接前馈层中的激活函数输出,即下投影矩阵 W p r o j W_{proj} Wproj 的输入。
.
2)优化值向量
然后,需要确定一个值向量 v∗,作为下投影矩阵 W p r o j W_{proj} Wproj 与 k∗ 运算后的期望结果。ROME 通过优化全连接前馈层的输出向量获得 v∗。
训练过程中,ROME 通过设计损失函数 L(v) = L1(v) + L2(v) 以确保编辑的准确性和局部性,如图 5.18。其中 v 是优化变量,用于替换全连接前馈层的输出。
图 5.18: 优化值向量
损失函数 L ( v ) \mathcal{L}(v) L(v) 的公式如下:
L ( v ) = L 1 ( v ) + L 2 ( v ) \mathcal{L}(v) = \mathcal{L}_1(v) + \mathcal{L}_2(v) L(v)=L1(v)+L2(v)
L 1 ( v ) = 1 N ∑ j = 1 N − log P M ′ ( o ∣ x j + p ) \mathcal{L}_1(v) = \frac{1}{N} \sum_{j=1}^N -\log \mathbb{P}_{M'}(o \mid x_j + p) L1(v)=N1j=1∑N−logPM′(o∣xj+p)
L 2 ( v ) = D K L ( P M ′ ( x ∣ p ′ ) ∣ ∣ P M ( x ∣ p ′ ) ) \mathcal{L}_2(v) = D_{KL}(\mathbb{P}_{M'}(x \mid p') ||\mathbb{P}_M(x \mid p')) L2(v)=DKL(PM′(x∣p′)∣∣PM(x∣p′))
其中:
-
M 为原始模型;
-
M’ 为优化 v 时的模型;
-
o 为客体,即目标答案;
-
p 为所编辑的目标问题 prompt;
-
D K L D_{KL} DKL 为 KL 散度;
-
p’ 是有关 s 的含义的 prompt。
图 5.19: 值向量损失函数
如图 5.19, 在 L(v) 中:
-
为了确保准确性,L1(v) 旨在最大化 o 的概率,通过优化 v 使网络对所编辑的问题 prompt p 做出正确的预测,与计算 k∗ 时相同,也会在 p 之前拼接不同前缀文本;
-
为了确保局部性,L2(v) 在 p′ =“{s} 是”这种 prompt 下,最小化 M′ 与 M 输出的 KL 散度,以避免模型对 s 本身的理解发生偏移, 从而确保局部性。
.
3)插入知识
确定了知识在编辑位置的向量表示 k∗ 和 v∗ 之后,ROME 的目标是调整全连接前馈层中的下投影矩阵 W p r o j W_{proj} Wproj,使得 W p r o j k ∗ = v ∗ W_{proj} k^∗ = v^∗ Wprojk∗=v∗,从而将新知识插入到全连接前馈层中。
然而,在插入新知识的同时,需要尽量避免影响 W p r o j W_{proj} Wproj 中的原有信息。
过程可抽象为一个带约束的最小二乘问题,其形式如下:
确保最小影响: min ∥ W ^ K − V ∥ 确保最小影响:\quad \min \| \hat{W} K - V \| 确保最小影响:min∥W^K−V∥
满足 W p r o j k ∗ = v ∗ 关系: s.t. W ^ k ∗ = v ∗ 满足W_{proj} k^∗ = v^∗关系:\quad \text{s.t.} \quad \hat{W} k^* = v^* 满足Wprojk∗=v∗关系:s.t.W^k∗=v∗
该问题可推导出闭式解为:
W ^ = W + Λ ( C − 1 k ∗ ) T \hat{W} = W + \Lambda (C^{-1} k^*)^T W^=W+Λ(C−1k∗)T
其中:
-
Λ = v ∗ − W k ∗ ( C − 1 k ∗ ) T k ∗ \Lambda = \frac{v^* - W k^*}{(C^{-1} k^*)^T k^*} Λ=(C−1k∗)Tk∗v∗−Wk∗
-
W 为原始的权重矩阵
-
W ^ \hat{W} W^ 为更新后的权重矩阵
-
C = K K T C = K K^T C=KKT 是一个预先计算的常数,基于维基百科中的大量文本样本 k 的去中心化协方差矩阵进行估计
利用这一简代数方法,ROME 能直接插入代表知识元组的键值对 (k*, v*),实现对模型知识的精确编辑。
.
其他参考:【大模型基础_毛玉仁】系列文章
声明:资源可能存在第三方来源,若有侵权请联系删除!