示例代码:
plt.imshow(np.transpose(tensor_denorm.numpy(), (1, 2, 0)))
它的作用是:把一个 PyTorch 的图像张量转换成 NumPy 格式,并按照正确的维度顺序显示出来。
🚀 一步步解释:
✅ tensor_denorm
这是一个形状为 (3, H, W)
的 PyTorch Tensor,表示一个图像:
- 3:表示三个颜色通道(RGB)
- H:图像高度
- W:图像宽度
PyTorch 中的图像张量格式是 (C, H, W)
✅ .numpy()
这一步把 PyTorch Tensor 转换成 NumPy 数组(前提是 Tensor 在 CPU 上):
tensor_denorm.numpy()
得到一个 NumPy 数组,形状依然是 (3, H, W)
✅ np.transpose(..., (1, 2, 0))
NumPy 默认显示图像的格式是 (H, W, C)
,也就是:
- 高度(H)
- 宽度(W)
- 通道(C)
所以要把 (3, H, W)
转换成 (H, W, 3)
,需要换维度顺序:
np.transpose(tensor_denorm.numpy(), (1, 2, 0))
✅ plt.imshow(...)
这是 matplotlib.pyplot
的图像显示函数。它接收一个 (H, W, 3)
的数组并显示出来:
plt.imshow(...)
📌 举个例子:
假设我们有这个张量:
tensor = torch.rand(3, 150, 150) # 随机图像,3通道 150x150
执行这一步:
plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))
就能把这个随机图像展示出来了。
✅ 总结一句话:
plt.imshow(np.transpose(tensor.numpy(), (1, 2, 0)))
等价于:
“把 PyTorch 中格式为
(C, H, W)
的图像转成(H, W, C)
并显示出来”