深度学习 | Transformer模型及代码实现

        Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

        


 

一、Transformer模型

 

1、模型结构

        首先介绍 Transformer 的整体结构,下图是 Transformer 用于中英文翻译的整体结构。

        

        可以看到 Transformer 由 Encoder 和 Decoder 两个部分组成,Encoder 和 Decoder 都包含 6 个 block。 6是随机选择的数字,也可以是其他的数字。我们可以将这个结构看成是串联在一起的电池组,彼此之间通过多次的非线性变换在不同的空间提取更多的信息。

        

        编码器的结构相同,但不共享权重。每个编码器具体来说由两个子层组成。

        自注意力层能够帮助编码器在编码特定单词时考虑句子中其他的单词,输出会输入到一个独立的前馈神经网络中,每个编码器都有相同的前馈神经网络运行。

        解码器又嵌入了一个 Encoder-Decoder注意力层,帮助解码器专注于输入句子的相关部分,类似于seq2seq中的注意力机制。

        在自然语言处理中,注意首先需要使用嵌入算法将每个单词转换为向量,每个词都会嵌入到一个512维的向量中,512是一个超参数,代表训练数据集中句子的最大长度。

        注意 每个词的流动路径都是独立的。词语之间的依赖关系是通过自注意力层来表达的,前向反馈层没有相互之间的计算,所以前向反馈层可以并行计算。

         

        

2、编码器

1)、位置嵌入 Embedding

        Transformer 中单词的输入表示 x 由单词 Embedding 和位置 Embedding 相加得到。

        

        其中单词的 Embedding (嵌入向量) 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。

        而位置Embedding(位置编码向量)的则是通过定义的正余弦函数来得到的。

        

        其中,pos 表示单词在句子中的位置,d 表示 PE的维度 (与词 Embedding 一样),2i 表示偶数的维度,2i+1 表示奇数维度 (即 2i≤d, 2i+1≤d)。

        使用这种公式计算 PE 有以下的好处:

        使 PE 能够适应比训练集里面所有句子更长的句子,假设训练集里面最长的句子是有 20 个单词,突然来了一个长度为 21 的句子,则使用公式计算的方法可以计算出第 21 位的 Embedding。可以让模型容易地计算出相对位置,对于固定长度的间距 k,PE(pos+k) 可以用 PE(pos) 计算得到。因为 Sin(A+B) = Sin(A)Cos(B) + Cos(A)Sin(B), Cos(A+B) = Cos(A)Cos(B) - Sin(A)Sin(B)。将单词的词 Embedding 和位置 Embedding 相加,就可以得到单词的表示向量 x,x 就是Transformer 的输入。

        我们之前学习的RNN模型中,并没有使用过位置Embedding,那为什么Transformer中要引入位置信息呢?

        这是因为 Transformer 不采用 RNN 的结构,而是使用全局信息,不能利用单词的顺序信息,而这部分信息对于NLP来说非常重要 所以 Transformer 中使用位置 Embedding 保存单词在序列中的相对或绝对位置。

2)Transformer 中的多头注意力机制

        在 Transformer 论文中,通过添加多头注意力机制,进一步完善了自注意力层。

        也就是说一个输入向量会分别生成不同的 Q K V 的组合,从而得到不同的注意力权重 Z,再拼接到一起,这样的话扩展了模型关注不同位置的能力,为注意力层提供了表示子空间。有点类似CNN中不同的卷积核,用于捕捉输入数据不同维度特征,这样在句子比较长 容易产生歧义或一词多义的情况下也能更好的提取特征信息。

        这些 Q K V 变换矩阵都是通过训练得到的,Transformer中就是用了八个头。

        

        举个例子:当我们在对下面句子中的 it 进行编码时

        不同的颜色代表不同头,颜色深浅代表自注意力权重,

        以it为例,编码时关注重点The animal、tired。

         

        下面是一个演示,不同颜色表示不同的头Q K V,类似CNN中不同的卷积核,捕获不同中的注意力权重。

        

        论文中给出的模型架构如下:

        

        左侧为编码器块,右侧为解码器块。橙色框中的部分就是Multi-Head Attention,它是由多个Self-Attention组成的,可以看到 Encoder block 包含一个Multi-Head Attention,而Decoder block 包含两个Multi-Head Attention (其中有一个用到 Masked)。Multi-Head Attention 上方还包括一个Add & Norm 层,这里Add 表示残差连接 (Residual Connection) 用于防止网络退化,Norm 表示 Layer Normalization,用于对每一层的激活值进行归一化。

        关于多头注意力和自注意力的内容,我们上一节已经介绍过,这里不再赘述。

3)残差结构

         在每个编码器子层和解码器子层中都使用了残差连接和归一化,他们可以让网络更容易学习复杂特征,从而避免梯度消失和爆炸的问题,同时训练更稳定,层归一化可以加速模型的收敛过程,有助于提高模型的泛化能力和稳定性。

         

3、解码器

        解码器需要同时链接编码器的输出。就像RNN一样,换句话说每一步解码都要使用编码器的输出来生成序列中下一个单词的表示。

        通过连接编码器和解码器模型可以有效的利用编码器对输入序列的理解从而生成更加准确的输出序列,同时也可以避免信息丢失的问题,从而提高模型的整体性能和稳定性。

        

 

4、编解码器协同工作

         编码器首先处理输入序列,然后将顶部的输出转换为 一组注意力向量 K 和 V,这些向量被每个解码器在他的编码器-解码器注意力层中使用,用于帮助解码器将注意力集中在输入序列中的恰当位置。

        在解码器阶段,每一步都会从输出序列中输出一个元素,这个元素的生成既依赖于之前的输出同时也依赖于编码器-解码器注意力层中的注意力向量,通过多次迭代计算,解码器可以逐渐生成完整的输出序列。

        整个过程中,编码器和解码器的协同工作是通过多头注意力机制和残差链接等技术实现的,这使得 Transformer 模型在各种NLP任务中取得了很好的性能。

        重复上述步骤,直到一个特殊的到达符号表示解码器已经完成了输出。可以说解码器的自注意力层和编码器的自注意力层是非常类似的,但是运行方式不同,解码器的自注意力层只允许关注输出序列中之前的位置,以避免信息泄露和信息未来化的问题,在每个解码器中,输入序列要经过多头注意力机制和前馈神经网络进行编码,然后通过编码器-解码器注意力层与编码器的输出 再进行交互,最后生成解码器最后的输出序列。整个过程中位置编码向量也被用来保留单词在序列中的位置信息。

        

 

5、线性层和softmax层

         对于解码器的输出,他是一个浮点向量,如何把他转换成一个单词呢?这就是线性层和softmax层的工作了。

        线性层:一个全连接神经网络,将解码器堆叠生成的向量映射到一个更大的向量,通常称为logits向量。

        每个单元格对应一个单词分数。

        Softmax层:将单词分数转化为概率,选择具有最高概率的单词作输出。

        

 

6、工作流程

        Transformer模型的工作流程主要包含三个步骤:

        第一步:获取输入句子的每一个单词的表示向量 X,X 由单词的 Embedding 和单词位置的 Embedding 相加得到。

         

        第二步:将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x ,传入 Encoder 中,经过 6 个 Encoder block 后可以得到句子所有单词的编码信息矩阵 C,如下图 2。

        单词向量矩阵用 X(n×d)表示, n 是句子中单词个数,d 是表示向量的维度 (论文中 d=512)。每一个 Encoder block 输出的矩阵维度与输入完全一致。

                

        第三步:将 Encoder 输出的编码信息矩阵 C传递到 Decoder 中,Decoder 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。

        需要特别说明的是Transformer中使用了多头注意力机制。

        

        上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 Begin,预测第一个单词 I;然后输入翻译开始符 Begin 和单词 I,预测单词 have,以此类推。这是 Transformer 使用时候的大致流程,接下来是里面各个部分的细节。 

 

7、优缺点总结

        Transformer 与 RNN 不同,可以比较好地并行训练。

        Transformer 本身是不能利用单词的顺序信息的,因此需要在输入中添加位置 Embedding,否则 Transformer 就是一个词袋模型了。

        Transformer 的重点是 Self-Attention 结构,其中用到的 Q, K, V矩阵通过输出进行线性变换得到。

        Transformer 中 Multi-Head Attention 中有多个 Self-Attention,可以捕获单词之间多种维度上的相关系数 attention score。

        

 



 

二、Transformer模型代码实现

 

         

1、数据准备

(1)代码包引入

import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
from torch import optim
import random
from tqdm import *
import matplotlib.pyplot as plt

(2)数据集生成

# 数据集生成
soundmark = ['ei',  'bi:',  'si:',  'di:',  'i:',  'ef',  'dʒi:',  'eit∫',  'ai', 'dʒei', 'kei', 'el', 'em', 'en', 'əu', 'pi:', 'kju:','ɑ:', 'es', 'ti:', 'ju:', 'vi:', 'd∧blju:', 'eks', 'wai', 'zi:']alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']t = 1000 #总条数
r = 0.9   #扰动项
seq_len = 6
src_tokens, tgt_tokens = [],[] #原始序列、目标序列列表for i in range(t):src, tgt = [],[]for j in range(seq_len):ind = random.randint(0,25)src.append(soundmark[ind])if random.random() < r:tgt.append(alphabet[ind])else:tgt.append(alphabet[random.randint(0,25)])src_tokens.append(src)tgt_tokens.append(tgt)
src_tokens[:2], tgt_tokens[:2]
([['kju:', 'kei', 'em', 'i:', 'vi:', 'pi:'],['bi:', 'kju:', 'eit∫', 'eks', 'ef', 'di:']],[['q', 'k', 'm', 'e', 'v', 'p'], ['b', 'q', 'h', 'x', 'f', 'd']])
from collections import Counter  # 计数类flatten = lambda l: [item for sublist in l for item in sublist]  # 展平数组
# 构建词表
class Vocab:def __init__(self, tokens):self.tokens = tokens  # 传入的tokens是二维列表self.token2index = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}  # 先存好特殊词元# 将词元按词频排序后生成列表self.token2index.update({token: index + 4for index, (token, freq) in enumerate(sorted(Counter(flatten(self.tokens)).items(), key=lambda x: x[1], reverse=True))})# 构建id到词元字典self.index2token = {index: token for token, index in self.token2index.items()}def __getitem__(self, query):# 单一索引if isinstance(query, (str, int)):if isinstance(query, str):return self.token2index.get(query, 3)elif isinstance(query, (int)):return self.index2token.get(query, '<unk>')# 数组索引elif isinstance(query, (list, tuple)):return [self.__getitem__(item) for item in query]def __len__(self):return len(self.index2token)

(3)数据集构造

from torch.utils.data import DataLoader, TensorDataset#实例化source和target词表
src_vocab, tgt_vocab = Vocab(src_tokens), Vocab(tgt_tokens)
src_vocab_size = len(src_vocab)  # 源语言词表大小
tgt_vocab_size = len(tgt_vocab)  # 目标语言词表大小#增加开始标识<bos>和结尾标识<eos>
encoder_input = torch.tensor([src_vocab[line + ['<pad>']] for line in src_tokens])
decoder_input = torch.tensor([tgt_vocab[['<bos>'] + line] for line in tgt_tokens])
decoder_output = torch.tensor([tgt_vocab[line + ['<eos>']] for line in tgt_tokens])# 训练集和测试集比例8比2,batch_size = 16
train_size = int(len(encoder_input) * 0.8)
test_size = len(encoder_input) - train_size
batch_size = 16# 自定义数据集函数
class MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):super(MyDataSet, self).__init__()self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return self.enc_inputs.shape[0]def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]train_loader = DataLoader(MyDataSet(encoder_input[:train_size], decoder_input[:train_size], decoder_output[:train_size]), batch_size=batch_size)
test_loader = DataLoader(MyDataSet(encoder_input[-test_size:], decoder_input[-test_size:], decoder_output[-test_size:]), batch_size=1)

2、模型构建

(1)位置编码

def get_sinusoid_encoding_table(n_position, d_model):def cal_angle(position, hid_idx):return position / np.power(10000, 2 * (hid_idx // 2) / d_model)def get_posi_angle_vec(position):return [cal_angle(position, hid_j) for hid_j in range(d_model)]sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # 偶数位用正弦函数sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # 奇数位用余弦函数return torch.FloatTensor(sinusoid_table)
print(get_sinusoid_encoding_table(30, 512))
tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,0.0000e+00,  1.0000e+00],[ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,1.0366e-04,  1.0000e+00],[ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,2.0733e-04,  1.0000e+00],...,[ 9.5638e-01, -2.9214e-01,  7.9142e-01,  ...,  1.0000e+00,2.7989e-03,  1.0000e+00],[ 2.7091e-01, -9.6261e-01,  9.5325e-01,  ...,  1.0000e+00,2.9026e-03,  1.0000e+00],[-6.6363e-01, -7.4806e-01,  2.9471e-01,  ...,  1.0000e+00,3.0062e-03,  1.0000e+00]])

(2)掩码操作

# mask掉没有意义的占位符
def get_attn_pad_mask(seq_q, seq_k):                       # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]batch_size, len_q = seq_q.size()batch_size, len_k = seq_k.size()pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)          # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]return pad_attn_mask.expand(batch_size, len_q, len_k)# mask掉未来信息
def get_attn_subsequence_mask(seq):                               # seq: [batch_size, tgt_len]attn_shape = [seq.size(0), seq.size(1), seq.size(1)]subsequence_mask = np.triu(np.ones(attn_shape), k=1)          # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]subsequence_mask = torch.from_numpy(subsequence_mask).byte()  #  [batch_size, tgt_len, tgt_len]return subsequence_mask 

(3)注意力计算函数

# 缩放点积注意力计算
class ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v]attn_mask: [batch_size, n_heads, seq_len, seq_len]'''scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.attn = nn.Softmax(dim=-1)(scores)context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]return context, attn#多头注意力计算
class MultiHeadAttention(nn.Module):def __init__(self):super(MultiHeadAttention, self).__init__()self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)def forward(self, input_Q, input_K, input_V, attn_mask):'''input_Q: [batch_size, len_q, d_model]input_K: [batch_size, len_k, d_model]input_V: [batch_size, len_v(=len_k), d_model]attn_mask: [batch_size, seq_len, seq_len]'''residual, batch_size = input_Q, input_Q.size(0)# (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k]K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k]V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]# context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]output = self.fc(context) # [batch_size, len_q, d_model]return nn.LayerNorm(d_model)(output + residual), attn

(4)构建前馈网络

class PoswiseFeedForwardNet(nn.Module):def __init__(self):super(PoswiseFeedForwardNet, self).__init__()self.fc = nn.Sequential(nn.Linear(d_model, d_ff, bias=False),nn.ReLU(),nn.Linear(d_ff, d_model, bias=False))def forward(self, inputs):                             # inputs: [batch_size, seq_len, d_model]residual = inputsoutput = self.fc(inputs)return nn.LayerNorm(d_model)(output + residual)   # 残差 + LayerNorm

(5)编码器模块

# 编码器层
class EncoderLayer(nn.Module):def __init__(self):super(EncoderLayer, self).__init__()self.enc_self_attn = MultiHeadAttention()  # 多头注意力self.pos_ffn = PoswiseFeedForwardNet()  # 前馈网络def forward(self, enc_inputs, enc_self_attn_mask):'''enc_inputs: [batch_size, src_len, d_model]enc_self_attn_mask: [batch_size, src_len, src_len]'''# enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,Venc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]return enc_outputs, attn# 编码器模块
class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Embedding(src_vocab_size, d_model)self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_vocab_size, d_model), freeze=True)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):'''enc_inputs: [batch_size, src_len]'''word_emb = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]pos_emb = self.pos_emb(enc_inputs) # [batch_size, src_len, d_model]enc_outputs = word_emb + pos_embenc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:# enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

(6)解码器模块

# 解码器层
class DecoderLayer(nn.Module):def __init__(self):super(DecoderLayer, self).__init__()self.dec_self_attn = MultiHeadAttention()self.dec_enc_attn = MultiHeadAttention()self.pos_ffn = PoswiseFeedForwardNet()def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):'''dec_inputs: [batch_size, tgt_len, d_model]enc_outputs: [batch_size, src_len, d_model]dec_self_attn_mask: [batch_size, tgt_len, tgt_len]dec_enc_attn_mask: [batch_size, tgt_len, src_len]'''# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)# dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]return dec_outputs, dec_self_attn, dec_enc_attn# 解码器模块
class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_vocab_size, d_model),freeze=True)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''dec_inputs: [batch_size, tgt_len]enc_intpus: [batch_size, src_len]enc_outputs: [batsh_size, src_len, d_model]'''word_emb = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]pos_emb = self.pos_emb(dec_inputs) # [batch_size, tgt_len, d_model]dec_outputs = word_emb + pos_embdec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]dec_self_attn_subsequent_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len]dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) # [batch_size, tgt_len, tgt_len]dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]dec_self_attns, dec_enc_attns = [], []for layer in self.layers:# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len,src_len]dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)dec_self_attns.append(dec_self_attn)dec_enc_attns.append(dec_enc_attn)return dec_outputs, dec_self_attns, dec_enc_attns

(7)Transformer模型

class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.encoder = Encoder()self.decoder = Decoder()self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)def forward(self, enc_inputs, dec_inputs):'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]'''# tensor to store decoder outputs# outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)# enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]enc_outputs, enc_self_attns = self.encoder(enc_inputs)# dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

3、模型训练

d_model = 512   # 字 Embedding 的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度 
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
num_epochs = 50 # 训练50轮
# 记录损失变化
loss_history = []model = Transformer()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.99)for epoch in tqdm(range(num_epochs)):total_loss = 0for enc_inputs, dec_inputs, dec_outputs in train_loader:'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]dec_outputs: [batch_size, tgt_len]'''# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)# outputs: [batch_size * tgt_len, tgt_vocab_size]outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)loss = criterion(outputs, dec_outputs.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss/len(train_loader)loss_history.append(avg_loss)print('Epoch:', '%d' % (epoch + 1), 'loss =', '{:.6f}'.format(avg_loss))
  2%|▏         | 1/50 [00:21<17:16, 21.15s/it]Epoch: 1 loss = 2.6337624%|▍         | 2/50 [00:42<17:06, 21.39s/it]Epoch: 2 loss = 2.0630026%|▌         | 3/50 [01:04<16:48, 21.47s/it]Epoch: 3 loss = 1.8669448%|▊         | 4/50 [01:25<16:16, 21.23s/it]Epoch: 4 loss = 1.80278310%|█         | 5/50 [01:45<15:49, 21.10s/it]Epoch: 5 loss = 1.64321712%|█▏        | 6/50 [02:07<15:27, 21.07s/it]Epoch: 6 loss = 1.80347114%|█▍        | 7/50 [02:27<15:04, 21.02s/it]Epoch: 7 loss = 1.51879416%|█▌        | 8/50 [02:48<14:41, 20.99s/it]Epoch: 8 loss = 1.63284018%|█▊        | 9/50 [03:10<14:23, 21.06s/it]Epoch: 9 loss = 1.44673020%|██        | 10/50 [03:31<14:02, 21.06s/it]Epoch: 10 loss = 1.34034822%|██▏       | 11/50 [03:52<13:40, 21.04s/it]Epoch: 11 loss = 1.36691724%|██▍       | 12/50 [04:13<13:20, 21.06s/it]Epoch: 12 loss = 1.49971526%|██▌       | 13/50 [04:34<13:01, 21.12s/it]Epoch: 13 loss = 1.37144628%|██▊       | 14/50 [04:55<12:41, 21.14s/it]Epoch: 14 loss = 1.38049830%|███       | 15/50 [05:16<12:19, 21.14s/it]Epoch: 15 loss = 1.29818332%|███▏      | 16/50 [05:37<11:53, 20.99s/it]Epoch: 16 loss = 1.10751234%|███▍      | 17/50 [05:57<11:27, 20.85s/it]Epoch: 17 loss = 1.01535536%|███▌      | 18/50 [06:18<11:04, 20.76s/it]Epoch: 18 loss = 0.89157338%|███▊      | 19/50 [06:39<10:41, 20.69s/it]Epoch: 19 loss = 1.03515740%|████      | 20/50 [06:59<10:19, 20.64s/it]Epoch: 20 loss = 1.05994342%|████▏     | 21/50 [07:20<09:58, 20.64s/it]Epoch: 21 loss = 0.99534744%|████▍     | 22/50 [07:40<09:38, 20.65s/it]Epoch: 22 loss = 0.82873046%|████▌     | 23/50 [08:01<09:18, 20.68s/it]Epoch: 23 loss = 0.71740348%|████▊     | 24/50 [08:22<08:59, 20.77s/it]Epoch: 24 loss = 0.76887050%|█████     | 25/50 [08:43<08:39, 20.80s/it]Epoch: 25 loss = 0.71392752%|█████▏    | 26/50 [09:04<08:18, 20.75s/it]Epoch: 26 loss = 0.79791854%|█████▍    | 27/50 [09:24<07:57, 20.74s/it]Epoch: 27 loss = 0.68024656%|█████▌    | 28/50 [09:45<07:36, 20.76s/it]Epoch: 28 loss = 0.61177058%|█████▊    | 29/50 [10:06<07:16, 20.77s/it]Epoch: 29 loss = 0.81035560%|██████    | 30/50 [10:27<06:57, 20.86s/it]Epoch: 30 loss = 0.53748762%|██████▏   | 31/50 [10:48<06:37, 20.93s/it]Epoch: 31 loss = 0.48465064%|██████▍   | 32/50 [11:09<06:15, 20.86s/it]Epoch: 32 loss = 0.44703366%|██████▌   | 33/50 [11:30<05:54, 20.83s/it]Epoch: 33 loss = 0.39907268%|██████▊   | 34/50 [11:51<05:34, 20.90s/it]Epoch: 34 loss = 0.37964970%|███████   | 35/50 [12:12<05:13, 20.92s/it]Epoch: 35 loss = 0.27082372%|███████▏  | 36/50 [12:32<04:52, 20.91s/it]Epoch: 36 loss = 0.33787874%|███████▍  | 37/50 [12:53<04:30, 20.81s/it]Epoch: 37 loss = 0.23544076%|███████▌  | 38/50 [13:14<04:09, 20.77s/it]Epoch: 38 loss = 0.33739378%|███████▊  | 39/50 [13:35<03:49, 20.85s/it]Epoch: 39 loss = 0.26019180%|████████  | 40/50 [13:56<03:28, 20.89s/it]Epoch: 40 loss = 0.21008482%|████████▏ | 41/50 [14:17<03:09, 21.03s/it]Epoch: 41 loss = 0.16861684%|████████▍ | 42/50 [14:38<02:47, 20.97s/it]Epoch: 42 loss = 0.21360786%|████████▌ | 43/50 [14:58<02:25, 20.82s/it]Epoch: 43 loss = 0.11055188%|████████▊ | 44/50 [15:19<02:04, 20.74s/it]Epoch: 44 loss = 0.18356290%|█████████ | 45/50 [15:39<01:43, 20.62s/it]Epoch: 45 loss = 0.09517292%|█████████▏| 46/50 [16:00<01:22, 20.57s/it]Epoch: 46 loss = 0.13238794%|█████████▍| 47/50 [16:20<01:01, 20.52s/it]Epoch: 47 loss = 0.16380596%|█████████▌| 48/50 [16:41<00:40, 20.49s/it]Epoch: 48 loss = 0.15219598%|█████████▊| 49/50 [17:01<00:20, 20.49s/it]Epoch: 49 loss = 0.086681
100%|██████████| 50/50 [17:22<00:00, 20.84s/it]Epoch: 50 loss = 0.085496
plt.plot(loss_history)
plt.ylabel('train loss')
plt.show()

4、模型预测

model.eval()
translation_results = []correct = 0
error = 0for enc_inputs, dec_inputs, dec_outputs in test_loader:'''enc_inputs: [batch_size, src_len]dec_inputs: [batch_size, tgt_len]dec_outputs: [batch_size, tgt_len]'''# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)# outputs: [batch_size * tgt_len, tgt_vocab_size]outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)# pred形状为 (seq_len, batch_size, vocab_size) = (1, 1, vocab_size)# dec_outputs, dec_self_attns, dec_enc_attns = model.decoder(dec_inputs, enc_inputs, enc_output)outputs = outputs.squeeze()pred_seq = []for output in outputs:next_token_index = output.argmax().item()if next_token_index == tgt_vocab['<eos>']:breakpred_seq.append(next_token_index)pred_seq = tgt_vocab[pred_seq]tgt_seq = dec_outputs.squeeze().tolist()# 需要注意在<eos>之前截断if tgt_vocab['<eos>'] in tgt_seq:eos_idx = tgt_seq.index(tgt_vocab['<eos>'])tgt_seq = tgt_vocab[tgt_seq[:eos_idx]]else:tgt_seq = tgt_vocab[tgt_seq]translation_results.append((' '.join(tgt_seq), ' '.join(pred_seq)))for i in range(len(tgt_seq)):if i >= len(pred_seq) or pred_seq[i] != tgt_seq[i]:error += 1else:correct += 1print(correct/(correct+error))
0.3333333333333333
translation_results
[('h x n y e k', 'h y y y k'),('y l z k i t', 't i t j i t y'),('t s x e e v', 's s v e e v'),('e g a m t h', 'f i h h h'),...................

 


参考

Chapter-11/11.7 Transformer代码实现.ipynb · 梗直哥/Deep-Learning-Code - Gitee.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/590266.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一年百模大战下来,有哪些技术趋势和行业真相逐渐浮出水面?

介绍 本人是独立开源软件开发者&#xff0c;参与很多项目建设&#xff0c;谈下感受。 ChatGPT开始AI生成元年&#xff0c;经历一年依然是第一。 LLaMA的巧合开启开源大模型浪潮。 名词解释 AIGC : AI-Generated Content 指利用人工智能技术&#xff08;生成式AI路径&#x…

类和接口

内容大部分来源于学习笔记&#xff0c;随手记录笔记内容以及个人笔记 对象Object java是面向对象的语言&#xff0c;一个对象包含状态和行为 可以这样理解&#xff0c;我眼前的石头&#xff0c;手里水杯&#xff0c;这些具体到某一个个体&#xff0c;这就是对象&#xff1b;…

非科班,培训出身,怎么进大厂?

今天分享一下我是怎么进大厂的经历&#xff0c;希望能给大家带来一点点启发&#xff01; 阿七毕业于上海一所大学的管理学院&#xff0c;在读期间没写过一行 Java 代码。毕业之后二战考研失利。 回过头来看&#xff0c;也很庆幸这次考研失利&#xff0c;因为这个时候对社会一…

OpenOCD简介和下载安装(Ubuntu)

文章目录 OpenOCD简介OpenOCD软件模块OpenOCD源码下载OpenOCD安装 OpenOCD简介 OpenOCD&#xff08;Open On-Chip Debugger&#xff09;开放式片上调试器 OpenOCD官网 https://openocd.org/&#xff0c;进入官网点击 About 可以看到OpenOCD最初的设计是由国外一个叫Dominic Ra…

红队打靶练习:SAR: 1

目录 信息收集 1、arp 2、netdiscover 3、nmap 4、nikto 5、whatweb 小结 目录探测 1、gobuster 2、dirsearch WEB CMS 1、cms漏洞探索 2、RCE漏洞利用 提权 get user.txt 本地提权 信息收集 1、arp ┌──(root㉿ru)-[~/kali] └─# arp-scan -l Interface:…

写在2024年初,软件测试面试笔记总结与分享

大家好&#xff0c;最近有不少小伙伴在后台留言&#xff0c;得准备年后面试了&#xff0c;又不知道从何下手&#xff01;为了帮大家节约时间&#xff0c;特意准备了一份面试相关的资料&#xff0c;内容非常的全面&#xff0c;真的可以好好补一补&#xff0c;希望大家在都能拿到…

LanceDB:在对抗数据复杂性战役中,您可信赖的坐骑

LanceDB 建立在 Lance&#xff08;一种开源列式数据格式&#xff09;之上&#xff0c;具有一些有趣的功能&#xff0c;使其对 AI/ML 具有吸引力。例如&#xff0c;LanceDB 支持显式和隐式矢量化&#xff0c;能够处理各种数据类型。LanceDB 与 PyTorch 和 TensorFlow 等领先的 M…

24届春招实习必备技能(一)之MyBatis Plus入门实践详解

MyBatis Plus入门实践详解 一、什么是MyBatis Plus? MyBatis Plus简称MP&#xff0c;是mybatis的增强工具&#xff0c;旨在增强&#xff0c;不做改变。MyBatis Plus内置了内置通用 Mapper、通用 Service&#xff0c;仅仅通过少量配置即可实现单表大部分 CRUD 操作&#xff0…

【LMM 003】生物医学领域的垂直类大型多模态模型 LLaVA-Med

论文标题&#xff1a;LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day 论文作者&#xff1a;Chunyuan Li∗, Cliff Wong∗, Sheng Zhang∗, Naoto Usuyama, Haotian Liu, Jianwei Yang Tristan Naumann, Hoifung Poon, Jianfeng Gao 作…

LeetCode二叉树路径和专题:最大路径和与路径总和计数的策略

目录 437. 路径总和 III 深度优先遍历 前缀和优化 124. 二叉树中的最大路径和 437. 路径总和 III 给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始&#xf…

简单FTP客户端软件开发——VMware安装Linux虚拟机(命令行版)

VMware安装包和Linux系统镜像&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1UwF4DT8hNXp_cV0NpSfTww?pwdxnoh 提取码&#xff1a;xnoh 这个学期做计网课程设计【简单FTP客户端软件开发】需要在Linux上配置 ftp服务器&#xff0c;故此用VMware安装了Linux虚拟机&…

burpsuite模块介绍之compare

导语 Burp Comparer是Burp Suite中的一个工具&#xff0c;主要提供一个可视化的差异比对功能&#xff0c;可以用于分析比较两次数据之间的区别。它的应用场景包括但不限于&#xff1a; 枚举用户名过程中&#xff0c;对比分析登陆成功和失败时&#xff0c;服务器端反馈结果的区…

编程式导航传参

(通过js代码实现跳转) 按照path进行跳转 第一步&#xff1a; 在app.vue中(前提是规则已经配置好) <template><div id"app">App组件<button clicklogin>跳转</button><!--路由出口-将来匹配的组件渲染地方--><router-view>&l…

【嵌入式学习笔记-01】什么是UC,操作系统历史介绍,计算机系统分层,环境变量(PATH),错误

【嵌入式学习笔记】什么是UC&#xff0c;操作系统历史介绍&#xff0c;计算机系统分层&#xff0c;环境变量&#xff08;PATH&#xff09;&#xff0c;错误 文章目录 什么是UC?计算机系统分层什么是操作系统&#xff1f; 环境变量什么是环境变量&#xff1f;环境变量的添加&am…

简写英语单词

题目&#xff1a; 思路&#xff1a; 这段代码的主要思路是读取一个字符串&#xff0c;然后将其中每个单词的首字母大写输出。具体来说&#xff0c;程序首先使用 fgets 函数读取一个字符串&#xff0c;然后遍历该字符串中的每个字符。当程序遇到一个字母时&#xff0c;如果此时…

基于图论的图像分割 python + PyQt5

数据结构大作业&#xff0c;基于图论中的最小生成树的图像分割。一个很古老的算法&#xff0c;精度远远不如深度学习算法&#xff0c;但是对于代码能力是一个很好的锻炼。 课设要求&#xff1a; &#xff08; 1 &#xff09;输入&#xff1a;图像&#xff08;例如教室场景图&a…

47、激活函数 - sigmoid

今天在看一个比较常见的激活函数,叫作 sigmoid 激活函数,它的数学表达式为: 其中,x 为输入,画出图来看更直观一些。 Sigmoid 函数的图像看起来像一个 S 形曲线,我们先分析一下这个函数的特点。 Sigmoid 函数的输出范围在 (0, 1) 之间,并且不等于0或1。 Sigmoid 很明显是…

Codeforces Round 900 (Div. 3)(A-F)

比赛链接 : Dashboard - Codeforces Round 900 (Div. 3) - Codeforces A. How Much Does Daytona Cost? 题面 : 思路 : 在序列中只要找到k&#xff0c;就返回true ; 代码 : #include<bits/stdc.h> #define IOS ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)…

spring 之 事务

1、JdbcTemplate Spring 框架对 JDBC 进行封装&#xff0c;使用 JdbcTemplate 方便实现对数据库操作 1.1 准备工作 ①搭建子模块 搭建子模块&#xff1a;spring-jdbc-tx ②加入依赖 <dependencies><!--spring jdbc Spring 持久化层支持jar包--><dependency&…

性能优化(CPU优化技术)-ARM Neon详细介绍

本文主要介绍ARM Neon技术&#xff0c;包括SIMD技术、SIMT、ARM Neon的指令、寄存器、意图为读者提供对ARM Neon的一个整体理解。 &#x1f3ac;个人简介&#xff1a;一个全栈工程师的升级之路&#xff01; &#x1f4cb;个人专栏&#xff1a;高性能&#xff08;HPC&#xff09…