在分类任务中,数据集文件存储往往是如下形式:
- train- class1- image1.jpg- image2.jpg...- class2- image1.jpg- image2.jpg......
此时,我们想要获取图片和标签,标签即为文件名(class1、class2…)
可以使用torchvision.datasets.ImageFolder()来进行获取,示例代码如下:
dataset = datasets.ImageFolder(root=DATA_PATH/'train')
torchvision.datasets.ImageFolder() 参数列表:
- root:图像文件读取路径
- transform:对图像数据采取的数据增强策略
- target_transform:对label进行转换
- loader:指定加载图像的函数
- is_valid_file:获取图像路径,检查文件的有效性
返回值
dataset 返回有如下三个属性:
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存(img-path, class) tuple的 list
我们得到的dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…]
下面我们打印dataset第一个元素中的图片:
返回对应的label:
其中,对于dataset[0]来说,其中也存储了两个元素,第一个是图片,第二个是类别索引号。
sample = dataset[0]
img = sample[0] #图片
label = sample[1] #类别索引
注意:
- dataset中存储的label是按文件夹顺序生成对应索引的,且以下标为0开始。如果要读取类别的字符,可以通过
self.classes[0]
来获取。 - train文件夹下的文件格式是固定的,不能有多余的文件,否则会读取出错。