深度学习每周学习总结J4(ResDenseNet 算法探索实践 - 鸟类识别)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 一:回顾与总结: 三种神经网络模型对比研究及尝试构成新的网络结构模型
      • 卷积计算过程
      • ResNet-50 模型
        • 1. 关于残差
        • 2. 关于ConvBlock的思考
        • 3. 关于IdentityBlock的思考
          • Identity Block 的特征
          • 相关代码解析
          • 完整前向传播过程
      • ResNet-50v2 模型
        • 1.与ResNet-50的区别
        • 2. 各模块分解
          • **Block2 类**
          • **Stack2 类**
          • **ResNet50V2 类**
      • DenseNet 模型
        • 1. 设计理念
        • 2. 网络结构
        • 3. 与CNN和ResNet的对比
        • 4. 代码及关键点解析
          • 1. 模块分解
            • **_DenseLayer 类**
            • **_DenseBlock 类**
            • **_Transition 类**
            • **DenseNet 类**
          • 2. DenseNet 关键特点分析
            • 特点 1: 特征重用机制
            • 特点 2: 瓶颈设计
            • 特点 3: 过渡层
            • 特点 4: Dropout 机制
            • 特点 5: Globally Average Pooling
          • 3. 总结
      • 新的模型框架
        • 1. 理解 ResNet 和 DenseNet 的关键特点
          • ResNet 的关键特点
          • DenseNet 的关键特点
        • 2. 设计新的模型框架
        • 3. 新模型框架的构建
        • 4. 关键点和总结
          • 关键点
        • 5. 未来的方向
    • 二:代码流程
      • 0. 总结
      • 1. 设置GPU
      • 2. 导入数据及处理部分
      • 3. 划分数据集
      • 4. 模型构建部分
      • 5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等
      • 6. 训练函数
      • 7. 测试函数
      • 8. 正式训练
      • 9. 结果可视化
      • 10. 模型的保存
      • 11. 使用训练好的模型进行预测

一:回顾与总结: 三种神经网络模型对比研究及尝试构成新的网络结构模型

卷积计算过程

对于卷积新的理解:

我之前的误解是错误以为有几个卷积核,就有几个权重矩阵。比如输入通道数为3,输出通道数为2,我误以为只有两个权重矩阵。但其实对于一个卷积层来说,卷积核的数量等同于输出通道数(也称为输出的特征图数)。如果输入通道数为3,输出通道数为2,那么根据卷积操作的定义,每个输出通道对应一个单独的卷积核组。每个卷积核组由数量等于输入通道数的权重矩阵组成,因此在这个例子中,实际上有 3 × 2 = 6 3×2=6 3×2=6 个权重矩阵(即每个输出通道有3个权重矩阵,与输入通道数相对应)。

此外,在训练神经网络时,除了卷积核的参数矩阵(也称为权重矩阵)以外,偏置参数 $ b$ 也是会随着训练过程而改变的。偏置参数 b b b 是每个神经元的一个独立参数。需要注意这个偏置是与整个卷积核相关的,而不是与单个权重矩阵相关,一个卷积核组可能有复数个权重矩阵。因此,对于每个输出通道,模型在计算完成卷积后,会加上各自的偏置来调整输出,它在每次反向传播过程中通过梯度下降或其他优化算法进行更新。

这样的更新机制是为了让神经网络模型能够更好地拟合训练数据。偏置参数 b b b的存在使得每个神经元在没有任何输入时就可以有一个非零输出,从而增加了模型的灵活性和表达能力。

Source: http://cs231n.github.io/convolutional-networks/

ResNet-50 模型

1. 关于残差

目前残差结构主要有两种:
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

左边的单元为 ResNet 两层的残差单元,两层的残差单元包含两个相同输出的通道数的 3x3 卷积,只是用于较浅的 ResNet 网络,对较深的网络主要使用三层的残差单元。三层的残差单元又称为 bottleneck 结构,先用一个 1x1 卷积进行降维,然后 3x3 卷积,最后用 1x1 升维恢复原有的维度。另外,如果有输入输出维度不同的情况,可以对输入做一个线性映射变换维度,再连接后面的层。三层的残差单元对于相同数量的层又减少了参数量,因此可以拓展更深的模型。通过残差单元的组合有经典的 ResNet-50,ResNet-101 等网络结构。

2. 关于ConvBlock的思考

关于convblock的参数设置例如:[64,64,256]:

主要基于参数选择的经验性:第一个是常用的通道数64是常见的开始值,第二个值为64是为了保持相同的特征数保留和加强特征,最后一个256会大一些是为了增加特征维度和网络模型的表达能力

残差层(layer1 到 layer4)
每个残差层通过 _build_layer 函数构建,每个层的输出通道配置如 [64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048],仍然是基于经验法则和文献中的设计,确保了网络在进行深度学习时能够高效提取特征。

3. 关于IdentityBlock的思考
Identity Block 的特征

在经典的残差网络实现中,一个 Identity Block 通常具有以下特征:

无尺寸或通道变化:在 Identity Block 中,输入和输出的通道数相同,并且通常步幅也为 1。
直接使用输入作为快捷路径:在没有下采样和通道变换的情况下,快捷连接就直接是输入。

相关代码解析
self.shortcut = nn.Sequential()
if in_channels != filters3 or strides != 1:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, filters3, kernel_size=1, stride=strides),nn.BatchNorm2d(filters3))
  1. shortcut的定义
  • self.shortcut 是一个用于储存快捷连接的神经网络模块(Sequential)。
  • 当输入层in_channels 与输出层 filters3 不同,或者 strides 不等于1(即需要下采样)时,表示需要调整输入特征图的尺寸或通道数。
    • 这种情况下,shortcut 被定义为一个卷积层(Conv2d)和一个批量归一化层(BatchNorm2d)的组合。
  • 如果通道数和步幅满足条件(即相等且为1),那么 shortcut 直接使用输入 x,实际上不会发生任何变化。
  1. Identity Block 的特性
  • 身份连接:在残差网络中,身份连接使得输入能直接回馈到输出。当 in_channelsfilters3 相等,且步幅为1时,shortcut 就会等于输入 x。这正是 Identity Block 的特性,即不对输入进行任何修改。
完整前向传播过程

forward 方法中,这段代码实现了身份连接:

def forward(self, x):shortcut = self.shortcut(x)  # 使用 shortcut 将 x 进行改动或保持不变x = F.relu(self.bn1(self.conv1(x)))  # 卷积操作和激活x = F.relu(self.bn2(self.conv2(x)))  # 另一个卷积操作和激活x = self.bn3(self.conv3(x))          # 最后的卷积操作x += shortcut  # 将调整(或不调整)的 shortcut 加到输出 x 上x = F.relu(x)  # 再次进行激活return x
  • 加法操作x += shortcut):这一行实现了残差连接。根据前面定义的逻辑,如果 shortcut 是通过卷积操作得到的结果,那么输入信号经过了变换,以下采样的方式对信号进行匹配。
  • 如果 shortcut 是直接使用的输入 x,那么就实现了身份连接,保持了输入特征图的维度和内容。

总结:

你的 ConvBlock 实现中已经巧妙地综合了 ConvBlockIdentity Block 的特性。通过对 shortcut 的定义和处理,你的代码在有需要时能够对信号进行变换,而在不需要时则能够保持信号不变。这种灵活性使得你的实现既可以实现残差学习的功能,也符合对身份块的定义。因此,不需要独立声明一个单独的 Identity Block。你的实现已经具备了这种能力。

ResNet-50v2 模型

1.与ResNet-50的区别

关键区别

  1. 预激活结构

    • ResNet50V2 使用了预激活方式(pre-activation block),在执行卷积操作之前先进行 Batch Normalization 和激活函数处理。这种结构使得网络在训练时更平稳,梯度传播更有效。
  2. 快捷连接的实现

    • Block2 中,允许通过卷积层进行快捷连接(conv_shortcut),而不是简单地保持输入不变。这种方式在输出通道数需要变化时能够更好地对齐特征图。
  3. 堆叠的结构

    • Stack2 类负责管理多个 Block2 的堆叠,每个堆叠都在调整通道数和步幅时保持更清晰的结构。这在模型构建时,使得模块更易于组织和理解。
  4. 模块化和可重用性

    • ResNet50V2 更加模块化,各模块之间具有良好的隔离性,这样有助于模型的可读性和重用性。
  5. 整体设计

    • 整体的通道调整和层数设定使得 ResNet50V2 能够在特征提取上更加灵活,适应不同的输入特征。

总结

ResNet50V2 改进了原始 ResNet50 的结构,使之在处理深度学习任务时更加高效。通过预激活、灵活的快捷连接和模块化设计,使得网络能够更好地训练和泛化。这些改进为构建更深更复杂的神经网络打下了基础,也进一步解放了网络结构设计的限制。

2. 各模块分解
Block2 类
  • 初始化方法 __init__

    • 设定了 conv_shortcutstride 等参数,用于控制是否使用卷积快捷连接以及步幅的设置。
    • 预激活(Pre-activation):使用了 Batch Normalization 和 ReLU 激活函数。预激活是一种改进,有助于缓解深层网络的梯度消失问题。
    • 快捷路径(Shortcut Path)
      • 如果 conv_shortcut 为真,通过 1x1 卷积调整通道数(4 * filters),以实现跨层连接。
      • 如果步幅不为 1 而 conv_shortcut 为假,则使用最大池化。
    • 该模块的三个卷积层结构沿用了标准的 ResNet 架构。
  • 前向传播 forward

    • 先进行预激活处理(Batch Normalization + ReLU)。
    • 根据设定的 conv_shortcut 决定快捷连接的实现。
    • 将经过卷积的特征图与快捷路径相加,然后返回结果。
Stack2 类
  • 在该类中,多个 Block2 被堆叠在一起。
  • 第一层使用卷积快捷连接,后续层保持输入和输出通道一致。
ResNet50V2 类
  • 初始化方法 __init__

    • 模型的输入从 3(RGB 图像)开始,经过一系列卷积层和堆栈结构。
    • stack1stack4 各自包含不同数量的 Block2,每层都对应通道数和块数的调整。
    • 使用了标准的卷积后,增加了一个 Batch Normalization 层和 ReLU 激活。
  • 前向传播 forward

    • 输入数据经过了一系列处理,包括卷积、池化、堆叠处理、全局平均池化和最终的全连接层输出。

DenseNet 模型

它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

1. 设计理念

DenseNet的核心设计理念是“密集连接”(Dense Connectivity)。其目标是通过在每一层与所有前面层的特征图进行连接,来更有效地利用特征,缓解深度网络中的梯度消失问题。具体来说,DenseNet的主要优点包括:

  • 特征重用:每一层都直接连接到前面的所有层,使得特征的重用变得更加高效,提高了信息流动性。
  • 降低参数数量:通过密集连接,DenseNet显著减少了需要学习的参数数量,从而降低了过拟合的风险。
  • 改善梯度流动:更好的梯度流动使得网络训练更加稳定,即使在极深的网络结构中也能保持良好的性能。
2. 网络结构

DenseNet的基本结构由多个稠密块(Dense Block)组成,每个稠密块内的每一层都与前面所有层进行连接。DenseNet的结构通常包含以下几个部分:

  • 输入层:输入原始图像。
  • 卷积层:初始卷积层用于提取特征。
  • 稠密块:由多个卷积层组成,每个卷积层的输出都会被连接到下一层。
  • 过渡层(Transition Layer):通常用于降低特征图的尺寸,采用1x1卷积和平均池化。
  • 分类层:最终的全局平均池化层和全连接层进行分类。

以下是DenseNet的一个简单结构示意图:

在这里插入图片描述
在这里插入图片描述

3. 与CNN和ResNet的对比
  • CNN(卷积神经网络):传统的CNN采用每层单独连接的结构,深度较大时会面临梯度消失和特征梯度消失问题。DenseNet通过密集连接,有效地解决了这些问题。

  • ResNet(残差网络):ResNet引入了残差学习,通过跳跃连接(skip connections)来缓解深度网络中的梯度消失问题。与ResNet不同,DenseNet将每一层都连接起来,这意味着每层的信息都可以从所有前面层中获得。DenseNet通常比同深度的ResNet具有更少的参数,从而提高了计算效率和分类精度。

以下是一个简单的对比图示:

在这里插入图片描述

DenseNet通过密集连接层间的特征图,有效地利用了信息流,降低了对参数的需求,并且改善了梯度流动。与传统的CNN和ResNet相比,DenseNet在许多应用上能够提供更好的性能和更高的效率。

4. 代码及关键点解析

让我们逐步分析这段 DenseNet 的代码实现,并详细解析模型的关键特点在代码中的体现。

1. 模块分解
_DenseLayer 类
  • 初始化方法 __init__

    • 创建一个基本的稠密层(Dense Layer),实现了瓶颈结构,主要由两个卷积层和两个 Batch Normalization 层组成。
    • 首先将输入特征通过 Batch Normalization 和 ReLU 激活,然后进行 1x1 卷积,接着经过 Batch Normalization 和 ReLU,再进行 3x3 的卷积。
    • drop_rate 是用来随机失活(Dropout)的概率。
  • 前向传播 forward

    • 先经过 BatchNormReLU,再执行 Conv2d 操作。
    • 如果 drop_rate 大于 0,则对新特征进行随机失活。
    • 最后,使用 torch.cat 将原始输入 x 和新产生的特征 new_features 在通道维度上连接。这是 DenseNet 的核心思想之一:每层都可以访问前面所有层的特征。
_DenseBlock 类
  • 初始化方法 __init__
    • 创建一个稠密块(Dense Block),由多个 _DenseLayer 组成。
    • 根据传入的 num_layers 为每层创建一个 _DenseLayer,并将其添加到该块中。
_Transition 类
  • 初始化方法 __init__
    • 实现了稠密块之间的过渡层(Transition Layer),主要用于减少特征图的通道数和进行空间下采样。
    • 包括 Batch Normalization、ReLU、1x1 卷积和 2x2 的平均池化。
DenseNet 类
  • 初始化方法 __init__

    • 初始化 DenseNet 模型的各个部分。
    • 卷积层和最大池化:从输入的 RGB 图片开始,然后进行 7x7 的卷积和 3x3 的最大池化,生成初始特征。
    • DenseBlocks:根据 block_config 指定的结构,多个 _DenseBlock 被堆叠在一起。在每个 Block 之后,如果不是最后一个 Block,会加入一个过渡层。
    • 最终的 Batch Normalization 和 ReLU 处理经过所有 DenseBlocks 的特征。
    • 通过一个全连接层进行分类。
  • 参数初始化:使用 Kaiming 初始化方法初始化卷积层,并将 Batch Normalization 和全连接层的偏置和权重进行初始化。

  • 前向传播 forward

    • 将输入 x 传递通过 features 串联的各层。
    • 在最后进行全局平均池化(7x7),将特征展平,并通过分类层输出最终的类概率。
2. DenseNet 关键特点分析
特点 1: 特征重用机制
  • DenseNet 通过 torch.cat 将所有前面层的特征连接起来,允许每一层直接访问所有之前层的输出。这种设计能够有效减轻深层网络的退化问题,并增强特征重用的机制。

代码片段:

return torch.cat([x, new_features], 1)  # 在通道维度连接所有特征
特点 2: 瓶颈设计
  • _DenseLayer 中使用 1x1 卷积作为瓶颈层,这减少了在特征图上进一步运算需要的计算量和空间复杂度。
self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate, kernel_size=1, stride=1, bias=False))
特点 3: 过渡层
  • 使用 _Transition 类来控制模型的复杂度和特征图的尺寸,使网络更具灵活性。通过平均池化来降低特征图尺寸,通过卷积减少通道数,防止过拟合。
self.add_module("pool", nn.AvgPool2d(2, stride=2))  # 下采样
特点 4: Dropout 机制
  • 在每个稠密层之后可以使用 Dropout 来减少过拟合风险,增加了模型的泛化能力。
if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
特点 5: Globally Average Pooling
  • 最终使用全局平均池化,这有助于减少模型的参数并提取全局特征,特别是在分类任务中表现更好。
3. 总结

DenseNet 在结构设计和数据流动态上与传统的卷积神经网络(如 ResNet)有显著差异,具体包括:

  • 特征重用:通过所有层的特征连接,增强了特征的流动和重用。
  • 瓶颈结构:减小了每个 DenseLayer 的计算负担,提高了计算效率。
  • 稠密块和过渡层:通过更细致地控制特征图的生长,保持了模型的紧凑和高效。
  • Dropout 和全局平均池化:增强了模型的泛化能力并减少了参数数量。

这些特性使 DenseNet 在处理复杂视觉任务时具有更好的性能,并在一定程度上克服了深度网络常见的退化和过拟合问题。

新的模型框架

探索将 ResNet 和 DenseNet 结合的可能性,是深度学习模型设计中的一个有趣方向。两者各自在不同的特性和优势,为我们提供了独特的构建块。我们可以尝试创建一个新的模型框架,结合这两者的优点,提升模型的性能。

1. 理解 ResNet 和 DenseNet 的关键特点
ResNet 的关键特点
  • 残差连接:通过直接连接输入和输出,允许梯度流过深层网络,缓解梯度消失问题。
  • 简化结构:在架构上相对简单,通过较少的层数实现深度,便于训练。
DenseNet 的关键特点
  • 特征重用:每一层利用之前所有层的特征图,使得网络更加高效和有效,增强了特征的传递能力。
  • 密集连接:与前一层直接相连,促进梯度流动和信息流动。
2. 设计新的模型框架

结合 ResNet 和 DenseNet 的特性,我们可以考虑以下几个关键设计原则:

  1. 引入残差连接:在 DenseNet 的每个层之间保留残差连接,以使训练过程更稳定。
  2. 特征重用与密集连接相结合:在模型内部使用 DenseNet 的特征重用机制,同时保持 ResNet 的残差连接。
  3. 选择合适的转换块:在 Dense Block 和 Residual Block 之间逐步过渡,以便控制特征图的维度和复杂度。
3. 新模型框架的构建

下面是一个新的混合模型框架的代码实现示例,结合了 ResNet 和 DenseNet 的特性。

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDictclass DenseResLayer(nn.Module):"""A layer that combines DenseNet and ResNet features."""def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super(DenseResLayer, self).__init__()# Pre-activation Blockself.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)# Dense Blockself.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, bias=False)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)self.drop_rate = drop_ratedef forward(self, x):# First part of the block (DenseNet)out = self.norm1(x)out = self.relu1(out)out = self.conv1(out)out = self.norm2(out)out = self.relu2(out)out = self.conv2(out)# Adding Residual Connectionout += xif self.drop_rate > 0:out = F.dropout(out, p=self.drop_rate, training=self.training)return outclass DenseResNet(nn.Module):"""A new Dense-Residual network model."""def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, drop_rate=0, num_classes=1000):super(DenseResNet, self).__init__()# Initial Conv Layerself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1))]))num_features = num_init_featuresfor i, num_layers in enumerate(block_config):for j in range(num_layers):layer = DenseResLayer(num_features, growth_rate, bn_size, drop_rate)self.features.add_module("denselayer%d-%d" % (i + 1, j + 1), layer)num_features += growth_rate  # Update feature number# Final BatchNorm + ReLUself.features.add_module("norm_final", nn.BatchNorm2d(num_features))self.features.add_module("relu_final", nn.ReLU(inplace=True))# Classification Layerself.classifier = nn.Linear(num_features, num_classes)def forward(self, x):features = self.features(x)out = F.avg_pool2d(features, features.size(2)).view(features.size(0), -1)  # Global Average Poolingout = self.classifier(out)return out
4. 关键点和总结
关键点
  1. DenseResLayer:

    • 结合了 DenseNet 的特征重用机制(通过添加特征)和 ResNet 的残差连接。
    • 使用 Batch Normalization 和 ReLU 激活确保梯度流动和训练稳定性。
  2. DenseResNet:

    • 在网络构建中,使用多层稠密残差层,将 ResNet 的残差连接与 Dense Block 的特征连接相结合。
    • 通过堆叠多个这种层,模型集成了 DenseNet 和 ResNet 的优势。
  3. 灵活性和扩展性

    • 可以通过调整 growth_rateblock_config 和其他超参数来创建不同深度和宽度的网络,来适应不同任务和数据集的需求。
5. 未来的方向

通过结合 ResNet 和 DenseNet 的特性,我们获得了一种新的模型框架,理论上会在不同的任务上表现良好。在实际应用中,可以通过以下方式进一步优化和评估:

  • 超参数调整:根据具体任务,在训练前进行超参数优化,以获得更好的性能。
  • 扩展到更深的网络:在具有足够计算资源的情况下,尝试构建更深的网络,评估深度增加对性能的影响。
  • 多任务学习或迁移学习:评估该框架在多个计算机视觉任务上的有效性,或在预训练模型的基础上进行微调。

结合 ResNet 和 DenseNet 的方法,是一种深层网络结构设计的探索,能够更好地利用特征并提升模型表现。希望这个新的模型框架对你的研究和应用有所帮助!

二:代码流程

0. 总结

数据导入及处理部分:本次数据导入没有使用torchvision自带的数据集,需要将原始数据进行处理包括数据导入,查看数据分类情况,定义transforms,进行数据类型转换等操作。

划分数据集:划定训练集测试集后,再使用torch.utils.data中的DataLoader()分别加载上一步处理好的训练及测试数据,查看批处理维度.

模型构建部分:resdesnet

设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn.functional as F
from collections import OrderedDict import os,PIL,pathlib
import matplotlib.pyplot as plt
import warningswarnings.filterwarnings('ignore') # 忽略警告信息plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率

1. 设置GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

2. 导入数据及处理部分

# 获取数据分布情况
path_dir = './data/bird_photos/'
path_dir = pathlib.Path(path_dir)paths = list(path_dir.glob('*'))
# classNames = [str(path).split("\\")[-1] for path in paths] # ['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
classNames = [path.parts[-1] for path in paths]
classNames
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
# 定义transforms 并处理数据
train_transforms = transforms.Compose([transforms.Resize([224,224]),      # 将输入图片resize成统一尺寸transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),             # 将PIL Image 或 numpy.ndarray 装换为tensor,并归一化到[0,1]之间transforms.Normalize(              # 标准化处理 --> 转换为标准正太分布(高斯分布),使模型更容易收敛mean = [0.485,0.456,0.406],    # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。std = [0.229,0.224,0.225])
])
test_transforms = transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean = [0.485,0.456,0.406],std = [0.229,0.224,0.225])
])
total_data = datasets.ImageFolder('./data/bird_photos/',transform = train_transforms)
total_data
Dataset ImageFolderNumber of datapoints: 565Root location: ./data/bird_photos/StandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)RandomHorizontalFlip(p=0.5)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
total_data.class_to_idx
{'Bananaquit': 0,'Black Skimmer': 1,'Black Throated Bushtiti': 2,'Cockatoo': 3}

3. 划分数据集

# 划分数据集
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_sizetrain_dataset,test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x24b5c8a54e0>,<torch.utils.data.dataset.Subset at 0x24b5c8a5720>)
# 定义DataLoader用于数据集的加载batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle = True,num_workers = 1
)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle = True,num_workers = 1
)
# 观察数据维度
for X,y in test_dl:print("Shape of X [N,C,H,W]: ",X.shape)print("Shape of y: ", y.shape,y.dtype)break
Shape of X [N,C,H,W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

4. 模型构建部分

class _DenseLayer(nn.Sequential):"""Basic unit of DenseBlock (using bottleneck layer) """def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):super(_DenseLayer, self).__init__()self.add_module("norm1", nn.BatchNorm2d(num_input_features))self.add_module("relu1", nn.ReLU(inplace=True))self.add_module("conv1", nn.Conv2d(num_input_features, bn_size*growth_rate,kernel_size=1, stride=1, bias=False))self.add_module("norm2", nn.BatchNorm2d(bn_size*growth_rate))self.add_module("relu2", nn.ReLU(inplace=True))self.add_module("conv2", nn.Conv2d(bn_size*growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False))self.drop_rate = drop_ratedef forward(self, x):new_features = super(_DenseLayer, self).forward(x)if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return torch.cat([x, new_features], 1)class _DenseBlock(nn.Sequential):"""DenseBlock"""def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):super(_DenseBlock, self).__init__()for i in range(num_layers):layer = _DenseLayer(num_input_features+i*growth_rate, growth_rate, bn_size,drop_rate)self.add_module("denselayer%d" % (i+1,), layer)class _Transition(nn.Sequential):"""Transition layer between two adjacent DenseBlock"""def __init__(self, num_input_feature, num_output_features):super(_Transition, self).__init__()self.add_module("norm", nn.BatchNorm2d(num_input_feature))self.add_module("relu", nn.ReLU(inplace=True))self.add_module("conv", nn.Conv2d(num_input_feature, num_output_features,kernel_size=1, stride=1, bias=False))self.add_module("pool", nn.AvgPool2d(2, stride=2))class DenseNet(nn.Module):"DenseNet-BC model"def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=1000):""":param growth_rate: (int) number of filters used in DenseLayer, `k` in the paper:param block_config: (list of 4 ints) number of layers in each DenseBlock:param num_init_features: (int) number of filters in the first Conv2d:param bn_size: (int) the factor using in the bottleneck layer:param compression_rate: (float) the compression rate used in Transition Layer:param drop_rate: (float) the drop rate after each DenseLayer:param num_classes: (int) number of classes for classification"""super(DenseNet, self).__init__()# first Conv2dself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(3, stride=2, padding=1))]))# DenseBlocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)self.features.add_module("denseblock%d" % (i + 1), block)num_features += num_layers*growth_rateif i != len(block_config) - 1:transition = _Transition(num_features, int(num_features*compression_rate))self.features.add_module("transition%d" % (i + 1), transition)num_features = int(num_features * compression_rate)# final bn+ReLUself.features.add_module("norm5", nn.BatchNorm2d(num_features))self.features.add_module("relu5", nn.ReLU(inplace=True))# classification layerself.classifier = nn.Linear(num_features, num_classes)# params initializationfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)out = self.classifier(out)return out
# Now, instantiate and use the model
densenet121 = DenseNet(num_init_features=64, # init_channel=64,growth_rate=32,block_config=(6,12,24,16),num_classes=len(classNames))  model = densenet121.to(device)
model
DenseNet((features): Sequential((conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu0): ReLU(inplace=True)(pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(denseblock1): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition1): _Transition((norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock2): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer7): _DenseLayer((norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer8): _DenseLayer((norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer9): _DenseLayer((norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer10): _DenseLayer((norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer11): _DenseLayer((norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer12): _DenseLayer((norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition2): _Transition((norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock3): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer7): _DenseLayer((norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer8): _DenseLayer((norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer9): _DenseLayer((norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer10): _DenseLayer((norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer11): _DenseLayer((norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer12): _DenseLayer((norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer13): _DenseLayer((norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer14): _DenseLayer((norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer15): _DenseLayer((norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer16): _DenseLayer((norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer17): _DenseLayer((norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer18): _DenseLayer((norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer19): _DenseLayer((norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer20): _DenseLayer((norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer21): _DenseLayer((norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer22): _DenseLayer((norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer23): _DenseLayer((norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer24): _DenseLayer((norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(transition3): _Transition((norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(pool): AvgPool2d(kernel_size=2, stride=2, padding=0))(denseblock4): _DenseBlock((denselayer1): _DenseLayer((norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer2): _DenseLayer((norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer3): _DenseLayer((norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer4): _DenseLayer((norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer5): _DenseLayer((norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer6): _DenseLayer((norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer7): _DenseLayer((norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer8): _DenseLayer((norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer9): _DenseLayer((norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer10): _DenseLayer((norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer11): _DenseLayer((norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer12): _DenseLayer((norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer13): _DenseLayer((norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer14): _DenseLayer((norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer15): _DenseLayer((norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))(denselayer16): _DenseLayer((norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu1): ReLU(inplace=True)(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu2): ReLU(inplace=True)(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)))(norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu5): ReLU(inplace=True))(classifier): Linear(in_features=1024, out_features=4, bias=True)
)
# 查看模型详情
import torchsummary as summary
summary.summary(model,(3,224,224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0BatchNorm2d-5           [-1, 64, 56, 56]             128ReLU-6           [-1, 64, 56, 56]               0Conv2d-7          [-1, 128, 56, 56]           8,192BatchNorm2d-8          [-1, 128, 56, 56]             256ReLU-9          [-1, 128, 56, 56]               0Conv2d-10           [-1, 32, 56, 56]          36,864BatchNorm2d-11           [-1, 96, 56, 56]             192ReLU-12           [-1, 96, 56, 56]               0Conv2d-13          [-1, 128, 56, 56]          12,288BatchNorm2d-14          [-1, 128, 56, 56]             256ReLU-15          [-1, 128, 56, 56]               0Conv2d-16           [-1, 32, 56, 56]          36,864BatchNorm2d-17          [-1, 128, 56, 56]             256ReLU-18          [-1, 128, 56, 56]               0Conv2d-19          [-1, 128, 56, 56]          16,384BatchNorm2d-20          [-1, 128, 56, 56]             256ReLU-21          [-1, 128, 56, 56]               0Conv2d-22           [-1, 32, 56, 56]          36,864BatchNorm2d-23          [-1, 160, 56, 56]             320ReLU-24          [-1, 160, 56, 56]               0Conv2d-25          [-1, 128, 56, 56]          20,480BatchNorm2d-26          [-1, 128, 56, 56]             256ReLU-27          [-1, 128, 56, 56]               0Conv2d-28           [-1, 32, 56, 56]          36,864BatchNorm2d-29          [-1, 192, 56, 56]             384ReLU-30          [-1, 192, 56, 56]               0Conv2d-31          [-1, 128, 56, 56]          24,576BatchNorm2d-32          [-1, 128, 56, 56]             256ReLU-33          [-1, 128, 56, 56]               0Conv2d-34           [-1, 32, 56, 56]          36,864BatchNorm2d-35          [-1, 224, 56, 56]             448ReLU-36          [-1, 224, 56, 56]               0Conv2d-37          [-1, 128, 56, 56]          28,672BatchNorm2d-38          [-1, 128, 56, 56]             256ReLU-39          [-1, 128, 56, 56]               0Conv2d-40           [-1, 32, 56, 56]          36,864BatchNorm2d-41          [-1, 256, 56, 56]             512ReLU-42          [-1, 256, 56, 56]               0Conv2d-43          [-1, 128, 56, 56]          32,768AvgPool2d-44          [-1, 128, 28, 28]               0BatchNorm2d-45          [-1, 128, 28, 28]             256ReLU-46          [-1, 128, 28, 28]               0Conv2d-47          [-1, 128, 28, 28]          16,384BatchNorm2d-48          [-1, 128, 28, 28]             256ReLU-49          [-1, 128, 28, 28]               0Conv2d-50           [-1, 32, 28, 28]          36,864BatchNorm2d-51          [-1, 160, 28, 28]             320ReLU-52          [-1, 160, 28, 28]               0Conv2d-53          [-1, 128, 28, 28]          20,480BatchNorm2d-54          [-1, 128, 28, 28]             256ReLU-55          [-1, 128, 28, 28]               0Conv2d-56           [-1, 32, 28, 28]          36,864BatchNorm2d-57          [-1, 192, 28, 28]             384ReLU-58          [-1, 192, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          24,576BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62           [-1, 32, 28, 28]          36,864BatchNorm2d-63          [-1, 224, 28, 28]             448ReLU-64          [-1, 224, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]          28,672BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68           [-1, 32, 28, 28]          36,864BatchNorm2d-69          [-1, 256, 28, 28]             512ReLU-70          [-1, 256, 28, 28]               0Conv2d-71          [-1, 128, 28, 28]          32,768BatchNorm2d-72          [-1, 128, 28, 28]             256ReLU-73          [-1, 128, 28, 28]               0Conv2d-74           [-1, 32, 28, 28]          36,864BatchNorm2d-75          [-1, 288, 28, 28]             576ReLU-76          [-1, 288, 28, 28]               0Conv2d-77          [-1, 128, 28, 28]          36,864BatchNorm2d-78          [-1, 128, 28, 28]             256ReLU-79          [-1, 128, 28, 28]               0Conv2d-80           [-1, 32, 28, 28]          36,864BatchNorm2d-81          [-1, 320, 28, 28]             640ReLU-82          [-1, 320, 28, 28]               0Conv2d-83          [-1, 128, 28, 28]          40,960BatchNorm2d-84          [-1, 128, 28, 28]             256ReLU-85          [-1, 128, 28, 28]               0Conv2d-86           [-1, 32, 28, 28]          36,864BatchNorm2d-87          [-1, 352, 28, 28]             704ReLU-88          [-1, 352, 28, 28]               0Conv2d-89          [-1, 128, 28, 28]          45,056BatchNorm2d-90          [-1, 128, 28, 28]             256ReLU-91          [-1, 128, 28, 28]               0Conv2d-92           [-1, 32, 28, 28]          36,864BatchNorm2d-93          [-1, 384, 28, 28]             768ReLU-94          [-1, 384, 28, 28]               0Conv2d-95          [-1, 128, 28, 28]          49,152BatchNorm2d-96          [-1, 128, 28, 28]             256ReLU-97          [-1, 128, 28, 28]               0Conv2d-98           [-1, 32, 28, 28]          36,864BatchNorm2d-99          [-1, 416, 28, 28]             832ReLU-100          [-1, 416, 28, 28]               0Conv2d-101          [-1, 128, 28, 28]          53,248BatchNorm2d-102          [-1, 128, 28, 28]             256ReLU-103          [-1, 128, 28, 28]               0Conv2d-104           [-1, 32, 28, 28]          36,864BatchNorm2d-105          [-1, 448, 28, 28]             896ReLU-106          [-1, 448, 28, 28]               0Conv2d-107          [-1, 128, 28, 28]          57,344BatchNorm2d-108          [-1, 128, 28, 28]             256ReLU-109          [-1, 128, 28, 28]               0Conv2d-110           [-1, 32, 28, 28]          36,864BatchNorm2d-111          [-1, 480, 28, 28]             960ReLU-112          [-1, 480, 28, 28]               0Conv2d-113          [-1, 128, 28, 28]          61,440BatchNorm2d-114          [-1, 128, 28, 28]             256ReLU-115          [-1, 128, 28, 28]               0Conv2d-116           [-1, 32, 28, 28]          36,864BatchNorm2d-117          [-1, 512, 28, 28]           1,024ReLU-118          [-1, 512, 28, 28]               0Conv2d-119          [-1, 256, 28, 28]         131,072AvgPool2d-120          [-1, 256, 14, 14]               0BatchNorm2d-121          [-1, 256, 14, 14]             512ReLU-122          [-1, 256, 14, 14]               0Conv2d-123          [-1, 128, 14, 14]          32,768BatchNorm2d-124          [-1, 128, 14, 14]             256ReLU-125          [-1, 128, 14, 14]               0Conv2d-126           [-1, 32, 14, 14]          36,864BatchNorm2d-127          [-1, 288, 14, 14]             576ReLU-128          [-1, 288, 14, 14]               0Conv2d-129          [-1, 128, 14, 14]          36,864BatchNorm2d-130          [-1, 128, 14, 14]             256ReLU-131          [-1, 128, 14, 14]               0Conv2d-132           [-1, 32, 14, 14]          36,864BatchNorm2d-133          [-1, 320, 14, 14]             640ReLU-134          [-1, 320, 14, 14]               0Conv2d-135          [-1, 128, 14, 14]          40,960BatchNorm2d-136          [-1, 128, 14, 14]             256ReLU-137          [-1, 128, 14, 14]               0Conv2d-138           [-1, 32, 14, 14]          36,864BatchNorm2d-139          [-1, 352, 14, 14]             704ReLU-140          [-1, 352, 14, 14]               0Conv2d-141          [-1, 128, 14, 14]          45,056BatchNorm2d-142          [-1, 128, 14, 14]             256ReLU-143          [-1, 128, 14, 14]               0Conv2d-144           [-1, 32, 14, 14]          36,864BatchNorm2d-145          [-1, 384, 14, 14]             768ReLU-146          [-1, 384, 14, 14]               0Conv2d-147          [-1, 128, 14, 14]          49,152BatchNorm2d-148          [-1, 128, 14, 14]             256ReLU-149          [-1, 128, 14, 14]               0Conv2d-150           [-1, 32, 14, 14]          36,864BatchNorm2d-151          [-1, 416, 14, 14]             832ReLU-152          [-1, 416, 14, 14]               0Conv2d-153          [-1, 128, 14, 14]          53,248BatchNorm2d-154          [-1, 128, 14, 14]             256ReLU-155          [-1, 128, 14, 14]               0Conv2d-156           [-1, 32, 14, 14]          36,864BatchNorm2d-157          [-1, 448, 14, 14]             896ReLU-158          [-1, 448, 14, 14]               0Conv2d-159          [-1, 128, 14, 14]          57,344BatchNorm2d-160          [-1, 128, 14, 14]             256ReLU-161          [-1, 128, 14, 14]               0Conv2d-162           [-1, 32, 14, 14]          36,864BatchNorm2d-163          [-1, 480, 14, 14]             960ReLU-164          [-1, 480, 14, 14]               0Conv2d-165          [-1, 128, 14, 14]          61,440BatchNorm2d-166          [-1, 128, 14, 14]             256ReLU-167          [-1, 128, 14, 14]               0Conv2d-168           [-1, 32, 14, 14]          36,864BatchNorm2d-169          [-1, 512, 14, 14]           1,024ReLU-170          [-1, 512, 14, 14]               0Conv2d-171          [-1, 128, 14, 14]          65,536BatchNorm2d-172          [-1, 128, 14, 14]             256ReLU-173          [-1, 128, 14, 14]               0Conv2d-174           [-1, 32, 14, 14]          36,864BatchNorm2d-175          [-1, 544, 14, 14]           1,088ReLU-176          [-1, 544, 14, 14]               0Conv2d-177          [-1, 128, 14, 14]          69,632BatchNorm2d-178          [-1, 128, 14, 14]             256ReLU-179          [-1, 128, 14, 14]               0Conv2d-180           [-1, 32, 14, 14]          36,864BatchNorm2d-181          [-1, 576, 14, 14]           1,152ReLU-182          [-1, 576, 14, 14]               0Conv2d-183          [-1, 128, 14, 14]          73,728BatchNorm2d-184          [-1, 128, 14, 14]             256ReLU-185          [-1, 128, 14, 14]               0Conv2d-186           [-1, 32, 14, 14]          36,864BatchNorm2d-187          [-1, 608, 14, 14]           1,216ReLU-188          [-1, 608, 14, 14]               0Conv2d-189          [-1, 128, 14, 14]          77,824BatchNorm2d-190          [-1, 128, 14, 14]             256ReLU-191          [-1, 128, 14, 14]               0Conv2d-192           [-1, 32, 14, 14]          36,864BatchNorm2d-193          [-1, 640, 14, 14]           1,280ReLU-194          [-1, 640, 14, 14]               0Conv2d-195          [-1, 128, 14, 14]          81,920BatchNorm2d-196          [-1, 128, 14, 14]             256ReLU-197          [-1, 128, 14, 14]               0Conv2d-198           [-1, 32, 14, 14]          36,864BatchNorm2d-199          [-1, 672, 14, 14]           1,344ReLU-200          [-1, 672, 14, 14]               0Conv2d-201          [-1, 128, 14, 14]          86,016BatchNorm2d-202          [-1, 128, 14, 14]             256ReLU-203          [-1, 128, 14, 14]               0Conv2d-204           [-1, 32, 14, 14]          36,864BatchNorm2d-205          [-1, 704, 14, 14]           1,408ReLU-206          [-1, 704, 14, 14]               0Conv2d-207          [-1, 128, 14, 14]          90,112BatchNorm2d-208          [-1, 128, 14, 14]             256ReLU-209          [-1, 128, 14, 14]               0Conv2d-210           [-1, 32, 14, 14]          36,864BatchNorm2d-211          [-1, 736, 14, 14]           1,472ReLU-212          [-1, 736, 14, 14]               0Conv2d-213          [-1, 128, 14, 14]          94,208BatchNorm2d-214          [-1, 128, 14, 14]             256ReLU-215          [-1, 128, 14, 14]               0Conv2d-216           [-1, 32, 14, 14]          36,864BatchNorm2d-217          [-1, 768, 14, 14]           1,536ReLU-218          [-1, 768, 14, 14]               0Conv2d-219          [-1, 128, 14, 14]          98,304BatchNorm2d-220          [-1, 128, 14, 14]             256ReLU-221          [-1, 128, 14, 14]               0Conv2d-222           [-1, 32, 14, 14]          36,864BatchNorm2d-223          [-1, 800, 14, 14]           1,600ReLU-224          [-1, 800, 14, 14]               0Conv2d-225          [-1, 128, 14, 14]         102,400BatchNorm2d-226          [-1, 128, 14, 14]             256ReLU-227          [-1, 128, 14, 14]               0Conv2d-228           [-1, 32, 14, 14]          36,864BatchNorm2d-229          [-1, 832, 14, 14]           1,664ReLU-230          [-1, 832, 14, 14]               0Conv2d-231          [-1, 128, 14, 14]         106,496BatchNorm2d-232          [-1, 128, 14, 14]             256ReLU-233          [-1, 128, 14, 14]               0Conv2d-234           [-1, 32, 14, 14]          36,864BatchNorm2d-235          [-1, 864, 14, 14]           1,728ReLU-236          [-1, 864, 14, 14]               0Conv2d-237          [-1, 128, 14, 14]         110,592BatchNorm2d-238          [-1, 128, 14, 14]             256ReLU-239          [-1, 128, 14, 14]               0Conv2d-240           [-1, 32, 14, 14]          36,864BatchNorm2d-241          [-1, 896, 14, 14]           1,792ReLU-242          [-1, 896, 14, 14]               0Conv2d-243          [-1, 128, 14, 14]         114,688BatchNorm2d-244          [-1, 128, 14, 14]             256ReLU-245          [-1, 128, 14, 14]               0Conv2d-246           [-1, 32, 14, 14]          36,864BatchNorm2d-247          [-1, 928, 14, 14]           1,856ReLU-248          [-1, 928, 14, 14]               0Conv2d-249          [-1, 128, 14, 14]         118,784BatchNorm2d-250          [-1, 128, 14, 14]             256ReLU-251          [-1, 128, 14, 14]               0Conv2d-252           [-1, 32, 14, 14]          36,864BatchNorm2d-253          [-1, 960, 14, 14]           1,920ReLU-254          [-1, 960, 14, 14]               0Conv2d-255          [-1, 128, 14, 14]         122,880BatchNorm2d-256          [-1, 128, 14, 14]             256ReLU-257          [-1, 128, 14, 14]               0Conv2d-258           [-1, 32, 14, 14]          36,864BatchNorm2d-259          [-1, 992, 14, 14]           1,984ReLU-260          [-1, 992, 14, 14]               0Conv2d-261          [-1, 128, 14, 14]         126,976BatchNorm2d-262          [-1, 128, 14, 14]             256ReLU-263          [-1, 128, 14, 14]               0Conv2d-264           [-1, 32, 14, 14]          36,864BatchNorm2d-265         [-1, 1024, 14, 14]           2,048ReLU-266         [-1, 1024, 14, 14]               0Conv2d-267          [-1, 512, 14, 14]         524,288AvgPool2d-268            [-1, 512, 7, 7]               0BatchNorm2d-269            [-1, 512, 7, 7]           1,024ReLU-270            [-1, 512, 7, 7]               0Conv2d-271            [-1, 128, 7, 7]          65,536BatchNorm2d-272            [-1, 128, 7, 7]             256ReLU-273            [-1, 128, 7, 7]               0Conv2d-274             [-1, 32, 7, 7]          36,864BatchNorm2d-275            [-1, 544, 7, 7]           1,088ReLU-276            [-1, 544, 7, 7]               0Conv2d-277            [-1, 128, 7, 7]          69,632BatchNorm2d-278            [-1, 128, 7, 7]             256ReLU-279            [-1, 128, 7, 7]               0Conv2d-280             [-1, 32, 7, 7]          36,864BatchNorm2d-281            [-1, 576, 7, 7]           1,152ReLU-282            [-1, 576, 7, 7]               0Conv2d-283            [-1, 128, 7, 7]          73,728BatchNorm2d-284            [-1, 128, 7, 7]             256ReLU-285            [-1, 128, 7, 7]               0Conv2d-286             [-1, 32, 7, 7]          36,864BatchNorm2d-287            [-1, 608, 7, 7]           1,216ReLU-288            [-1, 608, 7, 7]               0Conv2d-289            [-1, 128, 7, 7]          77,824BatchNorm2d-290            [-1, 128, 7, 7]             256ReLU-291            [-1, 128, 7, 7]               0Conv2d-292             [-1, 32, 7, 7]          36,864BatchNorm2d-293            [-1, 640, 7, 7]           1,280ReLU-294            [-1, 640, 7, 7]               0Conv2d-295            [-1, 128, 7, 7]          81,920BatchNorm2d-296            [-1, 128, 7, 7]             256ReLU-297            [-1, 128, 7, 7]               0Conv2d-298             [-1, 32, 7, 7]          36,864BatchNorm2d-299            [-1, 672, 7, 7]           1,344ReLU-300            [-1, 672, 7, 7]               0Conv2d-301            [-1, 128, 7, 7]          86,016BatchNorm2d-302            [-1, 128, 7, 7]             256ReLU-303            [-1, 128, 7, 7]               0Conv2d-304             [-1, 32, 7, 7]          36,864BatchNorm2d-305            [-1, 704, 7, 7]           1,408ReLU-306            [-1, 704, 7, 7]               0Conv2d-307            [-1, 128, 7, 7]          90,112BatchNorm2d-308            [-1, 128, 7, 7]             256ReLU-309            [-1, 128, 7, 7]               0Conv2d-310             [-1, 32, 7, 7]          36,864BatchNorm2d-311            [-1, 736, 7, 7]           1,472ReLU-312            [-1, 736, 7, 7]               0Conv2d-313            [-1, 128, 7, 7]          94,208BatchNorm2d-314            [-1, 128, 7, 7]             256ReLU-315            [-1, 128, 7, 7]               0Conv2d-316             [-1, 32, 7, 7]          36,864BatchNorm2d-317            [-1, 768, 7, 7]           1,536ReLU-318            [-1, 768, 7, 7]               0Conv2d-319            [-1, 128, 7, 7]          98,304BatchNorm2d-320            [-1, 128, 7, 7]             256ReLU-321            [-1, 128, 7, 7]               0Conv2d-322             [-1, 32, 7, 7]          36,864BatchNorm2d-323            [-1, 800, 7, 7]           1,600ReLU-324            [-1, 800, 7, 7]               0Conv2d-325            [-1, 128, 7, 7]         102,400BatchNorm2d-326            [-1, 128, 7, 7]             256ReLU-327            [-1, 128, 7, 7]               0Conv2d-328             [-1, 32, 7, 7]          36,864BatchNorm2d-329            [-1, 832, 7, 7]           1,664ReLU-330            [-1, 832, 7, 7]               0Conv2d-331            [-1, 128, 7, 7]         106,496BatchNorm2d-332            [-1, 128, 7, 7]             256ReLU-333            [-1, 128, 7, 7]               0Conv2d-334             [-1, 32, 7, 7]          36,864BatchNorm2d-335            [-1, 864, 7, 7]           1,728ReLU-336            [-1, 864, 7, 7]               0Conv2d-337            [-1, 128, 7, 7]         110,592BatchNorm2d-338            [-1, 128, 7, 7]             256ReLU-339            [-1, 128, 7, 7]               0Conv2d-340             [-1, 32, 7, 7]          36,864BatchNorm2d-341            [-1, 896, 7, 7]           1,792ReLU-342            [-1, 896, 7, 7]               0Conv2d-343            [-1, 128, 7, 7]         114,688BatchNorm2d-344            [-1, 128, 7, 7]             256ReLU-345            [-1, 128, 7, 7]               0Conv2d-346             [-1, 32, 7, 7]          36,864BatchNorm2d-347            [-1, 928, 7, 7]           1,856ReLU-348            [-1, 928, 7, 7]               0Conv2d-349            [-1, 128, 7, 7]         118,784BatchNorm2d-350            [-1, 128, 7, 7]             256ReLU-351            [-1, 128, 7, 7]               0Conv2d-352             [-1, 32, 7, 7]          36,864BatchNorm2d-353            [-1, 960, 7, 7]           1,920ReLU-354            [-1, 960, 7, 7]               0Conv2d-355            [-1, 128, 7, 7]         122,880BatchNorm2d-356            [-1, 128, 7, 7]             256ReLU-357            [-1, 128, 7, 7]               0Conv2d-358             [-1, 32, 7, 7]          36,864BatchNorm2d-359            [-1, 992, 7, 7]           1,984ReLU-360            [-1, 992, 7, 7]               0Conv2d-361            [-1, 128, 7, 7]         126,976BatchNorm2d-362            [-1, 128, 7, 7]             256ReLU-363            [-1, 128, 7, 7]               0Conv2d-364             [-1, 32, 7, 7]          36,864BatchNorm2d-365           [-1, 1024, 7, 7]           2,048ReLU-366           [-1, 1024, 7, 7]               0Linear-367                    [-1, 4]           4,100
================================================================
Total params: 6,957,956
Trainable params: 6,957,956
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.57
Params size (MB): 26.54
Estimated Total Size (MB): 321.69
----------------------------------------------------------------

5. 设置超参数:定义损失函数,学习率,以及根据学习率定义优化器等

# loss_fn = nn.CrossEntropyLoss() # 创建损失函数# learn_rate = 1e-3 # 初始学习率
# def adjust_learning_rate(optimizer,epoch,start_lr):
#     # 每两个epoch 衰减到原来的0.98
#     lr = start_lr * (0.92 ** (epoch//2))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr# optimizer = torch.optim.Adam(model.parameters(),lr=learn_rate)
# 调用官方接口示例
loss_fn = nn.CrossEntropyLoss()learn_rate = 1e-4
lambda1 = lambda epoch:(0.92**(epoch//2))optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法

6. 训练函数

# 训练函数
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset) # 训练集大小num_batches = len(dataloader) # 批次数目train_loss,train_acc = 0,0for X,y in dataloader:X,y = X.to(device),y.to(device)# 计算预测误差pred = model(X)loss = loss_fn(pred,y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录acc与losstrain_acc += (pred.argmax(1)==y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

7. 测试函数

# 测试函数
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc,test_loss = 0,0with torch.no_grad():for X,y in dataloader:X,y = X.to(device),y.to(device)# 计算losspred = model(X)loss = loss_fn(pred,y)test_acc += (pred.argmax(1)==y).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss

8. 正式训练

import copyepochs = 40train_acc = []
train_loss = []
test_acc = []
test_loss = []best_acc = 0.0for epoch in range(epochs):# 更新学习率——使用自定义学习率时使用# adjust_learning_rate(optimizer,epoch,learn_rate)model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,optimizer)scheduler.step() # 更新学习率——调用官方动态学习率时使用model.eval()epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print('Done')
Epoch: 1,Train_acc:57.5%,Train_loss:1.125,Test_acc:67.3%,Test_loss:0.976,Lr:1.00E-04
Epoch: 2,Train_acc:74.8%,Train_loss:0.812,Test_acc:78.8%,Test_loss:0.694,Lr:9.20E-05
Epoch: 3,Train_acc:81.9%,Train_loss:0.677,Test_acc:78.8%,Test_loss:0.528,Lr:9.20E-05
Epoch: 4,Train_acc:86.5%,Train_loss:0.463,Test_acc:84.1%,Test_loss:0.457,Lr:8.46E-05
Epoch: 5,Train_acc:86.5%,Train_loss:0.427,Test_acc:90.3%,Test_loss:0.311,Lr:8.46E-05
Epoch: 6,Train_acc:90.5%,Train_loss:0.350,Test_acc:89.4%,Test_loss:0.417,Lr:7.79E-05
Epoch: 7,Train_acc:92.7%,Train_loss:0.287,Test_acc:85.8%,Test_loss:0.303,Lr:7.79E-05
Epoch: 8,Train_acc:92.0%,Train_loss:0.339,Test_acc:90.3%,Test_loss:0.313,Lr:7.16E-05
Epoch: 9,Train_acc:93.1%,Train_loss:0.231,Test_acc:92.0%,Test_loss:0.243,Lr:7.16E-05
Epoch:10,Train_acc:94.9%,Train_loss:0.286,Test_acc:90.3%,Test_loss:0.285,Lr:6.59E-05
Epoch:11,Train_acc:95.4%,Train_loss:0.185,Test_acc:88.5%,Test_loss:0.334,Lr:6.59E-05
Epoch:12,Train_acc:95.6%,Train_loss:0.151,Test_acc:92.9%,Test_loss:0.226,Lr:6.06E-05
Epoch:13,Train_acc:96.9%,Train_loss:0.118,Test_acc:91.2%,Test_loss:0.263,Lr:6.06E-05
Epoch:14,Train_acc:98.5%,Train_loss:0.114,Test_acc:94.7%,Test_loss:0.177,Lr:5.58E-05
Epoch:15,Train_acc:98.0%,Train_loss:0.145,Test_acc:92.9%,Test_loss:0.244,Lr:5.58E-05
Epoch:16,Train_acc:96.9%,Train_loss:0.126,Test_acc:94.7%,Test_loss:0.200,Lr:5.13E-05
Epoch:17,Train_acc:97.8%,Train_loss:0.196,Test_acc:93.8%,Test_loss:0.166,Lr:5.13E-05
Epoch:18,Train_acc:97.6%,Train_loss:0.167,Test_acc:93.8%,Test_loss:0.236,Lr:4.72E-05
Epoch:19,Train_acc:98.0%,Train_loss:0.102,Test_acc:92.9%,Test_loss:0.258,Lr:4.72E-05
Epoch:20,Train_acc:98.7%,Train_loss:0.093,Test_acc:93.8%,Test_loss:0.207,Lr:4.34E-05
Epoch:21,Train_acc:98.9%,Train_loss:0.070,Test_acc:96.5%,Test_loss:0.180,Lr:4.34E-05
Epoch:22,Train_acc:99.1%,Train_loss:0.115,Test_acc:93.8%,Test_loss:0.213,Lr:4.00E-05
Epoch:23,Train_acc:99.3%,Train_loss:0.051,Test_acc:92.9%,Test_loss:0.241,Lr:4.00E-05
Epoch:24,Train_acc:99.1%,Train_loss:0.064,Test_acc:92.9%,Test_loss:0.201,Lr:3.68E-05
Epoch:25,Train_acc:99.3%,Train_loss:0.070,Test_acc:91.2%,Test_loss:0.313,Lr:3.68E-05
Epoch:26,Train_acc:98.5%,Train_loss:0.083,Test_acc:92.0%,Test_loss:0.288,Lr:3.38E-05
Epoch:27,Train_acc:99.1%,Train_loss:0.125,Test_acc:97.3%,Test_loss:0.146,Lr:3.38E-05
Epoch:28,Train_acc:99.3%,Train_loss:0.109,Test_acc:94.7%,Test_loss:0.259,Lr:3.11E-05
Epoch:29,Train_acc:98.7%,Train_loss:0.073,Test_acc:92.0%,Test_loss:0.268,Lr:3.11E-05
Epoch:30,Train_acc:99.6%,Train_loss:0.097,Test_acc:91.2%,Test_loss:0.254,Lr:2.86E-05
Epoch:31,Train_acc:98.9%,Train_loss:0.083,Test_acc:90.3%,Test_loss:0.255,Lr:2.86E-05
Epoch:32,Train_acc:99.1%,Train_loss:0.058,Test_acc:90.3%,Test_loss:0.326,Lr:2.63E-05
Epoch:33,Train_acc:99.6%,Train_loss:0.042,Test_acc:92.9%,Test_loss:0.155,Lr:2.63E-05
Epoch:34,Train_acc:100.0%,Train_loss:0.037,Test_acc:94.7%,Test_loss:0.157,Lr:2.42E-05
Epoch:35,Train_acc:99.1%,Train_loss:0.140,Test_acc:97.3%,Test_loss:0.144,Lr:2.42E-05
Epoch:36,Train_acc:99.3%,Train_loss:0.059,Test_acc:93.8%,Test_loss:0.174,Lr:2.23E-05
Epoch:37,Train_acc:99.8%,Train_loss:0.038,Test_acc:91.2%,Test_loss:0.215,Lr:2.23E-05
Epoch:38,Train_acc:99.8%,Train_loss:0.045,Test_acc:92.9%,Test_loss:0.163,Lr:2.05E-05
Epoch:39,Train_acc:99.1%,Train_loss:0.076,Test_acc:93.8%,Test_loss:0.185,Lr:2.05E-05
Epoch:40,Train_acc:98.9%,Train_loss:0.090,Test_acc:93.8%,Test_loss:0.176,Lr:1.89E-05
Done

9. 结果可视化

epochs_range = range(epochs)plt.figure(figsize = (12,3))plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label = 'Training Accuracy')
plt.plot(epochs_range,test_acc,label = 'Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label = 'Test Accuracy')
plt.plot(epochs_range,test_loss,label = 'Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and validation Loss')
plt.show()

在这里插入图片描述

10. 模型的保存

# 自定义模型保存
# 状态字典保存
torch.save(model.state_dict(),'./模型参数/J3_densenet121_model_state_dict.pth') # 仅保存状态字典# 定义模型用来加载参数best_model = DenseNet(num_init_features=64, # init_channel=64,growth_rate=32,block_config=(6,12,24,16),num_classes=len(classNames)).to(device)best_model.load_state_dict(torch.load('./模型参数/J3_densenet121_model_state_dict.pth')) # 加载状态字典到模型
<All keys matched successfully>

11. 使用训练好的模型进行预测

# 指定路径图片预测
from PIL import Image
import torchvision.transforms as transformsclasses = list(total_data.class_to_idx) # classes = list(total_data.class_to_idx)def predict_one_image(image_path,model,transform,classes):test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img) # 展示待预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)print(output) # 观察模型预测结果的输出数据_,pred = torch.max(output,1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')
# 预测训练集中的某张照片
predict_one_image(image_path='./data/bird_photos/Bananaquit/007.jpg',model = model,transform = test_transforms,classes = classes)
tensor([[ 5.3812, -3.3856, -0.2259, -2.2854]], device='cuda:0',grad_fn=<AddmmBackward0>)
预测结果是:Bananaquit
classes
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/57685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【天线&空中农业】作物病害检测系统源码&数据集全套:改进yolo11-EfficientFormerV2

改进yolo11-attention等200全套创新点大全&#xff1a;作物病害检测系统源码&#xff06;数据集全套 1.图片效果展示 项目来源 人工智能促进会 2024.11.01 注意&#xff1a;由于项目一直在更新迭代&#xff0c;上面“1.图片效果展示”和“2.视频效果展示”展示的系统图片或者…

Linux版更新流程

一.下载更新包 下载地址&#xff1a;https://www.nvisual.com/%e4%b8%8b%e8%bd%bd/ 二.更新包组成 更新包由三部分组成&#xff1a; 前端更新包&#xff1a;压缩的ZIP文件&#xff0c;例如&#xff1a;dist-2.2.26-20231227.zip (2.2.26是版本号 20231227是发布日期)后端更…

c++仿函数--通俗易懂

1.仿函数是什么 仿函数也叫函数对象&#xff0c;是一种可以像函数一样被调用的对象。从编程实现的角度看&#xff0c;它是一个类&#xff0c;不过这个类重载了函数调用运算符() class Add { public:int operator()(int a, int b) {return a b;} }; 注意&#xff1a;使用的时…

《中安证件阅读机:边检执法办案的得力助手》

在边检执法办案的过程中&#xff0c;高效、准确地识别和查验各类证件至关重要。而中安证件阅读机的出现&#xff0c;为边检工作带来了极大的便利&#xff0c;成为了边检执法人员的得力助手。 一、中安证件阅读机的强大功能 中安证件阅读机具备先进的技术和丰富的功能。它能够快…

计算机网络:网络层 —— IP数据报的发送和转发过程

文章目录 IP数据报的发送和转发过程主机发送IP数据报路由器转发IP数据报示例 IP数据报的发送和转发过程 IP 数据报的发送和转发过程包含以下两个过程&#xff1a; 主机发送IP数据报路由器转发IP数据报 直接交付&#xff1a;源主机与目的主机在同一网络中间接交付&#xff1a;…

104. UE5 GAS RPG 实现技能火焰爆炸

这一篇文章我们再实现一个技能火焰爆炸&#xff0c;由于我们之前已经实现了三个玩家技能&#xff0c;这一个技能有一些总结的味道&#xff0c;对于创建技能相同的部分&#xff0c;长话短说&#xff0c;我们过一遍。 准备工作 我们需要一个技能类&#xff0c;继承于伤害技能基…

【C语言】动态内存开辟

写在前面 C语言中有不少开辟空间的办法&#xff0c;但是在堆上开辟的方法也就只有动态内存开辟&#xff0c;其访问特性与数组相似&#xff0c;但最大区别是数组是开辟在栈上&#xff0c;而动态内存开辟是开辟在堆上的。这篇笔记就让不才娓娓道来。 PS:本篇没有目录实在抱歉CSD…

Excel:vba实现插入图片

实现的效果&#xff1a; 实现的代码&#xff1a; Sub InsertImageNamesAndPictures()Dim PicPath As StringDim PicName As StringDim PicFullPath As StringDim RowNum As IntegerDim Pic As ObjectDim Name As String 防止表格里面有脏数据Cells.Clear 遍历工作表中的每个图…

6.FreeRTOS之任务通知

什么是任务通知&#xff1f; FreeRTOS 从版本 V8.2.0 开始提供任务通知这个功能&#xff0c;每个任务都有一个 32 位的通知值。按照 FreeRTOS 官方的说法&#xff0c;使用消息通知比通过二进制信号量方式解除阻塞任务快 45% &#xff0c; 并且更加 省内存&#xff08;无需创…

前端之html(一)

HTML定义: HTML 超文本标记语言 (1)骨架: HTML:整个网页 head:网页头部 boby:网页主体 title:网页标题 (2)标签关系: 1.嵌套 2.并列 (3)注释 语法:<!-- ... --> 基础: (4) 标签:双标签:<> ... </> 单标签:<> <br> …

书生第四期实训营基础岛——L1G3000浦语提示词工程实践

基础任务 任务要求 背景问题&#xff1a;近期相关研究指出&#xff0c;在处理特定文本分析任务时&#xff0c;语言模型的表现有时会遇到挑战&#xff0c;例如在分析单词内部的具体字母数量时可能会出现错误。任务要求&#xff1a;利用对提示词的精确设计&#xff0c;引导语言…

Android启动流程_SystemServer阶段

前言 上一篇文档我们描述了在 Android 启动流程中 Zygote 部分的内容&#xff0c;从 Zygote 的配置、启动、初始化等内容展开&#xff0c;描述了 Zygote 在 Android 启动中的功能逻辑。本篇文档将会继续 Android 启动流程的描述&#xff0c;从 SystemServer 进程的内容展开&am…

Flutter CustomScrollView 效果-顶栏透明与标签栏吸顶

CustomScrollView 效果 1. 关键组件 CustomScrollView, SliverOverlapAbsorber, SliverPersistentHeader 2. 关键内容 TLDR SliverOverlapAbsorber 包住 pinned为 true 的组件 可以被CustomScrollView 忽略高度。 以下的全部内容的都为了阐述上面这句话。初阶 Flutter 开发知…

Litctf-web

Litctf-web exx xxe&#xff0c; <?xml version"1.0" encoding"utf-8"?> <!DOCTYPE xxe [<!ELEMENT name ANY ><!ENTITY xxe SYSTEM "file:///flag" >]><user><username>&xxe;</username> …

线程模型介绍

线程模型的介绍 线程有三种模型&#xff1a;N:1用户线程模型&#xff0c;1:1核心线程模式&#xff0c;N:M混合线程模型 POSIX: Portable Operating System Interface(可移值操作系统接口) N&#xff1a;1用户线程模型 线程的实现建立在进程控制的机制之上&#xff0c;有用户…

2024 Rust现代实用教程:1.3获取rust的库国内源以及windows下的操作

文章目录 一、使用Cargo第三方库1.直接修改Cargo.toml2.使用cargo-edit插件3.设置国内源4.与windows下面的rust不同点 参考 一、使用Cargo第三方库 1.直接修改Cargo.toml rust语言的库&#xff1a;crate 黏贴至Cargo.toml 保存完毕之后&#xff0c;自动下载依赖 拷贝crat…

ML 系列:第 18 部 - 高级概率论:条件概率、随机变量和概率分布

文章目录 一、说明二、关于条件概率2.1 为什么我们说条件概率&#xff1f;2.2 为什么条件概率在统计学中很重要 三、 随机变量的定义3.1 定义3.2 条件概率中的随机变量 四、概率分布的定义五、结论 一、说明 条件概率是极其重要的概率概念&#xff0c;它是因果关系的数学表述&…

基于springboot的社区团购系统设计与实现

一、项目背景 网络交易&#xff08;Electronic Commerce&#xff09;&#xff1a;是指实现整个贸易过程中各阶段的贸易活动的电子化。网络交易是一种多技术的集合体。其业务可包括&#xff1a;信息交换、售后服务、销售、电子支付、运输、组建虚拟企业、公司和贸易伙伴可以共同…

一文读懂系列:SSL加密流量检测技术详解

SSL加密流量检测功能的主要目的是为了对加密流量做解密处理&#xff0c;并对解密后的流量做内容安全检查&#xff08;比如反病毒、入侵防御、URL远程查询、内容过滤、文件过滤和邮件过滤等&#xff09;和审计&#xff08;防止信息泄露&#xff09;。接下来我们详细介绍SSL加密流…