进阶必看,3种灵活操作PyTorch张量的高级方法

大家好,在PyTorch中进行高级张量操作时,开发者经常面临这样的问题,如何根据一个索引张量从另一个张量中选取元素。

例如有一个包含数千个特征的大规模数据集,需要根据特定的索引模式快速提取信息。本文将介绍三种索引选择方法来解决这类问题。

torch.index_select

torch.index_select函数通过在指定的维度上进行元素选择,同时在其他维度上保持元素不变。也就是说,在目标维度上根据索引张量来挑选元素,而其他维度的元素则原封不动。为了更直观地理解这一概念,来看一个2D张量的示例,这里将沿着维度1进行元素的选择:

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)

由此得到的张量形状为[len_dim_0, num_picks]:对于维度0上的每个元素,都从维度1中选取了相同的元素。将其形象化:

现在迈入三维张量的世界,这样更贴近机器学习与数据科学的实际需求。

设想一个三维张量,其维度为[batch_size, num_elements, num_features]:num_elements表示每个批次中的项目数,每个项目具有num_features个特征。这种张量结构所有元素都是以批量方式处理的。

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = 2values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(num_picks,))
# [batch_size, num_picks, num_features]
picked = torch.index_select(values, 1, indices)

若更倾向于通过代码来理解index_select的功能,以下是使用简单的for循环来模拟该功能实现的示例:

picked_manual = torch.zeros_like(picked)
for i in range(batch_size):for j in range(num_picks):for k in range(num_features):picked_manual[i, j, k] = values[i, indices[j], k]assert torch.all(torch.eq(picked, picked_manual))

torch.gather

torch.gather函数在功能上与torch.index_select相似,但提供了更为灵活的元素选择方式。

torch.gather中,选择的元素不仅取决于索引张量,还受到其他维度的影响。以机器学习项目为例,可以针对每个批次和每个特征,根据条件从元素维度中选取不同的元素,实现这一点是通过使用另一个张量来指定索引。

在实际应用中,这种用法非常普遍,比如在决策树中根据特定条件选择节点。

每个节点由一组特征定义,可以创建一个索引矩阵,将选定的元素放置在批次维度上,并在特征维度上复制这些值。这样,对于每个批次索引,都可以基于特定条件选择不同的元素,尽管在我们的示例中,这些条件仅与批次索引相关,但也可以根据特征索引来确定。

为了更清楚地理解这一点,再次从二维(2D)示例开始,逐步展示如何使用torch.gather来实现这种灵活的索引选择。

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
# [len_dim_0, num_picks]
picked = torch.gather(values, 1, indices)

直观来看,torch.gather的元素选择呈现出与torch.index_select不同的模式。不同于后者沿直线进行选择,torch.gather根据维度0上的每个索引,在维度1中挑选出不同的元素:

接下来进入三维世界,并展示如何用Python代码来实现类似的选择机制:

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = 5
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
picked = torch.gather(values, 1, indices)picked_manual = torch.zeros_like(picked)
for i in range(batch_size):for j in range(num_picks):for k in range(num_features):picked_manual[i, j, k] = values[i, indices[i, j, k], k]assert torch.all(torch.eq(picked, picked_manual))

torch.take

在三个函数中,torch.take的工作原理最为简单明了。它首先将输入张量视为一维数组,然后根据指定的索引从中选取元素。

例如,对于一个4行5列的张量,如果使用torch.take并选取索引6和19,实际上获取的是这个张量在一维化之后位于第6个位置和第19个位置的元素,分别对应于原始二维结构中的第2行第2列和最后一行最后一列的元素。

2D示例:

num_picks = 2values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
# [num_picks]
picked = torch.take(values, indices)

现在得到了两个元素:

接下来探讨三维张量的索引选择及其实现。索引张量不受固定形状的限制,可以是任意形状。根据这个索引张量进行的元素选择,其结果也将遵循这种形状,确保输出与索引张量的维度结构一致。

import torchbatch_size = 16
num_elements = 64
num_features = 1024
num_picks = (2, 5, 3)values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
# [2, 5, 3]
picked = torch.take(values, indices)picked_manual = torch.zeros(num_picks)
for i in range(num_picks[0]):for j in range(num_picks[1]):for k in range(num_picks[2]):picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]assert torch.all(torch.eq(picked, picked_manual))

本文介绍了Pytorch中的三种常见选择方法:torch.index_selecttorch.gathertorch.take。可以使用这些方法,根据不同的条件从张量中选取或索引特定的元素。

对于每种方法,都先通过简单的二维(2D)示例引入,并直观地展示了选择结果。接着,进入更为复杂且实际的三维(3D)应用场景,演示了如何在形状为[batch_size, num_elements, num_features]的张量中进行元素选择——这种情况在机器学习项目中十分常见。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/30319.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

matlab编写微分方程椭圆型方程(一维形式)

文章目录 理论编程实例原代码 理论 椭圆型方程一维格式即常微分方程,边值问题,方程如下所示: 截断误差: 当 h → ∞ h\rightarrow\infty h→∞时,截断误差趋于零,离散方程组成立, 写成矩阵&…

鸿蒙小案例-短视频

参加泡泡玛特写了个小demo,然后给它稍微完善了一下 基于API11 演示效果 hfvideo演示视频 主要功能集中在4个tab页内 1.首页-视频播放页 2.朋友-关注、朋友、粉丝聚合 3.消息-聊天列表 4.我的-当前用户信息展示 主页页面 1.用户主页 2.聊天页面 3.朋友页面 4.视频播放页 因为不…

闲置资源共享平台

摘 要 随着共享经济的高速发展以及人们对物品的需求方面也越来也丰富,而且各大高校的大学生们的购买力也越来越强,随之而来的问题就是身边的闲置资源也越来越多,但是也有许多的大学生对物品的要求方面不是很高,也愿意买下经济实惠…

【计算机网络体系结构】计算机网络体系结构实验-DNS模拟器实验

一、DNS模拟器实验 拓扑图 1. 服务器ip 2. 服务器填写记录 3. 客户端ip以及连接到DNS服务器 4. ping测试

hadoop Yarn资源调度器

概述 Yarn是一个资源调度平台,负责为运算程序提供服务器资源,相当于一个分布式的操作系统平台,而MapReduce等运算程序相当于操作系统之上的应用程序 Yarn基本架构 YARN 主要由ResourceManager、NodeManager、ApplicationMaster、Container …

C++基础编程100题-010 OpenJudge-1.3-08 温度表达转化

更多资源请关注纽扣编程微信公众号 http://noi.openjudge.cn/ch0103/08/ 描述 利用公式 C 5 * (F-32) / 9 (其中C表示摄氏温度,F表示华氏温度) 进行计算转化。 输入 输入一行,包含一个实数f,表示华氏温度。&…

06 PXE高效批量网络装机

目录 6.1 部署PXE远程安装服务 6.1.1 搭建PXE远程安装服务器 1. 准备CentOS 7安装源 2. 安装并启用TFTP服务 3. 准备Linux内核、初始化镜像文件 4. 准备PXE引导程序 5. 安装并启用DHCP服务 6. 配置启动菜单文件 6.1.2 验证PXE网络安装 6.2 实现Kickstart无人值守安装 6.2.1 准…

STM32学习记录(八)————定时器输出PWM及舵机的控制

文章目录 前言一、PWM1.工作原理2.内部运作机制3. PWM工作模式4.PWM结构体及库函数 二、PWM控制舵机 前言 一个学习STM32的小白~ 有错误评论区或私信指出提示:以下是本篇文章正文内容,下面案例可供参考 一、PWM 1.工作原理 以向上计数为例&#xff0…

spark 整合 yarn

spark 整合 yarn 1、在master节点上停止spark集群 cd /usr/local/soft/spark-2.4.5/sbin ./stop-all.sh 2、spark整合yarn只需要在一个节点整合, 可以删除node1 和node2中所有的spark文件 分别在node1、node2 的/usr/local/soft目录运行 rm -rf spark-2.4.…

力扣469A

文章目录 1. 题目链接2. 题目代码3. 题目总结4. 代码分析 1. 题目链接 I Wanna Be the Guy 2. 题目代码 #include<iostream> #include<set> using namespace std; int main(){int highestLevelOfGame;cin >> highestLevelOfGame;set<int> levelCanPas…

Linux下Cmake安装或版本更新

下载Cmake源码 https://cmake.org/download/ 找到对应的版本和类型 放进linux环境解压 编译 安装 tar -vxvf cmake-3.13.0.tar.gz cd cmake-3.13.0 ./bootstrap make make install设置环境变量 vi ~/.bashrc在文件尾加入 export PATH/your_path/cmake-3.13.0/bin:$PAT…

多模态大模型解读

目录 1. CLIP 2. ALBEF 3. BLIP 4. BLIP2 参考文献 &#xff08;2023年&#xff09;视觉语言的多模态大模型的目前主流方法是&#xff1a;借助预训练好的LLM和图像编码器&#xff0c;用一个图文特征对齐模块来连接&#xff0c;从而让语言模型理解图像特征并进行深层次的问…

王思聪隐形女儿曝光

王思聪"隐形"女儿曝光&#xff01;黄一鸣独自面对怀孕风波&#xff0c;坚持生下爱情结晶近日&#xff0c;娱乐圈掀起了一场惊天波澜&#xff01;前王思聪绯闻女友黄一鸣在接受专访时&#xff0c;大胆揭露了她与王思聪之间的爱恨纠葛&#xff0c;并首度公开承认&#…

【C++入门(4)】引用、内联函数、auto

一、引用与类型转换 我们看下面这个例子。 用 int & 给 double 类型的变量起别名&#xff0c;编译器报错&#xff1a; int main() {double b 3.14;int a b;int& x b;return 0; } 用 const int & 给 double 类型的变量起别名&#xff0c;成功&#xff1a; in…

ROS 机器人运动控制

ROS 机器人运动控制 机器人运动 当我们拿到一台机器人&#xff0c;其配套的程序源码中&#xff0c;通常会有机器人核心节点&#xff0c;这个核心节点既能够驱动机器人的底层硬件&#xff0c;同时向上还会订阅一个速度话题。我们只需要编写一个新的节点&#xff08;速度控制节点…

白酒:中国的酒文化的传承与发扬

中国&#xff0c;一个拥有五千年文明史的国度&#xff0c;其深厚的文化底蕴孕育出了丰富多彩的酒文化。在这片广袤的土地上&#xff0c;酒不仅仅是一种产品&#xff0c;更是一种情感的寄托&#xff0c;一种文化的传承。云仓酒庄的豪迈白酒&#xff0c;正是这一文化脉络中的一颗…

文件加密软件排行榜|常用三款文件加密软件推荐

Top 1: 安秉网盾文件加密软件 加密模式多样&#xff1a;采用多种加密模式&#xff0c;对企业重要的文档、图纸进行全方位360度保护。可根据企业不同工作场景设置不同的加密模式。 全透明加密&#xff1a;通过全透明加密模式&#xff0c;对企业重要的图纸文件类型进行全盘透明…

521. 最长特殊序列 Ⅰ(Rust单百解法-脑筋急转弯)

题目 给你两个字符串 a 和 b&#xff0c;请返回 这两个字符串中 最长的特殊序列 的长度。如果不存在&#xff0c;则返回 -1 。 「最长特殊序列」 定义如下&#xff1a;该序列为 某字符串独有的最长 子序列 &#xff08;即不能是其他字符串的子序列&#xff09; 。 字符串 s …

XHS-Downloader是一款小红书图片视频下载工具

这款软件可以提取账号发布、收藏、点赞作品链接&#xff1b;提取搜索结果作品链接、用户链接&#xff1b;下载小红书作品信息&#xff1b;提取小红书作品下载地址&#xff1b;下载小红书无水印作品文件&#xff01; &#x1f4d1; 功能清单 ✅ 采集小红书图文 / 视频作品信息…

全国第四轮软件工程学科评估结果

#计算机专业好吗##高考填志愿选择专业##计算机专业还能不能报# 又到了让各位家长头疼的高考填志愿时刻。 前几天的头条&#xff0c;张雪峰直播卖卡3小时入账2亿&#xff0c;为了孩子的前途&#xff0c;家长们确实是不惜重金。 作为毕业如今18个年头一直从事软件领域的老码农&am…