深度学习项目--分组卷积与ResNext网络实验探究(pytorch复现)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

  • ResNext是分组卷积的开始之作,这里本文将学习ResNext网络;
  • 本文复现了ResNext50神经网络,并用其进行了猴痘病分类实验;
  • 没有最好的网络。只有最适合的网络,网络不是越复杂,越优秀越好,必须根据实际数据情况,目标要求决定,很多时候,简单的网络反而效果更好
  • 欢迎收藏 + 关注,本人将会持续更新

文章目录

  • 1、知识简介
    • 1、分组卷积
    • 2、split-transform-merge
    • 3、ResNext-50简介
  • 2、ResNext-50实验
    • 1、导入数据
      • 1、导入库
      • 2、查看数据信息和导入数据
      • 3、展示数据
      • 4、数据导入
      • 5、数据划分
      • 6、动态加载数据
    • 2、构建ResNext-50网络
    • 3、模型训练
      • 1、构建训练集
      • 2、构建测试集
      • 3、设置超参数
    • 4、模型训练
    • 5、结果可视化
    • 6、模型评估
  • 3、参考资料

1、知识简介

1、分组卷积

分组卷积最早出现在AlexNet网络中,在这里将通道数分成两组,采用两个GPU并行提取特征,网络结构如下:

在这里插入图片描述

提取到的特征图如下:

在这里插入图片描述

作者发现第一组提取的主要是黑白特征,第二组提取的主要是彩色特征,这样分组特征可以更好的提取不同特征数据。


普通卷积 VS 分组卷积

先看常规卷积,在常规卷积中,输入feature map尺寸为 n 个,输出feature map与卷积和数量相同也是n个,卷积核大小为:c * k * k,n个卷积核总大小为:n * c * k * k,最后输出的维度是:n * h1 * w1如下图左边所示

在这里插入图片描述

分组卷积,就是对输入的feature map进行分组,然后每组分别卷积。假设输入feature map的尺寸为 c * h * w,输出的feature map为 n,假设分为 g 组,则每组的输入的feature map数量为 c / g,每组输出的feature map为 n / g。但是注意只是每个卷积核的输入通道数量变成了 c / g,卷积核大小是不变的,每一组卷积核运算后得到了 (n / g) * h1 * w1,最后将各组矩阵进行拼接就可以得出最后的结果,最后输出的维度依然是n * h1 * w1,与常规卷积一样。

参数了对比

  • 常规卷积:c * k * k * n,c通道数,k * k:卷积核矩阵大小,n卷积核数量;
  • 分组卷积:(c / g) * k * k * (n / g) * g = k * k * c * n * (1 / g),从参数了来看,分组卷积更小

更详细的图如下

在这里插入图片描述

2、split-transform-merge

“Split-Transform-Merge” 是一种常见的设计模式或处理流程,广泛应用于软件开发、数据处理和系统架构中。它的核心思想是将一个复杂的问题分解为更小的部分(Split),对每个部分进行独立的处理或转换(Transform),最后将处理后的结果重新组合(Merge)以完成整体任务。


1. Split(拆分)

在这一阶段,输入数据或任务被分解成更小、更易于管理的部分。拆分的方式取决于具体问题和上下文。例如:

  • 数据拆分:将大数据集分割成多个小块。
  • 任务拆分:将一个复杂的任务分解为多个子任务。
  • 并行化:通过拆分实现并行处理,提高效率。

示例

  • 分组卷积中,输入通道分组拆分,分组进行卷积。

2. Transform(转换/处理)

在拆分后,每个部分被独立处理或转换。这是整个流程的核心阶段,通常涉及计算、分析或修改操作。转换的具体内容取决于任务需求:

  • 数据清洗、格式转换。
  • 算法计算或模型推理。
  • 对子任务的独立执行。

示例

  • 分组卷积中 ,每一组分别进行卷积计算,互补干扰。

3. Merge(合并)

在所有子任务完成后,将处理后的结果重新组合起来,形成最终的输出。合并的方式需要确保结果的完整性和一致性:

  • 数据合并:将多个处理后的数据块拼接成完整的数据集。
  • 结果整合:将多个子任务的结果汇总为最终答案。
  • 冲突解决:如果子任务之间存在冲突或重复,需要在合并阶段解决。

示例

  • 分组卷积中,最后将每一组卷积的结果进行组合。

3、ResNext-50简介

ResNext网络被誉为,分组卷积的开山之作,是何凯明团队在2017年CVPR会与提出的,是ResNet网络的升级版。

在论文中,作者提到了一个普遍存在的现象,提高模型准确率,往往采用的是加深或加宽网络的方法,这种方法虽然有一定效果,但是网络设计的难度和计算了也随着增加,因为不代表网络越深就越好,有时候提升了精度,但是代价也大,就如VGG16提出来的时候,计算了庞大。

在论文中,作者提出了在不额外增加计算代价的情况下,提升网络精度,提出了cardinality概念(cardinality指的是分组卷积中的“组数”).

下图中,左边是(Resnet)右边数(Resnext)的模块差异,在ResNet中,输入具有256个通道特征经过1 * 1卷积压缩到4倍到64个通道特征,然后通过3 * 3卷积核进行特征提取,最后经过 3 * 3卷积核进行还原通道数量输出,并于原来特征进行残差连接。在ResNext中,将256个输入通道特征分成32个组,每个组首先进行64倍压缩到4个通道,然后用3 * 3卷积核大小进行特征提取,最后通过1 * 1卷积核进行通道还原,后会将每个分组的结构进行维度拼接并与原始特征进行残差连接。

在这里插入图片描述

cardinatity指的是一个block中所具有的相同分支的数目,即“组数”.

下面进行ResNext-50网络图的搭建(pytorch复现)

2、ResNext-50实验

1、导入数据

1、导入库

import torch  
import torch.nn as nn
import torchvision 
import numpy as np 
import os, PIL, pathlib # 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"device 
'cuda'

2、查看数据信息和导入数据

数据目录有两个文件:一个数据文件,一个权重。

data_dir = "./data/"data_dir = pathlib.Path(data_dir)# 类别数量
classnames = [str(path).split('/')[0] for path in os.listdir(data_dir)]classnames
['Monkeypox', 'Others']

3、展示数据

import matplotlib.pylab as plt  
from PIL import Image # 获取文件名称
data_path_name = "./data/Others"
data_path_list = [f for f in os.listdir(data_path_name) if f.endswith(('jpg', 'png'))]# 创建画板
fig, axes = plt.subplots(2, 8, figsize=(16, 6))for ax, img_file in zip(axes.flat, data_path_list):path_name = os.path.join(data_path_name, img_file)img = Image.open(path_name) # 打开# 显示ax.imshow(img)ax.axis('off')plt.show()


在这里插入图片描述

4、数据导入

from torchvision import transforms, datasets # 数据统一格式
img_height = 224
img_width = 224 data_tranforms = transforms.Compose([transforms.Resize([img_height, img_width]),transforms.ToTensor(),transforms.Normalize(   # 归一化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
])# 加载所有数据
total_data = datasets.ImageFolder(root=data_dir, transform=data_tranforms)

5、数据划分

# 大小 8 : 2
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])

6、动态加载数据

batch_size = 32 train_dl = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_data,batch_size=batch_size,shuffle=False
)
# 查看数据维度
for data, labels in train_dl:print("data shape[N, C, H, W]: ", data.shape)print("labels: ", labels)break
data shape[N, C, H, W]:  torch.Size([32, 3, 224, 224])
labels:  tensor([1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,0, 1, 0, 0, 0, 1, 0, 0])

2、构建ResNext-50网络

ResNet-50网络结构图
在这里插入图片描述

在这里插入图片描述

在复现ResNext50网络中,我查阅了不少资料,但是我好像都没怎么看懂那个代码,后面我发现这个就是在ResNet50上加了分组卷积,其他网络结构就是在每一层,第二层的数量是resnet的2倍,后面基于以前搭建的ResNet网络结果进行修改,代码如下所示。

在ResNext50中,有几个参数需要注意:

  • 分组卷积:cardinality参数代表分组卷积数量,在Conv2d中groups参数就是分组卷积数量。
  • 通道数计算:每组的输出通道数由 group_depth 决定,总输出通道数为 cardinality × group_depth。这里,下面本人搭建的ResNext50网络结构,每一层输入通道数,输出通道数,都是自己手动输入的,故这里group_depth隐藏在filters中(手动计算).

回忆
Bottleneck 的基本概念

Bottleneck 结构通常由三个卷积层组成,他是ResNet以及其变体的基本网络层单元。

  1. 第一个 1×1 卷积:降低输入特征图的通道数,减少后续计算量。
  2. 中间的 3×3 卷积:核心特征提取过程。在 ResNeXt 中,这一层使用分组卷积来增强表达能力。
  3. 最后一个 1×1 卷积:恢复通道数到原始或者更高的数量,以便与输入特征图进行残差连接。

注意:

  • 在ResNext网络结构中,分组卷积只在Bottleneck只在第二层使用
import torch.nn.functional as F# Bottleneck: 分为残差模块一、残差模块二# 定义残差模块一,这个用于处理输入和输出通道一样的情况
'''  
卷积核大小:1       3       1
核心特点:尺寸不变:输入和输出的尺寸保持一致。 没有下采样:没有使用步长大于1的卷积操作,因此没有改变特征图的空间尺寸
'''
class Identity_block(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality):super(Identity_block, self).__init__()# 输出通道filter1, filter2, filter3 = filters# 卷积层一, 降维self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=1)self.bn1 = nn.BatchNorm2d(filter1)# 卷积层2, 分组卷积, 核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality)   # 通过卷积输入输出公式发现,padding=1,可以保证输入和输出尺寸相同self.bn2 = nn.BatchNorm2d(filter2)# 卷积层3, 升维self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)self.bn3 = nn.BatchNorm2d(filter3)def forward(self, x):# 记录原始值xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))# 残差连接,输入、输出维度不变x += xxx = F.relu(x)return x # 定义卷积模块二:用于处理输入和输出不一样的情况
'''  
* 卷积核还是:1 3 1
* stride=2
* 这里的分支是采用一个Conv2D,和一个归一化BN层,也是为了处理数据维度吧, 这种维度的变化,可以用ai举例子核心特点:尺寸变化,stride=2降维
'''
class ConvBlock(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality, stride=2):super(ConvBlock, self).__init__()filter1, filter2, filter3= filters# 卷积层1, 降维self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride)self.bn1 = nn.BatchNorm2d(filter1)# 卷积2, 分组卷积,核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality) # 需要维持维度不变self.bn2 = nn.BatchNorm2d(filter2)# 卷积3, 降维self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)  # stride = 1,维持通道不变self.bn3 = nn.BatchNorm2d(filter3)# 用于匹配维度的shortcut卷积,这个就是上面Identity_block的x分支self.shortcut = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride)self.shortcut_bn = nn.BatchNorm2d(filter3)def forward(self, x):xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))temp = self.shortcut_bn(self.shortcut(xx))x += tempx = F.relu(x)return x # 定义ResNext50
class ResNext50(nn.Module):def __init__(self, classes):   # 类别数量super().__init__()# 头顶, resnet以及变体一般都是这个self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一部分self.part1_1 = ConvBlock(64, 3, [128, 128, 256], cardinality=32, stride=1)self.part1_2 = Identity_block(256, 3, [128, 128, 256], cardinality=32)self.part1_3 = Identity_block(256, 3, [128, 128, 256], cardinality=32)# 第二部分self.part2_1 = ConvBlock(256, 3, [256, 256, 512], cardinality=32)self.part2_2 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_3 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_4 = Identity_block(512, 3, [256, 256, 512], cardinality=32)# 第三部分self.part3_1 = ConvBlock(512, 3, [512, 512, 1024], cardinality=32)self.part3_2 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_3 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_4 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_5 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_6 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)# 第四部分self.part4_1 = ConvBlock(1024, 3, [1024, 1024, 2048], cardinality=32)self.part4_2 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)self.part4_3 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)# 平均池化self.avg_pool = nn.AvgPool2d(kernel_size=7)# 全连接self.fn1 = nn.Linear(2048, classes)def forward(self, x):# 头部x = F.relu(self.bn1(self.conv1(x)))x = self.max_pool(x)x = self.part1_1(x)x = self.part1_2(x)x = self.part1_3(x)x = self.part2_1(x)x = self.part2_2(x)x = self.part2_3(x)x = self.part2_4(x)x = self.part3_1(x)x = self.part3_2(x)x = self.part3_3(x)x = self.part3_4(x)x = self.part3_5(x)x = self.part3_6(x)x = self.part4_1(x)x = self.part4_2(x)x = self.part4_3(x)x = self.avg_pool(x)x = x.view(x.size(0), -1)  # 扁平化x = self.fn1(x)return x model = ResNext50(classes=len(classnames)).to(device)model
ResNext50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(part1_1): ConvBlock((conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_2): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_3): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_1): ConvBlock((conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_2): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_3): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_4): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_1): ConvBlock((conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_2): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_3): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_4): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_5): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_6): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_1): ConvBlock((conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_2): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_3): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(avg_pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fn1): Linear(in_features=2048, out_features=2, bias=True)
)

3、模型训练

1、构建训练集

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)batch_size = len(dataloader)train_acc, train_loss = 0, 0 for X, y in dataloader:X, y = X.to(device), y.to(device)# 训练pred = model(X)loss = loss_fn(pred, y)# 梯度下降法optimizer.zero_grad()loss.backward()optimizer.step()# 记录train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_acc /= sizetrain_loss /= batch_sizereturn train_acc, train_loss

2、构建测试集

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)batch_size = len(dataloader)test_acc, test_loss = 0, 0 with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= batch_sizereturn test_acc, test_loss

3、设置超参数

loss_fn = nn.CrossEntropyLoss()  # 损失函数     
learn_lr = 1e-4            # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 优化器

4、模型训练

import copy train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 50best_acc = 0for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_model     if epoch_test_acc > best_acc:         best_acc   = epoch_test_acc         best_model = copy.deepcopy(model)  # 拷贝最好模型train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率     lr = optimizer.state_dict()['param_groups'][0]['lr']# 输出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")PATH = './best_model.pth'  # 保存的参数文件名 
torch.save(best_model.state_dict(), PATH)
Epoch: 1, Train_acc:62.3%, Train_loss:0.696, Test_acc:66.4%, Test_loss:0.604
Epoch: 2, Train_acc:67.9%, Train_loss:0.620, Test_acc:69.9%, Test_loss:0.580
Epoch: 3, Train_acc:69.5%, Train_loss:0.580, Test_acc:68.3%, Test_loss:0.603
Epoch: 4, Train_acc:71.6%, Train_loss:0.547, Test_acc:73.9%, Test_loss:0.530
Epoch: 5, Train_acc:74.7%, Train_loss:0.519, Test_acc:75.1%, Test_loss:0.520
Epoch: 6, Train_acc:78.2%, Train_loss:0.464, Test_acc:67.8%, Test_loss:0.683
Epoch: 7, Train_acc:78.1%, Train_loss:0.459, Test_acc:69.0%, Test_loss:0.652
Epoch: 8, Train_acc:80.8%, Train_loss:0.411, Test_acc:72.7%, Test_loss:0.643
Epoch: 9, Train_acc:84.8%, Train_loss:0.362, Test_acc:74.8%, Test_loss:0.575
Epoch:10, Train_acc:87.4%, Train_loss:0.314, Test_acc:77.9%, Test_loss:0.536
Epoch:11, Train_acc:89.3%, Train_loss:0.266, Test_acc:79.0%, Test_loss:0.505
Epoch:12, Train_acc:89.4%, Train_loss:0.260, Test_acc:78.3%, Test_loss:0.601
Epoch:13, Train_acc:90.7%, Train_loss:0.226, Test_acc:81.4%, Test_loss:0.493
Epoch:14, Train_acc:93.9%, Train_loss:0.159, Test_acc:80.4%, Test_loss:0.616
Epoch:15, Train_acc:93.8%, Train_loss:0.152, Test_acc:80.4%, Test_loss:0.620
Epoch:16, Train_acc:92.2%, Train_loss:0.190, Test_acc:82.3%, Test_loss:0.621
Epoch:17, Train_acc:94.0%, Train_loss:0.142, Test_acc:82.3%, Test_loss:0.582
Epoch:18, Train_acc:95.8%, Train_loss:0.106, Test_acc:79.3%, Test_loss:0.625
Epoch:19, Train_acc:95.5%, Train_loss:0.127, Test_acc:81.1%, Test_loss:0.625
Epoch:20, Train_acc:95.4%, Train_loss:0.113, Test_acc:83.0%, Test_loss:0.482
Epoch:21, Train_acc:96.7%, Train_loss:0.087, Test_acc:83.0%, Test_loss:0.667
Epoch:22, Train_acc:97.3%, Train_loss:0.083, Test_acc:80.4%, Test_loss:0.695
Epoch:23, Train_acc:97.1%, Train_loss:0.077, Test_acc:83.7%, Test_loss:0.634
Epoch:24, Train_acc:96.6%, Train_loss:0.086, Test_acc:82.5%, Test_loss:0.732
Epoch:25, Train_acc:96.6%, Train_loss:0.098, Test_acc:83.9%, Test_loss:0.711
Epoch:26, Train_acc:96.0%, Train_loss:0.107, Test_acc:75.3%, Test_loss:0.821
Epoch:27, Train_acc:95.6%, Train_loss:0.105, Test_acc:81.6%, Test_loss:0.596
Epoch:28, Train_acc:96.7%, Train_loss:0.088, Test_acc:84.4%, Test_loss:0.606
Epoch:29, Train_acc:97.5%, Train_loss:0.071, Test_acc:86.5%, Test_loss:0.615
Epoch:30, Train_acc:98.2%, Train_loss:0.051, Test_acc:80.4%, Test_loss:0.772
Epoch:31, Train_acc:98.5%, Train_loss:0.041, Test_acc:83.7%, Test_loss:0.694
Epoch:32, Train_acc:98.5%, Train_loss:0.048, Test_acc:82.8%, Test_loss:0.671
Epoch:33, Train_acc:97.7%, Train_loss:0.064, Test_acc:84.1%, Test_loss:0.745
Epoch:34, Train_acc:98.4%, Train_loss:0.054, Test_acc:83.7%, Test_loss:0.661
Epoch:35, Train_acc:98.2%, Train_loss:0.068, Test_acc:83.0%, Test_loss:0.605
Epoch:36, Train_acc:96.8%, Train_loss:0.086, Test_acc:83.2%, Test_loss:0.551
Epoch:37, Train_acc:97.8%, Train_loss:0.063, Test_acc:82.3%, Test_loss:0.739
Epoch:38, Train_acc:97.6%, Train_loss:0.065, Test_acc:83.0%, Test_loss:0.583
Epoch:39, Train_acc:98.2%, Train_loss:0.045, Test_acc:83.4%, Test_loss:0.697
Epoch:40, Train_acc:98.1%, Train_loss:0.048, Test_acc:82.5%, Test_loss:0.710
Epoch:41, Train_acc:98.2%, Train_loss:0.054, Test_acc:83.2%, Test_loss:0.564
Epoch:42, Train_acc:98.4%, Train_loss:0.051, Test_acc:85.5%, Test_loss:0.514
Epoch:43, Train_acc:99.0%, Train_loss:0.025, Test_acc:83.9%, Test_loss:0.663
Epoch:44, Train_acc:99.1%, Train_loss:0.029, Test_acc:85.5%, Test_loss:0.594
Epoch:45, Train_acc:98.3%, Train_loss:0.036, Test_acc:84.6%, Test_loss:0.719
Epoch:46, Train_acc:98.7%, Train_loss:0.036, Test_acc:84.4%, Test_loss:0.631
Epoch:47, Train_acc:97.7%, Train_loss:0.055, Test_acc:81.4%, Test_loss:0.643
Epoch:48, Train_acc:98.7%, Train_loss:0.040, Test_acc:85.1%, Test_loss:0.607
Epoch:49, Train_acc:98.8%, Train_loss:0.037, Test_acc:80.2%, Test_loss:0.897
Epoch:50, Train_acc:98.6%, Train_loss:0.042, Test_acc:84.4%, Test_loss:0.601
Done

5、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


在这里插入图片描述

6、模型评估

# 加载最好模型
best_model.load_state_dict(torch.load(PATH, map_location=device)) 
# 模型测试
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)print(epoch_test_acc, epoch_test_loss)
0.8648018648018648 0.6145411878824234

3、参考资料

  • 深度学习——分类之ResNeXt - 知乎
  • 通义 - 你的个人AI助手
  • ResNeXt代码复现+超详细注释(PyTorch)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从代码学习深度学习 - RNN PyTorch版

文章目录 前言一、数据预处理二、辅助训练工具函数三、绘图工具函数四、模型定义五、模型训练与预测六、实例化模型并训练训练结果可视化总结前言 循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一…

JS DOM节点增删改查

增加节点 通过document.createNode()函数创建对象 // 创建节点 const div document.createElement(div) // 追加节点 document.body.appendChild(div) 克隆节点 删除节点

IMX6ULL学习整理篇——Linux使用更现代的GPIO操作简单设备

IMX6ULL学习篇——实战:使用设备树/Pinctl-gpio子系统驱动LED 前言 ​ 经过层层考验,我们即将接近现代的LED驱动的解决方案了。那就是使用最现代的方式开发一个简单的GPIO驱动外设。 ​ 如果您忘记了设备树的相关内容,请自行到笔者的上一篇…

2025-04-07 NO.3 Quest3 MR 配置

文章目录 1 MR 介绍1.1 透视1.2 场景理解1.3 空间设置 2 配置 MR 环境2.1 场景配置2.2 MR 配置 3 运行测试 配置环境: Windows 11Unity 6000.0.42f1Meta SDK v74.0.2Quest3 1 MR 介绍 1.1 透视 ​ 透视(Passthrough)是将应用的背景从虚拟的…

如何在 GitHub 上开源一个小项目:从创建到长期维护的完整指南

如何在 GitHub 上开源一个小项目:从创建到长期维护的完整指南 适用于 个人开发者、团队合作、企业开源,涵盖 Git 基础、GitHub 配置、最佳实践、社区互动、自动化 CI/CD 及长期维护策略。 📌 1. 注册 GitHub 账户 如果你还没有 GitHub 账户&…

【技术报告】GPT-4o 原生图像生成的应用与分析

【技术报告】GPT-4o 原生图像生成的应用与分析 1. GPT-4o 原生图像生成简介1.1 文本渲染能力1.2 多轮对话迭代1.3 指令遵循能力1.4 上下文学习能力1.5 跨模态知识调用1.6 逼真画质与多元风格1.7 局限性与安全性 2. GPT-4o 技术报告2.1 引言2.2 安全挑战、评估与缓解措施2.2.1 安…

React中的跨组件通信

在React中,跨组件通信有几种常见的方式。每种方式适用于不同的场景,下面是几种常见的跨组件通信方法: 1. 通过父子组件传递 Props 父组件可以通过 props 将数据传递给子组件,子组件只能接收和使用这些数据。 父组件&#xff08…

系统与网络安全------Windows系统安全(8)

资料整理于网络资料、书本资料、AI,仅供个人学习参考。 DNS DNS概述 为什么需要DNS系统 www.baidu.com与119.75.217.56,哪个更好记? 互联网中的114查号台/导航员 DNS(Domian Name System,域名系统)的功…

[ctfshow web入门] web16

信息收集 提示:对于测试用的探针,使用完毕后要及时删除,可能会造成信息泄露 试试url/phpinfo.php url/phpsysinfo.php url/tz.php tz.php能用 点击phpinfo,查看phpinfo信息,搜索flag,发现flag被保存为变量…

Go基础一(Maps Functions 可变参数 闭包 递归 Range 指针 字符串和符文 结构体)

Maps 1.创建map make(map[键类型]值类型) 2.设置键值对 name[key]value; 3. name[key]获取键值 3.1 key不存在 则返回 0 4.len()方法 返回 map 上 键值对数量 len(name) 5.delete()方法 从map中删除 键值对 delete(name,key) 6.clear()方法 map中删除所有键值对 clear(name) 7…

✅ 2025最新 | YOLO 获取 COCO 指标终极指南 | 从标签转换到 COCOAPI 评估 (训练/验证) 全覆盖【B 站教程详解】

✅ YOLO 轻松获取论文 COCO 指标:AP(small,medium,large )| 从标签转换到 COCOAPI 评估 (训练/验证) 全覆盖 文章目录 一、摘要二、为什么需要 COCO 指标评估 YOLO 模型?三、核心挑战与解决方案 (视频教程核…

ResNet改进(18):添加 CPCA通道先验卷积注意力机制

1. CPCA 模块 CPCA(Channel Prior Convolutional Attention)是一种结合通道先验信息的卷积注意力机制,旨在通过显式建模通道间关系来增强特征表示能力。 核心思想 CPCA的核心思想是将通道注意力机制与卷积操作相结合,同时引入通道先验知识,通过以下方式优化特征学习: 通…

SpringMVC的简单介绍

SpringMVC的简单介绍 SpringMVC 是一个基于 Java 的 Web 框架,是 Spring Framework 中用于构建 Web 应用的一个核心模块。它采用了 模型-视图-控制器 (MVC) 设计模式,能够帮助开发者更加清晰地分离业务逻辑、用户界面和请求处理,从而提高应用…

MES生产工单管理系统,Java+Vue,含源码与文档,实现生产工单全流程管理,提升制造执行效率与精准度

前言: MES生产工单管理系统是制造业数字化转型的核心工具,通过集成生产、数据、库存等模块,实现全流程数字化管理。以下是对各核心功能的详细解析: 一、生产管理 工单全生命周期管理 创建与派发:根据销售订单或生产计…

Redis常见问题排查与解决方案指南

Redis作为高性能的内存数据库,广泛应用于缓存、队列、实时统计等场景。但在实际使用中,开发者和运维人员常会遇到性能下降、内存溢出、主从同步失败等问题。本文将针对高频问题进行详细分析,并提供对应的解决方案和预防措施,助你快…

目标跟踪Deepsort算法学习2025.4.7

一.DeepSORT概述 1.1 算法定义 DeepSORT(Deep Learning and Sorting)是一种先进的多目标跟踪算法,它结合了深度学习和传统的目标跟踪技术,在复杂环境下实现了高精度和鲁棒性的目标跟踪。该算法的核心思想是通过融合目标的外观特征和运动特征,实现对多个目标的持续跟踪,…

从零开始开发HarmonyOS应用并上架

开发环境搭建(1-2天) 硬件准备 操作系统:Windows 10 64位 或 macOS 10.13 内存:8GB以上(推荐16GB) 硬盘:至少10GB可用空间 软件安装 下载 DevEco Studio 3.1(官网:…

Linux | 无头 Linux 服务器安装和设置

注:本文为 “Headless Linux” 相关文章合辑。 机翻未校。 How to Install and Set Up Headless Linux Server 如何安装和设置无头 Linux 服务器 Winnie Ondara Last Updated: January 31, 2023 A vast majority of Linux users are familiar with a Linux desk…

AI赋能数据库管理“最后一公里”,融合架构重塑数据库承载成本效能——zCloud 6.7与zData X 3.3正式发布

点击蓝字 关注我们 在数据驱动的新时代,数据库的多元化和智能化已成不可逆的趋势。3月31日,云和恩墨以“奇点时刻数智跃迁”为主题举办线上发布会,云和恩墨创始人兼总经理盖国强、公司数据库和生态产品群总经理熊军共同带来 zCloud 6.7与 zD…

I have something to say about Vue Node.js

关于Vue Node.js,我真的说了很多次了,让我难以理解为啥这么粗糙的东西能流行一起。真疯狂的世界。 vue让感觉就像玩猫德一样的,如此的疯狂,天哪。睡觉了 Node.js v13 window7_nodejsv13-CSDN博客