PyTorch核心函数详解:gather与where的实战指南

PyTorch中的torch.gathertorch.where是处理张量数据的关键工具,前者实现基于索引的灵活数据提取,后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示,深入解析两者的工作原理及使用技巧,帮助开发者提升数据处理的灵活性与效率。

在深度学习中,我们经常需要根据特定规则提取或生成数据。例如:

  • 从预测概率中提取Top-K类别索引
  • 根据掩码筛选有效数据点
  • 动态生成条件化张量

torch.gathertorch.where正是解决这类问题的核心函数。下文将结合图像处理、数据筛选等场景,详解它们的用法与差异。
在这里插入图片描述

一、torch.gather:基于索引的精准提取

功能描述

torch.gather(input, dim, index) 沿指定维度dim,根据index张量中的索引值,从input中提取对应元素,输出形状与index一致。

参数说明
  • input:源张量
  • dim:指定操作的维度
  • index:索引张量,其值必须为整数类型

核心规则

  • 索引穿透性:索引值直接映射源张量的位置,不改变维度
  • 广播机制:当index维度小于input时,会自动广播到匹配形状
  • 多维索引:支持通过多维索引张量提取复杂结构的数据

应用场景与示例

场景1:图像数据批量提取

假设需要从批量图像中提取特定位置的像素值:

# 假设images是形状为(2,3,3)的图像批次 (批次大小2,通道3,分辨率3x3)
images = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]],  # 第一张图像[[10,11,12],[13,14,15],[16,17,18]]  # 第二张图像
])# 提取所有图像的第0行第1列像素 (shape: (2,))
pixels = torch.gather(images, dim=2, index=torch.tensor([[[0,1,0],[0,1,0]], [[0,1,0],[0,1,0]]]))
print(pixels)
# 输出: tensor([[1, 2, 1],
#                [10, 11, 10]])
场景2:从概率分布提取Top-K结果

在NLP任务中提取预测词ID:

logits = torch.tensor([[0.1, 0.4, 0.5], [0.3, 0.6, 0.1]])  # 2个样本的3个类别的概率
topk_indices = logits.topk(k=2, dim=1).indices  # 获取Top-2索引# 使用gather提取Top-2概率值
topk_probs = torch.gather(logits, dim=1, index=topk_indices)
print(topk_probs)
# 输出:
# tensor([[0.5, 0.4],
#         [0.6, 0.3]])

二、torch.where:条件驱动的动态生成

功能描述

torch.where(condition, x, y) 根据布尔条件condition,从张量xy中选择元素,生成与输入同形状的新张量。

参数说明
  • condition:布尔型张量,决定元素来源
  • x:满足条件时选择的元素来源
  • y:不满足条件时选择的元素来源

核心特性

  • 自动广播:支持不同形状的条件与输入张量
  • 元素级操作:逐元素比较生成动态结果
  • 类型转换:输出类型由xy决定

应用场景与示例

场景1:数据清洗与过滤

筛选出温度超过30℃且湿度低于60%的记录:

temperature = torch.tensor([25.0, 32.5, 28.0, 35.0])
humidity = torch.tensor([55.0, 58.0, 70.0, 50.0])# 生成布尔掩码
mask = (temperature > 30) & (humidity < 60)# 根据条件生成标签
labels = torch.where(mask, torch.tensor("High Risk"), torch.tensor("Normal"))
print(labels)
# 输出: tensor(['Normal', 'High Risk', 'Normal', 'Normal'], dtype=string)
场景2:图像二值化处理

将灰度图像转换为二值掩码:

gray_image = torch.tensor([[0.1, 0.8], [0.6, 0.3]], dtype=torch.float32)
threshold = 0.5# 生成二值掩码
binary_mask = torch.where(gray_image > threshold, torch.tensor(1.0), torch.tensor(0.0))
print(binary_mask)
# 输出:
# tensor([[0., 1.],
#         [1., 0.]])

三、函数对比与选择指南

特性torch.gathertorch.where
核心功能基于索引精确提取元素条件驱动动态生成元素
输入要求需显式提供索引张量需条件张量及候选值张量
维度匹配严格匹配索引与源张量维度自动广播兼容不同形状
典型应用多维数据查询、Top-K提取条件筛选、数据转换、掩码生成
性能消耗较高(涉及索引计算)较低(基于原生条件判断)

四、综合实战:图像语义分割后处理

任务需求

将模型输出的概率图转换为二值掩码,并提取连通区域标签。

解决方案

# 假设prob_map是模型输出的概率图 (H,W)
prob_map = torch.rand(256, 256) > 0.5  # 二值化处理# 使用where生成掩码
mask = torch.where(prob_map, torch.tensor(1), torch.tensor(0))# 使用gather提取连通区域标签(假设labels是预测的类别索引)
labels = torch.randint(0, 10, (256, 256))
selected_labels = torch.gather(labels, dim=0, index=mask.nonzero(as_tuple=True)[0])

五、注意事项与最佳实践

  1. 索引越界预防

    # 错误示例:索引超出范围会导致错误
    valid_indices = torch.clamp(indices, min=0, max=max_dim-1)
    
  2. 类型一致性

    # 确保index张量为整型
    index = index.long()  
    
  3. 内存优化

    # 优先使用in-place操作减少显存占用
    mask.masked_fill_(condition, value)
    

结语

torch.gathertorch.where作为PyTorch生态中的基石函数,在数据工程与模型开发中扮演着不可替代的角色。理解它们的底层逻辑与适用场景,能够帮助您:

  • 更高效地实现复杂数据操作
  • 优化模型推理与训练流程
  • 解决各类条件化数据处理难题

掌握这两把利器,您将在PyTorch开发中如鱼得水!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/76965.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

声学测温度原理解释

已知声速&#xff0c;就可以得到温度。 不同温度下的胜诉不同。 25度的声速大约346m/s 绝对温度-273度 不同温度下的声速。 FPGA 通过测距雷达测温度&#xff0c;固定测量距离&#xff0c;或者可以测出当前距离。已知距离&#xff0c;然后雷达发出声波到接收到回波的时间&a…

【网络篇】UDP协议的封装分用全过程

大家好呀 我是浪前 今天讲解的是网络篇的第二章&#xff1a;UDP协议的封装分用 我们的协议最开始是OSI七层网络协议 这个OSI 七层网络协议 是计算机的大佬写的&#xff0c;但是这个协议一共有七层&#xff0c;太多了太麻烦了&#xff0c;于是我们就把这个七层网络协议就简化为…

spring-ai-alibaba使用Agent实现智能机票助手

示例目标是使用 Spring AI Alibaba 框架开发一个智能机票助手&#xff0c;它可以帮助消费者完成机票预定、问题解答、机票改签、取消等动作&#xff0c;具体要求为&#xff1a; 基于 AI 大模型与用户对话&#xff0c;理解用户自然语言表达的需求支持多轮连续对话&#xff0c;能…

嵌入式C语言高级编程:OOP封装、TDD测试与防御性编程实践

一、面向对象编程(OOP) 尽管 C 语言并非面向对象编程语言&#xff0c;但借助一些编程技巧&#xff0c;也能实现面向对象编程&#xff08;OOP&#xff09;的核心特性&#xff0c;如封装、继承和多态。 1.1 封装 封装是把数据和操作数据的函数捆绑在一起&#xff0c;对外部隐藏…

蓝桥杯 web 常考到的一些知识点

filter&#xff1a;filter方法创建一个新数组&#xff0c;其包含通过所提供函数实现的测试的所有元素。这个 方法不会改变原数组&#xff0c;而是返回一个新的数组。 map&#xff1a;map方法创建一个新数组&#xff0c;其结果是该数组中的每个元素都调用一个提供的函数后的 返回…

音视频小白系统入门笔记-0

本系列笔记为博主学习李超老师课程的课堂笔记&#xff0c;仅供参阅 音视频小白系统入门课 音视频基础ffmpeg原理 绪论 ffmpeg推流 ffplay/vlc拉流 使用rtmp协议 ffmpeg -i <source_path> -f flv rtmp://<rtmp_server_path> 为什么会推流失败&#xff1f; 默认…

mysql按条件三表并联查询

下面为你呈现一个 MySQL 按条件三表并联查询的示例。假定有三个表&#xff1a;students、courses 和 enrollments&#xff0c;它们的结构和关联如下&#xff1a; students 表&#xff1a;包含学生的基本信息&#xff0c;有 student_id 和 student_name 等字段。courses 表&…

UML之序列图的消息

序列图表现各参与者之间为完成某个行为而发生的交互及其时间顺序&#xff0c;序列图中的交互通过消息实现。消息是从一条生命线到另一条生命线的通信&#xff0c;它们通常是水平或倾斜向下的箭头&#xff0c;从发送方生命线离开&#xff0c;到达接收方生命线。如果需要&#xf…

UniAD:自动驾驶的统一架构 - 创新与挑战并存

引言 自动驾驶技术正经历一场架构革命。传统上&#xff0c;自动驾驶系统采用模块化设计&#xff0c;将感知、预测和规划分离为独立组件。而上海人工智能实验室的OpenDriveLab团队提出的UniAD&#xff08;Unified Autonomous Driving&#xff09;则尝试将这些任务整合到一个统一…

如何写好合同管理系统需求分析

引言 在当今企业数字化转型的浪潮中&#xff0c;合同管理系统作为企业法律合规和商业运营的重要支撑工具&#xff0c;其需求分析的准确性和完整性直接关系到系统建设的成败。本文基于Volere需求过程方法论&#xff0c;结合江铃汽车集团合同管理系统需求规格说明书实践案例&…

libevent服务器附带qt界面开发(附带源码)

本章是入门章节&#xff0c;讲解如何实现一个附带界面的服务器&#xff0c;后续会完善与优化 使用qt编译libevent源码演示视频qt的一些知识 1.主要功能有登录界面 2.基于libevent实现的服务器的业务功能 使用qt编译libevent 下载这个&#xff0c;其他版本也可以 主要是github上…

八、自动化函数

1.元素的定位 web自动化测试的操作核心是能够找到页面对应的元素&#xff0c;然后才能对元素进行具体的操作。 常见的元素定位方式非常多&#xff0c;如id,classname,tagname,xpath,cssSelector 常用的主要由cssSelector和xpath 1.1 cssSelector选择器 选择器的功能&#x…

Web三漏洞学习(其二:sql注入)

靶场&#xff1a;NSSCTF 、云曦历年考核题 二、sql注入 NSSCTF 【SWPUCTF 2021 新生赛】easy_sql 这题虽然之前做过&#xff0c;但为了学习sql&#xff0c;整理一下就再写一次 打开以后是杰哥的界面 注意到html网页标题的名称是 “参数是wllm” 那就传参数值试一试 首先判…

单片机非耦合业务逻辑框架

在小型单片机项目开发初期&#xff0c;由于业务逻辑相对简单&#xff0c;我们往往较少关注程序架构层面的设计。 然而随着项目经验的积累&#xff0c;开发者会逐渐意识到模块间的耦合问题&#xff1a;当功能迭代时&#xff0c;一处修改可能引发连锁反应。 此时&#xff0c;构…

Zookeeper三台服务器三节点集群部署(docker-compose方式)

1. 准备工作 - 服务器:3 台服务器,IP 地址分别为 `10.10.10.11`、`10.10.10.12`、`10.10.10.13`。 - 安装 Docker:确保每台服务器已安装 Docker 和 Docker Compose。 - 网络通信:确保三台服务器之间可以通过 IP 地址互相访问,并开放以下端口: - `2181`:Zookeeper 客户…

Mac关闭sip方法

Mac关闭sip方法 导航 文章目录 Mac关闭sip方法导航完整操作流程图详细步骤 完整操作流程图 这东西是我在网上搬运下来的&#xff0c;但是我在为业务实操过程中&#xff0c;根据实操情况还是有新的注意点的 详细步骤 1.在「关于本机」-「系统报告」-「软件」;查看SIP是否开启…

C++| 深入剖析std::list底层实现:链表结构与内存管理机制

引言 std::list的底层实现基于双向链表&#xff0c;其设计哲学与std::vector截然不同。本文将深入探讨其节点结构、内存分配策略及迭代器实现原理&#xff0c;揭示链表的性能优势和潜在代价。 1. 底层数据结构&#xff1a;双向链表 每个std::list节点包含&#xff1a; 数据域…

汉诺塔问题——用贪心算法解决

目录 一&#xff1a;起源 二&#xff1a;问题描述 三&#xff1a;规律 三&#xff1a;解决方案 递归算法 四&#xff1a;代码实现 复杂度分析 一&#xff1a;起源 汉诺塔&#xff08;Tower of Hanoi&#xff09;问题起源于一个印度的古老传说。在世界中心贝拿勒斯&#…

【Python】Python 100题 分类入门练习题 - 新手友好

Python 100题 分类入门练习题 - 新手友好篇 - 整合篇 一、数学问题题目1&#xff1a;组合数字题目2&#xff1a;利润计算题目3&#xff1a;完全平方数题目4&#xff1a;日期天数计算题目11&#xff1a;兔子繁殖问题题目18&#xff1a;数列求和题目19&#xff1a;完数判断题目21…

【linux】--- 进程概念

进程概念 1.认识冯诺依曼结构2. 操作系统&#xff08;Operator system)2.1 概念2.2 设计OS的目的2.3 理解操作系统2.4 如何理解管理2.5 理解系统调用和库函数 3. 进程3.1 基本概念和基本操作3.1.1 描述进程 - PCB3.1.2 task_struct3.1.3 查看进程 3.2 进程状态3.2.1 运行&&…