pytorch数据载入
- 1.数据载入概况
- Dataloader 是啥
- 2.支持的三类数据集
- 2.1 torchvision.datasets.xxx
- 2.2 torchvision.datasets.ImageFolder
- 2.3 写自己的数据类,读入定制化数据
- 2.3.1 数据类的编写
- map-style范式
- iterable-style 范式
- 2.3.2 DataLoader 导入数据类
1.数据载入概况
数据是机器学习算法的驱动力, Pytorch提供了方便的数据载入和处理接口. 数据载入流程为:
step1: 指定要使用的数据集dataset
step2: 使用Dataloader载入数据
dataloader实质是一个可迭代对象,不能使用next()访问。但如果使用iter()封装,返回一个迭代器,可以使用.next()操作。
Dataloader 是啥
来自官网document的描述:
Dataloader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading,
customizing loading order and optional automatic batching (collation) and memory pinning.
See torch.utils.data documentation page for more details.
大概就是说:用来对数据集进行(小批次)迭代 载入的接口,所能够载入的数据集要么支持map-style操作,要么支持 iterable-style操作。
(这两种操作只有在编写用户数据类时才需要考虑,使用内置公开数据集和.ImageFolder不需要管这两者是啥东西,开发者已经帮你写好了)
2.支持的三类数据集
1.torchvision.datasets–内置了许多常见的公开数据集
2.torchvision.datasets.ImageFolder–用户定制数据集1(只要自己的数据集满足ImageFolder要求的格式,提供数据集所在的地址即可)
3.定制数据集–需要编写自己的dataset 类
2.1 torchvision.datasets.xxx
一些常用的公开数据集合,可以在torchvision.datasets接口中找到。
例如–MNIST、Fashion-MNIST、KMNIST、EMNIST、FakeData、COCO、Captions、Detection、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD等常用数据集合。
torchvision.datasets在使用一个新的数据集合前,需要保证本地拥有该数据集合(符合pytorch内部编码格式)。最简单额方式是第一次使用时,将download=True将默认将该数据集下载到指定的root 目录中。
CIFAR10数据集使用的例子
transform = transforms.Compose( [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)
默认值:train=True
step1 数据集选择与图片处理方式选择
trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,download=True, transform=transform)
参数解释:
1.root=’./data’
数据集的保存目录,各种数据集有自己的文件格式,其中MNIST是以training.pt和test.pt的保存图像数据信息(具体看一下文件应该怎么存,读入之后的列表和迭代器各是什么内容)
2.train =True
处理MNIST时从training.pt读取训练数据,=False 从test.pt读取测试数据。仔细观察,上面两句话只有在train这个选项处不同.
3.download =True
会从网上下载对应的数据集文件,MNIST对应.pt文件,如果存在 .pt 文件,这个参数可以设置为False
4.transform
设置一组对图像进行处理的操作,这一组操作由Compose组成,这一组compose 的顺序还很重要按如下顺序编写:
transforms.Resize()
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
step2 数据载入接口
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
参数解释
1.将刚刚生成的trainset列表传入 torch.utils.data.DataLoader()
2.batch_size=4 设定图像数据批次大小
3.shuffle=True 每一个epoch过程中会打乱数据顺序,重新随机选择
4.导入数据时的线程数目,默认为0,主线程导入数据
2.2 torchvision.datasets.ImageFolder
当数据集超出1中所提供数据集的范围时,Pytorch还提供了ImageFolder数据集导入方式。只要将数据按照一定的要求存放,就能如方式1一样方便取用。
数据集合格式要求:同类别的图像放在一个文件夹下,用类别名称/标号来命名文件夹。要自己手工设计训练集合、测试集合。
x=torch.datasets.ImageFolder(root="图像集合中文件夹路径”)
x是一个ImageFolder格式的数据:
其中重要主要成员为:
class_to_idx ={dict} 是字典数据,以“文件夹名字:分配类别序号”作为键值的字典
classes ={list} 包含所有文件夹名字的一个序列
imgs={list} 列表元素为–(图像路径,对应文件夹名)
使用torch.utils.data.DataLoader载入数据:
trainloader = torch.utils.data.DataLoader(x, batch_size=4, shuffle=True, num_workers=4)
参考网址:
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
2.3 写自己的数据类,读入定制化数据
当用户数据个格式不能用以上两种方式读取时,可以尝试写自己的数据类
所有的datasets都是torch.utils.data.Dataset的子类,方法1中使用的是torchvision.datasets.数据集,方法 2中使用的是torchvision.datasets.ImageFolder。当我们在编写自己的数据类时,也需要继承Dataset类。
2.3.1 数据类的编写
在介绍Dataloader 使提到过,其载入的数据类需要满足两者操作中的一个(map-style操作/iterable-style操作)
map-style范式
Map-style 操作范式数据类的核心:实现了 getitem() 和 len()方法,通过data[index]获取数据样本和相应的标签。
猜测:DataLoader 在导入minibatch数据时,随机采样一批index(通过len确认index 的采样范围), 然后在经过getitem获取相应的数据
class MyDataset:def __init__(self, gentor: object, batchSize: int, imgSize: int):# 从源数据中读取数据列表,或者能操作数据的名称列表def __len__(self):# 返回数据集样本的数量return sample_map_numdef __getitem__(self, idx:int):# 通过idx获取数据datadata = get(idx) // get 依据不同的数据集定制// 进行一些tansform操作在返回return data
官方实践demo:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
iterable-style 范式
Iterable-style 操作范式数据类 是 IterableDataset的子类,实现了__iter__()方法。当随机读取非常耗时/无法实现时。(数据流,实时记录的数据)
有机会实践一下
2.3.2 DataLoader 导入数据类
编写好了自己的数据类之后,同其他两种数据类一样使用DataLoader导入数据即可。
train_set = MyDataset()data = train_set[0] # idx 读取某一个数据trainloader = DataLoader(train_set, batch_size=64, shuffle=True) # 封装成dataloader的形式print(len(trainloader))for _, data in enumerate(trainloader):....
下面提供一些可供参考的博文:
https://www.jianshu.com/p/220357ca3342
https://www.cnblogs.com/devilmaycry812839668/p/10122148.html
https://ptorch.com/news/215.html