在PyTorch里,可以在tensor里找到min, max, mean, sum 等aggregation值。
直接上代码
import torch x = torch.arange(0, 100, 10)
print(x)
print(f"Minimum: {x.min()}")
print(f"Minimum: {torch.min(x)}")
print(f"Maximum: {x.max()}")
print(f"Maximum: {torch.max(x)}")
print(f"Mean: {x.type(torch.float32).mean()}")
print(f"Mean: {torch.mean(x.type(torch.float32))}")
print(f"Sum: {x.sum()}")
print(f"Sum: {torch.sum(x)}")# 结果如下
tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
Minimum: 0
Minimum: 0
Maximum: 90
Maximum: 90
Mean: 45.0
Mean: 45.0
Sum: 450
Sum: 450
可以在tensor里找到最大值和最小值的位置,用到 torch.argmax()
和 torch.argmin()
。
print(f"Index where max value occurs: {x.argmax()}")
print(f"Index where min value occurs: {x.argmin()}")# 结果如下
Index where max value occurs: 9
Index where min value occurs: 0
在深度学习中,会经常出现的问题是tensor的数据类型不对。如果一个tensor的数据类型是 torch.float64
,而另一个tensor的数据类型是 torch.float32
,运行起来就出错了。
要改变tensor的数据类型,可以使用 torch.Tensor.type(dtype=None)
其中的 dtype 参数是你想用的数据类型。
代码如下:
tensor = torch.arange(10., 100., 10.)
print(tensor.dtype)
tensor_float16 = tensor.type(torch.float16)
print(tensor_float16)
tensor_int8 = tensor.type(torch.int8)
print(tensor_int8)# 输出
torch.float32
tensor([10., 20., 30., 40., 50., 60., 70., 80., 90.], dtype=torch.float16)
tensor([10, 20, 30, 40, 50, 60, 70, 80, 90], dtype=torch.int8)
偶尔我们要将tensor的维度进行reshape, stack, squeeze, and unsqueeze。
Method | One-line description |
---|---|
torch.reshape(input, shape) | Reshapes input to shape (if compatible), can also use torch.Tensor.reshape() |
Tensor.view(shape) | Returns a view of the original tensor in a different shape but shares the same data as the original tensor |
torch.stack(tensors, dim=0) | Concatenates a sequence of tensors along a new dimension (dim), all tensors must be same size |
torch.squeeze(input) | Squeezes input to remove all the dimensions with value 1 |
torch.unsqueeze(input, dim) | Returns input with a dimension value of 1 added at dim |
torch.permute(input, dims) | Returns a view of the original input with its dimensions permuted (rearranged) to dims |
代码例子
x = torch.arange(1., 8.)
print(f"x: {x}")
print(f"x.shape: {x.shape}")x_reshaped = x.reshape(1, 7)
print(f"x_reshaped: {x_reshaped}")
print(f"x_reshaped.shape: {x_reshaped.shape}")z = x.view(1, 7)
print(f"z: {z}")
print(f"z.shape: {z.shape}")# changing z changes x
z[:, 1] = 5
print(f"z: {z}")
print(f"x: {x}")# Stack tensors on top of each other
x_stacked = torch.stack([x, x, x, x], dim=0)
print(f"x_stack: {x_stacked}")# Remove extra dimension from x_reshaped
x_squeezed = x_reshaped.squeeze()
print(f"New tensor: {x_squeezed}")
print(f"New shape: {x_squeezed.shape}")# Add an extra dimension with unsqueeze
x_unsqueezed = x_squeezed.unsqueeze(dim=0)
print(f"\nNew tensor unsqueezed: {x_unsqueezed}")
print(f"New shape unsqueezed: {x_unsqueezed.shape}")x_original = torch.rand(size=(224, 224, 3)) # index[0] = 224, index[1] = 224, index[2] = 3
x_permuted = x_original.permute(2, 0, 1) # 2 对应 3,0 对应 224, 1 对应 224
print(f"x_permuted: {x_permuted}")
print(f"Previous shape: {x_original.shape}")
print(f"New shape: {x_permuted.shape}")
结果如下:
x: tensor([1., 2., 3., 4., 5., 6., 7.])
x.shape: torch.Size([7])
x_reshaped: tensor([[1., 2., 3., 4., 5., 6., 7.]])
x_reshaped.shape: torch.Size([1, 7])
z: tensor([[1., 2., 3., 4., 5., 6., 7.]])
z.shape: torch.Size([1, 7])
z: tensor([[1., 5., 3., 4., 5., 6., 7.]])
x: tensor([1., 5., 3., 4., 5., 6., 7.])
x_stack: tensor([[1., 5., 3., 4., 5., 6., 7.],[1., 5., 3., 4., 5., 6., 7.],[1., 5., 3., 4., 5., 6., 7.],[1., 5., 3., 4., 5., 6., 7.]])
New tensor: tensor([1., 5., 3., 4., 5., 6., 7.])
New shape: torch.Size([7])New tensor unsqueezed: tensor([[1., 5., 3., 4., 5., 6., 7.]])
New shape unsqueezed: torch.Size([1, 7])
x_permuted: tensor([[[0.3225, 0.8588, 0.1680, ..., 0.0337, 0.5035, 0.5198],[0.8601, 0.8189, 0.3540, ..., 0.1257, 0.7823, 0.3571],[0.6149, 0.8713, 0.3548, ..., 0.2796, 0.6624, 0.9844],...,[0.6460, 0.5896, 0.6126, ..., 0.6501, 0.2514, 0.2283],[0.7159, 0.3523, 0.6296, ..., 0.4082, 0.5447, 0.5778],[0.2686, 0.9415, 0.7950, ..., 0.0317, 0.6215, 0.4071]],[[0.9712, 0.8914, 0.0946, ..., 0.7424, 0.7330, 0.5440],[0.1387, 0.5177, 0.2111, ..., 0.4829, 0.2734, 0.2656],[0.2806, 0.8434, 0.4510, ..., 0.2843, 0.2676, 0.0669],...,[0.8408, 0.8022, 0.8112, ..., 0.7236, 0.3939, 0.8946],[0.9174, 0.6701, 0.5786, ..., 0.1829, 0.7117, 0.5937],[0.3836, 0.1485, 0.7292, ..., 0.2435, 0.5428, 0.8280]],[[0.9295, 0.9307, 0.9878, ..., 0.1073, 0.8325, 0.4217],[0.0976, 0.2211, 0.1686, ..., 0.6174, 0.0807, 0.1583],[0.7492, 0.9756, 0.6296, ..., 0.0263, 0.0264, 0.7566],...,[0.6900, 0.5780, 0.3770, ..., 0.6371, 0.4390, 0.3228],[0.8862, 0.4170, 0.3777, ..., 0.0735, 0.0238, 0.2450],[0.8991, 0.6936, 0.9514, ..., 0.7649, 0.4279, 0.3810]]])
Previous shape: torch.Size([224, 224, 3])
New shape: torch.Size([3, 224, 224])
看到这里了,给个赞呗~