大模型基础——从零实现一个Transformer(1)-CSDN博客
大模型基础——从零实现一个Transformer(2)-CSDN博客
大模型基础——从零实现一个Transformer(3)-CSDN博客
大模型基础——从零实现一个Transformer(4)-CSDN博客
一、前言
上一篇文章已经把Encoder模块和Decoder模块都已经实现了,
接下来来实现完整的Transformer
二、Transformer
Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠
import torch
from torch import nn,Tensor
from torch.nn import Embedding#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_maskclass Transformer(nn.Module):def __init__(self,source_vocab_size:int,target_vocab_size:int,d_model: int = 512,n_heads: int = 8,num_encoder_layers: int = 6,num_decoder_layers: int = 6,d_ff: int = 2048,dropout: float = 0.1,max_positions:int = 5000,pad_idx: int = 0,norm_first: bool=False) -> None:''':param source_vocab_size: size of the source vocabulary.:param target_vocab_size: size of the target vocabulary.:param d_model: dimension of embeddings. Defaults to 512.:param n_heads: number of heads. Defaults to 8.:param num_encoder_layers: number of encoder blocks. Defaults to 6.:param num_decoder_layers: number of decoder blocks. Defaults to 6.:param d_ff: dimension of inner feed-forward network. Defaults to 2048.:param dropout: dropout ratio. Defaults to 0.1.:param max_positions: maximum sequence length for positional encoding. Defaults to 5000.:param pad_idx: pad index. Defaults to 0.:param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).Otherwise it's done after(Post-Norm). Default to False.'''super().__init__()# Token embeddingself.src_embeddings = Embedding(source_vocab_size,d_model)self.target_embeddings = Embedding(target_vocab_size,d_model)# Position embeddingself.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)# 编码层定义self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)# 解码层定义self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)self.pad_idx = pad_idxdef encode(self,src:Tensor,src_mask: Tensor=None,keep_attentions: bool=False) -> Tensor:'''编码过程:param src: (batch_size, src_seq_length) the sequence to the encoder:param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence:param keep_attentions: whether keep attention weigths or not. Defaults to False.:return: (batch_size, seq_length, d_model) encoder output'''src_embedding_tensor = self.src_embeddings(src)src_embedded = self.encoder_pos(src_embedding_tensor)return self.encoder(src_embedded,src_mask,keep_attentions)def decode(self,target_tensor: Tensor,memory: Tensor,target_mask: Tensor = None,memory_mask: Tensor = None,keep_attentions: bool = False) ->Tensor:''':param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.:param memory: (batch_size, src_seq_length, d_model) the sequence from the last layer of the encoder.:param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.:param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.:param keep_attentions: whether keep attention weigths or not. Defaults to False.:return: output (batch_size, tgt_seq_length, tgt_vocab_size)'''target_embedding_tensor = self.target_embeddings(target_tensor)target_embedded = self.decoder_pos(target_embedding_tensor)# logits (batch_size, target_seq_length, d_model)logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)return logitsdef forward(self,src: Tensor,target: Tensor,src_mask: Tensor=None,target_mask: Tensor=None,keep_attention:bool=False)->Tensor:''':param src: (batch_size, src_seq_length) the sequence to the encoder:param target: (batch_size, tgt_seq_length) the sequence to the decoder:param src_mask::param target_mask::param keep_attention: whether keep attention weigths or not. Defaults to False.:return: (batch_size, tgt_seq_length, tgt_vocab_size)'''memory = self.encode(src,src_mask,keep_attention)return self.decode(target,memory,target_mask,src_mask,keep_attention)
三、测试
写个简单的main函数,测试一下整体网络是否正常
if __name__ == '__main__':source_vocab_size = 300target_vocab_size = 300# padding对应的index,一般都是0pad_idx = 0batch_size = 1max_positions = 20model = Transformer(source_vocab_size=source_vocab_size,target_vocab_size=target_vocab_size)src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))## 最后5位置是paddingsrc_tensor[:,-5:] = 0## 最后10位置是paddingtarget_tensor[:, -10:] = 0src_mask = (src_tensor != pad_idx).unsqueeze(1)targe_mask = make_target_mask(target_tensor)logits = model(src_tensor,target_tensor,src_mask,targe_mask)print(logits.shape) #torch.Size([1, 20, 512])