detectron2 DiffusionDet 训练自己的数据集

配环境
git clone https://github.com/ShoufaChen/DiffusionDet# 创建环境
conda create -n diffusion python=3.9
conda activate diffusion
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install opencv-python# 安装detectron2
cd /data2/zy/DiffusionDet/
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2pip install timm # 不装就会报错 No module named 'timm' (diffusion) 
prepare datasets
mkdir -p datasets/coco
mkdir -p datasets/lvisln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017
ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017
修改配置文件等

复制一份train_net.py,命名为train.py,在其中添加下列代码注册数据集

#引入以下注释
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_coco_json
import pycocotools
#声明类别,尽量保持
CLASS_NAMES =["__background__","Inlet","Slightshort","Generalshort","Severeshort","Outlet"]
# 数据集路径
DATASET_ROOT = '/data2/zy/DiffusionDet/datasets/coco/'
ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')
TEST_PATH = os.path.join(DATASET_ROOT, 'test2017')TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')
TEST_JSON = os.path.join(ANN_ROOT, 'instances_test2017.json')# 声明数据集的子集
PREDEFINED_SPLITS_DATASET = {"coco_my_train": (TRAIN_PATH, TRAIN_JSON),"coco_my_val": (VAL_PATH, VAL_JSON),
}
#===========以下有两种注册数据集的方法,本人直接用的第二个plain_register_dataset的方式 也可以用register_dataset的形式==================
#注册数据集(这一步就是将自定义数据集注册进Detectron2)
def register_dataset():"""purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET"""for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():register_dataset_instances(name=key,json_file=json_file,image_root=image_root)#注册数据集实例,加载数据集中的对象实例
def register_dataset_instances(name, json_file, image_root):"""purpose: register dataset to DatasetCatalog,register metadata to MetadataCatalog and set attribute"""DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))MetadataCatalog.get(name).set(json_file=json_file,image_root=image_root,evaluator_type="coco")#=============================
# 注册数据集和元数据
def plain_register_dataset():#训练集DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH))MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES,  # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭evaluator_type='coco', # 指定评估方式json_file=TRAIN_JSON,image_root=TRAIN_PATH)#DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val"))#验证/测试集DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH))MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭evaluator_type='coco', # 指定评估方式json_file=VAL_JSON,image_root=VAL_PATH)
# 查看数据集标注,可视化检查数据集标注是否正确,
#这个也可以自己写脚本判断,其实就是判断标注框是否超越图像边界
#可选择使用此方法
def checkout_dataset_annotation(name="coco_my_val"):#dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH, name)dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH)print(len(dataset_dicts))for i, d in enumerate(dataset_dicts,0):#print(d)img = cv2.imread(d["file_name"])visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(name), scale=1.5)vis = visualizer.draw_dataset_dict(d)#cv2.imshow('show', vis.get_image()[:, :, ::-1])cv2.imwrite('out/'+str(i) + '.jpg',vis.get_image()[:, :, ::-1])#cv2.waitKey(0)if i == 200:break

main中调用注册函数

def main(args):cfg = setup(args)register_dataset() # here to registerif args.eval_only:model = Trainer.build_model(cfg)kwargs = may_get_ema_checkpointer(cfg, model)if cfg.MODEL_EMA.ENABLED:EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,resume=args.resume)else:DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS,resume=args.resume)res = Trainer.ema_test(cfg, model)if cfg.TEST.AUG.ENABLED:res.update(Trainer.test_with_TTA(cfg, model))if comm.is_main_process():verify_results(cfg, res)return restrainer = Trainer(cfg)trainer.resume_or_load(resume=args.resume)return trainer.train()

在 DiffisionDet/configs 下新建demo.yaml,主要是修改batchsize和max_iter

_BASE_: "Base-DiffusionDet.yaml"
MODEL:WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"RESNETS:DEPTH: 50STRIDE_IN_1X1: FalseDiffusionDet:NUM_PROPOSALS: 100NUM_CLASSES: 5
DATASETS:TRAIN: ("coco_my_train",)TEST:  ("coco_my_val",)
SOLVER:IMS_PER_BATCH: 16BASE_LR: 0.000025STEPS: (5850, 7000)MAX_ITER: 7500# TOTAL_NUM_IMAGES / (IMS_PER_BATCH * NUM_GPUS) * num_epochs = MAX_ITER# 2000/(16*1)*60=7500 
INPUT:MIN_SIZE_TRAIN: (800,)CROP:ENABLED: FalseFORMAT: "RGB"
OUTPUT_DIR: ./OUTPUT/bs16
训练
 python train.py --num-gpus 1     --config-file configs/diffdet.coco.res50.yaml

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/725128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32CubeMX学习笔记12 ---低功耗模式

在实际使用中很多产品都需要考虑低功耗的问题,STM32F10X提供了三种低功耗模式:睡眠模式(Sleep mode)、停机模式(Stop mode)和待机模式(Standby mode)。这些低功耗模式可以有效减少系…

jnitrace的用法(查看jni的执行流程,方便unidbg补环境)

一、简单执行 jnitrace -l <要trace的so库> <包名> jnitrace -l libxxx.so com.xxx.app二、插入js脚本执行 jnitrace -p E:\kill_sll.js -l libxxx.so com.xxx.app三、attach模式执行 默认使用spawn执行&#xff0c;attach模式可能有有时bug没反应 jnitrace -l li…

Casper Network(CSPR)即将迎来两项重大升级,以实现功能上的进一步完善

Casper Network&#xff08;CSPR&#xff09;即将实现更加完备的功能升级&#xff0c;现已进入倒计时阶段。 Casper Network&#xff08;CSPR&#xff09;将升级到其最先进以及更全的版本&#xff0c;即“功能完备”的版本&#xff0c;让Casper Network&#xff08;CSPR&#…

腾讯云十大优惠活动曝光,TOP10值得买云服务器配置报价

腾讯云服务器多少钱一年&#xff1f;61元一年起&#xff0c;2核2G3M配置&#xff0c;腾讯云2核4G5M轻量应用服务器165元一年、756元3年&#xff0c;4核16G12M服务器32元1个月、312元一年&#xff0c;8核32G22M服务器115元1个月、345元3个月&#xff0c;腾讯云服务器网txyfwq.co…

Java实现读取转码写入ES构建检索PDF等文档全栈流程

背景 之前已简单使用ES及Kibana和在线转Base64工具实现了检索文档的demo&#xff0c;并已实现WebHook的搭建和触发流程接口。 传送门&#xff1a; 基于GitBucket的Hook构建ES检索PDF等文档全栈方案 使用ES检索PDF、word等文档快速开始 实现读取本地文件入库ES 总体思路&…

索引类型介绍

4、说说你知道的MySQL的索引类型&#xff0c;并分别简述一下各自的场景。 普通索引&#xff1a;没有任何限制条件的索引&#xff0c;该索引可以在任何数据类型中创建。 唯一索引&#xff1a;使用UNIQUE参数可以设置唯一索引。创建该索引时&#xff0c;索引列的值必须唯一&…

44、网络编程/数据库相关操作练习20240306

一、代码实现数据库的创建&#xff08;员工信息表&#xff09;&#xff0c;并存储员工信息&#xff08;工号、姓名、薪资&#xff09;&#xff0c;能实现增加人员信息、删除人员信息、修改人员薪资操作。 代码&#xff1a; #include<myhead.h>int do_update(sqlite3 *p…

基于canvas纯前端实现验证码的绘制

验证码功能是实现登录功能中比较常见的一个问题 验证码的整体思路是&#xff1a; 1.前端登录页面发起获取验证码图片请求. 2.服务端收到请求后,生成一个唯一id,对应的验证码图片 以及验证码图片对应的值(这个值使用缓存保存,id-值一一对应,缓存可使用redis或本地缓存,本地缓存…

Python中的模块包第三方库详解

模块&包 模块 一个.py文件就是一个模块&#xff0c;里面是一些函数和变量&#xff0c;需要的时候可以导入。 模块命名规范: 1.以英文开头&#xff0c;不出现中文 2.模块名不应与系统内置函数重名 包 包本身就是一个文件夹&#xff0c;如果文件夹内有__init__.py文件&…

Java电梯模拟升级版

Java电梯模拟升级版 文章目录 Java电梯模拟升级版前言一、UML类图二、代码三、测试 前言 在上一版的基础上进行升级&#xff0c;楼层采用享元模式进行升级&#xff0c;并对楼层对象进一步抽象 一、UML类图 二、代码 电梯调度器抽象类 package cn.xx.evevator;import java.ut…

K倍区间 刷题笔记

法一 前缀和暴力搜索 &#xff08;数据大会超时&#xff09; #include<iostream> #include<cstring> #include<algorithm> #include<cstdio> using namespace std; const int N100010; int a[N],s[N]; int n,k; int main(){ cin>>n>>…

实现一个作用域插槽的场景

vue项目中&#xff0c;插槽slot有三种分别是&#xff1a;默认插槽、具名插槽、作用域插槽。默认插槽和具名插槽在平时的开发中用的比较多&#xff0c;作用域插槽用的相对较少&#xff0c;以前我对作用域插槽不是很理解&#xff0c;现在理解了一下。下面通过代码来实现一个作用域…

QLabel的setPixmap和setPicture有什么不同,请详细讲解

QLabel类提供了一个方便的方式来显示文本和图像。在Qt中&#xff0c;QLabel的setPixmap()和setPicture()方法都可以用来在标签中显示图像&#xff0c;但它们之间存在一些关键的区别&#xff0c;主要体现在它们接受的参数类型和用途上。 setPixmap() 参数类型&#xff1a;setP…

Linux系统之部署复古游戏平台

Linux系统之部署复古游戏平台 前言一、项目介绍1.1 项目简介1.2 项目特点1.3 游戏平台介绍二、本次实践介绍二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍三、本地环境检查3.1 安装Docker环境3.2 检查Docker服务状态3.3 检查Docker版本3.4 检查docker compose 版本四、构建…

RISC-V架构学习资料整理

1、韦东山——D1S哪吒开发板的裸机代码仓库 https://github.com/bigmagic123/d1-nezha-baremeta 2、melis系统移植到D1S https://blog.51cto.com/u_13800193/6268813 3、韦东山的gitee仓库 https://gitee.com/weidongshan 4、D1S编译工具链下载 https://github.com/Tina-Linux/…

ModbusTcp协议

Modbus TCP是一种通信协议,用于工业设备之间的通信。它是Modbus协议家族中的一个成员,最初是为串行通信设计的,但后来扩展到了TCP/IP网络。Modbus TCP/IP是一种公开的标准,由Modbus组织制定,并且被广泛应用于工业自动化和楼宇自动化领域。 Modbus TCP的主要特点: 基于TC…

LabVIEW管道缺陷智能检测系统

LabVIEW管道缺陷智能检测系统 管道作为一种重要的输送手段&#xff0c;其安全运行状态对生产生活至关重要。然而&#xff0c;随着时间的推移和环境的影响&#xff0c;管道可能会出现老化、锈蚀、裂缝等多种缺陷&#xff0c;这些缺陷若不及时发现和处理&#xff0c;将严重威胁到…

Java将File转换为MultipartFile

MultipartFile 是 Java 中用于处理 HTTP 文件上传的一个接口。它通常与 Spring 框架一起使用&#xff0c;特别是在 Spring MVC 中&#xff0c;用于处理 multipart/form-data 类型的 HTTP 请求。当用户在网页表单中选择并上传文件时&#xff0c;服务器端的控制器方法可能会接收一…

ProxySQL实现mysql8主从同步读写分离

ProxySQL基本介绍 ProxySQL是 MySQL 的高性能、高可用性、协议感知代理。以下为结合主从复制对ProxySQL读写分离、黑白名单、路由规则等做些基本测试。 先简单介绍下ProxySQL及其功能和配置&#xff0c;主要包括&#xff1a; 最基本的读/写分离&#xff0c;且方式有多种&…

python基础训练-for循环

适应人群&#xff1a;学习python大概在10-20天&#xff0c;比较勤于动手的同学&#xff0c;比较混沌的新手可以手敲一遍下面这些for循环&#xff0c;在AI时代只有脑子智能&#xff0c;手不生疏&#xff0c;多多运用AI进行语义编程&#xff0c;看懂代码&#xff0c;通过openai教…