MMdetection3.0 训练DETR问题分析
针对在MMdetection3.0框架下训练DETR模型,验证集AP值一直为0.000的原因作出如下分析并得出结论。
条件:
1、NWPU-VHR-10数据集:共650张,训练:验证=611:39;
2、MMdetection3.0框架实验分析;
3、DETR原论文提供源代码实验分析;
4、已在代码中完成了数据类别定义(num_classes)等相关配置的修改。
分析:
1、在MMdetection3.0框架下,只是加载backbone的预训练权重,val上AP始终为0.0000.如下图所示:=》loss收敛较慢,val始终为0.0000.
2、在MMdetection3.0框架下,直接加载detr的完整预训练权重。如下图所示:=》存在警告(size mismatch for bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 256]) from checkpoint, the shape in current model is torch.Size([11, 256]).
size mismatch for bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([11]).
),但训练测试指标还算正常。
=》警告原因:自定义数据集的类别是10+1,而MMdetection3.0提供的是coco数据集与训练权重80+1.
=》因此,需要修改预训练模型的全连接层输出(见下述第4点)。
3、在MMdetection3.0框架下,直接加载修改后的detr的完整预训练权重训练测试结果见下图所示:=》警告消除,一切正常,并且修改证据权重类别后loss下降变快,val指标更好(不能说更好,只能说更正常)
4、修改模型权重参数脚本
=》代码中的METAINFO不想修改 不修改也行。
=》主要是pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.weight’].resize_(11, 256)
pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.bias’].resize_(11)
import torch
METAINFO = dict(CLASSES=('airplane','ship','storage tank','baseball diamond','tennis court','basketball court','ground track field','harbor','bridge','vehicle',),PALETTE=[(120,120,120,),(180,120,120,),(6,230,230,),(80,50,50,),(4,200,3,),(120,120,80,),(140,140,140,),(204,5,255,),(230,230,230,),(4,250,7,),])pretrained_weights = torch.load('/home/admin1/pywork/data/weigh/resnet50-0676ba61.pth')
# 11 是指 数据类别 + 1
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(11, 256)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(11)
pretrained_weights['meta']['experiment_name'] = 'detr_r50_8xb2-150e_coco_11'
pretrained_weights['meta']['dataset_meta'] = METAINFO
torch.save(pretrained_weights, "detr_r50_8xb2-150e_coco_%d.pth" % num_classes)
5、DETR原论文提供的源代码训练情况跟MMdetection3.0框架下的情况类似,都必须加载预训练模型,否则就是一直0.000000000000000.
总结分析:
1、NWPU-VHR-10数据量太小导致的问题(90%),等待进一步测试。
2、Transformer模型提出来的时候就已经说明很吃数据,所以没有足够的数据直接使用transformer训练往往效果不好,所以数据量不足的情况下,还是加载预训练权重吧。
3、backbone的权重在模型的比例其实很小,主要还是后面的编码、解码器,所以只加载backbone的权重也没什么用。
总之,数据、数据、数据要足够哇