图解PyTorch中的torch.gather函数和 scatter 函数

前言

torch.gather在目前基于 transformer or query based 的目标检测中,在最后获取目标结果时,经常用到。

这里记录下用法,防止之后又忘了。

介绍

torch.gather

在这里插入图片描述
官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据
看到这个核心定义,我们很容易想到gather()的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的

示例

我们找个3x3的二维矩阵做个实验

import torchtensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示
在这里插入图片描述

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示
在这里插入图片描述

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],[7],[9]])

过程如图所示
在这里插入图片描述

scatter

基本是 gather 的反过程,是将数据添加进去,
doc:https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

example:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],[0, 2, 0, 0, 0],[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],[6, 7, 0, 0, 8],[0, 0, 0, 0, 0]])>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],[2.0000, 2.0000, 2.0000, 3.2300]])

具体过程见 gather 的就好~一摸一样,一个获取,一个填入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/786520.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu joystick 测试手柄 xbox

Ubuntu joystick 测试手柄 xbox 测试使用Ubuntu20.04 测试环境在工控机 安装测试 实际测试使用的手柄是北通阿修罗2pro 兼容xbox Ubuntu20.04主机 连接手柄或者无线接收器后查看是否已经检测到: ls /dev/input找到输入中的 js0 即为手柄输入 需要安装joysti…

注意力机制篇 | YOLOv8改进之添加DAT注意力机制

前言:Hello大家好,我是小哥谈。DAT(Vision Transformer with Deformable Attention)是一种引入了可变形注意力机制的视觉Transformer。在训练算法模型的时候,通过引入可变形注意力机制,改进了视觉Transformer的效率和性能,使其在处理复杂的视觉任务时更加高效和准确。�…

css酷炫边框

边框一 .leftClass {background: #000;/* -webkit-animation: twinkling 1s infinite ease-in-out; 1秒钟的开始结束都慢的无限次动画 */ } .leftClass::before {content: "";width: 104%;height: 102%;border-radius: 8px;background-image: linear-gradient(var(…

正则表达式引擎库汇合

1.总览表格 一些正则表达式库的对比 index库名编程语言说明代码示例编译指令1Posix正则C语言是C标准库中用于编译POSIX风格的正则表达式库 posix-re.cgcc posix-re.c 2PCRE库C语言提供类似Perl语言的一个正则表达式引擎库。 一般系统上对应/usr/lib64/libpcre.so这个库文件&am…

柔性数组详细讲解

动态内存函数的使用和综合实践-malloc,free,realloc,calloc-CSDN博客https://blog.csdn.net/Jason_from_China/article/details/137075045 柔性数组存在的意义 柔性数组在编程语言中指的是可以动态调整大小的数组。相比固定大小的数组&#…

STL容器的一些操作(常用的,不全)

目录 string 1.string的一些创建 2.string 的读入和输出: 3.string的一些操作 4.彻底清空string 容器的函数 vector 1.vector的一些创建: 2.vector的一些操作: 3.vector的彻底清空并释放内存: queue 循环队列&#xff1…

兑换码生成算法

兑换码生成算法 兑换码生成算法1.兑换码的需求2.算法分析2.重兑校验算法3.防刷校验算法 3.算法实现 兑换码生成算法 兑换码生成通常涉及在特定场景下为用户提供特定产品或服务的权益或礼品,典型的应用场景包括优惠券、礼品卡、会员权益等。 1.兑换码的需求 要求如…

Pointnet++分类和分割数据集准备和实验复现

5.分类数据集Modelnet40及可视化 Modelnet40分类数据集 原始的modelnet40是off文件,是cad模型 OFF文件是一种用于存储三维对象信息的文件格式,全称为"Object File Format"。它主要用于存储几何体的顶点、边和面信息,以及可能的颜…

面对复杂多变的网络攻击,企业应如何守护网络安全

企业上云,即越来越多的企业把业务和数据,迁移到云端。随着云计算、大数据、物联网、人工智能等技术的发展,用户、应用程序和数据无处不在,企业之间的业务边界逐渐被打破,网络攻击愈演愈烈,手段更为多。 当前…

uni app 扫雷

闲来无聊。做个扫雷玩玩吧&#xff0c;点击打开&#xff0c;长按标记&#xff0c;标记的点击两次或长按取消标记。所有打开结束 <template><view class"page_main"><view class"add_button" style"width: 100vw; margin-bottom: 20r…

Docker容器监控之CAdvisor+InfluxDB+Granfana

介绍&#xff1a;CAdvisor监控收集InfluxDB存储数据Granfana展示图表 目录 1、新建3件套组合的docker-compose.yml 2、查看三个服务容器是否启动 3、浏览cAdvisor收集服务&#xff0c;http://ip:8080/ 4、浏览influxdb存储服务&#xff0c;http://ip:8083/ 5、浏览grafan…

如何利用CSS实现文字滚动效果

1. 使用CSS3的animation属性 CSS3的animation属性可以让元素在一段时间内不停地播放某个动画效果。我们可以利用这个特性来实现文字滚动效果。 我们需要定义一个包含所有需要滚动的文本的容器元素。比如&#xff1a; <div class"scroll-container"><p>…

JAV八股--redis

如何保证Redis和数据库数据一致性 关于异步通知中消息队列和Canal的内容。 redisson实现的分布式锁的主从一致性 明天继续深入看这个系列问题 介绍IO复用模型

【机器学习300问】59、计算图是如何帮助人们理解反向传播的?

在学习神经网络的时候&#xff0c;势必会学到误差反向传播&#xff0c;它对于神经网络的意义极其重大&#xff0c;它是训练多层前馈神经网络的核心算法&#xff0c;也是机器学习和深度学习领域中最为重要的算法之一。要正确理解误差反向传播&#xff0c;不妨借助一个工具——计…

代码随想录算法训练营第24天|理论基础 |77. 组合

理论基础 jia其实在讲解二叉树的时候&#xff0c;就给大家介绍过回溯&#xff0c;这次正式开启回溯算法&#xff0c;大家可以先看视频&#xff0c;对回溯算法有一个整体的了解。 题目链接/文章讲解&#xff1a;代码随想录 视频讲解&#xff1a;带你学透回溯算法&#xff08;理…

深入理解数据结构——堆

前言&#xff1a; 在前面我们已经学习了数据结构的基础操作&#xff1a;顺序表和链表及其相关内容&#xff0c;今天我们来学一点有些难度的知识——数据结构中的二叉树&#xff0c;今天我们先来学习二叉树中堆的知识&#xff0c;这部分内容还是非常有意思的&#xff0c;下面我们…

前端秘法番外篇----学完Web API,前端才能算真正的入门

目录 一.引言 二.元素的获取和事件 1.获取元素 2.各种事件 2.1点击事件 2.2键盘事件 三.获取&修改操作 1.获取修改元素属性 2.修改表单属性 2.1暂停播放键的转换 2.2计数器的实现 2.3全选的实现 3.样式操作 3.1行内样式操作 3.2类名样式操作 四.节点 1.创…

记录Xshell使用ed25519公钥免密链接SSH

试了半天&#xff0c;Xshell好像没办法导入linux生成的ssh公钥,因此需要以下步骤实现免密登录 结论&#xff0c;在linux公钥文件中&#xff0c;将客户端生成的ed25519公钥加上去即可(一个公钥单独一行) 1.使用Linux生成秘钥文件(不需要输入私钥密码passphrase)或者直接创建一…

【Servlet】继承关系以及service方法

文章目录 一、继承关系二、相关方法 一、继承关系 Servlet接口下有一个GenericServlet抽象类。在GenericServlet下有一个子类HttpServlet&#xff0c;它是基于http协议。 继承关系 javax.servlet.Servlet接口​ javax.GenericServlet抽象类​ javax.servlet.http.HttpServ…

生产制造园区数字孪生3D大屏展示提升运营效益

在智慧园区的建设中&#xff0c;3D可视化管理平台成为必不可少的工具&#xff0c;数字孪生公司深圳华锐视点打造的智慧园区3D可视化综合管理平台&#xff0c;致力于将园区的人口、经济、应急服务等各项业务进行3D数字化、网络化处理&#xff0c;从而实现决策支持的优化和管理的…