PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

PyTorch训练中Dataset多线程加载数据,而不是在DataLoader

背景与需求

现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多。今天就给大家带来最近PyTorch训练的一些小小的心得。

大家做机器学习、深度学习都恨不得机器卡越多越好,这样可以跑得越快,道理好像也很直白,大家都懂。实际上我们在训练的时候很大一部分制约我们的训练的速度快慢被IO限制住了,然面CPU的利用率却不高,就算有8卡了,然而GPU的利用率却长期处理低水平,不能发挥设备本应该有的水平。所以我一直在想,有什么办法能加快IO的读取,当然最直截的就换SSD,那上速度会直接上去了。那如果是我们在服务器或者是普通的电脑就没有办法呢吗?

而且经常用PyTorch的人应该会发现,如果我们把DataLoader的num_workers设置比较大的时候,在训练启动时会等待比较久,而且在每一个epoch之间的切换也是需要等挺久的(更换,加载数据)。

如果是一个程序员的话,肯定会想到多线程、多进程,这是否会能加速我们训练的IO?答案是肯定的。

今天给大家带来的就是,多线程读取数据的实例,本次测试不含训练部分,只是对Dataset, DataLoader数据加载的部分进行测试。

PyTorch DataLoader会产生一个index然后Dataset再进行读取,如果一个batch_size=128的话,那就要产生128次的数据调试,并读取。

我的想法就很简单,我想要不我就直接在Dataset就生成好所需的Batches,这样在DataLoader的batch_size=1的话,那也是对应一个batch的数据,而我在Dataset的可以用线程去加载数据,这样应该能提高读取的效率。

有了想法就是干了。

平时我们重要Dataset的结构如下,这里用到了albumentations作为数据处理的库,而不是torchvision的transforms,其它没有什么区别的

def default_loader(path):return Image.open(path).convert('RGB')class AlbumentationsDatasetList(Dataset):""" Data processing using albumentation same as torchvision transforms"""def __init__(self, imgs, transform=None, loader=default_loader, percentage=1):# here can control the dataset size percentage    img_num = int(len(imgs) * percentage)self.imgs = imgs[:img_num]self.transform = transformself.loader = loaderdef __getitem__(self, index):fn = self.imgs[index]img = self.loader(fn)if self.transform is not None:image_np = np.array(img)augmented = self.transform(image=image_np)img = augmented['image']return imgdef __len__(self):return len(self.imgs)

方法的实现

说干就干,把多线程加进来进行改造Dataset,下面来看一下代码,代码加入了一些细节,所以会比较长,但结构还是跟上面的是一样的。只是Dataset就已经把batches都处理好了,在加载数据后,是把他们都stack在一起,这样就可以形成[N, C, W, H]结构的数据了。

注意:如果drop_last=False的话,那么最后的一个batch的数量一般不会与batch_size相同,所以在DataLoader的里batch_size要设置成1。还有DataLoader设置成1后,实际加载的数据是[1, N, C, W, H],所以在用的时候要squeeze一下。

class AlbumentationsDatasetList(Dataset):def __init__(self, images, batch_num, percentage=1,transform=None, multi_load=True,shuffle=True,seed=None,drop_last=False,num_workers=4,loader=default_loader) -> None:#==============================================# Set seed#==============================================if seed is None:self.seed = np.random.randint(0, 1e6, 1)[0] # Fix bug 2021-12-10else:self.seed = seedrandom.seed(self.seed)# add some assertation if the image empty donot proceed. Fix 2021-12-12assert images is not None, f'images must be NOT empty, but got {images}'  self.images = imagesself.batch_num = batch_num   # use batch_num instead of batch_size, same thingself.percentage = percentageself.transform = transformself.multi_load = multi_loadself.shuffle = shuffleself.drop_last = drop_lastself.num_workers = num_workers # Dataset num_workersself.loader = loaderself.batches = self._create_batches()self.batches = self._get_len_batches(self.percentage)def _get_len_batches(self, percentage):"""Description:- you could control how many batches you want to use for training or validatingindices sort, so that could keep the batches got in order from originla batchesParameters:- percentage: float, range [0, 1]Return- numpy array of the new bags"""batch_num = int(len(self.batches) * percentage)indices = random.sample(list(range(len(self.batches))), batch_num)indices.sort()new_batches = np.array(self.batches, dtype='object')[indices]return new_batchesdef _create_batches(self,):if self.shuffle:random.shuffle(self.images)batches = []ranges = list(range(0, len(self.images), self.batch_num))for i in ranges[:-1]:batch = self.images[i:i + self.batch_num]batches.append(batch)#== Drop last ===============================================last_batch = self.images[ranges[-1]:]if len(last_batch) == self.batch_num:batches.append(last_batch)elif self.drop_last:passelse:batches.append(last_batch)return batchesdef __getitem__(self, index):batch = self.batches[index]#== Stack all images, become a 4 dimensional tensor ===============if self.multi_load:batch_images = self._multi_loader(batch)else:batch_images = []for image in batch:img = self._load_transform(image)batch_images.append(img)batch_images_tensor = torch.stack(batch_images, dim=0)return batch_images_tensordef _load_transform(self, tile):img = self.loader(tile)if self.transform is not None:image_np = np.array(img)augmented = self.transform(image=image_np)img = augmented['image']return imgdef _multi_loader(self, tiles):images = []executor = ThreadPoolExecutor(max_workers=self.num_workers)results = executor.map(self._load_transform, tiles)executor.shutdown()for result in results:images.append(result)return imagesdef __len__(self):return len(self.batches)

代码与数据测试

接下来就是拿数据进行测试了,这里还设置了multi_load的参数,这样我们可以方便控制是否用多线程与否,这样我们就可以对比一下在相同的机器,相同的数据下,多线程加载数据是否比单线程快。

  • 测试的目的:

    • 1,是否多线程多单线程快;
    • 2,多线程能比单线路快多少;
    • 3,找到这机器最快(或者比较全适)的越参数,可作为其它机器的参考。
  • 测试平台:Window10

  • CPU:Intel Core i7-9850H @ 2.60GHz

  • RAM: 32 GB

  • 测试的数据:是5000张图像,全部都是3通道RBG,8位的512x512像素图像,图像格式是.PNG。

  • 测试方法:

    • 超参数如下:搜索空间为1024

      • multi_loads = [True, False]
        prefetch_factors = list(range(0, 17, 2))[1:] # [2, 4, 6, 8, 10, 12, 14, 16]
        dataset_workers = list(range(0, 17, 2))[1:]
        dataloader_workers = list(range(0, 17, 2))[1:]
        
    • 利用grid search方法,每一个搜索空间都对Dataset, DataLoader设置不同的参数,而且每轮数据都是读完、并处理完5000张图像,drop_last=False

    • 数据增强:只做了resize,normalize

下面是全部的测试代码。

albumentations_valid = album.Compose([album.Resize(480, 480),album.Normalize(mean=[0.7347, 0.4894, 0.6820, ], std=[0.1747, 0.2223, 0.1535, ]),ToTensorV2(),])from utils import get_specified_filespath = r"xxxxx"images = get_specified_files(path, suffixes=[".png"], recursive=True) # glob.globimages = images[:5000]print(len(images))results = []log_file = open(r"grid_search_log.txt", mode='a', encoding='utf-8')multi_loads = [True, False]prefetch_factors = list(range(0, 17, 2))[1:] # [2, 4, 6, 8, 10, 12, 14, 16]dataset_workers = list(range(0, 17, 2))[1:]dataloader_workers = list(range(0, 17, 2))[1:]for multi_load in multi_loads:for prefetch_factor in prefetch_factors:for dataset_worker in dataset_workers:for dataloader_worker in dataloader_workers:multi_load = multi_loadif multi_load:prefetch_factor = prefetch_factorelse:prefetch_factor = prefetch_factordataloader_worker = dataloader_workertrain_dataset = AEDataset(images, batch_num=128, percentage=1, transform=albumentations_valid, multi_load=multi_load, shuffle=True, seed=0, drop_last=False,num_workers=dataset_worker,)train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=False, num_workers=dataloader_worker, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=False)print("Start loading")start_time = time.time()for i, (batches) in enumerate(train_loader):i+1elapse = time.time() - start_timeprint(f"multi_load: {multi_load}, prefetch_factors: {prefetch_factor}, dataset_workers: {dataset_worker}, data_loader_workers: {dataloader_worker}, elapse: {elapse:.4f}")log_file.write(f"multi_load: {multi_load}, prefetch_factors: {prefetch_factor}, dataset_workers: {dataset_worker}, data_loader_workers: {dataloader_worker}, elapse: {elapse:.4f}\n")

测试结果

回到我们上面的测试目标

测试的目的:

  • 1,是否多线程多单线程快;
  • 2,多线程能比单线程快多少;
  • 3,找到这台机器最快(或者比较全适)的越参数,可作为其它机器的参考。

我们带着这3个问题,看一下下面的测试结果:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
path = "C:/Users/jasne/Desktop/grid_search_multi_load.csv"
df = pd.read_csv(path)
df.head()
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
0True1414219.9746
1True1410219.9816
2True1412220.0205
3True810220.0514
4True1416220.0943

Max elapse

也是我们平时用的普通load的方法,时间是72.28秒

df[df["elapse"]==df["elapse"].max()]
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
1024False11172.2857

Multi Load Max elapse

多线程时最慢的时间

multi_load = df[df["multi_load"]==True]
multi_load[multi_load["elapse"]==multi_load["elapse"].max()]
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
1023True6141648.3309

Min elapse

相差的倍数的计算公式为(max−min)/min(\text{max} - \text{min}) / \text{min}(maxmin)/min
时间是19.97秒,比最长的时间少了 52.31秒,快了2.6倍的时间,所以可以看出用multi_load肯定是比single load要快的。

多线程的时间,也受prefetch_factors, dataset_workers, dataloader_workers的影响。而且影响还是比较大的。

多线程时,最快与最慢的相差1.42倍

df[df["elapse"]==df["elapse"].min()]
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
0True1414219.9746

下面来看是否 data_loader_workers越大越好?

dataloader_workers = multi_load[(multi_load["prefetch_factors"]==2) & (multi_load["dataset_workers"]==2)]
dataloader_workers.sort_values("data_loader_workers", inplace=True)
dataloader_workers
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
376True22228.6076
102True22424.4866
144True22626.3106
410True22830.3909
536True221033.2621
724True221236.9114
946True221441.3437
986True221644.4443
plt.figure(figsize=(8, 5))
plt.scatter(dataloader_workers["data_loader_workers"], dataloader_workers["elapse"])
plt.show()

请添加图片描述

从图上可以看出,dataloader_workers并非越大越好,dataloader_workers=4时是在2-8之间是比较好的选择。随着dataloader_workers的增加,所需要的时间也呈线性的增加。

下面来看是否 dataset_workers越大越好

dataset_workers = multi_load[(multi_load["prefetch_factors"]==2) & (multi_load["data_loader_workers"]==2)]
dataset_workers.sort_values("dataset_workers", inplace=True)
dataset_workers
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
376True22228.6076
75True24223.5092
52True26222.4270
49True28222.2465
26True210221.7578
37True212222.0112
46True214222.1947
35True216221.9832
plt.figure(figsize=(8, 5))
plt.scatter(dataset_workers["dataset_workers"], dataset_workers["elapse"])
plt.show()

请添加图片描述

从图上可以看出,dataset_workers增加也可以明显减少数据加载所需要时间。但是当dataset_workers超过10后,不再呈现出减少的趋势,当达到12、14时有一点点上降。由于测试平台有限,这里所应该让测试一下dataset_workers达到128或者更高的数之间,是否会达到更少的数据加载时间。

下面来看是否 prefetch_factors越大越好

prefetch_factors = multi_load[(multi_load["dataset_workers"]==2) & (multi_load["data_loader_workers"]==2)]
prefetch_factors.sort_values("prefetch_factors", inplace=True)
prefetch_factors

multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
376True22228.6076
289True42227.7318
309True62228.0899
141True82226.2518
378True102228.6515
332True122228.2445
135True142226.0284
134True162226.0025
plt.figure(figsize=(8, 5))
plt.scatter(prefetch_factors["prefetch_factors"], prefetch_factors["elapse"])
plt.show()

请添加图片描述

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UUp7MHiu-1634438695527)(C:/Users/jasne/Desktop/Untitled/output_18_0.png)]

从图上可以看出,prefetch_factors似乎好像越大,加载的时间越少,但似乎也相差不多,最多的时间与最小的时间相差也仅为2.6秒。

prefetch_factors的外一个筛选条件

prefetch_factors = multi_load[(multi_load["dataset_workers"]==10) & (multi_load["data_loader_workers"]==4)]
prefetch_factors.sort_values("prefetch_factors", inplace=True)
prefetch_factors
multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
70True210423.3808
103True410424.4975
108True610424.6660
53True810422.5058
90True1010424.1555
92True1210424.1825
39True1410422.0710
120True1610425.0829
plt.figure(figsize=(8, 5))
plt.scatter(prefetch_factors["prefetch_factors"], prefetch_factors["elapse"])
plt.show()

请添加图片描述

从图上可以看出,prefetch_factors数量似乎对加载时间的影响似乎不太明显,最多的时间与最小的时间相差也仅为2.6秒。

multi_loadprefetch_factorsdataset_workersdata_loader_workerselapse
70True210423.3808
103True410424.4975
108True610424.6660
53True810422.5058
90True1010424.1555
92True1210424.1825
39True1410422.0710
120True1610425.0829
plt.figure(figsize=(8, 5))
plt.scatter(prefetch_factors["prefetch_factors"], 
prefetch_factors["elapse"])plt.show()

请添加图片描述

从图上可以看出,prefetch_factors数量似乎对加载时间的影响似乎不太明显,最多的时间与最小的时间相差也仅为2.6秒。

结论

  1. 多线程加载数据肯定是比单线程快的?
    • 这点是不用质疑的,单从计算机的运行方式就可以得出这个结论,这也是并行的优势。
  2. 多线程能比单线程快多少?
    • 从上面的结果,我们看到,当选用合适的超参数时,多线程加载相同的数据与相同的处理方法,比单线程快了52.31秒,快了2.6倍有多。就算是最不好的参数,多线和最长的加载时间为48.33秒,也比单线程的72.28秒,快差不多0.5倍。
  3. 找到这台机器最快(或者比较全适)的越参数,可作为其它机器的参考
    • dataset_workers 越大越好,但达到了一个临界值后,不会再增加了,本测试平台的值为10
    • data_loader_workers,不是越大越好,本测试平台最好的值为4,在4左右的值都是较好的参考值。然后随着此参数的数量的增加,所需要的时间也呈线性的增涨,这也说明了PyTorch大data_loader_workers启动需要等待更久的时间
    • prefetch_factors的数量似乎对数据的加载时间影响不大,但最好不要是1。

本次测试没有监测内存还有CPU的使用率,但在过程中观察了一下,CPU使用率基本都可以达到100%。也可以把这些参数也监测起来,形成更多的超参数,以便参考。
注意:由于在训练的过程中也是需要利用CPU的,所以尽量不要太多的dataset_workers,尽量不要把CPU都使用到100%,而造成死机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/260515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python调用摄像头人脸识别代码_利用face_recognition,dlib与OpenCV调用摄像头进行人脸识别...

用已经搭建好 face_recognition,dlib 环境来进行人脸识别 未搭建好环境请参考: 使用opencv 调用摄像头 import face_recognition import cv2 video_capture cv2.videocapture(0) # videocapture打开摄像头,0为笔记本内置摄像头,1…

【转】彻底搞清计算结构体大小和数据对齐原则

数据对齐: 许多计算机系统对基本数据类型合法地址做出了一些限制,要求某种类型对象的地址必须是某个值K(通常是2,4或8)的倍数。这种对齐限制简化了形成处理器和存储器系统之间的接口的硬件设计。例如,假设一个处理器总是从存储器中取出8个字节…

Pytorch 学习率衰减 之 余弦退火与余弦warmup 自定义学习率衰减scheduler

学习率衰减,通常我们英文也叫做scheduler。本文学习率衰减自定义,通过2种方法实现自定义,一是利用lambda,另外一个是继承pytorch的lr_scheduler import math import matplotlib.pyplot as plt import numpy as np import torch i…

c++ 字符串赋给另一个_7.2 C++字符串处理函数

点击上方“C语言入门到精通”,选择置顶第一时间关注程序猿身边的故事作者闫小林白天搬砖,晚上做梦。我有故事,你有酒么?C字符串处理函数C语言和C提供了一些字符串函数,使得用户能很方便地对字符串进行处理。这些是放在…

如何检测远程主机上的某个端口是否开启

有时候我们要测试远程主机上的某个端口是否开启,无需使用太复杂的工作,windows下就自带了工具,那就是telnet。怎么检测呢,按下面的步骤: 1、安装telnet。我的win7下就没有telnet,在cmd下输入telnet提示没有…

Windows10 + WSL (Ubuntu) + Anaconda + vscode 手把手配置python运行环境(含虚拟环境)

配置WSL windows桌面下,按下面顺序可以找到 "启动或关闭windows功能” , 开始 -> 设置 -> 应用 -> 应用和功能 -> 可选功能 -> 相关设置下 更多Windows功能(滚动鼠标到底部)点击后,会弹出 启动或…

2019编译ffepeg vs_如何在windows10下使用vs2017编译最新版本的FFmpeg和ffplay

该文章描述了如何在windows10 64位系统下面编译出FFmpeg的库及其自带的ffplay播放器,而且全部采用最新的版本,这样我们可以在vs2017的ide下调试ffplay,能使我们更容易学习FFmpeg的架构以及音视频播放器的原理。步骤:1.安装vs2017在…

训练集山准确率高测试集上准确率很低_推荐算法改版前的AB测试

编辑导语:所谓推荐算法就是利用用户的一些行为,通过一些数学算法,推测出用户可能喜欢的东西;如今很多软件都有这样的操作,对于此系统的设计也会进行测试;本文作者分享了关于推荐算法改版前的AB测试&#xf…

C#实现渐变颜色的Windows窗体控件

C#实现渐变颜色的Windows窗体控件! 1,定义一个BaseFormGradient,继承于System.Windows.Forms.Form2,定义三个变量: privateColor _Color1 Color.Gainsboro; privateColor _Color2 Color.White; privatefloat_ColorAngle 0f;3,重载OnPaintBackground方法 protecte…

Windows下 jupyter notebook 运行multiprocessing 报错的问题与解决方法

文章目录测试用的代码错误解决方法测试用的代码 下面每一个对应一个jupyter notebook的单元格 import time from multiprocessing import Process, Queuedef generator():c 0while True:time.sleep(1.0) # read somethingyield cc 1%%timeds generator() for i in range(3…

vc mysql_vc6.0连接mysql数据库

一、MySQL的安装Mysql的安装去官网下载就可以。。。最新的是5.7版本。。二、VC6.0的设置(1)打开VC6.中选0 工具栏Tools菜单下的Options选项,在Directories的标签页中右边的“Show directories for:”下拉列表中“Includefiles”,然后在中间列表框中添加你…

python class用法_python原类、类的创建过程与方法

【小宅按】今天为大家介绍一下python中与class 相关的知识……获取对象的类名python是一门面向对象的语言,对于一切接对象的python来说,咱们有必要深入的学习与了解一些知识首先大家都知道,要获取一个对象所对应的类,需要使用clas…

深度学习中的一些常见的激活函数集合(含公式与导数的推导)sigmoid, relu, leaky relu, elu, numpy实现

文章目录Sigmoid(x)双曲正切线性整流函数 rectified linear unit (ReLu)PReLU(Parametric Rectified Linear Unit) Leaky ReLu指数线性单元 Exponential Linear Units (ELU)感知机激活%matplotlib inline %config InlineBackend.f…

最牛X的GCC 内联汇编

正如大家知道的,在C语言中插入汇编语言,其是Linux中使用的基本汇编程序语法。本文将讲解 GCC 提供的内联汇编特性的用途和用法。对于阅读这篇文章,这里只有两个前提要求,很明显,就是 x86 汇编语言和 C 语言的基本认识。…

mysql的告警日志_MySQL Aborted connection告警日志的分析

前言:有时候,连接MySQL的会话经常会异常退出,错误日志里会看到"Got an error reading communication packets"类型的告警。本篇文章我们一起来讨论下该错误可能的原因以及如何来规避。1.状态变量Aborted_clients和Aborted_connects…

hosts多个ip对应一个主机名_一个简单的Web应用程序,用作连接到ssh服务器的ssh客户端...

WebSSH一个简单的Web应用程序,用作连接到ssh服务器的ssh客户端。它是用Python编写的,基于tornado,paramiko和xterm.js。特征支持SSH密码验证,包括空密码。支持SSH公钥认证,包括DSA RSA ECDSA Ed25519密钥。支持加密密钥…

Shell Notes(1)

> vi复制粘贴 光标移动到要复制的部分的开头,Esc退出插入模式,按v进入Visual模式,用hjkl选中要复制的部分 按Y或者yy,复制 移动光标到目标位置,按p,粘贴 > echo –e 参数 –e 可以使echo解释由反斜杠…

mysql多表查询语句_mysql查询语句 和 多表关联查询 以及 子查询

1.查询一张表:select * from 表名;2.查询指定字段:select 字段1,字段2,字段3….from 表名;3.where条件查询:select字段1,字段2,字段3 frome表名 where 条件表达式&#x…

Pytorch 自定义激活函数前向与反向传播 sigmoid

文章目录Sigmoid公式求导过程优点:缺点:自定义Sigmoid与Torch定义的比较可视化import matplotlib import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F%matplotlib inlineplt.rcPa…

js高级编程_这位设计师用Processing把创意编程玩到了极致!

Processing作为新媒体从业者的必备工具,近来却越来越成为设计师们的新宠!今天小编将介绍以为用Processing把创意编程玩到极致的设计师Tim Rodenbrker。“我们的世界正在以惊人的速度变化。新技术为创作带来了根本性的转变。编程是我们这个时代最宝贵的技…