文章目录
- 基本概念
- 模型
- 小结
基本概念
我们可以用独立学习得到的h组不同的 线性投影来变换查询、键和值。 然后,这h组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这h个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力。对于h个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。
模型
每个注意力头 h i h_i hi的计算公式为
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,
其中q-查询、k-键、v-值。 W i ( q ) W_i^(q) Wi(q)为q通过全连接层后得到的参数、 W i ( k ) W_i^(k) Wi(k)为k通过全连接层后得到的参数、 W i ( v ) W_i^(v) Wi(v)为v通过全连接层后得到的参数。
f f f为注意力汇聚函数,f内的注意力评分函数可以是加性注意力、缩放点击注意力。
多头注意力的输出需要经过另一个线性转换, 它对应着h个头连结后的结果,因此其可学习参数是 W o W_o Wo
W o [ h 1 ⋮ h h ] ∈ R p o . \begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split} Wo h1⋮hh ∈Rpo.
基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。
小结
-
多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
-
基于适当的张量操作,可以实现多头注意力的并行计算。