分类目录:《深入浅出Pytorch函数》总目录
按照值沿给定维度对输入张量的元素进行排序。如果未给定dim
,则选择输入的最后一个维度。若descending
被指定为True
,则元素按值降序排列,否则为升序。如果stable
为True
,则排序例程变为稳定,从而保持等价元素的顺序。
语法
torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)
参数
input
:[Tensor
] 输入张量dim
:[可选,int
] 待排序的维度- descending :[可选,
bool
] 排序顺序,False
为升序,True
为降序,默认为升序 - stable :[可选,
bool
] 指示排序是否稳定,如果为True
,则等价元素的顺序得到保留
返回值
元组(values, indices)
,其中values
是排序的值,indices
是原始输入张量中元素的索引。
实例
>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-0.2162, 0.0608, 0.6719, 2.3332],[-0.5793, 0.0061, 0.6058, 0.9497],[-0.5071, 0.3343, 0.9553, 1.0960]])
>>> indices
tensor([[ 1, 0, 2, 3],[ 3, 1, 0, 2],[ 0, 3, 1, 2]])>>> sorted, indices = torch.sort(x, 0)
>>> sorted
tensor([[-0.5071, -0.2162, 0.6719, -0.5793],[ 0.0608, 0.0061, 0.9497, 0.3343],[ 0.6058, 0.9553, 1.0960, 2.3332]])
>>> indices
tensor([[ 2, 0, 0, 1],[ 0, 1, 1, 2],[ 1, 2, 2, 0]])>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1]))
>>> x.sort(stable=True)
torch.return_types.sort(values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))