Word2Vec的CBOW模型

Word2Vec中的CBOW(Continuous Bag of Words)模型是一种用于学习词向量的神经网络模型。CBOW的核心思想是根据上下文中的周围单词来预测目标单词。

例如,对于句子“The cat climbed up the tree”,如果窗口大小为5,那么当中心单词为“climbed”时,上下文单词为“The”、“cat”、“up”和“the”。CBOW模型要求根据这四个上下文单词,计算出“climbed”的概率分布。

一个简单的CBOW模型

import torch
import torch.nn as nn
import torch.optim as optim# 定义CBOW模型
class CBOWModel(nn.Module):def __init__(self, vocab_size, embed_size):super(CBOWModel, self).__init__()self.embeddings = nn.Embedding(vocab_size, embed_size)self.linear = nn.Linear(embed_size, vocab_size)def forward(self, context):embedded = self.embeddings(context)embedded_sum = torch.sum(embedded, dim=1)output = self.linear(embedded_sum)return output# 定义训练函数
def train_cbow(data, target, model, criterion, optimizer):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()return loss.item()# 假设有一个简单的语料库和单词到索引的映射
corpus = ["I like deep learning", "I enjoy NLP", "I love PyTorch"]
word_to_index = {"I": 0, "like": 1, "deep": 2, "learning": 3, "enjoy": 4, "NLP": 5, "love": 6, "PyTorch": 7}# 将语料库转换为训练数据
context_size = 3
data = []
target = []
for sentence in corpus:tokens = sentence.split()for i in range(context_size, len(tokens) - context_size):context = [word_to_index[tokens[j]] for j in range(i - context_size, i + context_size + 1) if j != i]target_word = word_to_index[tokens[i]]data.append(torch.tensor(context, dtype=torch.long))target.append(torch.tensor(target_word, dtype=torch.long))# 超参数
vocab_size = len(word_to_index)
embed_size = 10
learning_rate = 0.01
epochs = 100# 初始化模型、损失函数和优化器
cbow_model = CBOWModel(vocab_size, embed_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cbow_model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(epochs):total_loss = 0for i in range(len(data)):loss = train_cbow(data[i], target[i], cbow_model, criterion, optimizer)total_loss += lossprint(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss}')# 获取词向量
word_embeddings = cbow_model.embeddings.weight.detach().numpy()
print("Word Embeddings:\n", word_embeddings)
  1. CBOW模型定义(class CBOWModel):

    • __init__ 方法:在初始化过程中定义了两个层,一个是nn.Embedding用于获取词向量,另一个是nn.Linear用于将词向量求和后映射到词汇表大小的空间
    • forward 方法:定义了模型的前向传播过程。给定一个上下文,首先通过Embedding层获取词向量,然后对词向量进行求和,最后通过Linear层进行映射。
  2. 训练函数(train_cbow):

    • train_cbow 函数用于训练CBOW模型。接受训练数据、目标、模型、损失函数和优化器作为输入,并执行前向传播、计算损失、反向传播和优化器更新权重的过程。
  3. 语料库和单词到索引的映射:

    • corpus 包含了三个简单的句子。
    • word_to_index 是单词到索引的映射。
  4. 将语料库转换为训练数据:

    • 对每个句子进行分词,然后构建上下文和目标。上下文是目标词的上下文词的索引列表,目标是目标词的索引。
  5. 超参数和模型初始化:

    • vocab_size 是词汇表大小。
    • embed_size 是词向量的维度。
    • learning_rate 是优化器的学习率。
    • epochs 是训练迭代次数。
    • CBOWModel 实例化为 cbow_model
    • 使用交叉熵损失函数和随机梯度下降(SGD)优化器。
  6. 训练过程:

    • 使用嵌套的循环对训练数据进行多次迭代。
    • 对每个训练样本调用 train_cbow 函数,计算损失并更新模型权重。
  7. 获取词向量:

    • 通过 cbow_model.embeddings.weight 获取训练后的词向量矩阵,并将其转换为 NumPy 数组。

需要注意的是,代码中的训练过程比较简单,通常在实际应用中可能需要更复杂的数据集、更大的模型和更多的训练策略。此处的代码主要用于展示CBOW模型的基本实现。

在CBOW(Continuous Bag of Words)模型中,神经网络的输入和输出数据的构造方式如下:

  1. 输入数据:

    • 对于每个训练样本,输入数据是上下文窗口内的单词的独热编码(one-hot encoding)向量的拼接。
    • 上下文窗口大小为3,因此对于每个目标词,上下文窗口内有3个单词。这3个单词的独热编码向量会被拼接在一起作为输入。
    • 对于语料库中的每个目标词,都会生成一个对应的训练样本。

    以 "I like deep learning" 为例:

    • "deep" 是目标词,上下文窗口为["like", "I", "learning"]。
    • 对应的独热编码向量分别是 [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0]。
    • 这三个向量拼接在一起作为神经网络的输入。

    对于整个语料库,这个过程会生成一组输入数据。

  2. 输出数据:

    • 输出数据是目标词的独热编码向量,表示模型要预测的词。
    • 对于 "I like deep learning" 中的 "deep",其对应的独热编码向量是 [0, 0, 0, 1, 0, 0, 0, 0]。
    • 整个语料库中,为每个目标词生成相应的输出数据。

综上所述,CBOW模型的神经网络输入数据是上下文窗口内单词的拼接独热编码向量,输出数据是目标词的独热编码向量。在训练过程中,模型通过学习输入与输出之间的映射关系,逐渐调整权重以更好地捕捉语境信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/626331.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

字符串与数组的异同

Java 中的字符串(String)和数组(Array)是两种不同类型的数据结构,它们有一些相似之处,同时也有一些显著的区别。 相同之处: 存储多个元素: 字符串和数组都用于存储多个元素。 使用…

微信小程序(四)页面跳转

注释很详细&#xff0c;直接上代码 新增内容 1.相对路径页面跳转 2. 绝对路径页面跳转 index.wxml <!-- navigator是块级元素&#xff0c;占一整行 --> <!-- 页面跳转url&#xff0c;相对路径 --> <navigator url"../logs/logs"><button type&…

赋值运算符和关系运算符

赋值运算符和关系运算符 赋值运算符 分类 符号作用说明赋值int a 10&#xff0c; 将10赋值给变量a加后赋值a b&#xff0c;将a b的值赋值给a-减后赋值a - b&#xff0c;将a - b的值赋值给a*乘后赋值a * b&#xff0c;将a b的值赋值给a/除后赋值a / b&#xff0c;将a b的…

运维知识点-Sqlite

Sqlite 引入 依赖 引入 依赖 <dependency><groupId>org.xerial</groupId><artifactId>sqlite-jdbc</artifactId><version>3.36.0.3</version></dependency>import javafx.scene.control.Alert; import java.sql.*;public clas…

第二证券:抢占技术前沿 中国光伏企业结伴“走出去”

2024年新年前后&#xff0c;光伏职业分外忙碌。据证券时报记者不完全统计&#xff0c;晶澳科技、华晟新动力、高测股份、华民股份等多家企业宣告新建项目投产&#xff0c;安徽皇氏绿能等企业的项目也迎来设备安装的重要节点。 证券时报记者采访多家企业的负责人后了解到&#…

AUTOSAR OS详细介绍及配置说明(更新版20240115)

前言 AUTOSAR OS扩展了OSEK/VDX标准中的操作系统,所以本文结合OSEK/VDX的标准来介绍AUTOSAR OS,并借助Vector Configurator讲解AUTOSAR OS的配置。 OSEK源于德语,英文意思是:“车载电子设备的开发系统和接口”,它是一个标准,用来产生嵌入式操作系统的规范,通讯协议栈,…

宝塔面板打不开,记录一下解决办法

由于在服务器宝塔内安装Apache&#xff0c;提示需要卸载nginx&#xff0c;卸载过后宝塔通过网址访问不了&#xff0c;特此记录一下问题。 1、检查宝塔端口会不会被占用 面板默许使用8888端口&#xff0c;使用命令查看8888端口会不会被占用&#xff1a; netstat -apn|grep 88…

tessreact训练字库

tessreact主要用于字符识别&#xff0c;除了使用软件自带的中英文识别库&#xff0c;还可以使用Tesseract OCR训练属于自己的字库。 一、软件环境搭建 使用Tesseract OCR训练自己的字库&#xff0c;需要安装Tesseract OCR和jTessBoxEditor(配套训练工具)。jTessBoxEditor需要…

接口以及多态

什么是接口 接口是一种抽象的数据类型&#xff0c;它定义了一组方法的规范&#xff0c;但没有具体的实现。接口可以被类实现&#xff0c;一个类实现了接口后&#xff0c;必须实现接口中定义的所有方法。接口可以被多个类实现&#xff0c;用以实现多重继承。 接口的定义使用关键…

基于SSM的社区老年人关怀服务系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

【Vue自定义指令详细介绍】

Vue自定义指令详细介绍 1. 自定义指令1.1 局部1.2 全局 1. 自定义指令 在 Vue.js 中&#xff0c;除了默认提供的核心指令&#xff08;如 v-model、v-show、v-if 等&#xff09;&#xff0c;Vue.js 也允许注册自定义指令&#xff0c;自定义指令给你提供了一种方法来扩展 Vue 的…

【23种设计模式应用场景汇总】

23种设计模式应用场景汇总 设计模式是一种在软件开发中解决特定问题的通用解决方案。下面我将尝试将23种设计模式融入到一个场景中&#xff1a; 假设我们正在开发一个在线购物系统&#xff0c;我们可以使用以下设计模式&#xff1a; 1. 工厂方法模式&#xff1a;当用户在网站上…

力扣hot100 只出现一次的数字 位运算

Problem: 136. 只出现一次的数字 文章目录 思路复杂度Code 思路 复杂度 时间复杂度: O ( n ) O(n) O(n) 空间复杂度: O ( n ) O(n) O(n) Code class Solution {public int singleNumber(int[] nums) {int res 0;for(int x : nums)res ^ x;return res;} }

UI自动化测试框架

文章目录 UI自动化基础什么是UI自动化测试框架UI自动化测试框架的模式数据驱动测试框架关键字驱动测试框架行为驱动测试框架 UI自动化测试框架的作用UI自动化测试框架的核心思想UI自动化测试框架的步骤UI自动化测试框架的构成UtilsLog.javaReadProperties.Java coreBaseTest.ja…

【分布式技术】监控技术zabbix实操

目录 一、脚本监控nginx的连接状态 步骤一&#xff1a;做好nginx的配置 步骤二&#xff1a;完成监控数据脚本编写&#xff0c;并使用zabbix_get测试 步骤三&#xff1a;在zabbix agent配置目录中&#xff0c;编写以conf结尾的用户参数文件 步骤四&#xff1a;在zabbix web…

Python 网络编程之TCP详细讲解

【一】传输层 【1】概念 传输层是OSI五层模型中的第四层&#xff0c;负责在网络中的两个端系统之间提供数据传输服务主要协议包括**TCP&#xff08;传输控制协议&#xff09;和UDP&#xff08;用户数据报协议&#xff09;** 【2】功能 **端到端通信&#xff1a;**传输层负责…

HackerGPTWhiteRabbitNeo的使用及体验对比

1. 简介 WhiteRabbitNeo&#xff08;https://www.whiterabbitneo.com/&#xff09;是基于Meta的LLaMA 2模型进行特化的网络安全AI模型。通过专门的数据训练&#xff0c;它在理解和生成网络安全相关内容方面具有深入的专业能力&#xff0c;可广泛应用于教育、专业培训和安全研究…

MongoDB系统性能调优(持续更新)

cache_size 指定WT存储引擎内部cache的内存用量上限。 需要注意的是&#xff0c;仅作用于WiredTiger cache&#xff0c;而非mongod进程的内存用量上限。MongoDB同时使用WT cache和文件系统cache&#xff0c;往往mongod进程的内存用量高于该值。cache_size相对于物理内存总量不要…

聊聊PowerJob的TransportServiceAware

序 本文主要研究一下PowerJob的TransportServiceAware TransportServiceAware tech/powerjob/server/remote/aware/TransportServiceAware.java public interface TransportServiceAware extends PowerJobAware {void setTransportService(TransportService transportServi…

什么是非电离辐射与电离辐射?

摘要: 非电离辐射和电离辐射是两种不同类型的辐射&#xff0c;它们主要区别在于能量水平和与物质相互作用的方式。 非电离辐射 非电离辐射是指能量较低&#xff0c;不足以使原子或分子的电子脱离其原子核束缚而产生电离现象的电磁波。这类辐射不 ... 非电离辐射和电离辐射是两…