主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2(包括完整代码+添加步骤+网络结构图)

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nndef channel_shuffle(x, groups):batchsize, num_channels, height, width = x.data.size()channels_per_group = num_channels // groupsx = x.view(batchsize, groups, channels_per_group, height, width)x = torch.transpose(x, 1, 2).contiguous()x = x.view(batchsize, -1, height, width)return xclass CBRM(nn.Module):  # Conv BN ReLU Maxpool2ddef __init__(self, c1, c2):super(CBRM, self).__init__()self.conv = nn.Sequential(nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(c2),nn.ReLU(inplace=True),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)def forward(self, x):return self.maxpool(self.conv(x))class Shuffle_Block(nn.Module):def __init__(self, ch_in, ch_out, stride):super(Shuffle_Block, self).__init__()if not (1 <= stride <= 2):raise ValueError('illegal stride value')self.stride = stridebranch_features = ch_out // 2assert (self.stride != 1) or (ch_in == branch_features << 1)if self.stride > 1:self.branch1 = nn.Sequential(self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(ch_in),nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)self.branch2 = nn.Sequential(nn.Conv2d(ch_in if (self.stride > 1) else branch_features,branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)def forward(self, x):if self.stride == 1:x1, x2 = x.chunk(2, dim=1)out = torch.cat((x1, self.branch2(x2)), dim=1)else:out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)out = channel_shuffle(out, 2)return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------elif m in [CBRM, Shuffle_Block]:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]# --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4- [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8- [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2- [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16- [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4- [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32- [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 3], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 9- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 2], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 12 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 15 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 6], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 18 (P5/32-large)- [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/749569.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

预备知识:深入理解接口测试!

实验简介 随着移动互联网甚至物联网的触角深入到人们生活的每个场景&#xff0c;每个角落&#xff0c;伴随而来的便是企业对其软件系统接口定义和研发&#xff0c;以便于进行数据传输和交换。由此导致目前企业急需大量专职接口测试工程师&#xff0c;因为接口测试天然具备自动…

c++算法学习笔记 (8) 树与图部分

1.树与图的存储 &#xff08;1&#xff09;邻接矩阵 &#xff08;2&#xff09;邻接表 // 链式前向星模板&#xff08;数组模拟&#xff09; #include <iostream> #include <cstring> #include <algorithm> using namespace std; const int N 100010, M …

【RS422】基于未来科技FT4232HL芯片的多波特率串口通信收发实现

功能简介 串行通信接口常常用于在计算机和低速外部设备之间传输数据。串口通信存在多种标准&#xff0c;以RS422为例&#xff0c;它将数据分成多个位&#xff0c;采用异步通信方式进行传输。   本文基于Xilinx VCU128 FPGA开发板&#xff0c;对RS422串口通信进行学习。   根…

家具工厂5G智能制造数字孪生可视化平台,推进家具行业数字化转型

家具制造5G智能制造工厂数字孪生可视化平台&#xff0c;推进家具行业数字化转型。随着科技的飞速发展&#xff0c;家具制造业正迎来一场前所未有的数字化转型。在这场家具制造业转型中&#xff0c;5G智能制造工厂数字孪生可视化平台发挥着至关重要的作用。 5G智能制造工厂数字孪…

MySQL语法分类 DQL(3)排序查询

为了更好的学习这里给出基本表数据用于查询操作 create table student (id int, name varchar(20), age int, sex varchar(5),address varchar(100),math int,english int );insert into student (id,name,age,sex,address,math,english) values (1,马云,55,男,杭州,66,78),…

华为数通方向HCIP-DataCom H12-821题库(多选题:161-180)

第161题 以下关于IPv6优势的描述,正确的是哪些项? A、底层自身携带安全特性 B、加入了对自动配置地址的支持,能够无状态自动配置地址 C、路由表相比IPv4会更大,寻址更加精确 D、头部格式灵活,具有多个扩展头 【参考答案】ABD 【答案解析】 第162题 在OSPF视图下使用Filt…

降维算法之t-SNE (t-Distributed Stochastic Neighbor Embedding)

注意&#xff1a;本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 &#xff08;[www.aideeplearning.cn]&#xff09; t-SNE是一种用于探索高维数据结构的非线性降维技术。它特别适用于高维数据的可视化&#xff0c;因为它能够在低维空间中保留原始高维数据的局部结…

[WUSTCTF2020]朴实无华

查看robots.txt 找到/fAke_flagggg.php 显然这是个假的flag&#xff0c;但是我们在header处发现了fl4g.php 近来发现中文全部变成了乱码 插件转成utf8后正常显示 <?php header(Content-type:text/html;charsetutf-8); error_reporting(0); highlight_file(__file__);//leve…

Linux 系统调用函数fork、vfork、clone详解

文章目录 1 fork1.1 基本介绍1.2 fork实例1.2.1多个fork返回值1.2.2 C语言 fork与输出1.2.3 fork &#x1f4a3; 2 vfork2.1 基本介绍2.2 验证vfork共享内存 3 clone3.1 基本介绍3.2 clone使用 1 fork 1.1 基本介绍 #include <sys/types.h> #include <unistd.h>p…

PS学习-抠图-蒙版-冰块酒杯等透明物体

选中图&#xff0c;ctrlA 全选 ctrlC复制 创建一个蒙版图层 选中蒙版Alt 点击进入 ctrlv 复制 ctrli 反转 原图层 ctrldelete填充为白色 添加一个背景&#xff0c;这个方法通用 首选创建一个 拖到最底部 给它填充颜色 这个可能是我图片的原因。视频是这样做的

五子棋小游戏(sut实验报告)

实验目的 实现人与人或人与电脑进行五子棋对弈 实验内容 启动游戏&#xff0c;显示游戏参数设置界面&#xff0c;用户输入参数后进入游戏界面&#xff0c;显示棋盘及双方博弈过程&#xff0c;游戏过程中可选择退出游戏。判定一方获胜后结束本局游戏&#xff0c;可选择继续下…

Fiddler抓不到包

fiddler该设置的设置好之后&#xff0c;为啥就是抓不到包 以下都是以谷歌浏览器为例子 方法一&#xff1a; 将fidder证书导入到浏览器&#xff0c;设置搜索证书-->安全-->管理证书 这里可以看到将证书导入之后样子&#xff0c;名字为&#xff1a;DO_NOT_TRUST_Fiddler…

Linux操作系统裸机开发-环境搭建

一、配置SSH服务 1、下载安装ssh服务输入以下命令 sudo apt-get install nfs-kernel-server portmap2、建立一个供SSH服务使用的文件夹如以下命令 mkdir linux 3、完成前两步之后需要将其文件路径放到/etc/exports文件里输入以下命令&#xff1a; sudo vi /etc/esports 4.打…

线性回归 quickstart

构建一元一次方程 100个&#xff08;X, y &#xff09;&#xff0c;大概是’y3x4’ import numpy as npnp.random.seed(42) # to make this code example reproducible m 100 # number of instances X 2 * np.random.rand(m, 1) # column vector y 4 3 * X np.random…

最详细数据仓库项目实现:从0到1的电商数仓建设(数仓部分)

1、数仓 数据仓库是一个为数据分析而设计的企业级数据管理系统&#xff0c;它是一个系统&#xff0c;不是一个框架。可以独立运行的&#xff0c;不需要你参与&#xff0c;只要运行起来就可以自己运行。 数据仓库不是为了存储&#xff08;但是能存&#xff09;&#xff0c;而是…

创业板指399006行情数据API接口

# 测试&#xff1a;返回不超过10条数据&#xff08;2年历史&#xff09; https://tsanghi.com/api/fin/index/CHN/daily?tokendemo&ticker399006&order2Python示例 import requestsurl f"https://tsanghi.com/api/fin/index/CHN/daily?tokendemo&ticker399…

EtherCAT 开源主站 IGH 在 linux 开发板的移植和伺服通信测试

手边有一套正点原子linux开发板imax6ul&#xff0c;一直在吃灰&#xff0c;周末业余时间无聊&#xff0c;把EtherCAT的开源IGH主站移植到开发板上玩玩儿&#xff0c;搞点事情做。顺便学习研究下EtherCAT总线协议及其对伺服驱动器的运动控制过程。实验很有意思&#xff0c;这里总…

【JDBC编程】 Java程序操作数据库

目录 一、数据库编程的必备条件 二、什么是JDBC&#xff1f; 三、JDBC的使用 1. 准备工作 2. 建立数据库连接 2.1 加载驱动程序 2.2 数据库连接池技术 3. 正式操作 四、JDBC的局限性与MyBatis的优势 一、数据库编程的必备条件 编程语言&#xff0c;如Java&#xff0…

创业新手看过来!四招助你开启成功之旅

如果你每个月的薪资仅有几千块&#xff0c;还背负着债务的重担&#xff0c;家中的老少都期盼着你为他们撑起一片天&#xff0c;那么&#xff0c;你每日都可能为了如何打破这一困境而焦虑不安。不过&#xff0c;请稍安勿躁&#xff0c;今天我将为你提供四个建议&#xff0c;或许…

Transformer学习笔记(二)

一、文本嵌入层Embedding 1、作用&#xff1a; 无论是源文本嵌入还是目标文本嵌入&#xff0c;都是为了将文本中词汇的数字表示转变为向量表示&#xff0c;希望在这样的高维空间捕捉词汇间的关系。 二、位置编码器Positional Encoding 1、作用&#xff1a; 因为在Transformer…