参考
如何计算神经网络模型的复杂度
深度学习卷积、全连接层、深度可分离层参数量和FLOPs计算公式
概念
- Params:模型的参数量。(空间复杂度)
- FLOPs:FLoating point Operations,前向推理的计算量。(时间复杂度)
- MAC:Memory Access Cost。基本上看每个计算输出结果 C o u t × H o u t × W o u t C_{out} \times H_{out} \times W_{out} Cout×Hout×Wout 相加的总和。
- MACC(MADD):multiply-accumulate operations:先乘起来再加起来的运算次数。 也就是 乘加 看做一次运算。
所以 1个 MACC = 2个 FLOPs。 - 内存量
H i n H_{in} Hin: 输入的 height
W i n W_{in} Win: 输入的 width
H o u t H_{out} Hout: 输出的 height
W o u t W_{out} Wout: 输入的 width
K K K: 卷积核size
C i n C_{in} Cin: 输入的channel 数
C o u t C_{out} Cout: 输出的 channel数
参数量计算
全连接层
卷积层
普通卷积: 输入尺寸 C i ∗ H i ∗ W i C_i * H_i * W_i Ci∗Hi∗Wi, 卷积核的大小为 K ∗ K K*K K∗K, 输出的尺寸大小为 C o ∗ H o ∗ W o C_o * H_o * W_o Co∗Ho∗Wo.
- 不考虑 bias
K 2 × C i × C o K^2 \times C_{i} \times C_{o} K2×Ci×Co - 考虑bias
( K 2 × C i + 1 ) × C o (K^2 \times C_{i} + 1) \times C_{o} (K2×Ci+1)×Co
池化层
对于池化层而言,常用的Max-pooling,Avg-pooling等是不存在参数量的。
batch norm
每个 batch 减均值,除方差。
再根据参数 α \alpha α, β \beta β 做缩放
在训练时计算的均值方差是直接计算,在预测时是用 running mean,running var.
所以参数量是?2HW*C, 错了, 是
2 × C i 2 \times C_{i} 2×Ci
激活函数
无参数
FLOPs
卷积层
- 不考虑 bias
( 2 × ( K 2 × C i ) − 1 ) × ( C o × H o × W o ) (2\times (K^2 \times C_{i} ) -1 ) \times (C_{o} \times H_{o} \times W_{o}) (2×(K2×Ci)−1)×(Co×Ho×Wo)
先计算输出的feature中一个元素需要的计算量。 ( K 2 × C i ) (K^2 \times C_{i} ) (K2×Ci) 表示乘法次数, ( K 2 × C i ) − 1 (K^2 \times C_{i} ) -1 (K2×Ci)−1 表示加法次数。
- 考虑bias
带bias 的计算(一部分是乘法,一部分是加法)
2 × ( K 2 × C i ) × ( C o × H o × W o ) 2\times (K^2 \times C_{i} ) \times (C_{o} \times H_{o} \times W_{o}) 2×(K2×Ci)×(Co×Ho×Wo)
全连接层
输入维度 C i C_i Ci, 输出 C o C_o Co. 全连接层就理解为一个矩阵,矩阵行数,矩阵列数,如考虑bias,则先计算输出向量中一个元素需要多少计算量,首先要做 C i C_i Ci 次乘法,然后做 C i − 1 C_i -1 Ci−1 次加法。若考虑 bias,则做的加法会多一次。
- 不考虑 bias : ( 2 N i n − 1 ) N o u t (2N_{in}-1)N_{out} (2Nin−1)Nout
N i n N o u t N_{in}N_{out} NinNout 为乘法的运算量,
( N i n − 1 ) N o u t (N_{in} - 1)N_{out} (Nin−1)Nout为加法的运算量 - 考虑 bias : ( 2 N i n ) N o u t (2N_{in})N_{out} (2Nin)Nout
工具
torchinfo
mmdetection 工具代码
More
https://github.com/sovrasov/flops-counter.pytorch
https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/utils/flops_counter.py