昇思25天学习打卡营第14天 | ShuffleNet图像分类

昇思25天学习打卡营第14天 | ShuffleNet图像分类

文章目录

  • 昇思25天学习打卡营第14天 | ShuffleNet图像分类
    • ShuffleNet
      • Pointwise Group Convolution
      • Channel Shuffle
      • ShuffleNet模块
      • 网络构建
    • 模型训练与评估
      • 数据集
      • 训练
      • 模型评估
      • 模型预测
    • 总结
    • 打卡

ShuffleNet

ShuffleNetV1是旷世科技提出的一种计算高效的CNN模型,这种模型利用有限的计算资源来达到最好的模型精度,主要应用在移动端。

ShuffleNetV1的核心是引入了两种操作:

  • Pointwise Group Convolution
  • Channel Shuffle
    这两种操作在保持精度的同时大大降低了模型的计算量。

Pointwise Group Convolution

Group Convolution(分组卷积)相对于普通卷积,每一组的卷积核大小为 in_channels / g ∗ k ∗ k \text{in\_channels} / g * k *k in_channels/gkk,一共有 g g g组,所有组共有 ( in_channels / g ∗ k ∗ k ) ∗ out_channels (\text{in\_channels}/g*k*k)*\text{out\_channels} (in_channels/gkk)out_channels个参数,是正常卷积参数的 1 / g 1/g 1/g
分组卷积的每个卷积核只处理特征图的一部分通道,但输出通道数仍等于卷积核的数量。
shufflenet2

图片来源:Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.

Depthwise Convolution(深度可分离卷积)将输入特征图的每个通道分开,分别使用一个卷积核进行卷积,假设卷积核大小为 1 × k × k 1\times k\times k 1×k×k,其中 1 1 1表示只对一个通道进行卷积。由于有in_channels个卷积核,故有 in_channels × k × k \text{in\_channels}\times k \times k in_channels×k×k个参数,得到的特征图通道数与输入相同。

Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组卷积核大小为 1 × 1 1\times 1 1×1,故共有 ( in_channels / g × 1 × 1 ) × out_channels (\text{in\_channels}/g\times1\times1)\times \text{out\_channels} (in_channels/g×1×1)×out_channels个参数。

from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensorclass GroupConv(nn.Cell):def __init__(self, in_channels, out_channels, kernel_size,stride, pad_mode="pad", pad=0, groups=1, has_bias=False):super(GroupConv, self).__init__()self.groups = groupsself.convs = nn.CellList()for _ in range(groups):self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,kernel_size=kernel_size, stride=stride, has_bias=has_bias,padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))def construct(self, x):features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)outputs = ()for i in range(self.groups):outputs = outputs + (self.convs[i](features[i].astype("float32")),)out = ops.cat(outputs, axis=1)return out

Channel Shuffle

Group Convolution只能保证组内的特征提取,而不同组之间的特征是不通信的,这就降低了网络的特征提取能力。
为了解决这个问题,ShuffleNet引入了Channel Shuffle机制,将不同分组通道均匀分散重组,使得网络在下一层能处理不同组别通道的信息。
shufflenet3
对于 g g g组,每组有 n n n个通道的特征图:

  1. reshape g × n g\times n g×n的矩阵;
  2. 转置为 n × g n\times g n×g的矩阵;
  3. 通过flatten操作得到新的排列。
    shufflenet4

ShuffleNet模块

ShuffleNet对ResNet中Bottleneck结构进行由(a)到(b), (c)的更改:
shufflenet5

  1. 将开始和最后的 1 × 1 1\times 1 1×1卷积模块改成Pointwise Group Convolution;
  2. 在降维后进行Channel Shuffle;
  3. 降采样模块中, 3 × 3 3\times 3 3×3Depthwise Convolution 的步长设置为2,特征图大小减半,因此shortcuts中采用步长为2的 3 × 3 3\times 3 3×3平均池化,并把相加改成拼接。
class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return outdef channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

网络构建

shufflenet6

class ShuffleNetV1(nn.Cell):def __init__(self, n_class=1000, model_size='2.0x', group=3):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)self.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedErrorinput_channel = self.stage_out_channels[1]self.first_conv = nn.SequentialCell(nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage + 2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.SequentialCell(features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)def construct(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = ops.reshape(x, (-1, self.stage_out_channels[-1]))x = self.classifier(x)return x

模型训练与评估

数据集

采用CIFAR-10数据集进行预训练。


![shufflenet6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/shufflenet_6.png)import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transformsdef get_dataset(train_dataset_path, batch_size, usage):image_trans = []if usage == "train":image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]elif usage == "test":image_trans = [vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]label_trans = transforms.TypeCast(ms.int32)dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)dataset = dataset.map(image_trans, 'image')dataset = dataset.map(label_trans, 'label')dataset = dataset.batch(batch_size, drop_remainder=True)return datasetdataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()

训练

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()# 由于时间原因,epoch = 5,可根据需求进行调整model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

模型评估

调用model.eval()接口对模型进行评估。

from mindspore import load_checkpoint, load_param_into_netdef test():mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")net = ShuffleNetV1(model_size="2.0x", n_class=10)param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")load_param_into_net(net, param_dict)net.set_train(False)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),'Top_5_Acc': Top5CategoricalAccuracy()}model = Model(net, loss_fn=loss, metrics=eval_metrics)start_time = time.time()res = model.eval(dataset, dataset_sink_mode=False)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \+ "', time: " + hour + "h " + minute + "m " + second + "s"print(log)filename = './eval_log.txt'with open(filename, 'a') as file_object:file_object.write(log + '\n')if __name__ == '__main__':test()

模型预测

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as dsnet = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5),vision.Resize((224, 224)),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:plt.subplot(2, 8, index+1)plt.title('{}'.format(class_dict[pred[index]]))index += 1plt.imshow(image)plt.axis("off")
plt.show()

总结

这一节介绍了ShuffleNet的基本结构,为了在移动设备这样的有限资源上进行训练,ShuffleNet提出了Pointwise Group Convolution操作以大幅减少参数量,使用Channel Shuffle来确保网络的特征提取能力。通过在ResNet的Bottleneck结构中应用上面的两种操作,得到ShuffleNet的基本网络模块。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/46506.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙实训笔记

第一天 #初始化一个新的NPM项目(根据提示操作) npm init #安装TSC、TSLint和NodeJS的类型声明 npm install -s typescript tslint types/node 在根目录中新建一个名为tsconfig.json的文件,然后在代码编辑器中打开,写入下述内容: {"co…

MATLAB激光通信和-积消息传递算法(Python图形模型算法)模拟调制

🎯要点 🎯概率论和图论数学形式和图结构 | 🎯数学形式、图结构和代码验证贝叶斯分类器算法:🖊多类型:朴素贝叶斯,求和朴素贝叶斯、高斯朴素贝叶斯、树增强贝叶斯、贝叶斯网络增强贝叶斯和半朴素…

网络层重点协议—IP协议

在复杂的网络环境中确定一个合适的路径 协议头格式如下: 4位版本号(version) 指定协议的版本(IPV4-4,IPV6-6) 4位首部长度(header length) IP头部的长度是多少个32bit,也就是length*4的字节数。4bit表示最大的数字是15&#x…

【密码学】密码学数学基础:群的定义

一、群的定义 在密码学中,群(Group)的概念是从抽象代数借用来的,它是一种数学结构,通常用于描述具有特定性质的运算集合。 群的定义 群定义中的几个关键要素: 集合:首先,群是由一系…

AutoMQ 中的元数据管理

本文所述 AutoMQ 的元数据管理机制均基于 AutoMQ Release 1.1.0 版本 [1]。 01 前言 AutoMQ 作为新一代基于云原生理念重新设计的 Apache Kafka 发行版,其底层存储从传统的本地磁盘替换成了以对象存储为主的共享存储服务。对象存储为 AutoMQ 带来可观成本优势的…

draggable 实现一个简单的拖拽

拖拽区域代码 <draggable v-if="activeFirstIndex !== 8" :list="showResourseList" :group="{ name: resources, pull: clone, put: false }" :sort="false" :multiple="false" :move="onMove1" @end="…

【JavaScript 算法】冒泡排序:简单有效的排序方法

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、算法原理二、算法实现三、应用场景四、优化与扩展五、总结 冒泡排序&#xff08;Bubble Sort&#xff09;是一种基础的排序算法&#xff0c;通过重复地遍历要排序的数列&#xff0c;一次比较两个元素&#xff0c;如果它…

【香橙派 AIpro测评:探索高效图片分类项目实战】

前言 最近入手了一块香橙派 AIpro开发板&#xff0c;在使用中被它的强大深深震撼&#xff0c;有感而发写下这篇文章。 本文旨在深入探讨OrangePi AIpro的各项性能&#xff0c;从硬件配置、软件兼容性到实际应用案例&#xff0c;全方位解析这款设备如何在开源社区中脱颖而出&am…

案例 | 人大金仓助力山西政务服务核心业务系统实现全栈国产化升级改造

近日&#xff0c;人大金仓支撑山西涉企政策服务平台、政务服务热线联动平台、政务网、办件中心等近30个政务核心系统完成全栈国产化升级改造&#xff0c;推进全省通办、跨省通办、综合业务受理、智能审批、一件事一次办等业务的数字化办结进程&#xff0c;为我国数字政务服务提…

数据结构(Java):LinkedList集合Stack集合

1、集合类LinkedList 1.1 什么是LinkedList LinkedList的底层是一个双向链表的结构&#xff08;故不支持随机访问&#xff09;&#xff1a; 在LinkedList中&#xff0c;定义了first和last&#xff0c;分别指向链表的首节点和尾结点。 每个节点中有一个成员用来存储数据&…

构建高效智能标准化仓库

在快节奏的现代商业环境中&#xff0c;仓库作为供应链的核心枢纽&#xff0c;其运营效率与管理水平直接影响着企业的整体竞争力。一个“高效智能标准化的仓库”&#xff0c;不仅是货物有序存储的空间&#xff0c;更是降本增效、提升客户满意度的关键所在。 在传统工厂管理模式下…

AI Agent 开发综合指南

本文介绍了 ReAct 模式以改进功能&#xff0c;并演示了如何从头开始创建 AI 代理。它涵盖了测试、调试和优化 AI 代理&#xff0c;以及工具、库、环境设置和实施。本教程为用户提供了创建有效 AI 代理所需的技能&#xff0c;无论他们是开发人员还是爱好者。 NSDT工具推荐&#…

【Linux】01.Linux 的常见指令

1. ls 指令 语法&#xff1a;ls [选项] [目录名或文件名] 功能&#xff1a;对于目录&#xff0c;该命令列出该目录下的所有子目录与文件。对于文件&#xff0c;将列出文件名以及其他信息 常用选项&#xff1a; -a&#xff1a;列出当前目录下的所有文件&#xff0c;包含隐藏文件…

从 Pandas 到 Polars 十八:数据科学 2025,对未来几年内数据科学领域发展的预测或展望

我在2021年底开始使用Polars和DuckDB。我立刻意识到这些库很快就会成为数据科学生态系统的核心。自那时起&#xff0c;这些库的受欢迎程度呈指数级增长。 在这篇文章中&#xff0c;我做出了一些关于未来几年数据科学领域的发展方向和原因的预测。 这篇文章旨在检验我的预测能力…

开始Linux之路

人生得一知己足矣&#xff0c;斯世当以同怀视之。——鲁迅 Linux操作系统简单操作指令 1、ls指令2、pwd命令3、cd指令4、mkdir指令(重要)5、whoami命令6、创建一个普通用户7、重新认识指令8、which指令9、alias命令10、touch指令11、rmdir指令 及 rm指令(重要)12、man指令(重要…

记录自己Ubuntu加Nvidia驱动从入门到入土的一天

前言 记录一下自己这波澜壮阔的一天&#xff0c;遇到了很多问题&#xff0c;解决了很多问题&#xff0c;但是还有很多问题&#xff0c;终于在晚上的零点彻底放弃&#xff0c;重启windows。 安装乌班图 1.安装虚拟机 我开始什么操作系统的基础都没有&#xff0c;网上随便搜了…

JDBC基础 -获取连接的方式、结果集、批处理、事务处理、连接池、Apache-DBUtils

文章目录 概述快速入门(增删改)获取数据库的五种方式方式一&#xff1a;获取Driver实现类对象方式二&#xff1a;反射方式三&#xff1a;使用DriverManager代替Driver方式四&#xff1a;Class.forName自动完成注册驱动&#xff08;推荐&#xff09;方式五&#xff1a;使用prope…

请你谈谈:BeanDefinition类作为Spring Bean的建模对象,与BeanFactoryPostProcessor之间的羁绊

那么&#xff0c;我们如何理解Spring Bean的建模对象呢&#xff1f;简而言之&#xff0c;它是指用于描述和配置Bean实例化过程的模型对象。有人可能会提出疑问&#xff0c;既然只需要Class&#xff08;类&#xff09;就可以实例化一个对象&#xff0c;Class作为类的元数据&…

springboot websocket 知识点汇总

以下是一个详细全面的 Spring Boot 使用 WebSocket 的知识点汇总 1. 配置 WebSocket 添加依赖 进入maven官网, 搜索spring-boot-starter-websocket&#xff0c;选择版本, 然后把依赖复制到pom.xml的dependencies标签中 配置 WebSocket 创建一个配置类 WebSocketConfig&…

mysql不初始化升级

1、下载mysql&#xff0c;下载地址&#xff1a;MySQL :: Download MySQL Community Server 2、解压下载好的mysql&#xff0c;修改配置文件的datadir指定目录为当前数据存储的目录 3、通过管理员cmd进入新版本mysql的bin目录&#xff0c; 然后执行命令安装mysql服务&#xff…