前言
这是对上一篇WordEmbedding的续篇PositionEmbedding。
视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili
上一篇链接:Transformer模型:WordEmbedding实现-CSDN博客
正文
先回顾一下原论文中对Position Embedding的计算公式:pos表示位置,i表示维度索引,d_model表示嵌入向量的维度,position分奇数列和偶数列。
Position Embedding也是二维的,行数是训练的序列最大长度,列是d_model。首先定义position的最大长度,这里定为12,也就是训练中的长度最大值都是12。
max_position_len = 12
这里先循环遍历得到pos,构造Pos序列,pos是从0到最大长度的遍历,决定行:
pos_mat = torch.arange(max_position_len)
但是此时得到的是一维的,我们要将它转为二维矩阵的,也就是得到目标行数,使用.reshape()函数,这样就构造好了行矩阵:
pos_mat = torch.arange(max_position_len).reshape((-1,1))
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11]])
接下来要构造列矩阵,构造 i 序列,首先是是2i/d_model部分,这里的8是因为我们设定的d_model=8,2是步长:
i_mat = torch.arange(0, 8, 2)/model_dim
这时候再把分母的完整形式实现,幂次使用pow()函数:
i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)
tensor([ 1., 10., 100., 1000.])
此时就得到了列向量,这时候就有疑问了为什么列只有4列,我们的d_model不是8吗,应该有8列才对啊。这是因为区分了奇数列跟偶数列的计算,所以这里才要求步长为2生成的只有4列。
先初始化一个max_position_len*model_dim的零矩阵(12*8),然后再分别使用sin和cos填充偶数列和奇数列:
pe_embedding_table = torch.zeros(max_position_len, model_dim)pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat) # 从第0列到结束,步长为2,也就是填充偶数列
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat) # 从第1列到结束,步长为2,也就是填充奇数列
得到的就是Position Embedding的权重矩阵了:
这下面采用的是使用nn.Embedding()的方法,得到的跟上面的结果还是一样的,只不过这里的pe_embedding是可以传入位置的,之后的调用就是这样得到的:
pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
这里就要构造位置索引了:
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)
然后传入位置索引,就得到了src跟tgt的Position Embedding:
src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
这里我很疑惑的点是生成的结果src_pe_embedding跟tgt_pe_embedding内容是一样的,并且单个里面的一个内容也就是position embedding,刚入门听得我还是有点不太能理解。
src_pos is:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], dtype=torch.int32)
tgt_pos is:
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], dtype=torch.int32)
src_pe_embedding is:
tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00],
[-7.5680e-01, -6.5364e-01, 3.8942e-01, 9.2106e-01, 3.9989e-02,
9.9920e-01, 4.0000e-03, 9.9999e-01],
[-9.5892e-01, 2.8366e-01, 4.7943e-01, 8.7758e-01, 4.9979e-02,
9.9875e-01, 5.0000e-03, 9.9999e-01],
[-2.7942e-01, 9.6017e-01, 5.6464e-01, 8.2534e-01, 5.9964e-02,
9.9820e-01, 6.0000e-03, 9.9998e-01],
[ 6.5699e-01, 7.5390e-01, 6.4422e-01, 7.6484e-01, 6.9943e-02,
9.9755e-01, 6.9999e-03, 9.9998e-01],
[ 9.8936e-01, -1.4550e-01, 7.1736e-01, 6.9671e-01, 7.9915e-02,
9.9680e-01, 7.9999e-03, 9.9997e-01],
[ 4.1212e-01, -9.1113e-01, 7.8333e-01, 6.2161e-01, 8.9879e-02,
9.9595e-01, 8.9999e-03, 9.9996e-01],
[-5.4402e-01, -8.3907e-01, 8.4147e-01, 5.4030e-01, 9.9833e-02,
9.9500e-01, 9.9998e-03, 9.9995e-01],
[-9.9999e-01, 4.4257e-03, 8.9121e-01, 4.5360e-01, 1.0978e-01,
9.9396e-01, 1.1000e-02, 9.9994e-01]],[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00],
[-7.5680e-01, -6.5364e-01, 3.8942e-01, 9.2106e-01, 3.9989e-02,
9.9920e-01, 4.0000e-03, 9.9999e-01],
[-9.5892e-01, 2.8366e-01, 4.7943e-01, 8.7758e-01, 4.9979e-02,
9.9875e-01, 5.0000e-03, 9.9999e-01],
[-2.7942e-01, 9.6017e-01, 5.6464e-01, 8.2534e-01, 5.9964e-02,
9.9820e-01, 6.0000e-03, 9.9998e-01],
[ 6.5699e-01, 7.5390e-01, 6.4422e-01, 7.6484e-01, 6.9943e-02,
9.9755e-01, 6.9999e-03, 9.9998e-01],
[ 9.8936e-01, -1.4550e-01, 7.1736e-01, 6.9671e-01, 7.9915e-02,
9.9680e-01, 7.9999e-03, 9.9997e-01],
[ 4.1212e-01, -9.1113e-01, 7.8333e-01, 6.2161e-01, 8.9879e-02,
9.9595e-01, 8.9999e-03, 9.9996e-01],
[-5.4402e-01, -8.3907e-01, 8.4147e-01, 5.4030e-01, 9.9833e-02,
9.9500e-01, 9.9998e-03, 9.9995e-01],
[-9.9999e-01, 4.4257e-03, 8.9121e-01, 4.5360e-01, 1.0978e-01,
9.9396e-01, 1.1000e-02, 9.9994e-01]]])
tgt_pe_embedding is:
tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00],
[-7.5680e-01, -6.5364e-01, 3.8942e-01, 9.2106e-01, 3.9989e-02,
9.9920e-01, 4.0000e-03, 9.9999e-01],
[-9.5892e-01, 2.8366e-01, 4.7943e-01, 8.7758e-01, 4.9979e-02,
9.9875e-01, 5.0000e-03, 9.9999e-01],
[-2.7942e-01, 9.6017e-01, 5.6464e-01, 8.2534e-01, 5.9964e-02,
9.9820e-01, 6.0000e-03, 9.9998e-01],
[ 6.5699e-01, 7.5390e-01, 6.4422e-01, 7.6484e-01, 6.9943e-02,
9.9755e-01, 6.9999e-03, 9.9998e-01],
[ 9.8936e-01, -1.4550e-01, 7.1736e-01, 6.9671e-01, 7.9915e-02,
9.9680e-01, 7.9999e-03, 9.9997e-01],
[ 4.1212e-01, -9.1113e-01, 7.8333e-01, 6.2161e-01, 8.9879e-02,
9.9595e-01, 8.9999e-03, 9.9996e-01],
[-5.4402e-01, -8.3907e-01, 8.4147e-01, 5.4030e-01, 9.9833e-02,
9.9500e-01, 9.9998e-03, 9.9995e-01],
[-9.9999e-01, 4.4257e-03, 8.9121e-01, 4.5360e-01, 1.0978e-01,
9.9396e-01, 1.1000e-02, 9.9994e-01]],[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
1.0000e+00, 0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 9.9833e-02, 9.9500e-01, 9.9998e-03,
9.9995e-01, 1.0000e-03, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 1.9867e-01, 9.8007e-01, 1.9999e-02,
9.9980e-01, 2.0000e-03, 1.0000e+00],
[ 1.4112e-01, -9.8999e-01, 2.9552e-01, 9.5534e-01, 2.9995e-02,
9.9955e-01, 3.0000e-03, 1.0000e+00],
[-7.5680e-01, -6.5364e-01, 3.8942e-01, 9.2106e-01, 3.9989e-02,
9.9920e-01, 4.0000e-03, 9.9999e-01],
[-9.5892e-01, 2.8366e-01, 4.7943e-01, 8.7758e-01, 4.9979e-02,
9.9875e-01, 5.0000e-03, 9.9999e-01],
[-2.7942e-01, 9.6017e-01, 5.6464e-01, 8.2534e-01, 5.9964e-02,
9.9820e-01, 6.0000e-03, 9.9998e-01],
[ 6.5699e-01, 7.5390e-01, 6.4422e-01, 7.6484e-01, 6.9943e-02,
9.9755e-01, 6.9999e-03, 9.9998e-01],
[ 9.8936e-01, -1.4550e-01, 7.1736e-01, 6.9671e-01, 7.9915e-02,
9.9680e-01, 7.9999e-03, 9.9997e-01],
[ 4.1212e-01, -9.1113e-01, 7.8333e-01, 6.2161e-01, 8.9879e-02,
9.9595e-01, 8.9999e-03, 9.9996e-01],
[-5.4402e-01, -8.3907e-01, 8.4147e-01, 5.4030e-01, 9.9833e-02,
9.9500e-01, 9.9998e-03, 9.9995e-01],
[-9.9999e-01, 4.4257e-03, 8.9121e-01, 4.5360e-01, 1.0978e-01,
9.9396e-01, 1.1000e-02, 9.9994e-01]]])
代码
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F# 句子数
batch_size = 2# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12# 模型的维度
model_dim = 8# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)
print(src_len)
print(tgt_len)#单词索引构成的句子
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max_src_seg_len-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max_tgt_seg_len-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)# 构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding_table.weight)
print(src_embedding)
print(tgt_embedding)# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2)/model_dim)# 构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)
print("pe_embedding_table is:\n",pe_embedding_table)pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len),0) for _ in tgt_len]).to(torch.int32)
print("src_pos is:\n",src_pos)
print("tgt_pos is:\n",tgt_pos)src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print("src_pe_embedding is:\n",src_pe_embedding)
print("tgt_pe_embedding is:\n",tgt_pe_embedding)
参考
Python的reshape的用法:reshape(1,-1)-CSDN博客