【代码整理】基于COCO格式的pytorch Dataset类实现

import模块

import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

class Transform():'''数据预处理/数据增强(基于albumentations库)'''def __init__(self, imgSize):maxSize = max(imgSize[0], imgSize[1])# 训练时增强self.trainTF = A.Compose([A.BBoxSafeRandomCrop(p=0.5),# 最长边限制为imgSizeA.LongestMaxSize(max_size=maxSize),A.HorizontalFlip(p=0.5),# 参数:随机色调、饱和度、值变化A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),# 随机明亮对比度A.RandomBrightnessContrast(p=0.2),   # 高斯噪声A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     A.OneOf([# 使用随机大小的内核将运动模糊应用于输入图像A.MotionBlur(p=0.2),   # 中值滤波A.MedianBlur(blur_limit=3, p=0.1),    # 使用随机大小的内核模糊输入图像A.Blur(blur_limit=3, p=0.1),  ], p=0.2),# 较短的边做paddingA.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),],bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),)# 验证时增强self.validTF = A.Compose([# 最长边限制为imgSizeA.LongestMaxSize(max_size=maxSize),# 较短的边做paddingA.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),],bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),)

自定义数据集读取类COCODataset实现


class COCODataset(Dataset):def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):'''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径Args::param annPath:     COCO annotation 文件路径:param imgDir:      图像的根目录:param inputShape: 网络要求输入的图像尺寸:param trainMode:   训练集/测试集Returns:FRCNNDataset'''      self.mode = trainModeself.tf = Transform(imgSize=inputShape)self.imgDir = imgDirself.annPath = annPathself.DataNums = len(os.listdir(imgDir))# 为实例注释初始化COCO的APIself.coco=COCO(annPath)# 获取数据集中所有图像对应的imgIdself.imgIds = list(self.coco.imgs.keys())def __len__(self):'''重载data.Dataset父类方法, 返回数据集大小'''return len(self.imgIds)def __getitem__(self, index):'''重载data.Dataset父类方法, 获取数据集中数据内容这里通过pycocotools来读取图像和标签'''   # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}imgId = self.imgIds[index]imgInfo = self.coco.loadImgs(imgId)[0]# 载入图像 (通过imgInfo获取图像名,得到图像路径)               image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))image = np.array(image.convert('RGB'))# 得到图像里包含的BBox的所有idimgAnnIds = self.coco.getAnnIds(imgIds=imgId)   # 通过BBox的id找到对应的BBox信息anns = self.coco.loadAnns(imgAnnIds) # 获取BBox的坐标和类别labels, boxes = [], []for ann in anns:labelName = ann['category_id']labels.append(labelName)boxes.append(ann['bbox'])labels = np.array(labels)boxes = np.array(boxes)# 训练/验证时的数据增强各不相同if(self.mode):# albumentation的图像维度得是[W,H,C]transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)else:transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)# 这里的box是coco格式(xywh)image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']return image.transpose(2,0,1), np.array(box), np.array(label)

其他

# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):images = []bboxes = []labels = []for img, box, label in batch:images.append(img)bboxes.append(box)labels.append(label)images = torch.from_numpy(np.array(images))return images, bboxes, labels# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):worker_seed = worker_id + seedrandom.seed(worker_seed)np.random.seed(worker_seed)torch.manual_seed(worker_seed)# 固定全局随机数种子
def seed_everything(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False

batch数据集可视化

def visBatch(dataLoader:DataLoader):'''可视化训练集一个batchArgs:dataLoader: torch的data.DataLoaderRetuens:None     '''catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}for step, batch in enumerate(dataLoader):images, boxes, labels = batch[0], batch[1], batch[2]# 只可视化一个batch的图像:if step > 0: break# 图像均值mean = np.array([0.485, 0.456, 0.406]) # 标准差std = np.array([[0.229, 0.224, 0.225]]) plt.figure(figsize = (8,8))for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):img, box, label = imgBoxLabelax = plt.subplot(4,4,idx+1)img = img.numpy().transpose((1,2,0))# 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原img = img * std + meanfor instBox, instLabel in zip(box, label):x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])# 显示框ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))# 显示类别ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})plt.imshow(img)# 在图像上方展示对应的标签# 取消坐标轴plt.axis("off")# 微调行间距plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)plt.show()

example

# for test only:
if __name__ == "__main__":# 固定随机种子seed = 23seed_everything(seed)# BatcchSizeBS = 16# 图像尺寸imgSize = [800, 800]trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"imgDir =  "E:/datasets/Universal/COCO2017/train2017"# 自定义数据集读取类trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))# validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)# validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))print(f'训练集大小 : {trainDataset.__len__()}')visBatch(trainDataLoader)for step, batch in enumerate(trainDataLoader):images, boxes, labels = batch[0], batch[1], batch[2]# torch.Size([bs, 3, 800, 800])print(f'images.shape : {images.shape}')   # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一print(f'len(boxes) : {len(boxes)}')     # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  print(f'len(labels) : {len(labels)}')     break

输出

在这里插入图片描述

images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/637307.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java SSM园林绿化管理系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM园林绿化管理系统是一套完善的web设计系统(系统采用SSM框架进行设计开发,springspringMVCmybatis),对理解JSP java编程开发语言有帮助,系统具有完整的源代 码和数据库,系统主要采…

网易真的大规模裁员吗?

关注卢松松,会经常给你分享一些我的经验和观点。 以前互联网公司裁员,大家不紧张,因为容易找工作,而现在不知道怎么回事,只要以提高某某公司裁员,这就能迅速登上热榜。 这不,最近网传网易裁员1…

Linux的IO文件操作和文件系统

前要:本次我想给您带来关于 IO 和文件的知识,而文件在本系列中分为内存上的文件和磁盘上的文件。 1.文件概念 1.1.文件读写 在谈及系统接口之前,我们先来从 C 语言的角度来谈及一些前要知识,以辅助我们后续来理解系统 IO。 我们…

大数据导论(3)---大数据技术

文章目录 1. 大数据技术概述2. 数据采集与预处理2.1 数据采集2.2 预处理 3. 数据存储和管理3.1 分布式基础架构Hadoop3.2 分布式文件系统HDFS3.3 分布式数据库HBase3.4 非关系型数据库NoSQL 4. 数据可视化与保护 1. 大数据技术概述 大数据技术主要包括数据采集与预处理、数据存…

关于常见分布式组件高可用设计原理的理解和思考

文章目录 1. 数据存储场景和存储策略1.1 镜像模式-小规模数据1.2 分片模式-大规模数据 2. 数据一致性和高可用问题2.1 镜像模式如何保证数据一致性2.2 镜像模式如何保证数据高可用2.2.1 HA模式2.2.2 分布式选主模式 2.3 分片模式如何数据一致性和高可用 3. 大规模数据集群的架构…

32 登录页组件

效果演示 实现了一个登录页面的样式,包括一个容器、左侧和右侧部分。左侧部分是一个背景图片,右侧部分是一个表单,包括输入框、复选框、按钮和忘记密码链接。整个页面的背景色为白色,容器为一个圆角矩形,表单为一个半透…

linux C语言socket函数send

在Linux中,使用C语言进行网络编程时,send函数是用于发送数据到已连接的套接字的重要函数之一。它通常用于TCP连接,但也可以用于UDP(尽管对于UDP,通常更推荐使用sendto,因为它允许你指定目标地址和端口&…

【linux驱动】用户空间程序与内核模块交互-- IOCTL和Netlink

创建自定义的IOCTL(输入/输出控制)或Netlink命令以便用户空间程序与内核模块交互涉及几个步骤。这里将分别介绍这两种方法。 一、IOCTL 方法 1. 定义IOCTL命令 在内核模块中,需要使用宏定义你的IOCTL命令。通常情况下,IOCTL命令…

Rancher部署k8s集群测试安装nginx(节点重新初始化方法,亲测)

目录 一、安装前准备工作计算机升级linux内核时间同步Hostname设置hosts设置关闭防火墙,selinux关闭swap安装docker 二、安装rancher部署rancher 三、安装k8s安装k8s集群易错点,重新初始化 四、安装kutectl五、测试安装nginx工作负载 一、安装前准备工作…

SD-WAN企业组网场景深度解析

在当前快速发展的企业网络环境中,SD-WAN技术不仅仅是实现企业站点之间网络互通的关键,更是满足不同站点对因特网、SaaS云应用、公有云等多种企业应用和业务访问的理想选择。从企业的WAN业务需求出发,我们可以对SD-WAN的组网场景进行深度解析&…

VIM工程的编译 / VI的快捷键记录

文章目录 VIM工程的编译 / VI的快捷键记录概述笔记工程的编译工程的编译 - 命令行vim工程的编译 - GUI版vim备注VIM的帮助文件位置VIM官方教程vim 常用快捷键启动vi时, 指定要编辑哪个文件正常模式光标的移动退出不保存 退出保存只保存不退出另存到指定文件移动到行首移动到行尾…

替代堆叠的新技术M-lag

M-lag:跨设备链路聚合组,是一种实现跨设备链路聚合的机制。将一台设备与另外两台设备进行跨设备链路聚合,从而把链路的可靠性从单板级提升到设备级,组成双活系统。 基本概念: peer-link链路:是一条聚合链…

[C#]winform部署官方yolov8-rtdetr目标检测的onnx模型

【官方框架地址】 https://github.com/ultralytics/ultralytics 【算法介绍】 RTDETR,全称“Real-Time Detection with Transformer for Object Tracking and Detection”,是一种基于Transformer结构的实时目标检测和跟踪算法。它在目标检测和跟踪领域…

力扣刷MySQL-第五弹(详细讲解)

🎉欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克🍹 ✨博客主页:小小恶斯法克的博客 🎈该系列文章专栏:力扣刷题讲解-MySQL 🍹文章作者技术和水平很有限,如果文中出…

Java 面向对象02 封装 (黑马)

人画圆:画圆这个方法应该定义在园这个类里面。 人关门:是人给了门一个作用力,然后门自己关上了门,所以关门的方法是在门的类里面 封装对象的好处: 调用Java自带的方法举例实现: 在测试类中,对其…

电脑pdf如何转换成word格式?用它实现pdf文件一键转换

pdf转word格式可以用于提取和重用pdf文档中的内容,有时候,我们可能需要引用或引用pdf文档中的一些段落、表格或数据,通过将pdf转换为可编辑的Word文档,可以轻松地复制和粘贴所需内容,节省我们的时间,那么如…

Element-UI 多个el-upload组件自定义上传,不用上传url,并且携带自定义传参(文件序号)

1. 需求: 有多个(不确定具体数量)的upload组件,每个都需要单独上传获取文件(JS File类型),不需要action上传到指定url,自定义上传动作和http操作。而且因为不确定组件数量&#xff0…

Oracle 经典练习题 50 题

文章目录 一 CreateTable二 练习题1 查询"01"课程比"02"课程成绩高的学生的信息及课程分数2 查询"01"课程比"02"课程成绩低的学生的信息及课程分数3 查询平均成绩大于等于60分的同学的学生编号和学生姓名和平均成绩4 查询平均成绩小于…

力扣精选算法100题——串联所有单词的字串(滑动窗口专题)

本题链接——串联所有单词的字串 本题和找到字符串中所有字母异位词题目非常相似,思路都是一样。通过自己的大脑能发现其中的相似之处。 第一步:了解题意 就按实例来分析吧,这样更通俗易懂。 words["ab","cd","ef…

Pycharm Terminal 无法激活conda环境

1.问题 Failed to activate conda environment. Please open Anaconda prompt, and run conda init powershell there. 这导致我们无法在Pycharm中使用conda命令 2.解决办法 修改为第二个,然后重启Terminal 再打开时发现已经是当前的conda环境