YOLOv8改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点

一、本文介绍

本文给大家带来的改进机制是iAFF(迭代注意力特征融合),其主要思想是通过改善特征融合过程来提高检测精度。传统的特征融合方法如加法或串联简单,未考虑到特定对象的融合适用性。iAFF通过引入多尺度通道注意力模块(我个人觉得这个改进机制就算融合了注意力机制的求和操作),更好地整合不同尺度和语义不一致的特征。该方法属于细节上的改进,并不影响任何其它的模块,非常适合大家进行融合改进,单独使用也是有一定的涨点效果。

推荐指数:⭐⭐⭐⭐

涨点效果:⭐⭐⭐⭐

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

训练结果对比图-> 

目录

一、本文介绍

二、iAFF的基本框架原理

三、iAFF的核心代码

四、手把手教你添加iAFF

4.1 iAFF添加步骤

4.1.1 步骤一

4.1.2 步骤二

4.1.3 步骤三

五、C2f_iAFF的yaml文件和运行记录

5.1 C2f_iAFF的yaml文件

5.2 C2f_iAFF的训练过程截图 

六、本文总结


二、iAFF的基本框架原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


iAFF的主要思想在于通过更精细的注意力机制来改善特征融合,从而增强卷积神经网络。它不仅处理了由于尺度和语义不一致而引起的特征融合问题,还引入了多尺度通道注意力模块,提供了一种统一且通用的特征融合方案。此外,iAFF通过迭代注意力特征融合来解决特征图初始整合可能成为的瓶颈。这种方法使得模型即使在层数或参数较少的情况下,也能取得到较好的效果。 

iAFF的创新点主要包括:

1. 注意力特征融合:提出了一种新的特征融合方式,利用注意力机制来改善传统的简单特征融合方法(如加和或串联)。

2. 多尺度通道注意力模块:解决了在不同尺度上融合特征时出现的问题,特别是语义和尺度不一致的特征融合问题。

3. 迭代注意力特征融合(iAFF):通过迭代地应用注意力机制来改善特征图的初步整合,克服了初步整合可能成为性能瓶颈的问题。

​ 

这张图片是关于所提出的AFF(注意力特征融合)和iAFF(迭代注意力特征融合)的示意图。图中展示了两种结构:

(a) AFF: 展示了一个通过多尺度通道注意力模块(MS-CAM)来融合不同特征的基本框架。特征图X和Y通过MS-CAM和其他操作融合,产生输出Z。

(b) iAFF: 与AFF类似,但添加了迭代结构。在这里,输出Z回馈到输入,与X和Y一起再次经过MS-CAM和融合操作,以进一步细化特征融合过程。

(这两种方法都是文章中提出的我仅使用了iAFF也就是更复杂的版本,大家对于AFF有兴趣的可以按照我的该法进行相似添加即可)


三、iAFF的核心代码

该代码的使用方式需要两个图片,有人去用其替换Concat操作,但是它的两个输入必须是相同shape,但是YOLOv8中我们Concat一般两个输入在图像宽高上都不一样,所以我用其替换Bottlenekc中的残差相加操作,算是一种比较细节上的创新。

import torch
import torch.nn as nndef autopad(k, p=None, d=1):  # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class iAFF(nn.Module):'''多特征融合 iAFF'''def __init__(self, channels=64, r=2):super(iAFF, self).__init__()inter_channels = int(channels // r)# 本地注意力self.local_att = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)# 全局注意力self.global_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),)# 第二次本地注意力self.local_att2 = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)# 第二次全局注意力self.global_att2 = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.sigmoid = nn.Sigmoid()def forward(self, x, residual):xa = x + residualxl = self.local_att(xa)xg = self.global_att(xa)xlg = xl + xgwei = self.sigmoid(xlg)xi = x * wei + residual * (1 - wei)xl2 = self.local_att2(xi)xg2 = self.global_att(xi)xlg2 = xl2 + xg2wei2 = self.sigmoid(xlg2)xo = x * wei2 + residual * (1 - wei2)return xoclass AFF(nn.Module):'''多特征融合 AFF'''def __init__(self, channels=64, r=4):super(AFF, self).__init__()inter_channels = int(channels // r)self.local_att = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.global_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.sigmoid = nn.Sigmoid()def forward(self, x, residual):xa = x + residualxl = self.local_att(xa)xg = self.global_att(xa)xlg = xl + xgwei = self.sigmoid(xlg)xo = 2 * x * wei + 2 * residual * (1 - wei)return xoclass C2f_iAFF(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))class Bottleneck(nn.Module):"""Standard bottleneck."""def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, andexpansion."""super().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, k[0], 1)self.cv2 = Conv(c_, c2, k[1], 1, g=g)self.add = shortcut and c1 == c2self.iAFF = iAFF(c2)def forward(self, x):"""'forward()' applies the YOLO FPN to input data."""if self.add:results =  self.iAFF(x , self.cv2(self.cv1(x)))else:results = self.cv2(self.cv1(x))return resultsif __name__ == '__main__':x = torch.ones(8, 64, 32, 32)channels = x.shape[1]model = C2f_iAFF(channels, channels, True)output = model(x)print(output.shape)


四、手把手教你添加iAFF

4.1 iAFF添加步骤

4.1.1 步骤一

首先我们找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个py文件,名字可以根据你自己的习惯起,然后将iAFF的核心代码复制进去。

4.1.2 步骤二

之后我们找到'ultralytics/nn/tasks.py'文件,在其中注册我们的iAFF模块。

首先我们需要在文件的开头导入我们的iAFF模块,如下图所示->

4.1.3 步骤三

我们找到parse_model这个方法,可以用搜索也可以自己手动找,大概在六百多行吧。 我们找到如下的地方,然后将C2f_iAFF添加进去即可,模仿我添加即可。

到此我们就注册成功了,可以修改yaml文件中输入C2f_iAFF使用这个模块了。


五、C2f_iAFF的yaml文件和运行记录

5.1 C2f_iAFF的yaml文件

下面的添加C2f_iAFF是我实验结果的版本。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f_iAFF, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f_iAFF, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f_iAFF, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f_iAFF, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.2 C2f_iAFF的训练过程截图 

下面是添加了C2f_iAFF的训练截图。

大家可以看下面的运行结果和添加的位置所以不存在我发的代码不全或者运行不了的问题大家有问题也可以在评论区评论我看到都会为大家解答(我知道的)。


六、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/584886.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JDK9及之后版本使用 jlink 生成定制化的 JRE

许多java软件的运行需要依赖jre,在 jdk8 之后,不再提供默认的 jre,后续如果项目中还是想用 jre 的形式发布软件,那么可以使用 jlink 工具生成 jre。 一、jlink 命令详解 jlink 二、查看jdk中包含的所有模块 如果在 jdk 安装文件夹…

Hadoop之Yarn 详细教程

1、yarn 的基本介绍和产生背景 YARN 是 Hadoop2 引入的通用的资源管理和任务调度的平台,可以在 YARN 上运行 MapReduce、Tez、Spark 等多种计算框架,只要计算框架实现了 YARN 所定义的 接口,都可以运行在这套通用的 Hadoop 资源管理和任务调…

【经典算法】有趣的算法之---蚁群算法梳理

every blog every motto: You can do more than you think. 0. 前言 蚁群算法记录 1. 简介 蚁群算法(Ant Clony Optimization, ACO)是一种群智能算法,它是由一群无智能或有轻微智能的个体(Agent)通过相互协作而表现出智能行为,从而为求解复杂问题提供了一个新的可能性…

VSCode远程开发配置

目录 概要远程开发插件安装开始连接SSH无密码登录开发环境配置 概要 现在很多公司都是直接远程到服务器上写代码,使用远程开发,可以在与生产环境相同的环境中开发、测试和部署代码,减少因环境不同而导致的问题。当下VSCode远程开发是支持的比…

ClickHouse基础知识(六):ClickHouse的副本配置

副本的目的主要是保障数据的高可用性,即使一台 ClickHouse 节点宕机,那么也可以 从其他服务器获得相同的数据。 1. 副本写入流程 2. 配置步骤 ➢ 启动 zookeeper 集群 ➢ 在hadoop101的/etc/clickhouse-server/config.d目录下创建一个名为metrika.xml…

002文章解读与程序——中国电机工程学报EI\CSCD\北大核心《计及源荷不确定性的综合能源生产单元运行调度与容量配置两阶段随机优化》已提供下载资源

👆👆👆👆👆👆👆👆👆👆👆👆👆👆👆👆👆👆下载资源链接&#x1f4…

Collector收集器的高级用法

Collectors收集器的高级用法 场景1:获取关联的班级名称 原先如果需要通过关联字段拿到其他表的某个字段,只能遍历List匹配获取 for (Student student : studentList) {Long clazzId student.getClazzId();// 遍历班级列表,获取学生对应班级…

HarmonyOS4.0系统性深入开发08服务卡片架构

服务卡片概述 服务卡片(以下简称“卡片”)是一种界面展示形式,可以将应用的重要信息或操作前置到卡片,以达到服务直达、减少体验层级的目的。卡片常用于嵌入到其他应用(当前卡片使用方只支持系统应用,如桌…

鸿鹄电子招投标系统:基于Spring Boot、Mybatis、Redis和Layui的企业电子招采平台源码与立项流程

在数字化时代,企业需要借助先进的数字化技术来提高工程管理效率和质量。招投标管理系统作为企业内部业务项目管理的重要应用平台,涵盖了门户管理、立项管理、采购项目管理、采购公告管理、考核管理、报表管理、评审管理、企业管理、采购管理和系统管理等…

服务器被入侵后如何查询连接IP以及防护措施

目前越来越多的服务器被入侵,以及攻击事件频频的发生,像数据被窃取,数据库被篡改,网站被强制跳转到恶意网站上,网站在百度的快照被劫持等等的攻击症状层出不穷,在这些问题中,如何有效、准确地追…

使用Vscode远程debug报错找不到Module找不到File

1..报第一个错 提示我无法导入自己写的module 如图: 解决办法: stackoverflow上说的在launch.json中加了一条 env,就解决了。 "env": { "PYTHONPATH":"/home/zt/ge-sc-master/ge-sc-master"}, 2.解决完第一个…

软件测试/测试开发丨Python、pycharm 安装与环境配置

Python 安装与环境配置 1. Python 安装 版本推荐 3.10.0下载地址:www.python.org/downloads/w… 若需要安装旧版本,在页面下方选择对应版本即可,MacOS选择对应系统即可 图示下载windows 3.11.4版本 安装Python 执行安装程序,安…

numpy数组03-数组的计算

一.数组与数字之间进行计算 numpy中的数组与数字进行计算是广播形式,数组-*/数字,则数组中的每一个数字都会进行相应的四则运算。 1.1数组与数字之间的四则运算 示例代码如下: import numpy as npa np.arange(24) b a.reshape(4, 6) pr…

【Maven】<scope>provided</scope>

在Maven中,“provided”是一个常用的依赖范围,它表示某个依赖项在编译和测试阶段是必需的,但在运行时则由外部环境提供,不需要包含在最终的项目包中。下面是对Maven scope “provided”的详细解释: 编译和测试阶段可用…

帆软FineBi V6版本经验总结

帆软FineBi V6版本经验总结 BI分析出现背景 ​ 现在是一个大数据的时代,每时每刻都有海量的明细数据出现。这时大数据时代用户思维是:1、数据的爆炸式增长,人们比起明细数据,更在意样本的整体特征、相互关系。2、基于明细的“小…

数据结构之树 --- 二叉树

目录 定义二叉树的结构体 二叉树的遍历 递归遍历 非递归遍历 链式二叉树的实现 二叉树的功能接口 先序遍历创建二叉树 后序遍历销毁二叉树 先序遍历查找树中值为x的节点 层序遍历 上篇我们对二叉树的顺序存储堆进行了讲述,本文我们来看链式二叉树。 定…

SpringCloud(H版alibaba)框架开发教程之nacos做配置中心——附源码(2)

上篇主要讲了使用eureka,zk,nacos当注册中心 这篇内容是nacos配置中心 代码改动部分mysql驱动更新到8.0,数据库版本升级到了8.0,nacos版本更新到了2.x nacos2.x链接 链接:https://pan.baidu.com/s/11nObzgTjWisAfOp…

探秘交互设计:深入了解五大核心维度!

交互式设计是用户体验(UX)设计的重要组成部分。本文将解释什么是交互设计,并分享一些有用的交互设计模型,并简要描述交互设计师通常做什么。 如何解释交互设计 交互式设计可以用一个简单的术语来理解:它是用户和产品…

借贷协议 Tonka Finance:铭文资产流动性的新破局者

“Tonka Finance 是铭文赛道中首个借贷协议,它正在为铭文资产赋予捕获流动性的能力,并为其构建全新的金融场景。” 在 2023 年的 1 月,比特币 Ordinals 协议被推出后,包括 BRC20,Ordinals 等在内的系列铭文资产在包括比…

nginx源码分析-3

这一章内容讲述nginx中的事件是如何一步步添加到epoll实例中的。 在初始化http连接的函数ngx_http_init_connection中,nginx为http连接初始化了处理请求的回调函数,之后调用ngx_handle_read_event函数对可读数据进行处理。这里只为连接设置read而没有设…