【MindSpore学习打卡】应用实践-计算机视觉-ShuffleNet图像分类:从理论到实践

在当今的深度学习领域,卷积神经网络(CNN)已经成为图像分类任务的主流方法。然而,随着网络深度和复杂度的增加,计算资源的消耗也显著增加,特别是在移动设备和嵌入式系统中,这种资源限制尤为突出。ShuffleNet作为一种高效的卷积神经网络,通过引入Pointwise Group Convolution和Channel Shuffle两种操作,大大降低了计算量,同时保持了较高的分类精度。在本篇博客中,我们将详细探讨ShuffleNet的设计原理,并通过MindSpore框架实现ShuffleNet在CIFAR-10数据集上的训练与评估,帮助读者更好地理解和应用这一高效的网络结构。

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,主要应用在移动端。其设计核心在于引入了两种操作:Pointwise Group Convolution和Channel Shuffle。这些操作在保持精度的同时,大大降低了模型的计算量。

Pointwise Group Convolution

Pointwise Group Convolution:我们在代码中定义了一个GroupConv类,用于实现逐点分组卷积。这种卷积操作通过将输入特征图分成多个组,每组单独进行卷积操作,从而显著减少了参数量和计算量。具体来说,逐点分组卷积的卷积核大小为 1 × 1 1 \times 1 1×1,这使得每个卷积核只作用于一个通道,进一步降低了计算复杂度。

Group Convolution

class GroupConv(nn.Cell):def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):super(GroupConv, self).__init__()self.groups = groupsself.convs = nn.CellList()for _ in range(groups):self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,kernel_size=kernel_size, stride=stride, has_bias=has_bias,padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))def construct(self, x):features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)outputs = ()for i in range(self.groups):outputs = outputs + (self.convs[i](features[i].astype("float32")),)out = ops.cat(outputs, axis=1)return out

Channel Shuffle

Channel Shuffle:为了克服分组卷积带来的不同组别通道无法进行信息交流的问题,ShuffleNet引入了Channel Shuffle机制。我们在代码中实现了一个channel_shuffle方法,通过对通道进行重排,使得不同组别的通道能够进行信息交互。这一步骤在保持网络高效性的同时,增强了特征的表达能力。
Channel Shuffle

def channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

ShuffleNet模块

ShuffleNet模块:在ShuffleNet模块中,我们结合了Pointwise Group Convolution和Channel Shuffle,并在降采样模块中引入了步长为2的Depth Wise Convolution。这种设计不仅提高了网络的计算效率,还保证了特征提取的有效性。在代码实现中,我们通过ShuffleV1Block类定义了ShuffleNet的基本模块,并在其中实现了上述操作。

ShuffleNet模块

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return out

构建ShuffleNet网络

ShuffleNet网络结构如下图所示。以输入图像 224 × 224 224 \times 224 224×224,组数3(g = 3)为例,经过多个ShuffleNet模块和全局平均池化,最终得到分类结果。

ShuffleNet网络结构

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

模型训练和评估

模型训练和评估:在训练部分,我们使用了CIFAR-10数据集,并通过数据增强技术(如随机裁剪和水平翻转)来提高模型的泛化能力。我们定义了ShuffleNet网络,并使用交叉熵损失函数和Momentum优化器进行训练。在评估部分,我们加载训练好的模型,并在测试集上进行评估,计算模型的Top-1和Top-5准确率,以全面衡量模型的性能。

训练集准备与加载

首先下载并加载CIFAR-10数据集。CIFAR-10共有60000张32x32的彩色图像,分为10个类别,其中50000张图片作为训练集,10000张图片作为测试集。

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transformsdef get_dataset(train_dataset_path, batch_size, usage):image_trans = []if usage == "train":image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]elif usage == "test":image_trans = [vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]label_trans = transforms.TypeCast(ms.int32)dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)dataset = dataset.map(image_trans, 'image')dataset = dataset.map(label_trans, 'label')dataset = dataset.batch(batch_size, drop_remainder=True)return datasetdataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()

模型训练

定义ShuffleNet网络,并使用交叉熵损失函数和Momentum优化器进行训练。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

在这里插入图片描述

模型评估

在CIFAR-10测试集上对训练好的模型进行评估。

from mindspore import load_checkpoint, load_param_into_netdef test():mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")net = ShuffleNetV1(model_size="2.0x", n_class=10)param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")load_param_into_net(net, param_dict)net.set_train(False)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),'Top_5_Acc': Top5CategoricalAccuracy()}model = Model(net, loss_fn=loss, metrics=eval_metrics)start_time = time.time()res = model.eval(dataset, dataset_sink_mode=False)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \+ "', time: " + hour + "h " + minute + "m " + second + "s"print(log)filename = './eval_log.txt'with open(filename, 'a') as file_object:file_object.write(log + '\n')if __name__ == '__main__':test()

在这里插入图片描述

模型预测

在CIFAR-10测试集上对模型进行预测,并将预测结果可视化。

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as dsnet = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:plt.subplot(2, 8, index+1)plt.title('{}'.format(class_dict[pred[index]]))index += 1plt.imshow(image)plt.axis("off")
plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39698.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

25计算机考研,这些学校双非闭眼入,性价比超高!

计算机考研,好的双非院校也很多! 对于一些二本准备考研的同学来说,没必要一直盯着985/211这些院校,竞争激烈不说,容易当陪跑,下面这些就是不错的双非院校: 燕山大学南京邮电大学南京信息工程大…

WPS-Word文档表格分页

一、问题描述 这种情况不好描述 就是像这种表格内容,但是会有离奇的分页的情况。这种情况以前的错误解决办法就是不断地调整表格的内容以及间隔显得很乱,于是今天去查了解决办法,现在学会了记录一下避免以后忘记了。 二、解决办法 首先记…

8.SQL注入-基于insert,update利用案例

SQL注入-基于insert/update利用案例 sql语句正常插入表中的数据 insert into member(username,pw,sex,phonenum,address,email) values(xiaoqiang,1111,1,2,3,4); select * from member;例如插入小强数据,如图所示: 采用or这个运算符,构造…

实测有效:Win11右键默认显示更多

Win11最大的变化之一莫过于右键菜单发生了变化,最大的问题是什么,是右键菜单很多时候需要点两次,实在是反人类 第一步 复制以下命令直接运行: reg.exe add "HKCU\Software\Classes\CLSID\{86ca1aa0-34aa-4e8b-a509-50c905ba…

python_zabbix

zabbix官网地址:19. API19. APIhttps://www.zabbix.com/documentation/4.2/zh/manual/api 每个版本可以有些差异,选择目前的版本在查看对于的api接口#token接口代码 import requests apiurl "http://zabbix地址/api_jsonrpc.php" data {&quo…

web的学习和开发

这个使同步和异步的区别 今天主要就是学了一些前端,搞了一些前端的页面,之后准备学一下后端。 我写的这个项目使百度贴吧,还没有写er图。 先看一下主界面是什么样子的。 这个是主界面,将来后面的主要功能点基本上全部是放在这个上…

推动能源绿色低碳发展,风机巡检进入国产超高清+AI时代

全球绿色低碳能源数字转型发展正在进入一个重要窗口期。风电作为一种清洁能源,在碳中和过程中扮演重要角色,但风电场运维却是一件十足的“苦差事”。 传统的风机叶片人工巡检方式主要依靠巡检人员利用高倍望远镜检查、高空绕行下降目测检查(蜘蛛人)、叶…

STM32——Modbus协议

一、Modbus协议简介: 1.modbus介绍: Modbus是一种串行通信协议,是Modicon公司(现在的施耐德电气 Schneider Electric)于1979年为使用可编程逻辑控制器(PLC)通信而发表。Modbus已经成为工业领域…

PDF压缩工具选哪个?6款免费PDF压缩工具分享

PDF文件已经成为一种常见的文档格式。然而,PDF文件的体积有时可能非常庞大,尤其是在包含大量图像或复杂格式的情况下。选择一个高效的PDF压缩工具就显得尤为重要。小编今天给大家整理了2024年6款市面上反响不错的PDF压缩文件工具。轻松帮助你找到最适合自…

漆包线行业生产管理革新:万界星空科技MES系统解决方案

一、引言 在科技日新月异的今天,万界星空科技凭借其在智能制造领域的深厚积累,为漆包线行业量身打造了一套先进的生产管理执行系统(MES)解决方案。随着市场竞争的加剧,漆包线作为电气设备的核心材料,其生产…

React+TS前台项目实战(二十四)-- 绘制组件Qrcode封装

文章目录 前言Qrcode组件1. 功能分析2. 代码详细注释3. 使用方式4. 效果展示(pc端 / 移动端) 总结 前言 今天要封装的Qrcode 组件,是通过传入的信息,绘制在二维码上,可用于很多场景,如区块链项目中的区块显示交易地址时就可以用到…

无人值守停车场管理系统具备哪些功能?无人值守收费停车场系统多少钱

随着城市化进程的加快,停车难已成为制约城市发展的一个突出问题。在传统停车场管理中,人工收费、车辆登记等环节不仅效率低下,而且容易出错。无人值守停车系统的出现,无人值守停车场系统以其高效、智能的特点,通过集成…

Meta 3D Gen:文生 3D 模型

是由 Meta 公布的一个利用 Meta AssetGen(模型生成)和 TextureGen(贴图材质生成)的组合 AI 系统,可以在分分钟内生成高质量 3D 模型和高分辨率贴图纹理。 视频演示的效果非常好,目前只有论文,期…

多源BFS——AcWing 173. 矩阵距离

多源BFS 定义 多源BFS(多源广度优先搜索)是一种图遍历算法,它是标准BFS(广度优先搜索)的扩展,主要用于解决具有多个起始节点的最短路径问题。在多源BFS中,不是从单一源点开始搜索整个图&#…

怎么把webp格式转换成jpg?5个图片格式转换方法全面解析(2024最新)

webp 图片常用于网站,可显著改善页面的浏览和加载体验。然而,许多设备(如苹果手机设备、安卓手机等)不支持webp文件。在这些设备上查看webp文件时,最佳做法是将其转换为其他常见格式,如jpg或 png。Windows电…

2024上海大学生程序设计竞赛I-六元组计数原根知识详解

以前基本没有了解原根相关的一块内容,最近正好碰到了这个题,于是写一篇博客记录一下。 这道题的总体思路就是比较明显,就是先算出 a p x a^px apx对于每个 x x x的解的个数,然后NTT算一下即可。 先来讲一下怎么求欧拉函数 ϕ ( …

前端FCP指标优化

优化前 第三方依赖按需引入之后,打包的总体积减小到初始值的55%,但是依然存在很大的js文件,需要继续优化 chunk-vendors.js进行分包之后 截图 compression-webpack-plugin压缩之后 截图

大学新生人工智能学习路线规划

1. 引言 七月来临,各省高考分数已揭榜完成。而高考的完结并不意味着学习的结束,而是新旅程的开始。对于有志于踏入IT领域的高考少年们,这个假期是开启探索IT世界的绝佳时机。作为该领域的前行者和经验前辈,我愿意为准新生们提供一…

剪映 v5.5 Pro Vip解锁版:使用指南与注意事项

摘要:本文介绍了剪映Pro VIP解锁版的使用方法,包括安装、测试和使用VIP素材的步骤,以及如何避免误报和保持解锁状态的建议。 正文: 剪映Pro是一款广受欢迎的视频编辑软件,提供了丰富的视频编辑功能和大量高质量的素材…

发送微信消息和文件

参考:https://www.bilibili.com/video/BV1S84y1m7xd 安装: pip install PyOfficeRobotimport PyOfficeRobotPyOfficeRobot.chat.send_message(who"文件传输助手", message"你好,我是PyOfficeRobot,有什么可以帮助…