无脑入门pytorch系列(一)—— nn.embedding

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo
  • 练习1——改变**embedding_dim**
  • 练习2——index越界
  • 练习3——sequence长度不一致
  • 练习4——改变输入

官方定义

nn.embedding就是一个简单的查找表,存储固定字典和大小的嵌入。

该模块通常用于存储词嵌入并使用索引检索它们。模块的输入是索引列表,输出是相应的词嵌入。

个人理解:

  • nn.embedding就是一个字典映射表,比如它的大小是128,0~127每个位置都存储着一个长度为3的数组,那么我们外部输入的值可以通过index (0~127)映射到每个对应的数组上,所以不管外部的值是如何都能在该nn.embedding中找到对应的数组。想想哈希表,就很好理解了。
  • 既然是映射表,那么外部的输入的值肯定不能超过最大长度,比如128,同时下限也是。

官方的文档如下,torch.nn.embedding:

image-20230802145811801

从官方的定义来看实在是非常复杂,下面看个例子:

demo

下面是一个官方文档给出的例子:

import torch
import torch.nn as nnembedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出的结果:

image-20230802150024797

我们一步步理解代码:

  1. 首先,embedding = nn.Embedding(10, 3)即定义一个embedding模块,包含了一个长度为10的张量,每个张量的大小是3。举个例子,[-1.0556, -0.2404, -0.4578]就是一个tensor,那么如何取该tensor?使用下标index去取,注意,理解这点非常重要。
  2. 其次,input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])即输入一个我们需要embedding的变量,输入的每个值最终映射到张量空间中。
  3. 最后,我们发现输出e变成了[2, 4, 3]的张量,那么没有学习过的同学自然是一脸懵逼。我们需要,说说怎么看张量的维度,从最外层的**[]开始,计算里面的独立个体,发现是2;接着从第二维度的[]**开始数,发现是4;依次类推就可以得到张量的维度是[2, 4, 3]。

仍然十分迷茫,但是没关系,我们看看embedding的weight:

embedding.weight

输出:

image-20230802150606779

我们发现embedding.weight是个[10, 3]的向量,那么embedding.weight的值是怎么被我们input取到的呢?
比如index = 1,那么我们取[-1.0556, -0.2404, -0.4578]; index = 2, 取[ 1.3328, 2.5743, -0.7375]; index = 4, 取[-0.0584, -0.6458, 0.8236]。
这时候,聪明的小伙伴已经发现了,这不就刚好对应了e的输入为1/2/4的值吗?只是我们把输入1作为index去embedding.weight取对应的值去填充新的张量e。

所以说,我们待输入的张量[[1,2,4,5],[4,3,2,9]],在经过nn.embedding后,从[2, 4]维度变换为[2, 4, 3],其实就是[2, 4]中的每个值作为索引去nn.embedding中取对应的权重。

练习1——改变embedding_dim

embedding = nn.Embedding(10, 4) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802152757460

很明显,当embedding是个[10, 4]的张量时,映射出的张量为[2, 4, 4]

练习2——index越界

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,10]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:IndexError: index out of range in self

输出会报错,那是因为我们的embedding的维度是[10, 3],所以index的取值从0~9,那么我们取10肯定就出现问题了。如果出现对应的问题时,就可以大致猜到输入的值越界了。

练习3——sequence长度不一致

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:ValueError: expected sequence of length 3 at dim 1 (got 4)

将第一维[1, 2, 4, 5]减去5变成[1,2,4],出现ValueError: expected sequence of length 3 at dim 1 (got 4)的问题,所以需要每个维度的长度都一致。

练习4——改变输入

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[[1,2],[2,3],[4,5],[5,7]],[[4,5],[3,4],[2,3],[8,9]]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802153045211

当输入的的维度为[2,4,2]时,经过embedding得到[2,4,2,3]的张量,也是很好理解的。

喜欢的朋友可以点赞三连一下,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/20367.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手游联运系统有什么优势?

手游联运系统有许多优势,其中一些重要的优势包括: 1、游戏联运 手游联运系统可以将多款手游整合在一起,形成游戏联运平台,提供给玩家更多游戏选择,增加游戏的多样性和趣味性。 2、推广渠道管理 手游联运系统可以集…

Vuex:Vue.js应用程序的状态管理模式

介绍 在Vue.js应用程序中,随着项目复杂度的增加,组件之间的数据共享和管理变得困难。为了解决这个问题,Vue.js提供了一个名为Vuex的状态管理模式。Vuex可以帮助我们更有效地组织、管理和共享应用程序的状态。 什么是Vuex? Vuex…

网络安全策略应包含哪些?

网络安全策略是保护组织免受网络威胁的关键措施。良好的网络安全策略可以确保数据和系统的保密性、完整性和可用性。以下是一个典型的网络安全策略应包含的几个重要方面: 1. 强化密码策略:采用强密码,要求定期更换密码,并使用多因…

Java类集框架(一)

目录 1.Collection集合接口 2.List 接口 (常用子类 ArrayList ,LinkedList,Vector) 3.Set 集合 接口(常用子类 HashSet LinkedHashSet,TreeSet) 4.集合输出(iterator , Enumeration) 1.Collection集合接口 Collection是集合中最大父接口,在接口中定义了核心的…

Java设计模式-解释器模式

解释器模式 1.解释器模式 解释器模式,给定一个语言,定义它的文法的一种表示,并定义一个解释器,这个解释器使用该表示来解释语言中的句子。 其实解释器模式很简单的,就是一个翻译的过程,就像翻译软件&…

vue中各种混淆用法汇总

✨在生成、导出、导入、使用 Vue 组件的时候,像我这种新手就会常常被位于不同文件的 new Vue() 、 export default{} 搞得晕头转向。本文对常见用法汇总区分 new Vue() 💦Vue()就是一个构造函数,new Vue()是创建一个 vue 实例。该实例是一个…

Redis - 缓存的双写一致性

概念: 当修改了数据库的数据也要同时更新缓存的数据,缓存和数据库的数据要保持一致 那为什么会有不一致的情况呢? 如果不追求一致性,正常有两种做法 先修改数据库 后删除旧的缓存先删除旧的缓存 再修改数据库 我们以先删除旧的…

html学习9(脚本)

1、<script>标签用于定义客户端脚本&#xff0c;比如JavaScript&#xff0c;既可包含脚本语句&#xff0c;也可通过src属性指向外部文件。 2、JavaScript最常用于图片操作、表单验证及内容动图更新。 3、<noscript>标签用于在浏览器禁用脚本或浏览器不支持脚本&a…

华为数通HCIP-PIM原理与配置

组播网络概念 组播网络由组播源&#xff0c;组播组成员与组播路由器组成。 组播源的主要作用是发送组播数据。 组播组成员的主要作用是接收组播数据&#xff0c;因此需要通过IGMP让组播网络感知组成员位置与加组信息。 组播路由器的主要作用是将数据从组播源发送到组播组成员。…

基于MATLAB长时间序列遥感数据分析(以MODIS数据处理为例)

MATLAB MATLAB是美国MathWorks公司出品的商业数学软件&#xff0c;用于数据分析、无线通信、深度学习、图像处理与计算机视觉、信号处理、量化金融与风险管理、机器人&#xff0c;控制系统等领域。 [1] MATLAB是matrix&laboratory两个词的组合&#xff0c;意为矩阵工厂&a…

第七篇:k8s集群使用helm3安装Prometheus Operator

安装Prometheus Operator 目前网上主要有两种安装方式&#xff0c;分别为&#xff1a;1. 使用kubectl基于manifest进行安装 2. 基于helm3进行安装。第一种方式比较繁琐&#xff0c;需要手动配置yaml文件&#xff0c;特别是需要配置pvc相关内容时&#xff0c;涉及到的yaml文件太…

Python网站页面开发HTML总结

Python网站页面开发HTML总结 一、HTML基础语法 1.HTML是什么&#xff1f; ●HTML是HyperText Mark-up Language的首字母简写&#xff0c;即超文本标记语言。 ●HTML不是一种编程语言&#xff0c;而是一种标记语言。 ●超文本指的是超链接&#xff0c;标记指的是标签&#xf…

[QT编程系列-38]:数据存储 - SQLite数据库存储与操作

目录 1. SQLite数据库概述 1.1 简介 1.2 SQLite不支持网络连接 1.3 SQLite不需要安装MySQL Server数据库 1.4. SQLite性能 1.5 SQLite支持的数据条目 2. SQLite操作示例 3. QSqlDatabase 4. QSqlQuery 1. SQLite数据库概述 1.1 简介 QT 提供了对 SQLite 数据库的支…

软件测试面试真题 | 什么是PO设计模式?

面试官问&#xff1a;UI自动化测试中有使用过设计模式吗&#xff1f;了解什么是PO设计模式吗&#xff1f; 考察点 《page object 设计模式》&#xff1a;PageObject设计模式的设计思想、设计原则 《web自动化测试实战》&#xff1a;结合PageObject在真实项目中的实践与应用情…

Shell脚本学习-MySQL单实例和多实例启动脚本

已知MySQL多实例启动命令为&#xff1a; mysqld_safe --defaults-file/data/3306/my.cnf & 停止命令为&#xff1a; mysqladmin -uroot -pchang123 -S /data/3306/mysql.sock shutdown 请完成mysql多实例的启动脚本的编写&#xff1a; 问题分析&#xff1a; 要想写出脚…

mybatis-plus 用法

目录 1 快速开始 1.1 依赖准备 1.2 配置准备 1.3 启动服务 2 使用 2.1 实体类注解 2.2 CRUD 2.3 分页 2.4 逻辑删除配置 2.5 通用枚举配置 2.6 自动填充 2.7 多数据源 3 测试 本文主要介绍 mybatis-plus 这款插件&#xff0c;针对 springboot 用户。包括引入&…

【C语言】你不知道的隐式类型转换规则

【C语言】你不知道的隐式类型转换规则 一、隐式类型转换的规则二、整型提升1.整型提升的意义2.如何进行整体提升&#xff1f;2.1正数的整形提升2.2负数的整型提升2.3无符号整型提升 3.整型提升实例 三、算术转换 &#x1f388;个人主页&#xff1a;库库的里昂&#x1f390;CSDN…

Redis 高可用:主从复制、哨兵模式、集群模式

文章目录 一、redis高可用性概述二、主从复制2.1 主从复制2.2 数据同步的方式2.2.1 全量数据同步2.2.2 增量数据同步 2.3 实现原理2.3.1 服务器 RUN ID2.3.2 复制偏移量 offset2.3.3 环形缓冲区 三、哨兵模式3.1 原理3.2 配置3.3 流程3.4 使用3.5 缺点 四、cluster集群4.1 原理…

带头单链表,附带完整测试程序

&#x1f354;链表基础知识 1.概念&#xff1a;链表是由多个节点链接构成的&#xff0c;节点包含数据域和指针域&#xff0c;指针域上存放的指针指向下一个节点 2.链表的种类&#xff1a;按单向或双向、带头或不带头、循环或不循环分为多个种类 3.特点&#xff1a;无法直接找到…

最近写了10篇Java技术博客【SQL和画图组件】

&#xff08;1&#xff09;Java获取SQL语句中的表名 &#xff08;2&#xff09;Java SQL 解析器实践 &#xff08;3&#xff09;Java SQL 格式化实践 &#xff08;4&#xff09;Java 画图 画图组件jgraphx项目整体介绍&#xff08;一&#xff09; 画图组件jgraphx项目导出…