位置编码公式
偶数位置用sin,奇数位置用cos. d_model 表示token的维度;pos表示token在序列中的位置;i表示每个token编码的第i个位置,属于[0,d_model)。
torch实现
import math
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as pltclass PositionalEncoder(nn.Module):def __init__(self, max_seq_len=50, d_model=128):super().__init__()self.d_model = d_model # d_model 表示token的维度pe = torch.zeros(max_seq_len, d_model) # max_seq_len * d_model 的二维张量 例如: 50*128for pos in range(max_seq_len): # 重新初始化for i in range(0, d_model, 2):pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))pe[pos, i + 1] = math.cos(pos / (10000 ** (i / d_model)))pe = pe.unsqueeze(0) # 1*50*128self.register_buffer('pe', pe)def forward(self, x):x = x * math.sqrt(self.d_model)seq_len = x.size(1)x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()return xif __name__ == '__main__':positional_encoder = PositionalEncoder(50, 128)plt.pcolormesh(positional_encoder.pe.numpy()[0], cmap='RdBu')plt.xlabel('Depth') # 50plt.xlim((0, 128))plt.ylabel('Position') # 128plt.colorbar()plt.show()