python-pytorch实现CBOW 0.5.000
- 数据加载、切词
- 准备训练数据
- 准备模型和参数
- 训练
- 保存模型
- 加载模型
- 简单预测
- 获取词向量
- 降维显示图
- 使用词向量计算相似度
- 参考
数据加载、切词
按照链接https://blog.csdn.net/m0_60688978/article/details/137538274操作后,可以获得的数据如下
- wordList 文本中所有的分词,放入这个数组中
- raw_text 这个可以忽略,相当于wordlist的备份,防止数据污染了
- vocab 将wordList转变为set,即set(wordList)
- vocab_size 所有分词的个数
- word_to_idx 字典格式,汉字对应索引
- idx_to_word 字典格式,索引对应汉字
准备训练数据
data3 = []
for i in range(2, len(raw_text) - 2):context = [raw_text[i - 2], raw_text[i - 1],raw_text[i + 1], raw_text[i + 2]]target = raw_text[i]data3 .append((context, target))print(data3 [:5])
"""
[(['从零开始', 'Zookeeper', '高', '可靠'], '开源'), (['Zookeeper', '开源', '可靠', '分布式'], '高'), (['开源', '高', '分布式', '一致性'], '可靠'), (['高', '可靠', '一致性', '协调'], '分布式'), (['可靠', '分布式', '协调', '服务'], '一致性')]
"""
准备模型和参数
# 超参数
learning_rate = 0.003
device = torch.device('cpu')
embedding_dim = 100
epoch = 10
class CBOW(nn.Module):def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))out = self.output(out)nll_prob = F.log_softmax(out, dim=-1)return nll_probmodel = CBOW(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练
losses = []
loss_function = nn.NLLLoss()for epoch in trange(3000):total_loss = 0for context, target in data1:context_vector = make_context_vector(context, word_to_idx)target = torch.tensor([word_to_idx[target]])# 梯度清零model.zero_grad()# 开始前向传播train_predict = model(context_vector) loss = loss_function(train_predict, target)# 反向传播loss.backward()# 更新参数optimizer.step()total_loss += loss.item()if epoch % 100 ==0:print("loss is ",total_loss,"echo is ",epoch)losses.append(total_loss)
print("losses-=", losses)
"""97%|███████████████████████████████████████████████████████████████████████████▍ | 2902/3000 [07:07<00:13, 7.17it/s]
loss is 0.18700819212244824 echo is 2900
100%|██████████████████████████████████████████████████████████████████████████████| 3000/3000 [07:21<00:00, 6.79it/s]
"""
保存模型
torch.save(model.state_dict(),"model.pth")
加载模型
model = CBOW(vocab_size, embedding_dim).to(device)
model.load_state_dict(torch.load("model.pth"))
print(model)
"""
CBOW((embeddings): Embedding(179, 100)(proj): Linear(in_features=100, out_features=128, bias=True)(output): Linear(in_features=128, out_features=179, bias=True)
)
"""
简单预测
def cut_sentense(str):stop_words = load_stop_words()with open('data/zh.txt', encoding='utf8') as f:allData = f.readlines()result = []c_words = jieba.lcut(str)for word in c_words:if word not in stop_words and word != "\n":result.append(word)return resultcontext_vector = make_context_vector(cut_sentense("在Master节点使用客户端"), word_to_idx).to(device)
print(context_vector,type(context_vector))
predict = model(context_vector).data.cpu().numpy()
max_idx = np.argmax(predict)
# 输出预测的值
print('Prediction: {}'.format(idx_to_word[max_idx]))"""
输出中心词语,看上去不怎么样
tensor([120, 37, 49]) <class 'torch.Tensor'>
Prediction: 除主
获取词向量
trained_vector_dic={}
for word, idx in word_to_idx.items(): # 输出每个词的嵌入向量trained_vector_dic[word]=model.embedding.weight[idx]
"""
trained_vector_dic内容类似于
{'参数值': tensor([-3.6921e+00, -1.3388e+00, 2.4545e-03, -1.1352e+00, -1.8306e-04,-6.3501e-01, -1.4372e-01, -8.2283e-01, -1.6009e+00, -7.4731e-01,-1.3509e-01, -2.5100e-01, -1.0037e+00, 9.0061e-01, 1.7794e-01,-8.6344e-03, -1.2831e+00, -2.1400e+00, 2.7457e-01, 1.8157e-01,2.1480e-01, -2.2192e-02, -3.8433e-01, 1.3575e+00, 1.8483e+00,-6.6326e-01, -2.0239e+00, -1.9854e+00, 4.0531e-01, -1.5659e-01,-2.7774e+00, -8.2578e-02, 1.5725e+00, -9.9693e-01, 6.0748e-01,-6.4992e-01, 8.5653e-01, -1.1889e+00, 1.1657e-04, -3.3866e-01,8.2302e-02, 1.0612e-02, -8.8592e-01, -1.9495e-01, -1.2271e-01,-4.1997e+00, 1.3430e+00, -6.6779e-01, -1.7927e-01, 3.0450e-01,8.4677e-02, -9.5100e-01, 2.5847e-01, 1.1187e+00, 3.1471e+00,2.4095e+00, -1.0612e-01, 2.1663e+00, -8.5172e-01, -2.1438e-01,2.3635e-01, 4.7740e-01, -2.8115e+00, -1.5964e-01, 4.9957e-02,1.6154e-01, -7.0892e-01, -5.6724e-01, -2.2594e-01, -1.2353e+00,8.9448e-01, -1.7034e-01, -6.5750e-01, 9.8126e-01, -1.7088e+00,-1.9967e-01, 2.6574e-01, -1.3275e-01, 6.1529e-01, -3.6684e-01,1.7341e-02, 1.5207e-03, -4.8425e-01, -2.2761e-01, -2.2298e+00,-5.5302e-01, 4.4864e-01, -2.5363e-01, 3.4734e-01, -4.4062e-02,-1.3769e+00, 1.6567e-01, -7.3674e-01, -8.4163e-01, 2.9937e-01,2.3714e+00, 1.2883e+00, 1.2383e-01, 7.5008e-01, -1.3516e-01],grad_fn=<SelectBackward0>),'05': tensor([ 1.1536e+00, -2.2545e-01, -9.9584e-01, 2.0407e-02, 1.9062e+00,-5.5870e-01, -6.1779e-04, 2.7210e-01, -1.9126e+00, -8.1227e-02,-6.0733e-02, -3.3426e-03, 9.4838e-01, 3.1968e-01, 1.1331e+00,1.9320e-01, 9.8004e-01, 1.3209e-01, 3.9876e-01, 1.9894e-01,9.6364e-01, -2.9291e-01, -1.4829e+00, 1.9647e+00, -1.2805e-01,1.7458e+00, 9.1834e-02, 7.3453e-01, -1.4541e-01, -1.5197e+00,2.5946e-01, 1.1071e+00, 2.3167e-02, -9.9457e-01, -6.4125e-02,-2.1326e-01, -2.1815e+00, -8.3949e-02, -3.8223e-01, 2.0616e+00,-7.3382e-02, 2.6695e-01, 9.4765e-02, -3.2757e-01, -4.8486e-01,-3.0599e-01, 8.8235e-01, 3.1940e-01, -1.3256e-01, -6.0862e-01,4.4978e-01, -3.0902e+00, 1.6898e+00, 5.7821e-01, -5.2478e-02,4.9577e-01, 4.5494e-01, 5.6485e-04, -2.5271e+00, 3.1652e+00,-4.2832e-02, -9.9416e-02, 3.1775e-01, -1.9758e+00, -1.2955e-02,-1.6038e+00, 5.3717e-02, 2.9455e-03, -3.6091e-01, -5.7126e-01,1.6538e+00, -2.0648e+00, -3.1718e-01, -1.0939e+00, 2.4513e+00,-3.5226e-03, 8.0853e-01, 4.0330e-01, 5.2394e-01, 2.7201e+00,-2.4086e-01, -3.3241e-01, 2.9677e+00, -2.2749e-01, 3.1172e+00,7.8760e-02, -1.0339e+00, 1.4011e+00, 5.2701e-01, 8.9391e-01,2.2373e-01, 1.3236e+00, -6.5663e-02, 8.7556e-01, 2.3522e+00,-2.2826e-01, -1.4658e-01, -1.8229e+00, -6.5210e-01, 4.1831e-04],grad_fn=<SelectBackward0>),'HOME': tensor([-1.2881e+00, 9.8371e-01, -1.7626e+00, 6.8964e-02, -1.2208e+00,-7.2041e-01, 1.6493e+00, 2.4161e-01, 3.0407e-01, 1.0450e+00,-3.7338e-02, 1.2912e+00, -7.8684e-01, -8.1084e-02, 3.1615e+00,1.1677e+00, -2.7518e-01, 1.2211e+00, 5.5950e-01, -2.1043e+00,5.2210e-01, -1.7408e-01, 5.1499e-02, 7.7797e-01, -1.4519e-03,-3.4803e-02, -4.3894e-01, -3.7840e+00, 1.8685e+00, 5.1014e-01,2.8481e-04, 7.3540e-01, 4.0983e-02, 1.9889e-01, 2.2323e-01,-1.2719e+00, 9.0170e-01, -1.7608e+00, 1.2378e-04, 3.6426e-01,-2.3393e-01, 3.9977e-01, 4.6494e-01, -2.2011e+00, -2.1913e-02,-2.4567e-04, -2.4916e-01, -9.5079e-01, -2.0207e-01, -7.1489e-02,-3.2497e-02, -2.0102e-01, 5.9411e-02, -7.5153e-01, -5.1971e-01,2.7858e-01, -1.7449e-01, -2.4816e-02, 6.8960e-01, 1.3359e+00,1.4179e+00, 2.1634e-02, 4.1195e-01, -2.4597e+00, -2.2374e+00,4.7058e-01, -3.2053e-01, 1.0844e+00, -8.6147e-01, 1.6927e+00,-1.0051e-01, -2.3251e+00, -1.3552e+00, -1.3862e+00, 4.0486e-01,4.2523e-02, -8.1515e-01, 2.9837e-01, -1.6220e-02, 1.0755e-01,3.7893e-01, -1.4399e+00, -2.8273e-01, -1.4445e-01, 3.2650e-01,2.5101e+00, 2.7584e-01, 2.6028e-01, 4.5515e-03, -1.3406e+00,-6.2879e-02, -3.8538e-01, -1.9729e+00, -1.1987e+00, -1.7349e-01,-2.0273e+00, 9.5012e-01, 3.1583e-02, 1.2475e+00, 1.7564e-01],grad_fn=<SelectBackward0>)}
"""
降维显示图
这里是参考另外一篇文章见最后的章节
"""待转换类型的PyTorch Tensor变量带有梯度,直接将其转换为numpy数据将破坏计算图,因此numpy拒绝进行数据转换,实际上这是对开发者的一种提醒。如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加detach()调用。
"""pca = PCA(n_components=2)
principalComponents = pca.fit_transform(W)# 降维后在生成一个词嵌入字典,即即{单词1:(维度一,维度二),单词2:(维度一,维度二)...}的格式
word2ReduceDimensionVec = {}
for word in word_to_idx.keys():word2ReduceDimensionVec[word] = principalComponents[word_to_idx[word], :]# 将生成的字典写入到文件中,字符集要设定utf8,不然中文乱码
with open("CBOW_ZH_wordvec.txt", 'w', encoding='utf-8') as f:for key in word_to_idx.keys():f.write('\n')f.writelines('"' + str(key) + '":' + str(word_2_vec[key]))f.write('\n')# 将词向量可视化
plt.figure(figsize=(20, 20))
# 只画出1000个,太多显示效果很差
count = 0
for word, wordvec in word2ReduceDimensionVec.items():if count < 1000:plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号,否则负号会显示成方块plt.scatter(wordvec[0], wordvec[1])plt.annotate(word, (wordvec[0], wordvec[1]))count += 1
plt.show()
使用词向量计算相似度
参照链接https://blog.csdn.net/m0_60688978/article/details/137535717,第五点
参考
https://blog.csdn.net/Metal1/article/details/132886936
https://blog.csdn.net/L_goodboy/article/details/136347947