torch.nn.Embedding
flyfish
此模块通常用于存储单词嵌入并使用索引检索它们。模块的输入是索引列表,输出是对应的单词嵌入。
import torch
import torch.nn as nn
torch.manual_seed(0)
embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
print(embedding.weight)
# tensor([[-0.7588, -0.0094, -0.8549],
# [-1.9320, -0.1008, 1.1125],
# [-0.7327, 0.5621, 0.2356],
# [-1.6812, -0.2477, 0.1624],
# [ 0.5170, 0.0979, -0.3463],
# [ 0.4478, -1.3857, 1.8448],
# [-1.2102, -0.5387, -1.8527],
# [-0.1349, -0.6765, -2.4845],
# [-1.5077, 0.4549, -0.9425],
# [ 1.9715, 0.9959, 0.0415]], requires_grad=True)
#10是embedding的大小
input = torch.LongTensor([[1, 2, 4, 5]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)
#从索引01 是[-1.9320, -0.1008, 1.1125],
2 是[-0.7327, 0.5621, 0.2356],
4 是[ 0.5170, 0.0979, -0.3463],
5 是[ 0.4478, -1.3857, 1.8448]# tensor([[[-1.9320, -0.1008, 1.1125],
# [-0.7327, 0.5621, 0.2356],
# [ 0.5170, 0.0979, -0.3463],
# [ 0.4478, -1.3857, 1.8448]]], grad_fn=<EmbeddingBackward0>)
embedding.weight的值是哪来的呢
是通过nn.init.normal_来的,使用从正态分布中提取的值填充输入张量。
将种子固定后,会得到相同的数值
torch.manual_seed(0)
w = torch.empty(10, 3)
print(nn.init.normal_(w,0,1))
参考
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html