python-pytorch编写transformer模型实现问答0.5.00--训练和预测
- 背景
- 代码
- 训练
- 预测
- 效果
背景
代码写不了这么长,接上一篇
https://blog.csdn.net/m0_60688978/article/details/139360270
代码
# 定义解码器类
n_layers = 6 # 设置 Decoder 的层数
class Decoder(nn.Module):def __init__(self, corpus):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(vocab_size, d_embedding) # 词嵌入层self.pos_emb = nn.Embedding.from_pretrained( \get_sin_enc_table(vocab_size+1, d_embedding), freeze=True) # 位置嵌入层 self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) # 叠加多层def forward(self, dec_inputs, enc_inputs, enc_outputs): #------------------------- 维度信息 --------------------------------# dec_inputs 的维度是 [batch_size, target_len]# enc_inputs 的维度是 [batch_size, source_len]# enc_outputs 的维度是 [batch_size, source_len, embedding_dim]#----------------------------------------------------------------- # 创建一个从 1 到 source_len 的位置索引序列pos_indices = torch.arange(1, dec_inputs.size(1) + 1).unsqueeze(0).to(dec_inputs)#------------------------- 维度信息 --------------------------------# pos_indices 的维度是 [1, target_len]#----------------------------------------------------------------- # 对输入进行词嵌入和位置嵌入相加dec_outputs = self.tgt_emb(dec_inputs)