Pytorch | 从零构建ResNet对CIFAR10进行分类

Pytorch | 从零构建ResNet对CIFAR10进行分类

  • CIFAR10数据集
  • ResNet
      • 核心思想
      • 网络结构
      • 创新点
      • 优点
      • 应用
  • ResNet结构代码详解
    • 结构代码
    • 代码详解
      • BasicBlock 类
      • ResNet 类
      • ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数
  • 训练过程和测试结果
  • 代码汇总
    • resnet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg、GoogleNet对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
这篇文章我们来构建ResNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,它在ILSVRC 2015图像识别挑战赛中取得了优异成绩,在图像分类、目标检测、语义分割等计算机视觉任务中具有广泛应用。以下是对ResNet的详细介绍:

核心思想

  • 解决梯度消失和退化问题:随着神经网络层数的增加,会出现梯度消失或梯度爆炸问题,导致模型难以训练。同时,还会出现网络退化现象,即增加网络层数后,准确率反而下降。ResNet的核心思想是引入残差连接(Residual Connection),通过跨层的shortcut连接,将输入直接传递到后面的层,使得后面的层可以学习到输入的残差,从而缓解了梯度消失和网络退化问题。

网络结构

  • 基本残差块:ResNet的基本组成单元是残差块(Residual Block)。一个典型的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,并且在第二个卷积层之后也有一个ReLU激活函数。输入通过一个shortcut连接直接与残差块的输出相加,形成残差学习。
  • 不同层数的架构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。其中,数字表示网络中卷积层和全连接层的总层数。层数越深,模型的表示能力越强,但计算成本也越高。

创新点

  • 瓶颈结构:在ResNet-50及更深的网络中,采用了瓶颈结构(Bottleneck)的残差块。这种结构先使用1×1卷积层进行降维,然后使用3×3卷积层进行特征提取,最后再使用1×1卷积层进行升维,这样可以在减少计算量的同时增加网络的深度和宽度,提高模型的性能。
  • 全局平均池化:在网络的最后一层,ResNet采用了全局平均池化(Global Average Pooling)代替传统的全连接层进行分类。全局平均池化可以将每个特征图的空间维度压缩为一个值,得到一个固定长度的特征向量,然后直接输入到分类器中进行分类。

优点

  • 训练深度网络更容易:残差连接使得梯度能够更有效地在网络中传播,大大降低了训练深度网络的难度,使得可以成功训练上百层甚至上千层的网络。
  • 性能出色:在各种图像识别任务中,ResNet都取得了非常出色的性能,相比之前的网络结构,具有更高的准确率和更好的泛化能力。
  • 模型可扩展性强:可以方便地通过增加残差块的数量来扩展网络的深度,以适应不同的任务和数据集需求。

应用

  • 图像分类:ResNet在图像分类任务中取得了巨大成功,如在ImageNet数据集上达到了很高的准确率,成为了图像分类领域的主流模型之一。
  • 目标检测:与其他目标检测算法结合,如Faster R-CNN、YOLO等,通过提取图像的特征,提高目标检测的准确率和召回率。
  • 语义分割:用于对图像进行像素级的分类,将图像中的每个像素分配到不同的类别中,在城市景观分割、医学图像分割等领域有广泛应用。

ResNet结构代码详解

结构代码

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

代码详解

以下是对上述提供的PyTorch代码的详细解释,这段代码实现了经典的ResNet(残差网络)系列模型,包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等不同深度的网络架构:

BasicBlock 类

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))
  • 类定义与属性
    • 定义了一个名为BasicBlock的类,继承自nn.Module,这是PyTorch中定义神经网络模块的基类。
    • expansion属性被设置为1,用于表示该基本块在通道维度上的扩展倍数,在BasicBlock中通道数不会进行额外的扩展(后续的Bottleneck块会有不同的扩展倍数)。
  • 初始化方法__init__
    • 首先调用父类nn.Module的初始化方法super(BasicBlock, self).__init__(),确保模块正确初始化。
    • 定义了两个卷积层conv1conv2
      • conv1:输入通道数为in_channels,输出通道数为out_channels,卷积核大小为3×3,步长为stride,填充为1,并且不使用偏置(bias=False),这是遵循ResNet论文中的实现方式,通常配合后续的BatchNorm使用。
      • conv2:输入通道数为out_channels,输出通道数为out_channels * BasicBlock.expansion(实际就是out_channels,因为expansion1),卷积核大小同样是3×3,填充为1,无偏置。
    • 定义了两个BatchNorm2dbn1bn2,分别对应两个卷积层之后,用于对卷积后的特征进行归一化处理,有助于加速训练和提高模型的稳定性。
    • 定义了一个ReLU激活函数relu,并且设置inplace=True,表示直接在原张量上进行激活操作,节省内存空间(但要注意使用不当可能导致梯度计算问题,如前面提到的错误情况)。
    • 定义了shortcut,初始化为一个空的nn.Sequential序列。当输入和输出的通道数不一致或者步长不为1时(意味着尺寸或通道数有变化),会重新构建shortcut,使其包含一个1×1卷积层(用于调整通道数)和一个BatchNorm2d层,以保证shortcut连接的特征维度能与主分支的输出特征维度相匹配,便于后续进行相加操作。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return out
  • 前向传播方法forward
    • 首先将输入x经过conv1卷积、bn1归一化后,再通过relu激活函数得到中间特征。
    • 接着将该中间特征再经过conv2卷积和bn2归一化。
    • 然后将主分支得到的特征outshortcut分支(直接连接输入x经过调整后的特征)进行逐元素相加,实现残差连接的操作。
    • 最后再经过一次relu激活函数后返回结果,作为该基本块的输出。

ResNet 类

class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)
  • 类定义与属性
    • 定义了ResNet类,同样继承自nn.Module,用于构建完整的ResNet网络架构。
    • 初始化了一个属性in_channels64,用于记录当前层的输入通道数,后续会动态更新。
    • 定义了网络的起始层,包括一个3×3卷积层conv1(输入通道为3,对应彩色图像的RGB三个通道,输出通道为64),一个BatchNorm2dbn1用于归一化,一个ReLU激活函数relu,以及一个最大池化层maxpool(其参数设置按照常规的ResNet结构配置)。
    • 分别定义了layer1layer2layer3layer4这四层网络结构,它们通过调用_make_layer方法来构建,每层的输出通道数以及重复的块数量由传入的参数决定,并且随着层数加深,步长会相应改变(从第二层开始步长为2,用于逐步减小特征图尺寸)。
    • 定义了一个自适应平均池化层avgpool,它能将输入的特征图尺寸自适应地变为(1, 1)大小,无论输入特征图的尺寸原本是多少,便于后续全连接层处理。最后定义了一个全连接层fc,用于将池化后的特征映射到指定的类别数num_classes上进行分类。
    def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)
  • _make_layer方法
    • 这个方法用于构建ResNet中的每一层网络结构(由多个基本块组成)。
    • 首先根据传入的stridenum_blocks生成一个步长列表strides,例如,如果传入stride=2num_blocks=3,那么strides会是[2, 1, 1],意味着第一个基本块可能会改变特征图的尺寸和通道数,后面的基本块保持步长为1
    • 然后循环遍历strides列表,每次创建一个指定的block(可以是BasicBlock或者后续定义的Bottleneck块),并传入当前的输入通道数、输出通道数以及对应的步长,将创建好的块添加到layers列表中。同时,更新self.in_channels为当前块输出的通道数(考虑了块的扩展倍数)。
    • 最后将layers列表中的所有块组合成一个nn.Sequential序列并返回,形成一层完整的网络结构。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
  • 前向传播方法forward
    • 首先将输入x依次经过网络起始层的卷积、归一化、激活和池化操作,得到初步的特征表示。
    • 然后将该特征依次通过layer1layer2layer3layer4这四层网络结构,不断提取和融合特征,每一层都会进一步加深特征的抽象程度并且改变特征图的尺寸和通道数。
    • 接着经过自适应平均池化层avgpool,将特征图变为(1, 1)大小的特征向量。
    • 通过out.view(out.size(0), -1)操作将特征向量展平为一维向量,使其能输入到全连接层fc中。
    • 最后将全连接层的输出作为整个网络的最终输出,返回分类结果。

ResNet18、ResNet34、ResNet50、ResNet101、ResNet152函数

# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
  • 这两个函数分别用于创建ResNet-18ResNet-34网络模型。它们通过调用ResNet类的构造函数,传入BasicBlock作为构建块类型,以及对应不同层数的重复块数量列表(如ResNet-18中每层分别重复2个基本块),还有指定的类别数num_classes,最终返回构建好的相应深度的ResNet模型实例,用于图像分类等任务。
# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride!= 1 or in_channels!= out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))
  • Bottleneck类定义与初始化
    • 定义了Bottleneck类,同样继承自nn.Module,用于构建更深层的ResNet网络(如ResNet-50及以上)中的基本块。
    • expansion属性被设置为4,意味着该块在经过一系列操作后,输出通道数会是输入通道数的4倍,通过这种方式在增加网络深度的同时控制计算量。
    • 在初始化方法中,定义了三个卷积层conv1conv2conv3,分别是1×1卷积用于降维、3×3卷积进行主要的特征提取、1×1卷积用于升维,并且每个卷积层后都有对应的BatchNorm2d层进行归一化,还有ReLU激活函数用于激活中间特征。
    • 同样定义了shortcut,其构建逻辑和BasicBlock中类似,根据输入输出通道数和步长情况来决定是否需要构建包含1×1卷积和BatchNorm2d层的调整结构,以保证残差连接的维度匹配。
    def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return out
  • Bottleneck块的前向传播方法
    • 前向传播过程与BasicBlock类似,只是中间经过了三个卷积层及对应的归一化和激活操作,最后同样是将主分支特征与shortcut分支特征相加后再经过ReLU激活函数输出,实现残差学习。
def ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
  • 这几个函数分别用于创建ResNet-50ResNet-101ResNet-152网络模型,它们与创建ResNet-18ResNet-34的函数类似,只是传入的构建块类型变为Bottleneck,以及对应不同层数的重复Bottleneck块数量列表,还有指定的类别数num_classes,最终返回相应深度的ResNet模型实例,用于更复杂的图像分类等任务,这些更深层的网络结构在处理大规模图像数据集时往往能取得更好的性能表现。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:

在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models|--__init__.py|-resnet.py|--...
|--results
|--weights
|--train.py
|--test.py

resnet.py

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels * BasicBlock.expansion)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BasicBlock.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += self.shortcut(x)out = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.layer1 = self._make_layer(block, 64, num_blocks[0], 1)self.layer2 = self._make_layer(block, 128, num_blocks[1], 2)self.layer3 = self._make_layer(block, 256, num_blocks[2], 2)self.layer4 = self._make_layer(block, 512, num_blocks[3], 2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# ResNet18, ResNet34
def ResNet18(num_classes):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def ResNet34(num_classes):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)# ResNet50, ResNet101, ResNet152 需要 BottleNeck 
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1= nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * self.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * self.expansion))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += self.shortcut(x)out = self.relu(out)return outdef ResNet50(num_classes):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def ResNet101(num_classes):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def ResNet152(num_classes):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model_name = 'ResNet18'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安装MongoDB,环境配置

官网下载地址:MongoDB Shell Download | MongoDB 选择版本 安装 下载完成双击打开 点击mongodb-windows-x86_64-8.0.0-signed 选择安装地址 检查安装地址 安装成功 二.配置MongoDB数据库环境 1.找到安装好MongoDB的bin路径 复制bin路径 打开此电脑 -> 打开高级…

7.C语言 宏(Macro) 宏定义,宏函数

目录 宏定义 宏函数 1.注释事项 2.注意事项 宏(Macro)用法 常量定义 简单函数实现 类型检查 条件编译 宏函数计算参数个数 宏定义进行类型转换 宏定义进行位操作 宏定义进行断言 总结 宏定义 #include "stdio.h" #include "string.h" #incl…

基于高云GW5AT-15 FPGA的SLVS-EC桥MIPI设计方案分享

作者:Hello,Panda 一、设计需求 设计一个4Lanes SLVS-EC桥接到2组4lanes MIPI DPHY接口的电路模块: (1)CMOS芯片:IMX537-AAMJ-C,输出4lanes SLVS-EC 4.752Gbps Lane速率; (2&…

【漏洞复现】CVE-2023-29944 Expression Injection

漏洞信息 NVD - cve-2023-29944 Metersphere v1.20.20-lts-79d354a6 is vulnerable to Remote Command Execution. The system command reverse-shell can be executed at the custom code snippet function of the metersphere system workbench. 背景介绍 MeterSphere is…

在VBA中结合正则表达式和查找功能给文档添加交叉连接

在VBA中搜索文本有两种方式可用,一种是利用Range.Find对象(更常见的形式可能是Selection.Find,Selection是Range的子类,Selection.Find其实就是特殊的Range.Find),另一种方法是利用正则表达式,但…

AW36518芯片手册解读(3)

接前一篇文章:AW36518芯片手册解读(2) 二、详述 3. 功能描述 (1)上电复位 当电源电压VIN降至预定义电压VPOR(典型值为2.0V)以下时,该设备会产生复位信号以执行上电复位操作&#x…

【mysql】唯一性约束unique

文章目录 唯一性约束1. 作用2. 关键字3. 特点4. 添加唯一约束5. 关于复合唯一约束 唯一性约束 1. 作用 用来限制某个字段/某列的值不能重复。 2. 关键字 UNIQUE3. 特点 同一个表可以有多个唯一约束。唯一约束可以是某一个列的值唯一,也可以多个列组合的值唯一。…

实操给桌面机器人加上超拟人音色

前面我们讲了怎么用CSK6大模型开发板做一个桌面机器人充当AI语音助理,近期上线超拟人方案,不仅大模型语音最快可以1秒内回复,还可以让我们的桌面机器人使用超拟人音色、具备声纹识别等能力,本文以csk6大模型开发板为例实操怎么把超…

SYD881X RTC定时器事件在调用timeAppClockSet后会出现比较大的延迟

RTC定时器事件在调用timeAppClockSet后会出现比较大的延迟 这里RTC做了两个定时器一个是12秒,一个是185秒: #define RTCEVT_NUM ((uint8_t) 0x02)//当前定时器事件数#define RTCEVT_12S ((uint32_t) 0x0000002)//定时器1s事件 /*整分钟定时器事件,因为其余的…

LearnOpenGL学习(碰撞检测,粒子)

完整代码见:zaizai77/OpenGLTo2DGame: 基于OpenGL制作2D游戏 物体本身的数据来检测碰撞会很复杂,一半使用重叠在物体上的更简单的外形来检测。 AABB - AABB 碰撞 AABB代表的是轴对齐碰撞箱(Axis-aligned Bounding Box),碰撞箱是指与场景基…

SwinTransformer 改进:添加SelfAttention自注意力层

目录 1. SelfAttention自注意力层 2. SwinTransformer SelfAttention 3. 代码 1. SelfAttention自注意力层 Self-Attention自注意力层是一种在神经网络中用于处理序列数据的注意力机制。它通过对输入序列中的不同位置进行关注,来计算每个位置与其他位置的关联程…

c++ ------语句

一、简单语句 简单语句是C中最基本的语句单元,通常以分号(;)结尾,用于执行一个单一的操作。常见的简单语句类型有: 表达式语句:由一个表达式后面加上分号构成,用于计算表达式的值或者执行具有…

【他山之石】The SVG path Syntax: An Illustrated Guide:SVG 中的 path 语法图解指南

写在前面 本文为我的自学精译专栏《CSS in Depth 2》第 086 篇文章、在介绍 CSS 的 clip-path 属性的用法时作者提到的一篇延伸阅读材料,以图文并茂的形式系统梳理了 SVG path 属性的方方面面。其中最为精彩的是文中列举的大量使用案例。为了方便查找,特…

小型 Vue 项目,该不该用 Pinia 、Vuex呢?

说到 Vue3 的状态管理,我们会第一时间想到 Pinia、Vuex,但是经过很长一段时间的 Vue3 项目开发,我逐渐发现,我们真的有必要用 Pinia、Vuex 这类的状态管理工具吗? 带着这样的疑惑,我首先是想知道一下 Pini…

c4d动画怎么导出mp4视频,c4d动画视频格式设置

宝子们,今天来给大家讲讲 C4D 咋导出mp4视频的方法。通过用图文教程的形式给大家展示得明明白白的,让你能轻松理解和掌握,不管是理论基础,还是实际操作和技能技巧,都能学到,快速入门然后提升自己哦。 c4d动…

EfficienetAD异常值检测之瓷砖表面缺陷检测(免费下载测试数据集和模型)

背景 当今制造业蓬勃发展,产品质量把控至关重要。从精密电子元件到大型工业板材,表面缺陷哪怕细微,都可能引发性能故障或外观瑕疵。人工目视检测耗时费力且易漏检,已无法适应高速生产线节奏。在此背景下,表面缺陷异常…

将Minio设置为Django的默认Storage(django-storages)

这里写自定义目录标题 前置说明静态文件收集静态文件 使用django-storages来使Django集成Minio安装依赖settings.py测试收集静态文件测试媒体文件 前置说明 静态文件 Django默认的Storage是本地,项目中的CSS、图片、JS都是静态文件。一般会将静态文件放到一个单独…

Redis生产实践中相关疑问记录

1. Redis相关疑问 1.1. redis内存使用率100% 就等同于redis不可用吗? 正常使用情况下,不是。 redis有【缓存淘汰机制】,Redis 在内存使用率达到 100% 时不会直接崩溃。相反,它依赖内存淘汰策略来释放内存,确保系统的…

量化交易——RSI策略(vectorbt实现)

本文为通过vectorbt(以下简称vbt)实现量化交易系列第一篇文章,通过使用vbt实现RSI策略从而熟悉其代码框架。 关于本文所使用数据的说明 由于vbt官方文档提供的入门案例使用的数据是通过其内置的yfinance包获取,在国内无法直接访…

本地摄像头视频流在html中打开

1.准备ffmpeg 和(rtsp-simple-server srs搭建流媒体服务器)视频服务器. 2.解压视频流服务器修改配置文件mediamtx.yml ,hlsAlwaysRemux: yes 3.双击运行服务器。 4,安装ffmpeg ,添加到环境变量。 5.查询本机设备列表 ffmpeg -list_devices true -f dshow -i d…