前面研究了一下YOLOX的网络结构,在YOLOv5(tag7.0)集成了yolox的骨干网络,现在继续下一步集成YOLOX的head模块。YOLOX的head模块是双分支解耦合网络,把目标置信度的预测和目标的位置预测分成两条支路,并验证双分支解耦合头性能要优于单分支耦合头。
1、关于双分支解耦合头decouplehead
主要是个人理解。这里的双分支解耦合和单分支耦合是根据网络的结构和承担的任务来分配的,具体地讲,就是head模块是几条之路,如果是一条支路,那目标的类别和位置预测任务就都在这一条支路上进行,类别分类和位置预测使用同一套权重进行预测,对于这两个任务来说,这种单分支网络结构是耦合的,我理解的就是绑定在一块,经过同一个结构得到不同任务的结果。
基于任务特点,一个是关注类别信息,一个是关注位置信息,在不同的支路上完成这两个任务,这样的网络就可以针对不同任务独享各自的权重,但这两个任务都是为了表达同一个目标的信息,是什么,在哪里,所以这种双分支的结构对于同一个目标来说是解耦合的,decouple.
了解上述区别后,而只能使用单分支的yolov5已经性能显著了。那加上双分支是不是效果会更好呢,这非常让人期待,于是实践操作一下就可以进行验证亲自感受一下。而且yolov8使用的也是解耦合头和anchor free方式取得了更好的效果,这也让我们看到了做这件事的意义。
在实践操作的过程中,我发现直接在yolov5中复现yolox的head有点繁琐,因为需要直接把 单分支、anchor-based 改为 双分支、anchoe-free,所以我先把任务拆分,先实现 双分支、anchor-based.
2、YOLOX的head结构
这个可以直接看yolox的官方代码,找到YOLOXHead的定义代码。 其中__init__ 定义了head模块的基础结构,forward则是定义网络结构连接的方式,用以约束数据运算。
class YOLOXHead(nn.Module):def __init__(self,num_classes,width=1.0,strides=[8, 16, 32],in_channels=[256, 512, 1024],act="silu",depthwise=False,):"""Args:act (str): activation type of conv. Defalut value: "silu".depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False."""super().__init__()self.num_classes = num_classesself.decode_in_inference = True # for deploy, set to Falseself.cls_convs = nn.ModuleList()self.reg_convs = nn.ModuleList()self.cls_preds = nn.ModuleList()self.reg_preds = nn.ModuleList()self.obj_preds = nn.ModuleList()self.stems = nn.ModuleList()Conv = DWConv if depthwise else BaseConvfor i in range(len(in_channels)):self.stems.append(BaseConv(in_channels=int(in_channels[i] * width),out_channels=int(256 * width),ksize=1,stride=1,act=act,))self.cls_convs.append(nn.Sequential(*[Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),]))self.reg_convs.append(nn.Sequential(*[Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),]))self.cls_preds.append(nn.Conv2d(in_channels=int(256 * width),out_channels=self.num_classes,kernel_size=1,stride=1,padding=0,))self.reg_preds.append(nn.Conv2d(in_channels=int(256 * width),out_channels=4,kernel_size=1,stride=1,padding=0,))self.obj_preds.append(nn.Conv2d(in_channels=int(256 * width),out_channels=1,kernel_size=1,stride=1,padding=0,))self.use_l1 = Falseself.l1_loss = nn.L1Loss(reduction="none")self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")self.iou_loss = IOUloss(reduction="none")self.strides = stridesself.grids = [torch.zeros(1)] * len(in_channels)def initialize_biases(self, prior_prob):for conv in self.cls_preds:b = conv.bias.view(1, -1)b.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)for conv in self.obj_preds:b = conv.bias.view(1, -1)b.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)def forward(self, xin, labels=None, imgs=None):outputs = []origin_preds = []x_shifts = []y_shifts = []expanded_strides = []for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(zip(self.cls_convs, self.reg_convs, self.strides, xin)):x = self.stems[k](x)cls_x = xreg_x = xcls_feat = cls_conv(cls_x)cls_output = self.cls_preds[k](cls_feat)reg_feat = reg_conv(reg_x)reg_output = self.reg_preds[k](reg_feat)obj_output = self.obj_preds[k](reg_feat)
从上述代码可知,head模块根据前面输出的三个维度的特征图数据进行网络构建,根据数据的维度具体定义网络的结构,具体方法是通过一个 for 循环来构建基础的结构:self.stem, self.cls_convs, self.reg_convs, self.cls_preds, self.reg_preds和self.obj_preds.通过看forward函数可以了解这些模块是怎么衔接的,进而可以按照自己的方式灵活的复现出一样的网络结构。
单独实例化YOLOXHead()得到的网络结构:
"""YOLOXHead((cls_convs): ModuleList((0): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(1): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(2): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))))(reg_convs): ModuleList((0): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(1): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(2): Sequential((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))))(cls_preds): ModuleList((0): Conv2d(256, 80, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(256, 80, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(256, 80, kernel_size=(1, 1), stride=(1, 1)))(reg_preds): ModuleList((0): Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1)))(obj_preds): ModuleList((0): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))(1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))(2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)))(stems): ModuleList((0): BaseConv((conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): BaseConv((conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(l1_loss): L1Loss()(bcewithlog_loss): BCEWithLogitsLoss()(iou_loss): IOUloss()
)
"""
3、重构YOLOX的DecoupledHead
基于yolov5的网络构建方式和自己的理解,重新根据网络的结构写了一下。我写的结构里 由于涉及到锚框,所以这里还不是完全体的yolox的head结构,只写了一条线的数据处理结构,因为yolov5的框架head模块要加入锚框,还会再定义。
class DecoupledHead(nn.Module):def __init__(self,in_channels=[256, 512, 1024],num_classes=80,width=0.5,anchors=(),act="silu",depthwise=False,prior_prob=1e-2,):super().__init__()self.num_classes = num_classes# self.nl = len(anchors)self.na = len(anchors[0]) // 2# self.in_channels = in_channels# import ipdb;ipdb.set_trace()Conv = DWConv if depthwise else BaseConvself.stems=BaseConv(in_channels=int(in_channels * width),out_channels=int(256 * width),ksize=1,stride=1,act=act,)self.cls_convs=nn.Sequential(Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),)self.cls_preds=nn.Conv2d(in_channels=int(256 * width),out_channels=self.num_classes * self.na,kernel_size=1,stride=1,padding=0,)self.reg_convs=nn.Sequential(Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),Conv(in_channels=int(256 * width),out_channels=int(256 * width),ksize=3,stride=1,act=act,),)self.reg_preds=nn.Conv2d(in_channels=int(256 * width),out_channels=4 * self.na,kernel_size=1,stride=1,padding=0,)self.obj_preds=nn.Conv2d(in_channels=int(256 * width),out_channels=1 * self.na,kernel_size=1,stride=1,padding=0,)#没用上初始化函数def initialize_biases(self):prior_prob = self.prior_probfor conv in self.cls_preds:b = conv.bias.view(1, -1)b.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)for conv in self.obj_preds:b = conv.bias.view(1, -1)b.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)def forward(self,x):# import ipdb;ipdb.set_trace()x = self.stems(x)cls_x = xreg_x = xcls_feat = self.cls_convs(cls_x)cls_output = self.cls_preds(cls_feat)reg_feat = self.reg_convs(reg_x)reg_output = self.reg_preds(reg_feat)obj_output = self.obj_preds(reg_feat)# out = torch.cat([cls_output,reg_output,obj_output], 1)out = torch.cat([reg_output,obj_output,cls_output], 1)return out
可以看到YOLOXHead的 init 和 forward() 都包含两个for循环,分别用来构建head结构和处理数据,为了避免麻烦和匹配yolov5的框架,我写了单流程结构。通过后面再在yolo.py进一步实现yoloxhead的双分支结构。
yolo.py下再定义的头部结构代码:
class DetectDcoupleHead(nn.Module):stride = None # strides computed during builddynamic = False # force grid reconstructionexport = False # export modedef __init__(self, nc=80, anchors=(), width=1.0, ch=(), inplace=True):super().__init__()self.prior_prob = 1e-2self.in_ch = [256, 512, 1024]self.nc = nc # number of classesself.width = widthself.no = nc + 5 # number of outputs per anchorself.nl = len(anchors) # number of detection layersself.na = len(anchors[0]) // 2 # number of anchorsself.grid = [torch.empty(0) for _ in range(self.nl)] # init gridself.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor gridself.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)# self.DecoupledHead = DecoupledHead()self.m = nn.ModuleList(DecoupledHead(x, self.nc, self.width, anchors) for x in self.in_ch) # output convself.inplace = inplace # use inplace ops (e.g. slice assignment)def forward(self, x):z = [] # inference output# import ipdb;ipdb.set_trace()for i in range(self.nl):x[i] = self.m[i](x[i]) # convbs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)# import ipdb;ipdb.set_trace()x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()# if not self.training: # inferenceif self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)if isinstance(self, Segment): # (boxes + masks)xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xywh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # why = torch.cat((xy, wh, conf.sigmoid(), mask), 4)else: # Detect (boxes only)xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)xy = (xy * 2 + self.grid[i]) * self.stride[i] # xywh = (wh * 2) ** 2 * self.anchor_grid[i] # why = torch.cat((xy, wh, conf), 4)z.append(y.view(bs, self.na * nx * ny, self.no))return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):d = self.anchors[i].devicet = self.anchors[i].dtypeshape = 1, self.na, ny, nx, 2 # grid shapey, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibilitygrid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)return grid, anchor_griddef initialize_biases(self):prior_prob = self.prior_probfor i in range(self.nl):conv_cls = self.m[i].cls_predsb = conv_cls.bias.view(1, -1)b.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv_cls.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)conv_obj = self.m[i].obj_predsb_obj = conv_obj.bias.view(1, -1)b_obj.data.fill_(-math.log((1 - prior_prob) / prior_prob))conv_obj.bias = torch.nn.Parameter(b_obj.view(-1), requires_grad=True)
在DetectDcoupleHead的定义中可以发现,self.m是最终的yolox的双分支head结构,它通过输入数据的维度分别重新构建了三个decoupledhead的结构,效果完全与yoloxhead一致,即下面这行代码
self.m = nn.ModuleList(DecoupledHead(x, self.nc, self.width, anchors) for x in self.in_ch)
另外,定义DetectDcoupleHead可以更好的解决锚框铺设问题,并且使得该头部模块可以在yaml配置文件中进行参数配置,最终实现yoloxhead的双分支、anchor-based结构
想要可以训练,还要再完善一些细节,主要是yolo.py中的网络构建代码
"""parse_model中添加:"""
elif m in {DetectDcoupleHead}:"""args是yaml配置文件的字典中每行的列表里模块后的参数"""# import ipdb;ipdb.set_trace()args.append([ch[x] for x in f])#append导致输入维度参数在最后一个位置if isinstance(args[1], int): # 锚框 number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)"""DetectionModel中添加:"""
if isinstance(m, DetectDcoupleHead):s = 256 # 2x min stridem.inplace = self.inplaceforward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forwardcheck_anchor_order(m)m.anchors /= m.stride.view(-1, 1, 1)self.stride = m.stridem.initialize_biases()
做了上面的这些修改工作,基本上就是把yolox的head结构复现到yolov5上了。如下是复现构建后的head模块的结构,可以发现self.m与yoloxhead的结构是一致的。
"""
self.m
ModuleList((0): DecoupledHead((stems): BaseConv((conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(cls_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(cls_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))(reg_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(reg_preds): Conv2d(128, 12, kernel_size=(1, 1), stride=(1, 1))(obj_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)))(1): DecoupledHead((stems): BaseConv((conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(cls_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(cls_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))(reg_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(reg_preds): Conv2d(128, 12, kernel_size=(1, 1), stride=(1, 1))(obj_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)))(2): DecoupledHead((stems): BaseConv((conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(cls_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(cls_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))(reg_convs): Sequential((0): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): BaseConv((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act): SiLU(inplace=True)))(reg_preds): Conv2d(128, 12, kernel_size=(1, 1), stride=(1, 1))(obj_preds): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)))
)"""
4、总结
其实做完这些工作回过头来,发现事情其实很简单,我们如果自己想要弄一个双分支的head结构,直接自己写一个,或者在原来的基础直接改就可以了,把一些数据维度上,让网络正常跑起来就行。这样做的话主要是一个学习理解的过程。
如下是本文复现工作的训练截图:
训练结果还是有一定优势的,双分支比单分支效果更好一些,有高几个点。
后续我打算把双分支、anchor-free也实现一下。目前已经做了一点工作,已经基本上跑起来了,anchor-free其实目前的代码没什么特殊的,预测的数据都基本上一样,不过训练的是一种位置解码方式,用一种方式来表达位置信息,同时去掉生成锚框的操作,但是会基于每个特征图的特征点进行预测,本质和锚框也差不多。