今天在学习代码时,发现有些深度学习的项目中使用到torch.sort()函数,在此记录一下,方便自己的查阅.
torch.sort()
官网给出了非常详细的介绍,但是为了更进一步掌握这一用法,在此记录一下。
具体官网链接如下:https://pytorch.org/docs/stable/generated/torch.sort.html#torch.sort
首先看到sort大家都清楚的知道是排序的功能,但是和python的sorted函数用法还不一样,torch中的排序封装的性能更加实用,不仅可以返回排序后的值,而且还可以返回排序后的下表索引。
一般我们的输入有三项
torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)->(Tensor, LongTensor)
input:形式上与 numpy.narray 类似,可以是一个列表形式或者是一个数组形式。
dim:维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1。以此类推。
descending:降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase
该函数返回的是一个元组,分别是一个排序后的张量,一个是下表索引的张量。
具体的代码如下,方便理解
# 可以使用jupyter编辑器直接复制运行结果
# 随机生成一个3行4列的数组x 这里是随机生成的数组,每次都不相同,不用纠结里面的数字,如何想看到相同的效果,可以使用固定数组
import torch
x = torch.randn(3,4)
# 初始值,始终不变
print(x)
tensor([[-0.9950, -0.6175, -0.1253, 1.3536],[ 0.1208, -0.4237, -1.1313, 0.9022],[-1.1995, -0.0699, -0.4396, 0.8043]])
sorted, indices = torch.sort(x) #按行从小到大排序
print(sorted)
tensor([[-0.9950, -0.6175, -0.1253, 1.3536],[-1.1313, -0.4237, 0.1208, 0.9022],[-1.1995, -0.4396, -0.0699, 0.8043]])
print(indices)
tensor([[0, 1, 2, 3],[2, 1, 0, 3],[0, 2, 1, 3]])
sorted, indices = torch.sort(x, descending=True) #按行从大到小排序 (即反序)
print(sorted)
tensor([[ 1.3536, -0.1253, -0.6175, -0.9950],[ 0.9022, 0.1208, -0.4237, -1.1313],[ 0.8043, -0.0699, -0.4396, -1.1995]])
print(indices)
tensor([[3, 2, 1, 0],[3, 0, 1, 2],[3, 1, 2, 0]])
sorted, indices = torch.sort(x, dim=0) #按列从小到大排序
print(sorted)
tensor([[-1.1995, -0.6175, -1.1313, 0.8043],[-0.9950, -0.4237, -0.4396, 0.9022],[ 0.1208, -0.0699, -0.1253, 1.3536]])
print(indices)
tensor([[2, 0, 1, 2],[0, 1, 2, 1],[1, 2, 0, 0]])
sorted, indices = torch.sort(x, dim=0, descending=True) #按列从大到小排序
print(sorted)
tensor([[ 0.1208, -0.0699, -0.1253, 1.3536],[-0.9950, -0.4237, -0.4396, 0.9022],[-1.1995, -0.6175, -1.1313, 0.8043]])
print(indices)
tensor([[1, 2, 0, 0],[0, 1, 2, 1],[2, 0, 1, 2]])