nn.embedding函数详解(pytorch)

提示:文章附有源码!!!

文章目录

  • 前言
  • 一、nn.embedding函数解释
  • 二、nn.embedding函数使用方法
  • 四、模型训练与预测的权重变化探讨


前言

最近发现prompt工程(如sam模型),也有transform的detr模型等都使用了nn.Embedding函数,对points、boxes或learn query进行编码或解码。因此,我想写一篇文章作为记录,本想简单对其 介绍,但写着写着就想把所有与它相关东西作为记录。本文章探讨了nn.Embedding参数、使用方法、模型训练与预测的变化,并附有列子源码作为支撑 ,呈现一个较为完善的理解内容。

一、nn.embedding函数解释

Embedding实际是一个索引表或查找表,它是符合随机初始化生成的正太分布的表,将输入向量化,其结构如下:

nn.Embedding(num_embeddings, embedding_dim)

第1个参数 num_embeddings 就是生成num_embeddings个嵌入向量。
第2个参数 embedding_dim 就是嵌入向量的维度,即用embedding_dim值的维数来表示一个基本单位。

当然,该函数还有很多其它参数,解释如下:

参数源码注释如下:

num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;therefore, the embedding vector at :attr:`padding_idx` is not updated during training,i.e. it remains as a fixed "pad". For a newly constructed Embedding,the embedding vector at :attr:`padding_idx` will default to all zeros,but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency ofthe words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.See Notes for more details regarding sparse gradients.

参数中文解释:

num_embeddings (python:int) – 词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999embedding_dim (python:int) – 嵌入向量的维度,即用多少维来表示一个符号。
padding_idx (python:int, optional) – 填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0max_norm (python:float, optional) – 最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
norm_type (python:float, optional) – 指定利用什么范数计算,并用于对比max_norm,默认为2范数。
scale_grad_by_freq (boolean, optional) – 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.
sparse (bool, optional) – 若为True,则与权重矩阵相关的梯度转变为稀疏张量

注:该函数服从正太分布,该函数可参与训练,我将在后面做解释。

二、nn.embedding函数使用方法

该函数实际是对词的编码,假如你有2句话,每句话有四个词,那么你想对每个词使用6个维度表达,其代码如下:

import torch.nn as nn
import torch
if __name__ == '__main__':embedding = nn.Embedding(100, 6)  # 我设置100个索引,每个使用6个维度表达。input = torch.LongTensor([[1, 2, 4, 5],[4, 3, 2, 3]])  # a batch of 2 samples of 4 indices eache = embedding(input)print('输出尺寸', e.shape)print('输出值:\n',e)weights=embedding.weightprint('embed权重输出值:\n', weights[:6])

输出结果:
在这里插入图片描述

从图上可看出,输入编码是通过索引查找已编号embedding的权重,并将其赋值替换表达。换句话说,nn.Embedding(100, 6)生成正太分布100行6列数据,行必须超过输入句子词语长度,而句子每个词使用整数编码成索引,该索引对应之前embedding行寻找,得到对应行
维度,即可转为表达该词的特征向量。

四、模型训练与预测的权重变化探讨

之前已说过nn.Embedding()在训练过程中会发生变化,但在预测中将不在变化,应该是被训练成最佳词的向量维度表达,也就是说每个词唯一对应索引,被Embedding特征表达训练成最佳特征表达,也可说训练词索引特征表达固定。为探讨此过程,我写了对应示列,如下:

import torch
from torch.nn import Embeddingclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = Embedding(5, 3)def forward(self,vec):input = torch.tensor([0, 1, 2, 3, 4])emb_vec1 = self.emb(input)# print(emb_vec1)  ### 输出对同一组词汇的编码output = torch.einsum('ik, kj -> ij', emb_vec1, vec)return output
def simple_train():model = Model()vec = torch.randn((3, 1))label = torch.Tensor(5, 1).fill_(3)loss_fun = torch.nn.MSELoss()opt = torch.optim.SGD(model.parameters(), lr=0.015)print('初始化emebding参数权重:\n',model.emb.weight)for iter_num in range(100):output = model(vec)loss = loss_fun(output, label)opt.zero_grad()loss.backward(retain_graph=True)opt.step()# print('第{}次迭代emebding参数权重{}:\n'.format(iter_num, model.emb.weight))print('训练后emebding参数权重:\n',model.emb.weight)torch.save(model.state_dict(),'./embeding.pth')return modeldef simple_test():model = Model()ckpt = torch.load('./embeding.pth')model.load_state_dict(ckpt)model=model.eval()vec = torch.randn((3, 1))print('加载emebding参数权重:\n', model.emb.weight)for iter_num in range(100):output = model(vec)print('n次预测后emebding参数权重:\n', model.emb.weight)if __name__ == '__main__':simple_train()  # 训练与保存权重simple_test()

结果如下:

在这里插入图片描述
训练代码参考博客:点击这里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/134054.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言经典算法—二分查找,冒泡,选择,插入,归并,快排,堆排

一、二分查找 1、前提条件&#xff1a;数据有序&#xff0c;随机访问&#xff1b; 2、实现&#xff1a;递归实现&#xff0c;非递归实现 3、注意事项&#xff1a; 循环退出条件:low <high,low high.说明还有一个元素&#xff0c;该元素还要与key进行比较 mid的取值&#xf…

UE5 新特性 Nanite 开启

啥也不说&#xff0c;只能说&#xff0c;真的牛&#xff0c;在自己的项目上&#xff0c;从10几20的帧数&#xff0c;直接彪到了70 适用场景&#xff1a; 大场景&#xff0c;三角面足够多 在Project Setting里面 将这几个勾未true 勾上这个&#xff0c;放入场景即可

【电子通识】USB Logo的标识含义

USB 图标的设计灵感是来自罗马神话中的海神尼普顿(Neptune)&#xff08;也是海王星的名字&#xff09;的武器「三叉戟」&#xff0c;一支强有力的三齿鱼叉。不过&#xff0c;为了避免鱼叉形状的设计暗示人们拿着自己的USB 存储设备到处乱插&#xff08;叉&#xff09;。设计师对…

机器学习模型,超级全面总结!

机器学习是一种通过让计算机自动从数据中学习规律和模式&#xff0c;从而完成特定任务的方法。按照模型类型&#xff0c;机器学习可以分为两大类&#xff1a;监督学习模型和无监督学习模型。 附注&#xff1a;除了以上两大类模型&#xff0c;还有半监督学习和强化学习等其他类…

Texlive安装

下载4.8G的iso文件 解压 或 装载后&#xff0c;以管理员身份运行(.bat)文件。 运行以下两句代码进行Texlive相关升级 tlmgr option repository otan tlmgr update --self --all 运行以下三行代码&#xff0c;检查是否安装成功 latex -v xelatex -v pdflatex -v 如果有异常…

基于单片机的智能扫地机设计

概要 本文主要设计一个简单的智能扫地机。该扫地机的核心控制元器件是stc89c52&#xff0c;具有编写程序简单&#xff0c;成本普遍较低&#xff0c;功能较多&#xff0c;效率特别高等优点&#xff0c;因此在市场上得到很大的应用。除此之外&#xff0c;该扫地机能够自动避开障碍…

【Java 进阶篇】JSP EL 详解

在 Java Web 开发中&#xff0c;JavaServer Pages&#xff08;JSP&#xff09;是一种强大的技术&#xff0c;用于创建动态 Web 应用程序。JSP 的一个关键方面是 Expression Language&#xff08;EL&#xff09;表达语言&#xff0c;它允许您在 JSP 页面中嵌入 Java 代码&#x…

关于卷积神经网络的多通道

多通道输入 当输入的数据包含多个通道时&#xff0c;我们需要构造一个与输入通道数相同通道数的卷积核&#xff0c;从而能够和输入数据做卷积运算。 假设输入的形状为n∗n&#xff0c;通道数为ci​&#xff0c;卷积核的形状为f∗f&#xff0c;此时&#xff0c;每一个输入通道都…

记CVE-2022-39227-Python-JWT漏洞

文章目录 前言影响版本漏洞分析Newstar2023 Week5总结 前言 在Asal1n师傅的随口一说之下&#xff0c;说newstar week5出了一道祥云杯一样的CVE&#xff0c;于是自己也是跑去看了一下&#xff0c;确实是自己不知道的一个CVE漏洞&#xff0c;于是就从这道题学习到了python-jwt库…

机器视觉 opencv 深度学习 驾驶人脸疲劳检测系统 -python 计算机竞赛

文章目录 0 前言1 课题背景2 Dlib人脸识别2.1 简介2.2 Dlib优点2.3 相关代码2.4 人脸数据库2.5 人脸录入加识别效果 3 疲劳检测算法3.1 眼睛检测算法3.2 打哈欠检测算法3.3 点头检测算法 4 PyQt54.1 简介4.2相关界面代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#x…

在MacBook上实现免费的PDF文件编辑

之前我想对PDF文件进行简单处理&#xff08;比如删页面、添空白页、调整页面顺序&#xff09;&#xff0c;要么是开wps会员【花钱贵】&#xff0c;下载&#xff08;盗版&#xff09;Adobe Acrobat【macOS不好下载】&#xff0c;要么用福昕阅览器登陆学生账号&#xff08;学校买…

[React] React-Redux 快速入门

文章目录 1.安装 Redux Toolkit 和 React Redux2.创建 Redux Store3.为 React 提供 Redux Store​4.创建 Redux State Slice5.添加 Slice Reducers 到 Store6.在 React 组件中使用 Redux State 和 Actions​7.总结 1.安装 Redux Toolkit 和 React Redux npm install reduxjs/t…

KaiOS APN配置文件apn.json调试验证方法(无需项目全编)

1、KaiOS 的应用就类似web应用&#xff0c;结合文件夹路径webapp字面意思理解。 2、KaiOS APN配置文件源代码在apn.json&#xff0c; &#xff08;1&#xff09;apn.json可以自定义路径&#xff0c;通过配置脚本实现拷贝APN在编译时动态选择路径在机器中生效。 &#xff08;…

集合框架:List系列集合:特点、方法、遍历方式、ArrayList,LinkList的底层原理

目录 List集合 特有方法 遍历方式 1. 使用普通 for 循环&#xff1a; 2. 使用增强型 for 循环&#xff08;foreach&#xff09;&#xff1a; 3. 使用迭代器&#xff08;Iterator&#xff09;&#xff1a; 4. 使用 Java 8 的流&#xff08;Stream&#xff09;API&#xff…

Softing新版HART多路复用器现支持图尔克excom和西门子ET 200iSP等远程I/O

Softing工业自动化最近升级了用于访问配置和诊断数据的smartLink SW-HT软件&#xff0c;现在该软件可支持访问图尔克excom和西门子ET 200iSP等远程I/O。 &#xff08;smartLink SW-HT支持访问配置和诊断数据&#xff09; 越来越多的新型远程I/O选择使用以太网来替代PROFIBUS连接…

系列十一、拦截器(二)#案例演示

一、案例演示 说明&#xff1a;如下案例通过springboot的方式演示拦截器是如何使用的&#xff0c;以获取Controller中的请求参数为切入点进行演示 1.1、前置准备工作 1.1.1、pom <dependencies><!-- spring-boot --><dependency><groupId>org.spring…

分享一下怎么做小程序营销活动

小程序营销活动已经成为现代营销的必备利器&#xff0c;它能够帮助企业提高品牌知名度、促进产品销售&#xff0c;以及加强与用户的互动。然而&#xff0c;要想成功地策划和执行一个小程序营销活动&#xff0c;需要精心设计和全面规划。本文将为您介绍小程序营销活动的策划和执…

OpenSign 开源 PDF 电子签名解决方案

OpenSign 是一个开源文档电子签名解决方案&#xff0c;旨在为 DocuSign、PandaDoc、SignNow、Adobe Sign、Smartwaiver、SignRequest、HelloSign 和 Zoho Sign 等商业平台提供安全、可靠且免费的替代方案。 特性&#xff1a; 安全签名&#xff1a;利用最先进的加密算法来确保…

easyHttp -- 轻量级的 HTTP 客户端工具包

easyHttp gitte地址:easy-http 介绍 easyHttp 是一个轻量级的 HTTP 客户端工具包&#xff0c;专为 Java 设计&#xff0c;使得基本的 HTTP 请求变得异常简单。该库主要针对常见的 HTTP 请求提供了简洁的 API&#xff0c;使得开发者无需面对复杂的设置。当前版本已支持基本的请…

私有化部署大模型:5个.Net开源项目

从零构建.Net前后端分离项目 今天一起盘点下&#xff0c;10月份推荐的5个.Net开源项目&#xff08;点击标题查看详情&#xff09;。 1、BootstrapBlazor企业级组件库&#xff1a;前端开发的革新之路 BootstrapBlazor是一个用于构建现代Web应用程序的开源框架&#xff0c;它基…