文章目录
- 1. 举例说明
- 2. pytorch 代码
1. 举例说明
torch.argsort 的作用是可以将矩阵中的元素进行从小到大排序,得到对应的序号。假设我们有一个向量a表示如下
a = [ 8 , 7 , 6 , 9 , 7 ] \begin{equation} a=[8,7,6,9,7] \end{equation} a=[8,7,6,9,7]
那么从小到大可以得到排序向量为b
b = [ 2 , 1 , 4 , 0 , 3 ] \begin{equation} b=[2,1,4,0,3] \end{equation} b=[2,1,4,0,3]
如果我想通过序号向量b来直接从小到大排序的向量c,那么就需要torch.gather函数
c = [ 6 , 7 , 7 , 8 , 9 ] \begin{equation} c=[6,7,7,8,9] \end{equation} c=[6,7,7,8,9]
2. pytorch 代码
- python 代码描述:
import torch
torch.manual_seed(23231)torch.set_printoptions(precision=3, sci_mode=False)
# torch.seed()
if __name__ == "__main__":run_code = 0a_vector =torch.randint(low=1,high=10,size=(5,))print(f"a_vector=\n{a_vector}")a_argsort = torch.argsort(input=a_vector)print(f"a_argsort=\n{a_argsort}")a_restore = torch.argsort(a_argsort)print(f"a_restore=\n{a_restore}")a_gather = torch.gather(input=a_vector, dim=0, index=a_argsort)print(f"a_gather={a_gather}")a_matrix = torch.randint(0, 10, (3, 4))matrix_argsort = torch.argsort(input=a_matrix, dim=1)print(f"a_matrix=\n{a_matrix}")print(f"matrix_argsort=\n{matrix_argsort}")matrix_gather = torch.gather(input=a_matrix,dim=1,index=matrix_argsort)print(f"matrix_gather=\n{matrix_gather}")
- result:
a_vector=
tensor([8, 7, 6, 9, 7])
a_argsort=
tensor([2, 1, 4, 0, 3])
a_restore=
tensor([3, 1, 0, 4, 2])
a_gather=tensor([6, 7, 7, 8, 9])
a_matrix=
tensor([[0, 2, 9, 5],[0, 6, 8, 5],[0, 8, 3, 7]])
matrix_argsort=
tensor([[0, 1, 3, 2],[0, 3, 1, 2],[0, 2, 3, 1]])
matrix_gather=
tensor([[0, 2, 5, 9],[0, 5, 6, 8],[0, 3, 7, 8]])