python-pytorch实现CBOW 0.5.000

python-pytorch实现CBOW 0.5.000

    • 数据加载、切词
    • 准备训练数据
    • 准备模型和参数
    • 训练
    • 保存模型
    • 加载模型
    • 简单预测
    • 获取词向量
    • 降维显示图
    • 使用词向量计算相似度
    • 参考

数据加载、切词

按照链接https://blog.csdn.net/m0_60688978/article/details/137538274操作后,可以获得的数据如下

  1. wordList 文本中所有的分词,放入这个数组中
  2. raw_text 这个可以忽略,相当于wordlist的备份,防止数据污染了
  3. vocab 将wordList转变为set,即set(wordList)
  4. vocab_size 所有分词的个数
  5. word_to_idx 字典格式,汉字对应索引
  6. idx_to_word 字典格式,索引对应汉字

准备训练数据

data3 = []
for i in range(2, len(raw_text) - 2):context = [raw_text[i - 2], raw_text[i - 1],raw_text[i + 1], raw_text[i + 2]]target = raw_text[i]data3 .append((context, target))print(data3 [:5])
"""
[(['从零开始', 'Zookeeper', '高', '可靠'], '开源'), (['Zookeeper', '开源', '可靠', '分布式'], '高'), (['开源', '高', '分布式', '一致性'], '可靠'), (['高', '可靠', '一致性', '协调'], '分布式'), (['可靠', '分布式', '协调', '服务'], '一致性')]
"""

准备模型和参数

# 超参数
learning_rate = 0.003
device = torch.device('cpu')
embedding_dim = 100
epoch = 10
class CBOW(nn.Module):def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))out = self.output(out)nll_prob = F.log_softmax(out, dim=-1)return nll_probmodel = CBOW(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练

losses = []
loss_function = nn.NLLLoss()for epoch in trange(3000):total_loss = 0for context, target in data1:context_vector = make_context_vector(context, word_to_idx)target = torch.tensor([word_to_idx[target]])# 梯度清零model.zero_grad()# 开始前向传播train_predict = model(context_vector) loss = loss_function(train_predict, target)# 反向传播loss.backward()# 更新参数optimizer.step()total_loss += loss.item()if epoch % 100 ==0:print("loss is ",total_loss,"echo is ",epoch)losses.append(total_loss)
print("losses-=", losses)
"""97%|███████████████████████████████████████████████████████████████████████████▍  | 2902/3000 [07:07<00:13,  7.17it/s]
loss is  0.18700819212244824 echo is  2900
100%|██████████████████████████████████████████████████████████████████████████████| 3000/3000 [07:21<00:00,  6.79it/s]
"""

保存模型

torch.save(model.state_dict(),"model.pth")

加载模型

model = CBOW(vocab_size, embedding_dim).to(device)
model.load_state_dict(torch.load("model.pth"))
print(model)
"""
CBOW((embeddings): Embedding(179, 100)(proj): Linear(in_features=100, out_features=128, bias=True)(output): Linear(in_features=128, out_features=179, bias=True)
)
"""

简单预测

def cut_sentense(str):stop_words = load_stop_words()with open('data/zh.txt', encoding='utf8') as f:allData = f.readlines()result = []c_words = jieba.lcut(str)for word in c_words:if word not in stop_words and word != "\n":result.append(word)return resultcontext_vector = make_context_vector(cut_sentense("在Master节点使用客户端"), word_to_idx).to(device)
print(context_vector,type(context_vector))
predict = model(context_vector).data.cpu().numpy()
max_idx = np.argmax(predict)
# 输出预测的值
print('Prediction: {}'.format(idx_to_word[max_idx]))"""
输出中心词语,看上去不怎么样
tensor([120,  37,  49]) <class 'torch.Tensor'>
Prediction: 除主

获取词向量

trained_vector_dic={}
for word, idx in word_to_idx.items(): # 输出每个词的嵌入向量trained_vector_dic[word]=model.embedding.weight[idx]
"""
trained_vector_dic内容类似于
{'参数值': tensor([-3.6921e+00, -1.3388e+00,  2.4545e-03, -1.1352e+00, -1.8306e-04,-6.3501e-01, -1.4372e-01, -8.2283e-01, -1.6009e+00, -7.4731e-01,-1.3509e-01, -2.5100e-01, -1.0037e+00,  9.0061e-01,  1.7794e-01,-8.6344e-03, -1.2831e+00, -2.1400e+00,  2.7457e-01,  1.8157e-01,2.1480e-01, -2.2192e-02, -3.8433e-01,  1.3575e+00,  1.8483e+00,-6.6326e-01, -2.0239e+00, -1.9854e+00,  4.0531e-01, -1.5659e-01,-2.7774e+00, -8.2578e-02,  1.5725e+00, -9.9693e-01,  6.0748e-01,-6.4992e-01,  8.5653e-01, -1.1889e+00,  1.1657e-04, -3.3866e-01,8.2302e-02,  1.0612e-02, -8.8592e-01, -1.9495e-01, -1.2271e-01,-4.1997e+00,  1.3430e+00, -6.6779e-01, -1.7927e-01,  3.0450e-01,8.4677e-02, -9.5100e-01,  2.5847e-01,  1.1187e+00,  3.1471e+00,2.4095e+00, -1.0612e-01,  2.1663e+00, -8.5172e-01, -2.1438e-01,2.3635e-01,  4.7740e-01, -2.8115e+00, -1.5964e-01,  4.9957e-02,1.6154e-01, -7.0892e-01, -5.6724e-01, -2.2594e-01, -1.2353e+00,8.9448e-01, -1.7034e-01, -6.5750e-01,  9.8126e-01, -1.7088e+00,-1.9967e-01,  2.6574e-01, -1.3275e-01,  6.1529e-01, -3.6684e-01,1.7341e-02,  1.5207e-03, -4.8425e-01, -2.2761e-01, -2.2298e+00,-5.5302e-01,  4.4864e-01, -2.5363e-01,  3.4734e-01, -4.4062e-02,-1.3769e+00,  1.6567e-01, -7.3674e-01, -8.4163e-01,  2.9937e-01,2.3714e+00,  1.2883e+00,  1.2383e-01,  7.5008e-01, -1.3516e-01],grad_fn=<SelectBackward0>),'05': tensor([ 1.1536e+00, -2.2545e-01, -9.9584e-01,  2.0407e-02,  1.9062e+00,-5.5870e-01, -6.1779e-04,  2.7210e-01, -1.9126e+00, -8.1227e-02,-6.0733e-02, -3.3426e-03,  9.4838e-01,  3.1968e-01,  1.1331e+00,1.9320e-01,  9.8004e-01,  1.3209e-01,  3.9876e-01,  1.9894e-01,9.6364e-01, -2.9291e-01, -1.4829e+00,  1.9647e+00, -1.2805e-01,1.7458e+00,  9.1834e-02,  7.3453e-01, -1.4541e-01, -1.5197e+00,2.5946e-01,  1.1071e+00,  2.3167e-02, -9.9457e-01, -6.4125e-02,-2.1326e-01, -2.1815e+00, -8.3949e-02, -3.8223e-01,  2.0616e+00,-7.3382e-02,  2.6695e-01,  9.4765e-02, -3.2757e-01, -4.8486e-01,-3.0599e-01,  8.8235e-01,  3.1940e-01, -1.3256e-01, -6.0862e-01,4.4978e-01, -3.0902e+00,  1.6898e+00,  5.7821e-01, -5.2478e-02,4.9577e-01,  4.5494e-01,  5.6485e-04, -2.5271e+00,  3.1652e+00,-4.2832e-02, -9.9416e-02,  3.1775e-01, -1.9758e+00, -1.2955e-02,-1.6038e+00,  5.3717e-02,  2.9455e-03, -3.6091e-01, -5.7126e-01,1.6538e+00, -2.0648e+00, -3.1718e-01, -1.0939e+00,  2.4513e+00,-3.5226e-03,  8.0853e-01,  4.0330e-01,  5.2394e-01,  2.7201e+00,-2.4086e-01, -3.3241e-01,  2.9677e+00, -2.2749e-01,  3.1172e+00,7.8760e-02, -1.0339e+00,  1.4011e+00,  5.2701e-01,  8.9391e-01,2.2373e-01,  1.3236e+00, -6.5663e-02,  8.7556e-01,  2.3522e+00,-2.2826e-01, -1.4658e-01, -1.8229e+00, -6.5210e-01,  4.1831e-04],grad_fn=<SelectBackward0>),'HOME': tensor([-1.2881e+00,  9.8371e-01, -1.7626e+00,  6.8964e-02, -1.2208e+00,-7.2041e-01,  1.6493e+00,  2.4161e-01,  3.0407e-01,  1.0450e+00,-3.7338e-02,  1.2912e+00, -7.8684e-01, -8.1084e-02,  3.1615e+00,1.1677e+00, -2.7518e-01,  1.2211e+00,  5.5950e-01, -2.1043e+00,5.2210e-01, -1.7408e-01,  5.1499e-02,  7.7797e-01, -1.4519e-03,-3.4803e-02, -4.3894e-01, -3.7840e+00,  1.8685e+00,  5.1014e-01,2.8481e-04,  7.3540e-01,  4.0983e-02,  1.9889e-01,  2.2323e-01,-1.2719e+00,  9.0170e-01, -1.7608e+00,  1.2378e-04,  3.6426e-01,-2.3393e-01,  3.9977e-01,  4.6494e-01, -2.2011e+00, -2.1913e-02,-2.4567e-04, -2.4916e-01, -9.5079e-01, -2.0207e-01, -7.1489e-02,-3.2497e-02, -2.0102e-01,  5.9411e-02, -7.5153e-01, -5.1971e-01,2.7858e-01, -1.7449e-01, -2.4816e-02,  6.8960e-01,  1.3359e+00,1.4179e+00,  2.1634e-02,  4.1195e-01, -2.4597e+00, -2.2374e+00,4.7058e-01, -3.2053e-01,  1.0844e+00, -8.6147e-01,  1.6927e+00,-1.0051e-01, -2.3251e+00, -1.3552e+00, -1.3862e+00,  4.0486e-01,4.2523e-02, -8.1515e-01,  2.9837e-01, -1.6220e-02,  1.0755e-01,3.7893e-01, -1.4399e+00, -2.8273e-01, -1.4445e-01,  3.2650e-01,2.5101e+00,  2.7584e-01,  2.6028e-01,  4.5515e-03, -1.3406e+00,-6.2879e-02, -3.8538e-01, -1.9729e+00, -1.1987e+00, -1.7349e-01,-2.0273e+00,  9.5012e-01,  3.1583e-02,  1.2475e+00,  1.7564e-01],grad_fn=<SelectBackward0>)}
"""

降维显示图

这里是参考另外一篇文章见最后的章节

"""待转换类型的PyTorch Tensor变量带有梯度,直接将其转换为numpy数据将破坏计算图,因此numpy拒绝进行数据转换,实际上这是对开发者的一种提醒。如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加detach()调用。
"""pca = PCA(n_components=2)
principalComponents = pca.fit_transform(W)# 降维后在生成一个词嵌入字典,即即{单词1:(维度一,维度二),单词2:(维度一,维度二)...}的格式
word2ReduceDimensionVec = {}
for word in word_to_idx.keys():word2ReduceDimensionVec[word] = principalComponents[word_to_idx[word], :]# 将生成的字典写入到文件中,字符集要设定utf8,不然中文乱码
with open("CBOW_ZH_wordvec.txt", 'w', encoding='utf-8') as f:for key in word_to_idx.keys():f.write('\n')f.writelines('"' + str(key) + '":' + str(word_2_vec[key]))f.write('\n')# 将词向量可视化
plt.figure(figsize=(20, 20))
# 只画出1000个,太多显示效果很差
count = 0
for word, wordvec in word2ReduceDimensionVec.items():if count < 1000:plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号,否则负号会显示成方块plt.scatter(wordvec[0], wordvec[1])plt.annotate(word, (wordvec[0], wordvec[1]))count += 1
plt.show()

在这里插入图片描述

使用词向量计算相似度

参照链接https://blog.csdn.net/m0_60688978/article/details/137535717,第五点

参考

https://blog.csdn.net/Metal1/article/details/132886936
https://blog.csdn.net/L_goodboy/article/details/136347947

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

由近期 RAGFlow 的火爆看 RAG 的现状与未来

4 月 1 日&#xff0c;InfiniFlow &#xff08;英飞流&#xff09;的端到端 RAG 解决方案 RAGFlow 正式开源&#xff0c;首日即获得了 github 千星&#xff0c;目前已接近 3000 star。在这之前&#xff0c;InfiniFlow 还开源了专门用于 RAG 场景的 AI 原生数据库 Infinity&…

用 ElementPlus 的日历组件 Calendar 自定义渲染

文章目录 需求分析1. 英文改为中文2. 修改样式3. 自定义头部4. 增删改功能接入需求 使用 ElementPlus中的 Calendar 组件完成自定义渲染 分析 1. 英文改为中文 转为中文的方式:用 ElementPlus的日历组件如何改为中文 2. 修改样式 附源码<template><el-calendar&…

[工程经验] 模块设计规范

模块设计规范 文章目录 模块设计规范1.需求2.概念与逻辑图3.主要的数据结构图4.算法5.接口定义 1.需求 根据需求文档&#xff0c;摘录模块的对应部分&#xff0c;细化到可指导开发的程度&#xff0c;并根据实现的需要进行拓展&#xff0c;落地为一份设计文档。 2.概念与逻辑图…

linux查看硬盘空间使用情况

df &#xff08;1&#xff09;查看磁盘空间的占用情况 -h是给大小带上单位 df -h 总空间不一定等于已用未用&#xff0c;系统可能留出来一点空间另做他用 &#xff08;2&#xff09;查看INode的使用情况 df -idu du命令比df命令复杂一点&#xff0c;是查看文件和目录占用的…

【排序算法】七、快速排序补充:三指针+随机数法

「前言」文章内容是对快速排序算法的补充&#xff0c;之前的算法流程细节多难处理&#xff0c;这里补充三指针随机数法&#xff08;递归&#xff09;&#xff0c;这个容易理解&#xff0c;在时间复杂度上也更优秀 「归属专栏」排序算法 「主页链接」个人主页 「笔者」枫叶先生(…

Docker-compose部署Alertmanager+Dingtalk+Prometheus+Grafana实现钉钉报警

部署监控 version: 3.7services: #dingtalkdingtalk:image: timonwong/prometheus-webhook-dingtalk:latestcontainer_name: dingtalkrestart: alwayscommand:- --config.file/etc/prometheus-webhook-dingtalk/config.ymlvolumes:- /data/monitor/dingtalk/config.yml:/etc/p…

部署GlusterFS群集

目录 一、部署GlusterFS群集 1. 服务器节点分配 2. 服务器环境&#xff08;所有node节点上操作&#xff09; 2.1 关闭防火墙 2.2 磁盘分区&#xff0c;并挂载 2.3 修改主机名&#xff0c;配置/etc/hosts文件 3. 安装、启动GlusterFS&#xff08;所有node节点上操作&…

51单片机入门_江协科技_25~26_OB记录的笔记_蜂鸣器教程

25. 蜂鸣器 25.1. 蜂鸣器介绍 •蜂鸣器是一种将电信号转换为声音信号的器件&#xff0c;常用来产生设备的按键音、报警音等提示信号 •蜂鸣器按驱动方式可分为有源蜂鸣器和无源蜂鸣器&#xff08;开发板上用的无源蜂鸣器&#xff09; •有源蜂鸣器&#xff1a;内部自带振荡源&a…

二:什么是RocketMQ

RocketMQ是阿里开源的消息中间件产品&#xff0c;纯Java开发&#xff0c;具有高吞吐量、高可用性、适合大规模分布式系统应用的特点,性能强劲(零拷贝技术)&#xff0c;支持海量堆积,在阿里内部进行大规模使用&#xff0c;适合在互联网与高并发系统中应用。 官方文档&#xff1a…

【Linux】虚拟化技术docker搭建SuitoCRM系统及汉化

CRM系统 CRM&#xff08;Customer Relationship Management&#xff0c;客户关系管理&#xff09;系统是一种用于管理和优化企业与客户关系的软件工具。在商业竞争激烈的现代社会中&#xff0c;CRM系统已成为许多企业提高销售、增强客户满意度和实现持续增长的重要工具。本文将…

Hive-生产常用操作-表操作和数据处理技巧-202404

hive语句操作 我这个只涉及到hive的对表的操作&#xff0c;包括建表&#xff0c;建分区表&#xff0c;加载数据&#xff0c;导出数据&#xff0c;查询数据&#xff0c;删除数据&#xff0c;插入数据&#xff0c;以及对hive分区表的操作&#xff0c;包括查看分区&#xff0c;添加…

【宝德PI300T G2智能小站开发教程(二)】命令行linux如何挂载移动硬盘

目录 一.前言 二.步骤 1.查找移动硬盘: 2.建立挂载点 3.挂载 4.进入硬盘 5.解除挂载 一.前言 Linux中的挂载是将存储设备(如硬盘、分区、USB驱动器等)与文件系统关联起来,以便能够访问和使用其存储空间。 二.步骤 1.查找移动硬盘:

数据检索的优化之道:B树与B+树的深度解析与应用探索

1、引言 在信息时代&#xff0c;数据检索的速度和效率对于任何依赖数据处理的系统来说都至关重要。无论是在线搜索引擎、数据库管理系统还是文件存储系统&#xff0c;快速准确地检索所需数据都是核心需求。传统的线性数据结构在处理大规模数据集时往往力不从心&#xff0c;因此…

计算器(C语言)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 介绍关键代码运行代码&#xff08;3种&#xff09; 介绍 标准计数器&#xff1a;执行加减乘除等等科学计算器&#xff1a;执行分数、统计学、指数函数、对数、三角…

封装Element-Plus表单组件

业务组件 <template><m-form ref="form":options="options" label-width="100px"@on-preview="handlePreview"@on-remove="handleRemove"@before-remove="beforeRemove"@on-exceed="handleExceed&…

如何学习JVM的知识

文章目录 1. 明确学习JVM知识的目的2. 高屋建瓴地审视知识点 1. 明确学习JVM知识的目的 为什么需要学习jvm的知识&#xff1f; jvm的知识重点是内存分配和垃圾回收&#xff0c;这些都是能更深入理解java代码运行原理的关键&#xff0c;也是求职面试中绕不过去的一个坎。 并且它…

BoostCompass(数据准备预处理模块)

阅读导航 一、网页数据下载二、编写数据去标签与数据清洗的模块 Parser✅boost 开发库的安装1. 基本思路2. 详细讲解&#xff08;1&#xff09;程序递归遍历目录&#xff0c;收集所有HTML文件的路径&#xff08;2&#xff09;对每个HTML文件进行解析&#xff0c;提取出文档标题…

【资源分享】书籍:现代统计学:使用Python的计算方法

::: block-1 “时问桫椤”是一个致力于为本科生到研究生教育阶段提供帮助的不太正式的公众号。我们旨在在大家感到困惑、痛苦或面临困难时伸出援手。通过总结广大研究生的经验&#xff0c;帮助大家尽早适应研究生生活&#xff0c;尽快了解科研的本质。祝一切顺利&#xff01;—…

【微服务】------微服务架构技术栈

目前微服务早已火遍大江南北&#xff0c;对于开发来说&#xff0c;我们时刻关注着技术的迭代更新&#xff0c;而项目采用什么技术栈选型落地是开发、产品都需要关注的事情&#xff0c;该篇博客主要分享一些目前普遍公司都在用的技术栈&#xff0c;快来分享一下你当前所在用的技…

深入理解与实践:npm常用命令全面解析

引言 npm的重要性&#xff1a;简要介绍npm&#xff08;Node Package Manager&#xff09;作为Node.js生态系统的基石&#xff0c;其在JavaScript开发中的角色和作用。npm的功能概述&#xff1a;包管理和发布、依赖管理、版本控制、脚本执行等核心功能说明。 一、npm基础操作 …