Word2Vec的CBOW模型

Word2Vec中的CBOW(Continuous Bag of Words)模型是一种用于学习词向量的神经网络模型。CBOW的核心思想是根据上下文中的周围单词来预测目标单词。

例如,对于句子“The cat climbed up the tree”,如果窗口大小为5,那么当中心单词为“climbed”时,上下文单词为“The”、“cat”、“up”和“the”。CBOW模型要求根据这四个上下文单词,计算出“climbed”的概率分布。

一个简单的CBOW模型

import torch
import torch.nn as nn
import torch.optim as optim# 定义CBOW模型
class CBOWModel(nn.Module):def __init__(self, vocab_size, embed_size):super(CBOWModel, self).__init__()self.embeddings = nn.Embedding(vocab_size, embed_size)self.linear = nn.Linear(embed_size, vocab_size)def forward(self, context):embedded = self.embeddings(context)embedded_sum = torch.sum(embedded, dim=1)output = self.linear(embedded_sum)return output# 定义训练函数
def train_cbow(data, target, model, criterion, optimizer):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()return loss.item()# 假设有一个简单的语料库和单词到索引的映射
corpus = ["I like deep learning", "I enjoy NLP", "I love PyTorch"]
word_to_index = {"I": 0, "like": 1, "deep": 2, "learning": 3, "enjoy": 4, "NLP": 5, "love": 6, "PyTorch": 7}# 将语料库转换为训练数据
context_size = 3
data = []
target = []
for sentence in corpus:tokens = sentence.split()for i in range(context_size, len(tokens) - context_size):context = [word_to_index[tokens[j]] for j in range(i - context_size, i + context_size + 1) if j != i]target_word = word_to_index[tokens[i]]data.append(torch.tensor(context, dtype=torch.long))target.append(torch.tensor(target_word, dtype=torch.long))# 超参数
vocab_size = len(word_to_index)
embed_size = 10
learning_rate = 0.01
epochs = 100# 初始化模型、损失函数和优化器
cbow_model = CBOWModel(vocab_size, embed_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cbow_model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(epochs):total_loss = 0for i in range(len(data)):loss = train_cbow(data[i], target[i], cbow_model, criterion, optimizer)total_loss += lossprint(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss}')# 获取词向量
word_embeddings = cbow_model.embeddings.weight.detach().numpy()
print("Word Embeddings:\n", word_embeddings)
  1. CBOW模型定义(class CBOWModel):

    • __init__ 方法:在初始化过程中定义了两个层,一个是nn.Embedding用于获取词向量,另一个是nn.Linear用于将词向量求和后映射到词汇表大小的空间
    • forward 方法:定义了模型的前向传播过程。给定一个上下文,首先通过Embedding层获取词向量,然后对词向量进行求和,最后通过Linear层进行映射。
  2. 训练函数(train_cbow):

    • train_cbow 函数用于训练CBOW模型。接受训练数据、目标、模型、损失函数和优化器作为输入,并执行前向传播、计算损失、反向传播和优化器更新权重的过程。
  3. 语料库和单词到索引的映射:

    • corpus 包含了三个简单的句子。
    • word_to_index 是单词到索引的映射。
  4. 将语料库转换为训练数据:

    • 对每个句子进行分词,然后构建上下文和目标。上下文是目标词的上下文词的索引列表,目标是目标词的索引。
  5. 超参数和模型初始化:

    • vocab_size 是词汇表大小。
    • embed_size 是词向量的维度。
    • learning_rate 是优化器的学习率。
    • epochs 是训练迭代次数。
    • CBOWModel 实例化为 cbow_model
    • 使用交叉熵损失函数和随机梯度下降(SGD)优化器。
  6. 训练过程:

    • 使用嵌套的循环对训练数据进行多次迭代。
    • 对每个训练样本调用 train_cbow 函数,计算损失并更新模型权重。
  7. 获取词向量:

    • 通过 cbow_model.embeddings.weight 获取训练后的词向量矩阵,并将其转换为 NumPy 数组。

需要注意的是,代码中的训练过程比较简单,通常在实际应用中可能需要更复杂的数据集、更大的模型和更多的训练策略。此处的代码主要用于展示CBOW模型的基本实现。

在CBOW(Continuous Bag of Words)模型中,神经网络的输入和输出数据的构造方式如下:

  1. 输入数据:

    • 对于每个训练样本,输入数据是上下文窗口内的单词的独热编码(one-hot encoding)向量的拼接。
    • 上下文窗口大小为3,因此对于每个目标词,上下文窗口内有3个单词。这3个单词的独热编码向量会被拼接在一起作为输入。
    • 对于语料库中的每个目标词,都会生成一个对应的训练样本。

    以 "I like deep learning" 为例:

    • "deep" 是目标词,上下文窗口为["like", "I", "learning"]。
    • 对应的独热编码向量分别是 [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0]。
    • 这三个向量拼接在一起作为神经网络的输入。

    对于整个语料库,这个过程会生成一组输入数据。

  2. 输出数据:

    • 输出数据是目标词的独热编码向量,表示模型要预测的词。
    • 对于 "I like deep learning" 中的 "deep",其对应的独热编码向量是 [0, 0, 0, 1, 0, 0, 0, 0]。
    • 整个语料库中,为每个目标词生成相应的输出数据。

综上所述,CBOW模型的神经网络输入数据是上下文窗口内单词的拼接独热编码向量,输出数据是目标词的独热编码向量。在训练过程中,模型通过学习输入与输出之间的映射关系,逐渐调整权重以更好地捕捉语境信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/626331.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序(四)页面跳转

注释很详细&#xff0c;直接上代码 新增内容 1.相对路径页面跳转 2. 绝对路径页面跳转 index.wxml <!-- navigator是块级元素&#xff0c;占一整行 --> <!-- 页面跳转url&#xff0c;相对路径 --> <navigator url"../logs/logs"><button type&…

赋值运算符和关系运算符

赋值运算符和关系运算符 赋值运算符 分类 符号作用说明赋值int a 10&#xff0c; 将10赋值给变量a加后赋值a b&#xff0c;将a b的值赋值给a-减后赋值a - b&#xff0c;将a - b的值赋值给a*乘后赋值a * b&#xff0c;将a b的值赋值给a/除后赋值a / b&#xff0c;将a b的…

运维知识点-Sqlite

Sqlite 引入 依赖 引入 依赖 <dependency><groupId>org.xerial</groupId><artifactId>sqlite-jdbc</artifactId><version>3.36.0.3</version></dependency>import javafx.scene.control.Alert; import java.sql.*;public clas…

第二证券:抢占技术前沿 中国光伏企业结伴“走出去”

2024年新年前后&#xff0c;光伏职业分外忙碌。据证券时报记者不完全统计&#xff0c;晶澳科技、华晟新动力、高测股份、华民股份等多家企业宣告新建项目投产&#xff0c;安徽皇氏绿能等企业的项目也迎来设备安装的重要节点。 证券时报记者采访多家企业的负责人后了解到&#…

tessreact训练字库

tessreact主要用于字符识别&#xff0c;除了使用软件自带的中英文识别库&#xff0c;还可以使用Tesseract OCR训练属于自己的字库。 一、软件环境搭建 使用Tesseract OCR训练自己的字库&#xff0c;需要安装Tesseract OCR和jTessBoxEditor(配套训练工具)。jTessBoxEditor需要…

基于SSM的社区老年人关怀服务系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

【23种设计模式应用场景汇总】

23种设计模式应用场景汇总 设计模式是一种在软件开发中解决特定问题的通用解决方案。下面我将尝试将23种设计模式融入到一个场景中&#xff1a; 假设我们正在开发一个在线购物系统&#xff0c;我们可以使用以下设计模式&#xff1a; 1. 工厂方法模式&#xff1a;当用户在网站上…

力扣hot100 只出现一次的数字 位运算

Problem: 136. 只出现一次的数字 文章目录 思路复杂度Code 思路 复杂度 时间复杂度: O ( n ) O(n) O(n) 空间复杂度: O ( n ) O(n) O(n) Code class Solution {public int singleNumber(int[] nums) {int res 0;for(int x : nums)res ^ x;return res;} }

UI自动化测试框架

文章目录 UI自动化基础什么是UI自动化测试框架UI自动化测试框架的模式数据驱动测试框架关键字驱动测试框架行为驱动测试框架 UI自动化测试框架的作用UI自动化测试框架的核心思想UI自动化测试框架的步骤UI自动化测试框架的构成UtilsLog.javaReadProperties.Java coreBaseTest.ja…

【分布式技术】监控技术zabbix实操

目录 一、脚本监控nginx的连接状态 步骤一&#xff1a;做好nginx的配置 步骤二&#xff1a;完成监控数据脚本编写&#xff0c;并使用zabbix_get测试 步骤三&#xff1a;在zabbix agent配置目录中&#xff0c;编写以conf结尾的用户参数文件 步骤四&#xff1a;在zabbix web…

Python 网络编程之TCP详细讲解

【一】传输层 【1】概念 传输层是OSI五层模型中的第四层&#xff0c;负责在网络中的两个端系统之间提供数据传输服务主要协议包括**TCP&#xff08;传输控制协议&#xff09;和UDP&#xff08;用户数据报协议&#xff09;** 【2】功能 **端到端通信&#xff1a;**传输层负责…

HackerGPTWhiteRabbitNeo的使用及体验对比

1. 简介 WhiteRabbitNeo&#xff08;https://www.whiterabbitneo.com/&#xff09;是基于Meta的LLaMA 2模型进行特化的网络安全AI模型。通过专门的数据训练&#xff0c;它在理解和生成网络安全相关内容方面具有深入的专业能力&#xff0c;可广泛应用于教育、专业培训和安全研究…

什么是非电离辐射与电离辐射?

摘要: 非电离辐射和电离辐射是两种不同类型的辐射&#xff0c;它们主要区别在于能量水平和与物质相互作用的方式。 非电离辐射 非电离辐射是指能量较低&#xff0c;不足以使原子或分子的电子脱离其原子核束缚而产生电离现象的电磁波。这类辐射不 ... 非电离辐射和电离辐射是两…

Centos 更换内核

文章目录 一、查看/更换系统内核1.1 查看当前运行环境的内核1.2 查看系统上所有可用内核1.3 切换内核方法一&#xff1a;通过启动菜单更换内核方法二&#xff1a;更换默认启动内核 二、安装内核2.1 使用ELRepo安装2.2 安装指定内核版本参考资料 一、查看/更换系统内核 1.1 查看…

docker搭建SSH镜像、systemctl镜像、nginx镜像、tomcat镜像

目录 一、SSH镜像 二、systemctl镜像 三、nginx镜像 四、tomcat镜像 五、mysql镜像 一、SSH镜像 1、开启ip转发功能 vim /etc/sysctl.conf net.ipv4.ip_forward 1sysctl -psystemctl restart docker 2、 cd /opt/sshd/vim Dockerfile 3、生成镜像 4、启动容器并修改ro…

面试题:你知道 Spring lazy-init 懒加载的原理吗?

文章目录 前言一、先睹为快二、原理分析三、总结 前言 普通的bean的初始化是在容器启动初始化阶段执行的&#xff0c;而被lazy-init修饰的bean 则是在从容器里第一次进行context.getBean(“”)时进行触发。 Spring 启动的时候会把所有bean信息(包括XML和注解)解析转化成Spring…

这可能是最全面的Java并发编程八股文了

内容摘自我的学习网站&#xff1a;topjavaer.cn 分享50道Java并发高频面试题。 线程池 线程池&#xff1a;一个管理线程的池子。 为什么平时都是使用线程池创建线程&#xff0c;直接new一个线程不好吗&#xff1f; 嗯&#xff0c;手动创建线程有两个缺点 不受控风险频繁创…

SpringBoot基础:一步步创建SpringBoot工程

摘要 本文介绍了&#xff0c;从零开始创建SpringBoot工程&#xff0c;且在每一步给出分析和原因。创建maven – 转Springboot – 引入jdbc – 引入数据库操作框架&#xff0c;最后给出了不同场景指定不同配置文件的方案。 背景 为什么要使用SpringBoot工程&#xff1f; 使用Sp…

Python 网络编程之粘包问题

【一】粘包问题介绍 【1】粘包和半包 粘包&#xff1a; 定义&#xff1a; 粘包指的是发送方发送的若干个小数据包被接收方一次性接收&#xff0c;形成一个大的数据包。原因&#xff1a; 通常是因为网络底层对数据传输的优化&#xff0c;将多个小数据包组合成一个大的数据块一次…

Linux搭建和使用redis

官网地址&#xff1a;http://redis.io/download 文件上传到服务器 tar包解压 tar zxvf redis-5.0.14.tar.gz安装 进入解压目录下&#xff0c;找到Makefile所在目录&#xff0c;执行make命令 make执行之后&#xff0c;会产生src等目录&#xff0c;进入执行make install命令…