torch.topk用法
- 介绍
- 使用示例
介绍
官网介绍:https://pytorch.org/docs/stable/generated/torch.topk.html
在指定维度选取k个最大(最小)的值。
使用示例
values = torch.tensor([[2, 1, 3], [1, 2, 3]])
# values
# tensor([[2, 1, 3],
# [1, 2, 3]])select_values, indices = torch.topk(values, k=2, dim=0)
# select_values
# tensor([[2, 2, 3],
# [1, 1, 3]])
# indices
# tensor([[0, 1, 0],
# [1, 0, 1]])# 根据indices取值
values_by_indices = values[indices, torch.arange(3)]
# values_by_indices
# tensor([[2, 2, 3],
# [1, 1, 3]])