torch.gather是PyTorch中的一个函数,用于从源张量中按照指定的索引张量来收集数据。
基本语法如下,
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
- input:输入源张量
- dim:要收集数据的维度
- index:索引
- sparse_grad:如果为True,则gather()在反向传播时会返回稀疏梯度
- out:输出张量,形状与index相同
用法讲解
假设有以下输入张量x,
x = torch.tensor([[[ 1, 2],[ 3, 4]],[[ 5, 6],[ 7, 8]],[[ 9, 10],[11, 12]]
])
假设有以下索引index,
index = torch.tensor([[[0, 1],[1, 0]],[[1, 0],[0, 1]],[[0, 1],[1, 0]]
])
index的索引及里面的元素的对应关系如下,
index[0, 0, 0] = 0
index[0, 0, 1] = 1
index[0, 1, 0] = 1
index[0, 1, 1] = 0
index[1, 0, 0] = 1
index[1, 0, 1] = 0
index[1, 1, 0] = 0
index[1, 1, 1] = 1
index[2, 0, 0] = 0
index[2, 0, 1] = 1
index[2, 1, 0] = 1
index[2, 1, 1] = 0
接下来,有3种情况出现,分别是dim=0、dim=1、dim=2
dim=0
拿index里的元素值去替换对应索引中第1个维度的数值,得到新索引,
(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [1, 0, 1]
[0, 1, 0], 1 -> [1, 1, 0]
[0, 1, 1], 0 -> [0, 1, 1]
[1, 0, 0], 1 -> [1, 0, 0]
[1, 0, 1], 0 -> [0, 0, 1]
[1, 1, 0], 0 -> [0, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [0, 0, 0]
[2, 0, 1], 1 -> [1, 0, 1]
[2, 1, 0], 1 -> [1, 1, 0]
[2, 1, 1], 0 -> [0, 1, 1]
有了新索引后,便可根据新索引从输入张量中获取输出张量,
result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[1, 0, 1]],[x[1, 1, 0], x[0, 1, 1]],[[x[1, 0, 0], x[0, 0, 1],[x[0, 1, 0], x[1, 1, 1]],[[x[0, 0, 0], x[1, 0, 1], [x[1, 1, 0], x[0, 1, 1]]]]=[[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]]
”“”
打印输出张量,
print(result)
"""
实际值:
tensor([[[1, 6],[7, 4]],[[5, 2],[3, 8]],[[1, 6],[7, 4]]])
"""
dim=1
拿index里的元素值去替换对应索引中第2个维度的数值,得到新索引,
(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 1, 1]
[0, 1, 0], 1 -> [0, 1, 0]
[0, 1, 1], 0 -> [0, 0, 1]
[1, 0, 0], 1 -> [1, 1, 0]
[1, 0, 1], 0 -> [1, 0, 1]
[1, 1, 0], 0 -> [1, 0, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 1, 1]
[2, 1, 0], 1 -> [2, 1, 0]
[2, 1, 1], 0 -> [2, 0, 1]
有了新索引后,便可根据新索引从输入张量中获取输出张量,
result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 1, 1]],[x[0, 1, 0], x[0, 0, 1]],[[x[1, 1, 0], x[1, 0, 1],[x[1, 0, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 1, 1], [x[2, 1, 0], x[2, 0, 1]]]]=[[[1, 4],[3, 2]],[[7, 6],[5, 8]],[[9, 12],[11, 10]]]
”“”
打印输出张量,
print(result)
"""
实际值:
tensor([[[ 1, 4],[ 3, 2]],[[ 7, 6],[ 5, 8]],[[ 9, 12],[11, 10]]])
"""
dim=3
拿index里的元素值去替换对应索引中第3个维度的数值,得到新索引,
(索引, 索引对应的元素) -> 新索引
[0, 0, 0], 0 -> [0, 0, 0]
[0, 0, 1], 1 -> [0, 0, 1]
[0, 1, 0], 1 -> [0, 1, 1]
[0, 1, 1], 0 -> [0, 1, 0]
[1, 0, 0], 1 -> [1, 0, 1]
[1, 0, 1], 0 -> [1, 0, 0]
[1, 1, 0], 0 -> [1, 1, 0]
[1, 1, 1], 1 -> [1, 1, 1]
[2, 0, 0], 0 -> [2, 0, 0]
[2, 0, 1], 1 -> [2, 0, 1]
[2, 1, 0], 1 -> [2, 1, 1]
[2, 1, 1], 0 -> [2, 1, 0]
有了新索引后,便可根据新索引从输入张量中获取输出张量,
result = torch.gather(x, 0, index)
“”“
预测值:
result = [[[x[0, 0, 0], x[0, 0, 1]],[x[0, 1, 1], x[0, 1, 0]],[[x[1, 0, 1], x[1, 0, 0],[x[1, 1, 0], x[1, 1, 1]],[[x[2, 0, 0], x[2, 0, 1], [x[2, 1, 1], x[2, 1, 0]]]]=[[[1, 2],[4, 3]],[[6, 5],[7, 8]],[[9, 10],[12, 11]]]
”“”
打印输出张量,
print(result)
"""
实际值:
tensor([[[ 1, 2],[ 4, 3]],[[ 6, 5],[ 7, 8]],[[ 9, 10],[12, 11]]])
"""