pytorch花式索引提取topk的张量

文章目录

  • pytorch花式索引提取topk的张量
    • 问题设定
    • 代码实现
      • 索引方法
      • gather方法
      • 验证
    • 补充知识
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的张量

问题设定

在这里插入图片描述
或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(bs, dim, X)的reduced向量。我们在进行topk操作(以减少计算量)的时候经常碰到这种情况。
给出如下两种实现方法,分别使用花式索引(参考informer的代码)以及pytorch的gather方法

代码实现

索引方法

参考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 结果和indices一致,说明在第二个channel上,每个样本的索引是一样的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在这里插入图片描述
在这里插入图片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

验证

两种方法得到的结果完全相同
在这里插入图片描述

补充知识

expand方法

在 PyTorch 中,expand() 方法用于扩展张量的大小。它会在不实际复制数据的情况下,重复张量的元素以填充新的形状。这个方法可以用于广播操作,以便在执行一些需要相同形状的张量之间的数学运算时,使它们具有相同的形状。

下面是使用 expand() 方法的基本用法:

import torch# 创建一个原始张量
x = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用 expand 扩展张量的大小
expanded_x = x.expand(2, 3, 4)  # 扩展成维度为(2, 3, 4)的张量print(expanded_x)

在上面的例子中,我们首先创建了一个形状为 (2, 3) 的原始张量 x。然后,我们使用 expand() 方法将其扩展成一个维度为 (2, 3, 4) 的新张量 expanded_x,该张量的形状是在原始张量形状的基础上每个维度都扩展了一倍。

需要注意的是,expand() 方法只能用于增加张量的大小,不能减小。另外,扩展后的张量与原始张量共享底层数据,因此在原始张量上进行的任何修改都会反映在扩展后的张量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于从输入张量中按照指定索引提取元素。这个方法通常用于根据索引收集特定的元素,例如根据类别索引从分类得分张量中获取对应类别的得分。

下面是使用 gather() 方法的基本用法:

import torch# 创建一个输入张量
input_tensor = torch.tensor([[1, 2],[3, 4],[5, 6]])# 创建一个索引张量
indices = torch.tensor([[0, 0],[1, 0]])# 使用 gather 方法根据索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)print(output_tensor)

在上面的例子中,我们首先创建了一个形状为 (3, 2) 的输入张量 input_tensor,以及一个形状为 (2, 2) 的索引张量 indices。然后,我们使用 gather() 方法从输入张量 input_tensor 中按照索引张量 indices 收集元素。

gather() 方法中,参数 dim 指定了在哪个维度上进行收集操作,而 index 参数指定了收集元素所使用的索引张量。

需要注意的是,索引张量 indices 的形状必须与输出张量的形状一致,或者是可以广播成与输出张量形状一致的形状。

randint

torch.randint() 是 PyTorch 中用于生成随机整数张量的函数。它可以生成一个张量,其中的元素是在指定范围内随机抽样的整数。

下面是 torch.randint() 的基本用法示例:

import torch# 生成一个形状为 (3, 3) 的随机整数张量,范围是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))print(random_integers)

在上面的示例中,我们使用了 torch.randint() 函数来生成一个形状为 (3, 3) 的随机整数张量,其中的元素取值范围在闭区间 [low, high) 内,即从 0 到 9。

torch.randint() 函数的主要参数包括:

  • low:生成的随机整数的最小值(包含)。
  • high:生成的随机整数的最大值(不包含)。
  • size:生成的张量的形状。

你也可以不指定 low 参数,默认情况下它为 0。此外,还可以使用其他参数来控制生成的随机整数张量的设备类型、数据类型等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/681382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML世界之第二重天

目录 一、HTML 格式化 1.HTML 文本格式化标签 2.HTML "计算机输出" 标签 3.HTML 引文, 引用, 及标签定义 二、HTML 链接 1.HTML 链接 2.HTML 超链接 3.HTML 链接语法 4.文本链接 5.图像链接 6.锚点链接 7.下载链接 8.Target 属性 9.Id 属性 三、HTML …

王树森《RNN Transformer》系列公开课

本课程主要介绍NLP相关,包括RNN、LSTM、Attention、Transformer、BERT等模型,以及情感识别、文本生成、机器翻译等应用 ShusenWang的个人空间-ShusenWang个人主页-哔哩哔哩视频 (bilibili.com) (一)NLP基础 1、数据处理基础 数…

Spring Boot 笔记 007 创建接口_登录

1.1 登录接口需求 1.2 JWT令牌 1.2.1 JWT原理 1.2.2 引入JWT坐标 1.2.3 单元测试 1.2.3.1 引入springboot单元测试坐标 1.2.3.2 在单元测试文件夹中创建测试类 1.2.3.3 运行测试类中的生成和解析方法 package com.geji;import com.auth0.jwt.JWT; import com.auth0.jwt.JWTV…

【Spring】公司为什么禁止在SpringBoot项目中使用@Autowired注解

目录 前言 说明 依赖注入的类型 2.1 基于构造器的依赖注入 2.2 基于 Setter 的依赖注入 2.3 基于属性的依赖注入 基于字段的依赖注入缺陷 3.1 不允许声明不可变域 3.2 容易违反单一职责设计原则 3.3 与依赖注入容器紧密耦合 3.4 隐藏依赖关系 总结 参考文档 前言 …

网络协议与攻击模拟_16HTTP协议

1、HTTP协议结构 2、在Windows server去搭建web扫描器 3、分析HTTP协议流量 一、HTTP协议 1、概念 HTTP(超文本传输协议)用于在万维网服务器上传输超文本(HTML)到本地浏览器的传输协议 基于TCP/IP(HTML文件、图片、查询结构等&…

数字IC实践项目(9)— Tang Nano 20K: I2C OLED Driver

Tang Nano 20K: I2C OLED Driver 写在前面的话硬件模块RTL电路和相关资源报告SSD1306 OLED 驱动芯片SSD1306 I2C协议接口OLED 驱动模块RTL综合实现 总结 写在前面的话 之前在逛淘宝的时候偶然发现了Tang Nano 20K,十分感慨国产FPGA替代方案的进步之快;被…

自动生成测试用例_接口测试用例自动生成工具

前言 写用例之前,我们应该熟悉API的详细信息。建议使用抓包工具Charles或AnyProxy进行抓包。 har2case 我们先来了解一下另一个项目har2case 他的工作原理就是将当前主流的抓包工具和浏览器都支持将抓取得到的数据包导出为标准通用的 HAR 格式(HTTP A…

【AI】安装ubuntu20.04教程(未完待续)

目录 1 制作ubuntu20.04系统盘1.1 下载ubuntu镜像1.2 使用ultraiso写入镜像 2 安装Ubuntu系统 1 制作ubuntu20.04系统盘 1.1 下载ubuntu镜像 在清华镜像站https://mirrors.tuna.tsinghua.edu.cn/下载ubuntu20.04镜像 路径为/ubuntu-releases/20.04/,下载ubuntu-20…

操作系统(16)----磁盘相关

目录 一.磁盘相关概念 1.磁盘 2.磁道 3.扇区 4.盘面、柱面 5.磁盘的分类 二.磁盘调度算法 1.一次磁盘读/写操作需要的时间 2.先来先服务算法(FCFS) 3.最短寻找时间优先(SSTF) 4.扫描算法(SCAN) 5.LOOK调度算法 6.循环扫描算法(C-SCAN) 7.C-LOOK调度算法 三.减少…

9.【CPP】List (迭代器的模拟实现||list迭代器失效||list的模拟实现)

介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前一个元素和后一个元素。list与forward_…

华为机考入门python3--(12)牛客12-字符串反转

分类:字符串 知识点: 字符串是否为空 if not my_str 字符串逆序 my_str[::-1] 题目来自【牛客】 def reverse_string(s): # 判断字符串是否为空或只包含空格 if not s.strip(): return "" # 使用Python的切片语法反转字符串 re…

Pytorch的可视化

1 使用 wandb进行可视化训练过程 本文章将从wandb的安装、wandb的使用、demo的演示进行讲解。 1.1 如何安装wandb? wandb的安装比较简单,在终端中执行如下的命令即可: pip install wandb在安装完成之后,我们需要,去…

matlab入门,在线编辑,无需安装matab

matlab相关教程做的很完善,除了B站看看教程,官方教程我觉得更加高效。跟着教程一步一步编辑,非常方便。 阅读 MATLAB 官方教程: MATLAB 官方教程提供了从基础到高级的教学内容,内容包括 MATLAB 的基本语法、数据处理…

Vue3高频知识点和写法

一 Vue插件 二 vue3项目创建 创建完成后npm install npm run dev 三 setup 一 响应式数据 setup函数是用来代替data和methods的写法的,在setup函数中声明的数据和函数,导出后可以在页面中使用。 但是暂时不是响应式数据,如果要响应式数据的…

C++笔记1:操纵符输入输出

C操纵符用来控制输出控制,一是输出的形式,二是控制补白的数量和位置。本文记录一下,在一些笔试的ACM模式可能有用。其中1-4节的部分是关于格式化输入输出操作,5-6节的部分是关于未格式化输入输出操作。 1. 控制布尔值的格式 一般…

C语言—基础数据类型(含进制转换)

进制转换不多,但我觉得适合小白(我爱夸自己嘿嘿) 练习 1. 确认基础类型所占用的内存空间(提示:使用sizeof 运算符): 在这里我说一下,long 类型通常占用 4 字节。在 64 位系统上,long 类型通常也可为 8 字节。 格式…

LeetCode、208. 实现 Trie (前缀树)【中等,自定义数据结构】

文章目录 前言LeetCode、208. 实现 Trie (前缀树)【中等,自定义数据结构】题目链接与分类思路 资料获取 前言 博主介绍:✌目前全网粉丝2W,csdn博客专家、Java领域优质创作者,博客之星、阿里云平台优质作者、专注于Java后端技术领…

低资源学习与知识图谱:构建与应用

目录 前言1 低资源学习方法1.1 数据增强1.2 特征增强1.3 模型增强 2 低资源知识图谱构建与推理2.1 元关系学习2.2 对抗学习2.3 零样本关系抽取2.4 零样本学习与迁移学习2.5 零样本学习与辅助信息 3 基于知识图谱的低资源学习应用3.1 零样本图像分类3.2 知识增强的零样本学习3.3…

云原生介绍与容器的基本概念

云原生介绍 1、云原生的定义 云原生为用户指定了一条低心智负担的、敏捷的、能够以可扩展、可复制的方式最大化地利用云的能力、发挥云的价值的最佳路径。 2、云原生思想两个理论 第一个理论基础是:不可变基础设施。 第二个理论基础是:云应用编排理…

备战蓝桥杯---图论基础理论

图的存储&#xff1a; 1.邻接矩阵&#xff1a; 我们用map[i][j]表示i--->j的边权 2.用vector数组&#xff08;在搜索专题的游戏一题中应用过&#xff09; 3.用邻接表&#xff1a; 下面是用链表实现的基本功能的代码&#xff1a; #include<bits/stdc.h> using nam…