要将注意力机制模块添加到YoloV5工程项目中的yolo.py中,可参考以下四种情况。
以下4个elif代码来自https://yolov5.blog.csdn.net/article/details/129108082
elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:args = [*args[:]]elif m in [CoordAtt, GAMAttention]:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:c1 = ch[f]args = [c1, *args[0:]]elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:c1 = ch[f]args = [c1]
根据这4种情况,我们在yaml文件中,填写args时(比如下图中RefConv的[1024,3,1]以及SE中的[1024]),需要填入的参数个数是不同的
# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]], #2[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]], #4[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]], #6[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]], # 8[-1,1,SE,[1024]], #9[[-1, 1, SimAM, [1e-4]], # 10[-1, 1, SPPF, [1024, 5]], # 11[-1, 1, RefConv, [1024, 3, 1]], # 12]
具体来说,是要将elif模块添加到yolo.py文件中的parse_model函数里。在编写elif模块代码时,我们需要关注的是,你的注意力机制模块代码(在common.py)中的__init__里面的参数,是否设置了“输入通道数”和“输出通道数”,这两个参数。
一、先上结论
情况1:_init_中不包含“输入通道数”和“输出通道数”,但含有其它参数
以下这些模块:SimAM,ECA,SpatialGroupEnhance,TripletAttention 全都满足情况1
elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:args = [*args[:]]
以下是SimAM,ECA,SpatialGroupEnhance,TripletAttention 的__init__
class SimAM(torch.nn.Module):def __init__(self, e_lambda=1e-4):class ECA(nn.Module):def __init__(self, k_size=3):class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):class SpatialGroupEnhance(nn.Module):def __init__(self, groups=8):class TripletAttention(nn.Module):def __init__(self, no_spatial=False):
这里解释一下args = [*args[:]]是什么意思
在Python中,*args 是一个特殊的语法,用于在函数定义中处理不确定数量的位置参数。args 是一个元组,包含了所有传递给函数的位置参数。
args[:] 是一个切片操作,它会创建一个 args 的浅拷贝。这意味着如果你修改了 args[:] 的内容,原始的 args 不会被改变。
[*args[:]] 则是将 args[:] 中的元素解包(unpack)成一个列表。这样做的目的通常是为了创建一个新的列表,而不是修改原始的 args。
例如:
args = [1, 2, 3, 4]
new_args = [*args[:]]
print(new_args) #输出为[1,2,3,4]
情况2:_init_中同时包含“输入通道数”和“输出通道数”,且含有其它参数
以下这些模块:CoordAtt, GAMAttention 全都满足情况2
elif m in [CoordAtt, GAMAttention]:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]
以下是CoordAtt, GAMAttention 的__init__
class CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):class GAMAttention(nn.Module):def __init__(self, c1, c2, group=True, rate=4):
情况3:_init_中只包含“输入通道数”,不包含“输出通道数”,且含有其它参数
以下这些模块:SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA全都满足情况3
elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:c1 = ch[f]args = [c1, *args[0:]]
以下是SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA的__init__
#SE机制它的输入通道和输出通道是一样的,所以在实现上可以只传入输入通道c1,但如果也给出输出通道c2的参数也是可以的。下面这两种都是在迪菲赫尔曼博客中实现过的SE模块
class SE(nn.Module):def __init__(self, c1, ratio=16):
class SE(nn.Module):def __init__(self, c1, c2, ratio=16):class ShuffleAttention(nn.Module):def __init__(self, channel=512, G=8):class CBAM(nn.Module):def __init__(self, c1, ratio=16, kernel_size=7):class SKAttention(nn.Module):def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):class DoubleAttention(nn.Module):def __init__(self, in_channels, reconstruct=True):class CoTAttention(nn.Module):def __init__(self, dim=512, kernel_size=3):class EffectiveSEModule(nn.Module):def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):class GlobalContext(nn.Module):def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,rd_ratio=1. / 8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
class GatherExcite(nn.Module):def __init__(self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,rd_ratio=1. / 16, rd_channels=None, rd_divisor=1, add_maxpool=False,act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):class MHSA(nn.Module):def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
情况4:_init_中只包含“输入通道数”,不包含“输出通道数”,且不含有其他参数(注意对比情况3)
以下这些模块:S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention全都满足情况4
elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:c1 = ch[f]args = [c1]
以下是S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention的__init__
class S2Attention(nn.Module):def __init__(self, channels=512):class NAMAttention(nn.Module):def __init__(self, channels):class CrissCrossAttention(nn.Module):def __init__(self, in_dim):class SequentialPolarizedSelfAttention(nn.Module):def __init__(self, channel=512):class ParallelPolarizedSelfAttention(nn.Module):def __init__(self, channel=512):class ParNetAttention(nn.Module):def __init__(self, channel=512):
二、解释代码
在理解代码前,我们需要知道,在parse_model函数中,args列表的前两个位置被设计为存放输入通道数(c1)和输出通道数(c2)。这是因为在创建这些模块时,我们通常会按照这个顺序传递参数。例如,对于nn.Conv2d,其构造函数的签名为Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True),其中in_channels和out_channels就对应于我们的c1和c2。因此,当我们在parse_model函数中创建这些模块时,我们需要先获取c1和c2,然后将它们放在args的前两个位置,以确保它们能被正确地传递给模块的构造函数。
情况1
'''
这段代码的目的是为了处理SimAM模块的参数。elif m in [SimAM]:
这行代码检查当前的模块m是否是SimAM模块。如果是,那么就执行下一行代码。args = [*args[:]]:
这行代码创建了args的一个浅拷贝。在Python中,args[:]会创建一个新的列表,这个新列表包含了args中的所有元素。*操作符会将这个新列表解包,然后我们再用[]将解包后的元素重新组装成一个新的列表。所以,[*args[:]]就等于args[:],它们都会创建args的一个浅拷贝。综上,这段代码使args列表保持不变,因为SimAM模块不需要修改输入和输出通道数。
'''elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:args = [*args[:]]
情况2
'''
这段代码的目的是为了处理CoordAtt和GAMAttention模块的参数。elif m in [CoordAtt, GAMAttention]:
这行代码检查当前的模块m是否是CoordAtt模块或GAMAttention模块。如果是,那么就执行下面的代码。c1, c2 = ch[f], args[0]:
这行代码从ch和args两个列表中获取了两个值c1和c2。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。args也是一个列表,它保存了当前层的参数。在模型配置文件中,每一层的参数都被保存在一个列表中,例如[64, 6, 2, 2]。这个列表的第一个元素通常是当前层的输出通道数。所以args[0]就是当前层的输出通道数。if c2 != no:
这行代码检查当前层的输出通道数c2是否不等于no。no是模型的总输出通道数。c2 = make_divisible(c2 * gw, 8):
如果c2不等于no,那么就重新计算c2的值。gw是模型的宽度倍数,make_divisible(c2 * gw, 8)会将c2 * gw调整为最接近的8的倍数。这是因为某些硬件(如GPU)在处理通道数为8的倍数的数据时,可以获得更好的性能。args = [c1, c2, *args[1:]]:
这行代码创建了一个新的参数列表args。新的args列表的第一个元素是c1,第二个元素是c2,剩下的元素是原始args列表的第二个元素及其后面的所有元素。*args[1:]是Python的解包(unpack)操作,它可以将列表args[1:]中的所有元素解包出来。所以,[c1, c2, *args[1:]]就等于[c1, c2]和args[1:]两个列表的连接。
'''elif m in [CoordAtt, GAMAttention]:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]
情况3
'''
这段代码的目的是为了处理SE, ShuffleAttention等模块的参数。elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:
这行代码检查当前的模块m是否是SE, ShuffleAttention等模块中的一个。如果是,那么就执行下面的代码。c1 = ch[f]:
这行代码从ch列表中获取了一个值c1。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。args = [c1, *args[0:]]:
这行代码创建了一个新的参数列表args。新的args列表的第一个元素是c1,剩下的元素是原始args列表的第一个元素及其后面的所有元素。*args[0:]是Python的解包(unpack)操作,它可以将列表args[0:]中的所有元素解包出来。所以,[c1, *args[0:]]就等于[c1]和args[0:]两个列表的连接。综上,这段代码将当前层的输入通道数c1添加到参数列表args的开始位置,因为这些模块的初始化函数通常需要输入通道数作为第一个参数。'''elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:c1 = ch[f]args = [c1, *args[0:]]
情况4
'''
这段代码的目的是为了处理S2Attention, NAMAttention等模块的参数。elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:
这行代码检查当前的模块m是否是S2Attention, NAMAttention等模块中的一个。如果是,那么就执行下面的代码。c1 = ch[f]:
这行代码从ch列表中获取了一个值c1。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。args = [c1]:这行代码创建了一个新的参数列表args。新的args列表只有一个元素,就是c1。综上,这段代码将当前层的输入通道数c1作为唯一的参数传递给这些模块,因为这些模块的初始化函数通常只需要输入通道数作为参数。
'''elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:c1 = ch[f]args = [c1]