无脑入门pytorch系列(一)—— nn.embedding

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo
  • 练习1——改变**embedding_dim**
  • 练习2——index越界
  • 练习3——sequence长度不一致
  • 练习4——改变输入

官方定义

nn.embedding就是一个简单的查找表,存储固定字典和大小的嵌入。

该模块通常用于存储词嵌入并使用索引检索它们。模块的输入是索引列表,输出是相应的词嵌入。

个人理解:

  • nn.embedding就是一个字典映射表,比如它的大小是128,0~127每个位置都存储着一个长度为3的数组,那么我们外部输入的值可以通过index (0~127)映射到每个对应的数组上,所以不管外部的值是如何都能在该nn.embedding中找到对应的数组。想想哈希表,就很好理解了。
  • 既然是映射表,那么外部的输入的值肯定不能超过最大长度,比如128,同时下限也是。

官方的文档如下,torch.nn.embedding:

image-20230802145811801

从官方的定义来看实在是非常复杂,下面看个例子:

demo

下面是一个官方文档给出的例子:

import torch
import torch.nn as nnembedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出的结果:

image-20230802150024797

我们一步步理解代码:

  1. 首先,embedding = nn.Embedding(10, 3)即定义一个embedding模块,包含了一个长度为10的张量,每个张量的大小是3。举个例子,[-1.0556, -0.2404, -0.4578]就是一个tensor,那么如何取该tensor?使用下标index去取,注意,理解这点非常重要。
  2. 其次,input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])即输入一个我们需要embedding的变量,输入的每个值最终映射到张量空间中。
  3. 最后,我们发现输出e变成了[2, 4, 3]的张量,那么没有学习过的同学自然是一脸懵逼。我们需要,说说怎么看张量的维度,从最外层的**[]开始,计算里面的独立个体,发现是2;接着从第二维度的[]**开始数,发现是4;依次类推就可以得到张量的维度是[2, 4, 3]。

仍然十分迷茫,但是没关系,我们看看embedding的weight:

embedding.weight

输出:

image-20230802150606779

我们发现embedding.weight是个[10, 3]的向量,那么embedding.weight的值是怎么被我们input取到的呢?
比如index = 1,那么我们取[-1.0556, -0.2404, -0.4578]; index = 2, 取[ 1.3328, 2.5743, -0.7375]; index = 4, 取[-0.0584, -0.6458, 0.8236]。
这时候,聪明的小伙伴已经发现了,这不就刚好对应了e的输入为1/2/4的值吗?只是我们把输入1作为index去embedding.weight取对应的值去填充新的张量e。

所以说,我们待输入的张量[[1,2,4,5],[4,3,2,9]],在经过nn.embedding后,从[2, 4]维度变换为[2, 4, 3],其实就是[2, 4]中的每个值作为索引去nn.embedding中取对应的权重。

练习1——改变embedding_dim

embedding = nn.Embedding(10, 4) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802152757460

很明显,当embedding是个[10, 4]的张量时,映射出的张量为[2, 4, 4]

练习2——index越界

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,10]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:IndexError: index out of range in self

输出会报错,那是因为我们的embedding的维度是[10, 3],所以index的取值从0~9,那么我们取10肯定就出现问题了。如果出现对应的问题时,就可以大致猜到输入的值越界了。

练习3——sequence长度不一致

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:ValueError: expected sequence of length 3 at dim 1 (got 4)

将第一维[1, 2, 4, 5]减去5变成[1,2,4],出现ValueError: expected sequence of length 3 at dim 1 (got 4)的问题,所以需要每个维度的长度都一致。

练习4——改变输入

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[[1,2],[2,3],[4,5],[5,7]],[[4,5],[3,4],[2,3],[8,9]]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802153045211

当输入的的维度为[2,4,2]时,经过embedding得到[2,4,2,3]的张量,也是很好理解的。

喜欢的朋友可以点赞三连一下,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/20367.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络安全策略应包含哪些?

网络安全策略是保护组织免受网络威胁的关键措施。良好的网络安全策略可以确保数据和系统的保密性、完整性和可用性。以下是一个典型的网络安全策略应包含的几个重要方面: 1. 强化密码策略:采用强密码,要求定期更换密码,并使用多因…

Java类集框架(一)

目录 1.Collection集合接口 2.List 接口 (常用子类 ArrayList ,LinkedList,Vector) 3.Set 集合 接口(常用子类 HashSet LinkedHashSet,TreeSet) 4.集合输出(iterator , Enumeration) 1.Collection集合接口 Collection是集合中最大父接口,在接口中定义了核心的…

vue中各种混淆用法汇总

✨在生成、导出、导入、使用 Vue 组件的时候,像我这种新手就会常常被位于不同文件的 new Vue() 、 export default{} 搞得晕头转向。本文对常见用法汇总区分 new Vue() 💦Vue()就是一个构造函数,new Vue()是创建一个 vue 实例。该实例是一个…

Redis - 缓存的双写一致性

概念: 当修改了数据库的数据也要同时更新缓存的数据,缓存和数据库的数据要保持一致 那为什么会有不一致的情况呢? 如果不追求一致性,正常有两种做法 先修改数据库 后删除旧的缓存先删除旧的缓存 再修改数据库 我们以先删除旧的…

html学习9(脚本)

1、<script>标签用于定义客户端脚本&#xff0c;比如JavaScript&#xff0c;既可包含脚本语句&#xff0c;也可通过src属性指向外部文件。 2、JavaScript最常用于图片操作、表单验证及内容动图更新。 3、<noscript>标签用于在浏览器禁用脚本或浏览器不支持脚本&a…

华为数通HCIP-PIM原理与配置

组播网络概念 组播网络由组播源&#xff0c;组播组成员与组播路由器组成。 组播源的主要作用是发送组播数据。 组播组成员的主要作用是接收组播数据&#xff0c;因此需要通过IGMP让组播网络感知组成员位置与加组信息。 组播路由器的主要作用是将数据从组播源发送到组播组成员。…

第七篇:k8s集群使用helm3安装Prometheus Operator

安装Prometheus Operator 目前网上主要有两种安装方式&#xff0c;分别为&#xff1a;1. 使用kubectl基于manifest进行安装 2. 基于helm3进行安装。第一种方式比较繁琐&#xff0c;需要手动配置yaml文件&#xff0c;特别是需要配置pvc相关内容时&#xff0c;涉及到的yaml文件太…

软件测试面试真题 | 什么是PO设计模式?

面试官问&#xff1a;UI自动化测试中有使用过设计模式吗&#xff1f;了解什么是PO设计模式吗&#xff1f; 考察点 《page object 设计模式》&#xff1a;PageObject设计模式的设计思想、设计原则 《web自动化测试实战》&#xff1a;结合PageObject在真实项目中的实践与应用情…

Shell脚本学习-MySQL单实例和多实例启动脚本

已知MySQL多实例启动命令为&#xff1a; mysqld_safe --defaults-file/data/3306/my.cnf & 停止命令为&#xff1a; mysqladmin -uroot -pchang123 -S /data/3306/mysql.sock shutdown 请完成mysql多实例的启动脚本的编写&#xff1a; 问题分析&#xff1a; 要想写出脚…

mybatis-plus 用法

目录 1 快速开始 1.1 依赖准备 1.2 配置准备 1.3 启动服务 2 使用 2.1 实体类注解 2.2 CRUD 2.3 分页 2.4 逻辑删除配置 2.5 通用枚举配置 2.6 自动填充 2.7 多数据源 3 测试 本文主要介绍 mybatis-plus 这款插件&#xff0c;针对 springboot 用户。包括引入&…

Redis 高可用:主从复制、哨兵模式、集群模式

文章目录 一、redis高可用性概述二、主从复制2.1 主从复制2.2 数据同步的方式2.2.1 全量数据同步2.2.2 增量数据同步 2.3 实现原理2.3.1 服务器 RUN ID2.3.2 复制偏移量 offset2.3.3 环形缓冲区 三、哨兵模式3.1 原理3.2 配置3.3 流程3.4 使用3.5 缺点 四、cluster集群4.1 原理…

带头单链表,附带完整测试程序

&#x1f354;链表基础知识 1.概念&#xff1a;链表是由多个节点链接构成的&#xff0c;节点包含数据域和指针域&#xff0c;指针域上存放的指针指向下一个节点 2.链表的种类&#xff1a;按单向或双向、带头或不带头、循环或不循环分为多个种类 3.特点&#xff1a;无法直接找到…

最近写了10篇Java技术博客【SQL和画图组件】

&#xff08;1&#xff09;Java获取SQL语句中的表名 &#xff08;2&#xff09;Java SQL 解析器实践 &#xff08;3&#xff09;Java SQL 格式化实践 &#xff08;4&#xff09;Java 画图 画图组件jgraphx项目整体介绍&#xff08;一&#xff09; 画图组件jgraphx项目导出…

安防视频综合管理合平台EasyCVR可支持的视频播放协议有哪些?

EasyDarwin开源流媒体视频EasyCVR安防监控平台可提供视频监控直播、云端录像、云存储、录像检索与回看、智能告警、平台级联、云台控制、语音对讲、智能分析等能力。 视频监控综合管理平台EasyCVR具备视频融合能力&#xff0c;平台基于云边端一体化架构&#xff0c;具有强大的…

vue卡片轮播图

我的项目是vue3的&#xff0c;用的swiper8 <template><div class"tab-all"><div class"tab-four"><swiper:loop"true":autoplay"{disableOnInteraction:false,delay:3000}":slides-per-view"3":center…

[NLP]LLM高效微调(PEFT)--LoRA

LoRA 背景 神经网络包含很多全连接层&#xff0c;其借助于矩阵乘法得以实现&#xff0c;然而&#xff0c;很多全连接层的权重矩阵都是满秩的。当针对特定任务进行微调后&#xff0c;模型中权重矩阵其实具有很低的本征秩&#xff08;intrinsic rank&#xff09;&#xff0c;因…

c语言实现八大排序详细解析

首先先看排序算法的整体分类 排序&#xff1a;所谓排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a;假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录&#xff…

为Android构建现代应用——应用导航设计

在前一章节的实现中&#xff0c;Skeleton: Main structure&#xff0c;我们留下了几个 Jetpack 架构组件&#xff0c;这些组件将在本章中使用&#xff0c;例如 Composables、ViewModels、Navigation 和 Hilt。此外&#xff0c;我们还通过 Scaffold 集成了 TopAppBar 和 BottomA…

yolov3-spp 训练结果分析:网络结果可解释性、漏检误检分析

1. valid漏检误检分析 ①为了探查第二层反向找出来的目标特征在最后一层detector上的意义&#xff01;——为什么最后依然可以框出来目标&#xff0c;且mAP还不错的&#xff1f; ②如何进一步提升和改进这个数据的效果&#xff1f;可以有哪些优化数据和改进的地方&#xff1f;让…

《ChatGPT原理最佳解释,从根上理解ChatGPT》

【热点】 2022年11月30日&#xff0c;OpenAI发布ChatGPT&#xff08;全名&#xff1a;Chat Generative Pre-trained Transformer&#xff09;&#xff0c; 即聊天机器人程序 &#xff0c;开启AIGC的研究热潮。 ChatGPT是人工智能技术驱动的自然语言处理工具&#xff0c;它能够…