[论文地址] [代码] [ICLR 22]
阅前须知:本博文可能有描述不准确/过度简化/出错的地方,仅供参考。
网络结构
其中,原有模型的参数是直接冻结的,可训练参数只有额外引入的LoRA参数(由nn.Parameter实现)。
模型微调的本质
记网络原有预训练参数为 W 0 ∈ R d × k W_0 \in R^{d \times k} W0∈Rd×k。在新的下游任务微调后,参数变为 W ∈ R d × k W \in R^{d \times k} W∈Rd×k。可以发现,参数的变化量 Δ W = W − W 0 \Delta W = W - W_0 ΔW=W−W0。换而言之,有: W = W 0 + Δ W W=W_0+\Delta W W=W0+ΔW 也就是说,对模型微调,其实可以将原有参数 W 0 W_0 W0直接给冻结了,只学这个变化量 Δ W = W − W 0 \Delta W = W - W_0 ΔW=W−W0即可。
为什么要进行低秩分解
LoRA文中指出,现有的预训练模型通常是过参数化的(the learned over-parametrized models in fact reside on a low intrinsic dimension),在对这些模型进行微调时,参数的更新主要在低维子空间中。换而言之,很多高维子空间的参数在微调前后根本就没动。基于这一点,微调所学的 Δ W \Delta W ΔW其实也就不需要那么高的维度(秩),我们可以将其降低到一个更低的维度进行优化。当然从这里也可以注意到,如果参数的更新也会大量发生在高维子空间中,此时进行低秩分解会遗漏信息,导致LoRA失效。
如何理解低维子空间/高维子空间特征
这里笔者给出一个可能不正确的类比。比如在计算机视觉中,无论是做分割,检测,医学等各种不同下游任务,都可以基于ImageNet上的预训练模型(如ResNet)进行微调。预训练模型中的纹理,边缘,轮廓等特征,一般是无论做哪种任务都需要的,那么这种任务无关特征就类似于上面所提到的高维子空间特征,在下游任务微调时基本上不发生变化。反之,对于一些下游任务中自有的先验特征(比如特有的光照条件,目标位置分布),则可以被视为上面所提到的低维子空间特征。模型想要刷点到SOTA则必须对这些任务相关特征进行有效的利用。
以数学形式描述低秩分解
LoRA将参数变化量矩阵 Δ W \Delta W ΔW分解成了两个更低秩的矩阵相乘,有: Δ W = B A \Delta W=BA ΔW=BA其中 B ∈ R d × r B \in R^{d \times r} B∈Rd×r, A ∈ R r × k A \in R^{r \times k} A∈Rr×k。
为什么矩阵B被初始化为0,而矩阵A正常高斯初始化
这里讨论另外两种设置的缺点:
- 如果B,A全都初始化为0,那么缺点与深度网络全0初始化一样,很容易导致梯度消失(因为此时初始所有神经元的功能都是等价的)。
- 如果B,A全部高斯初始化,那么在网络训练刚开始就会有概率为得到一个过大的偏移值 Δ W \Delta W ΔW从而引入太多噪声,导致难以收敛。
因此,一部分初始为0,一部分正常初始化是为了在训练开始时维持网络的原有输出(初始偏移为0),但同时也保证在真正开始学习后能够更好的收敛。
低秩分解到底有多低
哪怕降到8也是高度可用的,甚至能降到1:
注意这里r=64时性能甚至降低了。按照先前的结论来解释,这是因为参数的更新大多在低秩空间内;使用高秩矩阵允许对高维空间进行更新,反而可能会导致额外的非必要参数变化(引入了噪声)。
LoRA最终被插入在网络的哪些地方
只加在了Self Attention层的Q,K,V,O矩阵上,其余部分诸如MLP等位置则没有添加。当然,后续也有一些实验[1]表明,在其他任务中只添加在Q和K上会更好,如下图所示。因此这也可以算实际应用LoRA中一个可调节的点了。
LoRA与Adapter的区别
其实从结构上讲,更早出现的Adapter也是引入了少量可训练参数,并且也具有先降维再升维的"BottleNeck"型结构,如下所示:
主要的区别个人认为有如下几点:
- 插入位置。LoRA是以残差连接的形式"并联"在Transformer的Q,K,V,O矩阵上,而Adapter是插入在Feed-forward Layer后面。
- 推理延迟。LoRA在训练完后其参数可以与原有预训练模型直接合并,变回单分支结构,不会引入额外的延迟;而Adapter由于引入了额外的串联网络层,因此会带来额外的延迟。
- 参数存储。使用LoRA进行微调,在训练完毕后只需要保存LoRA本身的参数;而使用Adapter则要保存整个原有模型的参数。
参考文献
[1] Customized Segment Anything Model for Medical Image Segmentation