深度学习J6周 ResNeXt-50实战解析

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

本周任务:

1.阅读ResNeXt论文,了解作者的构建思路

2.对比之前介绍的ResNet50V2、DenseNet算法

3.复现ResNeXt-50算法

一、模型结构

ResNeXt由何凯明团队,2017年CVPR会议上提出新型图像分类网络。它是ResNet升级版,在ResNet的基础上,引入cardinality概念。

在论文中,作者提出当时普遍存在的一个问题,如果要提高模型准确率,往往采取加深网络或者加宽网络的方法。但网络设计的难度和计算开销也增加了。为了一点精度的提升往往付出更大的代价。因此,需要在不额外增加计算代价的情况下,提升网络精度。

左边--ResNet,输入的具有256个通道的特征经过1*1卷积压缩到64个通道,之后3*3的卷积核用于处理特征,经1*1卷积扩大通道数与原特征残差连接后输出。

右边--ResNeXt,输入的具有256个通道的特征被分为32个组,每组被压缩到4个通道后处理,32个组相加后与原特征残差连接后输出。cardinality指的是一个block中所具有相同的分支的数目。

二、分组卷积

1.ResNeXt采用分组卷积:将特征图分为不同的组,再对每组特征图分别进行卷积,有效降低计算量。

2.分组卷积中,每个卷积核只处理部分通道,如下图,红色卷积核只处理红色通道,绿色卷积核只处理绿色通道,黄色卷积核只处理黄色通道。此时,每个卷积核有2个通道,每个卷积核生成一张特征图。

三、代码

学习于深度学习第J6周:ResNeXt-50实战解析_resnext50-CSDN博客

 1.前期准备

#配置GPU
import os, PIL, random, pathlib
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)#导入数据集
data_dir = './data/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
print(classeNames)image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)#数据预处理+划分数据集
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder("./data/", transform=train_transforms)
print(total_data.class_to_idx)train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

结果:

2.模型

class BN_Conv2d(nn.Module):"""BN_CONV_RELU"""def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):super(BN_Conv2d, self).__init__()self.seq = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=bias),nn.BatchNorm2d(out_channels))def forward(self, x):return F.relu(self.seq(x))class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self, in_chnls, cardinality, group_depth, stride):super(ResNeXt_Block, self).__init__()self.group_chnls = cardinality * group_depthself.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)self.bn = nn.BatchNorm2d(self.group_chnls*2)self.short_cut = nn.Sequential(nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),nn.BatchNorm2d(self.group_chnls*2))def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(self.conv3(out))out += self.short_cut(x)return F.relu(out)class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:super(ResNeXt, self).__init__()self.cardinality = cardinalityself.channels = 64self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)d1 = group_depthself.conv2 = self.___make_layers(d1, layers[0], stride=1)d2 = d1 * 2self.conv3 = self.___make_layers(d2, layers[1], stride=2)d3 = d2 * 2self.conv4 = self.___make_layers(d3, layers[2], stride=2)d4 = d3 * 2self.conv5 = self.___make_layers(d4, layers[3], stride=2)self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input sizedef ___make_layers(self, d, blocks, stride):strides = [stride] + [1] * (blocks-1)layers = []for stride in strides:layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))self.channels = self.cardinality*d*2return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = F.max_pool2d(out, 3, 2, 1)out = self.conv2(out)out = self.conv3(out)out = self.conv4(out)out = self.conv5(out)out = F.avg_pool2d(out, 7)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out),dim=1)return out
# 定义完成,测试一下
model = ResNeXt([3, 4, 6, 3], 32, 4, 4)
model.to(device)# 统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model, (3, 224, 224))

结果:

 

 

 3.训练运行

 
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
 
import copyoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数epochs = 10train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):# 更新学习率(使用自定义学习率时使用)# adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)print('Done')

结果:

 

4.打印训练图

import matplotlib.pyplot as plt
# 隐藏警告
import warningswarnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

四、总结

1.读论文原文要花很长时间,但有讲义,就会快速知道论文的创新点是什么。

2.实验的流程已经很熟悉,现在就在慢慢学每一步的具体内容,争取下次能自己写出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/64435.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对话 Project Astra 研究主管:打造通用 AI 助理,主动视频交互和全双工对话是未来重点

Project Astra 愿景之一:「系统不仅能在你说话时做出回应,还能在持续的过程中帮助你。」 近期,Google DeepMind 的 YouTube 频道采访了 Google DeepMind 研究主管格雷格韦恩 (Greg Wayne)。 格雷格韦恩的研究工作为 DeepMind 的诸多突破性成…

LunarVim安装

LunarVim以其丰富的功能和灵活的定制性,迅速在Nvim用户中流行开来。它不仅提供了一套完善的默认配置,还允许用户根据自己的需求进行深度定制。无论是自动补全、内置终端、文件浏览器,还是模糊查找、LSP支持、代码检测、格式化和调试&#xff…

高质量 Next.js 后台管理模板源码分享,开发者必备

高质量 Next.js后台管理模板源码分享,开发者必备 Taplox 是一个基于 Bootstrap 5 和 Next.js 构建的现代化后台管理模板和 UI 组件库。它不仅设计精美,还提供了一整套易用的工具,适合各种 Web 应用、管理系统和仪表盘项目。无论你是初学者还是…

开发场景中Java 集合的最佳选择

在 Java 开发中,集合类是处理数据的核心工具。合理选择集合,不仅可以提高代码效率,还能让代码更简洁。本篇文章将重点探讨 List、Set 和 Map 的适用场景及优缺点,帮助你在实际开发中找到最佳解决方案。 一、List:有序存…

Java包装类型的缓存

Java 基本数据类型的包装类型的大部分都用到了缓存机制来提升性能。 Byte,Short,Integer,Long 这 4 种包装类默认创建了数值 [-128,127] 的相应类型的缓存数据,Character 创建了数值在 [0,127] 范围的缓存数据,Boolean 直接返回 True or Fal…

工程师 - MinGW

MinGW Minimalist GNU for Windows,前身为mingw32,是一个免费开源的软件开发环境,从2010年开始项目停止并不再使用。后续提供MinGW-w64。 MinGW包括: - 移植到Windows上的GNU编译器集(GCC),包括C、C、ADA和…

EasyExcel(读取操作和填充操作)

文章目录 1.准备Read.xlsx(具有两个sheet)2.读取第一个sheet中的数据1.模板2.方法3.结果 3.读取所有sheet中的数据1.模板2.方法3.结果 EasyExcel填充1.简单填充1.准备 Fill01.xlsx2.无模版3.方法4.结果 2.列表填充1.准备 Fill02.xlsx2.模板3.方法4.结果 …

CKA认证 | Day7 K8s存储

第七章 Kubernetes存储 1、数据卷与数据持久卷 为什么需要数据卷? 容器中的文件在磁盘上是临时存放的,这给容器中运行比较重要的应用程序带来一些问题。 问题1:当容器升级或者崩溃时,kubelet会重建容器,容器内文件会…

关于JAVA方法值传递问题

1.1 前言 之前在学习C语言的时候,将实参传递给方法(或函数)的方式分为两种:值传递和引用传递,但在JAVA中只有值传递(颠覆认知,基础没学踏实) 参考文章:https://blog.csd…

Excel基础知识

一:数组 一行或者一列数据称为一维数组,多行多列称为二维数组,数组支持算术运算(如加减乘除等)。 行:{1,2,3,4} 数组中的每个值用逗号分隔列:{1;2;3;4} 数组中的每个值用分号分隔行列&#xf…

基于DIODES AP43781+PI3USB31531+PI3DPX1207C的USB-C PD Video 之全功能显示器连接端口方案

随着USB-C连接器和PD功能的出现,新一代USB-C PD PC显示器可以用作个人和专业PC工作环境的电源和数据集线器。 虽然USB-C PD显示器是唯一插入墙壁插座的交流电源输入设备,但它可以作为数据UFP(上游接口)连接到连接到TCD&#xff0…

gazebo_world 基本围墙。

如何使用&#xff1f; 参考gazebo harmonic的官方教程。 本人使用harmonic的template&#xff0c;在里面进行修改就可以分流畅地使用下去。 以下是world 文件. <?xml version"1.0" ?> <!--Try sending commands:gz topic -t "/model/diff_drive/…

解决无法在 Ubuntu 24.04 上运行 AppImage 应用

在 Ubuntu 24.04 中运行 AppImage 应用的完整指南 在 Ubuntu 24.04 中&#xff0c;许多用户可能会遇到 AppImage 应用无法启动的问题。即使你已经设置了正确的文件权限&#xff0c;AppImage 仍然拒绝运行。这通常是由于缺少必要的库文件所致。 问题根源&#xff1a;缺少 FUSE…

springboot配置oracle+达梦数据库多数据源配置并动态切换

项目场景&#xff1a; 在工作中很多情况需要跨数据库进行数据操作,自己总结的经验希望对各位有所帮助 问题描述 总结了几个问题 1.识别不到mapper 2.识别不到xml 3.找不到数据源 原因分析&#xff1a; 1.配置文件编写导致识别mapper 2.配置类编写建的格式有问题 3.命名…

html+css+js网页设计 美食 家美食1个页面

htmlcssjs网页设计 美食 家美食1个页面 网页作品代码简单&#xff0c;可使用任意HTML辑软件&#xff08;如&#xff1a;Dreamweaver、HBuilder、Vscode 、Sublime 、Webstorm、Text 、Notepad 等任意html编辑软件进行运行及修改编辑等操作&#xff09;。 获取源码 1&#xf…

【机器学习】【朴素贝叶斯分类器】从理论到实践:朴素贝叶斯分类器在垃圾短信过滤中的应用

&#x1f31f; 关于我 &#x1f31f; 大家好呀&#xff01;&#x1f44b; 我是一名大三在读学生&#xff0c;目前对人工智能领域充满了浓厚的兴趣&#xff0c;尤其是机器学习、深度学习和自然语言处理这些酷炫的技术&#xff01;&#x1f916;&#x1f4bb; 平时我喜欢动手做实…

Vue使用Tinymce 编辑器

目录 一、下载并重新组织tinymce结构二、使用三、遇到的坑 一、下载并重新组织tinymce结构 下载 npm install tinymce^7 or yarn add tinymce^7重构目录 在node_moudles里找到tinymce文件夹&#xff0c;把里面文件拷贝一份放到public下&#xff0c;如下&#xff1a; -- pub…

EMNLP'24 最佳论文解读 | 大语言模型的预训练数据检测:基于散度的校准方法

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入&#xff01; 点击 阅读原文 观看作者讲解回放&#xff01; 作者简介 张伟超&#xff0c;中国科学院计算所网络数据科学与技术重点实验室三年级直博生 内容简介 近年来&#xff0c;大语言模型&#xff08;LLMs&#xff09;的…

大数据技术-Hadoop(一)Hadoop集群的安装与配置

目录 一、准备工作 1、安装jdk&#xff08;每个节点都执行&#xff09; 2、修改主机配置 &#xff08;每个节点都执行&#xff09; 3、配置ssh无密登录 &#xff08;每个节点都执行&#xff09; 二、安装Hadoop&#xff08;每个节点都执行&#xff09; 三、集群启动配置&a…

折腾日记:如何让吃灰笔记本发挥余热——搭建一个相册服务

背景 之前写过&#xff0c;我在家里用了一台旧的工作站笔记本做了服务器&#xff0c;连上一个绿联的5位硬盘盒实现简单的网盘功能&#xff0c;然而&#xff0c;还是觉的不太理想&#xff0c;比如使用filebrowser虽然可以备份文件和图片&#xff0c;当使用手机使用网页&#xf…