昇思学习打卡-10-ShuffleNet图像分类

文章目录

  • 网络介绍
  • 网络结构
    • 部分实现
    • 对应网络结构
  • 模型训练
  • shuffleNet的优缺点总结
    • 优点
    • 不足

网络介绍

ShuffleNet主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

网络结构

部分实现

class ShuffleV1Block(nn.Cell):def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = stridepad = ksize // 2self.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupself.relu = nn.ReLU()branch_main_1 = [GroupConv(in_channels=inp, out_channels=mid_channels,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=1 if first_group else group),nn.BatchNorm2d(mid_channels),nn.ReLU(),]branch_main_2 = [nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,pad_mode='pad', padding=pad, group=mid_channels,weight_init='xavier_uniform', has_bias=False),nn.BatchNorm2d(mid_channels),GroupConv(in_channels=mid_channels, out_channels=outputs,kernel_size=1, stride=1, pad_mode="pad", pad=0,groups=group),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.SequentialCell(branch_main_1)self.branch_main_2 = nn.SequentialCell(branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')def construct(self, old_x):left = old_xright = old_xout = old_xright = self.branch_main_1(right)if self.group > 1:right = self.channel_shuffle(right)right = self.branch_main_2(right)if self.stride == 1:out = self.relu(left + right)elif self.stride == 2:left = self.branch_proj(left)out = ops.cat((left, right), 1)out = self.relu(out)return outdef channel_shuffle(self, x):batchsize, num_channels, height, width = ops.shape(x)group_channels = num_channels // self.groupx = ops.reshape(x, (batchsize, group_channels, self.group, height, width))x = ops.transpose(x, (0, 2, 1, 3, 4))x = ops.reshape(x, (batchsize, num_channels, height, width))return x

对应网络结构

在这里插入图片描述

模型训练

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracydef train():mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")net = ShuffleNetV1(model_size="2.0x", n_class=10)loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)min_lr = 0.0005base_lr = 0.05lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,batches_per_epoch*250,batches_per_epoch,decay_epoch=250)lr = Tensor(lr_scheduler[-1])optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)callback = [TimeMonitor(), LossMonitor()]save_ckpt_path = "./"config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)callback += [ckpt_callback]print("============== Starting Training ==============")start_time = time.time()# 由于时间原因,epoch = 5,可根据需求进行调整model.train(5, dataset, callbacks=callback)use_time = time.time() - start_timehour = str(int(use_time // 60 // 60))minute = str(int(use_time // 60 % 60))second = str(int(use_time % 60))print("total time:" + hour + "h " + minute + "m " + second + "s")print("============== Train Success ==============")if __name__ == '__main__':train()

在这里插入图片描述

shuffleNet的优缺点总结

优点

  • 轻量化设计:ShuffleNet通过减少参数和计算量,使得模型更适合在资源受限的设备上运行。
  • 高效的通道重排:ShuffleNet引入了一种高效的通道重排机制,称为"Shuffle",以提高模型的表达能力,同时保持参数数量较低。
  • 性能与速度的平衡:ShuffleNet在保持较低延迟的同时,提供了与较大模型相当的准确性,这使得它在实时应用中非常有用。
  • 适用于移动设备:由于其轻量化的特性,ShuffleNet非常适合在移动设备和嵌入式系统中使用。
  • 参数共享:ShuffleNet通过在网络中使用参数共享技术,进一步减少了模型的大小和计算成本。
  • 灵活性:ShuffleNet提供了不同版本的模型,允许研究者和开发者根据特定应用的需求选择合适的模型大小。

不足

  • 准确率折损:由于模型的轻量化设计,ShuffleNet在某些复杂任务上可能无法达到最高精度。
  • 特定任务的局限性:在一些需要非常高精度或对模型容量要求较高的任务中,ShuffleNet可能不是最佳选择。
  • 对输入尺寸敏感:ShuffleNet的性能可能对输入图像的尺寸较为敏感,这可能需要在设计网络时进行额外的考虑。

此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/42438.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国内采用docker部署open-metadata

背景 最近看看开源的元数据管理项目,比较出名点的有open-metadata、datahub、OpenLineage、atlas。 open-metadata有1千多的贡献者,4.8K的stars,社区现在也比较活跃,支持的数据库类型还蛮多,基本市面上常见的都有支持…

【每日一练】python三目运算符的用法

""" 三目运算符与基础运算的对比 """ a 1 b 2#1.基础if运算判断写法: if a > b:print("基础判断输出:a大于b") else:print("基础判断输出: a不大于b")#2.三目运算法判断:…

揭秘IP:从虚拟地址到现实世界的精准定位

1.IP地址介绍 1.内网 IP 地址(私有 IP 地址) 内网 IP 地址,即私有 IP 地址,是在局域网(LAN)内部使用的 IP 地址。这些地址不会在公共互联网中路由,因此可以在多个局域网中重复使用。私有 IP 地…

股票Level-2行情是什么,应该怎么使用,从哪里获取数据

行情接入方法 level2行情websocket接入方法-CSDN博客 相比传统的股票行情,Level-2行情为投资者打开了更广阔的视野,不仅限于买一卖一的表面数据,而是深入到市场的核心,提供了十档乃至千档的行情信息(沪市十档&#…

STM32第十六课:WiFi模块的配置及应用

文章目录 需求一、WiFi模块概要二、配置流程1.配置通信串口,引脚和中断2.AT指令3.发送逻辑编写 三、需求实现代码总结 需求 完成WiFi模块的配置,使其最终能和服务器相互发送消息。 一、WiFi模块概要 本次使用的WiFi模块为ESP-12F模块(安信可&#xf…

【LLM第8篇】Delta Tuning

如何对large-scale PLM进行调整呢? 一个有效的方式是delta tuning;只更新PLM中的一小部分参数,其它参数不动。 把解决任务的能力具象化成delta object这样的参数,只需要几十兆参数存储。 过去模型参数是随机的,现在预…

【MySQL】逻辑架构与存储引擎

一、逻辑架构 1、MySQL逻辑架构 我们可以根据上图来对sql的执行过程进行分析 第一步:客户端与服务器建立一个连接,从连接池中分配一个线程处理SQL语句第二步:SQL接口接受SQL指令第三步:如果是5.7版本,就会先去缓存中…

Python字符串处理常用的30种操作

我们平时编写代码时,经常需要对字符串进行处理,本文详细介绍Python处理字符串常用的30种操作,并给出了对应的代码。 分割 使用split()方法将字符串按照指定的分隔符进行分割。 s "Hello,World" result s.split(","…

国产AI芯片被撕下遮羞布,宁买阉割八成性能的NVIDIA,也不买国产

曾经有传言指有国产AI芯片大受欢迎,还卖出了100万片,不过半年多时间过去,海外分析机构指出国内的互联网企业纷纷抢购NVIDIA阉割八成性能的H20,至于国产AI芯片则不获欢迎。 导致如此结果,在于NVIDIA拥有许多独特的优势&…

论文略读: LLaMA Pro: Progressive LLaMA with Block Expansion

ACL 2024 人类通常在不损害旧技能的情况下获得新技能 然而,对于大型语言模型(LLMs),例如从LLaMA到CodeLLaMA,情况正好相反。深度学习笔记:灾难性遗忘-CSDN博客——>论文提出了一种用于LLMs的新的预训练…

Nettyの源码分析

本篇为Netty系列的最后一篇,按照惯例会简单介绍一些Netty相关核心源码。 1、Netty启动源码分析 代码就使用最初的Netty服务器案例,在bind这一行打上断点,观察启动的全过程: 由于某些方法的调用链过深,节约篇幅&#xf…

昇思MindSpore学习笔记4-03生成式--Diffusion扩散模型

摘要: 记录昇思MindSpore AI框架使用DDPM模型给图像数据正向逐步添加噪声,反向逐步去除噪声的工作原理和实际使用方法、步骤。 一、概念 1. 扩散模型Diffusion Models DDPM(denoising diffusion probabilistic model) (无)条件…

【嵌入式DIY实例-ESP8266篇】-LCD ST7735显示BMP280传感器数据

LCD ST7735显示BMP280传感器数据 文章目录 LCD ST7735显示BMP280传感器数据1、硬件准备与接线2、代码实现本文介绍如何将 ESP8266 NodeMCU 板 (ESP-12E) 与 Bosch Sensortec 的 BMP280 气压和温度传感器连接。 NodeMCU 微控制器 (ESP8266EX) 从 BMP280 传感器读取温度和压力值,…

普通Java工程如何在代码中引用docker-compose.yml中的environment值

文章目录 一、概述二、常规做法1. 数据库配置分离2. 代码引用配置3. 编写启动类4. 支持打包成可执行包5. 支持可执行包打包成docker镜像6. docker运行 三、存在问题分析四、改进措施1. 包含environment 变量的编排文件2. 修改读取配置文件方式3. 为什么可以这样做 五、运行效果…

python库(6):Pygments库

1 Pygments介绍 在软件开发和文档编写中,代码的可读性是至关重要的一环。无论是在博客文章、技术文档还是教程中,通过代码高亮可以使程序代码更加清晰和易于理解。而在Python世界中,Pygments库就是这样一个强大的工具,它能够将各…

ValueError: Expected EmbeddingFunction.__call__ to have the following signature

题意: 使用 langchain 时,特别是在定义或调用嵌入函数(Embedding Function)时,签名(函数的参数列表和返回类型)不符合预期 问题背景: When I try to pass a Chroma Client to Lang…

搭建论坛和mysql数据库安装和php安装

目录 概念 步骤 安装mysql8.0.30 安装php 安装Discuz 概念 搭建论坛的架构: lnmpDISCUZ l 表示linux操作系统 n 表示nginx前端页面的web服务 m 表示 mysql 数据库 用来保存用户和密码以及论坛的相关内容 p 表示php 动态请求转发的中间件 步骤 &#xff…

【C++深度探索】:继承(定义赋值兼容转换作用域派生类的默认成员函数)

✨ 愿随夫子天坛上,闲与仙人扫落花 🌏 📃个人主页:island1314 🔥个人专栏:C学习 🚀 欢迎关注:👍点赞…

CVE-2024-0603 漏洞复现

CVE-2024-0603 源码:https://gitee.com/dazensun/zhicms 开题: CVE-2024-0603描述:ZhiCms up to 4.0版本的文件app/plug/controller/giftcontroller.php中存在一处未知漏洞。攻击者可以通过篡改参数mylike触发反序列化,从而远程…

python脚本“文档”撰写——“诱骗”ai撰写“火火的动态”python“自动”脚本文档

“火火的动态”python“自动”脚本文档,又从ai学习搭子那儿“套”来,可谓良心质量👍👍。 (笔记模板由python脚本于2024年07月07日 15:15:33创建,本篇笔记适合喜欢钻研python和页面源码的coder翻阅) 【学习的细节是欢悦…