PyTorch数据Pipeline标准化代码模板

前言

PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyTorch在视觉和NLP领域的顶级会议上已呈一统之势。

这篇文章笔者将和大家聚焦于PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。我们从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

from torch.utils.data import Dataset

本文将围绕Dataset对象分别从原始模板、torchvision的transforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,我们这里称之为原始模板。其代码结构如下:

from torch.utils.data import Dataset
class CustomDataset(Dataset):def __init__(self, ...):# stuffdef __getitem__(self, index):# stuffreturn (img, label)def __len__(self):# return examples sizereturn count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务,分别往__init__()、__getitem__()和__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。

  • __len__()函数则用于返回样本数量。

现在我们往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:

from torch.utils.data import Dataset
class CustomDataset(Dataset):def __init__(self):self.samples = list(range(1, 101))def __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]if __name__ == '__main__':dataset = CustomDataset()print(len(dataset))print(dataset[50])print(dataset[1:100])

添加torchvision.transforms

然后我们来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。

添加了tranforms之后的读取模块可以改写为:

from torch.utils.data import Dataset
from torchvision import transforms as Tclass CustomDataset(Dataset):def __init__(self, ...):# stuff...# compose the transforms methodsself.transform = T.Compose([T.CenterCrop(100),T.ToTensor()])def __getitem__(self, index):# stuff...data = # Some data read from a file or image# execute the transformdata = self.transform(data)return (img, label)def __len__(self):# return examples sizereturn countif __name__ == '__main__':# Call the datasetcustom_dataset = CustomDataset(...)

可以看到,我们使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。我们以猫狗图像数据为例进行说明。

定义数据读取方法如下:

class DogCat(Dataset):    def __init__(self, root, transforms=None, train=True, val=False):"""get images and execute transforms."""self.val = valimgs = [os.path.join(root, img) for img in os.listdir(root)]# train: Cats_Dogs/trainset/cat.1.jpg# val: Cats_Dogs/valset/cat.10004.jpgimgs = sorted(imgs, key=lambda x: x.split('.')[-2])self.imgs = imgs         if transforms is None:# normalize      normalize = T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])# trainset and valset have different data transform # trainset need data augmentation but valset don't.# valsetif self.val:self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])# trainsetelse:self.transforms = T.Compose([T.Resize(256),T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])def __getitem__(self, index):"""return data and label"""img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):"""return images size."""return len(self.imgs)if __name__ == "__main__":train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)print(len(train_dataset))print(train_dataset[0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

class CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv filetransform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = transforms.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":# Call datasetdataset =  CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件为示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pdclass CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv file            transform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = T.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Third column is for an operation indicatorself.operation_arr = np.asarray(self.data_info.iloc[:, 2])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Check if there is an operationsome_operation = self.operation_arr[index]# If there is an operationif some_operation:# Do some operation on image# ...# ...pass# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":transform = T.Compose([T.ToTensor()])dataset = CustomDatasetFromCSV('./mnist_labels.csv')print(len(dataset))print(dataset[5])

运行示例如下:

训练集验证集划分

一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

以kaggle的花朵数据为例:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_splittransform = T.Compose([T.Resize((224, 224)),T.RandomHorizontalFlip(),T.ToTensor()])dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)trainset, valset = random_split(dataset, [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img, label)valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下:

使用DataLoader

dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as Tif __name__ == "__main__":# Define transformstransformations = T.Compose([T.ToTensor()])# Define custom datasetdataset = CustomDatasetFromCSV('./labels.csv')# Define data loaderdata_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)for images, labels in data_loader:# Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。

本文原创首发于公众号【机器学习实验室】,开创了【深度学习60讲】、【机器学习算法手推30讲】和【深度学习100问】三大系列文章。

一个算法工程师的成长之路


长按二维码.关注机器学习实验室

机器学习实验室的近期文章:

  • 机器学习公式推导和算法手写之XGBoost

  • 机器学习公式推导和算法手写之马尔科夫链蒙特卡洛

  • 如何部署一个轻量级深度学习项目?

  • 基于C++的PyTorch模型部署

  • PyTorch数据Pipeline标准化代码模板

  • 算法工程师的一天

参考文献

【1】https://pytorch.org/docs/stable/data.html

【2】https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f

【3】https://github.com/utkuozbulak/pytorch-custom-dataset-examples

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2019 最全支付宝高级Java现场面试37题

支付宝现场三面面试题目,文末有福利:阿里经典面试88题目答案 01 支付宝一面 介绍一下自己。 项目参与的核心设计有哪些 ArrayList和LinkedList底层 HashMap及线程安全的ConcurrentHashMap,以及各自优劣势 Java如何实现线程安全 Synchronized和Lock…

腾讯互娱刘伟 | 知识图谱在运维中的应用

本文转载自公众号:InfoQ。随着业务监控建设不断完善,海量业务故障时产生成百上千条告警,如何智能定位故障根源、实时统计业务影响是现阶段运营面临的一个难题。Google 利用知识图谱优化了其搜索服务以来,知识图谱得到了迅速发展。…

中文人物关系图谱构建与应用项目(人物关系抽取,关系抽取评测)

ChinesePersonRelationGraph ChinesePersonRelationGraph, person relationship extraction based on nlp methods.中文人物关系知识图谱项目,内容包括中文人物关系图谱构建,基于知识库的数据回标,基于远程监督与bootstrapping方法的人物关系抽取,基于知识图谱的知识问答等应用…

2019 最新阿里中间件Java 4轮面试题!60万年薪起步~

Java中间件一面 1.技术一面考察范围: 重点问了Java线程锁:synchronized 和ReentrantLock相关的底层实现 线程池的底层实现以及常见的参数 数据结构基本都问了一遍:链表、队列等 Java内存模型:常问的JVM分代模型,以…

0011【冥想】87天冥想感悟汇总

0011【冥想】87天冥想感悟汇总 2018.1.6 Day1图片发自简书App1.7 Day2图片发自简书App❤️1.8冥想Day3❤️听了谷老师的分享,挺受益的。当我们的心越来越柔软,身体也会越来越柔软,所谓相由心生,冥想的时候,身体也会听从…

这个自然语言处理“工具”,玩得停不下来

今天推荐一个有趣的自然语言处理公众号「AINLP」,关注后玩得根本停不下来!AINLP的维护者是我爱自然语言处理(52nlp)博主,他之前在腾讯从事NLP相关的研发工作,目前在一家创业公司带技术团队。AINLP公众号的定…

观点 | 抛开炒作看知识图谱,为什么现在才爆发?

本文转载自公众号:AI前线。 作者 | George Anadiotis 译者 | 无明 导读:知识图谱究竟是什么,都有哪些围绕它们的炒作?如果你想要像 Airbnb、亚马逊…

算法--排序--寻找数组内第K大的元素

此题目,需要用到快速排序里的划分数组操作: 快排参考:https://blog.csdn.net/qq_21201267/article/details/81516569#t2 先选取一个合适的哨兵(三数取中法)将数组分成三部分【小于哨兵的】【哨兵】【大于等于哨兵的】…

淘宝网Java五面:现场面试49题含答案!

淘宝一面: 面试介绍 1)自我介绍? 2)项目介绍? 3)遇到的最大困难是什么?怎么解决的? 4)你觉得你能怎么优化这个项目? 面试题目 1)讲一下JVM 2&#xff…

告别自注意力,谷歌为Transformer打造新内核Synthesizer

一只小狐狸带你解锁 炼丹术&NLP 秘籍作者:舒意恒(南京大学硕士生,知识图谱方向)今天给大家介绍一篇来自Google的最新论文《SYNTHESIZER: Rethinking Self-Attention in Transformer Models》[4],该论文重新探索了T…

50万抽象知识图谱项目(实体抽象、性状抽象与动作抽象)

AbstractKnowledgeGraph AbstractKnowledgeGraph, a systematic knowledge graph that concentrate on abstract thing including abstract entity and action. 抽象知识图谱,目前规模50万,支持名词性实体、状态性描述、事件性动作进行抽象。目标于抽象…

算法--排序--大小写字母数字分离(桶排序思想)

题目: 对D,a,F,B,c,A,z这个字符串进行排序,要求将其中所有小写字母都排在大写字母的前面,但小写字母内部和大写字母内部不要求有序。比如经过排序之后为a,c&a…

2019 最新蚂蚁花呗Java三面题目:红黑树+并发容器+CAS+Solr+分布式等

蚂蚁金服专场 涵盖了蚂蚁金服从Java工程师到技术专家面试题目 支付宝高级Java三面题目:线程锁事务雪崩Docker等 蚂蚁花呗团队面试题:LinkedHashMapSpringCloud线程锁分布式 蚂蚁金服高级Java面试题目 支付宝Java开发四面:NgnixMQ队列集群并发抢购 蚂…

论文浅尝 | 实体图的预览表格生成

链接:ranger.uta.edu/~cli/pubs/2016/tabview-sigmod16-yan.pdf动机对于结构化数据和关系数据,通常使用Schema图为数据库的使用者提供基本信息。因此,作者提出了生成预览表格(preview table)的方法,为实体图…

万能的BERT连文本纠错也不放过

一只小狐狸带你解锁炼丹术&NLP秘籍作者:孙树兵学校:河北科技大学方向:QA/NLU/信息抽取编辑:小轶背景文本纠错(Spelling Error Correction)技术常用于文本的预处理阶段。在搜索引擎、输入法和 OCR 中有着…

POJ 1664 苹果放盘子(递归)

题目链接:http://poj.org/problem?id1664 m个相同的苹果放在n个相同的盘子里,有多少种不一样的方法。 例如,3个苹果放在4个盘子里有(3,0,0,0)(1,1&#xf…

蚂蚁金服4轮面经(Java研发):G1收集器+连接池+分布式架构

一面 线程池有哪些参数?分别有什么用?如果任务数超过的核心线程数,会发生什么?阻塞队列大小是多少? 数据库连接池介绍下,底层实现说下 hashset底层实现,hashmap的put操作过程 说说HaspMap底层…

算法--二分查找--求平方根(循环法/递归法)

二分查找: 数据需要是顺序表(数组)数据必须有序可以一次排序,多次查找;如果数据频繁插入,删除操作,就必须保证每次操作后有序,或者查找前继续排序,这样成本高&#xff0…

论文浅尝 | 在生成式多跳机器阅读任务中引入外部常识知识

Commonsense for Generative Multi-Hop Question Answering Tasks链接: https://arxiv.org/abs/1809.06309背景机器阅读任务按照答案类型的不同,可以大致分为:(1) 分类问题: 从所有候选实体选择一个(2) answer span: 答案是输入文本的一个片段(3) …

Overleaf v2 评测

原文链接:https://www.jianshu.com/p/1d73d4b9e880 Overleaf v2 评测 去年,两个著名的Latex在线编辑器Overleaf和Sharelatex合并了,强强联手,让我们对他们合并之后的新产品充满了期待。最近,他们的新产品发布了&#x…