torchsummary
要使用 Jupyter Notebook 绘制一个神经网络的结构图,可以使用 `torchsummary` 库中的 `summary` 函数。该函数可以显示模型的结构以及每一层的输出形状等信息。首先,确保你已经安装了 `torchsummary`:
pip install torchsummary
然后,在 Jupyter Notebook 中运行以下代码,即可显示模型结构图:
from torchsummary import summary
import torch
import torch.nn as nnclass FeatureExtractor_01(nn.Module):def __init__(self, in_channel=5, kernel_size=3, stride=1, padding=2, mp_kernel_size=2, mp_stride=2):super(FeatureExtractor_01, self).__init__()layer1 = nn.Sequential(nn.Conv1d(5, 32, kernel_size=kernel_size, stride=stride, padding=padding),nn.BatchNorm1d(32),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))layer2 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=kernel_size, stride=stride, padding=padding),nn.BatchNorm1d(64),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))layer3 = nn.Sequential( nn.Conv1d(64, 128, kernel_size=kernel_size, stride=stride, padding=padding),nn.BatchNorm1d(128),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))layer4 = nn.Sequential(nn.Conv1d(128, 256, kernel_size=kernel_size, stride=stride, padding=padding),nn.BatchNorm1d(256),nn.ReLU(inplace=True),nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))layer5 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=kernel_size, stride=stride, padding=padding),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.AdaptiveMaxPool1d(1),nn.Flatten())self.fs = nn.Sequential(layer1,layer2,layer3,layer4,layer5,)def forward(self, tar, x=None, y=None):h = self.fs(tar)return h# Create an instance of the model
model = FeatureExtractor_01()# Display the model summary
summary(model, (5, 100))
这将在输出中显示类似于以下内容的模型结构信息:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv1d-1 [-1, 32, 104] 512BatchNorm1d-2 [-1, 32, 104] 64ReLU-3 [-1, 32, 104] 0MaxPool1d-4 [-1, 32, 52] 0Conv1d-5 [-1, 64, 56] 12,352BatchNorm1d-6 [-1, 64, 56] 128ReLU-7 [-1, 64, 56] 0MaxPool1d-8 [-1, 64, 28] 0Conv1d-9 [-1, 128, 32] 24,704BatchNorm1d-10 [-1, 128, 32] 256ReLU-11 [-1, 128, 32] 0MaxPool1d-12 [-1, 128, 16] 0Conv1d-13 [-1, 256, 20] 98,560BatchNorm1d-14 [-1, 256, 20] 512ReLU-15 [-1, 256, 20] 0MaxPool1d-16 [-1, 256, 10] 0Conv1d-17 [-1, 512, 14] 393,728BatchNorm1d-18 [-1, 512, 14] 1,024ReLU-19 [-1, 512, 14] 0
AdaptiveMaxPool1d-20 [-1, 512, 1] 0Flatten-21 [-1, 512] 0
================================================================
Total params: 531,840
Trainable params: 531,840
Non-trainable params: 0
----------------------------------------------------------------