【Python】mmSegmentation语义分割框架教程(自定义数据集、训练设定、数据增强)

文章目录

  • 0.mmSegmentation介绍
  • 1.mmSegmentation基本框架
    • 1.1.mmSegmentation的model设置
    • 1.2.mmSegmentation的dataset设置
      • 1.2.1.Dataset Class文件配置
      • 1.2.2.Dataset Config文件配置
      • 1.2.3.Total Config文件配置
  • 2.运行代码
  • 3.展示效果图和预测
  • X.附录
    • X.1.mmSegmentation框架解释
    • X.2.mmsegmentation使用的预训练backbone
    • X.2.mmsegmentation官方帮助文档

0.mmSegmentation介绍

\qquadmmSegmentation是openmmlab项目下开源的图像语义分割框架,目前支持pytorch,由于其拥有pipeline加速,完善的数据增强体系,完善的模型库,作为大数据语义分割训练及测试的代码框架是再好不过了。
\qquad在开始本教程之前,你需要解决openmmlab的环境配置问题,好在这个repo上已经有很人性化的步骤讲解了,在此附上链接,就不赘述了:

  • Github链接:安装openmmlab环境

使用教程的相关链接如下(github的项目还自带了中文版):

  • Github链接:openmmlab/mmSegmentation
  • Gitio教程:openmmlab/mmSegmenatation

\qquad对着mmSegmentation官方教程一步步做固然是能做出来,但是由于其框架结构过于复杂,加之官方教程对如何规范自定义数据集缺乏一些tips,因而本文提供了一个相对简单的教程供大家参考。本文所有讲解目录均为mmSegmentation的项目目录。
MMSegmentation

1.mmSegmentation基本框架

\qquad要说mmSegmentation(以下简称mmSeg)当中最重要的东西,固然是Config文件了,Config文件可以分为4大类:

  1. model config
  2. dataset config
  3. runtime config
  4. schedule config

\qquad如果你想知道为什么分成这四大类,请参考本文X.1.节,对这个不感兴趣就继续往下看。其实3和4大多数人都用不到的,重点还是在1和2,下面就从这两个角度给大家来一个不算精细的讲解。

1.1.mmSegmentation的model设置

\qquad如果采用的是mmSegmentation里面支持的模型,那么固然是不需要自己写class了,自己挑一个模型就可以了。这些model的目录保存在了configs/models里面了。
models
第一个下划线前面的都好理解,就是模型的名字呗,那r50-d8可能就是resnet的类型了,有人会问,那resnet101和resnet152哪去了,别急,其实这些只是baseline,它的backbone是可以改的,比如说我们要使用的是danet_r50-d8.py,我们先打开它(这里我已经将SyncBN改成了BN,因为需要单GPU训练):
danet
\qquad只需要把model.backbone.depth设为101或者152就可以使用resnet101或者resnet152啦,如果你的本地没有模型,mmSeg就会从model_zoo里面下载一个,如果本地有(应该是保存在了checkpoint里面),则自动加载本地的,不会重复下载。其他的操作后面会讲,另外如果你是多GPU操作就选择使用SyncBN,否则就使用BN就可以了。如果使用了SyncBN却只有一块可用的GPU,那可能会报类似AssertionError:Default process group is not initialized的错误。有人可能问那我直接改了这个文件不就吧原来的默认参数给覆盖了嘛,不要紧,看到后面大家就会明白这个问题很容易解决,这里只是给大家做一个demo。

1.2.mmSegmentation的dataset设置

\qquad数据集设置比model的稍微复杂一点,这里会直接定义一个自己的数据集(Custom Dataset)来说明其原理。数据集需要准备的文件有三个

  1. Dataset Class文件
  2. Dataset Config文件
  3. Total Config文件

\qquad在X.1.节提到的config文件就是Total config(顶层设置文件),也是train.py文件直接调用的config文件,而Dataset Class文件是用来定义数据集的类别数和标签名称的,Dataset Config文件则是用来定义数据集目录、数据集信息(例如图片大小)、数据增强操作以及pipeline的。

1.2.1.Dataset Class文件配置

\qquad首先来说Dataset Class文件,这个文件存放在 mmseg/datasets/ 目录下,
mmseg
\qquad在这个目录下自己建一个数据集文件,并命个名。配置文件实际上是继承该目录下custom.py当中的CustomDataset父类的,这样写起了就简单多了,大多数情况下(当你的数据集是以一张张图片出现并且可用PIL模块读入时),你只需要设置两个参数即可——类别标签名称(CLASSES)和类别标签上色的RGB颜色(PALETTE)。以我的配置文件为例,代码如下:

from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
import os.path as osp@DATASETS.register_module()
class MRDDataset(CustomDataset):CLASSES = ("background","road")PALETTE = [[0,0,0],[255,255,255]]def __init__(self, split, **kwargs):super().__init__(img_suffix='.png', seg_map_suffix='.png', split=split, **kwargs)assert osp.exists(self.img_dir) and self.split is not None

\qquadimg_suffixseg_map_suffix分别是你的数据集图片的后缀和标签图片的后缀,因个人差异而定,tif格式的图片我还没有试过,但是jpg和png的肯定是可以的。
\qquad设置好之后记得保存在mmseg/datasets/目录下(我的文件名叫my_road_detect.py)。另外还需要设置一下该目录下的__init__文件:

from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .custom import CustomDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .stare import STAREDataset
from .voc import PascalVOCDataset
from .my_road_detect import MRDDataset
__all__ = ['CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset','DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset','PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset','PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset','STAREDataset',"MRDDataset"
]

\qquad需要改两个地方,①import的时候要把自己的Dataset加载进来,②__all__数组里面需要加入自己的Dataset类名称,修改完成之后保存。这两部操作完成之后还不行,由于训练的时候需要txt文件指示训练集、验证集和测试集的txt文件,一开始我以为这只是一个optional option,但无奈Custom Dataset的__init___下面给我来了一句assert osp.exists(self.img_dir) and self.split is not None,那好吧,不知道删了and后面的条件会有什么后果,还是自己创一个吧,写来一个简单的划分数据集并保存到txt的demo,大家可以把这个py文件放到你的数据集上一级目录上并对着稍微改改:

import mmcv
import os.path as osp
data_root = "/data3/datasets/Custom/Lab/Segmentation/"
ann_dir = "ann_png1"
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:# select first 4/5 as train settrain_length = int(len(filename_list)*4/5)f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:# select last 1/5 as train setf.writelines(line + '\n' for line in filename_list[train_length:])

data_root写自己的工作目录名称,ann_dir写标签图片所在的目录,split_dir则是在data_root下生成split txt文件保存的文件夹目录,其他的就不需要怎么改了。如果你在data_root/split_dir/下成功找到了train.txt和val.txt文件,就没有问题了。

1.2.2.Dataset Config文件配置

\qquadDataset Config文件在 configs/__base__/ 目录下,需要自己新建一个xxx.py文件。
set
还是以我自己的Custom Dataset为例,它的书写格式如下:

# dataset settings
dataset_type = 'MRDDataset'
data_root = '/data3/datasets/Custom/Lab/Segmentation/'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 480)
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations'),dict(type='Resize', img_scale=(640, 480), keep_ratio=True),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='Normalize', **img_norm_cfg),dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),dict(type='DefaultFormatBundle'),dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='MultiScaleFlipAug',img_scale=(640, 480),# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],flip=False,transforms=[dict(type='Resize', keep_ratio=True),dict(type='RandomFlip'),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img']),])
]
data = dict(samples_per_gpu=2,workers_per_gpu=2,train=dict(type=dataset_type,data_root=data_root,img_dir='data1_for_ann',ann_dir='ann_png1/',pipeline=train_pipeline,split="splits/train.txt"),val=dict(type=dataset_type,data_root=data_root,img_dir='data1_for_ann',ann_dir='ann_png1',split="splits/val.txt",pipeline=test_pipeline),test=dict(type=dataset_type,data_root=data_root,img_dir='data1_for_ann',ann_dir='ann_png1',split="splits/val.txt",pipeline=test_pipeline))

需要改的地方有以下几个:

  1. img_norm_cfg:数据集的方差和均值
  2. crop_size:数据增强时裁剪的大小. img_dir:
  3. img_scale:原图像尺寸
  4. data_root:工作目录
  5. img_dir:工作目录下存图片的目录
  6. ann_dir:工作目录下存标签的目录
  7. split:之前操作做txt文件的目录
  8. sample_per_gpu:batch size
  9. workers_per_gpu:dataloader的线程数目,一般设2,4,8,根据CPU核数确定,或使用os.cpu_count()函数代替
  10. PhotoMetricDistortion是数据增强操作,有四个参数(参考博客)分别是亮度、对比度、饱和度和色调,它们的默认设定如下:
brightness_delta=32; # 32 
contrast_range=(0.5, 1.5); # (0.5, 1.5),下限-上限
saturation_range=(0.5, 1.5); # (0.5, 1.5),下限-上限
hue_delta=18; # 18

如果不想使用默认设定,仿照其他选项将自定义参数写在后面即可,例如

dict(type='PhotoMetricDistortion',contrast_range=(0.5, 1.0))

改好之后保存 configs/__base__/ 目录下。
\qquad这里也给大家提供了计算数据集方差和均值的一个样例程序(多数据集计算整体均值和标准差):

# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 10:38:17 2021@author: 17478
"""
import os
import cv2
import numpy as np
from tqdm import tqdm  # pip install tqdm
import argparsedef input_args():parser = argparse.ArgumentParser(description="calculating mean and std")parser.add_argument("--data_fmt",type=str,default='samples_{name}')parser.add_argument("--data-name",type=str,nargs="+",default=['morning','noon','afternoon','dusk','snowy'])return parser.parse_args()if __name__ == "__main__":opt = input_args()img_files =[]for name in opt.data_name:img_dir = opt.data_fmt.format(name=name)files = os.listdir(img_dir)img_files.extend([os.path.join(img_dir,file) for file in files])meanRGB = np.asarray([0,0,0],dtype=np.float64)varRGB = np.asarray([0,0,0],dtype=np.float64)for img_file in tqdm(img_files,desc="calculating mean",mininterval=0.1):img = cv2.imread(img_file,-1)meanRGB[0] += np.mean(img[:,:,0])/255.0meanRGB[1] += np.mean(img[:,:,1])/255.0meanRGB[2] += np.mean(img[:,:,2])/255.0meanRGB = meanRGB/len(img_files)for img_file in tqdm(img_files,desc="calculating var",mininterval=0.1):img = cv2.imread(img_file,-1)varRGB[0] += np.sqrt(np.mean((img[:,:,0]/255.0-meanRGB[0])**2))varRGB[1] += np.sqrt(np.mean((img[:,:,1]/255.0-meanRGB[1])**2))varRGB[2] += np.sqrt(np.mean((img[:,:,2]/255.0-meanRGB[2])**2))varRGB = varRGB/len(img_files)print("meanRGB:{}".format(meanRGB))print("stdRGB:{}".format(varRGB))

1.2.3.Total Config文件配置

\qquadTotal Config文件是train.py直接调用的config文件,在X.1.节也有介绍,在此只说明如何即可。该文件在 config/xxxmodel/ 的目录下,你选用的是哪一个model,就选择哪一个目录。
model
以DANet为例,我们书写一个total config文件,并保存在configs/danet的文件夹下:

_base_ = ['../_base_/models/danet_r50-d8.py', '../_base_/datasets/my_road_detect.py','../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
]
model = dict(decode_head=dict(num_classes=2),auxiliary_head=dict(num_classes=2))

\qquad这个代码就一个__base__的数组,第一个元素代表模型路径,也就是在1.1.节介绍的模型文件(在这个教程里就不带着大家重写模型了);第二个元素代表数据集的Dataset config文件(详见1.2.2节);第三个元素和第四个元素本教程未涉及到,按照默认参数写也没有太大问题,如果想修改训练的代数以及log和save的频率修改第4元素及响应文件,在此就不再赘述了。另外如果你的模型不是19类的(因为是原模型是根据cityscapes写的,输出通道为19),需按照上面修改一下。
\qquad到此为止要恭喜大家,代码终于可以试跑了,如果你的代码出现Error或者Exception也不要慌,从环境配置到流程一一对照一遍,调试大项目要有耐心,也欢迎大家评论区留言。

2.运行代码

\qquad在项目目录下,输入python tools/train.py xxxconfig.py --work-dir=xxx即可运行,其中xxxconfig.py就是我们刚刚保存的Total config文件(记得要把完整路径也加上),work-dir其实就是保存log和model的目录(如果没有会自己创建)。如果发现import mmseg找不到这个包,那八成是调试器运行目录不在根目录下造成的,要不就配置run的目录,要不就直接吧tools/train.py复制到根目录下运行。运行结果差不多是这样:
mmSegmentation
使用gpustat的包查看gpu状态
gpu
\qquad虽然我的数据集很小(做测试的,就50张图片),但是gpu利用率仍然接近100%,可见其代码优化做的已经相当理想了。(我开了NVIDIA的图形加速,所以出现了很多其他的利用进程)。
\qquad这里有读者会疑问为什么上面不显示epoch,因为mmseg默认是iteration-based的,所谓iteration即batch的个数,若要改成epoch,则需要参考docs/config.md进行修改:

runner = dict(type='EpochBasedRunner',max_epoch='200')
checkpoint_config = dict(by_epoch=True,interval=20)  # save checkpoint per 20 epochs

以上代码可放在Total config文件中。

3.展示效果图和预测

\qquad最后写了展示预测效果的代码,把config_file和checkpoint_file替换成你自己的config文件和pth文件(保存模型的)即可:

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette
config_file = "configs/danet/danet_r50-d8_360x480_20k_mrd.py"
checkpoint_file = 'work_dirs/danet_r50-d8_375x1242_20k_mrd/latest.pth'
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
img = '/data3/datasets/Custom/Lab/Segmentation/data1_for_ann/000000.png'
result = inference_segmentor(model, img)
show_result_pyplot(model, img, result, [[0,255,0],[255,255,255]])

pic
\qquad我上的是白色(道路)和绿色(非道路),不是特别好看,哈哈,但是mask和img的相对位置很容易看出来,这个配颜色的话,大家还是自己定吧。我这个数据集太少,只是给大家做个演示,结果肯定是过拟合的。

X.附录

X.1.mmSegmentation框架解释

在mmSegmentation的项目目录下,打开Configs/下面的目录
mmseg
随便打开一个文件就知道了
mmsegmentation
从文件的名字也可以看出,它是模型(baseline+backbone、数据集、schedule的组合(runtime是default设置,就没包含在名称内)。

X.2.mmsegmentation使用的预训练backbone

预训练backbone下载链接为:
mmcv预训练模型下载地址(.json文件,复制对应模型的链接即可下载)
download

X.2.mmsegmentation官方帮助文档

可在docs/tutorials中查看
在这里插入图片描述

希望本文对您有帮助,谢谢阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/544438.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基本shell编程【3】- 常用的工具awk\sed\sort\uniq\od

awk awk是个很好用的东西,大量使用在linux系统分析的结果展示处理上。并且可以使用管道, input | awk | output1.首先要知道形式awk command file 如 awk {print $0} a.txt b.txt (后面可以跟一个或多个文件)2.command学习。c…

面试官:重写 equals 时为什么一定要重写 hashCode?

作者 | 磊哥来源 | Java面试真题解析(ID:aimianshi666)转载请联系授权(微信ID:GG_Stone)重要说明:本篇为博主《面试题精选-基础篇》系列中的一篇,关注我,查看更多面试题。…

rotate array_Array.rotate! Ruby中的示例方法

rotate arrayArray.rotate! 方法 (Array.rotate! Method) In this article, we will study about Array.rotate! method. You all must be thinking the method must be doing something which is related to rotating certain elements. It is not as simple as it…

【python】获取PC机公网IP并发送至邮箱

文章目录0.引言1.获取外网IP2.打开SMTP服务3.python发送邮件4.完整代码0.引言 \qquad之前一直使用Putty连接公司的PC机进行远程办公,苦于外网的IP地址不能固定下来,所以购买了内网穿透服务,免费版还会限速。后来转念一想,如果能定…

List 去重的 6 种方法,这个方法最完美!

作者 | 王磊来源 | Java中文社群(ID:javacn666)转载请联系授权(微信ID:GG_Stone)在日常的业务开发中,偶尔会遇到需要将 List 集合中的重复数据去除掉的场景。这个时候可能有同学会问&#xff1a…

Mongodb -(3) replica set+sharding

分片集搭建---何旭东目录分片集搭建...................................................................................................................... 1生态系统...............................................................................................…

electron 菜单栏_如何在Electron JS中添加任务栏图标菜单?

electron 菜单栏If you are new here, please consider checking out my recent articles on Electron JS including Tray Icons. 如果您是新来的,请考虑查看我最近关于Electron JS的文章, 包括托盘图标 。 In this tutorial, we will set up 2 menu it…

【逆强化学习-0】Introduction

文章目录专栏传送门0.引言1.逆强化学习发展历程2.需要准备的专栏传送门 0.简介 1.学徒学习 2.最大熵学习 0.引言 \qquad相比于深度学习,国内强化学习的教程并不是特别多,而相比强化学习,逆强化学习的教程可谓是少之又少。而本人想将整理到的资…

不知道Mysql排序的特性,加班到12点,认了认了!

小弟新写了一个功能,自测和测试环境测试都没问题,但在生产环境会出现偶发问题。于是,加班到12点一直排查问题,终于定位了的问题原因:Mysql Limit查询优化导致。现抽象出问题模型及解决方案,分析给大家&…

js中==与===的区别

2019独角兽企业重金招聘Python工程师标准>>> 1、对于string,number等基础类型,和是有区别的 1)不同类型间比较,之比较“转化成同一类型后的值”看“值”是否相等,如果类型不同,其结果就是不等 2&#xff09…

c语言中memcpy函数_带有示例的C中的memcpy()函数

c语言中memcpy函数memcpy()函数 (memcpy() function) memcpy() is a library function, which is declared in the “string.h” header file - it is used to copy a block of memory from one location to another (it can also be considered as to copy a string to anothe…

【逆强化学习-1】学徒学习(Apprenticeship Learning)

文章目录0.引言1.算法原理2.仿真环境3.运行4.补充(学徒学习深度Q网络)本文为逆强化学习系列第1篇,没有看过逆强化学习介绍的那篇的朋友,可以看一下:Inverse Reinforcement Learning-Introduction 传送门 0.引言 \qquad…

面试官:HashMap有几种遍历方法?推荐使用哪种?

作者 | 磊哥来源 | Java面试真题解析(ID:aimianshi666)转载请联系授权(微信ID:GG_Stone)HashMap 的遍历方法有很多种,不同的 JDK 版本有不同的写法,其中 JDK 8 就提供了 3 种 HashMa…

HTML 5 input placeholder 属性

<input placeholder"请先选择组织" type"text" value"" </input>placeholder 属性提供可描述输入字段预期值的提示信息&#xff08;hint&#xff09;。 该提示会在输入字段为空时显示&#xff0c;并会在字段获得焦点时消失。 注释&…

【逆强化学习-2】最大熵学习(Maximum Entropy Learning)

文章目录0.引言1.算法原理2.仿真0.引言 \qquad本文是逆强化学习系列的第2篇&#xff0c;其余博客传送门如下&#xff1a; 逆强化学习0-Introduction 逆强化学习1-学徒学习 \qquad最大熵学习是2008年出现的方法&#xff0c;原论文&#xff08;链接见【逆强化学习0】的博客&#…

uselocale_Java扫描仪useLocale()方法与示例

uselocale扫描器类useLocale()方法 (Scanner Class useLocale() method) useLocale() method is available in java.util package. useLocale()方法在java.util包中可用。 useLocale() method is used to use this Scanner locale to the given locale (lo). useLocale()方法用…

面试官又整新活,居然问我for循环用i++和++i哪个效率高?

前几天&#xff0c;一个小伙伴告诉我&#xff0c;他在面试的时候被面试官问了这么一个问题&#xff1a;在for循环中&#xff0c;到底应该用 i 还是 i &#xff1f;听到这&#xff0c;我感觉这面试官确实有点不按套路出牌了&#xff0c;放着好好的八股文不问&#xff0c;净整些幺…

UVa 988 - Many Paths, One Destination

称号&#xff1a;生命是非常多的选择。现在给你一些选择&#xff08;0~n-1&#xff09;&#xff0c;和其他选项后&#xff0c;分支数每一次选择&#xff0c;选择共求。 分析&#xff1a;dp&#xff0c;图论。假设一个状态也许是选择的数量0一个是&#xff0c;代表死亡&#xff…

Java PrintWriter close()方法与示例

PrintWriter类close()方法 (PrintWriter Class close() method) close() method is available in java.io package. close()方法在java.io包中可用。 close() method is used to close this stream and free all system resources linked with the stream. close()方法用于关闭…

pipedreader_Java PipedReader ready()方法与示例

pipedreaderPipedReader类ready()方法 (PipedReader Class ready() method) ready() method is available in java.io package. ready()方法在java.io包中可用。 ready() method is used to check whether this PipedReader stream is ready to be read or not. ready()方法用…