在Pytorch中涉及张量的操作都会涉及“dim”的设置,虽然也理解个大差不差,但是偶尔还是有点犯迷糊,究其原因还是没有形象化的理解。
首先,张量的维度排序是有固定顺序的,0,1,2,......,是遵循一个从外到内的索引顺序;张量本身的维度越高,往内延伸的维度数越高。
“dim define what operation elements is”——这是我自己的形象化理解。
看一组代码:
>>> ones = torch.ones(3,4)
>>> ones
tensor([[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]])
>>> zeros = torch.zeros(3,4)
>>> zeros
tensor([[0., 0., 0., 0.],[0., 0., 0., 0.],[0., 0., 0., 0.]])
>>> ra = torch.arange(12).view(3,4)
>>> ra
tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])>>> torch.stack((ra,zeros),dim=0)
tensor([[[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.]],[[ 0., 0., 0., 0.],[ 0., 0., 0., 0.],[ 0., 0., 0., 0.]]])
>>> torch.stack((ones,zeros),dim=0)
tensor([[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[0., 0., 0., 0.],[0., 0., 0., 0.],[0., 0., 0., 0.]]])
>>> torch.stack((ones,zeros),dim=-1)
tensor([[[1., 0.],[1., 0.],[1., 0.],[1., 0.]],[[1., 0.],[1., 0.],[1., 0.],[1., 0.]],[[1., 0.],[1., 0.],[1., 0.],[1., 0.]]])
>>> torch.stack((ra,zeros),dim=-1)
tensor([[[ 0., 0.],[ 1., 0.],[ 2., 0.],[ 3., 0.]],[[ 4., 0.],[ 5., 0.],[ 6., 0.],[ 7., 0.]],[[ 8., 0.],[ 9., 0.],[10., 0.],[11., 0.]]])
>>> torch.stack((ra,zeros),dim=1)
tensor([[[ 0., 1., 2., 3.],[ 0., 0., 0., 0.]],[[ 4., 5., 6., 7.],[ 0., 0., 0., 0.]],[[ 8., 9., 10., 11.],[ 0., 0., 0., 0.]]])
>>> print("dim define what operation elements is")
dim define what operation elements is
>>>
>>>
看完代码你应该会比较形象化的理解最后一句话:dim其实定义了参与操作的元素是什么样的。对于一个batch的数据来说,dim=0上定义的是一个个样本,dim=1定义了第二个维度即每个样本的特征维度,......, dim=-1代表了从最底层的逐个数值操作。