pytorch nn.Embedding 用法和原理

nn.Embedding 是 PyTorch 中的一个模块,用于将离散的输入(通常是词或子词的索引)映射到连续的向量空间。它在自然语言处理和其他需要处理离散输入的任务中非常常用。以下是 nn.Embedding 的用法和原理。

用法

初始化 nn.Embedding
nn.Embedding 的初始化需要两个主要参数:

  1. num_embeddings:字典的大小,即输入的最大索引值 + 1。
  2. embedding_dim:每个嵌入向量的维度。

此外,还有一些可选参数,如 padding_idx、max_norm、norm_type、scale_grad_by_freq 和 sparse。

import torch
import torch.nn as nn# 创建一个 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)

输入和输出
nn.Embedding 的输入是一个包含索引的长整型张量,输出是对应的嵌入向量。

# 示例输入
input_indices = torch.LongTensor([1, 2, 3, 4])
output_vectors = embedding_layer(input_indices)
print(output_vectors)

示例代码
以下是一个完整的示例代码,展示了如何使用 nn.Embedding 层:

import torch
import torch.nn as nn# 创建 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)# 示例输入
input_indices = torch.LongTensor([1, 2, 3, 4])# 获取嵌入向量
output_vectors = embedding_layer(input_indices)
print("Input indices:", input_indices)
print("Output vectors:", output_vectors)

原理

nn.Embedding 层的本质是一个查找表,它将输入的每个索引映射到一个固定大小的向量。这个映射表在初始化时会随机生成,然后在训练过程中通过反向传播进行优化。
主要步骤

  1. 初始化:在初始化时,nn.Embedding 会创建一个大小为 (num_embeddings, embedding_dim)的权重矩阵。这些权重是嵌入层的参数,会在训练过程中更新。
  2. 前向传播:在前向传播过程中,nn.Embedding 层会将输入的索引映射到权重矩阵的相应行,从而得到对应的嵌入向量。
  3. 反向传播:在训练过程中,嵌入层的权重矩阵会根据损失函数的梯度进行更新。这使得嵌入向量能够捕捉到输入的语义信息。

参数解释

  • padding_idx:如果指定了 padding_idx,则该索引的嵌入向量在训练过程中不会被更新。通常用于处理填充(padding)标记。
  • max_norm:如果指定了 max_norm,则会对每个嵌入向量的范数进行约束,使其不超过 max_norm。
  • norm_type:用于指定范数的类型,默认是2范数。
  • scale_grad_by_freq:如果设置为 True,则会根据输入中每个词的频率缩放梯度。
  • sparse:如果设置为 True,则使用稀疏梯度更新,适用于大词汇表的情况。

原理解释

  1. 查找表:nn.Embedding 的核心是一个查找表,其大小为 (num_embeddings,embedding_dim),每一行代表一个词或索引的嵌入向量。
  2. 前向传播:在前向传播中,输入的索引被用来查找嵌入向量。假设输入是 [1, 2, 3],则输出是权重矩阵中第1、第2和第3行的向量。
  3. 反向传播:在反向传播中,嵌入向量的梯度会根据损失函数进行计算,并用于更新权重矩阵。

通过这种方式,嵌入向量能够在训练过程中不断调整,使得相似的输入索引(例如语义相似的词)在向量空间中更接近,从而捕捉到输入的语义信息。

总结
nn.Embedding 是 PyTorch 中处理离散输入的一个非常强大且常用的工具。通过将离散索引映射到连续向量空间,并在训练过程中优化这些向量,nn.Embedding 能够捕捉到输入的丰富语义信息。这对于自然语言处理等任务来说是非常重要的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/864278.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 1321, 209, 102

目录 1321. 餐馆营业额变化增长题目链接表要求知识点思路代码 209. 长度最小的子数组题目链接标签暴力法思路代码 滑动窗口思路代码 102. 二叉树的层序遍历题目链接标签思路代码 1321. 餐馆营业额变化增长 题目链接 1321. 餐馆营业额变化增长 表 表Customer的字段为custome…

使用Python实现学生管理系统

文章目录 1. 系统概述2. 系统功能3. 实现细节3.1 初始化学生列表3.2 添加学生3.3 显示所有学生3.4 查找学生3.5 删除学生3.6 主菜单 4. 运行系统 在本文中,我们将使用Python编程语言来开发一个简单的学生管理系统。该系统将允许用户执行基本的学生信息管理操作&…

嵌入式UI开发-lvgl+wsl2+vscode系列:5、事件(Events)

一、前言 这节进行事件的总结,通过事件回调方式将用户和ui的交互行为绑定组合起来。 二、事件示例 1、示例1(点击事件) #include "../lv_examples.h" #if LV_BUILD_EXAMPLES && LV_USE_SWITCHstatic void event_cb(lv_…

Chapter8 透明效果——Shader入门精要学习笔记

一、基本概念 在Unity中通常使用两种方法来实现透明效果 透明度测试(无法达到真正的半透明效果)透明度混合(关闭了深度写入) 透明度测试 基本原理:设置一个阈值,只要片元的透明度小于阈值,就…

全球AI新闻速递7.1

全球AI新闻速递 1.科大讯飞发布讯飞星火 V4.0。 2.成都人形机器人创新中心:基于视觉扩散架构的人形机器人任务生成式模型 R-DDPRM。 3.安徽省人形机器人产业创新中心获批,将打造国内首创、世界领先研究基地。 4.亳州牵手华为打造华佗中医药大模型。 …

[论文精读]Variational Graph Auto-Encoders

论文网址:[1611.07308] Variational Graph Auto-Encoders (arxiv.org) 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎…

DL/T 645与modbus协议是否兼容,有何关系?

不兼容。645是电表协议,modbus是通用控制协议,两个是平行关系,两个协议都可以使用485通信协议(物理接口协议)进行传输,或传输介质与物理接口相同,软件协议不同。 Modbus有以下三种通信模式 在…

ARM功耗管理软件之时钟电源树

安全之安全(security)博客目录导读 思考:功耗管理软件栈及示例?WFI&WFE?时钟&电源树?DVFS&AVS? 目录 一、时钟&电源树简介 二、时钟树示例 三、电源树示例 一、时钟&电源树简介 时钟门控与自…

人工智能与机器学习原理精解【1】

文章目录 Rosenblatt感知器基础收敛算法算法概述算法步骤关键点说明总结 C实现要点代码 参考文献 Rosenblatt感知器 基础 感知器,也可翻译为感知机,是一种人工神经网络。它可以被视为一种最简单形式的前馈式人工神经网络,是一种二元线性分类…

【技术路线选择】:Qt or macOS/iOS ?

【技术路线选择】:Qt or macOS/iOS ? 【Question 1】: I have more than two years of experience developing with the following skills: Qt C and macOS/iOS development. Im interested in pursuing a software engineering career and would …

Victor CMS v1.0 SQL 注入漏洞(CVE-2022-28060)

前言 CVE-2022-28060 是 Victor CMS v1.0 中的一个SQL注入漏洞。该漏洞存在于 /includes/login.php 文件中的 user_name 参数。攻击者可以通过发送特制的 SQL 语句,利用这个漏洞执行未授权的数据库操作,从而访问或修改数据库中的敏感信息。 漏洞详细信…

论文阅读_优化RAG系统的检索

英文名称: The Power of Noise: Redefining Retrieval for RAG Systems 中文名称: 噪声的力量:重新定义RAG系统的检索 链接: https://arxiv.org/pdf/2401.14887.pdf 作者: Florin Cuconasu, Giovanni Trappolini, Federico Siciliano, Simone Filice, Cesare Campag…

半导体中名词“wafer”“chip”“die”中文名字和用途

①wafer——晶圆 wafer 即为图片所示的晶圆,由纯硅(Si)构成。一般分为6英寸、8英寸、12英寸规格不等,晶片就是基于这个wafer上生产出来的。晶圆是指硅半导体集成电路制作所用的硅晶片,由于其形状为圆形,故称为晶圆;在硅晶片上可加…

App测试技术(纯理论)

之前我们也学习过一些普通用例的设计, 如功能, 性能, 安全性, 兼容性, 易用性, 界面的测试用例设计, 之前我们讲的基本都是对于Web应用而言的, 这里我们来讲一下移动端的App测试用例设计. 功能方面 安装&卸载测试 这是只属于App的一类测试, 再平常我们使用移动设备(手机…

php 命令行模式详解

PHP 的命令行模式(Command Line Interface, CLI)是 PHP 的一个特定版本或运行时配置,它允许 PHP 脚本在没有 Web 服务器的情况下直接在命令行环境中执行。CLI 版本的 PHP 通常不包含 CGI 或者其他 web server 接口,因此更轻量级&a…

Redis的使用(一)概述

1.绪论 redis是一款用c编写的kv数据库,它具有丰富的数据类型,并且执行原子操作,自带数持久化,并且实现了集群部署等功能,我们来看看它有哪些特点: 1.提供了丰富的数据结构,比如string,list&am…

【第11章】MyBatis-Plus条件构造器(上)

文章目录 前言一、功能详解1. allEq2. eq3. ne4. gt5. ge6. lt7. le8. between9. notBetween10. like11. notLike12. likeLeft13. likeRight14. notLikeLeft15. notLikeRight16. isNull17. in18. notIn19. inSql20. notInSql21. eqSqlSince 3.5.622. gtSql Since 3.4.3.223. ge…

Linux4(Docker)

目录 一、Docker介绍 二、Docker结构 三、Docker安装 四、Docker 镜像 五、Docker 容器 六、Docker 安装nginx 七、Docker 中的MySQL部署 一、Docker介绍 Docker:是给予Go语言实现的开源项目。 Docker的主要目标是“Build,Ship and Run Any App,Anywhere” 也…

jenkins配置git

参考: 容器化部署 Jenkins,并配置SSH远程操作服务器_jenkins ssh-CSDN博客