python-pytorch实现CBOW 0.5.000

python-pytorch实现CBOW 0.5.000

    • 数据加载、切词
    • 准备训练数据
    • 准备模型和参数
    • 训练
    • 保存模型
    • 加载模型
    • 简单预测
    • 获取词向量
    • 降维显示图
    • 使用词向量计算相似度
    • 参考

数据加载、切词

按照链接https://blog.csdn.net/m0_60688978/article/details/137538274操作后,可以获得的数据如下

  1. wordList 文本中所有的分词,放入这个数组中
  2. raw_text 这个可以忽略,相当于wordlist的备份,防止数据污染了
  3. vocab 将wordList转变为set,即set(wordList)
  4. vocab_size 所有分词的个数
  5. word_to_idx 字典格式,汉字对应索引
  6. idx_to_word 字典格式,索引对应汉字

准备训练数据

data3 = []
for i in range(2, len(raw_text) - 2):context = [raw_text[i - 2], raw_text[i - 1],raw_text[i + 1], raw_text[i + 2]]target = raw_text[i]data3 .append((context, target))print(data3 [:5])
"""
[(['从零开始', 'Zookeeper', '高', '可靠'], '开源'), (['Zookeeper', '开源', '可靠', '分布式'], '高'), (['开源', '高', '分布式', '一致性'], '可靠'), (['高', '可靠', '一致性', '协调'], '分布式'), (['可靠', '分布式', '协调', '服务'], '一致性')]
"""

准备模型和参数

# 超参数
learning_rate = 0.003
device = torch.device('cpu')
embedding_dim = 100
epoch = 10
class CBOW(nn.Module):def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))out = self.output(out)nll_prob = F.log_softmax(out, dim=-1)return nll_probmodel = CBOW(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练

losses = []
loss_function = nn.NLLLoss()for epoch in trange(3000):total_loss = 0for context, target in data1:context_vector = make_context_vector(context, word_to_idx)target = torch.tensor([word_to_idx[target]])# 梯度清零model.zero_grad()# 开始前向传播train_predict = model(context_vector) loss = loss_function(train_predict, target)# 反向传播loss.backward()# 更新参数optimizer.step()total_loss += loss.item()if epoch % 100 ==0:print("loss is ",total_loss,"echo is ",epoch)losses.append(total_loss)
print("losses-=", losses)
"""97%|███████████████████████████████████████████████████████████████████████████▍  | 2902/3000 [07:07<00:13,  7.17it/s]
loss is  0.18700819212244824 echo is  2900
100%|██████████████████████████████████████████████████████████████████████████████| 3000/3000 [07:21<00:00,  6.79it/s]
"""

保存模型

torch.save(model.state_dict(),"model.pth")

加载模型

model = CBOW(vocab_size, embedding_dim).to(device)
model.load_state_dict(torch.load("model.pth"))
print(model)
"""
CBOW((embeddings): Embedding(179, 100)(proj): Linear(in_features=100, out_features=128, bias=True)(output): Linear(in_features=128, out_features=179, bias=True)
)
"""

简单预测

def cut_sentense(str):stop_words = load_stop_words()with open('data/zh.txt', encoding='utf8') as f:allData = f.readlines()result = []c_words = jieba.lcut(str)for word in c_words:if word not in stop_words and word != "\n":result.append(word)return resultcontext_vector = make_context_vector(cut_sentense("在Master节点使用客户端"), word_to_idx).to(device)
print(context_vector,type(context_vector))
predict = model(context_vector).data.cpu().numpy()
max_idx = np.argmax(predict)
# 输出预测的值
print('Prediction: {}'.format(idx_to_word[max_idx]))"""
输出中心词语,看上去不怎么样
tensor([120,  37,  49]) <class 'torch.Tensor'>
Prediction: 除主

获取词向量

trained_vector_dic={}
for word, idx in word_to_idx.items(): # 输出每个词的嵌入向量trained_vector_dic[word]=model.embedding.weight[idx]
"""
trained_vector_dic内容类似于
{'参数值': tensor([-3.6921e+00, -1.3388e+00,  2.4545e-03, -1.1352e+00, -1.8306e-04,-6.3501e-01, -1.4372e-01, -8.2283e-01, -1.6009e+00, -7.4731e-01,-1.3509e-01, -2.5100e-01, -1.0037e+00,  9.0061e-01,  1.7794e-01,-8.6344e-03, -1.2831e+00, -2.1400e+00,  2.7457e-01,  1.8157e-01,2.1480e-01, -2.2192e-02, -3.8433e-01,  1.3575e+00,  1.8483e+00,-6.6326e-01, -2.0239e+00, -1.9854e+00,  4.0531e-01, -1.5659e-01,-2.7774e+00, -8.2578e-02,  1.5725e+00, -9.9693e-01,  6.0748e-01,-6.4992e-01,  8.5653e-01, -1.1889e+00,  1.1657e-04, -3.3866e-01,8.2302e-02,  1.0612e-02, -8.8592e-01, -1.9495e-01, -1.2271e-01,-4.1997e+00,  1.3430e+00, -6.6779e-01, -1.7927e-01,  3.0450e-01,8.4677e-02, -9.5100e-01,  2.5847e-01,  1.1187e+00,  3.1471e+00,2.4095e+00, -1.0612e-01,  2.1663e+00, -8.5172e-01, -2.1438e-01,2.3635e-01,  4.7740e-01, -2.8115e+00, -1.5964e-01,  4.9957e-02,1.6154e-01, -7.0892e-01, -5.6724e-01, -2.2594e-01, -1.2353e+00,8.9448e-01, -1.7034e-01, -6.5750e-01,  9.8126e-01, -1.7088e+00,-1.9967e-01,  2.6574e-01, -1.3275e-01,  6.1529e-01, -3.6684e-01,1.7341e-02,  1.5207e-03, -4.8425e-01, -2.2761e-01, -2.2298e+00,-5.5302e-01,  4.4864e-01, -2.5363e-01,  3.4734e-01, -4.4062e-02,-1.3769e+00,  1.6567e-01, -7.3674e-01, -8.4163e-01,  2.9937e-01,2.3714e+00,  1.2883e+00,  1.2383e-01,  7.5008e-01, -1.3516e-01],grad_fn=<SelectBackward0>),'05': tensor([ 1.1536e+00, -2.2545e-01, -9.9584e-01,  2.0407e-02,  1.9062e+00,-5.5870e-01, -6.1779e-04,  2.7210e-01, -1.9126e+00, -8.1227e-02,-6.0733e-02, -3.3426e-03,  9.4838e-01,  3.1968e-01,  1.1331e+00,1.9320e-01,  9.8004e-01,  1.3209e-01,  3.9876e-01,  1.9894e-01,9.6364e-01, -2.9291e-01, -1.4829e+00,  1.9647e+00, -1.2805e-01,1.7458e+00,  9.1834e-02,  7.3453e-01, -1.4541e-01, -1.5197e+00,2.5946e-01,  1.1071e+00,  2.3167e-02, -9.9457e-01, -6.4125e-02,-2.1326e-01, -2.1815e+00, -8.3949e-02, -3.8223e-01,  2.0616e+00,-7.3382e-02,  2.6695e-01,  9.4765e-02, -3.2757e-01, -4.8486e-01,-3.0599e-01,  8.8235e-01,  3.1940e-01, -1.3256e-01, -6.0862e-01,4.4978e-01, -3.0902e+00,  1.6898e+00,  5.7821e-01, -5.2478e-02,4.9577e-01,  4.5494e-01,  5.6485e-04, -2.5271e+00,  3.1652e+00,-4.2832e-02, -9.9416e-02,  3.1775e-01, -1.9758e+00, -1.2955e-02,-1.6038e+00,  5.3717e-02,  2.9455e-03, -3.6091e-01, -5.7126e-01,1.6538e+00, -2.0648e+00, -3.1718e-01, -1.0939e+00,  2.4513e+00,-3.5226e-03,  8.0853e-01,  4.0330e-01,  5.2394e-01,  2.7201e+00,-2.4086e-01, -3.3241e-01,  2.9677e+00, -2.2749e-01,  3.1172e+00,7.8760e-02, -1.0339e+00,  1.4011e+00,  5.2701e-01,  8.9391e-01,2.2373e-01,  1.3236e+00, -6.5663e-02,  8.7556e-01,  2.3522e+00,-2.2826e-01, -1.4658e-01, -1.8229e+00, -6.5210e-01,  4.1831e-04],grad_fn=<SelectBackward0>),'HOME': tensor([-1.2881e+00,  9.8371e-01, -1.7626e+00,  6.8964e-02, -1.2208e+00,-7.2041e-01,  1.6493e+00,  2.4161e-01,  3.0407e-01,  1.0450e+00,-3.7338e-02,  1.2912e+00, -7.8684e-01, -8.1084e-02,  3.1615e+00,1.1677e+00, -2.7518e-01,  1.2211e+00,  5.5950e-01, -2.1043e+00,5.2210e-01, -1.7408e-01,  5.1499e-02,  7.7797e-01, -1.4519e-03,-3.4803e-02, -4.3894e-01, -3.7840e+00,  1.8685e+00,  5.1014e-01,2.8481e-04,  7.3540e-01,  4.0983e-02,  1.9889e-01,  2.2323e-01,-1.2719e+00,  9.0170e-01, -1.7608e+00,  1.2378e-04,  3.6426e-01,-2.3393e-01,  3.9977e-01,  4.6494e-01, -2.2011e+00, -2.1913e-02,-2.4567e-04, -2.4916e-01, -9.5079e-01, -2.0207e-01, -7.1489e-02,-3.2497e-02, -2.0102e-01,  5.9411e-02, -7.5153e-01, -5.1971e-01,2.7858e-01, -1.7449e-01, -2.4816e-02,  6.8960e-01,  1.3359e+00,1.4179e+00,  2.1634e-02,  4.1195e-01, -2.4597e+00, -2.2374e+00,4.7058e-01, -3.2053e-01,  1.0844e+00, -8.6147e-01,  1.6927e+00,-1.0051e-01, -2.3251e+00, -1.3552e+00, -1.3862e+00,  4.0486e-01,4.2523e-02, -8.1515e-01,  2.9837e-01, -1.6220e-02,  1.0755e-01,3.7893e-01, -1.4399e+00, -2.8273e-01, -1.4445e-01,  3.2650e-01,2.5101e+00,  2.7584e-01,  2.6028e-01,  4.5515e-03, -1.3406e+00,-6.2879e-02, -3.8538e-01, -1.9729e+00, -1.1987e+00, -1.7349e-01,-2.0273e+00,  9.5012e-01,  3.1583e-02,  1.2475e+00,  1.7564e-01],grad_fn=<SelectBackward0>)}
"""

降维显示图

这里是参考另外一篇文章见最后的章节

"""待转换类型的PyTorch Tensor变量带有梯度,直接将其转换为numpy数据将破坏计算图,因此numpy拒绝进行数据转换,实际上这是对开发者的一种提醒。如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加detach()调用。
"""pca = PCA(n_components=2)
principalComponents = pca.fit_transform(W)# 降维后在生成一个词嵌入字典,即即{单词1:(维度一,维度二),单词2:(维度一,维度二)...}的格式
word2ReduceDimensionVec = {}
for word in word_to_idx.keys():word2ReduceDimensionVec[word] = principalComponents[word_to_idx[word], :]# 将生成的字典写入到文件中,字符集要设定utf8,不然中文乱码
with open("CBOW_ZH_wordvec.txt", 'w', encoding='utf-8') as f:for key in word_to_idx.keys():f.write('\n')f.writelines('"' + str(key) + '":' + str(word_2_vec[key]))f.write('\n')# 将词向量可视化
plt.figure(figsize=(20, 20))
# 只画出1000个,太多显示效果很差
count = 0
for word, wordvec in word2ReduceDimensionVec.items():if count < 1000:plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号,否则负号会显示成方块plt.scatter(wordvec[0], wordvec[1])plt.annotate(word, (wordvec[0], wordvec[1]))count += 1
plt.show()

在这里插入图片描述

使用词向量计算相似度

参照链接https://blog.csdn.net/m0_60688978/article/details/137535717,第五点

参考

https://blog.csdn.net/Metal1/article/details/132886936
https://blog.csdn.net/L_goodboy/article/details/136347947

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

由近期 RAGFlow 的火爆看 RAG 的现状与未来

4 月 1 日&#xff0c;InfiniFlow &#xff08;英飞流&#xff09;的端到端 RAG 解决方案 RAGFlow 正式开源&#xff0c;首日即获得了 github 千星&#xff0c;目前已接近 3000 star。在这之前&#xff0c;InfiniFlow 还开源了专门用于 RAG 场景的 AI 原生数据库 Infinity&…

用 ElementPlus 的日历组件 Calendar 自定义渲染

文章目录 需求分析1. 英文改为中文2. 修改样式3. 自定义头部4. 增删改功能接入需求 使用 ElementPlus中的 Calendar 组件完成自定义渲染 分析 1. 英文改为中文 转为中文的方式:用 ElementPlus的日历组件如何改为中文 2. 修改样式 附源码<template><el-calendar&…

linux查看硬盘空间使用情况

df &#xff08;1&#xff09;查看磁盘空间的占用情况 -h是给大小带上单位 df -h 总空间不一定等于已用未用&#xff0c;系统可能留出来一点空间另做他用 &#xff08;2&#xff09;查看INode的使用情况 df -idu du命令比df命令复杂一点&#xff0c;是查看文件和目录占用的…

部署GlusterFS群集

目录 一、部署GlusterFS群集 1. 服务器节点分配 2. 服务器环境&#xff08;所有node节点上操作&#xff09; 2.1 关闭防火墙 2.2 磁盘分区&#xff0c;并挂载 2.3 修改主机名&#xff0c;配置/etc/hosts文件 3. 安装、启动GlusterFS&#xff08;所有node节点上操作&…

51单片机入门_江协科技_25~26_OB记录的笔记_蜂鸣器教程

25. 蜂鸣器 25.1. 蜂鸣器介绍 •蜂鸣器是一种将电信号转换为声音信号的器件&#xff0c;常用来产生设备的按键音、报警音等提示信号 •蜂鸣器按驱动方式可分为有源蜂鸣器和无源蜂鸣器&#xff08;开发板上用的无源蜂鸣器&#xff09; •有源蜂鸣器&#xff1a;内部自带振荡源&a…

二:什么是RocketMQ

RocketMQ是阿里开源的消息中间件产品&#xff0c;纯Java开发&#xff0c;具有高吞吐量、高可用性、适合大规模分布式系统应用的特点,性能强劲(零拷贝技术)&#xff0c;支持海量堆积,在阿里内部进行大规模使用&#xff0c;适合在互联网与高并发系统中应用。 官方文档&#xff1a…

【Linux】虚拟化技术docker搭建SuitoCRM系统及汉化

CRM系统 CRM&#xff08;Customer Relationship Management&#xff0c;客户关系管理&#xff09;系统是一种用于管理和优化企业与客户关系的软件工具。在商业竞争激烈的现代社会中&#xff0c;CRM系统已成为许多企业提高销售、增强客户满意度和实现持续增长的重要工具。本文将…

计算器(C语言)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 介绍关键代码运行代码&#xff08;3种&#xff09; 介绍 标准计数器&#xff1a;执行加减乘除等等科学计算器&#xff1a;执行分数、统计学、指数函数、对数、三角…

BoostCompass(数据准备预处理模块)

阅读导航 一、网页数据下载二、编写数据去标签与数据清洗的模块 Parser✅boost 开发库的安装1. 基本思路2. 详细讲解&#xff08;1&#xff09;程序递归遍历目录&#xff0c;收集所有HTML文件的路径&#xff08;2&#xff09;对每个HTML文件进行解析&#xff0c;提取出文档标题…

【资源分享】书籍:现代统计学:使用Python的计算方法

::: block-1 “时问桫椤”是一个致力于为本科生到研究生教育阶段提供帮助的不太正式的公众号。我们旨在在大家感到困惑、痛苦或面临困难时伸出援手。通过总结广大研究生的经验&#xff0c;帮助大家尽早适应研究生生活&#xff0c;尽快了解科研的本质。祝一切顺利&#xff01;—…

【微服务】------微服务架构技术栈

目前微服务早已火遍大江南北&#xff0c;对于开发来说&#xff0c;我们时刻关注着技术的迭代更新&#xff0c;而项目采用什么技术栈选型落地是开发、产品都需要关注的事情&#xff0c;该篇博客主要分享一些目前普遍公司都在用的技术栈&#xff0c;快来分享一下你当前所在用的技…

Java每日一题(三道同一类型的题)

前言 本文一共有三道题:1.两数之和 2.三数之和 3. 四数之和 为什么把这三道题放一起呢&#xff0c;因为三数之和是可以根据两数之和进行推导&#xff0c;四数之和可以根据三数之和进行推导。 两数之和 思路分析: 我的思路: 1.排序 2.使用左右指针 3.处理细节问题 先让数组…

生活中的数学 --- 等额本息贷款和等额本金贷款的月供应该怎么算?

等额本息贷款和等额本金贷款的月供应该怎么算&#xff1f; 从一个例子开始&#xff0c;假设我要从银行贷款36万(即&#xff0c;本金)&#xff0c;银行给出的贷款年利率是12%(月利率为年利率除以12)&#xff0c;贷款半年(6个月)&#xff0c;按月还款&#xff0c;分6期还完。 问分…

电池二次利用走向可持续大循环周期的潜力和挑战(第二篇)

一、二次利用风险 电动汽车的当前电池信息&#xff0c;如年份、容量和制造商&#xff0c;通常是相互关联和不完整的。再加上电池内部的电化学变化&#xff0c;SLB在包括安全和环境在内的一些领域存在很大的风险&#xff0c;这表明短期内梯次利用仍然是一个不成熟的方案。 1.1 安…

在mysql中如何更新数据呢?

如何更新一条数据&#xff1f; 在 MySQL 中&#xff0c;更新一条数据可以使用 UPDATE 语句。以下是更新一条数据的基本语法&#xff1a; UPDATE table_name SET column1 value1, column2 value2,... WHERE condition;其中&#xff1a; table_name&#xff1a;要更新的表的…

Linux 系统下对于 MySQL 的初级操作

由于公司老板想把早已封存的服务器陈年老码捣鼓一下&#xff0c;所以找了一个外援&#xff0c;我则是配合提供支持。但是过程并不顺利。至少 5 年以上的间隔&#xff0c;导致外援查看的时候发现很多代码和配置是缺失的&#xff0c;目前卡在数据库部分&#xff0c;而我这边就帮忙…

libVLC 提取视频帧使用QGraphicsView渲染

在前面章节中&#xff0c;我们讲解了如何使用QWidget渲染每一帧视频数据&#xff0c;这种方法对 CPU 负荷较高。 libVLC 提取视频帧使用QWidget渲染-CSDN博客 后面又讲解了使用OpenGL渲染每一帧视频数据&#xff0c;使用 OpenGL去绘制&#xff0c;利用 GPU 减轻 CPU 计算负荷…

亚马逊AWS永久免费数据库

Amazon DynamoDB 是一项无服务器的 NoSQL 数据库服务&#xff0c;您可以通过它来开发任何规模的现代应用程序。作为无服务器数据库&#xff0c;您只需按使用量为其付费&#xff0c;DynamoDB 可以扩展到零&#xff0c;没有冷启动&#xff0c;没有版本升级&#xff0c;没有维护窗…

交换机与队列的介绍

1.流程 首先先介绍一个简单的一个消息推送到接收的流程&#xff0c;提供一个简单的图 黄色的圈圈就是我们的消息推送服务&#xff0c;将消息推送到 中间方框里面也就是 rabbitMq的服务器&#xff0c;然后经过服务器里面的交换机、队列等各种关系&#xff08;后面会详细讲&…

RabbitMQ如何保证消息的幂等性???

在RabbitMQ中&#xff0c;保证消费者的幂等性主要依赖于业务设计和实现&#xff0c;而非RabbitMQ本身提供的一种直接功能。 在基于Spring Boot整合RabbitMQ的场景下&#xff0c;要保证消费者的幂等性&#xff0c;通常需要结合业务逻辑设计以及额外的技术手段来实现。以下是一个…