文章目录
- 解释
- 代码举例
解释
torch.max
是 PyTorch 中的一个函数,用于在张量中沿指定维度计算最大值。它有两种用法:
① 如果只提供一个输入张量,则返回该张量中的最大值和对应的索引。
② 如果提供两个输入张量,则返回两个张量中对应位置的较大值。
深度学习中主要使用第一种用法,下面对该用法举例说明:
代码举例
import torch# 创建一个张量
# tensor = torch.rand(1, 4, 3, 3)
tensor = torch.tensor([[[[2, 2, 0.7944],[2, 0.6368, 0.6928],[0.9620, 0.5716, 0.3827]],[[0.6216, 0, 1],[0.0588, 1, 0.0718],[1, 0.1084, 0.0462]],[[0.3117, 0.3333, 0.655],[0.8207, 0.5918, 3],[0.6565, 3, 0.2866]],[[0.6613, 0.1222, 0.0590],[0.4555, 0.0166, 0.0838],[0.3797, 0.6666, 4]]]])
# print(tensor)
print("原张量的shape为:", tensor.shape, '\n')# 计算整个张量中的最大值和对应的索引
max_value, max_indices = torch.max(tensor, dim=1)print("max_value:\n", max_value) # 输出第二个维度上的最大值
print("max_indices:\n", max_indices, '\n') # 输出第二个维度上最大值的索引print("max_value.shape为:", max_value.shape) # 输出每行的最大值
print("max_indices.shape为:", max_indices.shape) # 输出每行最大值的索引
运行结果: