DETR目标检测模型训练自己的数据集

前言

基础环境:ubuntu20.04、python=3.8、pytorch:1.10.0、CUDA:11.3
代码地址:https://github.com/facebookresearch/detr

目录

  • 一、训练准备
    • 1、预训练模型下载
    • 2、txt文件转为coco模式
  • 二、修改训练模型参数
  • 三、开始训练
  • 四、实现DETR的推理

一、训练准备

1、预训练模型下载

下载地址:https://github.com/facebookresearch/detr?tab=readme-ov-file

在这里插入图片描述

下载后放到detr目录下

在这里插入图片描述

2、txt文件转为coco模式

根据图片和txt标签文件生成json数据,命名分别为instances_train2017.json和instances_val2017.json, 保存annotations文件夹下,train2017和val2107中存放训练集图片和验证集图片,文件夹结构如下:
在这里插入图片描述

第一步:创建people.names类别文件

在这里插入图片描述

第二步:格式转换,脚本如下:

import os
import json
import cv2
import random
import time
from PIL import Imagecoco_format_save_path='/root/detr/data/Crowdhuman/images/annotations'    #要生成的标准coco格式标签所在文件夹
yolo_format_classes_path='/root/detr/data/Crowdhuman/images/people.names'     #类别文件,一行一个类
yolo_format_annotation_path='/root/detr/data/Crowdhuman/labels/train'        #yolo格式标签所在文件夹
img_pathDir='/root/detr/data/Crowdhuman/images/train2017'                        #图片所在文件夹with open(yolo_format_classes_path,'r') as fr:                               #打开并读取类别文件lines1=fr.readlines()
# print(lines1)
categories=[]                                                                 #存储类别的列表
for j,label in enumerate(lines1):label=label.strip()categories.append({'id':j+1,'name':label,'supercategory':'None'})         #将类别信息添加到categories中
# print(categories)write_json_context=dict()                                                      #写入.json文件的大字典
write_json_context['info']= {'description': '', 'url': '', 'version': '', 'year': 2021, 'contributor': '', 'date_created': '2021-07-25'}
write_json_context['licenses']=[{'id':1,'name':None,'url':None}]
write_json_context['categories']=categories
write_json_context['images']=[]
write_json_context['annotations']=[]#接下来的代码主要添加'images'和'annotations'的key值
imageFileList=os.listdir(img_pathDir)                                           #遍历该文件夹下的所有文件,并将所有文件名添加到列表中
for i,imageFile in enumerate(imageFileList):imagePath = os.path.join(img_pathDir,imageFile)                             #获取图片的绝对路径image = Image.open(imagePath)                                               #读取图片,然后获取图片的宽和高W, H = image.sizeimg_context={}                                                              #使用一个字典存储该图片信息#img_name=os.path.basename(imagePath)                                       #返回path最后的文件名。如果path以/或\结尾,那么就会返回空值img_context['file_name']=imageFileimg_context['height']=Himg_context['width']=Wimg_context['date_captured']='2021-07-25'img_context['id']=i                                                         #该图片的idimg_context['license']=1img_context['color_url']=''img_context['flickr_url']=''write_json_context['images'].append(img_context)                            #将该图片信息添加到'image'列表中txtFile = imageFile.rsplit('.', 1)[0] + '.txt'                                              #获取该图片获取的txt文件,这个数字"6"要根据自己图片名修改with open(os.path.join(yolo_format_annotation_path,txtFile),'r') as fr:lines=fr.readlines()                                                   #读取txt文件的每一行数据,lines2是一个列表,包含了一个图片的所有标注信息for j,line in enumerate(lines):bbox_dict = {}                                                          #将每一个bounding box信息存储在该字典中# line = line.strip().split()# print(line.strip().split(' '))class_id,x,y,w,h=line.strip().split(' ')                                          #获取每一个标注框的详细信息class_id,x, y, w, h = int(class_id), float(x), float(y), float(w), float(h)       #将字符串类型转为可计算的int和float类型xmin=(x-w/2)*W                                                                    #坐标转换ymin=(y-h/2)*Hxmax=(x+w/2)*Wymax=(y+h/2)*Hw=w*Wh=h*Hbbox_dict['id']=i*10000+j                                                         #bounding box的坐标信息bbox_dict['image_id']=ibbox_dict['category_id']=class_id+1                                               #注意目标类别要加一bbox_dict['iscrowd']=0height,width=abs(ymax-ymin),abs(xmax-xmin)bbox_dict['area']=height*widthbbox_dict['bbox']=[xmin,ymin,w,h]bbox_dict['segmentation']=[[xmin,ymin,xmax,ymin,xmax,ymax,xmin,ymax]]write_json_context['annotations'].append(bbox_dict)                               #将每一个由字典存储的bounding box信息添加到'annotations'列表中name = os.path.join(coco_format_save_path,"instances_train2017"+ '.json')
with open(name,'w') as fw:                                                                #将字典信息写入.json文件中json.dump(write_json_context,fw,indent=2)

备注:必须严格按照笔者图中的文件命名方式进行命名,训练集清一色命名为instances_train2017.json,验证集清一色命名为instances_val2017.json,这是模型本身的命名要求,用户需要严格遵守。

二、修改训练模型参数

第一步:先在目录中新建python脚本文件detr_r50_tf.py,代码如下:

import torchpretrained_weights = torch.load('detr-r50-e632da11.pth')num_class = 2  # 类别数+1, 因为背景也算一个
pretrained_weights["model"]["class_embed.weight"].resize_(num_class + 1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class + 1)
torch.save(pretrained_weights, "detr-r50_%d.pth" % num_class)

第二步:将其中类别数改成自己数据集的类别数即可,执行完成后会在目录下生成适合自己数据集类别的预训练模型:

在这里插入图片描述

第三步:然后在models文件夹下打开detr.py,修改其中的类别数:

在这里插入图片描述

第四步:打开main.py,修改其中的coco_path(数据存放路径)、output_dir(结果输出路径)、device(没有cuda就改为cpu)、resume(自己生成的预训练模型)

在这里插入图片描述

第五步:修改epochs数

在这里插入图片描述

三、开始训练

运行python main.py

跑起来的效果是这样的:

在这里插入图片描述

四、实现DETR的推理

将要预测的图片保存在一个文件夹下,预测时一次输出所有图片的预测结果,代码如下:

import argparse
import random
import time
from pathlib import Path
import numpy as np
import torch
from models import build_model
from PIL import Image
import os
import torchvision
from torchvision.ops.boxes import batched_nms
import cv2# 设置参数
def get_args_parser():parser = argparse.ArgumentParser('Set transformer detector', add_help=False)parser.add_argument('--lr', default=1e-4, type=float)parser.add_argument('--lr_backbone', default=1e-5, type=float)parser.add_argument('--batch_size', default=2, type=int)parser.add_argument('--weight_decay', default=1e-4, type=float)parser.add_argument('--epochs', default=300, type=int)parser.add_argument('--lr_drop', default=200, type=int)parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm')# Model parametersparser.add_argument('--frozen_weights', type=str, default=None, help="Path to the pretrained model. If set, only the mask head will be trained")parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use")parser.add_argument('--dilation', action='store_true', help="If true, we replace stride with dilation in the last convolutional block (DC5)")parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")# Transformerparser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")parser.add_argument('--dim_feedforward', default=2048, type=int, help="Intermediate size of the feedforward layers in the transformer blocks")parser.add_argument('--hidden_dim', default=256, type=int, help="Size of the embeddings (dimension of the transformer)")parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads inside the transformer's attentions")parser.add_argument('--num_queries', default=100, type=int, help="Number of query slots")parser.add_argument('--pre_norm', action='store_true')# Segmentationparser.add_argument('--masks', action='store_true', help="Train segmentation head if the flag is provided")# Lossparser.add_argument('--no_aux_loss', dest='aux_loss', default='False', help="Disables auxiliary decoding losses (loss at each layer)")# Matcherparser.add_argument('--set_cost_class', default=1, type=float, help="Class coefficient in the matching cost")parser.add_argument('--set_cost_bbox', default=5, type=float, help="L1 box coefficient in the matching cost")parser.add_argument('--set_cost_giou', default=2, type=float, help="giou box coefficient in the matching cost")# Loss coefficientsparser.add_argument('--mask_loss_coef', default=1, type=float)parser.add_argument('--dice_loss_coef', default=1, type=float)parser.add_argument('--bbox_loss_coef', default=5, type=float)parser.add_argument('--giou_loss_coef', default=2, type=float)parser.add_argument('--eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class")# dataset parametersparser.add_argument('--dataset_file', default='coco')parser.add_argument('--coco_path', type=str, default="/root/detr/data/Crowdhuman/coco")parser.add_argument('--coco_panoptic_path', type=str)parser.add_argument('--remove_difficult', action='store_true')parser.add_argument('--output_dir', default='/root/detr/inference_demo/inference_output', help='path where to save, empty for no saving')parser.add_argument('--device', default='cuda', help='device to use for training / testing')parser.add_argument('--seed', default=42, type=int)parser.add_argument('--resume', default='/root/detr/data/output/checkpoint.pth', help='resume from checkpoint')parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')parser.add_argument('--eval', default="True")parser.add_argument('--num_workers', default=2, type=int)# distributed training parametersparser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')return parserdef box_cxcywh_to_xyxy(x):# 将DETR的检测框坐标(x_center,y_center,w,h)转化成coco数据集的检测框坐标(x0,y0,x1,y1)x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):# 把比例坐标乘以图像的宽和高,变成真实坐标img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef filter_boxes(scores, boxes, confidence=0.7, apply_nms=True, iou=0.5):# 筛选出真正的置信度高的框keep = scores.max(-1).values > confidencescores, boxes = scores[keep], boxes[keep]if apply_nms:top_scores, labels = scores.max(-1)keep = batched_nms(boxes, top_scores, labels, iou)scores, boxes = scores[keep], boxes[keep]return scores, boxes# COCO classes
CLASSES = ['N/A', 'pedestrian']# 生成随机颜色的函数
def random_color():return [random.randint(0, 255) for _ in range(3)]# 创建类别颜色字典
COLORS = {cls: random_color() for cls in CLASSES}def plot_one_box(x, img, color=None, label=None, line_thickness=2):# 把检测框画到图片上tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thicknesscolor = [255, 0, 0]  # 固定为红色c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)if label:tf = max(tl - 1, 1)  # font thicknesst_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filledcv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)def main(args):print(args)device = torch.device(args.device)# 导入网络# 下面的criterion是算损失函数要用的,推理用不到,postprocessors是解码用的,这里也没有用,用的是自己的。model, criterion, postprocessors = build_model(args)# 加载权重checkpoint = torch.load(args.resume, map_location='cuda')model.load_state_dict(checkpoint['model'])# 把权重加载到gpu或cpu上model.to(device)# 打印出网络的参数大小n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)print("parameters:", n_parameters)# 设置好存储输出结果的文件夹output_dir = Path(args.output_dir)# 读取数据集,进行推理image_Totensor = torchvision.transforms.ToTensor()image_file_path = os.listdir("inference_demo/detect_demo")image_set = []for image_item in image_file_path:print("inference_image:", image_item)image_path = os.path.join("inference_demo/detect_demo", image_item)image = Image.open(image_path)image_tensor = image_Totensor(image)image_tensor = torch.reshape(image_tensor, [-1, image_tensor.shape[0], image_tensor.shape[1], image_tensor.shape[2]])image_tensor = image_tensor.to(device)time1 = time.time()inference_result = model(image_tensor)time2 = time.time()print("inference_time:", time2 - time1)probas = inference_result['pred_logits'].softmax(-1)[0, :, :-1].cpu()bboxes_scaled = rescale_bboxes(inference_result['pred_boxes'][0, ].cpu(), (image_tensor.shape[3], image_tensor.shape[2]))scores, boxes = filter_boxes(probas, bboxes_scaled)scores = scores.data.numpy()boxes = boxes.data.numpy()for i in range(boxes.shape[0]):class_id = scores[i].argmax()label = CLASSES[class_id]confidence = scores[i].max()text = f"{label} {confidence:.3f}"image = np.array(image)plot_one_box(boxes[i], image, color=COLORS[label], label=text)# cv2.imshow("images", image)# cv2.waitKey(1)image = Image.fromarray(image)image.save(os.path.join(args.output_dir, image_item))if __name__ == '__main__':parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])args = parser.parse_args()if args.output_dir:Path(args.output_dir).mkdir(parents=True, exist_ok=True)main(args)

需要修改的参数有:

1、使用训练时已经下载好了主干特征网络是Resnet50的DETR权重文件,放在主文件夹下

在这里插入图片描述

2、数据集有关参数
–coco_path 修改为自己的数据集路径
–outputdir 修改为建立的预测图片的保存文件夹
–resume 修改为训练好的模型文件路径

在这里插入图片描述

3、修改待预测的图片文件夹路径image_file_path和image_path

在这里插入图片描述
4、修改类别,根据自己实际情况定义

在这里插入图片描述

备注:由于我用服务器跑,无法传回图片而出现一个报错,于是把这两句注释掉了:
在这里插入图片描述

预测结果:

在这里插入图片描述

参考:
1、【DETR】训练自己的数据集-实践笔记
2、 pytorch实现DETR的推理程序
3、 DETR实现目标检测(一)-训练自己的数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48937.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案 一, 文章简介二,硬件平台构建2.1 音频源板2.2 音频收发板2.3 双板硬件连接 三,软件方案与软件实现3.1 方案实现3.2 软件代码实现3.2.1 4路I2S接收3.2.2 I2S DMA pingpong配置3.2.3 音频数…

Python自动化批量下载ECWMF和GFS最新预报数据脚本

一、白嫖EC和GFS预报数据 EC的openData部分公开了一部分预报数据,作为普通用户只能访问这些免费预报数据,具体位置在这 可以发现,由于是Open Data,我们只能获得临近四天的预报结果,虽然时间较短,但是我们…

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转!这一次,做两件事。第一件事是把二级分类页面的跳转(也就是路由)设计一下。第二件事是把二级页面的面包屑导航设计一下。 第一件事,二级页面的跳转路由设计一下。 如…

Python爬虫(4) --爬取网页图片

文章目录 爬虫爬取图片指定url发送请求获取想要的数据数据解析定位想要内容的位置存放图片 完整代码实现总结 爬虫 Python 爬虫是一种自动化工具,用于从互联网上抓取网页数据并提取有用的信息。Python 因其简洁的语法和丰富的库支持(如 requests、Beaut…

科普文:后端性能优化的实战小结

一、背景与效果 ICBU的核心沟通场景有了10年的“积累”,核心场景的界面响应耗时被拉的越来越长,也让性能优化工作提上了日程,先说结论,经过这一波前后端齐心协力的优化努力,两个核心界面90分位的数据,FCP平…

Day05-readinessProbe探针,startupProbe探针,Pod生命周期,静态Pod,初始化容器,rc控制器的升级和回滚,rs控制器精讲

Day05-readinessProbe探针,startupProbe探针,Pod生命周期,静态Pod,初始化容器,rc控制器的升级和回滚,rs控制器精讲 0、昨日内容回顾1、readinessProbe可用性检查探针之exec案例2、可用性检查之httpGet案例3…

[数据集][目标检测]躺坐站识别检测数据集VOC+YOLO格式9488张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):9488 标注数量(xml文件个数):9488 标注数量(txt文件个数):9488 标注…

C语言 | Leetcode C语言题解之第242题有效的字母异位词

题目&#xff1a; 题解&#xff1a; bool isAnagram(char* s, char* t) {int len_s strlen(s), len_t strlen(t);if (len_s ! len_t) {return false;}int table[26];memset(table, 0, sizeof(table));for (int i 0; i < len_s; i) {table[s[i] - a];}for (int i 0; i &…

EMQX 跨域集群:增强可扩展性,打破地域限制

跨域集群的概念 提到 EMQX&#xff0c;人们通常首先会想到它的可扩展性。尽管 EMQX 能随着硬件数量的增加几乎实现线性扩展&#xff0c;但在单个计算实例上的扩展能力终究有限&#xff1a;资源总会耗尽&#xff0c;升级成本也会急剧上升。这时&#xff0c;分布式部署就显得尤为…

JavaScript(11)——对象

对象 声明&#xff1a; let 对象名 { 属性名&#xff1a;属性值, 方法名&#xff1a;函数 } let 对象名 new Object() 对象的操作 先创建一个对象 let op {name:jvav,id:4,num:1001} 查 对象名.属性 console.log(op.name) 对象名[属性名] 改 对象名.属性 新值 op.name …

Pytorch学习笔记day4——训练mnist数据集和初步研读

该来的还是来了hhhhhhhhhh&#xff0c;基本上机器学习的初学者都躲不开这个例子。开源&#xff0c;数据质量高&#xff0c;数据尺寸整齐&#xff0c;问题简单&#xff0c;实在太适合初学者食用了。 今天把代码跑通&#xff0c;趁着周末好好的琢磨一下里面的各种细节。 代码实…

Spring MVC的高级功能——拦截器(三)拦截器的执行流程

一、单个拦截器的执行流程 如果在项目中只定义了一个拦截器&#xff0c;单个拦截器的执行流程如图所示。 二、单个拦截器的执行流程分析 从单个拦截器的执行流程图中可以看出&#xff0c;程序收到请求后&#xff0c;首先会执行拦截器中的preHandle()方法&#xff0c;如果preHa…

bug诞生记——动态库加载错乱导致程序执行异常

大纲 背景问题发生问题猜测和分析过程是不是编译了本工程中的其他代码是不是有缓存是不是编译了非本工程的文件是不是调用了其他可执行文件查看CMakefiles分析源码检查正在运行程序的动态库 解决方案 这个案例发生在我研究ROS 2的测试Demo时发生的。 整体现象是&#xff1a;修改…

聊一聊前端动画的种类,以及动画的触发方式有哪些?

引言 动画在前端开发中扮演着重要的角色。它不仅可以提升用户体验&#xff0c;还可以使界面更加生动和有趣。在这篇文章中&#xff0c;我们将深入探讨前端动画的各种实现方式&#xff0c;包括 CSS 动画、JavaScript 动画、SVG 动画等。我们还将讨论一些触发动画的方式和动画在…

【MQTT(2)】开发一个客户端,ubuntu版本

基本流程如下&#xff0c;先生成Mosquitto的库&#xff0c;然后qt调用库进行开发界面。 文章目录 0 生成库1 有界面的QT版本2 无界面版本 0 生成库 下载源码&#xff1a;https://github.com/eclipse/mosquitto.git 编译ubuntu 版本很简单&#xff0c;安装官方说明直接make&am…

rk3568 OpenHarmony4.1 Launcher定制开发—桌面壁纸替换

Launcher 作为系统人机交互的首要入口&#xff0c;提供应用图标的显示、点击启动、卸载应用&#xff0c;并提供桌面布局设置以及最近任务管理等功能。本文将介绍如何使用Deveco Studio进行单独launcher定制开发、然后编译并下载到开发板&#xff0c;以通过Launcher修改桌面背景…

记录|如何打包C#项目

参考文章&#xff1a; c#窗体应用程序怎么打包 经过检验确实有效 Step1. 生成发布文件 在Visual Studio的菜单中&#xff0c;找到“生成”->“发布” 第一次会有个向导&#xff0c;基本上一路next下来既可以 最后&#xff0c;点击完成即可以 Step2. 获得publish文件 自…

【JavaEE】AQS原理

本文将介绍AQS的简单原理。 首先有个整体认识&#xff0c;全称是 AbstractQueuedSynchronizer&#xff0c;是阻塞式锁和相关的同步器工具的框架。常用的ReentrantLock、Semaphore、CountDownLatch等都有实现它。 本文参考&#xff1a; 深入理解AbstractQueuedSynchronizer只需…

[C++]TinyWebServer

TinyWebServer 文章目录 TinyWebServer1 主体框架2 Buffer2.1 向Buffer写入数据2.2 从Buffer读取数据2.3 动态扩容2.4 从socket中读取数据2.5 具体实现 3 日志系统3.1 生产者-消费者模型3.2 数据一致3.3 代码 4 定时器4.1 调整堆中元素操作4.2 堆的操作4.2.1 增4.2.2 删4.2.3 改…

微信小程序-应用,页面和组件生命周期总结

情景1&#xff1a;小程序冷启动时候的顺序 情景2: 使用navigator&#xff08;保留并打开另一个页面&#xff09;和redirect&#xff08;关闭并打开另一个页面&#xff09;的执行顺序 情景3&#xff1a;切后台和切前台