【代码整理】基于COCO格式的pytorch Dataset类实现

import模块

import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

class Transform():'''数据预处理/数据增强(基于albumentations库)'''def __init__(self, imgSize):maxSize = max(imgSize[0], imgSize[1])# 训练时增强self.trainTF = A.Compose([A.BBoxSafeRandomCrop(p=0.5),# 最长边限制为imgSizeA.LongestMaxSize(max_size=maxSize),A.HorizontalFlip(p=0.5),# 参数:随机色调、饱和度、值变化A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),# 随机明亮对比度A.RandomBrightnessContrast(p=0.2),   # 高斯噪声A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     A.OneOf([# 使用随机大小的内核将运动模糊应用于输入图像A.MotionBlur(p=0.2),   # 中值滤波A.MedianBlur(blur_limit=3, p=0.1),    # 使用随机大小的内核模糊输入图像A.Blur(blur_limit=3, p=0.1),  ], p=0.2),# 较短的边做paddingA.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),],bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),)# 验证时增强self.validTF = A.Compose([# 最长边限制为imgSizeA.LongestMaxSize(max_size=maxSize),# 较短的边做paddingA.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),],bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),)

自定义数据集读取类COCODataset实现


class COCODataset(Dataset):def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):'''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径Args::param annPath:     COCO annotation 文件路径:param imgDir:      图像的根目录:param inputShape: 网络要求输入的图像尺寸:param trainMode:   训练集/测试集Returns:FRCNNDataset'''      self.mode = trainModeself.tf = Transform(imgSize=inputShape)self.imgDir = imgDirself.annPath = annPathself.DataNums = len(os.listdir(imgDir))# 为实例注释初始化COCO的APIself.coco=COCO(annPath)# 获取数据集中所有图像对应的imgIdself.imgIds = list(self.coco.imgs.keys())def __len__(self):'''重载data.Dataset父类方法, 返回数据集大小'''return len(self.imgIds)def __getitem__(self, index):'''重载data.Dataset父类方法, 获取数据集中数据内容这里通过pycocotools来读取图像和标签'''   # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}imgId = self.imgIds[index]imgInfo = self.coco.loadImgs(imgId)[0]# 载入图像 (通过imgInfo获取图像名,得到图像路径)               image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))image = np.array(image.convert('RGB'))# 得到图像里包含的BBox的所有idimgAnnIds = self.coco.getAnnIds(imgIds=imgId)   # 通过BBox的id找到对应的BBox信息anns = self.coco.loadAnns(imgAnnIds) # 获取BBox的坐标和类别labels, boxes = [], []for ann in anns:labelName = ann['category_id']labels.append(labelName)boxes.append(ann['bbox'])labels = np.array(labels)boxes = np.array(boxes)# 训练/验证时的数据增强各不相同if(self.mode):# albumentation的图像维度得是[W,H,C]transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)else:transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)# 这里的box是coco格式(xywh)image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']return image.transpose(2,0,1), np.array(box), np.array(label)

其他

# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):images = []bboxes = []labels = []for img, box, label in batch:images.append(img)bboxes.append(box)labels.append(label)images = torch.from_numpy(np.array(images))return images, bboxes, labels# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):worker_seed = worker_id + seedrandom.seed(worker_seed)np.random.seed(worker_seed)torch.manual_seed(worker_seed)# 固定全局随机数种子
def seed_everything(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False

batch数据集可视化

def visBatch(dataLoader:DataLoader):'''可视化训练集一个batchArgs:dataLoader: torch的data.DataLoaderRetuens:None     '''catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}for step, batch in enumerate(dataLoader):images, boxes, labels = batch[0], batch[1], batch[2]# 只可视化一个batch的图像:if step > 0: break# 图像均值mean = np.array([0.485, 0.456, 0.406]) # 标准差std = np.array([[0.229, 0.224, 0.225]]) plt.figure(figsize = (8,8))for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):img, box, label = imgBoxLabelax = plt.subplot(4,4,idx+1)img = img.numpy().transpose((1,2,0))# 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原img = img * std + meanfor instBox, instLabel in zip(box, label):x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])# 显示框ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))# 显示类别ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})plt.imshow(img)# 在图像上方展示对应的标签# 取消坐标轴plt.axis("off")# 微调行间距plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)plt.show()

example

# for test only:
if __name__ == "__main__":# 固定随机种子seed = 23seed_everything(seed)# BatcchSizeBS = 16# 图像尺寸imgSize = [800, 800]trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"imgDir =  "E:/datasets/Universal/COCO2017/train2017"# 自定义数据集读取类trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))# validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)# validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))print(f'训练集大小 : {trainDataset.__len__()}')visBatch(trainDataLoader)for step, batch in enumerate(trainDataLoader):images, boxes, labels = batch[0], batch[1], batch[2]# torch.Size([bs, 3, 800, 800])print(f'images.shape : {images.shape}')   # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一print(f'len(boxes) : {len(boxes)}')     # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  print(f'len(labels) : {len(labels)}')     break

输出

在这里插入图片描述

images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/637307.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java SSM园林绿化管理系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM园林绿化管理系统是一套完善的web设计系统(系统采用SSM框架进行设计开发,springspringMVCmybatis),对理解JSP java编程开发语言有帮助,系统具有完整的源代 码和数据库,系统主要采…

网易真的大规模裁员吗?

关注卢松松,会经常给你分享一些我的经验和观点。 以前互联网公司裁员,大家不紧张,因为容易找工作,而现在不知道怎么回事,只要以提高某某公司裁员,这就能迅速登上热榜。 这不,最近网传网易裁员1…

二、项目开发计划模板

1.引言 1.1编写目的 1.2项目背景 1.3定义 1.4参考资料 2.项目概述 2.1工作内容 2.2条件与限制 2.3产品 2.4运行环境 2.5服务 2.6验收标准 3.实施计划 3.1任务分解 3.2进度 3.3预算 3.4关键问题 4&#xff…

Linux的IO文件操作和文件系统

前要:本次我想给您带来关于 IO 和文件的知识,而文件在本系列中分为内存上的文件和磁盘上的文件。 1.文件概念 1.1.文件读写 在谈及系统接口之前,我们先来从 C 语言的角度来谈及一些前要知识,以辅助我们后续来理解系统 IO。 我们…

大数据导论(3)---大数据技术

文章目录 1. 大数据技术概述2. 数据采集与预处理2.1 数据采集2.2 预处理 3. 数据存储和管理3.1 分布式基础架构Hadoop3.2 分布式文件系统HDFS3.3 分布式数据库HBase3.4 非关系型数据库NoSQL 4. 数据可视化与保护 1. 大数据技术概述 大数据技术主要包括数据采集与预处理、数据存…

关于常见分布式组件高可用设计原理的理解和思考

文章目录 1. 数据存储场景和存储策略1.1 镜像模式-小规模数据1.2 分片模式-大规模数据 2. 数据一致性和高可用问题2.1 镜像模式如何保证数据一致性2.2 镜像模式如何保证数据高可用2.2.1 HA模式2.2.2 分布式选主模式 2.3 分片模式如何数据一致性和高可用 3. 大规模数据集群的架构…

32 登录页组件

效果演示 实现了一个登录页面的样式,包括一个容器、左侧和右侧部分。左侧部分是一个背景图片,右侧部分是一个表单,包括输入框、复选框、按钮和忘记密码链接。整个页面的背景色为白色,容器为一个圆角矩形,表单为一个半透…

linux C语言socket函数send

在Linux中,使用C语言进行网络编程时,send函数是用于发送数据到已连接的套接字的重要函数之一。它通常用于TCP连接,但也可以用于UDP(尽管对于UDP,通常更推荐使用sendto,因为它允许你指定目标地址和端口&…

建议数据库设计的必选字段

在数据库设计时,建议以下13个字段设置为数据库必要字段,以保证数据的完整和连续。(参考阿里开发规范,结合业务特点) id(id) id 是否删除(if_delete) 用于表达该记录是…

vivado 平台板流程

介绍 板文件使用XML格式来定义有关使用或的系统级板的信息包括AMD设备。AMD可以使用板文件中包含的信息Vivado™ Design Suite和Vivado IP集成商,以促进和验证AMD的连接设备到板。本章讨论董事会文件的不同部分及其用法本附录中所示的示例使用AMD Kintex 7 KC705评…

【linux驱动】用户空间程序与内核模块交互-- IOCTL和Netlink

创建自定义的IOCTL(输入/输出控制)或Netlink命令以便用户空间程序与内核模块交互涉及几个步骤。这里将分别介绍这两种方法。 一、IOCTL 方法 1. 定义IOCTL命令 在内核模块中,需要使用宏定义你的IOCTL命令。通常情况下,IOCTL命令…

python 基础知识点(蓝桥杯python 科目个人复习计划22)

今日复习内容:基础算法中的时间复杂度 时间复杂度分析 时间复杂度是衡量算法执行时间随输入规模增长的增长率。通过分析算法中基本操作的执行次数来确定时间复杂度‘常见的时间复杂度包括:常数时间O(1),线性时间O(n),对数时间O(log n)&…

[GN] Vue3.2 快速上手 ---- 核心语法(终章)_3

文章目录 路由器工作模式命名路由to的三种写法嵌套路由路由传参query参数params参数 路由的props配置replace 和 push编程式导航重定向 总结 路由器工作模式 history模式 优点:URL更加美观,不带有#,更接近传统的网站URL。 缺点:后…

UIElement编辑器扩展 组件 Inspector

UIElement编辑器扩展 组件 Inspector https://docs.unity.cn/cn/2021.3/Manual/UIE-create-a-binding-uxml-inspector.html 简单开始 声明序列化VisualTreeAsset [SerializeField] VisualTreeAsset visualTree; 声明完,直接在脚本的Inspector面板,把你…

水塘抽样算法

水塘抽样算法 1、问题描述 最近经常能看到面经中出现在大数据流中的随机抽样问题 即:当内存无法加载全部数据时,如何从包含未知大小的数据流中随机选取k个数据,并且要保证每个数据被抽取到的概率相等。 假设数据流含有N个数,我…

JS中运算符的算术、赋值、+、比较(不同类型之间比较)、逻辑

在JavaScript中,运算符用于执行各种计算和操作。 算术运算符: :用于加法运算。 javascriptlet a 5; let b 3; let sum a b; // 结果: 8 -:用于减法运算。 javascriptlet difference a - b; // 结果: 2 *:用于乘法…

树莓派挂载fat32 u盘

通过fdisk -l 查到设备是sda1 sudo nano /etc/fstab 文件末尾添加: /dev/sda1 /home/pi/mydic_mount auto defaults,noexec,umask0000 0 0 参考文章树莓派linux系统 挂载硬盘(U盘)相关知识总结(五星推荐)_树莓派挂…

Rancher部署k8s集群测试安装nginx(节点重新初始化方法,亲测)

目录 一、安装前准备工作计算机升级linux内核时间同步Hostname设置hosts设置关闭防火墙,selinux关闭swap安装docker 二、安装rancher部署rancher 三、安装k8s安装k8s集群易错点,重新初始化 四、安装kutectl五、测试安装nginx工作负载 一、安装前准备工作…

SD-WAN企业组网场景深度解析

在当前快速发展的企业网络环境中,SD-WAN技术不仅仅是实现企业站点之间网络互通的关键,更是满足不同站点对因特网、SaaS云应用、公有云等多种企业应用和业务访问的理想选择。从企业的WAN业务需求出发,我们可以对SD-WAN的组网场景进行深度解析&…

参数校验: spring-boot-starter-validation

参数校验: spring-boot-starter-validation pom.xml <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-validation</artifactId></dependency>应用 PostMapping("/login")public Re…