YOLOv5、YOLOv8改进:RepVGG结构

1.简介 

论文参考:最新RepVGG结构: Paper

我们所说的“VGG式”指的是:

  1. 没有任何分支结构。即通常所说的plain或feed-forward架构。

  2. 仅使用3x3卷积。

  3. 仅使用ReLU作为激活函数。

主要创新点为结构重参数化。在训练时,网络的结构是多分支进行的,而在推理时则将分支的参数进行重参数化,合为一个分支来进行的,所以推理的速度要比多分支网络快很多,并且精度也比单分支的网络更高。

 1.1 RepVGG Block

整个RepVGG网络结构很简单,就是不断地堆叠RepVGG Block,所以了解了RepVGG Block就基本了解了整个RepVGG网络结构。下图中,左图为stride=2进行下采样时的RepVGG Block结构,右图为stride=1时的RepVGG Block结构。可以看到一般不进行下采样时,RepVGG Block有三个分支,分别是卷积核为3x3的主分支、卷积核为1x1的shortcut分支和只含BN层的shortcut分支。

这里有几个问题,

一、为什么多分支结构的精度会比单分支高?

二、为什么单分支结构会比多分支结构速度快了将近一倍?

先来说第一个问题,为什么多分支结构的精度会比单分支高?因为之前的模型像Inception系列、ResNet以及DenseNet等模型,我们能够发现这些模型都并行了多个分支。至少根据现有的一些经验来看,并行多个分支一般能够增加模型的表征能力。在论文的表6中,作者也做了个简单的消融实验,证明了增加分支是能够提升精度的。
 

接着是第二个问题,为什么单分支结构会比多分支结构速度快了将近一倍?根据论文3.1章节的内容可知,采用单路模型会更快、更省内存并且更加的灵活。

更快:主要是考虑到模型在推理时硬件计算的并行程度以及MAC(memory access cost),对于多分支模型,硬件需要分别计算每个分支的结果,有的分支计算的快,有的分支计算的慢,而计算快的分支计算完后只能干等着,等其他分支都计算完后才能做进一步融合,这样会导致硬件算力不能充分利用,或者说并行度不够高。而且每个分支都需要去访问一次内存,计算完后还需要将计算结果存入内存(不断地访问和写入内存会在IO上浪费很多时间)。

但是这里的说法又有一个问题,在结构重参数化之前,最耗时的是3x3卷积的时间,1x1卷积和恒等映射的时间比较短,需要等3x3的卷积时间结束之后,才继续走下一个模块,所以应该可以理解成,重参数化之前的时间其实就是一个3x3卷积操作的时间加上三个分支融合的时间呢?然而在重参数化之后,3x3卷积还是存在的,所以总的时间还是3x3卷积的时间,也就是说,结构重参数化之后所节省的时间仅仅只是一个三分支融合的时间,但是为什么最后的结果却节省了将近一倍的时间,这里不太理解。

更省内存:在论文的图3当中,作者举了个例子,如图(A)所示的Residual模块,假设卷积层不改变channel的数量,那么在主分支和shortcut分支上都要保存各自的特征图或者称Activation,那么在add操作前占用的内存大概是输入Activation的两倍,而图(B)的Plain结构占用内存始终不变。

更加灵活:作者在论文中提到了模型优化的剪枝问题,对于多分支的模型,结构限制较多剪枝很麻烦,而对于Plain结构的模型就相对灵活很多,剪枝也更加方便。

除此之外,在多分支转化成单路模型后很多算子进行了融合(比如Conv2d和BN融合),使得计算量变小了,而且算子减少后启动kernel的次数也减少了(比如在GPU中,每次执行一个算子就要启动一次kernel,启动kernel也需要消耗时间)。而且现在的硬件一般对3x3的卷积操作做了大量的优化,转成单路模型后采用的都是3x3卷积,这样也能进一步加速推理。

1.2结构重参数化

在了解了RepVGG Block后就到了这篇文章最重要的、也是最核心的部分,结构重参数化。也就是如何将三个分支的内容最后融合到一个主分支当中。

这个过程主要分为两步,

一、将三个分支中的卷积算子和BN算子都融合为卷积算子(一个卷积核加一个偏置的形式)

二、将三个分支上的卷积算子都化为3x3卷积核和偏置的形式,相加得到最终的主分支上的结果。

下图就能完美体现这两个过程融合Conv2d和BN,将三个分支上的卷积算子和BN算子都转化为卷积算子(包括卷积核和偏置)


因为Conv2d和BN两个算子都是做线性运算,所以可以融合成一个算子。如果不了解卷积层的计算过程以及BN的计算过程的话建议先了解后再看该部分的内容。这里还需要强调一点,融合是在网络训练完之后做的,所以现在讲的默认都是推理模式,注意BN在训练以及推理时计算方式是不同的。
 

2.在YOLOv7中加入RepVGG模块

2.1YOLOv7的yaml配置文件

代码
# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 1.0  # layer channel multiple# anchors
anchors:- [12,16, 19,36, 40,28]  # P3/8- [36,75, 76,55, 72,146]  # P4/16- [142,110, 192,243, 459,401]  # P5/32# yolov7 backbone by yoloair
backbone:# [from, number, module, args][[-1, 1, Conv, [32, 3, 1]],  # 0[-1, 1, Conv, [64, 3, 2]],  # 1-P1/2[-1, 1, Conv, [64, 3, 1]],[-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 [-1, 1, RepVGGBlock, [128, 3, 2]], # 5-P4/16[-1, 1, Conv, [256, 3, 2]], [-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3], 1, Concat, [1]],  # 16-P3/8[-1, 1, Conv, [128, 1, 1]],[-2, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[-1, 1, Conv, [128, 3, 1]],[[-1, -3, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [512, 1, 1]],[-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]],[[-1, -3], 1, Concat, [1]],[-1, 1, Conv, [256, 1, 1]],[-2, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[-1, 1, Conv, [256, 3, 1]],[[-1, -3, -5, -6], 1, Concat, [1]],[-1, 1, Conv, [1024, 1, 1]],          [-1, 1, MP, []],[-1, 1, Conv, [512, 1, 1]],[-3, 1, Conv, [512, 1, 1]],[-1, 1, Conv, [512, 3, 2]],[[-1, -3], 1, Concat, [1]],[-1, 1, C3C2, [1024]],[-1, 1, Conv, [256, 3, 1]],]# yolov7 head by yoloair
head:[[-1, 1, SPPCSPC, [512]],[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[31, 1, Conv, [256, 1, 1]],[[-1, -2], 1, Concat, [1]],[-1, 1, C3C2, [128]],[-1, 1, Conv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[18, 1, Conv, [128, 1, 1]],[[-1, -2], 1, Concat, [1]],[-1, 1, C3C2, [128]],[-1, 1, MP, []],[-1, 1, Conv, [128, 1, 1]],[-3, 1, Conv, [128, 1, 1]],[-1, 1, Conv, [128, 3, 2]],[[-1, -3, 44], 1, Concat, [1]],[-1, 1, C3C2, [256]], [-1, 1, MP, []],[-1, 1, Conv, [256, 1, 1]],[-3, 1, Conv, [256, 1, 1]],[-1, 1, Conv, [256, 3, 2]], [[-1, -3, 39], 1, Concat, [1]],[-1, 3, C3C2, [512]],# 检测头 -----------------------------[49, 1, RepConv, [256, 3, 1]],[55, 1, RepConv, [512, 3, 1]],[61, 1, RepConv, [1024, 3, 1]],[[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)]

2.2 common.py配置

在./models/common.py文件中增加以下模块,直接复制即可        

class RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3,stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelspadding_11 = padding - kernel_size // 2self.nonlinearity = nn.SiLU()# self.nonlinearity = nn.ReLU()if use_se:self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)else:self.se = nn.Identity()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride,padding=padding, dilation=dilation, groups=groups, bias=True,padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,padding=padding_11, groups=groups)# print('RepVGG Block, identity = ', self.rbr_identity)
def switch_to_deploy(self):if hasattr(self, 'rbr_1x1'):kernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasfor para in self.parameters():para.detach_()self.rbr_dense = self.rbr_reparam# self.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Truedef get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef forward(self, inputs):if self.deploy:return self.nonlinearity(self.rbr_dense(inputs))if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.se(self.rbr_reparam(inputs)))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

其中缺少的C3C2模块看本专栏相关文章,你也可以改回原来的模块

2.3 yolo.py配置

然后找到./models/yolo.py文件下里的parse_model函数,将类名加入进去
在 models/yolo.py文件夹下

  • parse_model函数中
  • for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
  • 对应位置 下方只需要增加 RepVGGBlock模块
elif m is RepVGGBlock:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/94792.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

P4727 [HNOI2009] 图的同构计数

B u r n s i d e : ∣ X / G ∣ 1 ∣ G ∣ ∑ g ∈ G ∣ X ( g ) ∣ Burnside:|X/G|\frac{1}{|G|}\sum\limits_{g\in G}|X^{(g)}| Burnside:∣X/G∣∣G∣1​g∈G∑​∣X(g)∣ 该题: ∣ X / G ∣ 1 ∣ G ∣ ∑ b 2 k Π ( b i …

Visopsys 0.92 发布

Visopsys 是一个 PC 机的操作系统,系统小型、快速而且开源。有着丰富的图形界面、抢先式多任务机制以及支持虚拟内存。Visopsys 视图兼容很多操作系统,但并不是他们的克隆版本。Visopsys 0.92 现已发布,此维护版本引入了多任务处理程序、文件…

JS里实现判断条件不通过退出整个循环

总结,使用for...in循环实现。 如果你需要在JavaScript中使用类似break的行为,你可以使用for...in循环。下面是一个例子: let arr [1, 2, 3, 4, 5];for (let index in arr) {if (arr[index] > 3) {console.log(找到了大于3的元素&#xf…

3126: 【C2】【结构体】小明做调查

题目描述 小明的班级同学们很不喜欢小明,小明想做点什么改变同学们的印象,现在小明想调查班级中每个同学的生日,并按照从大到小的顺序排序。但和蔼的编程老师总让他去放鸭子了,没有时间,所以请你帮他排序。 输入 有…

二叉树题目:路径总和 II

文章目录 题目标题和出处难度题目描述要求示例数据范围 前言解法一思路和算法代码复杂度分析 解法二思路和算法代码复杂度分析 题目 标题和出处 标题:路径总和 II 出处:113. 路径总和 II 难度 4 级 题目描述 要求 给你二叉树的根结点 root \tex…

MySQL找回误删的数据,数据恢复

原创作品,未经同意,请勿转载;允许复制链接,对原文直接进行转发。 原创作者玉龙有着十几年大厂软件开发工作经验, 目前自由职业, 欢迎业务洽谈。 误删了几十万条MySQL记录, 要如何找回物理删除的…

我的第一个react.js 的router工程

react.js 开发的时候,都是针对一个页面的,多个页面就要用Router了,本文介绍我在vscode 下的第一个router 工程。 我在学习react.js 前端开发,学到router 路由的时候有点犯难了。经过1-2天的努力,终于完成了第一个工程…

使用Pytorch构建神经网络

构建神经网络的典型流程 定义一个拥有可学习参数的神经网络遍历训练数据集处理输入数据使其流经神经网络计算损失值将网络参数的梯度进行反向传播以一定的规则更新网络的权重 我们首先定义一个Pytorch实现的神经网络: # 导入若干工具包 import torch import torch.nn as nn …

亲,您的假期余额已经严重不足了......

引言 大家好,我是亿元程序员,一位有着8年游戏行业经验的主程。 转眼八天长假已经接近尾声了,今天来总结一下大家的假期,聊一聊假期关于学习的看法,并预估一下大家节后大家上班时的样子。 1.放假前一天 即将迎来八天…

基于Web安全的Python编程(1)

目录 一、http协议基础知识介绍 1、http协议分类 2、请求方法 3、什么是URL 4、请求头 5、响应状态码 二、常用Python库、函数、操作 三、http常用请求方法 1、不带参请求 2、带参数请求(get和post存在细微区别) 四、http响应属性获取 1、获取…

计算机网络(六):应用层

参考引用 计算机网络微课堂-湖科大教书匠计算机网络(第7版)-谢希仁 1. 应用层概述 应用层是计算机网络体系结构的最顶层,是设计和建立计算机网络的最终目的,也是计算机网络中发展最快的部分 早期基于文本的应用 (电子邮件、远程登…

【古谷彻】算法模板(更新ing···)

目录 一、数学 1、逆元 (一)费马小定理/欧拉定理(快速幂) 2、组合数 (1)求组合数C(n,m) 方法一:阶乘+逆元+快速幂求组合数 方法二:记忆化搜索 方法三:递推公式 (2)组合数求概率 3、高精度sqrt (1)二分法 (2)递加递减 4、快速幂 5、欧拉函数 方法一:…

分布式架构篇

1、微服务 微服务架构风格,就像是把一个单独的应用程序开发为一套小服务,每个服务运行在自己的进程中,并使用轻量级机制通信,通常是 HTTP API。这些服务围绕业务能力来构建,并通过完全自动化部署机制来独立部署。这些…

Spring 原理

它是一个全面的、企业应用开发一站式的解决方案,贯穿表现层、业务层、持久层。但是 Spring仍然可以和其他的框架无缝整合。 1 Spring 特点 轻量级控制反转面向切面容器框架集合 2 Spring 核心组件 3 Spring 常用模块 4 Spring 主要包 5 Spring 常用注解 bean…

第十七章:Java连接数据库jdbc(java和myql数据库连接)

1.进入命令行:输入cmd,以管理员身份运行 windowsr 2.登录mysql 3.创建库和表 4.使用Java命令查询数据库操作 添加包 导入包的快捷键 选择第四个 找到包的位置 导入成功 创建java项目 二:连接数据库: 第一步:注册驱动…

设计模式 - 策略模式

目录 一. 前言 二. 实现 一. 前言 策略模式 (Strategy Pattern) 是指对一系列的算法定义,并将每一个算法封装起来,而且使它们还可以相互替换。此模式让算法的变化独立于使用算法的客户。 与状态模式的比较 状态模式的类图和策略模式类似,并…

VUE3照本宣科——内置指令与自定义指令及插槽

VUE3照本宣科——内置指令与自定义指令及插槽 前言一、内置指令1.v-text2.v-html3.v-show4.v-if5.v-else6.v-else-if7.v-for8.v-on9.v-bind10.v-model11.v-slot12.v-pre13.v-once14.v-memo15.v-cloak 二、自定义指令三、插槽1.v-slot2.useSlots3.defineSlots() 前言 &#x1f…

Windows下启动freeRDP并自适应远端桌面大小

几个二进制文件 xfreerdp # Linux下的,an X11 Remote Desktop Protocol (RDP) client which is part of the FreeRDP project wfreerdp.exe # Windows下的,freerdp2.0 主程序,freerdp3.0将废弃 sdl-freerdp.exe # Windows下的&…

【AI视野·今日NLP 自然语言处理论文速览 第四十三期】Thu, 28 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Thu, 28 Sep 2023 Totally 38 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Cross-Modal Multi-Tasking for Speech-to-Text Translation via Hard Parameter Sharing Authors Brian Yan,…

STM32CubeMX学习笔记-USB接口使用(CDC虚拟串口)

STM32CubeMX学习笔记-USB接口使用(CDC虚拟串口) 一、USB简介二、新建工程1. 打开 STM32CubeMX 软件,点击“新建工程”2. 选择 MCU 和封装3. 配置时钟4. 配置调试模式 三、USB3.1 参数配置3.3 配置时钟3.4 USB Device 四、生成代码五、查看端口…