transformer初探
- self-attention
- multihead-attention
- encoder
- decoder
self-attention
其实就是三个矩阵, W q W_q Wq、 W k W_k Wk、 W v W_v Wv,这三个矩阵就是需要训练的参数。分别得到每个token对应的 q q q k k k v v v,其中 q q q k k k 用来计算每个token之间的相似度,这里一般称为attention scores,然后通过一个Soft-max作一个norm。
拿到attention scores以后呢,既然已经知道了token的之间的“关联性”,再分别和 v v v作一个简单的加权求和,最后得到attention以后的输出 b b b
值得注意的是,上述操作是可以通过矩阵表示的,如下所示
multihead-attention
其实就是把前面一小节得到的 q q q k k k v v v 作一个拆分,每一个都拆成 n n n份,其中 n n n是head的数量。在 q q q k k k v v v的第 i i i个head中,都只与对应head作计算,然后将结果拼接起来就好。
encoder
encoder输入还会考虑一个位置编码,一起嵌入到Embedding表示后的token中。
整个计算过程也很直观
decoder
这里有一个很关键的点是,在encoder中只有self-attention,因为是一次性输入所有的token,计算每个token之间的关联性,得到一个编码后的输出。但是decoder是一个一个输入,每输入一个产生一个输出,虽然说这一步也可以用矩阵并行计算,其原理就是masked-attention。计算 b 1 b^1 b1的时候,我们只考虑 a 1 a^1 a1,计算 b 2 b^2 b2的时候,我们考虑 a 1 a^1 a1和 a 2 a^2 a2,依此类推。实现原理其实就是一个mask矩阵。
值得注意的是,在decoder中mask-attention后的输出,还会和encoder的输出再作一次attention,这被称为cross-attention。self-attention是同一个序列计算得分,而cross-attention是两个不同序列计算得分。
总结一下,encoder输入src序列,docoder输入target序列,最后将decoder的输出和target序列作一个cross entropy,优化的目标就是两个分布越接近越好。其实这一步可以看成一个分类,而类别就是词汇表的单词的总数?