pytorch Nvidia 数据预处理加速

目录

安装 不支持Windows:

官方说明:

预处理加速:

学习笔记:


参考:

深度学习预处理工具---DALI详解_nvidia.dali.fn_扫地的小何尚的博客-CSDN博客

安装 不支持Windows:

官方说明:

Installation — NVIDIA DALI 1.30.0 documentation

pip install nvidia-pyindex
pip install nvidia-dali-cuda110


import nvidia.dali.ops
import nvidia.dali.types
 
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
 

官网下载地址:看起来么有windows版本,

Index of /compute/redist///nvidia-dali-cuda110

预处理加速:

Nvidia Dali: 强大的数据增强库_笔记大全_设计学院

学习笔记:

对于深度学习任务,训练速度决定了模型的迭代速度,而训练速度又取决于数据预处理和网络的前向和后向耗时。
对于识别任务,batch size通常较大,并且需要做数据增强,因此常常导致训练速度的瓶颈在数据读取和预处理上,尤其对于小网络而言。
对于数据读取耗时的提升,粗暴且有效的解决办法是使用固态硬盘,或者将数据直接拷贝至/tmp文件夹(内存空间换时间)。
对于数据预处理的耗时,则可以通过使用Nvidia官方开发的Dali预处理加速工具包,将预处理放在cpu/gpu上进行加速。pytorch1.6版本内置了Dali,无需自己安装。

官方的Dali交程较为简单,实际训练通常要根据任务需要自定义Dataloader,并于分布式训练结合使用。这里将展示一个使用Dali定义DataLoader的例子,功能是返回序列图像,并对序列图像做常见的统一预处理操作。
`

from nvidia.dali.plugin.pytorch import DALIGenericIteratorfrom nvidia.dali.types import DALIImageType
import cv2
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from sklearn.utils import shuffle
import numpy as np
from torchvision import transforms
import torch.utils.data as torchdata
import random
from pathlib import Path
import torchclass TRAIN_INPUT_ITER(object):def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=True):self.batch_size = batch_sizeself.num_class = num_classself.seq_len = seq_lenself.sample_rate = sample_rateself.num_shards = num_shardsself.shard_id = shard_idself.train = is_trainingself.image_name_formatter = lambda x: f'image_{x:05d}.jpg'self.root_dir = root_dirwith open(list_file,'r') as f:self.ori_lines = f.readlines()def __iter__(self):self.i = 0bucket = len(self.ori_lines)//self.num_shardsself.n = bucketreturn selfdef __next__(self):batch = [[] for _ in range(self.seq_len)]labels = []for _ in range(self.batch_size):# self.sample_rate = random.randint(1,2)if self.train and self.i % self.n == 0:bucket = len(self.ori_lines)//self.num_shardsself.ori_lines= shuffle(self.ori_lines, random_state=0)self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]line = self.lines[self.i].strip()dir_name,start_f,end_f, label = line.split(' ')start_f = int(start_f)end_f = int(end_f)label = int(label)begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))begin_frame = max(1,begin_frame)last_frame = Nonefor k in range(self.seq_len):filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)if filename.exists():f = open(filename,'rb')last_frame = filenameelif last_frame is not None:f = open(last_frame,'rb')else:print('{} does not exist'.format(filename))raise IOErrorbatch[k].append(np.frombuffer(f.read(), dtype = np.uint8))if random.randint(0,1)%2 == 0:end_frame = start_f + random.randint(0,self.sample_rate*self.seq_len//2)begin_frame = max(1,end_frame-self.sample_rate*self.seq_len)else:begin_frame = end_f - random.randint(0,self.sample_rate*self.seq_len//2)begin_frame = max(1,begin_frame)end_frame = begin_frame + self.sample_rate*self.seq_lenlast_frame = Nonefor k in range(self.seq_len):filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)if filename.exists():f = open(filename,'rb')last_frame = filenameelif last_frame is not None:f = open(last_frame,'rb')else:print('{} does not exist'.format(filename))raise IOErrorbatch[k].append(np.frombuffer(f.read(), dtype = np.uint8))labels.append(np.array([label], dtype = np.uint8))if label==8 or label == 9:labels.append(np.array([label], dtype = np.uint8))else:labels.append(np.array([self.num_class-1], dtype = np.uint8))self.i = (self.i + 1) % self.nreturn (batch, labels)next = __next__class VAL_INPUT_ITER(object):def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=False):self.batch_size = batch_sizeself.num_class = num_classself.seq_len = seq_lenself.sample_rate = sample_rateself.num_shards = num_shardsself.shard_id = shard_idself.train = is_trainingself.image_name_formatter = lambda x: f'image_{x:05d}.jpg'self.root_dir = root_dirwith open(list_file,'r') as f:self.ori_lines = f.readlines()self.ori_lines= shuffle(self.ori_lines, random_state=0)def __iter__(self):self.i = 0bucket= len(self.ori_lines)//self.num_shardsself.n = bucketreturn selfdef __next__(self):batch = [[] for _ in range(self.seq_len)]labels = []for _ in range(self.batch_size):# self.sample_rate = random.randint(1,2)if self.train and self.i % self.n == 0:bucket = len(self.ori_lines)//self.num_shardsself.ori_lines= shuffle(self.ori_lines, random_state=0)self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]if self.i % self.n == 0:bucket = len(self.ori_lines)//self.num_shardsself.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]line = self.lines[self.i].strip()dir_name,start_f,end_f, label = line.split(' ')start_f = int(start_f)end_f = int(end_f)label = int(label)begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))begin_frame = max(1,begin_frame)last_frame = Nonefor k in range(self.seq_len):filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)if filename.exists():f = open(filename,'rb')last_frame = filenameelif last_frame is not None:f = open(last_frame,'rb')else:print('{} does not exist'.format(filename))raise IOErrorbatch[k].append(np.frombuffer(f.read(), dtype = np.uint8))labels.append(np.array([label], dtype = np.uint8))self.i = (self.i + 1) % self.nreturn (batch, labels)next = __next__class HybridPipe(Pipeline):def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards,shard_id,root_dir, list_file, num_threads, device_id=0, dali_cpu=True,size = (224,224),is_gray = True,is_training = True):super(HybridPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)if is_training:self.external_data = TRAIN_INPUT_ITER(batch_size//2, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)else:self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)# self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)self.seq_len = seq_lenself.training = is_trainingself.iterator = iter(self.external_data)self.inputs = [ops.ExternalSource() for _ in range(seq_len)]self.input_labels = ops.ExternalSource()self.is_gray = is_graydecoder_device = 'cpu' if dali_cpu else 'mixed'self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)if self.is_gray:self.space_converter = ops.ColorSpaceConversion(device='gpu',image_type=types.RGB,output_type=types.GRAY)self.resize = ops.Resize(device='gpu', size=size)self.cast_fp32 = ops.Cast(device='gpu',dtype = types.FLOAT)if self.training:self.crop_coin = ops.CoinFlip(probability=0.5)self.crop_pos_x = ops.Uniform(range=(0., 1.))self.crop_pos_y = ops.Uniform(range=(0., 1.))self.crop_h = ops.Uniform(range=(256*0.85,256))self.crop_w = ops.Uniform(range=(256*0.85,256))self.crmn = ops.CropMirrorNormalize(device="gpu",output_layout=types.NHWC)self.u_rotate = ops.Uniform(range=(-8, 8))self.rotate = ops.Rotate(device='gpu',keep_size=True)self.brightness = ops.Uniform(range=(0.9,1.1))self.contrast = ops.Uniform(range=(0.9,1.1))self.saturation = ops.Uniform(range=(0.9,1.1))self.hue = ops.Uniform(range=(-0.3,0.3))self.color_jitter = ops.ColorTwist(device='gpu')else:self.crmn = ops.CropMirrorNormalize(device="gpu",crop=(224,224),output_layout=types.NHWC)def define_graph(self):self.batch_data = [i() for i in self.inputs]self.labels = self.input_labels()out = self.decode(self.batch_data)out = [out_elem.gpu() for out_elem in out]if self.training:out = self.color_jitter(out,brightness=self.brightness(),contrast=self.contrast())if self.is_gray:out = self.space_converter(out)if self.training:out = self.rotate(out,angle=self.u_rotate())out = self.crmn(out,crop_h=self.crop_h(),crop_w=self.crop_w(),crop_pos_x=self.crop_pos_x(),crop_pos_y=self.crop_pos_y(),mirror=self.crop_coin())else:out = self.crmn(out)out = self.resize(out)if not self.training:out = self.cast_fp32(out)return (*out, self.labels)def iter_setup(self):try:(batch_data, labels) = self.iterator.next()for i in range(self.seq_len):self.feed_input(self.batch_data[i], batch_data[i])self.feed_input(self.labels, labels)except StopIteration:self.iterator = iter(self.external_data)raise StopIterationdef dali_loader(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,list_file,num_workers,device_id,dali_cpu=True,size = (224,224),is_gray = True,is_training=True):print('##########',root_dir)pipe = HybridPipe(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,list_file,num_workers,device_id=device_id,dali_cpu=dali_cpu,size = size,is_gray=is_gray,is_training=is_training)# pipe.build()names = []for i in range(seq_len):names.append(f'data{i}')names.append('label')print('##############',names)loader = DALIGenericIterator(pipe,names,pipe.external_data.n,last_batch_padded=True, fill_last_batch=True)return loade

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101671.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【设计模式】使用建造者模式组装对象并加入自定义校验

文章目录 1.前言1.1.创建对象时的痛点 2.建造者模式2.1 被建造类准备2.2.建造者类实现2.3.构建对象测试2.4.使用lombok简化建造者2.5.lombok简化建造者的缺陷 3.总结 1.前言 在我刚入行不久的时候就听说过建造者模式这种设计模式,当时只知道是用来组装对象&#xf…

Vuex存值取值与异步请求处理

目录 前言 一、Vuex简介 1.Vuex是什么 2.Vuex的核心概念 3.使用Vuex的好处 4.Vuex执行流程 二、Vuex的使用步骤 1.安装Vuex 2.创建store模块,分别维护state/actions/mutations/getters 3.使用Vuex存储值,获取值和改变值 1.state.js---存值 2.…

关于Vuex的基础使用存值及异步

目录 一.概述 二.取值 2.1.安装 2.2.菜单栏 2.3.模块 2.4.引用 三.改值 四.异步&后台请求 好啦今天就到这里了希望能帮到你哦!!! 一.概述 Vuex 是一个用于 Vue.js 应用程序的状态管理库。它主要用于集中管理应用程序中的共享状态&a…

idea中maven plugin提示not found

在终端中输入: mvn dependency:resolve 然后 解决了部分问题 Plugin org.apache.maven.plugins:maven-jar-plugin:3.1.0 not found 改为3.3.0了 Plugin maven-source-plugin:3.3.0 not found 改为 2.4 了 版本下降了 感觉后继有坑 待观察

Linux网络和系统管理

网络管理命令 1、ifconfig 命令 作用 ifconfig 命令用于显示或设置网络设备的信息。格式 ifconfig [网卡名字] [参数]可选项 网卡名字:指定要操作的网络设备。参数: up:启动指定网卡。down:关闭指定网卡。-a:显示所有网卡接口的信息,包括未激活的网卡接口。使用示例 1…

时代风口中的Web3.0基建平台,重新定义Web3.0!

近年来,Web3.0概念的广泛兴起,给加密行业带来了崭新的叙事方式,同时也为加密行业提供了更加具有想象力的应用场景与商业空间,并让越来越多的行业从业者们意识到只有更大众化的市场共性需求才能推动加密市场的持续繁荣。当前围绕这…

IDEA设置自动导入包

IDEA设置自动导入包 首先进入设置选项 之后勾选以下两项: 第一项:IntelliJ IDEA 将在我们书写代码的时候自动帮我们优化导入的包,比如自动去掉一些没有用到的包。 第二项: IntelliJ IDEA 将在我们书写代码的时候自动帮我们导入…

.NET ABP.Zero 项目疑似内存排查历程

当前项目是 .NET 5 EentityFrameworkCore,疑似内存泄漏,之所以说是疑似是因为到目前位置还没有能准确的定位到问题。当前这个框架从 .NET Core 2.1 就开始用,期间有升级到 3.1、5.0、6.0,在排查过程中还把 5.0 分支升级到了 7.0 。…

HashMap -- 调研

HashMap 调研 前言JDK1.8之前拉链法: JDK1.8之后JDK1.7 VS JDK1.8 比较优化了一下问题: HashMap的put方法的具体流程?HashMap的扩容resize操作怎么实现的? 前言 在Java中,保存数据有两种比较简单的数据结构:数组和链表。 数组的特点是:寻址容易,插入…

【RabbitMQ 实战】11 队列的结构和惰性队列

一、 队列的结构 队列的组成: 队列由 rabbit_amgqueue_process 和 backing_queue两部分组成。rabbit_amqqueue_process负责协议相关的消息处理,即接收生产者发布的消息、向消费者交付消息、处理消息的确认 (包括生产端的 confirm 和消费端的 ack) 等。…

Spring Boot读取配置文件

Spring Boot 是一种用于快速构建基于Spring的应用程序的框架,它提供了很多便利的功能和约定,使开发者可以快速搭建、配置和部署应用程序。在Spring Boot中,读取配置文件是一个非常常见的任务,本文将介绍如何在Spring Boot应用程序…

Qt/C++原创推流工具/支持多种流媒体服务/ZLMediaKit/srs/mediamtx等

一、前言 1.1 功能特点 支持各种本地视频文件和网络视频文件。支持各种网络视频流,网络摄像头,协议包括rtsp、rtmp、http。支持将本地摄像头设备推流,可指定分辨率和帧率等。支持将本地桌面推流,可指定屏幕区域和帧率等。自动启…

CAN通信-应用

up起来 驱动加载完成&#xff0c;使用ifconfig -a 可以看到两个节点 can0: flags128<NOARP> mtu 16unspec 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 txqueuelen 10 (UNSPEC)RX packets 0 bytes 0 (0.0 B)RX errors 0 dropped 0 overruns 0 frame 0TX p…

【Vuex+ElementUI】Vuex中取值存值以及异步加载的使用

一、导言 1、引言 Vuex是一个用于Vue.js应用程序的状态管理模式和库。它建立在Vue.js的响应式系统之上&#xff0c;提供了一种集中管理应用程序状态的方式。使用Vuex&#xff0c;您可以将应用程序的状态存储在一个单一的位置&#xff08;即“存储”&#xff09;中&#xff0c;…

MATLAB算法实战应用案例精讲-【图像处理】SLAM技术详解(应用篇)

目录 前言 知识储备 概率论基础 边缘概率 联合概率和独立 独立与条件独立

iPhone15手机拓展坞方案,支持手机快充+传输数据功能

手机拓展坞的组合有何意义&#xff1f;首先是数据存储场景&#xff0c;借助拓展坞扩展出的接口&#xff0c;可以连接U盘、移动硬盘等采用USB接口的设备&#xff0c;实现大文件的快速存储或者流转&#xff1b;其次是图片、视频的读取场景&#xff0c;想要读取相机、无人机SD/TF存…

options.html 页面设计成聊天框,左侧是功能列表,右侧是根据左侧的功能切换成不同的内容。--chatGpt

gpt: 要将 options.html 页面设计成一个聊天框式的界面&#xff0c;其中左侧是功能列表&#xff0c;右侧根据左侧的功能切换成不同的内容&#xff0c;你可以按照以下步骤进行设计和实现&#xff1a; 1. 首先&#xff0c;创建 options.html 文件&#xff0c;并在其中定义基本的…

【Xcode-宏定义配置】

1&#xff0c;项目中配置宏-DEBUG 打开Xcode项目工程 -> Targets -> Build Settings -> Preprocessor Macros -> &#xff0c;在Debug状态下点击右侧&#xff0c;“”添加&#xff0c;在对话框中输入DEBUG1&#xff0c;并保存&#xff0c;注意别把Release给覆盖了…

shell脚本学习

shell是一个用 C 语言编写的程序&#xff0c;是用户使用 Linux 的桥梁。Shell 既是一种命令语言&#xff0c;又是一种程序设计语言。实际上Shell是一个命令解释器&#xff0c;它解释由用户输入的命令并且把它们送到内核。不仅如此&#xff0c;Shell有自己的编程语言用于对命令的…

【angular】实现简单的angular国际化(i18n)

文章目录 目标过程运行参考 目标 实现简单的angular国际化。本博客实现中文版和法语版。 将Hello i18n!变为中文版&#xff1a;你好 i18n!或法语版:Bonjour l’i18n !。 过程 创建一个项目&#xff1a; ng new i18nDemo在集成终端中打开。 添加本地化包&#xff1a; ng a…