图解PyTorch中的torch.gather函数和 scatter 函数

前言

torch.gather在目前基于 transformer or query based 的目标检测中,在最后获取目标结果时,经常用到。

这里记录下用法,防止之后又忘了。

介绍

torch.gather

在这里插入图片描述
官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据
看到这个核心定义,我们很容易想到gather()的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的

示例

我们找个3x3的二维矩阵做个实验

import torchtensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示
在这里插入图片描述

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示
在这里插入图片描述

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],[7],[9]])

过程如图所示
在这里插入图片描述

scatter

基本是 gather 的反过程,是将数据添加进去,
doc:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

example:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],[0, 2, 0, 0, 0],[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],[6, 7, 0, 0, 8],[0, 0, 0, 0, 0]])>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],[2.0000, 2.0000, 2.0000, 3.2300]])

具体过程见 gather 的就好~一摸一样,一个获取,一个填入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/786520.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu joystick 测试手柄 xbox

Ubuntu joystick 测试手柄 xbox 测试使用Ubuntu20.04 测试环境在工控机 安装测试 实际测试使用的手柄是北通阿修罗2pro 兼容xbox Ubuntu20.04主机 连接手柄或者无线接收器后查看是否已经检测到: ls /dev/input找到输入中的 js0 即为手柄输入 需要安装joysti…

什么函数不能声明为虚函数?

常见的不能声明为虚函数的有普通函数(非成员函数)、静态成员函数、内联成员函数、 构造函数和友元函数。以下将分别对这几种情况进行分析。 普通函数(非成员函数) 普通函数(非成员函数)只能overload &…

sql注入详解

ps:简单说下这里只写了我能理解的明白的,后面的二阶注入,堆叠注入没写 手工sql注入 1.存在sql注入本质上就是数据库过滤的不严格或者未进行过滤,1 and 11,返回正常,1 and 12 返回不正常,说明带到数据库里面…

注意力机制篇 | YOLOv8改进之添加DAT注意力机制

前言:Hello大家好,我是小哥谈。DAT(Vision Transformer with Deformable Attention)是一种引入了可变形注意力机制的视觉Transformer。在训练算法模型的时候,通过引入可变形注意力机制,改进了视觉Transformer的效率和性能,使其在处理复杂的视觉任务时更加高效和准确。�…

C51-- 蓝牙,WIFI模块

HC-08蓝牙模块: 蓝牙 -- 最好用的 串口透传 模块 透传 -- 透明传送,指的是在数据传输的过程中,通过无线的方式这组数据不发生任何形式的改变, 仿佛传输过程是透明的,同时保证传输质量,最终原封不动的传送到接收者手…

css酷炫边框

边框一 .leftClass {background: #000;/* -webkit-animation: twinkling 1s infinite ease-in-out; 1秒钟的开始结束都慢的无限次动画 */ } .leftClass::before {content: "";width: 104%;height: 102%;border-radius: 8px;background-image: linear-gradient(var(…

正则表达式引擎库汇合

1.总览表格 一些正则表达式库的对比 index库名编程语言说明代码示例编译指令1Posix正则C语言是C标准库中用于编译POSIX风格的正则表达式库 posix-re.cgcc posix-re.c 2PCRE库C语言提供类似Perl语言的一个正则表达式引擎库。 一般系统上对应/usr/lib64/libpcre.so这个库文件&am…

柔性数组详细讲解

动态内存函数的使用和综合实践-malloc,free,realloc,calloc-CSDN博客https://blog.csdn.net/Jason_from_China/article/details/137075045 柔性数组存在的意义 柔性数组在编程语言中指的是可以动态调整大小的数组。相比固定大小的数组&#…

Redis中惰性策略的启发和流量包应用设计

引言 在技术领域,许多中间件之所以获得巨大成功,部分原因在于它们所采用的思想之先进。这些思想解决了一个个世纪难题,接下来我将讲述一个我学习到的思想,并将其应用至工作中的案例。 惰性策略在日常编码中随处可见,但…

STL容器的一些操作(常用的,不全)

目录 string 1.string的一些创建 2.string 的读入和输出: 3.string的一些操作 4.彻底清空string 容器的函数 vector 1.vector的一些创建: 2.vector的一些操作: 3.vector的彻底清空并释放内存: queue 循环队列&#xff1…

兑换码生成算法

兑换码生成算法 兑换码生成算法1.兑换码的需求2.算法分析2.重兑校验算法3.防刷校验算法 3.算法实现 兑换码生成算法 兑换码生成通常涉及在特定场景下为用户提供特定产品或服务的权益或礼品,典型的应用场景包括优惠券、礼品卡、会员权益等。 1.兑换码的需求 要求如…

【DevOps工具篇】安装 LDAP 管理 GUI PhpLdapAdmin

【DevOps工具篇】安装 LDAP 管理 GUI PhpLdapAdmin 目录 【DevOps工具篇】安装 LDAP 管理 GUI PhpLdapAdmin启用远程管理功能安装 phpLDAPadmin对 phpLDAPadmin 进行补丁Apache 的配置编辑配置文件推荐超级课程: Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速…

Pointnet++分类和分割数据集准备和实验复现

5.分类数据集Modelnet40及可视化 Modelnet40分类数据集 原始的modelnet40是off文件,是cad模型 OFF文件是一种用于存储三维对象信息的文件格式,全称为"Object File Format"。它主要用于存储几何体的顶点、边和面信息,以及可能的颜…

面对复杂多变的网络攻击,企业应如何守护网络安全

企业上云,即越来越多的企业把业务和数据,迁移到云端。随着云计算、大数据、物联网、人工智能等技术的发展,用户、应用程序和数据无处不在,企业之间的业务边界逐渐被打破,网络攻击愈演愈烈,手段更为多。 当前…

未来的发展趋势-无服务架构-即将到来-让我们欢呼吧

无服务架构(Serverless Architecture)是一种颠覆性的云计算架构范式,旨在简化应用程序开发和部署过程,提高开发效率和降低成本。在传统的基础设施即服务(IaaS)和平台即服务(PaaS)模型…

uni app 扫雷

闲来无聊。做个扫雷玩玩吧&#xff0c;点击打开&#xff0c;长按标记&#xff0c;标记的点击两次或长按取消标记。所有打开结束 <template><view class"page_main"><view class"add_button" style"width: 100vw; margin-bottom: 20r…

Pytorch实用教程:torch.from_numpy(X_train)和torch.from_numpy(X_train).float()的区别

在PyTorch中&#xff0c;torch.from_numpy()函数和.float()方法被用来从NumPy数组创建张量&#xff0c;并可能改变张量的数据类型。两者之间的区别主要体现在数据类型的转换上&#xff1a; torch.from_numpy(X_train)&#xff1a;这行代码将NumPy数组X_train转换为一个PyTorch张…

Docker容器监控之CAdvisor+InfluxDB+Granfana

介绍&#xff1a;CAdvisor监控收集InfluxDB存储数据Granfana展示图表 目录 1、新建3件套组合的docker-compose.yml 2、查看三个服务容器是否启动 3、浏览cAdvisor收集服务&#xff0c;http://ip:8080/ 4、浏览influxdb存储服务&#xff0c;http://ip:8083/ 5、浏览grafan…

如何利用CSS实现文字滚动效果

1. 使用CSS3的animation属性 CSS3的animation属性可以让元素在一段时间内不停地播放某个动画效果。我们可以利用这个特性来实现文字滚动效果。 我们需要定义一个包含所有需要滚动的文本的容器元素。比如&#xff1a; <div class"scroll-container"><p>…

【每日一道算法题】有序数组的平方、长度最小的子数组

文章目录 有序数组的平方写在前面题目思路解析暴力解法双指针法 我的代码暴力解法双指针法 参考答案解法暴力方法双指针法 长度最小的子数组原题思路解析暴力法滑动窗口法 我的代码官方题解滑动窗口法 有序数组的平方 写在前面 本人是一名在java后端寻路的小白&#xff0c;希…