PyTorch 稀疏函数解析:embedding 、one_hot详解

目录

PyTorch子模块Sparse functions详解

embedding

参数

输出形状

示例

带有 padding_idx 的示例

embedding_bag

参数

输出形状

示例

使用 padding_idx 的示例

one_hot

参数

返回

示例

总结


PyTorch子模块Sparse functions详解

embedding

torch.nn.functional.embedding 是 PyTorch 中的一个函数,用于从固定字典和大小的简单查找表中检索嵌入(embeddings)。这个函数通常用于使用索引检索词嵌入。其输入是一个索引列表和嵌入矩阵,输出是相应的词嵌入。

参数

  • input (LongTensor):包含嵌入矩阵索引的张量。
  • weight (Tensor):嵌入矩阵,行数等于最大可能索引 + 1,列数等于嵌入大小。
  • padding_idx (int, 可选):如果指定,padding_idx 处的条目不会对梯度产生贡献;因此,在训练期间,padding_idx 处的嵌入向量不会更新,即它保持为固定的“填充”。
  • max_norm (float, 可选):如果给定,每个嵌入向量的范数大于 max_norm 时将被重新规范化为 max_norm。注意:这将就地修改 weight。
  • norm_type (float, 可选):用于计算 max_norm 选项的 p-范数的 p。默认为 2。
  • scale_grad_by_freq (bool, 可选):如果给定,将按照小批量中单词频率的倒数来缩放梯度。默认为 False。
  • sparse (bool, 可选):如果为 True,weight 相对于的梯度将是一个稀疏张量。有关稀疏梯度的更多细节,请参阅 torch.nn.Embedding

输出形状

  • 输入:任意形状的 LongTensor,包含要提取的索引。
  • 权重:浮点类型的嵌入矩阵,形状为 (V, embedding_dim),其中 V = 最大索引 + 1,embedding_dim = 嵌入大小。
  • 输出:(*, embedding_dim),其中 * 是输入的形状。

示例

import torch
import torch.nn.functional as F# 两个样本的批次,每个样本有 4 个索引
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])# 包含 10 个大小为 3 的张量的嵌入矩阵
embedding_matrix = torch.rand(10, 3)# 使用 F.embedding 获取嵌入
output = F.embedding(input, embedding_matrix)

此例中,input 包含两个样本的索引,embedding_matrix 是一个随机初始化的嵌入矩阵。F.embedding 函数返回了这些索引对应的嵌入向量。

带有 padding_idx 的示例

weights = torch.rand(10, 3)
weights[0, :].zero_()  # 将索引 0 的嵌入向量设置为零
embedding_matrix = weights
input = torch.tensor([[0, 2, 0, 5]])# 使用 padding_idx
output = F.embedding(input, embedding_matrix, padding_idx=0)

在这个例子中,索引 0 被用作填充索引(padding_idx),因此它的嵌入向量在训练过程中不会更新,且初始化为零。这对于处理可变长度的序列数据特别有用,其中某些位置可能需要被忽略。 

embedding_bag

torch.nn.functional.embedding_bag 是 PyTorch 中的一个函数,它计算嵌入向量的“包”(bag)的和、均值或最大值,而无需实例化中间的嵌入向量。这个函数对于处理文本数据特别有用,特别是在处理变长序列或者需要聚合嵌入表示时。

参数

  • input (LongTensor):包含嵌入矩阵索引的包的张量。
  • weight (Tensor):嵌入矩阵,行数等于最大可能索引 + 1,列数等于嵌入大小。
  • offsets (LongTensor, 可选):仅当输入是1D时使用。offsets确定每个包(序列)在输入中的起始索引位置。
  • max_norm (float, 可选):如果给定,每个嵌入向量的范数大于 max_norm 时将被重新规范化为 max_norm。注意:这将就地修改 weight。
  • norm_type (float, 可选):用于计算 max_norm 选项的 p-范数的 p。默认为 2。
  • scale_grad_by_freq (bool, 可选):如果给定,将按照小批量中单词频率的倒数来缩放梯度。默认为 False。
  • mode (str, 可选):可选 "sum", "mean" 或 "max"。指定聚合包的方式。默认为 "mean"。
  • sparse (bool, 可选):如果为 True,weight 相对于的梯度将是一个稀疏张量。
  • per_sample_weights (Tensor, 可选):浮点/双精度权重的张量,或 None 表示所有权重应视为 1。如果指定,per_sample_weights 的形状必须与 input 完全相同。
  • include_last_offset (bool, 可选):如果为 True,offsets 的大小等于包的数量 + 1。最后一个元素是输入的大小,或最后一个包(序列)的结束索引位置。
  • padding_idx (int, 可选):如果指定,padding_idx 处的条目不会对梯度产生贡献;因此,训练期间不会更新 padding_idx 处的嵌入向量,即它保持为固定的“填充”。

输出形状

  • 输入:LongTensor,和可选的 offsets (LongTensor)

    • 如果输入是二维的,形状为 (B, N),它将被视为 B 个固定长度 N 的包(序列),这将根据 mode 返回 B 个聚合值。在这种情况下,offsets 被忽略并且要求为 None。
    • 如果输入是一维的,形状为 (N),它将被视为多个包(序列)的串联。offsets 是一个一维张量,包含输入中每个包的起始索引位置。因此,对于形状为 (B) 的 offsets,输入将被视为有 B 个包。空包(即长度为0)将返回由零填充的向量。
  • 权重 (Tensor):可学习的模块权重,形状为 (num_embeddings, embedding_dim)。

  • per_sample_weights (Tensor, 可选):具有与输入相同形状的张量。

  • 输出:聚合后的嵌入值,形状为 (B, embedding_dim)。

示例

import torch
import torch.nn.functional as F# 包含 10 个大小为 3 的张量的嵌入矩阵
embedding_matrix = torch.rand(10, 3)# 一个样本的批次,包含 4 个索引
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])# 指定每个包(序列)的开始位置
offsets = torch.tensor([0, 4])# 使用 F.embedding_bag 获取嵌入
output = F.embedding_bag(input, embedding_matrix, offsets)

在此示例中,input 包含 8 个索引,表示两个序列(或“包”),由 offsets 指定其开始位置。embedding_matrix 是一个随机初始化的嵌入矩阵。F.embedding_bag 函数将返回这些索引对应的嵌入向量的聚合(默认为均值)。

使用 padding_idx 的示例

embedding_matrix = torch.rand(10, 3)
input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
offsets = torch.tensor([0, 4])# 使用 padding_idx
output = F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')

在这个例子中,索引 2 被用作填充索引(padding_idx),这意味着在聚合时,索引为 2 的嵌入向量不会对结果产生贡献,并且在训练过程中不会更新这个嵌入向量。这在处理变长序列时特别有用,其中某些位置可能需要被忽略。示例中使用的 mode='sum' 表示对每个包内的嵌入向量进行求和操作。

embedding_bag 函数的这种处理方式比标准的 embedding 函数更高效,因为它避免了创建大量中间嵌入向量的步骤,特别是在处理包含许多短序列的大批量数据时。此外,它允许不同长度的序列共存于同一个批次中,这在处理自然语言处理任务时尤其有价值。

one_hot

torch.nn.functional.one_hot 是 PyTorch 中的一个函数,它用于将长整型张量(LongTensor)转换为一种称为 one-hot 编码的形式。在 one-hot 编码中,每个类别的索引将被转换为一个向量,该向量中除了对应类别索引处的值为 1 外,其余位置均为 0。

参数

  • tensor (LongTensor):任何形状的类别值。
  • num_classes (int):总类别数。如果设置为 -1,则类别数将推断为输入张量中最大类别值加1。

返回

  • 返回一个多了一维的 LongTensor,其中在最后一个维度的索引处,由输入指定的位置为 1,其他位置为 0。

示例

import torch
import torch.nn.functional as F# 示例 1:基本用法
output = F.one_hot(torch.arange(0, 5) % 3)
print(output)
# tensor([[1, 0, 0],
#         [0, 1, 0],
#         [0, 0, 1],
#         [1, 0, 0],
#         [0, 1, 0]])# 示例 2:指定 num_classes
output = F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
print(output)
# tensor([[1, 0, 0, 0, 0],
#         [0, 1, 0, 0, 0],
#         [0, 0, 1, 0, 0],
#         [1, 0, 0, 0, 0],
#         [0, 1, 0, 0, 0]])# 示例 3:使用多维张量
output = F.one_hot(torch.arange(0, 6).view(3, 2) % 3)
print(output)
# tensor([[[1, 0, 0],
#          [0, 1, 0]],
#         [[0, 0, 1],
#          [1, 0, 0]],
#         [[0, 1, 0],
#          [0, 0, 1]]])

在这些示例中,torch.arange(0, 5) % 3 生成一个周期为 3 的序列,然后 F.one_hot 将这些值转换为 one-hot 编码形式。在第二个示例中,通过指定 num_classes=5,可以控制 one-hot 编码向量的长度。在第三个示例中,展示了如何对多维张量进行 one-hot 编码。 

总结

本篇博客探讨了 PyTorch 框架中几个关键的稀疏函数,包括 embeddingembedding_bagone_hot。这些函数在处理自然语言处理(NLP)任务和其他需要高效、灵活处理大量类别或序列数据的应用中至关重要。embedding 函数用于从预定义的嵌入矩阵中检索指定索引的嵌入向量,支持自定义嵌入矩阵大小、填充索引和范数限制。embedding_bag 提供了一种高效的方法来处理变长序列,通过聚合(如求和、均值或最大值)嵌入向量,而无需单独处理每个序列。one_hot 函数则用于将类别标签转换为 one-hot 编码形式,适用于处理分类任务中的标签数据。这些函数的灵活性和高效性使它们成为深度学习模型设计和实现中的重要工具。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/630702.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Git怎么将文件夹上传至github,全过程

小白建议参考github文件上传全流程-新手入门系列(超详细!!!) 中间可能会有报错 $ ssh -T gitgithub.com ssh: connect to host github.com port 22: Connection timed out 这时,参考,如何解决&a…

React 基于Ant Degisn 实现table表格列表拖拽排序

效果图: 代码: myRow.js import { MenuOutlined } from ant-design/icons; import { DndContext } from dnd-kit/core; import { restrictToVerticalAxis } from dnd-kit/modifiers; import {arrayMove,SortableContext,useSortable,verticalListSorti…

std::for_each 简单使用

首先简单使用一下 std::for_each 是C标准库中的一个算法&#xff0c;用于对指定范围内的元素执行指定的操作。它通常用于迭代容器中的元素&#xff0c;对每个元素执行某种操作。下面是一个简单的例子&#xff0c;说明如何使用 std::for_each&#xff1a; #include <iostre…

.Net 全局过滤,防止SQL注入

问题背景&#xff1a;由于公司需要整改的老系统的漏洞检查&#xff0c;而系统就是没有使用参数化SQL即拼接查询语句开发的程序&#xff0c;导致漏洞扫描出现大量SQL注入问题。 解决方法&#xff1a;最好的办法就是不写拼接SQL&#xff0c;改用参数化SQL&#xff0c;推荐新项目…

2023年的年度总结PPT不一样了?

添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; 到了年终&#xff0c;需要撰写年度总结和制定计划了吗&#xff1f; 找不到合适的 PPT 模板&#xff1f; 感到缺乏灵感&#xff1f; 为做 PPT 绞尽脑汁&#xff1f; 为何不试试 AI 写 PPT 呢&#xff1f…

SQL笔记 -- 范式(第一范式、第二范式、第三范式、巴斯范式、反范式)及数据库设计原则

1.范式 1.1 范式简介 在关系型数据库中&#xff0c;关于数据表设计的基本原则、规则就称为范式。可以理解为&#xff0c;一张数据表的设计结构需要满足的某种设计标准的级别 。要想设计一个结构合理的关系型数据库&#xff0c;必须满足一定的范式。 目前关系型数据库有六种常…

Spring MVC学习之——自定义日期转化器

日期转换器 在数据库中的日期数据是date类型&#xff0c;而如何我们想在页面自己添加数据&#xff0c;一般是使用年-月-日的形式&#xff0c;这种形式不仅date类型接收不到&#xff0c;而且传来的是String类型&#xff0c;此时&#xff0c;我们就可以自定义日期转换器来接收数…

【MySQL】权限控制

DCL-权限控制 查询权限 show grants for 用户名主机名;授予权限 grant 权限列表 on 数据库名.表名 to 用户名主机名;grant all on test.* to user%; %是通配符&#xff0c;表示任意主机。撤销权限 revoke 权限列表 on 数据库名.表名 from 用户名主机名;revoke all on test.*…

顶顶通呼叫中心中间件自动外呼来电转人工显示被叫号码而不是显示路由条件 :一步步配置(mod_cti基于FreeSWITCH)

介绍 顶顶通呼叫中心中间件自动外呼来电转人工显示被叫号码而不是显示自动外呼的路由条件&#xff0c;可以是默认的被叫号码也可以改为显示指定的号码 一、显示默认被叫 1、配置拨号方案 打开ccadmin-》点击拨号方案-》找到进入排队-》配置跟图中一样的通道变量。修改了拨号…

关于KT6368A双模蓝牙芯片的BLE在ios的lightblue大数量数据测试

测试简介 关于KT6368A双模蓝牙芯片的BLE在ios的lightblue app大数量数据测试 测试环境&#xff1a;iphone7 。KT6368A双模程序96B6 App&#xff1a;lightblue ios端 可以打开log日志查看通讯流程 测试数据&#xff1a;长度是1224个字节&#xff0c;单次直接发给KT6368A&a…

Pixels:重新定义游戏体验的区块链农场游戏

数据源&#xff1a;Pixels Dashboard 作者&#xff1a;lesleyfootprint.network 最近&#xff0c;Pixels 通过从 Polygon 转移到 Sky Mavis 旗下的 Ronin 网络&#xff0c;完成了一次战略性的转变。 Pixels 每日交易量 Pixels 在 Ronin 网络上的受欢迎程度急剧上升&#xf…

073:vue+mapbox 加载here地图(影像瓦片图 v3版)

第073个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中加载here地图的影像瓦片图。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(共77行)相关API参考:专栏目标示例效果

模型之掷骰子问题

掷骰子问题 假设我要掷一对骰子&#xff0c;想要了解它们的行为如何。经验告诉我&#xff0c;问某些问题根本是不现实的。例如&#xff0c;不可能期待有人能预先告诉我某一次掷骰子的结果&#xff0c;即便是他掌握了很高超的科技&#xff0c;并且用机器来掷骰子。与此相反的是…

Jxls 实现动态导出功能

目录 引言前端页面后端代码excel模板导出效果 引言 在实际做项目的过程中&#xff0c;导出报表时需要根据每个人所关注的点不一样&#xff0c;所需导出的字段也不一样&#xff0c;这时后端就需要根据每个所选的字段去相应的报表&#xff0c;这就是本文要讲的动态导出报表。 前端…

LabVIEW图像识别检测机械零件故障

项目背景&#xff1a; 在工业生产中&#xff0c;零件尺寸的准确检测对保证产品质量至关重要。传统的人工测量方法不仅耗时费力&#xff0c;精度低&#xff0c;还容易导致零件的接触磨损。为了解决这些问题&#xff0c;开发了一套基于LabVIEW和机器视觉的机械零件检测系统。该系…

【Java实战项目】基于ssm的数据结构课程网络学习平台

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

Modbus协议学习第三篇之协议通信规则

导语 本篇博客将深入介绍Modbus协议的一些内容&#xff0c;主要包括通讯方式和通讯模型的介绍 Modbus通讯方式 Modbus协议是单主机、多从机的通信协议&#xff0c;即同一时间&#xff0c;总线上只能有一个主设备&#xff0c;但可以有一个或者多个从设备&#xff08;最多好像是2…

CSS实现图片放大缩小的几种方法

参考 方法一&#xff1a; 常用使用img标签&#xff0c;制定width或者height的任意一个&#xff0c;图片会自动等比例缩小 <div><img src"https://avatar.csdn.net/8/5/D/1_u012941315.jpg"/> </div> <!-- CSS--> <style> img {widt…

在Windows 10的PowerShell上实现对Linux机器,vscode同样可登录

在Windows 10的PowerShell上实现对Linux机器&#xff08;如 test192.168.10.13&#xff09;的SSH免密登录 1.检查SSH客户端&#xff1a;确保你的Windows 10系统已安装SSH客户端。 如果看到相关的命令说明&#xff0c;那么SSH客户端已安装。 在PowerShell中输入: ssh2.生成SSH…

Canvas和Three.js区别

Canvas&#xff1a; Canvas和Three.js都是用于在网页上创建和显示图形的工具&#xff0c;但它们的重点不同。Canvas是一个HTML5定义的标签&#xff0c;通过Canvas&#xff0c;你可以直接使用JavaScript来绘制线条、形状、文本和图像等。它有一套丰富的API&#xff0c;允许进行…