目录
- 1. 可视化过程
- 2. 代码验证
前言:开始处理视频数据,遇到了陌生又熟悉的3D卷积,但对其计算过程有点疑惑,网上也没找到什么可视化动画,所以研究明白并做个记录,方便日后复习。有点简化,但认真琢磨一下图片和代码肯定能看明白:)
1. 可视化过程
3D卷积可以用来处理视频输入,对于图片来讲,shape为:[C_in, H, W]。而视频多了时间这一维度,因此视频的shape为:[C_in, D, H, W],其中D为帧数(frame),比如一条视频有10帧,则D=10。(以上都忽略了batch size N)
假如我们现在的输入的视频shape为:[3, 7, 4, 4]。即:
input_channel | frame | H | W |
---|---|---|---|
3 | 7 | 4 | 4 |
kernel shape为:[5, 3, 2, 2, 2]
output_channel | input_channel | kernel_D | kernel_H | kernel_W |
---|---|---|---|---|
5 | 3 | 2 | 2 | 2 |
计算过程可视化如下:
output shape为:[5, 6, 3, 3],其中:
output_channel | output_D | output_H | output_W |
---|---|---|---|
5 | 6 | 3 | 3 |
2. 代码验证
import torch
import torch.nn as nnN, C_in, D, H, W = 1, 3, 7, 4, 4
C_out = 5
m = nn.Conv3d(in_channels=C_in, out_channels=C_out, kernel_size=2, stride=1, bias=False)
inputs = torch.zeros(N, C, D, H, W)m.weight = nn.Parameter(torch.ones(C_out, C_in, 2, 2, 2))inputs[:, 0, :, :, :] = torch.ones(D, H, W)
inputs[:, 1, :, :, :] = torch.ones(D, H, W) * 2
inputs[:, 2, :, :, :] = torch.ones(D, H, W) * 3output = m(inputs)
print(inputs, inputs.shape)"""
tensor([[[[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]],[[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.],[1., 1., 1., 1.]]],[[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]],[[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.],[2., 2., 2., 2.]]],[[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]],[[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.],[3., 3., 3., 3.]]]]])shape:
torch.Size([1, 3, 7, 4, 4])
"""
print(output, output.shape)"""
tensor([[[[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]]],[[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]]],[[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]]],[[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]]],[[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]],[[48., 48., 48.],[48., 48., 48.],[48., 48., 48.]]]]], grad_fn=<SlowConv3DBackward0>)shape:
torch.Size([1, 5, 6, 3, 3])
"""
48怎么来的?
2x2x2x1 + 2x2x2x2 + 2x2x2x3 = 48