进阶必看,3种灵活操作PyTorch张量的高级方法

大家好,在PyTorch中进行高级张量操作时,开发者经常面临这样的问题,如何根据一个索引张量从另一个张量中选取元素。

例如有一个包含数千个特征的大规模数据集,需要根据特定的索引模式快速提取信息。本文将介绍三种索引选择方法来解决这类问题。

torch.index_select

torch.index_select函数通过在指定的维度上进行元素选择,同时在其他维度上保持元素不变。也就是说,在目标维度上根据索引张量来挑选元素,而其他维度的元素则原封不动。为了更直观地理解这一概念,来看一个2D张量的示例,这里将沿着维度1进行元素的选择:

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

由此得到的张量形状为[len_dim_0, num_picks]:对于维度0上的每个元素,都从维度1中选取了相同的元素。将其形象化:

现在迈入三维张量的世界,这样更贴近机器学习与数据科学的实际需求。

设想一个三维张量,其维度为[batch_size, num_elements, num_features]:num_elements表示每个批次中的项目数,每个项目具有num_features个特征。这种张量结构所有元素都是以批量方式处理的。

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = 2values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(num_picks,))
# [batch_size, num_picks, num_features]
picked = torch.index_select(values, 1, indices)

若更倾向于通过代码来理解index_select的功能,以下是使用简单的for循环来模拟该功能实现的示例:

picked_manual = torch.zeros_like(picked)
for i in range(batch_size):for j in range(num_picks):for k in range(num_features):picked_manual[i, j, k] = values[i, indices[j], k]assert torch.all(torch.eq(picked, picked_manual))

torch.gather

torch.gather函数在功能上与torch.index_select相似,但提供了更为灵活的元素选择方式。

torch.gather中,选择的元素不仅取决于索引张量,还受到其他维度的影响。以机器学习项目为例,可以针对每个批次和每个特征,根据条件从元素维度中选取不同的元素,实现这一点是通过使用另一个张量来指定索引。

在实际应用中,这种用法非常普遍,比如在决策树中根据特定条件选择节点。

每个节点由一组特征定义,可以创建一个索引矩阵,将选定的元素放置在批次维度上,并在特征维度上复制这些值。这样,对于每个批次索引,都可以基于特定条件选择不同的元素,尽管在我们的示例中,这些条件仅与批次索引相关,但也可以根据特征索引来确定。

为了更清楚地理解这一点,再次从二维(2D)示例开始,逐步展示如何使用torch.gather来实现这种灵活的索引选择。

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
# [len_dim_0, num_picks]
picked = torch.gather(values, 1, indices)

直观来看,torch.gather的元素选择呈现出与torch.index_select不同的模式。不同于后者沿直线进行选择,torch.gather根据维度0上的每个索引,在维度1中挑选出不同的元素:

接下来进入三维世界,并展示如何用Python代码来实现类似的选择机制:

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = 5
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
picked = torch.gather(values, 1, indices)picked_manual = torch.zeros_like(picked)
for i in range(batch_size):for j in range(num_picks):for k in range(num_features):picked_manual[i, j, k] = values[i, indices[i, j, k], k]assert torch.all(torch.eq(picked, picked_manual))

torch.take

在三个函数中,torch.take的工作原理最为简单明了。它首先将输入张量视为一维数组,然后根据指定的索引从中选取元素。

例如,对于一个4行5列的张量,如果使用torch.take并选取索引6和19,实际上获取的是这个张量在一维化之后位于第6个位置和第19个位置的元素,分别对应于原始二维结构中的第2行第2列和最后一行最后一列的元素。

2D示例:

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
# [num_picks]
picked = torch.take(values, indices)

现在得到了两个元素:

接下来探讨三维张量的索引选择及其实现。索引张量不受固定形状的限制,可以是任意形状。根据这个索引张量进行的元素选择,其结果也将遵循这种形状,确保输出与索引张量的维度结构一致。

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = (2, 5, 3)values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
# [2, 5, 3]
picked = torch.take(values, indices)picked_manual = torch.zeros(num_picks)
for i in range(num_picks[0]):for j in range(num_picks[1]):for k in range(num_picks[2]):picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]assert torch.all(torch.eq(picked, picked_manual))

本文介绍了Pytorch中的三种常见选择方法:torch.index_selecttorch.gathertorch.take。可以使用这些方法,根据不同的条件从张量中选取或索引特定的元素。

对于每种方法,都先通过简单的二维(2D)示例引入,并直观地展示了选择结果。接着,进入更为复杂且实际的三维(3D)应用场景,演示了如何在形状为[batch_size, num_elements, num_features]的张量中进行元素选择——这种情况在机器学习项目中十分常见。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/30319.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

评估 RAG?只要大模型框架 LlamaIndex 就足够了

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对大模型技术趋势、算法项目落地经验分享、新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。 总结链接如…

matlab编写微分方程椭圆型方程(一维形式)

文章目录 理论编程实例原代码 理论 椭圆型方程一维格式即常微分方程,边值问题,方程如下所示: 截断误差: 当 h → ∞ h\rightarrow\infty h→∞时,截断误差趋于零,离散方程组成立, 写成矩阵&…

鸿蒙小案例-短视频

参加泡泡玛特写了个小demo,然后给它稍微完善了一下 基于API11 演示效果 hfvideo演示视频 主要功能集中在4个tab页内 1.首页-视频播放页 2.朋友-关注、朋友、粉丝聚合 3.消息-聊天列表 4.我的-当前用户信息展示 主页页面 1.用户主页 2.聊天页面 3.朋友页面 4.视频播放页 因为不…

Clickhouse集群_ 双副本配置下Replicatedmergetree引擎的表在一个节点被删除后会自动恢复吗

2分片双副本的配置:1,2,3,4节点,分片1落在1,2节点, 1,2节点互为对方的副本,分片2落在3,4节点, 3,4节点互为对方的副本 replicatedmergetree引擎的表在一个节点被删除后,虽然另一个节点还有它的副本,但是这个副本不过同…

闲置资源共享平台

摘 要 随着共享经济的高速发展以及人们对物品的需求方面也越来也丰富,而且各大高校的大学生们的购买力也越来越强,随之而来的问题就是身边的闲置资源也越来越多,但是也有许多的大学生对物品的要求方面不是很高,也愿意买下经济实惠…

【计算机网络体系结构】计算机网络体系结构实验-DNS模拟器实验

一、DNS模拟器实验 拓扑图 1. 服务器ip 2. 服务器填写记录 3. 客户端ip以及连接到DNS服务器 4. ping测试

hadoop Yarn资源调度器

概述 Yarn是一个资源调度平台,负责为运算程序提供服务器资源,相当于一个分布式的操作系统平台,而MapReduce等运算程序相当于操作系统之上的应用程序 Yarn基本架构 YARN 主要由ResourceManager、NodeManager、ApplicationMaster、Container …

Gone框架介绍29 - 在Gone中使用gRPC通信

gone是可以高效开发Web服务的Golang依赖注入框架 github地址:https://github.com/gone-io/gone 文档地址:https://goner.fun/zh/ 文章目录 使用gRPC通信编写proto文件,生成golang代码编写服务端代码注册客户端编写配置文件测试总结 使用gRPC通…

C++基础编程100题-010 OpenJudge-1.3-08 温度表达转化

更多资源请关注纽扣编程微信公众号 http://noi.openjudge.cn/ch0103/08/ 描述 利用公式 C 5 * (F-32) / 9 (其中C表示摄氏温度,F表示华氏温度) 进行计算转化。 输入 输入一行,包含一个实数f,表示华氏温度。&…

Linux常见的压缩文件种类与对应的压缩解压方法

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

超详细的selenium使用指南

🍅 视频学习:文末有免费的配套视频可观看 🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 概述 selenium是网页应用中最流行的自动化测试工具,可以用来做自动化测试或者浏览器…

06 PXE高效批量网络装机

目录 6.1 部署PXE远程安装服务 6.1.1 搭建PXE远程安装服务器 1. 准备CentOS 7安装源 2. 安装并启用TFTP服务 3. 准备Linux内核、初始化镜像文件 4. 准备PXE引导程序 5. 安装并启用DHCP服务 6. 配置启动菜单文件 6.1.2 验证PXE网络安装 6.2 实现Kickstart无人值守安装 6.2.1 准…

STM32学习记录(八)————定时器输出PWM及舵机的控制

文章目录 前言一、PWM1.工作原理2.内部运作机制3. PWM工作模式4.PWM结构体及库函数 二、PWM控制舵机 前言 一个学习STM32的小白~ 有错误评论区或私信指出提示:以下是本篇文章正文内容,下面案例可供参考 一、PWM 1.工作原理 以向上计数为例&#xff0…

spark 整合 yarn

spark 整合 yarn 1、在master节点上停止spark集群 cd /usr/local/soft/spark-2.4.5/sbin ./stop-all.sh 2、spark整合yarn只需要在一个节点整合, 可以删除node1 和node2中所有的spark文件 分别在node1、node2 的/usr/local/soft目录运行 rm -rf spark-2.4.…

力扣469A

文章目录 1. 题目链接2. 题目代码3. 题目总结4. 代码分析 1. 题目链接 I Wanna Be the Guy 2. 题目代码 #include<iostream> #include<set> using namespace std; int main(){int highestLevelOfGame;cin >> highestLevelOfGame;set<int> levelCanPas…

Linux下Cmake安装或版本更新

下载Cmake源码 https://cmake.org/download/ 找到对应的版本和类型 放进linux环境解压 编译 安装 tar -vxvf cmake-3.13.0.tar.gz cd cmake-3.13.0 ./bootstrap make make install设置环境变量 vi ~/.bashrc在文件尾加入 export PATH/your_path/cmake-3.13.0/bin:$PAT…

Java_Optional 类

文章目录 一、Optional1.1 常用方法 一、Optional 到目前为止&#xff0c;臭名昭著的空指针异常是导致 Java 应用程序失败的最常见原因。以前&#xff0c;为了解决空指针异常&#xff0c;Google 在著名的 Guava 项目引入了 Optional类&#xff0c;通过检查空值的方式避免空指针…

多模态大模型解读

目录 1. CLIP 2. ALBEF 3. BLIP 4. BLIP2 参考文献 &#xff08;2023年&#xff09;视觉语言的多模态大模型的目前主流方法是&#xff1a;借助预训练好的LLM和图像编码器&#xff0c;用一个图文特征对齐模块来连接&#xff0c;从而让语言模型理解图像特征并进行深层次的问…

王思聪隐形女儿曝光

王思聪"隐形"女儿曝光&#xff01;黄一鸣独自面对怀孕风波&#xff0c;坚持生下爱情结晶近日&#xff0c;娱乐圈掀起了一场惊天波澜&#xff01;前王思聪绯闻女友黄一鸣在接受专访时&#xff0c;大胆揭露了她与王思聪之间的爱恨纠葛&#xff0c;并首度公开承认&#…

PostgreSQL源码分析——基础备份

进行基础备份有2中方式&#xff0c;可使用pg_basebackup工具或其他备份工具进行备份&#xff0c;另一种是使用底层命令进行基础备份。pg_basebackup等工具其实是封装了底层命令&#xff0c;所以&#xff0c;为了更好的理解基础备份的过程&#xff0c;这里我们使用底层命令进行备…