使用coco数据集进行语义分割:数据预处理与损失函数

如何coco数据集进行目标检测的介绍已经有很多了,但是关于语义分割几乎没有。本文旨在说明如何处理 stuff_train2017.json    stuff_val2017.json    panoptic_train2017.json    panoptic_val2017.json,将上面那些json中的dict转化为图片的label mask,也就是制作图片像素的标签的ground truth。

首先下载图片和annotation文件(也就是json文件)

2017 Train Images和2017 Val Images解压得到这两个文件夹,里面放着的是所有图片。

annotation下载下来有4个json文件

以及两个压缩包

这两个压缩包里面是panoptic segmentation的mask,用于制作ground truth。

首先说明一下那四个json文件的含义,stuff是语义分割,panoptic是全景分割。二者的区别在与

语义分割只区别种类,全景分割还区分个体。上图中,中间的语义分割将所有的人类视作同一个种类看待,而右边的全景分割还将每一个个体区分出来。

做语义分割,可以使用stuff_train2017.json生成ground truth。但是即便是只做语义分割,不做全景分割,在只看语义分割的情况下,stuff的划分精度不如panoptic,因此建立即使是做语义分割也用panoptic_train2017.json。本文中将两种json都进行说明。

stuff_train2017.json

下面这段代码是处理stuff_train2017.json生成ground truth

from pycocotools.coco import COCO
import os
from PIL import Image
import numpy as np
from matplotlib import pyplot as pltdef convert_coco2mask_show(image_id):print(image_id)img = coco.imgs[image_id]image = np.array(Image.open(os.path.join(img_dir, img['file_name'])))plt.imshow(image, interpolation='nearest')# cat_ids = coco.getCatIds()cat_ids = list(range(183))anns_ids = coco.getAnnIds(imgIds=img['id'], catIds=cat_ids, iscrowd=None)anns = coco.loadAnns(anns_ids)mask = np.zeros((image.shape[0], image.shape[1]))for i in range(len(anns)):tmp = coco.annToMask(anns[i]) * anns[i]["category_id"]mask += tmpprint(np.max(mask), np.min(mask))# 绘制二维数组对应的颜色图plt.figure(figsize=(8, 6))plt.imshow(mask, cmap='viridis', vmin=0, vmax=182)plt.colorbar(ticks=np.linspace(0, 182, 6), label='Colors')  # 添加颜色条# plt.savefig(save_dir + str(image_id) + ".jpg")plt.show()if __name__ == '__main__':Dataset_dir = "/home/xxxx/Downloads/coco2017/"coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")# createIndex# coco = COCO("/home/robotics/Downloads/coco2017/annotations/annotations/panoptic_val2017.json")img_dir = os.path.join(Dataset_dir, 'train2017')save_dir = os.path.join(Dataset_dir, "Mask/stuff mask/train")if not os.path.isdir(save_dir):os.makedirs(save_dir)image_id = 9convert_coco2mask_show(image_id)# for keyi, valuei in coco.imgs.items():#     image_id = valuei["id"]#     convert_coco2mask_show(image_id)

解释一下上面的代码。

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")

从coco的官方库中导入工具,读取json。会得到这样的字典结构

在 coco.imgs ,找到 "id" 对应的value,这个就是每个图片唯一的编号,根据这个编号,到 coco.anns 中去索引 segmentation

注意 coco.imgs 的 id 对应的是 anns 中的 image_id,这个也是 train2017 文件夹中图片的文件名。

# cat_ids = coco.getCatIds()
cat_ids = list(range(183))

网上的绝大多数教程都写的是上面我注释的那行代码,那行代码会得到91个类,那个只能用于图像检测,但对于图像分割任务,包括的种类是182个和一个未归类的0类一共183个类。

tmp = coco.annToMask(anns[i]) * anns[i]["category_id"]

上面那行代码就是提取每一个类所对应的id,将所有的类都加进一个二维数组,就制作完成了ground truth。

panoptic_train2017.json

上面完成了通过stuff的json文件自作ground truth,那么用panoptic的json文件制作ground truth,是不是只要把

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")

换为

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/annotations/panoptic_val2017.json")

就行了呢?

答案是不行!会报错

Traceback (most recent call last):File "/home/xxxx/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_codeexec(code_obj, self.user_global_ns, self.user_ns)File "<ipython-input-4-66f5c6e7dbe4>", line 1, in <module>coco = COCO("/home/xxxx/Downloads/coco2017/annotations/panoptic_val2017.json")File "/home/robotics/.local/lib/python3.8/site-packages/pycocotools/coco.py", line 86, in __init__self.createIndex()File "/home/xxxx/.local/lib/python3.8/site-packages/pycocotools/coco.py", line 96, in createIndexanns[ann['id']] = ann
KeyError: 'id'

用官方的库读官方的json还报错,真是巨坑!关键还没有官方的文档教你怎么用,只能一点点摸索,非常浪费时间精力。

制作panoptic的ground truth,用到了Meta公司的detectron2这个库

# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import json
import osfrom detectron2.data import MetadataCatalog
from detectron2.utils.file_io import PathManagerdef load_coco_panoptic_json(json_file, image_dir, gt_dir, meta):"""Args:image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".Returns:list[dict]: a list of dicts in Detectron2 standard format. (See`Using Custom Datasets </tutorials/datasets.html>`_ )"""def _convert_category_id(segment_info, meta):if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][segment_info["category_id"]]segment_info["isthing"] = Trueelse:segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][segment_info["category_id"]]segment_info["isthing"] = Falsereturn segment_infowith PathManager.open(json_file) as f:json_info = json.load(f)ret = []for ann in json_info["annotations"]:image_id = int(ann["image_id"])# TODO: currently we assume image and label has the same filename but# different extension, and images have extension ".jpg" for COCO. Need# to make image extension a user-provided argument if we extend this# function to support other COCO-like datasets.image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")label_file = os.path.join(gt_dir, ann["file_name"])segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]ret.append({"file_name": image_file,"image_id": image_id,"pan_seg_file_name": label_file,"segments_info": segments_info,})assert len(ret), f"No images found in {image_dir}!"assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]return retif __name__ == "__main__":from detectron2.utils.logger import setup_loggerfrom detectron2.utils.visualizer import Visualizerimport detectron2.data.datasets  # noqa # add pre-defined metadatafrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as nplogger = setup_logger(name=__name__)meta = MetadataCatalog.get("coco_2017_train_panoptic")dicts = load_coco_panoptic_json("/home/xxxx/Downloads/coco2017/annotations/panoptic_train2017.json","/home/xxxx/Downloads/coco2017/train2017","/home/xxxx/Downloads/coco2017/Mask/panoptic mask/panoptic_train2017", meta.as_dict())logger.info("Done loading {} samples.".format(len(dicts)))dirname = "coco-data-vis"os.makedirs(dirname, exist_ok=True)new_dic = {}num_imgs_to_vis = 100for i, d in enumerate(dicts):img = np.array(Image.open(d["file_name"]))visualizer = Visualizer(img, metadata=meta)pan_seg, segments_info = visualizer.draw_dataset_dict(d)seg_cat = {0: 0}for segi in segments_info:seg_cat[segi["id"]] = segi["category_id"]mapped_seg = np.vectorize(seg_cat.get)(pan_seg)# 保存数组为txt文件save_name = "/home/xxxx/Downloads/coco2017/Mask/panoptic label/train/" + str(d["image_id"]) + ".txt"np.savetxt(save_name, mapped_seg, fmt='%i')new_dic[d["image_id"]] = {"image_name": d["file_name"], "label_name": save_name}# # 将numpy数组转换为PIL Image对象# img1 = Image.fromarray(np.uint8(mapped_seg))# # 缩放图片# img_resized = img1.resize((224, 224))# # 将PIL Image对象转换回numpy数组# mapped_seg_resized = np.array(img_resized)# # 创建一个新的图形# plt.figure(figsize=(6, 8))# # 使用imshow函数来显示数组,并使用cmap参数来指定颜色映射# plt.imshow(mapped_seg_resized, cmap='viridis')# # 显示图形# plt.show()# fpath = os.path.join(dirname, os.path.basename(d["file_name"]))# # vis.save(fpath)if i + 1 >= num_imgs_to_vis:# 将字典转换为json字符串json_data = json.dumps(new_dic)# 将json字符串写入文件with open('/home/xxxx/Downloads/coco2017/data_for_train.json', 'w') as f:f.write(json_data)break
dicts = load_coco_panoptic_json("/home/xxxx/Downloads/coco2017/annotations/panoptic_train2017.json","/home/xxxx/Downloads/coco2017/train2017","/home/xxxx/Downloads/coco2017/Mask/panoptic mask/panoptic_train2017", meta.as_dict())

load_coco_panoptic_json的第三个参数,就是下载panoptic的annotations时,里面包含的两个压缩包。

上面这段代码,还需要修改一下detectron2库内部的文件

pan_seg, segments_info = visualizer.draw_dataset_dict(d)

进入 draw_dataset_dict这个函数内部,将两个中间结果拿出来

603行和604行是我自己加的代码。

这两个中间结果拿出来后,pan_seg是图片每个像素所属的种类,但是这个种类不是coco分类的那183个类

pan_seg中的数,要映射到segments_info中的 category_id,这个才是coco数据集所规定的183个类 。下图节选了183个中的前10个展示

mapped_seg = np.vectorize(seg_cat.get)(pan_seg)

上面这行代码就是完成pan_seg中的数,映射到segments_info中的 category_id。不要用两层的for循环,太低效了,numpy中有函数可以完成数组的映射。

这样就得到了语义分割的ground truth,也就是每个像素所属的种类。

通过以下代码训练网络

import argparse
import time
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.optim as optim
from config import get_config
from build import build_model, build_transform, build_criterion
from utils import load_pretrained
from prepare_data import get_dataloaderCHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/my_model/")
os.makedirs(CHECKPOINT_PATH, exist_ok=True)def parse_option():parser = argparse.ArgumentParser("Swin Backbone", add_help=False)parser.add_argument("--cfg", type=str, default="configs/swin/swin_tiny_patch4_window7_224_22k.yaml")parser.add_argument("--pretrained", default="checkpoint/swin_tiny_patch4_window7_224_22k.pth")parser.add_argument("--device", default="cuda")args = parser.parse_args()config = get_config(args)return configif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")config = parse_option()model = build_model(config).to(device)load_pretrained(config, model)transform = build_transform(config)# img_addr = "/home/xxxx/Downloads/coco2017/train2017/000000000009.jpg"# img = Image.open(img_addr)# img = transform(img)# # img = img.numpy().transpose(1, 2, 0)# # plt.imshow(img)# # plt.show()# img = img.unsqueeze(0).cuda()# output = model(img)train_loader = get_dataloader(transform)criterion = build_criterion()num_epochs = 500for epoch in range(num_epochs):for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)output = model(data)output = output.permute(0, 2, 3, 1).contiguous()features = output.shape[-1]output = output.view(-1, features)  # (batchsize*width*height, features)target = target.view(-1)  # (batchsize*width*height)loss = criterion(output, target)optimizer = optim.AdamW(model.parameters(), lr=0.0001)optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 10 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))if (epoch + 1) % 10 == 0:torch.save(model.state_dict(), CHECKPOINT_PATH)print(f"Model weights saved at epoch {epoch+1}.")

 网络输出的tensor的形状是(batch_size, 特征数,width, height)。将输出reshape成(batch_size*width*height, 特征数),将ground truth形状变成(batch_size*width*height,)这样的一个一维数组,数组中每个元素就是种类的index,代表该像素属于什么类。通过

criterion = nn.CrossEntropyLoss()

损失函数训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/197235.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BUUCTF 小易的U盘 1

BUUCTF:https://buuoj.cn/challenges 题目描述&#xff1a; 小易的U盘中了一个奇怪的病毒&#xff0c;电脑中莫名其妙会多出来东西。小易重装了系统&#xff0c;把U盘送到了攻防实验室&#xff0c;希望借各位的知识分析出里面有啥。请大家加油噢&#xff0c;不过他特别关照&a…

【C++ STL】vector类最全详解(什么是vector?vector类的常用接口有哪些?)

目录 一、前言 二、什么是vector ? &#x1f4a6; vector的基本概念 &#x1f4a6;vector的作用是什么 &#x1f4a6;总结 三、 vector的(一维)定义 四、vector(一维)常用接口的使用 &#x1f4a6;vector的常见构造&#xff08;初始化&#xff09; &#x1f4a6;vector…

ISIS配置以及详解

作者简介&#xff1a;大家好&#xff0c;我是Asshebaby&#xff0c;热爱网工&#xff0c;有网络方面不懂的可以加我一起探讨 :1125069544 个人主页&#xff1a;Asshebaby博客 当前专栏&#xff1a; 网络HCIP内容 特色专栏&#xff1a; 常见的项目配置 本文内容&am…

2024年天津中德应用技术大学专升本专业课报名及考试时间通知

天津中德应用技术大学2024年高职升本科专业课报名确认及考试通知 按照市高招办《2024年天津市高职升本科招生实施办法》&#xff08;津招办高发〔2023〕14号&#xff09;文件要求&#xff0c;天津中德应用技术大学制定了2024年高职升本科专业课考试报名、确认及考试实施方案&a…

JFrog----软件成分分析(SCA)简介

文章目录 1. SCA的重要性2. SCA的工作方式3. 安全漏洞分析4. 许可证合规性5. 代码质量和维护性结语 在当今的快速发展的软件行业中&#xff0c;软件成分分析&#xff08;Software Composition Analysis&#xff0c;简称SCA&#xff09;已成为一个不可或缺的工具。SCA的主要任务…

服务异步通讯

四、服务异步通讯 4.1初始MQ 4.1.1同步通讯和异步通讯 同步调用的优点: 时效性较强,可以立即得到结果 同步调用的问题: 耦合度高性能和吞吐能力下降有额外的资源消耗有级联失败问题异步通信的优点: 耦合度低吞吐量提升故障隔离流量削峰异步通信的缺点: 依赖于Broker的…

记一次SQL Server磁盘突然满了导致数据库锁死事件is full due to ‘LOG_BACKUP‘.

背景 最近我们的sql server 数据库磁盘在80左右&#xff0c;需要新增磁盘空间。还是处以目前可控的范围内&#xff0c;但是昨天晚上告警是80%&#xff0c;凌晨2:56分告警是90%&#xff0c;今天早上磁盘就满了。 经过 通过阿里云后台查看&#xff0c;磁盘已经占据99%&#xff0c…

蓝牙概述及基本架构介绍

蓝牙概述及基本架构介绍 1. 概述1.1 蓝牙的概念1.2 蓝牙的发展历程1.3 蓝牙技术概述1.3.1 Basic Rate(BR)1.3.2 Low Energy&#xff08;LE&#xff09; 2. 蓝牙的基本架构2.1 芯片架构2.2 协议架构2.2.1 官方协议中所展示的蓝牙协议架构2.2.1.1 全局分析2.2.1.2 局部分析 2.2.2…

广告公司选择企业邮箱的策略与技巧

对于广告公司而言&#xff0c;选择一款适合的企业邮箱不仅能提升工作效率&#xff0c;更能维护并强化公司的品牌形象。以下是在选择企业邮箱时需关注的关键因素和注意事项。 1、邮件服务商的安全性。 邮件服务商应具备严密的安全防护措施&#xff0c;包括反垃圾邮件、防病毒、防…

使用JDBC连接和操作数据库以及myBatis初级入门

JDBC简介和使用 java程序操作数据库的方式有很多种&#xff0c;下面列举一些市面上常用的方式&#xff1a; 从图片分析的知&#xff1a; MyBatis MyBatisPlus 这两个所占的比重比较大。都是用于简化JDBC开发的 JDBC&#xff1a;(Java DataBase Connectivity)&#xff0c;就…

95基于matlab的多目标优化算法NSGA3

基于matlab的多目标优化算法NSGA3&#xff0c;动态输出优化过程&#xff0c;得到最终的多目标优化结果。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 95matlab多目标优化 (xiaohongshu.com)

leetcode做题笔记1038. 从二叉搜索树到更大和树

给定一个二叉搜索树 root (BST)&#xff0c;请将它的每个节点的值替换成树中大于或者等于该节点值的所有节点值之和。 提醒一下&#xff0c; 二叉搜索树 满足下列约束条件&#xff1a; 节点的左子树仅包含键 小于 节点键的节点。节点的右子树仅包含键 大于 节点键的节点。左右…

智能优化算法应用:基于树种算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于树种算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于树种算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.树种算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

【腾讯云云上实验室】个人对腾讯云向量数据库的体验心得

目录 前言Tencent Cloud VectorDB概念使用初体验腾讯云向量数据库的优势应用场景有哪些&#xff1f;未来展望番外篇&#xff1a;腾讯云向量数据库的设计核心结语 前言 还是那句话&#xff0c;不用多说想必大家都能猜到&#xff0c;现在技术圈最火的是什么&#xff1f;非人工智…

【LeetCode热题100】【双指针】三数之和

给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所有和为 0 且不重复的三元组。 注意&#xff1a;答案中不可以包含重复的三元组。 示例 …

Java开发中一些重要软件安装配置

Java技术栈中重要过程 1、JavaWeb1、开发工具VsCode的安装和使用2、Tomcat服务器3、nodejs的简介和安装4、Vite创建Vue3工程化项目ViteVue3项目的创建、启动、停止ViteVue3项目的目录结构 5、Maven安装和配置 1、JavaWeb 1、开发工具VsCode的安装和使用 1 安装过程 安装过程比…

OpenResty(nginx+lua+resty-http)实现访问鉴权

OpenResty(nginxluaresty-http)实现访问鉴权 最近用BI框架解决了一些报表需求并生成了公开链接&#xff0c;现在CMS开发人员打算将其嵌入到业务系统中&#xff0c;结果发现公开链接一旦泄露任何人都可以访问&#xff0c;需要实现BI系统报表与业务系统同步的权限控制。但是目前…

视频分割方法:批量剪辑高效分割视频,提取m3u8视频技巧

随着互联网的快速发展&#xff0c;视频已成为获取信息、娱乐、学习等多种需求的重要载体。然而&#xff0c;很多时候&#xff0c;需要的只是视频的一部分&#xff0c;这就要对视频进行分割。而m3u8视频是一种常见的流媒体文件格式&#xff0c;通常用于在线视频播放。本文将分享…

【前端首屏加载速度优化(0): 谷歌浏览器时间参数】

DOMContentLoaded 浏览器已经完全加载了 HTML&#xff0c;DOM树构建完成&#xff0c;但是像是 <img> 和样式表等外部资源可能并没有下载完毕。 Load DOM树构建完成后&#xff0c;继续加载 html/css 中的外部资源&#xff0c;加载完成之后&#xff0c;视为页面加载完成。…

听说这样寄快递会很省钱!速来收藏!便宜寄快递!

哈喽&#xff0c;小伙伴们&#xff0c;今天介绍一个寄快递省钱的小妙招。话不多说&#xff0c;小编直接开整&#xff01; 在闪侠惠递这个快递折扣平台下单送快递&#xff0c;不仅方便&#xff0c;而且运费更便宜。平台整合了京东、顺丰、圆通、申通、极兔、等快递公司可供选择…