学习注意力机制并将其应用到网络中

什么是注意力机制

注意力机制的核心重点就是让网络关注到它更需要关注的地方

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

注意力机制可以分为通道注意力机制空间注意力机制,以及二者的结合。

通道注意力机制关注的是某些重要的通道,空间注意力机制关注的是图片中某些重要的区域

注意力机制的实现方式

在深度学习中,常见的注意力机制的实现方式有SENet,CBAM,ECA等等。

1.SENet的实现

SENet是通道注意力机制的典型实现。
对于输入进来的特征层,我们关注其每一个通道的权重,对于SENet而言,其重点是获得输入进来的特征层,每一个通道的权值。利用SENet,我们可以让网络关注它最需要关注的通道

其具体实现方式就是:
1、对输入进来的特征层进行全局平均池化
2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同。
3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。
4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。

实现代码:

def se_block(input_feature, ratio=16, name=""):channel = input_feature._keras_shape[-1]se_feature = GlobalAveragePooling2D()(input_feature)se_feature = Reshape((1, 1, channel))(se_feature)se_feature = Dense(channel // ratio,activation='relu',kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "se_block_one_"+str(name))(se_feature)se_feature = Dense(channel,kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "se_block_two_"+str(name))(se_feature)se_feature = Activation('sigmoid')(se_feature)se_feature = multiply([input_feature, se_feature])return se_feature

2.CBAM的实现

CBAM将通道注意力机制和空间注意力机制进行一个结合,相比于SENet只关注通道的注意力机制可以取得更好的效果。CBAM会对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理。
通道注意力机制的实现可以分为两个部分,我们会对输入进来的单个特征层,分别进行全局平均池化和全局最大池化。之后对平均池化和最大池化的结果,利用共享的全连接层进行处理,我们会对处理后的两个结果进行相加,然后取一个sigmoid,此时我们获得了输入特征层每一个通道的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可

空间注意力机制的实现:我们会对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,然后取一个sigmoid,此时我们获得了输入特征层每一个特征点的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。

实现代码如下:

def channel_attention(input_feature, ratio=8, name=""):channel = input_feature._keras_shape[-1]shared_layer_one = Dense(channel//ratio,activation='relu',kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "channel_attention_shared_one_"+str(name))shared_layer_two = Dense(channel,kernel_initializer='he_normal',use_bias=False,bias_initializer='zeros',name = "channel_attention_shared_two_"+str(name))avg_pool = GlobalAveragePooling2D()(input_feature)    max_pool = GlobalMaxPooling2D()(input_feature)avg_pool = Reshape((1,1,channel))(avg_pool)max_pool = Reshape((1,1,channel))(max_pool)avg_pool = shared_layer_one(avg_pool)max_pool = shared_layer_one(max_pool)avg_pool = shared_layer_two(avg_pool)max_pool = shared_layer_two(max_pool)cbam_feature = Add()([avg_pool,max_pool])cbam_feature = Activation('sigmoid')(cbam_feature)return multiply([input_feature, cbam_feature])def spatial_attention(input_feature, name=""):kernel_size = 7cbam_feature = input_featureavg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)concat = Concatenate(axis=3)([avg_pool, max_pool])cbam_feature = Conv2D(filters = 1,kernel_size=kernel_size,strides=1,padding='same',kernel_initializer='he_normal',use_bias=False,name = "spatial_attention_"+str(name))(concat)	cbam_feature = Activation('sigmoid')(cbam_feature)return multiply([input_feature, cbam_feature])def cbam_block(cbam_feature, ratio=8, name=""):cbam_feature = channel_attention(cbam_feature, ratio, name=name)cbam_feature = spatial_attention(cbam_feature, name=name)return cbam_feature

3、ECA的实现
ECANet是也是通道注意力机制的一种实现形式。ECANet可以看作是SENet的改进版
ECANet的作者认为SENet对通道注意力机制的预测带来了副作用,捕获所有通道的依赖关系是低效并且是不必要的
ECA模块的思想是非常简单的,它去除了原来SE模块中的全连接层,直接在全局平均池化之后的特征上通过一个1D卷积进行学习

既然使用到了1D卷积,那么1D卷积的卷积核大小的选择就变得非常重要了,了解过卷积原理的同学很快就可以明白,1D卷积的卷积核大小会影响注意力机制每个权重的计算要考虑的通道数量

实现代码如下:

def eca_block(input_feature, b=1, gamma=2, name=""):channel = input_feature._keras_shape[-1]kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1avg_pool = GlobalAveragePooling2D()(input_feature)x = Reshape((-1,1))(avg_pool)x = Conv1D(1, kernel_size=kernel_size, padding="same", name = "eca_layer_"+str(name), use_bias=False,)(x)x = Activation('sigmoid')(x)x = Reshape((1, 1, -1))(x)output = multiply([input_feature,x])return output

开始应用:将注意力机制加入到YOLOv8中

1.找到conv.py文件

2.在conv.py中添加名字

3.在__init__.py中添加名字

4.在tasks.py文件中添加名字

5.在tasks.py中添加配置

在该函数中添加代码

添加的代码为:

elif m in {CBAM}:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, *args[1:]]

添加后的为:

6.打开yaml文件

7.尽量不要在这个文件中更改内容,我们可以自己创建一个yaml文件(my_yolov8_CBAM.yaml),然后将yolov8.yaml中的内容复制过来

8.在backbone中进行修改

from列中的-1表示应用上一层的参数、repeats列表示重复多少次、module列表示模型的名字、args列表示参数

9.第八点操作添加完后层数会改变,head部分需要进行相应的修改

修改前:

# YOLOv8.0n head
head:- [ -1, 1, nn.Upsample, [ None, 2, "nearest" ] ]- [ [ -1, 6 ], 1, Concat, [ 1 ] ] # cat backbone P4- [ -1, 3, C2f, [ 512 ] ] # 12- [ -1, 1, nn.Upsample, [ None, 2, "nearest" ] ]- [ [ -1, 4 ], 1, Concat, [ 1 ] ] # cat backbone P3- [ -1, 3, C2f, [ 256 ] ] # 15 (P3/8-small)- [ -1, 1, Conv, [ 256, 3, 2 ] ]- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4- [ -1, 3, C2f, [ 512 ] ] # 18 (P4/16-medium)- [ -1, 1, Conv, [ 512, 3, 2 ] ]- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5- [ -1, 3, C2f, [ 1024 ] ] # 21 (P5/32-large)- [ [ 15, 18, 21 ], 1, Detect, [ nc ] ] # Detect(P3, P4, P5)

修改后:

为什么都+1了?

举个例子,原来要连接第六层,加了注意力层后,原来的第六层就变成第七层,所以在Concat连接时需要修改相应的层数

至此,注意力机制已经插入,可以开始使用了

10.在根目录下新建一个main.py文件,代码如下:

from ultralytics import YOLOmodel = (YOLO("ultralytics/cfg/models/v8/my_yolov8_CBAM.yaml"))
model.train(**{'cfg': 'ultralytics/cfg/default.yaml'})

运行即可开始训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/11411.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

证明力引导算法forceatlas2为什么不是启发式算法

一、基本概念 吸引力 F a ( n i ) ∑ n j ∈ N c t d ( n i ) ω i , j d E ( n i , n j ) V i , j \displaystyle \bm{F}_a(n_i) \sum_{n_j \in \mathcal{N}_{ctd}(n_i)} \omega_{i,j} \; d_E(n_i,n_j) \bm{V}_{i,j} Fa​(ni​)nj​∈Nctd​(ni​)∑​ωi,j​dE​(ni​,nj​…

class常量池、运行时常量池和字符串常量池的关系

类常量池、运行时常量池和字符串常量池这三种常量池,在Java中扮演着不同但又相互关联的角色。理解它们之间的关系,有助于深入理解Java虚拟机(JVM)的内部工作机制,尤其是在类加载、内存分配和字符串处理方面。 类常量池…

NeurIPS‘24 截稿日期逼近 加拿大温哥华邀你共赴盛会

会议之眼 快讯 第38届NeurIPS24(Conference and Workshop on Neural Information Processing Systems)即神经信息处理系统研讨会将于 2024 年 12月9日-15日在加拿大温哥华会议中心举行! NeurIPS 每一年都是全球AI领域的一场盛宴,吸引着来自世界各地的顶…

5.10.8 Transformer in Transformer

Transformer iN Transformer (TNT)。具体来说,我们将局部补丁(例如,1616)视为“视觉句子”,并将它们进一步划分为更小的补丁(例如,44)作为“视觉单词”。每个单词的注意力将与给定视…

信号和槽基本概念

🐌博主主页:🐌​倔强的大蜗牛🐌​ 📚专栏分类:QT❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、概述 二、信号的本质 三、槽的本质 一、概述 在 Qt 中,用户和控件的每次交互过程称…

Bootloader+升级方案

随着设备的功能越来越强大,系统也越来越复杂,产品升级也成为了开发过程不可或缺的一道程序。在工程应用中,如何在不更改硬件的前提下通过软件的方式实现产品升级。通过Bootloader来实现固件的升级是一种极好的方式,Bootloader是单…

I2CKD : INTRA- AND INTER-CLASS KNOWLEDGE DISTILLATION FOR SEMANTIC SEGMENTATION

摘要 本文提出了一种新的针对图像语义分割的知识蒸馏方法,称为类内和类间知识蒸馏(I2CKD)。该方法的重点是在教师(繁琐模型)和学生(紧凑模型)的中间层之间捕获和传递知识。对于知识提取&#x…

12个乒乓球,有一个次品,不知轻重,用一台无砝码天平称三次,找出次品,告知轻重?

前言 B站上看到个视频:为什么有人不认可清北的学生大多是智商高的? 然后试了下,发现我真菜 自己的思路(失败) 三次称重要获取到12个乒乓球中那个是次品,我想着将12个小球编号,分为四组,每组…

yo!这里是socket网络编程相关介绍

目录 前言 基本概念 源ip&&目的ip 源端口号&&目的端口号 udp&&tcp初识 socket编程 网络字节序 socket常见接口 socket bind listen accept connect 地址转换函数 字符串转in_addr in_addr转字符串 套接字读写函数 recvfrom&&a…

Java入门基础学习笔记2——JDK的选择下载安装

搭建Java的开发环境: Java的产品叫JDK(Java Development Kit: Java开发者工具包),必须安装JDK才能使用Java。 JDK的发展史: LTS:Long-term Support:长期支持版。指的Java会对这些版…

pycharm报错Process finished with exit code -1073740791 (0xC0000409)

pycharm报错Process finished with exit code -1073740791 (0xC0000409) 各种垃圾文章(包括chatgpt产生的垃圾文章),没有给出具体的解决办法。 解决办法就是把具体报错信息显示出来,然后再去查。 勾选 然后再运行就能把错误显示…

MetaRTC-play拉流客户端代码分析

渲染使用opengl,音频播放使用alsa。 当点击播放按钮后,以此调用的类如下,开始建立rtc连接,AV解码,音频渲染,视频渲染。 如果想去除QT,改为cmake工程管理,去掉渲染部分即可。 下方是…

Linux---vim编辑器(续写)

5. vim正常模式命令集 插入模式 按「i」切换进入插入模式「insert mode」, 按“i”进入插入模式后是从光标当前位置开始输入文件; 按「a」进入插入模式后,是从目前光标所在位置的下一个位置开始输入文字; 按「o」进入插入模式…

从头开始学Spring—01Spring介绍和IOC容器思想

目录 1.Spring介绍 1.1Spring概述 1.2特性 1.3五大功能模块 2.IOC容器 2.1IOC思想 ①获取资源的传统方式 ②反转控制方式获取资源 ③DI 2.2IOC容器在Spring中的实现 ①BeanFactory ②ApplicationContext ③ApplicationContext的主要实现类 1.Spring介绍 1.1Sprin…

Linux系统一步一脚印式学习

Linux操作系统具有许多特点和优势。首先,它是开放源代码的,也就意味着任何人都可以对源代码进行查看和修改。其次,可以同时支持多个用户且可以同时执行多个任务,此外,Linux操作系统也非常稳定和安全。相对于其他操作系…

安全测试|常见SQL注入攻击方式、影响及预防

SQL注入 什么是SQL注入? SQL注入是比较常见的网络攻击方式之一,主要攻击对象是数据库,针对程序员编写时的疏忽,通过SQL语句,实现无账号登录,篡改数据库。 SQL注入简单来说就是通过在表单中填写包含SQL关键…

SSD-60S施耐德电机保护器EOCR-SSD

EOCR主要产品有电子式电动机保护继电器,电子式过电流继电器,电子式欠电流继电器,电子式欠电压继电器,其它保护装置,电流互感器。EOCR-SSD 10-60A电机保护器 系列型号: EOCRSSD-05SEOCRssD-30s EOCRSSD-60SEOCRSSD-0…

开源即时通讯IM框架 MobileIMSDK v6.5 发布

一、更新内容简介 本次更新为次要版本更新,进行了bug修复和优化升级(更新历史详见:码云 Release Notes、Github Release Notes)。 MobileIMSDK 可能是市面上唯一同时支持 UDPTCPWebSocket 三种协议的同类开源IM框架。轻量级、高…

8种常见的CMD命令

1.怎么打开CMD窗口 步骤1:winr 步骤2:在弹出的窗口输入cmd,然后点击确认,就会出现一个cmd的窗口 2.CMD的8种常见命令 2.1盘符名称冒号 说明:切换盘的路径 打开CMD窗口这里默认的是C盘的Users的27823路径底下&#xf…

基于微信小程序+JAVA Springboot 实现的【网上商城小程序】app+后台管理系统 (内附设计LW + PPT+ 源码+ 演示视频 下载)

项目名称 项目名称: 基于微信小程序的网上商城 项目技术栈 该项目采用了以下核心技术栈: 后端框架/库: Java, SSM框架数据库: MySQL前端技术: 微信开发者工具,微信小程序框架 项目展示 5.1 管理员服务…