CNN(六):ResNeXt-50实战

  •  🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

        ResNeXt是有何凯明团队在2017年CVPR会议上提出来的新型图像分类网络。它是ResNet的升级版,在ResNet的基础上,引入了cardinality的概念,类似于ResNet。ResNeXt也有ResNeXt-50,ResNeXt-101版本。

1 模型结构

        在ResNeXt的论文中,作者提出了当时普遍存在的一个问题,如果要提高模型的准确率,往往采取加深网络或者加宽网络的方法。虽然这种方法有效,但随之而来的,是网络设计的难度和计算开销的增加。为了一点精度的提升往往需要付出更大的代价,因此需要一个更好的策略,在不额外增加计算代价的情况下,提升网络的精度。由此,何等人提出了cardinality的概念。

        下面是ResNet(左)与ResNeXt(右)block的差异。在ResNet中,输入的具有256个通道的特征经过1x1卷积压缩4倍到64个通道,之后3x3的卷积核用于处理特征,经1x1卷积扩大通道数与原特征残差连接后输出。ResNeXt也是相同的处理策略,但在ResNeXt中,输入的具有256个通道的特征被分为32个组,每组被压缩64倍到4个通道后进行处理。32个组相加后与原特征残差连接后输出。这里cardinatity指的是一个block中所具有的相同分支的数目。

图1 ResNet和ResNeXt​​​​

2 分组卷积 

        ResNeXt中采用的分组卷积简单来说是将特征图分为不同的组,再对每组特征图分别进行卷积,这个操作可以有效的降低计算量。

        在分组卷积中,每个卷积核只处理部分通道,比如下图中,红色卷积核只处理红色的通道,绿色卷积核只处理绿色的通道,黄色卷积核只处理黄色通道。此时每个卷积核有2个通道,每个卷积核生成一张特征图。

图2 分组卷积示意图

        分组卷积的优势在于其参数开销,图3是其对比效果。

图3 标准卷积和分组卷积参数量对比

 3 代码实现

3.1 开发环境

电脑系统:ubuntu16.04

编译器:Jupter Lab

语言环境:Python 3.7

深度学习环境:tensorflow

 3.2 前期准备

3.2.1 设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]], "GPU")

3.2.2 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号import os, PIL, pathlib
import numpy as npfrom tensorflow import keras
from tensorflow.keras import layers,modelsdata_dir = "../data/bird_photos"
data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

3.2.3 加载数据

batch_size = 8
img_height = 224
img_width = 224train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_Names = train_ds.class_names
print("class_Names:",class_Names)

3.2.4 可视化数据

plt.figure(figsize=(10, 5)) # 图形的宽为10,高为5
plt.suptitle("imshow data")for images,labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_Names[labels[i]])plt.axis("off")

3.2.5 检查数据

for image_batch, lables_batch in train_ds:print(image_batch.shape)print(lables_batch.shape)break

3.2.6 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

3.3 ResNeXt-50代码实现

3.3.1 分组卷积模块

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c*g_channels:(c+1)*g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3,3), strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merge = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merge)x = ReLU()(x)return x

3.3.2 残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters*2, kernel_size=(1,1), strides=strides, padding='same', use_bias=False)(x)# epsilon位BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1,1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1,1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

3.3.3 堆叠残差单元

        每个stack的第一个block的输入和输出shape是不一致的,所以残差连接都需要使用1x1卷积升维后才能进行Add操作。而其他block的输入和输出的shape是一致的,所以可以直接执行Add操作。

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

3.3.4 搭建ResNeXt-50网络

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224, 224, 3] -> [230, 230, 3]x = ZeroPadding2D((3,3))(inputs)x = Conv2D(filters=64, kernel_size=(7,7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3,3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return modelmodel = ResNext50(input_shape=(224,224,3), num_classes=4)
model.summary()

     结果显示如下(由于结果内容较多,只展示前后部分内容):

 (中间内容省略)

3.4 正式训练

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs)

     结果如下图所示:

3.5 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("ResNeXt-50 test")plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation loss')
plt.legend(loc='upper right')
plt.title('Training and Validation loss')
plt.show()

      结果如下图所示: 

4 总结

         总而言之,ResNeXt是在ResNet的网络架构上,使用类似于Inception的分治思想,即split-tranform-merge策略,将模块中的网络拆开分组,与Inception不同,每组的卷积核大小一致,这样其感受野一致,但由于每组的卷积核参数不同,提取的特征自然不同。然后将每组得到的特征进行concat操作后,再与原输入特征x或者经过卷积等处理(即进行非线性变换)的特征进行Add操作。这样做的好处是,在不增加参数复杂度的前提下提高准确率,同时还能提高超参数的数量。

        另外,cardinality是基的意思,将数个通道特征进行分组,不同的特征组之间可以看作是由不同基组成的子空间,每个组的核虽然一样,但参数不同,在各自的子空间中学到的特征就多种多样,这点跟transformer中的Multi-head attention不谋而合(Multi-head attention allows the model to jointly attend to information from different representation subspaces.)而且分组进行特征提取,使得学到的特征冗余度降低,获取能起到正则化的作用。(参考ResNeXt的分类效果为什么比Resnet好? - 知乎)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/68010.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【mybatis-plus进阶】多租户场景中多数据源自定义来源dynamic-datasource实现

Springbootmybatis-plusdynamic-datasourceDruid 多租户场景中多数据源自定义来源dynamic-datasource实现 文章目录 Springbootmybatis-plusdynamic-datasourceDruid 多租户场景中多数据源自定义来源dynamic-datasource实现0.前言1. 作者提供了接口2. 基于此接口的抽象类实现自…

简单工厂模式概述和使用

目录 一、简单工厂模式简介1. 定义2. 使用动机 二、简单工厂模式结构1.模式结构2. 时序图 三、简单工厂的使用实例四、简单工厂模式优缺点五、简单工厂模式在Java中的应用 一、简单工厂模式简介 原文链接 1. 定义 简单工厂模式(Simple Factory Pattern):又称为静…

博流RISC-V芯片JTAG debug配置与运行

文章目录 1、Windows下安装与配置2、Linux下安装与配置3、芯片默认 JTAG PIN 列表4、命令行运行JTAG5、Eclipse下使用JTAG 1、Windows下安装与配置 CKLink 驱动安装 Windows版驱动下载地址: https://occ-oss-prod.oss-cn-hangzhou.aliyuncs.com/resource//1666331…

Spring Boot 整合 Redis,使用 RedisTemplate 客户端

文章目录 一、SpringBoot 整合 Redis1.1 整合 Redis 步骤1.1.1 添加依赖1.1.2 yml 配置文件1.1.3 Config 配置文件1.1.4 使用示例 1.2 RedisTemplate 概述1.2.1 RedisTemplate 简介1.2.2 RedisTemplate 功能 二、RedisTemplate API2.1 RedisTemplate 公共 API2.2 String 类型 A…

iWatch框架设计

iWatch框架设计 一、项目框架结构设计 1、项目文件介绍 OverSeaProject:是IOS相关文件文件内容iWatchApp和iWatch Extension:是之前使用xcode14之前的xcode创建的360 app的Watch App,产生的文件结构,包含一个app和Extension的ta…

025: vue父子组件中传递方法控制:$emit,$refs,$parent,$children

第025个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下,本专栏提供行之有效的源代码示例和信息点介绍,做到灵活运用。 (1)提供vue2的一些基本操作:安装、引用,模板使…

如何在`Pycharm`中配置基于WSL的`Python Interpreters`,以及配置基于WSL的`Terminal`

文章目录 一、创建pycharm用户并授予sudo权限0. 启动WSL下的CentOS1. 创建pycharm用户并授予sudo权限2. 设置pycharm用户为wsl启动Linux的默认用户3. 重启并重新登录wsl下的CentOS4. 验证pycharm用户的sudo权限 二、创建基于WSL的Python Interpreter1. 添加基于WSL的Python Int…

DR IP-SoC China 2023 Day演讲预告 | 龙智Perforce专家解析芯片开发中的数字资产管理

2023年9月6日(周三),龙智即将亮相于上海举行的D&R IP-SoC China 2023 Day,呈现集成了Perforce与Atlassian产品的芯片开发解决方案,助力企业更好、更快地进行芯片开发。 D&R IP-SoC China 2023 Day 是中国首个…

⽹络与HTTP 笔试题精讲1

OSI七层与TCP/IP 这个就是OSI参考模型,⽽实际我们现在的互联⽹世界是就是这个理论模型的落地叫做TCP/IP协议 TCP的三次握⼿与四次挥⼿ 客户端想要发送数据给服务端,在发送实际的数据之前,需要先在两端之间建⽴连接,数据发完以后也需要将该连接关闭。建⽴连接的过程就是我们…

热门框架漏洞

文章目录 一、Thinkphp5.0.23 代码执行1.thinkphp5框架2.thinkphp5高危漏洞3.漏洞特征4.THinkphp5.0 远程代码执行--poc5.TP5实验一(Windows5.0.20)a.搭建实验环境b.测试phpinfoc.写入shelld.使用菜刀连接 6.TP5实验二(Linux5.0.23)a.搭建实验环境b.测试方法c.测试phpinfod.写入…

SQL 语句继续学习之记录三

一,数据的插入(insert 语句的使用方法) 使用insert语句可以向表中插入数据(行)。原则上,insert语句每次执行一行数据的插入。 列名和值用逗号隔开,分别扩在()内,这种形式称为清单。…

如何使用Python和正则表达式处理XML表单数据

在日常的Web开发中,处理表单数据是一个常见的任务。而XML是一种常用的数据格式,用于在不同的系统之间传递和存储数据。本文通过阐述一个技术问题并给出解答的方式,介绍如何使用Python和正则表达式处理XML表单数据。我们将探讨整体设计、编写思…

VB6.0 设置窗体的默认焦点位置在 TextBox 中

文章目录 VB6.0 窗体的加载过程确定指针的焦点位置添加代码效果如下未设置指定焦点已设置焦点 VB6.0 窗体的加载过程 在VB6.0中,窗体(Form)加载时会触发多个事件,这些事件按照特定的顺序执行。下面是窗体加载过程中常见事件的执行…

基于ETLCloud的自定义规则调用第三方jar包实现繁体中文转为简体中文

背景 前面曾体验过通过零代码、可视化、拖拉拽的方式快速完成了从 MySQL 到 ClickHouse 的数据迁移,但是在实际生产环境,我们在迁移到目标库之前还需要做一些过滤和转换工作;比如,在诗词数据迁移后,发现原来 MySQL 中…

常见路由跳转的几种方式

常见的路由跳转有以下四种&#xff1a; 1. <router-link to"跳转路径"> /* 不带参数 */ <router-link :to"{name:home}"> <router-link :to"{path:/home}"> // 更建议用name // router-link链接中&#xff0c;带/ 表示从根…

elementUI可拖拉宽度抽屉

1&#xff0c;需求&#xff1a; 在elementUI的抽屉基础上&#xff0c;添加可拖动侧边栏宽度的功能&#xff0c;实现效果如下&#xff1a; 2&#xff0c;在原组件上添加自定义命令 <el-drawer v-drawerDrag"left" :visible.sync"drawerVisible" direc…

Unity 顶点vertices,uv,与图片贴图,与mesh

mesh就是组成3d物体的三角形们。 mesh由顶点组成的三角形组成&#xff0c;三角形的大小 并不 需要一样&#xff0c;由顶点之间的位置决定。 mesh可以是一个或者多个面。 贴图的原点在左下角&#xff0c;uv是贴图的坐标&#xff0c;数量和顶点数一样&#xff08;不是100%确定…

Vue3 el-tooltip 根据内容控制宽度大小换行和并且内容太短不显示

el-tooltip 根据长度自适应换行以及显隐 环境 vue: "3.2.37" element-ui: "2.1.8"要求 tooltip 根据内容自动换行如果内容超出显示省略号显示&#xff0c;不超出不显示 tooltip 代码 组件 // ContentTip 组件 <template><el-tooltipv-bind&qu…

【数学建模竞赛】超详细Matlab二维三维图形绘制

二维图像绘制 绘制曲线图 g 是表示绿色 b--o是表示蓝色/虚线/o标记 c*是表示蓝绿色(cyan)/*标记 ‘MakerIndices,1:5:length(y) 每五个点取点&#xff08;设置标记密度&#xff09; 特殊符号的输入 序号 需求 函数字符结构 示例 1 上角标 ^{ } title( $ a…

初识c++

文章目录 前言一、C命名空间1、命名空间2、命名空间定义 二、第一个c程序1、c的hello world2、std命名空间的使用惯例 三、C输入&输出1、c输入&输出 四、c中缺省参数1、缺省参数概念2、缺省参数分类3、缺省参数应用 五、c中函数重载1、函数重载概念2、函数重载应用 六、…