Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.pyimport torchdef random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoreif __name__ == '__main__':x = torch.arange(64).reshape(1, 16, 4)random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4,  4,  4,  4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834136.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JumpServer发布web应用

项目背景: 由于防火墙密码安全没有达到审计要求,需要加固防火墙用户安全,通过JumpServer发布防火墙登录页面,提供远程访问 认证要求: 1、密码记忆多次 2、密码大小写 3、密码字符 4、密码数字 加固前密码策略&…

详解BOM编程

华子目录 BOM编程window对象常见的window对象的属性常见的window对象的方法注意 history对象history对象的属性history对象的方法 screen 对象navigator 对象属性方法 location对象属性方法示例 BOM编程 JavaScript本质是在浏览器中运行,所以JavaScript提供了BOM&a…

初学java

注意点 1.使用关键字long的时候,在其赋值的时候要在后面加上大写或者小写的l,个人推荐大写,小写与数‘1’难区分。 2.函数的名字要与文件夹的名字相同,并且文件夹后面一定要有.java。例如这个的名字是Main,函数就得用这个&#x…

docker学习笔记(四)制作镜像

目录 第1步:编辑Dockerfile 第2步:编辑requirements.txt文件 第3步:编辑app.py文件,我们的程序文件 第4步:生成镜像文件 第5步:使用镜像,启动容器 第6步: 启动redis容器、将容器…

短信群发平台:全功能SDK短信接口解决方案

SDK短信接口介绍: 为了满足不同企业的需求,我们提供了一站式SDK短信接口解决方案。这些接口不仅功能强大,而且易于集成到现有的企业系统中,以提供更加安全、高效和便捷的服务。 1.短信验证码接口:用于用户注册、密码修…

动态代理,案例理解

动态代理:代理就是被代理者没有能力或者不愿意去完成某件事情,需要找个人代替自己去完成这件事,动态代理就是用来对业务功能(方法)进行代理的。 步骤: 1.必须有接口,实现类要实现接口&#xf…

Windows下,基于Gradle用Docker发布自己的程序

方案1: windows下打包程序,然后,上传到linux下,生成docker镜像,然后执行。 首先: 由于是采用Gradle管理的项目,打包的时候需要执行build任务。执行完成后,再build\libs目录下应该…

Python——Fastapi管理平台(打包+优化)

目录 一、配置多个表 1、后端项目改造 2、导包报错——需要修改(2个地方) 3、启动后端(查看是否有问题) 4、配置前端 二、打包——成exe文件(不包含static文件)简单 1、后端修改 2、前端修改 3、运行打包命…

Hive SQL-DML-Load加载数据

Hive SQL-DML-Load加载数据 在 Hive 中,可以使用 SQL DML(Data Manipulation Language)语句中的 LOAD 命令来加载数据到表中。LOAD 命令用于将本地文件系统或 HDFS(Hadoop 分布式文件系统)中的数据加载到 Hive 表中。 …

【ITK配准】第九期 基匹配Metric的配准样例

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享ITK配准中的基匹配Metric的配准样例,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 基匹配Me…

idea已配置的git仓库地址 更换新的Git仓库地址 教程

文章目录 目录 文章目录 更改流程 小结 概要更改流程技术细节小结 概要 先在idea控制台走一下流程 先将本地的git仓库删除 1. 查看当前远程仓库地址: 在终端或命令行中,导航到你的项目目录,并运行以下命令查看当前的远程仓库地址&#xff…

JMeter断言介绍

JMeter是一个功能强大的性能测试工具,它不仅可以模拟用户的行为,还可以对web应用程序的响应进行检测。其中断言就是JMeter中非常实用的功能之一。 断言是用于验证服务器响应是否正确的测试元素。它会检查服务器响应中的部分或全部内容,并在响…

正点原子Linux学习笔记(七)在 LCD 上显示 png 图片

在 LCD 上显示 png 图片 21.1 PNG 简介21.2 libpng 简介21.3 zlib 移植下载源码包编译源码安装目录下的文件夹介绍移植到开发板 21.4 libpng 移植下载源码包编译源码安装目录下的文件夹介绍移植到开发板 21.5 libpng 使用说明libpng 的数据结构创建和初始化 png_struct 对象创建…

cookie、session、token

cookie 纳入标准文档,标准浏览器需要遵守的协议之一,作为标准浏览器必须支持的。 WEB应用都是基于HTTP协议,标准的HTTP协议是无状态的。 什么是无状态? 不管是谁,不管是从哪个地方发起的请求。只要你的请求&#xff08…

openssl 生成证书步骤

本地测试RSA非对称加密功能时,需要用到签名证书。本文记录作者使用openssl本地生成证书的步骤,并没有深入研究openssl,难免会有错误,欢迎指出!!! 生成证书标准流程: 1、生成私钥&am…

【Linux】Linux——Centos7安装RabbitMQ

目录 安装包准备socaterlang 安装rabbitmq安装命令启动rabbitmq,两种方式查看rabbitmq 启动后的情况配置并开启网页插件关闭防火墙或开放端口测试登录问题配置web端访问账号密码和权限添加用户,后面两个参数分别是用户名和密码.添加权限修改用户角色再次…

单片机的中断

1. 中断系统是为使CPU具有对外界紧急事件的实时处理能力而设置 当中央处理机CPU正在处理某件事的时候外界发生紧急事件请求,要CPU暂停当前的工作,转而去处理这个紧急事件,处理完以后,再回到原来中断的地方,继续原…

C语言22行代码,让你的朋友以为中了病毒

1 **C语言介绍 ** C语言是一种计算机编程语言,由丹尼斯里奇(Dennis Ritchie)在1972年左右为UNIX操作系统设计并开发。它具有高效、可移植、灵活和强大的特点,在计算机科学领域中具有广泛的应用。C语言是一种结构化语言&#xff0…

电子硬件设计-Xilinx FPGA/SoC前期功耗评估方法(1)

目录 1. 简介 2. 使用方法 2.1 设计输入 2.2 查看结果 3. 额外说明 4. 总结 1. 简介 XPE (Xilinx Power Estimator, 功耗估算器) 电子表格是一种功耗估算工具,用于项目的预设计和预实现阶段。 该工具可以帮助工程师进行架构评估、器件选择、合适的电源组件以…

官方文档k8s1.30安装部署高可用集群,kubeadm安装Kubernetes1.30最新版本

文章目录 节点架构一、准备开始(每一台机器都执行)1️⃣ 检查所需端口(可以直接关闭防火墙放开所有端口)端口和协议控制面工作节点 关闭防火墙关闭 SELinux 2️⃣ 安装containerd容器containerd部署containerd切换为国内源 3️⃣ 设置/etc/hosts 二、安装 kubeadm、kubelet 和 …