pytorch图像数据集定义

文章目录

    • 相关链接
    • Dataset
      • VisionDataset
      • DatsetFolder
      • ImageFolder
    • torchvision.transforms
    • Pytorch Lightning
      • LightningDataModule

对于图像数据集来说,首先是在Dataset类对数据集进行定义,一般来说不定义transform,则数据为PIL Image,PIL格式到tensor的转换也是transforms变换的一种,所以定Dataset+transforms+Dataloader,最后在训练部分to(device)来得到模型的输入。

相关链接

torchvision.datasets的三个基础类
torchvision.datasets
torch.utils.data.Dataset
Pillow(PIL Fork) Image模块

Dataset

Dataset是数据集在pytorch中的化身,需要重写__ getitem__ 和 __ len__。__ __ getitem__ 通过传入的索引加载指定路径的数据,路径常常是一个列表,如很多张图片组成的数据集,需要在初始化时定义函数得到路径列表,或者在外部定义,总之要得到一个路径List。也需要在其中定义或调用具体读取的代码,如PIL库的Image.open()来读取图片,或Image.fromarray()来创建图片,也就是需要知道数据在哪里和怎么读取。

└─Dataset└─VisionDataset└─DatasetFolder└─ImageFolder

Dataset是torch.utils.data中的类,是数据集的基础类

VisionDataset

VisionDataset是torchvision.datasets.vision中的类,是torchvision类数据集的基础类,相比于原始的Dataset类,提供了transform,transforms,target_transform数据变换的接口

DatasetFolder,ImageFolder都来自torchvision.datasets.folder ,既然叫做folder,实际上已经有了完整的数据集功能,可以按照默认的目录结构读取数据。DatasetFolder还需要定义loader以读取特定类型的数据,和is_valid_file或者extensions,is_valid_file和extensions不能同时定义,但必须有一个定义,如果定义了有效后缀名,会自动通过后缀来判断文件有效性。而ImageFolder更进一步,默认使用读取图像数据的loader读取,还默认定义了图像后缀名。从Dataset到ImageFolder构成了不同层次的封装,完成度越高,灵活性越低,可以根据自己的需要选择。

除了在__ getitem__ 中通过得到的路径列表来读取数据,对于不同格式的数据也有不同的做法,如torchvision中内置cifar数据集,会直接从原始数据中以矩阵的形式读取, 因此 __ getitem__ 会从矩阵中创建Image对象。总而言之,一般来讲对于图片数据集来说,__ get __返回的都是PIL Image对象,不管是从路径列表中读取,还是整个以矩阵形式读取,如果不定义transform,最后在Dataset阶段都是PIL对象。

DatsetFolder

默认的排列结构如下,每一个文件夹表示一类,下面是这一类的样本

​      directory/​      ├── class_x​      │  ├── xxx.ext​      │  ├── xxy.ext​      │  └── ...​      │    └── xxz.ext​      └── class_y​        ├── 123.ext​        ├── nsdf3.ext​        └── ...​        └── asd932_.ext

用文件夹来区分不同的类别。比较重要的有两类操作,find_class函数得到类别名和类别序号。make_dataset得到路径列表。

默认的findclass函数

文件夹名是类名。

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idx

默认的make_dataset

得到instance列表,表示文件的路径列表。

基本上很大一部分是在定义有效性判断相关,主要部分是一个双层for循环,因为类名定义为文件夹名,所以会遍历各个类的文件夹,会将遍历到的有效文件的路径加入instance,遍历过的非空类添加到available_classe。

def make_dataset(directory: Union[str, Path],class_to_idx: Optional[Dict[str, int]] = None,extensions: Optional[Union[str, Tuple[str, ...]]] = None,is_valid_file: Optional[Callable[[str], bool]] = None,allow_empty: bool = False,
) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x: str) -> bool:return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]is_valid_file = cast(Callable[[str], bool], is_valid_file)instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if is_valid_file(path):item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes and not allow_empty:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "if extensions is not None:msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"raise FileNotFoundError(msg)return instances

ImageFolder

 root/dog/xxy.pngroot/dog/[...]/xxz.pngroot/cat/123.pngroot/cat/nsdf3.png

ImageFolder如名字所示,如果数据集是这种文件夹排列,而且是图像文件,又没有需要特殊定义的部分 ,可以直接实例化一个ImageFolder,而不需要重写任何部分 ,实例化一个数据集只需要传入数据集路径和tansform变换。

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']]

torchvision.transforms

一般会在数据集实例化时,从外部传入,通常自定义的Transforms序列包含ToTensor,可以将上一阶段的PIL Image转换为Tensor,而下一次变化要到训练时的to(device),这样数据最终输入完成,也可以在Dataset类中写入默认的transform。

通过torchvision.get_image_backend得到torchvision现在的后端默认为PILtorchvision.set_image_backend(backend)指定用来读取图片的包,可选accimage

Loader将数据读取为PIL对象,一般数据集定义不在数据集内部定义默认的transform图像变换,而是在外部定义一个transform序列,通常倒数第二个是torchvision.transforms.ToTensor()操作,会将一个PIL Image或者一个ndarray转换为tensor并缩放到[0.0, 1.0]。因此接下来会通过transforms.Normalize进行归一化。

PILToTensor会把PIL Image转化为tensor,但是不会进行缩放, ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)(C×H×W)

ToTensor会把PIL Image或者ndarray转换成tensor而且会进行缩放。 ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)(C×H×W)​ 在规定的模式如RGBA,RGB,YCbCr或者dtype = np.uint8情况下,别的情况下不缩放。

Normalize只支持tensor,其他大部分操作也支持PIL,所以在ToTensor之后最后进行Normalize

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

Pytorch Lightning

Pytorch Lightning是Pytorch中的kersas,简称pl

LightningDataModule

Pytorch Lightning继承LightningDataModule定义数据集,pl中的Dataset和Dataloader是高度耦合的。

import lightning.pytorch as L
import torch.utils.data as data
from pytorch_lightning.demos.boring_classes import RandomDatasetclass MyDataModule(L.LightningDataModule):def prepare_data(self):# download, IO, etc. Useful with shared filesystems# only called on 1 GPU/TPU in distributed...def setup(self, stage):# make assignments here (val/train/test split)# called on every process in DDPdataset = RandomDataset(1, 100)self.train, self.val, self.test = data.random_split(dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42))def train_dataloader(self):return data.DataLoader(self.train)def val_dataloader(self):return data.DataLoader(self.val)def test_dataloader(self):return data.DataLoader(self.test)def teardown(self):# clean up state after the trainer stops, delete files...# called on every process in DDP...

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/766530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sora入门级概念、Open-Sora 1.0和现状挑战(附多个文生视频 Prompt 案例)

OpenAI Sora入门级概念 Sora模型是OpenAI 发布的人工智能模型,它主要用于生成和处理视频内容。以下是Sora模型的一些入门级概念: 视频内容生成:Sora模型能够根据文本描述生成视频内容。这意味着你可以输入一段描述性的文本,模型将基于这段文本生成相应的视频画面。场景和角…

Github 2024-03-19 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-03-19统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目9TypeScript项目2HTML项目1GDScript项目1MetaGPT: 多代理框架 创建周期:260 天开发语言:Python协议类型:MIT LicenseStar数量:35…

MongoDb数据库介绍安装使用

#安装mongodb# 第一步 下载mongoDb: 官网https://www.mongodb.com/ 第二步 进行安装配置修改Data directory 和 Log Directory 将数据目录和日志目录存放在D盘 第三步 取消install MongoDb Compass这个是安装可视化工具的意思在这里不需要 #配置环境变量加入到系统中的path环…

鸿蒙Harmony应用开发—ArkTS-LazyForEach:数据懒加载

LazyForEach从提供的数据源中按需迭代数据,并在每次迭代过程中创建相应的组件。当在滚动容器中使用了LazyForEach,框架会根据滚动容器可视区域按需创建组件,当组件滑出可视区域外时,框架会进行组件销毁回收以降低内存占用。 接口…

EI级!高创新原创未发表!VMD-TCN-BiGRU-MATT变分模态分解卷积神经网络双向门控循环单元融合多头注意力机制多变量时间序列预测(Matlab)

EI级!高创新原创未发表!VMD-TCN-BiGRU-MATT变分模态分解卷积神经网络双向门控循环单元融合多头注意力机制多变量时间序列预测(Matlab) 目录 EI级!高创新原创未发表!VMD-TCN-BiGRU-MATT变分模态分解卷积神经…

Spring设计模式-实战篇之模板方法模式

什么是模板方法模式? 模板方法模式用于定义一个算法的框架,并允许子类在不改变该算法结构的情况下重新定义算法中的某些步骤。这种模式提供了一种将算法的通用部分封装在一个模板方法中,而将具体步骤的实现延迟到子类中的方式。 模板方法模式…

【Go】Go语言中的数组与切片

纵使微茫如烟 纵有万般思念 流光总将故人搁浅在断简残篇 不成眠 不等谁来证明 不必狂歌痛饮 唯盼重相见 我如倦鸟归林 🎵 流浪的蛙蛙《从别后》 摘要 Go语言提供了强大的数据结构来处理固定长度的序列和动态长度的序列,分别称为数…

Superset二次开发之 配置Docker

手动安装 安装必要的一些系统工具 在设置仓库之前,需先安裝所需的软件包。yum-utils提供了yum-config-manager,并且device mapper存储驱动程序需要device-mapper-persistent-data和lvm2。 yum install -y yum-utils device-mapper-persistent-data lvm2 设置源仓库 使用阿里云…

爬虫工作量由小到大的思维转变---<第五十一章 Scrapy 深入理解Scrapy爬虫引擎(2)--引擎的工作流程>

前言: 继续上一篇:https://hsnd-91.blog.csdn.net/article/details/136943552 本章主要介绍Scrapy引擎的启动流程、请求处理的生命周期、如何处理下载的内容以及触发Item Pipeline的过程。还讨论了数据处理在爬虫解析函数和Item Pipeline中的作用,并介绍了引擎关闭…

2024年 前端JavaScript Web APIs 第五天 笔记

5.1-BOM和延迟函数setTimeout 5.2-事件循环eventloop 1-》 3 -》2 1-》 3 -》2 5.3-location对象 案例&#xff1a;5秒钟之后自动跳转页面 <body><a href"http://www.itcast.cn">支付成功<span>5</span>秒钟之后跳转到首页</a><sc…

数据库测试案例20240322-binlog_format为row binlog日志分析,主备数据不一致会导致复制出问题

1 测试概述 master-1&#xff0c;master-2表数据test如下&#xff1a; 9:26: [mytest]> select *From test; ---------- | id | name | ---------- | 10 | 123 | ---------- 1 row in set (0.00 sec) 2 在主库将数据删除导致数据不一致 09:26: [mytest]> set sql_…

git的实际应用场景

本文章的场景主要来源于实际工作&#xff0c;用于记载回看&#xff1b;持续更新&#xff0c;最后更新日期&#xff1a;2024-03-23软件&#xff1a;Git BASH、GitK、Git GUI三者配合使用 1、git reset < file > 作用&#xff1a;把文件从暂存区状态重置为工作区状态&…

对象操作篇

文章目录 9.1 dir()9.2 hash()9.3 help()9.4 id()9.5 type() 9.1 dir() dir() 是 Python 中的一个内置函数&#xff0c;用于返回一个对象的所有属性和方法的列表。当dir()不带参数调用时&#xff0c;它会返回当前作用域中的变量、方法和定义的类型列表。如果dir()带有一个参数…

从零开始学HCIA之网络自动化02

1、Python 是一种解释型&#xff08;即不需要编译环节&#xff09;的、面向对象&#xff08;即支持面向对象的风格或代码&#xff09;的、动态数据类型的高级程序设计语言。对于所谓的高级程序设计语言&#xff0c;你可以理解为“同声传译”的过程。 2、Python标准库很庞大&am…

Shut down, sleep, or hibernate your PC 关闭、睡眠或休眠

最近一段时间没有整服务器了~自己开始捉摸18年买的笔记本-x280&#xff0c;除了发现usb type c和thunderbolt 3接口的不一样外&#xff0c;也开始研究这个待机的功能了~找了官方文档&#xff0c;做个简易的翻译&#xff0c;给大家一起看看学习把。 官方文档URL&#xff1a; S…

Docker搭建LNMP环境实战(02):Win10下安装VMware

实战开始&#xff0c;先安装 VMware 虚拟机。话不多说&#xff0c;上手就干&#xff01; 1、基本环境检查 1.1、本机Bios是否支持虚拟化 进入&#xff1a;任务管理器- 性能&#xff0c;查看“虚拟化”是否启用&#xff0c;如果已启用&#xff0c;则满足要求&#xff0c;如果未…

【Swagger】接口文档生成

文章目录 一、前后端分离开发流程二、YApi导入接口文档三、Swagger3.1 介绍3.2 使用步骤3.2.1 导入 knife4j 的maven依赖3.2.2 在配置类中加入 knife4j 相关配置3.2.3 配置类中设置静态资源映射3.2.4 访问测试 3.3 常用注解3.4 全局参数设置 四、YApi 与 Swagger 一、前后端分离…

Day18:LeedCode 513.找树左下角的值 112. 路径总和 106.从中序与后序遍历序列构造二叉树

513. 找树左下角的值 给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1: 输入: root [2,1,3] 输出: 1 思路:出该二叉树的 最底层 最左边 节点的值找出深度最大的第一个结点(左结点先遍历) 方法一…

一个单生产-多消费模式下无锁方案(ygluu/卢益贵)

一个单生产-多消费模式下无锁方案 ygluu/卢益贵 关键词&#xff1a;生产者-消费者模型、无锁队列、golang、RWMutex 本文介绍一个“单生产(低频)-多消费”模式下的无锁哈希类方案&#xff0c;这个方案的性能优于golang的RWMutex&#xff0c;因为它永远不会因为“写”而导致与…

i2c-tools基本用法

一. 前言 前面调试一个I2C设备&#xff0c;用到了i2c-tools&#xff0c;觉得是一个调试I2C不错的工具&#xff0c;本文对i2c-tools的基本用法做一些介绍。i2c-tools是一些控制2C接口工具的集合&#xff0c;其中包括i2cdetect&#xff0c;i2cdump&#xff0c;i2cget&#xff0c;…