文章目录
- 一、什么是TorchVision
- 二、以torchvision.datasets子模块下的CIFAR10数据集为例
- 1、CIFAR10数据集参数
- 2、代码中使用
一、什么是TorchVision
torchvision是pytorch的一个图形库,用来处理图像,主要用来构建计算机视觉模型。
从下面的官网截图可以看到torchvision有很多模块,下面以dataset模块进行举例。
torchvision中datasets包:用来进行数据加载,主要有以下几个模块
CelebA
CIFAR
Cityscapes
COCO
Captions
Detection
DatasetFolder
EMNIST
FakeData
Fashion-MNIST
Flickr
HMDB51
ImageFolder
ImageNet
Kinetics-400
KMNIST
LSUN
MNIST
Omniglot
PhotoTour
Places365
QMNIST
SBD
SBU
STL10
SVHN
UCF101
USPS
VOC
二、以torchvision.datasets子模块下的CIFAR10数据集为例
从上图可知:
CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。有50000张训练图像和10000张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批包含从每个类随机选择的1000个图像。训练批次以随机顺序包含剩余的图像,但一些训练批次可能包含来自一个类的更多图像。在它们之间,训练批次包含来自每个类的5000张图像。
1、CIFAR10数据集参数
class CIFAR10(VisionDataset):"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.Args:root (string): Root directory of dataset where directory``cifar-10-batches-py`` exists or will be saved to if download is set to True.train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""
由上述代码可知,有如下5个参数
root :即指定数据集要下载在哪一个文件夹里面,如:root=“./dataset” 即将数据集下载到当前目录的dataset文件夹下
train :是否为训练集,布尔类型,如果train=True即为训练集,否则train=False则为非训练集。
transform :进行图像变换的各种操作,如RandomCrop、Compose等。
target_transform :对于标签进行transform 操作。
download :是否下载数据集,download = True表示下载数据集,download = False表示不下载数据集。(如果当前文件夹已经有需要下载的数据集,但是在程序编写中又把download属性值设定为True,此时不会再下载。)
2、代码中使用
import torchvision
from torch.utils.tensorboard import SummaryWriter# 创建训练数据集
# step1 准备创建数据集需要的各种参数
trans_tool = torchvision.transforms.Compose([torchvision.transforms.ToTensor() # 转为Tensor类型# torchvision.transforms.Resize((5, 5)) # 进行大小裁剪
])
# 第一个参数root表示下载的数据集需要放在哪一个文件夹里面,第二个参数tran表示是否是训练数据集,第三个参数transform表示进行变换操作,第四个参数download表示是否在线下载
tran_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=trans_tool,download=True)
# 创建测试数据集
test_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans_tool,download=True)
print(tran_dataset[0]) # 此时显示的是(<PIL.Image.Image image mode=RGB size=32x32 at 0x259A43BA350>, 6),即元组的形式,显示图片类别和标签
# step2 在tensorboard中显示writer = SummaryWriter("logs")
for i in range(10):img, laber = tran_dataset[i]writer.add_image("CIFAR10",img,i)
writer.close()