一、前言
本次要介绍的函数为Tensor.scatter_函数,也是PyTorch中常用的函数之一,但遗憾的是,我想在网络上查询该函数的用法时,大部分的文章都是直接给出一个示例,看完之后,其中的原理我还是无法理解,因此,秉着靠别人不如靠自己的原则,决定自己彻底弄明白!话不多说,Let’s go!🫵希望能够对你有所帮助🫵!
二、方法解析
首先,要明白的是:scatter_函数是在做怎样的事情?实际上,scatter_函数是一个填值操作。假设有一个元素全为0的tensor A,有一个tensor B,还有一个索引tensor C,当执行 A.scatter_(dim, index=C, src=B) 时,实际上是将B按照C的规则,将B中对应位置的值赋值给A对应位置元素。
嗯……解释了一通,貌似还是有点绕,没关系,现在你只要知道scatter_函数的作用就是给A赋值就完事了!接下来,通过大量的案例彻底弄明白***scatter_***函数到底在干什么!
三、案例分析
3.1 tensor.shape=2D
- dim = 0
import torch
# 创建一个tensor, name为score
score = torch.arange(1,12).reshape((3,4))
# 创建一个元素全为0的tensor, name为mask
mask = torch.zeros(scroe.shape, dtype=score.dtype)
# 创建一个tensor, name为index
index = torch.tensor([[0,1,2,1]])
# 给mask赋值
mask.scatter_(dim=0, index=index, src=score)
# print mask
print(mask)
👋👋👋👋👋👋👋重要!接下来,详细介绍一下scatter_的工作原理!👋👋👋👋👋👋👋
案例为tensor的维度为2D,dim=0的场景。
当我们执行***mask.scatter_(dim=0, index=index, src=score)***这行代码时,流程如下:
可以发现,当dim=0时,第2维度的值是不会发生变化的;同理,当dim=1时,第1维度的值是不会发生变化的!
- dim=1
import torch
# 创建一个tensor, name为score
score = torch.arange(1,12).reshape((3,4))
# 创建一个元素全为0的tensor, name为mask
mask = torch.zeros(scroe.shape, dtype=score.dtype)
# 创建一个tensor, name为index
index = torch.tensor([[0,1,2,1]])
# 给mask赋值
mask.scatter_(dim=1, index=index, src=score)
# print mask
print(mask)
参考文献
[1] https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_