昇思25天学习打卡营第13天 | ShuffleNet图像分类

在这里插入图片描述

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.

模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/gkk,一共有g组,所有组共有(in_channels/gkk)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量。

在这里插入图片描述
Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1kk,则卷积核参数量为:in_channelskk,得到的feature maps通道数与输入通道数相等;

Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×1
,卷积核参数量为(in_channels/g11)*out_channels。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensorclass GroupConv(nn.Cell):def __init__(self, in_channels, out_channels, kernel_size,stride, pad_mode="pad", pad=0, groups=1, has_bias=False):super(GroupConv, self).__init__()self.groups = groupsself.convs = nn.CellList()for _ in range(groups):self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,kernel_size=kernel_size, stride=stride, has_bias=has_bias,padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))def construct(self, x):features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)outputs = ()for i in range(self.groups):outputs = outputs + (self.convs[i](features[i].astype("float32")),)out = ops.cat(outputs, axis=1)return out

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。

为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

在这里插入图片描述
如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

在这里插入图片描述

ShuffleNet模块

如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:

  1. 将开始和最后的 1×1
    卷积模块(降维、升维)改成Point Wise Group Convolution;

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;

  3. 降采样模块中, 3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的 3×3
    平均池化,并把相加改成拼接。

在这里插入图片描述

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return outdef channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

构建ShuffleNet网络

ShuffleNet网络结构如下图所示,以输入图像 224×224
,组数3(g = 3)为例,首先通过数量24,卷积核大小为 3×3
,stride为2的卷积层,输出特征图大小为 112×112
,channel为24;然后通过stride为2的最大池化层,输出特征图大小为 56×56
,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为 1×1×960
,再经过全连接层和softmax,得到分类概率。

在这里插入图片描述

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

训练集准备与加载
采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()# 由于时间原因,epoch = 5,可根据需求进行调整model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

在这里插入图片描述

学习心得

通过本次学习,我不仅掌握了ShuffleNetV1的网络结构和实现方法,还深入理解了分组卷积和通道重排在提高模型效率中的作用。未来,我希望能够进一步探索ShuffleNetV2以及其他高效模型的设计与应用,并尝试将其应用于更多复杂的数据集和任务中。同时,我还计划研究模型压缩和加速的其他技术,如模型剪枝和量化,以进一步提升模型的应用性能。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/42699.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ExcelVBA运用Excel的【条件格式】(二)

ExcelVBA运用Excel的【条件格式】(二)前面知识点回顾1. 访问 FormatConditions 集合 Range.FormatConditions2. 添加条件格式 FormatConditions.Add 方法语法表达式。添加 (类型、 运算符、 Expression1、 Expression2)3. 修改或删除条件格式4. …

密码技术中分组模式解析

目录 1. 概述 2. ECB模式 2.1 概述 2.2 ECB模式的加密 2.3 ECB模式的解密 2.4 优点 2.5 缺点 3. CBC模式【推荐】 3.1 概述 3.2 CBC模式的加密 3.3 CBC模式的解密 3.4 优点 3.5 缺点 4. CFB模式 4.1 概述 4.2 CFB模式的加密 4.3 CFB模式的解密 4.4 优点 4.…

智慧地产视觉监控系统开源了,系统采用多种优化技术,提高系统的响应速度和资源利用率

智慧地产视觉监控平台是一款功能强大且简单易用的实时算法视频监控系统。它的愿景是最底层打通各大芯片厂商相互间的壁垒,省去繁琐重复的适配流程,实现芯片、算法、应用的全流程组合,从而大大减少企业级应用约95%的开发成本。用户只需在界面上…

Python打开Excel文档并读取数据

Python 版本 目前 Python 3 版本为主流版本,这里测试的版本是:Python 3.10.5。 常用库说明 Python 操作 Excel 的常用库有:xlrd、xlwt、xlutils、openpyxl、pandas。这里主要说明下 Excel 文档 .xls 格式和 .xlsx 格式的文档打开和读取。 …

Drools开源业务规则引擎(二)- Drools规则语言(DRL)

文章目录 1.DRL文件的组成:2.package3.import4.function5.query6.declare7.global8.rule8.1.规则属性8.2.LHS8.2.1.语法格式8.2.2.运算符优先级8.2.3.特殊的运算符1.matches, not matches2.contains, not contains3.memberOf, not memberOf4.in, notin5.soundslike6…

Powershell 获取电脑保存的所有wifi密码

一. 知识点 netsh wlan show profiles 用于显示计算机上已保存的无线网络配置文件 Measure-Object 用于统计数量 [PSCustomObject]{ } 用于创建Powershell对象 [math]::Round 四舍五入 Write-Progress 显示进度条 二. 代码 只能获取中文Windows操作系统的wifi密码如果想获取…

护网在即,助力安服仔漏洞扫描~

整合了个漏扫系统,安服仔必备~ 使用场景 网前布防,漏洞扫描,资产梳理 使用方法: 启动虚拟机后运行命令: ./StartSystemScript.sh 输入密码attack 启动完成后浏览器打开网站: http://IP:5000 相关账户…

02-android studio实现下拉列表+单选框+年月日功能

一、下拉列表功能 1.效果图 2.实现过程 1&#xff09;添加组件 <LinearLayoutandroid:layout_width"match_parent"android:layout_height"wrap_content"android:layout_marginLeft"20dp"android:layout_marginRight"20dp"android…

运维系列.Nginx中使用HTTP压缩功能

运维专题 Nginx中使用HTTP压缩功能 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550…

【刷题汇总--字符串中找出连续最长的数字串、岛屿数量、拼三角】

C日常刷题积累 今日刷题汇总 - day0071、字符串中找出连续最长的数字串1.1、题目1.2、思路1.3、程序实现 -- 比较1.4、程序实现 -- 双指针 2、岛屿数量2.1、题目2.2、思路2.3、程序实现 - dfs 3、拼三角3.1、题目3.2、思路3.3、程序实现 -- 蛮力法3.4、程序实现 -- 巧解(单调性…

pwm 呼吸灯(如果灯一直亮或者一直灭)

&#xff08;这个文章收藏在我的csdn keil文件夹下面&#xff09; 如果这样设置预分频和计数周期&#xff0c;那么算出来的pwm频率如下 人眼看起来就只能是一直亮或者灭&#xff0c;因为pwm的频率太高了&#xff0c;但是必须是频率够高&#xff0c;才能实现呼吸灯的缓慢亮缓慢…

SPL-404:如何彻底改变Solana上的NFT与DeFi

在不断发展的数字资产领域中&#xff0c;非同质化Token&#xff08;NFT&#xff09;已成为一股革命性力量&#xff0c;彻底改变了我们对数字所有权的看法和互动方式。从艺术和收藏品到游戏和虚拟房地产&#xff0c;NFT吸引了创作者、投资者和爱好者的想象力。 本指南将带您进入…

MyBatisPlus-分页插件的基本使用

目录 配置插件 使用分页API 配置插件 首先&#xff0c;要在配置类中注册MyBatisPlus的核心插件&#xff0c;同时添加分页插件。&#xff08;可以放到config软件包下&#xff09; 可以看到&#xff0c;我们定义了一个配置类&#xff0c;在配置类里声明了一个Bean,这个Bean的名…

排序 -- 计数排序以及对排序的总结

到了这篇文章就说明常见的排序我们就快要讲完了&#xff0c;那这篇文章我们就讲一下非比较排序--计数排序。 一、非比较排序 1.基本思想 计数排序又称为鸽巢原理&#xff0c;是对哈希直接定址法的变形应用。 操作步骤&#xff1a; 统计相同元素出现次数 根据统计的结果将序列…

昇思25天学习打卡营第14天|基于MindNLP的文本解码原理

基于MindNLP的文本解码原理 文本解码 文本解码是自然语言处理中的一个关键步骤,特别是在任务如机器翻译、文本摘要、自动回复生成等领域。解码过程涉及将编码器(如语言模型、翻译模型等)的输出转换为可读的文本序列。以下是一些常见的文本解码方法和原理: 1. 自回归解码:…

打造属于你的私人云盘:在 OrangePi AIpro 上搭建个人云盘

随着数字化时代的到来&#xff0c;数据的存储和管理变得愈发重要。相比于公共云存储服务&#xff0c;搭建一个属于自己的个人云盘不仅能够更好地保护隐私&#xff0c;还可以更灵活地管理数据。 近期刚好收到了一个 香橙派 AIpro 的开发板&#xff0c;借此机会用来搭建一个属于…

美股交易相关知识点 持续完善中

美股交易时间 美东时间&#xff1a;除了凌晨 03:50 ~ 04:00 这10分钟时间不可交易以外&#xff0c;其他时间都是可以交易的。 如果是在香港或者北京时间下交易要区分两种: 美东夏令时&#xff1a;除了下午 15:50 ~ 16:00 这10分钟时间不可交易以外&#xff0c;其他时间都是可…

法国工程师IMT联盟 密码学及其应用 2022年期末考试

1 密码学 1.1 问题1 对称加密&#xff08;密钥加密) 1.1.1 问题 对称密钥la cryptographie symtrique和公开密钥有哪些优缺点&#xff1f; 1.1.1.1 对称加密&#xff08;密钥加密)的优缺点 1.1.1.1.1 优点 加解密速度快encrypt and decrypt&#xff1a;对称加密算法通常基于…

【vue组件库搭建06】组件库构建及npm发包

一、格式化目录结构 根据以下图片搭建组件库目录 index.js作为入口文件&#xff0c;将所有组件引入&#xff0c;并注册组件名称 import { EButton } from "./Button"; export * from "./Button"; import { ECard } from "./Card"; export * fr…

一、MyBatis

一、MyBatis 1、MyBatis简介 1.1、MyBatis历史 MyBatis最初是Apache的一个开源项目iBatis, 2010年6月这个项目由Apache Software Foundation迁移到了Google Code。随着开发团队转投Google Code旗下&#xff0c; iBatis3.x正式更名为MyBatis。代码于2013年11月迁移到Github。…