文章目录
- 一、基本使用
- 二、常见指标
- 2.1Input size
- 2.2Forward/backward pass size
一、基本使用
torchsummary
库是一个好用的模型可视化工具,用于帮助开发者把握每个网络层级的细节,包括其中的连接和维度。使用方法:
from torchsummary import summary
库中仅有一个函数:
summary(model, input_size, batch_size=-1, device="cuda"):
model
:模型对象。input_size
:输入数据的格式,使用(C,H,W)格式。batch_size
:批数据的数量。device
:使用的设备。
以自定义的LeNet网络模型为例:
import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 手写数字图片大小为32*32,故需填充2个像素self.model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(in_features=16 * 5 * 5, out_features=120),nn.Linear(in_features=120, out_features=84),nn.Linear(in_features=84, out_features=10),)def forward(self, x):return self.model(x)myLeNet = LeNet().to(device)
print(summary(myLeNet, input_size=(1, 28, 28), batch_size=64, device='cuda'))
二、常见指标
2.1Input size
Input size
表示输入数据的大小。在上述例子中,batch_size=64
,每张图片大小为(1,28,28)
,而Pytorch默认使用float32(双精度浮点数)占4字节,则每个batch所用内存大小为:
64 x 1 x 28 x 28 x 4 = 200 , 704 ( B y t e s ) 64x1x28x28x4=200,704(Bytes) 64x1x28x28x4=200,704(Bytes)
转化为以MB为单位:
200 , 704 / 102 4 2 ( B y t e s ) = 0.19140625 ( B y t e s ) 200,704/1024^2(Bytes)=0.19140625(Bytes) 200,704/10242(Bytes)=0.19140625(Bytes)
约等于0.19MB。
2.2Forward/backward pass size
https://blog.csdn.net/weixin_43589323/article/details/137105988?ops_request_misc=&request_id=&biz_id=102&utm_term=torchsummary&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-5-137105988.142v100pc_search_result_base2&spm=1018.2226.3001.4187