重新审视MHA与Transformer

本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上,早在一年前博主就已经分别介绍了两者:各种注意力机制的PyTorch实现、从零开始手写一个Transformer,但当时的实现大部分是基于d2l教程的,这次将基于PyTorch源码重新实现一遍。

目录

  • 1. MultiheadAttention
    • 1.1 思路
    • 1.2 源码
    • 1.3 极简版MHA(面试用)
  • 2. Transformer
  • 3. Q&A
    • 1. MHA的参数量?时间复杂度?FLOPs?

1. MultiheadAttention

1.1 思路

回顾多头注意力,其公式如下:

MHA ( Q , K , V ) = Concat ( head 1 , ⋯ , head h ) W O head i = Attn ( Q W i Q , K W i K , V W i V ) \text{MHA}(Q,K,V)=\text{Concat}(\text{head}_1,\cdots,\text{head}_h)W^O \\ \text{head}_i=\text{Attn}(QW_i^Q,KW_i^K,VW_i^V) MHA(Q,K,V)=Concat(head1,,headh)WOheadi=Attn(QWiQ,KWiK,VWiV)

其中 W i Q ∈ R d m o d e l × d k W_i^Q\in \mathbb{R}^{d_{model}\times d_k} WiQRdmodel×dk W i K ∈ R d m o d e l × d k W_i^K\in \mathbb{R}^{d_{model}\times d_k} WiKRdmodel×dk W i V ∈ R d m o d e l × d v W_i^V\in \mathbb{R}^{d_{model}\times d_v} WiVRdmodel×dv W O ∈ R h d v × d m o d e l W^O\in \mathbb{R}^{hd_v\times d_{model}} WORhdv×dmodel,且 d k = d v = d m o d e l / h d_k=d_v=d_{model}/h dk=dv=dmodel/h

如果记 d h e a d = d m o d e l / h d_{head}=d_{model}/h dhead=dmodel/h,则 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV 的形状均为 ( d m o d e l , d h e a d ) (d_{model},d_{head}) (dmodel,dhead) W O W^O WO 的形状为 ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel)

先不考虑batch和mask的情形,在只有一个头的情况下( h = 1 h=1 h=1),MHA的计算方式为

class MHA(nn.Module):def __init__(self, d_model):super().__init__()self.w_q = nn.Parameter(torch.empty(d_model, d_model))self.w_k = nn.Parameter(torch.empty(d_model, d_model))self.w_v = nn.Parameter(torch.empty(d_model, d_model))self.w_o = nn.Parameter(torch.empty(d_model, d_model))self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, query, key, value):"""Args:query: (n, d_model),n是query的个数,m是key-value的个数key: (m, d_model)value: (m, d_model)"""q = query @ self.w_qk = key @ self.w_kv = value @ self.w_vattn_logits = q @ k.transpose(0, 1) / math.sqrt(q.size(1))  # attn_logits: (n, m)attn_probs = F.softmax(attn_logits, dim=-1)attn_output = attn_probs @ v  # attn_output: (n, d_model)return attn_output, attn_probs

现在考虑 h = 2 h=2 h=2 的情形,此时一共需要 3 ⋅ 2 + 1 = 7 3\cdot2+1=7 32+1=7 个参数矩阵

class MHA(nn.Module):def __init__(self, d_model):super().__init__()self.w_q_1 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_k_1 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_v_1 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_q_2 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_k_2 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_v_2 = nn.Parameter(torch.empty(d_model, d_model // 2))self.w_o = nn.Parameter(torch.empty(d_model, d_model))self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, query, key, value):"""Args:query: (n, d_model),n是query的个数,m是key-value的个数key: (m, d_model)value: (m, d_model)"""q_1 = query @ self.w_q_1k_1 = key @ self.w_k_1v_1 = value @ self.w_v_1q_2 = query @ self.w_q_2k_2 = key @ self.w_k_2v_2 = value @ self.w_v_2attn_logits_1 = q_1 @ k_1.transpose(0, 1) / math.sqrt(q_1.size(1))attn_probs_1 = F.softmax(attn_logits_1, dim=-1)attn_output_1 = attn_probs_1 @ v_1attn_logits_2 = q_2 @ k_2.transpose(0, 1) / math.sqrt(q_2.size(1))attn_probs_2 = F.softmax(attn_logits_2, dim=-1)attn_output_2 = attn_probs_2 @ v_2attn_output = torch.cat([attn_output_1, attn_output_2], dim=-1) @ self.w_o  # attn_output: (n, d_model)attn_probs = torch.stack([attn_probs_1, attn_probs_2], dim=0)  # attn_probs: (2, n, m),其中2是头数return attn_output, attn_probs

可以看到代码量已经增加了不少,如果扩展到 h h h 个头的情形,则需要 3 h + 1 3h+1 3h+1 个参数矩阵。手动去一个个声明显然不现实,因为 h h h 是动态变化的,而用for循环创建又略显笨拙,有没有更简便的方法呢?

在上面的代码中,我们用小写 q q q 来代表查询 Q Q Q 经过投影后的结果( k , v k,v k,v 同理),即

q i = Q W i Q , i = 1 , 2 , ⋯ , h q_i=QW_i^Q,\quad i =1,2,\cdots,h qi=QWiQ,i=1,2,,h

其中 Q Q Q 的形状为 ( n , d m o d e l ) (n,d_{model}) (n,dmodel) q i q_i qi 的形状为 ( n , d h e a d ) (n,d_{head}) (n,dhead),且有

h e a d i = softmax ( q i k i T d h e a d ) v i head_i=\text{softmax}\left(\frac{q_ik_i^{T}}{\sqrt{d_{head}}}\right)v_i headi=softmax(dhead qikiT)vi

注意到

[ q 1 , q 2 , ⋯ , q h ] = Q [ W 1 Q , W 2 Q , ⋯ , W h Q ] (1) [q_1,q_2,\cdots,q_h]=Q[W_1^Q,W_2^Q,\cdots,W_h^Q]\tag{1} [q1,q2,,qh]=Q[W1Q,W2Q,,WhQ](1)

如果记 q ≜ [ q 1 , q 2 , ⋯ , q h ] q\triangleq [q_1,q_2,\cdots,q_h] q[q1,q2,,qh] W Q ≜ [ W 1 Q , W 2 Q , ⋯ , W h Q ] W^Q\triangleq [W_1^Q,W_2^Q,\cdots,W_h^Q] WQ[W1Q,W2Q,,WhQ],则 W Q W^Q WQ 的形状为 ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel) h h h 无关 q q q 的形状为 ( n , d m o d e l ) (n,d_{model}) (n,dmodel)。这样一来,我们就不需要一个个声明 W i Q W_i^Q WiQ 了,并且可以一次性存储所有的 q i q_i qi

要计算 h e a d 1 head_1 head1,我们需要能够从 q q q 中取出 q 1 q_1 q1 k , v k,v k,v 同理),所以我们期望 q q q 的形状是 ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead),从而 q [ 1 ] q[1] q[1] 就是 q 1 q_1 q1(这里下标从 1 1 1 开始)。

📝 当然也可以是 ( n , h , d h e a d ) (n,h,d_{head}) (n,h,dhead) 等形状,但必须要确保形状里含且只含这三个数字。之所以把 h h h 放在第一个维度是为了方便索引和后续计算。

同理可知 k , v k,v k,v 的形状均为 ( h , m , d h e a d ) (h,m,d_{head}) (h,m,dhead)。我们可以视 h h h 所在的维度为批量维,从而可以执行批量乘法 torch.bmm 来一次性算出 h h h 个头的结果。

q = torch.randn(h, n, d_head)
k = torch.randn(h, m, d_head)
v = torch.randn(h, m, d_head)# @和torch.bmm的效果相同,但写法更简洁
attn_logits = q @ k.transpose(1, 2) / math.sqrt(q.size(2))
attn_probs = F.softmax(attn_logits, dim=-1)
attn_output = attn_probs @ v  # attn_output: (h, n, d_head)

h h h 个头的结果存储在形状为 ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead) 的张量中,那我们如何把这 h h h 个结果concat在一起呢?注意到我们实际上是将 h h h 个形状为 ( n , d h e a d ) (n,d_{head}) (n,dhead) 的张量横向concat为一个形状为 ( n , d m o d e l ) (n,d_{model}) (n,dmodel) 的张量,因此只需执行如下的形状变换:

( h , n , d h e a d ) → ( n , h , d h e a d ) → ( n , h ⋅ d h e a d ) = ( n , d m o d e l ) (2) (h,n,d_{head})\to(n,h,d_{head})\to(n,h\cdot d_{head})=(n,d_{model}) \tag{2} (h,n,dhead)(n,h,dhead)(n,hdhead)=(n,dmodel)(2)

n = attn_output.size(1)
attn_output = attn_output.transpose(0, 1).reshape(n, -1)

⚠️ 注意,切勿直接将 ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead) reshape成 ( n , d m o d e l ) (n,d_{model}) (n,dmodel)

之前我们只讨论了 q q q 的形状应当是 ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead),但并没有讨论它是如何变换得来的。这是因为, Q Q Q 在经过投影后得到的 q q q 只具有 ( n , d m o d e l ) (n,d_{model}) (n,dmodel) 的形状,要进行形状变换,一种做法是对 q q q 沿纵向切 h h h 刀再堆叠起来,这样从直观上来看也比较符合公式 ( 1 ) (1) (1)

q = torch.randn(n, d_model)
q = torch.stack(torch.split(q, d_head, dim=-1), dim=0)

但由于 W Q W^Q WQ 初始时是随机的,所以我们不需要严格按照公式 ( 1 ) (1) (1) 那样操作,直接执行 ( 2 ) (2) (2) 的逆变换即可

( n , d m o d e l ) = ( n , h ⋅ d h e a d ) → ( n , h , d h e a d ) → ( h , n , d h e a d ) (n,d_{model})=(n,h\cdot d_{head})\to(n,h,d_{head})\to(h,n,d_{head}) (n,dmodel)=(n,hdhead)(n,h,dhead)(h,n,dhead)

现考虑有batch的情形,设批量大小为 b b b,则 Q Q Q 的形状为 ( b , n , d m o d e l ) (b,n,d_{model}) (b,n,dmodel) ( n , b , d m o d e l ) (n,b,d_{model}) (n,b,dmodel),具体是哪一个要看 batch_first 是否为 True。接下来均假设 batch_first = False

在以上的假设下, q q q 的形状也为 ( n , b , d m o d e l ) (n,b,d_{model}) (n,b,dmodel),我们将 b b b h h h 看成同一维度(都是批量维),从而 ( 2 ) (2) (2) 式改写为

( n , b , d m o d e l ) → ( n , b , h , d h e a d ) → ( n , b ⋅ h , d h e a d ) → ( b ⋅ h , n , d h e a d ) (n,b,d_{model})\to(n,b,h,d_{head})\to(n,b\cdot h,d_{head})\to(b\cdot h,n,d_{head}) (n,b,dmodel)(n,b,h,dhead)(n,bh,dhead)(bh,n,dhead)

关于 key_padding_maskattn_mask 这里不再介绍,如有需要可阅读博主之前的文章,这里主要讲解如何合并两种mask。

前者的形状为 ( b , m ) (b,m) (b,m),用来mask掉key中的 [PAD],防止query注意到它。而后者的形状可以是 ( n , m ) (n,m) (n,m) 也可以是 ( b ⋅ h , n , m ) (b\cdot h,n,m) (bh,n,m)。在实际合并两种mask的时候,我们均需要按照 ( b ⋅ h , n , m ) (b\cdot h,n,m) (bh,n,m) 这个形状去计算。也就是说,如果是 key_padding_mask,我们需要进行形状变换 ( b , m ) → ( b , 1 , 1 , m ) → ( b , h , 1 , m ) → ( b ⋅ h , 1 , m ) (b,m)\to(b,1,1,m)\to(b,h,1,m)\to(b\cdot h,1,m) (b,m)(b,1,1,m)(b,h,1,m)(bh,1,m);如果是 attn_mask,我们需要进行形状变换 ( n , m ) → ( 1 , n , m ) (n,m)\to(1,n,m) (n,m)(1,n,m)

1.2 源码

本节将遵循以下记号:

记号说明
b b bbatch size
h h hnum heads
d d dhead dim
n n nnum queries
m m mnum key-value pairs

首先实现一个MHA的基类:

class MultiheadAttentionBase_(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0., bias=True):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.dropout = dropoutself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dimself.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))if bias:self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))else:self.register_parameter('in_proj_bias', None)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)self._reset_parameters()def _reset_parameters(self):nn.init.xavier_uniform_(self.in_proj_weight)if self.in_proj_bias is not None:nn.init.constant_(self.in_proj_bias, 0.)nn.init.constant_(self.out_proj.bias, 0.)def forward(self,query,key,value,key_padding_mask,attn_mask,need_weights=True,):"""Args:query: (n, b, h * d)key: (m, b, h * d)value: (m, b, h * d)key_padding_mask: (b, m), bool typeattn_mask: (n, m) or (b * h, n, m), bool typeReturns:attn_output: (n, b, h * d)attn_weights: (b, h, n, m)"""w_q, w_k, w_v = self.in_proj_weight.chunk(3)if self.in_proj_bias is not None:b_q, b_k, b_v = self.in_proj_bias.chunk(3)else:b_q = b_k = b_v = Noneq = F.linear(query, w_q, b_q)k = F.linear(key, w_k, b_k)v = F.linear(value, w_v, b_v)b, h, d = q.size(1), self.num_heads, self.head_dimq, k, v = map(lambda x: x.reshape(-1, b, h, d), [q, k, v])attn_mask = self.merge_masks(key_padding_mask, attn_mask, q)attn_output, attn_weights = self.attention(q, k, v, attn_mask, out_proj=self.out_proj, dropout=self.dropout, training=self.training)if not need_weights:attn_weights = Nonereturn attn_output, attn_weightsdef merge_masks(self, key_padding_mask, attn_mask, q):"""Args:key_padding_mask: (b, m), bool typeattn_mask: (n, m) or (b * h, n, m), bool typeq: only used to confirm the dtype of attn_maskReturns:attn_mask: (b * h, n, m), float type"""assert key_padding_mask is not None and key_padding_mask.dtype == torch.boolb, m = key_padding_mask.size()key_padding_mask = key_padding_mask.view(b, 1, 1, m).expand(-1, self.num_heads, -1, -1).reshape(b * self.num_heads, 1, m)if attn_mask is not None:assert attn_mask.dtype == torch.boolif attn_mask.dim() == 2:attn_mask = attn_mask.unsqueeze(0)attn_mask = attn_mask.logical_or(key_padding_mask)else:attn_mask = key_padding_maskattn_mask = torch.zeros_like(attn_mask, dtype=q.dtype).masked_fill_(attn_mask, -1e28)return attn_maskdef attention(self, q, k, v, attn_mask, out_proj, dropout, training):"""Args:q: (n, b, h, d)k: (m, b, h, d)v: (m, b, h, d)attn_mask: (b * h, n, m), float typeout_proj: nn.Linear(h * d, h * d)Returns:attn_output: (n, b, h * d), is the result of concating h heads.attn_weights: (b, h, n, m)"""raise NotImplementedError

接下来,只需要重写 attention 方法就可以实现普通版的MHA了

class MultiheadAttention(MultiheadAttentionBase_):def attention(self, q, k, v, attn_mask, out_proj, dropout, training):if not training:dropout = 0n, b, h, d = q.size()q, k, v = map(lambda x: x.reshape(-1, b * h, d).transpose(0, 1), [q, k, v])attn_logits = q @ k.transpose(-2, -1) / math.sqrt(d) + attn_maskattn_probs = F.softmax(attn_logits, dim=-1)attn_weights = F.dropout(attn_probs, p=dropout)attn_output = attn_weights @ vattn_output = attn_output.transpose(0, 1).reshape(n, b, h * d)attn_output = out_proj(attn_output)return attn_output, attn_weights

1.3 极简版MHA(面试用)

不少面试会让现场手写MHA,这里提供了一份模版,略去了很多细节。

相比原版,极简版做了如下改动:

  • 略去了参数初始化。
  • 去掉了mask
class MultiheadAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0., bias=True):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.dropout = nn.Dropout(dropout)self.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dimself.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))if bias:self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))else:self.register_parameter('in_proj_bias', None)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)def forward(self, query, key, value):"""Args:query: (n, b, h * d)key: (m, b, h * d)value: (m, b, h * d)"""w_q, w_k, w_v = self.in_proj_weight.chunk(3)if self.in_proj_bias is not None:b_q, b_k, b_v = self.in_proj_bias.chunk(3)else:b_q = b_k = b_v = Noneq, k, v = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)b, h, d = q.size(1), self.num_heads, self.head_dimq, k, v = map(lambda x: x.reshape(-1, b * h, d).transpose(0, 1), [q, k, v])attn_logits = q @ k.transpose(-2, -1) / math.sqrt(d)attn_probs = F.softmax(attn_logits, dim=-1)attn_weights = self.dropout(attn_probs)attn_output = attn_weights @ vattn_output = attn_output.transpose(0, 1).reshape(-1, b, h * d)attn_output = self.out_proj(attn_output)return attn_output, attn_weights

注意,如果尝试直接输出的话,会得到一堆 nan,这是因为没有xavier初始化,需要 _reset_parameters() 一下。

具体需要哪种mask可根据面试官的要求去实现。

2. Transformer

接下来基于PyTorch官方的MHA来实现Transformer。

首先需要实现一个基础函数,它可以用来复制一个 Module N次。

def _get_clones(module, n):return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])

EncoderLayer的实现

class TransformerEncoderLayer(nn.Module):def __init__(self,d_model,n_head,d_ffn,dropout=0.1,activation=F.relu,norm_first=False,):super().__init__()self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, dropout=dropout)self.dropout1 = nn.Dropout(dropout)self.linear1 = nn.Linear(d_model, d_ffn)self.activation = activationself.dropout2 = nn.Dropout(dropout)self.linear2 = nn.Linear(d_ffn, d_model)self.dropout3 = nn.Dropout(dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm_first = norm_firstdef forward(self, src, src_mask, src_key_padding_mask):x = srcif self.norm_first:x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)x = x + self._ff_block(self.norm2(x))else:x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))x = self.norm2(x + self._ff_block(x))return xdef _sa_block(self, x, attn_mask, key_padding_mask):x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]return self.dropout1(x)def _ff_block(self, x):x = self.linear2(self.dropout2(self.activation(self.linear1(x))))return self.dropout3(x)

这里的 norm_first 用来决定是Pre-LN还是Post-LN,如下图所示

DecoderLayer的实现

class TransformerDecoderLayer(nn.Module):def __init__(self,d_model,n_head,d_ffn,dropout=0.1,activation=F.relu,norm_first=False,):super().__init__()self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, dropout=dropout)self.dropout1 = nn.Dropout(dropout)self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, dropout=dropout)self.dropout2 = nn.Dropout(dropout)self.linear1 = nn.Linear(d_model, d_ffn)self.activation = activationself.dropout3 = nn.Dropout(dropout)self.linear2 = nn.Linear(d_ffn, d_model)self.dropout4 = nn.Dropout(dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.norm_first = norm_firstdef forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask):x = tgtif self.norm_first:x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)x = x + self._ca_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)x = x + self._ff_block(self.norm3(x))else:x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))x = self.norm2(x + self._ca_block(x, memory, memory_mask, memory_key_padding_mask))x = self.norm3(x + self._ff_block(x))return xdef _sa_block(self, x, attn_mask, key_padding_mask):x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]return self.dropout1(x)def _ca_block(self, x, mem, attn_mask, key_padding_mask):x = self.cross_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]return self.dropout2(x)def _ff_block(self, x):x = self.linear2(self.dropout3(self.activation(self.linear1(x))))return self.dropout4(x)

根据EncoderLayer搭建Encoder。需要注意的是,PyTorch源码中还提供了 encoder_norm 这一参数,即决定是否在Encoder最后放一个LN。

class TransformerEncoder(nn.Module):def __init__(self, encoder_layer, num_layers, encoder_norm=None):super().__init__()self.layers = _get_clones(encoder_layer, num_layers)self.num_layers = num_layersself.encoder_norm = encoder_normdef forward(self, src, src_mask, src_key_padding_mask):output = srcfor mod in self.layers:output = mod(output, src_mask, src_key_padding_mask)if self.encoder_norm is not None:output = self.encoder_norm(output)return output

DecoderLayer同理

class TransformerDecoder(nn.Module):def __init__(self, decoder_layer, num_layers, decoder_norm=None):super().__init__()self.layers = _get_clones(decoder_layer, num_layers)self.num_layers = num_layersself.decoder_norm = decoder_normdef forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask):output = tgtfor mod in self.layers:output = mod(output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)if self.decoder_norm is not None:output = self.decoder_norm(output)return output

PyTorch官方的Transformer默认添加 encoder_normdecoder_norm,然而这对于Post-LN的情形,无疑是多余的,所以这里我们做个简单修改,即如果是Post-LN情形,就不在最后添加LN了。

class Transformer(nn.Module):def __init__(self,d_model=512,n_head=8,num_encoder_layers=6,num_decoder_layers=6,d_ffn=2048,dropout=0.1,activation=F.relu,norm_first=False,):super().__init__()if norm_first:encoder_norm, decoder_norm = nn.LayerNorm(d_model), nn.LayerNorm(d_model)else:encoder_norm = decoder_norm = Noneencoder_layer = TransformerEncoderLayer(d_model, n_head, d_ffn, dropout, activation, norm_first)self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, n_head, d_ffn, dropout, activation, norm_first)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self,src,tgt,src_mask=None,tgt_mask=None,memory_mask=None,src_key_padding_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None,):memory = self.encoder(src, src_mask, src_key_padding_mask)output = self.decoder(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)return output

截止到目前,我们实现的Transfomer并不是完整的,还缺少embedding层和Decoder后面的Linear层,这里只介绍前者,因为后者仅仅是简单的 nn.Linear(d_model, tgt_vocab_size)

Transformer的embedding层分为token embedding和Positional Encoding,前者是可学习的 nn.Embedding,后者是固定的Sinusoidal编码。

PE的公式为

P [ i , 2 j ] = sin ⁡ ( i 1000 0 2 j / d m o d e l ) P [ i , 2 j + 1 ] = cos ⁡ ( i 1000 0 2 j / d m o d e l ) 0 ≤ i < m a x _ l e n , 0 ≤ j < d m o d e l P[i,2j]=\sin\left(\frac{i}{10000^{2j/d_{model}}}\right)\\ P[i,2j+1]=\cos\left(\frac{i}{10000^{2j/d_{model}}}\right) \\ 0\leq i < max\_len,\;0\leq j<d_{model} P[i,2j]=sin(100002j/dmodeli)P[i,2j+1]=cos(100002j/dmodeli)0i<max_len,0j<dmodel

class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(dropout)position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe = torch.zeros(max_len, 1, d_model)  # 1是batch size维度pe[:, 0, 0::2] = torch.sin(position * div_term)pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0)]return self.dropout(x)

3. Q&A

1. MHA的参数量?时间复杂度?FLOPs?

只考虑自注意力情形。为简便起见,令 h ≜ d m o d e l h\triangleq d_{model} hdmodel

MHA模块一共包含四个参数矩阵: W Q , W K , W V , W O W^Q,W^K,W^V,W^O WQ,WK,WV,WO,形状均为 ( h , h ) (h,h) (h,h),因此weight部分的参数量是 4 ⋅ h 2 4\cdot h^2 4h2。每个参数矩阵都会带有一个长度为 h h h 的bias,因此总共的参数量为 4 h 2 + 4 h 4h^2+4h 4h2+4h

📝 注意FLOPs和FLOPS的含义不同。前者是floating point operations,指浮点运算数,可以理解为计算量,用来衡量模型/算法的复杂度;后者是floating point operations per second,指每秒浮点运算次数,可以理解为计算速度,用来衡量衡量硬件的性能。

在计算形状为 ( m , n ) (m,n) (m,n) ( n , k ) (n,k) (n,k) 矩阵的乘积时,每计算一次内积都要执行 n n n 次乘法和 n n n 次加法,而最终输出矩阵的形状为 ( m , k ) (m,k) (m,k),所以总共的浮点运算次数为 ( n + n ) ⋅ m ⋅ k = 2 m n k (n+n)\cdot m\cdot k=2mnk (n+n)mk=2mnk

回到MHA,只考虑矩阵乘法

  • 首先会对形状为 ( l , b , h ) (l,b,h) (l,b,h) 的embedding进行投影,执行的矩阵乘法为 ( l , b , h ) × ( h , h ) → ( l , b , h ) (l,b,h)\times (h, h)\to(l,b,h) (l,b,h)×(h,h)(l,b,h),这一步的计算量为 2 l b h 2 2lbh^2 2lbh2。由于会分别投影到 Q , K , V Q,K,V Q,K,V 三个矩阵,因此这一步的总计算量为 6 l b h 2 6lbh^2 6lbh2
  • 接下来是 Q K T QK^T QKT 相乘,执行的矩阵乘法为 ( b ⋅ n h , l , h d ) × ( b ⋅ n h , h d , l ) → ( b ⋅ n h , l , l ) (b\cdot nh,l,hd)\times(b\cdot nh,hd,l)\to(b\cdot nh,l,l) (bnh,l,hd)×(bnh,hd,l)(bnh,l,l),其中 n h nh nh 代表 num_heads h d hd hd 代表 head_dim。计算量为 2 l 2 b h 2l^2bh 2l2bh
  • 然后是对 V V V 进行加权,执行的矩阵乘法为 ( b ⋅ n h , l , l ) × ( b ⋅ n h , l , h d ) → ( b ⋅ n h , l , h d ) (b\cdot nh,l,l)\times(b\cdot nh,l,hd)\to(b\cdot nh,l,hd) (bnh,l,l)×(bnh,l,hd)(bnh,l,hd),计算量为 2 l 2 b h 2l^2bh 2l2bh
  • 最后的投影中,执行的矩阵乘法为 ( l , b , h ) × ( h , h ) → ( l , b , h ) (l,b,h)\times(h,h)\to(l,b,h) (l,b,h)×(h,h)(l,b,h),计算量为 2 l b h 2 2lbh^2 2lbh2

由上述步骤可知,MHA的FLOPs约为 6 l b h 2 + 2 l 2 b h + 2 l 2 b h + 2 l b h 2 = 4 l b h ( 2 h + l ) 6lbh^2+2l^2bh+2l^2bh+2lbh^2=4lbh(2h+l) 6lbh2+2l2bh+2l2bh+2lbh2=4lbh(2h+l)

再来看MHA的复杂度,依然只考虑矩阵乘法。在计算形状为 ( m , n ) (m,n) (m,n) ( n , k ) (n,k) (n,k) 矩阵的乘积时,计算内积的时间复杂度为 O ( n ) O(n) O(n),而输出矩阵的形状为 ( m , k ) (m,k) (m,k),填满这个矩阵所需要的时间为 O ( m k ) O(mk) O(mk),所以总时间复杂度为 O ( m n k ) O(mnk) O(mnk)

可以发现一个不严谨的等式(仅针对矩阵乘法场景):

时间复杂度 = O ( FLOPs 2 ) 时间复杂度=O\left(\frac{\text{FLOPs}}{2}\right) 时间复杂度=O(2FLOPs)

由此可得到MHA的时间复杂度为 O ( 2 l b h ( 2 h + l ) ) = O ( l b h 2 + l 2 b h ) O(2lbh(2h+l))=O(lbh^2+l^2bh) O(2lbh(2h+l))=O(lbh2+l2bh)。特别地,当 b = 1 b=1 b=1 h ≪ l h\ll l hl 时,MHA的复杂度退化为 O ( l 2 h ) O(l^2h) O(l2h),这就是Transformer那篇论文里提到的复杂度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15562.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

opencv顺时针,逆时针旋转视频并保存视频

原视频 代码 import cv2# 打开视频文件 video cv2.VideoCapture(inference/video/lianzhang.mp4)# 获取原视频的宽度和高度 width int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))# 创建视频编写器并设置输出视频参数 fourcc …

【C++】类和对象(下)

1、初始化列表 初始化列表&#xff1a;以一个冒号开始&#xff0c;接着是一个以逗号分隔的数据成员列表&#xff0c;每个"成员变量"后面跟一个放在括号中的初始值或表达式。 class Date { public:Date(int year, int month, int day): _year(year), _month(month), _…

OSPF协议RIP协议+OSPF实验(eNSP)

本篇博客主要讲解单区域的ospf&#xff0c;多区域的仅作了解。 目录 一、OSPF路由协议概述 1.内部网关协议和外部网关协议 二、OSPF的应用环境 1.从以下几方面考虑OSPF的使用 2.OSPF的特点 三、OSPF重要基本概念 3.1&#xff0c;辨析邻居和邻接关系以及七种邻居状态 3…

【MySQL】索引与B+树

【MySQL】索引与B树 索引概念前导硬件软件方面 索引的理解单个page多个page引入B树B树的特征为什么B树做索引优于其他数据结构&#xff1f;聚簇索引与非聚簇索引辅助索引 索引的创建主键索引的创建和查看唯一键索引的创建和查看普通索引的创建和查看复合索引全文索引索引的其他…

js全端支持的深拷贝structuredClone

Jul 7, 2023 经过一年半的试用&#xff0c;structuredClone转正了&#xff0c;全端可以正式使用。 https://developer.mozilla.org/en-US/docs/Web/API/structuredClone

OpenHarmony开源鸿蒙学习入门 - 基于3.2Release 应用开发环境安装

OpenHarmony开源鸿蒙学习入门 - 基于3.2Release 应用开发环境安装 基于目前官方master主支&#xff0c;最新文档版本3.2Release&#xff0c;更新应用开发环境安装文档。 一、安装IDE&#xff1a; 1.IDE安装的系统要求 2.IDE下载官网链接&#xff08;IDE下载链接&#xff09; …

Modbus tcp转ETHERCAT在Modbus软件中的配置方法

Modbus tcp和ETHERCAT是两种不同的协议&#xff0c;这给工业生产带来了很大的麻烦&#xff0c;因为这两种设备之间无法通讯。但是&#xff0c;捷米JM-ECT-TCP网关的出现&#xff0c;却为这个难题提供了解决方案。 JM-ECT-TCP网关能够连接到Modbus tcp总线和ETHERCAT总线中&…

网络面试合集

传输层的数据结构是什么&#xff1f; 就是在问他的协议格式&#xff1a;UDP&TCP 2.1.1三次握手 通信前&#xff0c;要先建立连接&#xff0c;确保双方都是在线&#xff0c;具有数据收发的能力。 2.1.2四次挥手 通信结束后&#xff0c;会有一个断开连接的过程&#xff0…

Qsys介绍

文章目录 前言一、为什么需要Qsys1、简化了系统的设计流程2、Qsys涉及的技术 二、Qsys真身1、一种系统集成工具2、何为Nios II1、内核架构2、Nios II选型 三、Qsys设计涉及到的软件&工具四、总结五、参考资料 前言 Qsys是Altera下的一个系统集成工具&#xff0c;可用于搭建…

APP自动化测试-Python+Appium+Pytest+Allure框架实战封装(详细)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 pytest只是单独的…

JVM入门篇-JVM的概念与学习路线

JVM入门篇-JVM的概念与学习路线 什么是 JVM 定义 Java Virtual Machine - java 程序的运行环境&#xff08;java 二进制字节码的运行环境&#xff09; 好处 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收功能数组下标越界检查多态 比较 jvm jre jdk 常…

单片机第一季:零基础12——I2C和EEPROM

目录 1&#xff0c;EEPROM 2&#xff0c;I2C 2.1&#xff0c;I2C物理层 2.2&#xff0c;I2C协议层 3&#xff0c;AT24C02介绍 4&#xff0c;代码 1&#xff0c;EEPROM 为什么需要EEPROM&#xff1f; 单片机内部的ROM只能在程序下载时进行擦除和改写&#xff0c;但是…

护眼灯全光谱和减蓝光哪个好?推荐五款好用护眼台灯

如今&#xff0c;面临视力下降的问题越来越重视&#xff0c;护眼灯越来越成为人们日常生活中不可或缺的一部分&#xff0c;特别是在工作和学习中使用电脑、手机等电子设备时间较长的人群中。对于护眼灯来说&#xff0c;全光谱和减蓝光都是其主要功能之一&#xff0c;那么哪一种…

aws中opensearch 日志通(Centralized Logging with OpenSearch)2.0(一)

aws日志通2.0 实现全面的日志管理和分析功能 一体化日志摄取 &#xff1a;把aws服务器日志和应用日志传输到opensearch域中无代码日志处理 &#xff1a;在网页控制台中就可以实现数据处理开箱即用 &#xff1a;提供可视化模版&#xff08;nginx、HTTP server &#xff09; 架构…

使用 CSS 自定义属性

我们常见的网站日夜间模式的变化&#xff0c;其实用到了 css 自定义属性。 CSS 自定义属性&#xff08;也称为 CSS 变量&#xff09;是一种在 CSS 中预定义和使用的变量。它们提供了一种简洁和灵活的方式来通过多个 CSS 规则共享相同的值&#xff0c;使得样式更易于维护和修改。…

【LeetCode每日一题】——566.重塑矩阵

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时间频度】九【代码实现】十【提交结果】 一【题目类别】 矩阵 二【题目难度】 简单 三【题目编号】 566.重塑矩阵 四【题目描述】 在 MATLAB 中&…

小红书运营推广方法分享

大家好&#xff0c;我是网媒智星&#xff0c;今天跟大家讨论一下小红书的运营推广方法&#xff0c;总结了七点经验分享给大家。 首先&#xff0c;让我们了解一下什么是热门文案。热门文案可从以下三个方面来定义&#xff1a; 1. 阅读量&#xff1a;如果一篇小红书的阅读量达到上…

【RabbitMQ】golang客户端教程1——HelloWorld

一、介绍 本教程假设RabbitMQ已安装并运行在本机上的标准端口&#xff08;5672&#xff09;。如果你使用不同的主机、端口或凭据&#xff0c;则需要调整连接设置。如果你未安装RabbitMQ&#xff0c;可以浏览我上一篇文章Linux系统服务器安装RabbitMQ RabbitMQ是一个消息代理&…

《MySQL 实战 45 讲》课程学习笔记(四)

深入浅出索引 索引的出现其实就是为了提高数据查询的效率&#xff0c;就像书的目录一样。 索引的常见模型 哈希表 哈希表是一种以键 - 值&#xff08;key-value&#xff09;存储数据的结构&#xff0c;我们只要输入待查找的值即 key&#xff0c;就可以找到其对应的值即 Val…

docker中涉及的挂载点总结

文章目录 1.场景描述2. 容器信息在主机上位置3. 通过docker run 命令4、通过Dockerfile创建挂载点5、容器共享卷&#xff08;挂载点&#xff09;6、最佳实践&#xff1a;数据容器 1.场景描述 在介绍VOLUME指令之前&#xff0c;我们来看下如下场景需求&#xff1a; 1&#xff…