💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可
该专栏完整目录链接: 芒果YOLOv8深度改进教程
该篇博客为
免费阅读内容
,直接改进即可🚀🚀🚀
文章目录
- 1. GAM论文
- 2. YOLOv8 核心代码改进部分
- 2.1 核心新增代码
- 2.2 修改部分
- 2.3 YOLOv8-gam 网络配置文件
- 2.4 运行代码
- 改进说明
1. GAM论文
研究了多种注意力机制来提高各种计算机视觉任务的性能。然而,现有的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息缩减和放大全局交互式表示来提高深度神经网络的性能。我们引入了带有多层感知器的 3D 排列,用于通道注意力以及卷积空间注意力子模块。对CIFAR-100和ImageNet-1K上图像分类任务的所提机制的评估表明,我们的方法在ResNet和轻量级MobileNet上都稳定地优于最近的几种注意力机制。
具体细节可以去看原论文:https://arxiv.org/pdf/2112.05561v1.pdf
2. YOLOv8 核心代码改进部分
2.1 核心新增代码
首先在ultralytics/nn/modules文件夹下,创建一个 gam.py文件,新增以下代码
import numpy as np
import torch
from torch import nn
from torch.nn import initclass GAMAttention(nn.Module):#https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True,rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 /rate)),nn.ReLU(inplace=True),nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle out = x * x_spatial_attreturn out def channel_shuffle(x, groups=2):B, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out=out.view(B, C, H, W) return out
2.2 修改部分
在ultralytics/nn/modules/init.py中导入 定义在 gam.py 里面的模块
from .gam import GAMAttention'GAMAttention' 加到 __all__ = [...] 里面
第一步:
在ultralytics/nn/tasks.py
文件中,新增
from ultralytics.nn.modules import GAMAttention
然后在 在tasks.py
中配置
找到
elif m is nn.BatchNorm2d:args = [ch[f]]
在这句上面加一个
elif m is GAMAttention:c1, c2 = ch[f], args[0]if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]
2.3 YOLOv8-gam 网络配置文件
新增YOLOv8-gam.yaml
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 3, GAMAttention, [1024]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)
2.4 运行代码
直接替换YOLOv8-gam.yaml 进行训练即可
到这里就完成了这篇的改进。
改进说明
这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下
🥇🥇🥇
添加博主联系方式:
友好的读者可以添加博主QQ: 2434798737
, 有空可以回答一些答疑和问题
🚀🚀🚀
参考
https://github.com/ultralytics/ultralytics