文章目录
- 相关链接
- Dataset
- VisionDataset
- DatsetFolder
- ImageFolder
- torchvision.transforms
- Pytorch Lightning
- LightningDataModule
对于图像数据集来说,首先是在Dataset类对数据集进行定义,一般来说不定义transform,则数据为PIL Image,PIL格式到tensor的转换也是transforms变换的一种,所以定Dataset+transforms+Dataloader,最后在训练部分to(device)来得到模型的输入。
相关链接
torchvision.datasets的三个基础类
torchvision.datasets
torch.utils.data.Dataset
Pillow(PIL Fork) Image模块
Dataset
Dataset是数据集在pytorch中的化身,需要重写__ getitem__ 和 __ len__。__ __ getitem__ 通过传入的索引加载指定路径的数据,路径常常是一个列表,如很多张图片组成的数据集,需要在初始化时定义函数得到路径列表,或者在外部定义,总之要得到一个路径List。也需要在其中定义或调用具体读取的代码,如PIL库的Image.open()来读取图片,或Image.fromarray()来创建图片,也就是需要知道数据在哪里和怎么读取。
└─Dataset└─VisionDataset└─DatasetFolder└─ImageFolder
Dataset是torch.utils.data中的类,是数据集的基础类
VisionDataset
VisionDataset是torchvision.datasets.vision中的类,是torchvision类数据集的基础类,相比于原始的Dataset类,提供了transform,transforms,target_transform数据变换的接口
DatasetFolder,ImageFolder都来自torchvision.datasets.folder ,既然叫做folder,实际上已经有了完整的数据集功能,可以按照默认的目录结构读取数据。DatasetFolder还需要定义loader以读取特定类型的数据,和is_valid_file或者extensions,is_valid_file和extensions不能同时定义,但必须有一个定义,如果定义了有效后缀名,会自动通过后缀来判断文件有效性。而ImageFolder更进一步,默认使用读取图像数据的loader读取,还默认定义了图像后缀名。从Dataset到ImageFolder构成了不同层次的封装,完成度越高,灵活性越低,可以根据自己的需要选择。
除了在__ getitem__ 中通过得到的路径列表来读取数据,对于不同格式的数据也有不同的做法,如torchvision中内置cifar数据集,会直接从原始数据中以矩阵的形式读取, 因此 __ getitem__ 会从矩阵中创建Image对象。总而言之,一般来讲对于图片数据集来说,__ get __返回的都是PIL Image对象,不管是从路径列表中读取,还是整个以矩阵形式读取,如果不定义transform,最后在Dataset阶段都是PIL对象。
DatsetFolder
默认的排列结构如下,每一个文件夹表示一类,下面是这一类的样本
directory/ ├── class_x │ ├── xxx.ext │ ├── xxy.ext │ └── ... │ └── xxz.ext └── class_y ├── 123.ext ├── nsdf3.ext └── ... └── asd932_.ext
用文件夹来区分不同的类别。比较重要的有两类操作,find_class函数得到类别名和类别序号。make_dataset得到路径列表。
默认的findclass函数
文件夹名是类名。
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:"""Finds the class folders in a dataset.See :class:`DatasetFolder` for details."""classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idx
默认的make_dataset
得到instance列表,表示文件的路径列表。
基本上很大一部分是在定义有效性判断相关,主要部分是一个双层for循环,因为类名定义为文件夹名,所以会遍历各个类的文件夹,会将遍历到的有效文件的路径加入instance,遍历过的非空类添加到available_classe。
def make_dataset(directory: Union[str, Path],class_to_idx: Optional[Dict[str, int]] = None,extensions: Optional[Union[str, Tuple[str, ...]]] = None,is_valid_file: Optional[Callable[[str], bool]] = None,allow_empty: bool = False,
) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")both_none = extensions is None and is_valid_file is Noneboth_something = extensions is not None and is_valid_file is not Noneif both_none or both_something:raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")if extensions is not None:def is_valid_file(x: str) -> bool:return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]is_valid_file = cast(Callable[[str], bool], is_valid_file)instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if is_valid_file(path):item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes and not allow_empty:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "if extensions is not None:msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"raise FileNotFoundError(msg)return instances
ImageFolder
root/dog/xxy.pngroot/dog/[...]/xxz.pngroot/cat/123.pngroot/cat/nsdf3.png
ImageFolder如名字所示,如果数据集是这种文件夹排列,而且是图像文件,又没有需要特殊定义的部分 ,可以直接实例化一个ImageFolder,而不需要重写任何部分 ,实例化一个数据集只需要传入数据集路径和tansform变换。
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']]
torchvision.transforms
一般会在数据集实例化时,从外部传入,通常自定义的Transforms序列包含ToTensor,可以将上一阶段的PIL Image转换为Tensor,而下一次变化要到训练时的to(device),这样数据最终输入完成,也可以在Dataset类中写入默认的transform。
通过torchvision.get_image_backend得到torchvision现在的后端默认为PILtorchvision.set_image_backend(backend)指定用来读取图片的包,可选accimage
Loader将数据读取为PIL对象,一般数据集定义不在数据集内部定义默认的transform图像变换,而是在外部定义一个transform序列,通常倒数第二个是torchvision.transforms.ToTensor()操作,会将一个PIL Image或者一个ndarray转换为tensor并缩放到[0.0, 1.0]。因此接下来会通过transforms.Normalize进行归一化。
PILToTensor会把PIL Image转化为tensor,但是不会进行缩放, ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)→(C×H×W)
ToTensor会把PIL Image或者ndarray转换成tensor而且会进行缩放。 ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)→(C×H×W) 在规定的模式如RGBA,RGB,YCbCr或者dtype = np.uint8情况下,别的情况下不缩放。
Normalize只支持tensor,其他大部分操作也支持PIL,所以在ToTensor之后最后进行Normalize
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
Pytorch Lightning
Pytorch Lightning是Pytorch中的kersas,简称pl
LightningDataModule
Pytorch Lightning继承LightningDataModule定义数据集,pl中的Dataset和Dataloader是高度耦合的。
import lightning.pytorch as L
import torch.utils.data as data
from pytorch_lightning.demos.boring_classes import RandomDatasetclass MyDataModule(L.LightningDataModule):def prepare_data(self):# download, IO, etc. Useful with shared filesystems# only called on 1 GPU/TPU in distributed...def setup(self, stage):# make assignments here (val/train/test split)# called on every process in DDPdataset = RandomDataset(1, 100)self.train, self.val, self.test = data.random_split(dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42))def train_dataloader(self):return data.DataLoader(self.train)def val_dataloader(self):return data.DataLoader(self.val)def test_dataloader(self):return data.DataLoader(self.test)def teardown(self):# clean up state after the trainer stops, delete files...# called on every process in DDP...