【保姆级教程|YOLOv8添加注意力机制】【1】添加SEAttention注意力机制步骤详解、训练及推理使用

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
22.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】23.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
24.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】25.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
26.【基于YOLOv8深度学习的人脸面部表情识别系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】

《------正文------》

前言

在这里插入图片描述
SENet是一种新颖的架构单元,通过明确建模通道之间的相互依赖性,自适应地重新校准通道特征响应。本文详细介绍了如何在YOLOv8的主干网络中添加SEAttention注意力机制,并且使用修改后的yolov8网络结构进行目标检测训练与推理。本文提供了所有源码免费供小伙伴们学习参考,需要的可以通过文末方式自行下载。

本文使用的ultralytics版本为:ultralytics == 8.0.227

目录

  • 前言
  • 1. SENet简介
    • 1.1 SENet网络结构
    • 1.2 性能对比
  • 2.在YOLOV8主干中添加SEAttention注意力
    • 第1步:新建SEAttention模块并导入
    • 第2步:修改tasks.py部分代码
    • 第3步:加载配置文件训练模型
    • 第4步:模型推理
  • 【源码免费获取】
  • 结束语

1. SENet简介

github地址:https://github.com/hujie-frank/SENet
paper地址:https://arxiv.org/pdf/1709.01507.pdf

摘要:卷积神经网络(CNN)的核心构建块是卷积运算符,它使网络能够通过在每一层的局部感受域内融合空间信息和通道信息来构建有信息量的特征。以前的研究广泛探讨了这种关系的空间部分,旨在通过增强特征层次结构中空间编码的质量来增强CNN的表征能力。在这项工作中,我们转而关注通道关系,并提出了一种新颖的架构单元,我们称之为“挤压激励”(SE)块,通过明确建模通道之间的相互依赖性,自适应地重新校准通道特征响应。我们展示了这些块可以堆叠在一起形成可以在不同数据集上极其有效地推广的SENet架构。我们进一步证明,SE块在现有最先进的CNN上带来了显著的性能改进,稍微增加了计算成本。挤压激励网络构成了我们在ILSVRC 2017分类竞赛中的首个提交,获得了第一名,并将前五错误率降低到2.251%,相对于2016年的获胜成绩改善了约25%。

论文亮点如下:
SE块是一种旨在通过使网络能够进行动态通道特征校准来改善网络表征能力的架构单元。广泛的实验表明了SENets的有效性,它们在多个数据集和任务上实现了最先进的性能。此外,SE块也为之前的架构无法充分建模通道特征依赖关系的问题提供了一些启示。我们希望这一洞见对于其他需要强有力的区分特征的任务也能有所帮助。最后,SE块产生的特征重要性值也可以在其他任务中发挥作用,例如网络剪枝以进行模型压缩。

1.1 SENet网络结构

在这里插入图片描述
在这里插入图片描述

1.2 性能对比

在这里插入图片描述
在这里插入图片描述

2.在YOLOV8主干中添加SEAttention注意力

第1步:新建SEAttention模块并导入

ultralytics/nn目录下,新建SEAttention.py文件,内容如下:
在这里插入图片描述

import numpy as np
import torch
from torch import nn
from torch.nn import initclass SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

然后在ultralytics/nn/tasks.py中,导入SEAttention模块。
在这里插入图片描述

第2步:修改tasks.py部分代码

修改在ultralytics/nn/task.py中的parse_model函数【作用是解析模型结构】:在解析的地方添加如下代码:
在这里插入图片描述

        elif m in {SEAttention}:args = [ch[f], *args]

然后创建SEAtt_yolov8.yaml文件,用于修改网络结构添加注意力,内容如下:【将注意力添加到自己想添加的层就行】,在这示例中我们是添加到了主干网络的最后面。
在这里插入图片描述

# Ultralytics YOLO ?, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9- [-1, 1, SEAttention, [16]]# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在这里插入图片描述
此处注意修改层数的变化,层数是从0开始数的,由于此处是添加到了第10层,因此后面层数都发生了变化。10层以后的相关层数都需要加1.具体修改内容如下:【左边是原始的yolov8.yaml文件,右边是新建的SEAtt_yolov8.yaml文件】
在这里插入图片描述

第3步:加载配置文件训练模型

运行训练代码train.py文件,内容如下:

#coding:utf-8
from ultralytics import YOLO# 加载预训练模型
# 添加注意力机制,SEAtt_yolov8.yaml 默认使用的是n。
# SEAtt_yolov8s.yaml,则使用的是s,模型。
model = YOLO("ultralytics/cfg/models/v8/SEAtt_yolov8n.yaml").load('yolov8n.pt')# Use the model
if __name__ == '__main__':# Use the modelresults = model.train(data='datasets/TomatoData/data.yaml', epochs=250, batch=4)  # 训练模型# 将模型转为onnx格式# success = model.export(format='onnx')

训练开始的时候,注意一下,打印出的网络结构是否有修改,如下图所示:模型训练开始时,打印的网络结构中显示SEAttention已经添加成功。
在这里插入图片描述

第4步:模型推理

模型训练完成后,我们使用训练好的模型对图片进行检测:

#coding:utf-8
from ultralytics import YOLO
import cv2# 所需加载的模型目录
# path = 'models/best2.pt'
path = 'runs/detect/train2/weights/best.pt'
# 需要检测的图片地址
img_path = "TestFiles/Riped tomato_20.jpeg"# 加载预训练模型
# conf	0.25	object confidence threshold for detection
# iou	0.7	intersection over union (IoU) threshold for NMS
model = YOLO(path, task='detect')# 检测图片
results = model(img_path)
res = results[0].plot()
res = cv2.resize(res,dsize=None,fx=0.5,fy=0.5,interpolation=cv2.INTER_LINEAR)
cv2.imshow("YOLOv8 Detection", res)
cv2.waitKey(0)

在这里插入图片描述

【源码免费获取】

为了小伙伴们能够,更好的学习实践,本文已将所有代码、示例数据集、论文等相关内容打包上传,供小伙伴们学习。获取方式如下:

关注下方名片G-Z-H:【阿旭算法与机器学习】,发送【yolov8改进】即可免费获取

在这里插入图片描述


结束语

关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

觉得不错的小伙伴,感谢点赞、关注加收藏哦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/621045.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot+thymeleaf实战遇到的问题

目录 一、控制台: 二、数据库查询异常: 三、前后端错误校验 四、在serviceImp中需要添加一个eq条件,表示和数据库中的哪个字段进行比较,否则会查出所有数据,导致500 五、使用流转换数据更简洁 六、重复报错&…

动态规划篇-03:打家劫舍

198、打家劫舍 状态转移方程 base case 边界问题就是:走到最后一间房子门口也没抢,那么最终抢到的金额为0 明确状态 “原问题和子问题中会变化的变量” 抢到的金额数就是状态,因为随着在每一件房子门口做选择,抢到的金额数会随…

大模型训练营Day3 基于 InternLM 和 LangChain 搭建你的知识库 作业

本篇记录大模型训练营第三次的作业,属实是拖延症本症患者。 主要步骤前面的安装各种包和依赖如前面作业一样,按照文档操作即可: 再按照文档进行各种克隆,把知识库复制到本地: 复制粘贴操作文档中的构建向量数据库的文…

七:Day08_任务调度

第一章 定时任务概述 在项目中开发定时任务应该一种比较常见的需求,在 Java 中开发定时任务主要有三种解决方案:一是使用JDK 自带的 Timer,二是使用 Spring Task,三是使用第三方组件 Quartz。 建议: 单体项目架构使用…

基于51单片机的智能热水器设计

需要全部文件请私信关注我!!! 基于51单片机的智能热水器设计 摘要一、绪论1.1 选题背景及意义1.2 完成目标与功能设计 二、硬件系统设计2.1 硬件完成要求2.2 方案选择2.3 电源电路设计2.4 键盘电路2.5 蜂鸣器报警电路2.6 温度检测电路2.7 红…

UR5机械臂控制

1.ros环境安装 快速安装命令:wget http://fishros.com/install -O fishros && . fishros 2.ur驱动安装 虚拟机Ubuntu16.04ros-kinetic控制真实UR5机械臂总结记录ros kinetic控制UR3机械臂 3.ur命令行控制 使用了 URScript 语言来描述机器人的运动指令&…

数学建模.斯皮尔曼相关系数

一、两种定义 二、用matlab计算 三、两种相关系数计算结果的对比 四、取检验值(临界值)分为两种情况 (1)小样本查表 (2)大样本 P值是大于检验值的概率 本文是学习清风网课后的总结,希望对大家有…

1000以内的质数,用python获取放到list1中,1000以内的斐波那契数,用python获取放到list2中,然后两个list画出曲线图

# -*- coding: utf-8 -*- import matplotlib.pyplot as plt # 获取1000以内的质数 def get_primes(n): primes [] for possiblePrime in range(2, n 1): # 假设数是质数 isPrime True for num in range(2, int(possiblePrime ** 0.5) 1): if possiblePrime % num …

自动驾驶车辆运动规划方法综述 - 论文阅读

本文旨在对自己的研究方向做一些记录,方便日后自己回顾。论文里面有关其他方向的讲解读者自行阅读。 参考论文:自动驾驶车辆运动规划方法综述 1 摘要 规划决策模块中的运动规划环节负责生成车辆的局部运动轨迹 ,决定车辆行驶质量的决定因素…

RabbitMQ如何保证消息不丢失?

RabbitMQ如何保证消息不丢失? 消息丢失的情况 生产者发送消息未到达交换机生产者发送消息未到达队列MQ宕机,消息丢失消费者服务宕机,消息丢失 生产者确认机制 解决的问题:publisher confirm机制来避免消息发送到MQ过程中消失。…

数据结构排序——计数排序和排序总结(附上912. 排序数组讲解)

数据结构排序——计数排序和排序总结 现在常见算法排序都已讲解完成,今天就再讲个计数排序。再总结一下 文章目录 1.计数排序2.排序总结3.排序oj(排序数组)题目详情代码思路 1.计数排序 计数排序是一种非基于比较的排序算法,它通…

STL标准库与泛型编程(侯捷)笔记2

STL标准库与泛型编程(侯捷) 本文是学习笔记,仅供个人学习使用。如有侵权,请联系删除。 参考链接 Youbute: 侯捷-STL标准库与泛型编程 B站: 侯捷 - STL Github:STL源码剖析中源码 https://github.com/SilverMaple/STLSourceCo…

SLA(服务等级协议)

在硅谷一线大厂所维护的系统服务中,我们经常可以看见SLA这样的承诺。 例如,在谷歌的云计算服务平台Google Cloud Platform中,他们会写着“99.9% Availability”这样的承诺。那什么是“99.9% Availability”呢? 要理解这个承诺是…

七:Day07_redis进阶02

第一章 Redis 事务 1.1 节 数据库事务复习 在数据库层面,事务是指一组操作,这些操作要么全都被成功执行,要么全都不执行。 数据库事务的四大特性: A:Atomic, 原子性。要么全部执行,要么全部不…

复合机器人作为一种新型的智能制造装备高效、精准和灵活的生产方式

随着汽车制造业的快速发展,对于高效、精准和灵活的生产方式需求日益增强。复合机器人作为一种新型的智能制造装备,以其独特的优势在汽车制造中发挥着越来越重要的作用。因此,富唯智能顺应时代的发展趋势,研发出了ICR系列的复合机器…

03 Strategy策略

抽丝剥茧设计模式 之 Strategy策略 - 更多内容请见 目录 文章目录 一、Strategy策略二、Comparable和Comparator源码分析使用案例Arrays.sort源码Collections.sort源码Comparable源码Comparator源码 一、Strategy策略 策略模式是一种设计模式,它定义了一系列的算法…

Unity-生命周期函数

目录 生命周期函数是什么? 生命周期函数有哪些? Awake() OnEnable() Start() FixedUpdate() Update() Late Update() OnDisable() OnDestroy() Unity中生命周期函数支持继承多态吗? 生命周期函数是什么? 在Unity中&…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -关于我们页面实现

锋哥原创的uniapp微信小程序投票系统实战: uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

4_【Linux版】重装数据库问题处理记录

1、卸载已安装的oracle数据库。 2、知识点补充: 3、调整/dev/shm/的大小 【linux下修改/dev/shm tmpfs文件系统大小 - saratearing - 博客园 (cnblogs.com)】 mount -o remount,size100g /dev/shm 4、重装oracle后没有orainstRoot.sh 【重装oracle后没有orains…

隧道应用2-netsh端口转发监听Meterpreter

流程介绍: 跳板机 A 和目标靶机 B 是可以互相访问到的,在服务器 A 上可以通过配置 netsh 端口映射访问 B 服务器。如果要拿 B 服务器的权限通常是生成正向后门,使用 kali 的 msf 正向连接B服务器,进而得到 Meterpreter&#xff0c…