PyTorch 系列 | 数据加载和预处理教程

640?wx_fmt=jpeg

图片来源:Unsplash,作者:Damiano Baschiera

2019 年第 66 篇文章,总第 90 篇文章

本文大约 8000 字,建议收藏阅读

原题 | DATA LOADING AND PROCESSING TUTORIAL

作者 | Sasank Chilamkurthy

译者 | kbsc13("算法猿的成长"公众号作者)

原文 | https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途

简介

本文教程主要是介绍如何加载、预处理并对数据进行增强的方法。

首先需要确保安装以下几个 python 库:

  • scikit-image :处理图片数据

  • pandas :处理 csv 文件

导入模块代码如下:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils# Ignore warnings
import warnings
warnings.filterwarnings("ignore")plt.ion()   # interactive mode

本次教程采用的是一个人脸姿势数据集,其图片如下所示:

640?wx_fmt=png

每张人脸都是有 68 个人脸关键点,它是由 dlib 生成的,具体实现可以查看其官网介绍:

https://blog.dlib.net/2014/08/real-time-face-pose-estimation.html

数据集下载地址:

https://download.pytorch.org/tutorial/faces.zip

数据集中的 csv 文件的格式如下所示,图片名字和每个关键点的坐标 x, y

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

数据集下载解压缩后放到文件夹 data/faces 中,然后我们先快速打开 face_landmarks.csv 文件,查看文件内容,即标注信息,代码如下所示:

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出如下所示:

640?wx_fmt=png

接着写一个辅助函数来显示人脸图片及其关键点,代码如下所示:

def show_landmarks(image, landmarks):"""Show image with landmarks"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001)  # pause a bit so that plots are updatedplt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),landmarks)
plt.show()

输出如下所示:

640?wx_fmt=png

Dataset 类

torch.utils.data.Dataset 是表示一个数据集的抽象类,在自定义自己的数据集的时候需要继承 Dataset 类别,并重写下方这些方法:

  • len :调用 len(dataset) 时可以返回数据集的数量;

  • getitem:获取数据,可以实现索引访问,即 dataset[i] 可以访问第 i 个样本数据

接下来将给我们的人脸关键点数据集自定义一个类别,在 __init__ 方法中将读取数据集的信息,并在 __getitem__ 方法调用获取的数据集,这主要是基于内存的考虑,这种做法不需要将所有数据一次读取存储在内存中,可以在需要读取数据的时候才读取加载到内存里。

数据集的样本将用一个字典表示:{'image': image, 'landmarks': landmarks},另外还有一个可选参数 transform 用于预处理读取的样本数据,下一节将介绍这个 transform 的用处。

自定义函数的代码如下所示:

class FaceLandmarksDataset(Dataset):"""Face Landmarks dataset."""def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): 带有标注信息的 csv 文件路径root_dir (string): 图片所在文件夹transform (callable, optional): 可选的用于预处理图片的方法"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):# 读取图片img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)# 读取关键点并转换为 numpy 数组landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample

接下来是一个简单的例子来使用上述我们自定义的数据集类,例子中将读取前 4 个样本并展示:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')fig = plt.figure()
# 读取前 4 张图片并展示
for i in range(len(face_dataset)):sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)plt.tight_layout()ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break

输出结果如下所示:

640?wx_fmt=png

Transforms

从上述例子输出的结构可以看到一个问题,图片的大小并不一致,但大多数神经网络都需要输入图片的大小固定。因此,接下来是给出一些预处理的代码,主要是下面三种预处理方法:

  • Rescale :调整图片大小

  • RandomCrop:随机裁剪图片,这是一种数据增强的方法

  • ToTensor:将 numpy 格式的图片转换为 pytorch 的数据格式 tensors,这里需要交换坐标。

这几种方法都将写成可调用的类,而不是简单的函数,这样就不需要每次都传递参数。因此,我们需要实现 __call__ 方法,以及有必要的话,__init__ 方法也是要实现的,然后就可以如下所示一样调用这些方法:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

Rescale 方法的实现代码如下:

class Rescale(object):"""将图片调整为给定的大小.Args:output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]# 判断给定大小的形式,tuple 还是 int 类型if isinstance(self.output_size, int):# int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 hlandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}

RandomCrop 的代码实现:

class RandomCrop(object):"""给定图片,随机裁剪其任意一个和给定大小一样大的区域.Args:output_size (tuple or int): 期望裁剪的图片大小。如果是 int,将得到一个正方形大小的图片."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_size# 随机选择裁剪区域的左上角,即起点,(left, top),范围是由原始大小-输出大小top = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]# 调整关键点坐标,平移选择的裁剪起点landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}

ToTensor 的方法实现:

class ToTensor(object):"""将 ndarrays 转换为 tensors."""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 调整坐标尺寸,numpy 的维度是 H x W x C,而 torch 的图片维度是 C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}

组合使用预处理方法

接下来就是介绍使用上述自定义的预处理方法的例子。

假设我们希望将图片的最短边长调整为 256,然后随机裁剪一个 224*224 大小的图片区域,也就是我们需要组合调用 Rescale 和 RandomCrop 预处理方法。

torchvision.transforms.Compose 是一个可以实现组合调用欲处理方法的类,实现代码如下所示:

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 对图片数据调用上述 3 种形式预处理方法,即单独使用 Rescale,RandomCrop,组合使用 Rescale和 RandomCrop
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()

输出结构:

640?wx_fmt=png

迭代整个数据集

现在我们已经定义好一个处理数据集的类,3种预处理数据的类,那么可以将它们整合在一起,实现加载并预处理数据的流程,流程如下所示:

  • 首先根据图片路径读取图片

  • 对图片都调用预处理的方法

  • 预处理方法也可以实现数据增强

实现的代码如下所示:

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/',transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break

输出结果:

640?wx_fmt=png

上述只是一个简单的处理过程,实际上处理和加载数据的时候,我们一般还对数据做以下的处理:

  • 将数据按给定大小分成一批一批数据

  • 打乱数据排列顺序

  • 采用 multiprocessing 来并行加载数据

torch.utils.data.DataLoader 是一个可以实现上述操作的迭代器。其需要的参数如下代码所示,其中一个参数 collate_fn 是用于指定如何对数据进行分批的操作,但也可以采用默认函数。

dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):"""Show image with landmarks for a batch of samples."""images_batch, landmarks_batch = \sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid_border_size = 2grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size,s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())# observe 4th batch and stop.if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break

输出结果:

640?wx_fmt=png

torchvision

最后介绍 torchvision 这个库,它提供了一些常见的数据集和预处理方法,采用这个库就可以不需要自定义类,它比较常用的方法是 ImageFolder ,它假定图片的保存路径如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这里的 antsbees 等等都是类别标签,此外对 PIL.Image 的预处理方法,如 RandomHorizontalFlip 、Scale 都包含在 torchvision 中,一个使用例子如下所示:

import torch
from torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)

小结

本教程主要介绍如何对自己的数据集自定义一个类来加载,以及预处理的方法,同时最后也介绍了 PyTorch 中的 torchvision ,torch.utils.data.DataLoader 方法。

本文的代码上传至 Github:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/pytorch_dataloader_tutorial.ipynb

另外,还有用 dlib 生成人脸关键点的代码:

https://github.com/ccc013/DeepLearning_Notes/blob/master/Pytorch/create_landmark_dataset.py

此外,也可以公众号后台回复“PyTorch”获取本次教程的数据集和代码。

留言时间

640?wx_fmt=png


欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!

640?wx_fmt=png

如果觉得不错,在看、转发就是对小编的一个支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408572.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch系列 | 如何加快你的模型训练速度呢?

图片来源:Pixabay,作者:talha khalil2019 年第 67 篇文章,总第 91 篇文章本文大约 6500 字,建议收藏阅读!原题 | Speed Up your Algorithms Part 1 — PyTorch作者 | Puneet Grover译者 | kbsc13("算法…

Leetcode 系列 | 反转链表

点击上方“算法猿的成长”,选择“加为星标”第一时间关注 AI 和 Python 知识最近会更新一个 leetcode 的刷题系列,每次更新一道题目,并且通过画图辅助介绍自己的解题思路,大家如果有更好的解题思路也欢迎在文末留言,或…

Android UI开发第二十九篇——Android中五种常用的menu(菜单)

Android Menu在手机的应用中起着导航的作用,作者总结了5种常用的Menu。 1、左右推出的Menu 前段时间比较流行,我最早是在海豚浏览器中看到的,当时耳目一新。最早使用左右推出菜单的,听说是Facebook,我不确定消息的真实…

深度学习的一些经验总结和建议| To do v.s Not To Do

关注&置顶“Charlotte数据挖掘”每日9:00,干货速递!昨天看到几篇不同的文章写关于机器学习的to do & not to do,有些观点赞同,有些不赞同,是现在算法岗位这么热门,已经不像几年前一样,可…

PyTorch系列 | 快速入门迁移学习

点击上方“算法猿的成长”,选择“加为星标”第一时间关注 AI 和 Python 知识图片来源:Pexels,作者:Arthur Ogleznev2019 年第 68 篇文章,总第 92 篇文章本文大约 6800 字,建议收藏阅读原题 | TRANSFER LEAR…

AI研发工程师成长指南

AI研发工程师成长指南本文为数据茶水间群友原创,经授权在本公众号发表。关于作者:Japson。某人工智能公司AI平台研发工程师,专注于AI工程化及场景落地。持续学习中,期望与大家多多交流技术以及职业规划。0x00 前言首先&#xff0c…

进程和线程(上)

点击上方“算法猿的成长”,选择“加为星标”第一时间关注 AI 和 Python 知识2019 年第 70 篇文章,总第 94 篇文章本文大约 6000 字,阅读大约需要 15 分钟最近会开始继续 Python 的进阶系列文章,这是该系列的第一篇文章&#xff0c…

你有哪些deep learning(rnn、cnn)调参的经验?

点击上方“算法猿的成长”,选择“加为星标”第一时间关注 AI 和 Python 知识来源:知乎问题深度学习中调参其实是一个比较重要的技巧,但很多时候都需要多尝试多积累经验,因此算法工程师也被调侃为调参工程师。这里分享来自知乎上的…

A+B for Matrices 及 C++ transform的用法

题目大意&#xff1a;给定两个矩阵&#xff0c;矩阵的最大大小是M*N&#xff08;小于等于10&#xff09;&#xff0c;矩阵元素的值的绝对值小于等于100&#xff0c;求矩阵相加后全0的行以及列数。 1 #include<iostream>2 using namespace std;3 #define N 104 5 int main…

进程和线程(下)

点击上方“算法猿的成长”&#xff0c;选择“加为星标”第一时间关注 AI 和 Python 知识2019 年第 71 篇文章&#xff0c;总第 95 篇文章本文大约 8000 字&#xff0c;建议收藏阅读上一篇文章介绍了进程和线程的基本概念&#xff0c;以及多进程如何实现&#xff0c;本文则介绍下…

8月总结抽奖

1首先是小小总结下 8 月发文的情况&#xff0c;总共推文时间是 21 天&#xff0c;原创文章有 9 篇&#xff0c;分别如下&#xff1a;Jupyter 进阶教程PyTorch 系列 | 数据加载和预处理教程PyTorch系列 | 如何加快你的模型训练速度呢&#xff1f;Leetcode 系列 | 反转链表PyTorc…

【文件系统】浅解释FAT32

了解完linux下的文件系统之后&#xff0c;顺便对FAT32也研究一下。 假如一个FAT32表如下所示。 文件的簇应该保留在目录中&#xff0c;根据此簇&#xff0c;应该能得到一个块。 要找到文件的下一块&#xff0c;就要根据簇在FAT中寻找&#xff0c;所以FAT中存储的不是本簇的簇号…