Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor([[[0, 0, 0, 0],[2, 2, 2, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.pyimport torchdef random_masking(x, mask_ratio=0.75):"""Perform per-sample random masking by per-sample shuffling.Per-sample shuffling is done by argsort random noise.x: [N, L, D], sequence"""N, L, D = x.shape  # batch, length, dimlen_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sample【核心代码】ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]# 利用torch.gather()从源tensor中获取25%的unmasked tokensx_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)return x_masked, mask, ids_restoreif __name__ == '__main__':x = torch.arange(64).reshape(1, 16, 4)random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15],[16, 17, 18, 19], # 4[20, 21, 22, 23],[24, 25, 26, 27],[28, 29, 30, 31],[32, 33, 34, 35],[36, 37, 38, 39],[40, 41, 42, 43], # 10[44, 45, 46, 47],[48, 49, 50, 51], # 12[52, 53, 54, 55], # 13[56, 57, 58, 59],[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],[12, 12, 12, 12],[ 4,  4,  4,  4],[13, 13, 13, 13]]])# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],[48, 49, 50, 51],[16, 17, 18, 19],[52, 53, 54, 55]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834136.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JumpServer发布web应用

项目背景: 由于防火墙密码安全没有达到审计要求,需要加固防火墙用户安全,通过JumpServer发布防火墙登录页面,提供远程访问 认证要求: 1、密码记忆多次 2、密码大小写 3、密码字符 4、密码数字 加固前密码策略&…

详解BOM编程

华子目录 BOM编程window对象常见的window对象的属性常见的window对象的方法注意 history对象history对象的属性history对象的方法 screen 对象navigator 对象属性方法 location对象属性方法示例 BOM编程 JavaScript本质是在浏览器中运行,所以JavaScript提供了BOM&a…

初学java

注意点 1.使用关键字long的时候,在其赋值的时候要在后面加上大写或者小写的l,个人推荐大写,小写与数‘1’难区分。 2.函数的名字要与文件夹的名字相同,并且文件夹后面一定要有.java。例如这个的名字是Main,函数就得用这个&#x…

GPIO基础知识学习

前言: 本文记录了我自己学习最基本的单片机电路知识的学习笔记。本文参考了引用链接中的大量内容,并加上了自己的很少的一点思考(因为我本人的电路知识基本没有)。 引用: 上拉电阻与下拉电阻总结 与 GPIO框图分析_上…

docker学习笔记(四)制作镜像

目录 第1步:编辑Dockerfile 第2步:编辑requirements.txt文件 第3步:编辑app.py文件,我们的程序文件 第4步:生成镜像文件 第5步:使用镜像,启动容器 第6步: 启动redis容器、将容器…

短信群发平台:全功能SDK短信接口解决方案

SDK短信接口介绍: 为了满足不同企业的需求,我们提供了一站式SDK短信接口解决方案。这些接口不仅功能强大,而且易于集成到现有的企业系统中,以提供更加安全、高效和便捷的服务。 1.短信验证码接口:用于用户注册、密码修…

动态代理,案例理解

动态代理:代理就是被代理者没有能力或者不愿意去完成某件事情,需要找个人代替自己去完成这件事,动态代理就是用来对业务功能(方法)进行代理的。 步骤: 1.必须有接口,实现类要实现接口&#xf…

Windows下,基于Gradle用Docker发布自己的程序

方案1: windows下打包程序,然后,上传到linux下,生成docker镜像,然后执行。 首先: 由于是采用Gradle管理的项目,打包的时候需要执行build任务。执行完成后,再build\libs目录下应该…

Python——Fastapi管理平台(打包+优化)

目录 一、配置多个表 1、后端项目改造 2、导包报错——需要修改(2个地方) 3、启动后端(查看是否有问题) 4、配置前端 二、打包——成exe文件(不包含static文件)简单 1、后端修改 2、前端修改 3、运行打包命…

解决Rust Cargo报错

1.查看操作系统信息 [root@localhost ~]# cat /etc/.kyinfo [dist] name=Kylin milestone=Server-V10-GFB-Release-ZF9_01-2204-Build03 arch=arm64 beta=False time=2023-01-09 11:04:36 dist_id=Kylin-Server-V10-GFB-Release-ZF9_01-2204-Build03-arm64-2023-01-09 11:04:…

Hive SQL-DML-Load加载数据

Hive SQL-DML-Load加载数据 在 Hive 中,可以使用 SQL DML(Data Manipulation Language)语句中的 LOAD 命令来加载数据到表中。LOAD 命令用于将本地文件系统或 HDFS(Hadoop 分布式文件系统)中的数据加载到 Hive 表中。 …

互联网摸鱼日报(2024-05-09)

互联网摸鱼日报(2024-05-09) 36氪新闻 编剧和观众吵翻了天,到底谁不懂女性 OpenAI的搜索引擎真要来了:开启灰度测试,微软Bing加持 沾过光、踩过坑的小米生态链企业,给品牌指了什么出路? 中国县城里的热辣滚烫 前销…

java识别word段落和Java识别pdf端口整理

首先理解word与xml的关系 word文档与xml关系_docx xml-CSDN博客 Word和XML之间有密切的关系,因为Word文档实际上是XML文件的一种。从Word 2003开始,Microsoft Word文档的默认格式是XML,即.docx。XML是一种可扩展的标记语言,它允…

edge 的使用心得与深度搜索

关于 **Edge 浏览器** 的使用心得,用户普遍提到了以下几个方面: 1. **性能和效率**:Edge 浏览器在微软重新构建它基于 Chromium 内核后,性能得到了显著的提升。用户通常会注意到它在启动速度、页面加载速度和资源管理方面相对更优…

【ITK配准】第九期 基匹配Metric的配准样例

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享ITK配准中的基匹配Metric的配准样例,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 基匹配Me…

idea已配置的git仓库地址 更换新的Git仓库地址 教程

文章目录 目录 文章目录 更改流程 小结 概要更改流程技术细节小结 概要 先在idea控制台走一下流程 先将本地的git仓库删除 1. 查看当前远程仓库地址: 在终端或命令行中,导航到你的项目目录,并运行以下命令查看当前的远程仓库地址&#xff…

JMeter断言介绍

JMeter是一个功能强大的性能测试工具,它不仅可以模拟用户的行为,还可以对web应用程序的响应进行检测。其中断言就是JMeter中非常实用的功能之一。 断言是用于验证服务器响应是否正确的测试元素。它会检查服务器响应中的部分或全部内容,并在响…

MATLAB--Sequences Series II

Problem 2575. Sum of series I What is the sum of the following sequence:(这个序列的和是多少:) Σ(2k-1) for k1...n for different n?(对于不同的 ( n )?) 在MATLAB中,可以使用循环来计算…

docker 安装镜像及使用命令

目录 1. Mysql2. Redis3. Nginx4. Elasticsearch官网指导 docker pull 容器名:版本号 拉取容器, 不指定版本号默认最新的 run 运行 -d 后台运行 -p 3306:3306 -p是port 对外端口:对内端口 –name xyy_mysql 容器名称 -e MYSQL_ROOT_PASSWORD123456 环境变量 -v 系统地址:docker…

js强大的运算符:??、??=

学习目标: js中强大的运算符 ?? 非空运算符 学习内容: ?? 非空运算符 注意:?? 运算符被称为非空运算符。如果第一个参数不是 null/undefined 将返回第一个参数,否则返回第二个参数 之前: 给变量设置默认值时…