深度学习系列53:mmdetection上手

1. 安装

使用openmim安装:

pip install -U openmim
mim install "mmengine>=0.7.0"
mim install "mmcv>=2.0.0rc4"

2. 测试案例

下载代码和模型:

git clone https://github.com/open-mmlab/mmdetection.git
mkdir ./checkpoints
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest ./checkpoints

运行代码,核心是定义inferencer和使用inferencer进行推理两行:

from mmdet.apis import DetInferencer# Choose to use a config
model_name = 'rtmdet_tiny_8xb32-300e_coco'
# Setup a checkpoint file to load
checkpoint = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'# Set the device to be used for evaluation
device = 'cpu'# Initialize the DetInferencer
inferencer = DetInferencer(model_name, checkpoint, device)# Use the detector to do inference
img = 'demo.jpg'
result = inferencer(img, out_dir='./output')# Show the structure of result dict
from rich.pretty import pprint
pprint(result, max_length=4)# Show the output image
from PIL import Image
Image.open('./output/vis/demo.jpg')

3. 自定义数据进行训练

3.1 准备数据

建议使用coco格式,参见https://cocodataset.org/#format-data。文件从头至尾按照顺序分为以下段落:

{
“info”: info,
“licenses”: [license],
“images”: [image],
“annotations”: [annotation],
“categories”: [category]
}
下面是从instances_val2017.json文件中摘出的一个annotation的实例,这里的segmentation就是polygon格式:

{
“segmentation”: [[510.66,423.01,511.72,420.03,510.45…]],
“area”: 702.1057499999998,
“iscrowd”: 0,
“image_id”: 289343,
“bbox”: [473.07,395.93,38.65,28.67],
“category_id”: 18,
“id”: 1768
},
从instances_val2017.json文件中摘出的2个category实例如下所示:

{
“supercategory”: “person”,
“id”: 1,
“name”: “person”
},
{
“supercategory”: “vehicle”,
“id”: 2,
“name”: “bicycle”
},

我们来看测试案例的例子,包含三个大字段,其中categories非常简单,只有一个balloon(我们需要训练的目标)
在这里插入图片描述
images则是如下的清单:
在这里插入图片描述
annotations如下:
在这里插入图片描述

3.2 配置config文件

config文件中需要定义数据,模型,训练参数,优化器等各种参数。测试案例如下:

config_balloon = """
# Inherit and overwrite part of the config based on this config
_base_ = './rtmdet_tiny_8xb32-300e_coco.py'data_root = 'data/balloon/' # dataset roottrain_batch_size_per_gpu = 4
train_num_workers = 2max_epochs = 20
stage2_num_epochs = 1
base_lr = 0.00008metainfo = {'classes': ('balloon', ),'palette': [(220, 20, 60),]
}train_dataloader = dict(batch_size=train_batch_size_per_gpu,num_workers=train_num_workers,dataset=dict(data_root=data_root,metainfo=metainfo,data_prefix=dict(img='train/'),ann_file='train.json'))val_dataloader = dict(dataset=dict(data_root=data_root,metainfo=metainfo,data_prefix=dict(img='val/'),ann_file='val.json'))test_dataloader = val_dataloaderval_evaluator = dict(ann_file=data_root + 'val.json')test_evaluator = val_evaluatormodel = dict(bbox_head=dict(num_classes=1))# learning rate
param_scheduler = [dict(type='LinearLR',start_factor=1.0e-5,by_epoch=False,begin=0,end=10),dict(# use cosine lr from 10 to 20 epochtype='CosineAnnealingLR',eta_min=base_lr * 0.05,begin=max_epochs // 2,end=max_epochs,T_max=max_epochs // 2,by_epoch=True,convert_to_iter_based=True),
]train_pipeline_stage2 = [dict(type='LoadImageFromFile', backend_args=None),dict(type='LoadAnnotations', with_bbox=True),dict(type='RandomResize',scale=(640, 640),ratio_range=(0.1, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=(640, 640)),dict(type='YOLOXHSVRandomAug'),dict(type='RandomFlip', prob=0.5),dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),dict(type='PackDetInputs')
]# optimizer
optim_wrapper = dict(_delete_=True,type='OptimWrapper',optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),paramwise_cfg=dict(norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))default_hooks = dict(checkpoint=dict(interval=5,max_keep_ckpts=2,  # only keep latest 2 checkpointssave_best='auto'),logger=dict(type='LoggerHook', interval=5))custom_hooks = [dict(type='PipelineSwitchHook',switch_epoch=max_epochs - stage2_num_epochs,switch_pipeline=train_pipeline_stage2)
]# load COCO pre-trained weight
load_from = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
"""with open('../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py', 'w') as f:f.write(config_balloon)

3.3 开始训练

使用Mac M2芯片需要修改3个地方。首先是需要设置

export PYTORCH_ENABLE_MPS_FALLBACK=1

其次是mmcv中的nms需要转到cpu上计算,打开mmcv/ops/nms.py,将class NMSop(torch.autograd.Function)中的inds = ext_module.nms(bboxes, scores…)改为inds = ext_module.nms(bboxes.cpu(), scores.cpu()…)
运行后会出现一个assert报错,找到源代码,把那一行assert删掉即可。
运行完成后,可以查看tensorboard:

%load_ext tensorboard# see curves in tensorboard
%tensorboard --logdir ./work_dirs

然后查看测试结果

from mmdet.apis import DetInferencer
import glob# Choose to use a config
config = '../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py'
# Setup a checkpoint file to load
checkpoint = glob.glob('./work_dirs/rtmdet_tiny_1xb4-20e_balloon/best_coco*.pth')[0]# Set the device to be used for evaluation
device = 'cpu'# Initialize the DetInferencer
inferencer = DetInferencer(config, checkpoint, device)# Use the detector to do inference
img = './data/balloon/val/4838031651_3e7b5ea5c7_b.jpg'
result = inferencer(img, out_dir='./output')
# Show the output image
Image.open('./output/vis/4838031651_3e7b5ea5c7_b.jpg')

在这里插入图片描述

4. 其他

MMYOLO:传统的目标检测库
MMRotate:旋转检测库
MMDetection3D:三维检测库
下面几期一一介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/149748.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(二)汇编语句组成

一个完整的 RISC-V 汇编程序有多条 语句(statement) 组成。 一条典型的 RISC-V 汇编 语句 由 3 部分组成: 1.标签 List item label(标签): 标签是标识程序位置的记号。通常定义一个名称然后加上":"后缀。…

2023初中生古诗文大会复赛12月2日举行,来做做全真在线模拟题吧

2023年11月19日日,上海市古诗文大会主办方通过官微发布了2023上海中学生古诗文大会(初中组)复选将于12月2日举行的通知,就初中生古诗文大会复赛(复选)的相关安排做了说明,六分成长已经为您把通知…

CSDN的文档编辑器使用

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

越南MIC新规针对ICT和ITE产品电气授权标准变更

从2024年1月1日起,所有ICT和ITE产品(如台式电脑、笔记本电脑、平板电脑、DVB-T2电视/机顶盒、DECT电话等)都需要越南MIC授权的电气安全标准——QCVN132:2022。 目前MIC仍未最终确定要求,因为这与另一个监管机构存在冲突。所以目前他们可以接受ISO 17025的…

竞赛选题 深度学习验证码识别 - 机器视觉 python opencv

文章目录 0 前言1 项目简介2 验证码识别步骤2.1 灰度处理&二值化2.2 去除边框2.3 图像降噪2.4 字符切割2.5 识别 3 基于tensorflow的验证码识别3.1 数据集3.2 基于tf的神经网络训练代码 4 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 &#x…

macOS下如何使用Flask进行开发

👨🏻‍💻 热爱摄影的程序员 👨🏻‍🎨 喜欢编码的设计师 🧕🏻 擅长设计的剪辑师 🧑🏻‍🏫 一位高冷无情的编码爱好者 大家好,我是全栈工…

Cannot find proj.db

原因 编译GDAL完成后,我打了个包(包括.so)移动到了另外同环境的机器上。 应用gdal ogr2ogr时候提示找不到proj.db 解决办法: 把proj的share拷贝到另外环境上。 #gdal新建othershare,proj的share复制过去 mkdir -p /usr/local/gdal-3.6.2…

字符串函数详解

一.字母大小写转换函数. 1.1.tolower 结合cppreference.com 有以下结论&#xff1a; 1.头文件为#include <ctype.h> 2.使用规则为 #include <stdio.h> #include <ctype.h> int main() {char ch A;printf("%c\n",tolower(ch));//大写转换为小…

打印工具HandyPrint Pro Mac中文版软件特点

HandyPrint Pro Mac是一款打印工具&#xff0c;它支持AIrPrint协议&#xff0c;可以让用户在iPhone、iPad、iPod等设备上进行打印操作&#xff0c;只需要将这些设备连接到Mac电脑的WiFi网络中即可实现打印功能。 ​ HandyPrint Pro Mac软件特点 简单易用&#xff1a;用户只需…

基于单片机PM2.5监测系统仿真设计

**单片机设计介绍&#xff0c; 基于单片机PM2.5监测系统仿真设计 文章目录 一 概要简介设计目标系统组成工作流程仿真设计结论 二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 # 基于单片机PM2.5监测系统仿真设计介绍 简介 PM2.5&#xff08;可吸…

QQ自动批量加好友(手机端)

1.需求 按照格式输入批量qq号,输入加好友间隔时间,脚本自动打开qq应用开始自动加好友,全程自动化操作。 输入qq号格式: 运行示意图: 2.代码 function carmiLogin () {var carmi = getCarMi()try {const data = {"key": carmi}http.__okhttp__.setTimeout(3000…

大势智慧代理商体系持续开疆拓土,全国代理火热招募中...

11月15日&#xff0c;武汉大势智慧科技有限公司&#xff08;后简称“大势智慧”&#xff09;与上海宝天信息科技有限公司&#xff08;后简称“宝天信息”&#xff09;金牌代理商签约授牌仪式成功举行。大势智慧副总裁周济安先生、宝天信息经理王芳女士分别作为双方签约代表出席…

【WSL/WSL2-Ubuntu】突破界限:不使用服务器在一台Windows搭建Nginx+FastDFS

打造超级开发环境&#xff1a;Nginx和FastDFS在WSL中的完美结合 前言 随着软件开发领域的快速发展&#xff0c;跨平台的开发环境变得日益重要。Windows Subsystem for Linux&#xff08;WSL&#xff09;和WSL 2为开发者提供了在Windows操作系统上体验Linux环境的便捷途径。本…

解决IP查询结果偏差的几个方法

解决IP查询结果偏差的方法可以包括以下几个方面&#xff1a; 选择权威的IP查询工具&#xff1a;使用来自可信来源的IP查询工具&#xff0c;例如官方或专业的IP地址数据库&#xff0c;以确保查询结果的准确性和可靠性。 考虑使用代理服务器或VPN&#xff1a;如果需要更准确的IP…

CentOS7 安装mysql8(离线安装)postgresql14(在线安装)

注&#xff1a;linux系统为vmware虚拟机&#xff0c;和真实工作环境可能有出入&#xff0c;不过正因如此我暴露了NAT转出的IP也没什么大碍 引言 postgresql与mysql目前都是非常受人欢迎的两大数据库&#xff0c;其各有各的优势&#xff0c;初学者先使用简单一张图来说明两者区…

深入解析具名导入es6规范中的具名导入是在做解构吗

先说答案&#xff0c;不是 尽管es6的具名导入和语法非常相似 es6赋值解构 const obj {a: 1,f() {this.a}}const { a, f } objes6具名导入 //导出文件代码export let a 1export function f() {a}export default {a,f}//导入文件代码import { a, f } from ./tsVolution可以看出…

11月20日星期一今日早报简报微语报早读

11月20日星期一&#xff0c;农历十月初八&#xff0c;早报微语早读。 1、T1以3-0横扫WBG&#xff0c;拿下S13冠军&#xff01;Faker豪取第4冠&#xff1b; 2、天舟七号货运飞船已运抵文昌发射场&#xff0c;将于明年初发射&#xff1b; 3、“中韩之战”球票已经售罄&#xf…

我了解的3D游戏引擎和图形开发框架

如果你像我一样&#xff0c;没有什么比编写或设计软件更让人兴奋的了。 当我编写代码时&#xff0c;我所获得的巨大快乐促使我开发了跨越许多软件领域的项目。 这些领域之一是为本机应用程序、桌面展示或 Web 创建 3D 图形。 我从未创建过任何 3D 游戏&#xff0c;但很多时候我…

基于单片机16路抢答器仿真系统

**单片机设计介绍&#xff0c; 基于单片机16路抢答器仿真系统 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的16路抢答器仿真系统是一种用于模拟和实现抢答竞赛的系统。该系统由硬件和软件两部分组成。 硬件方面&am…

clickhouse 业务日志告警

一、需求 对入库到clickhouse的业务日志进行告警&#xff0c;达阀值后发送企业微信告警。 方法一、 fluent-bit–>clickhouse(http)<–shell脚本,每隔一分钟获取分析结果 --> 把结果保存到/dev/shm/目录下 <-- node_exporter读取指标入库到prometheus<-- rules…