使用自定义 PyTorch 运算符优化深度学习数据输入管道

这篇文章[1]中,我们讨论 PyTorch 对创建自定义运算符的支持,并演示它如何帮助我们解决数据输入管道的性能瓶颈、加速深度学习工作负载并降低训练成本。

构建 PyTorch 扩展

PyTorch 提供了多种创建自定义操作的方法,包括使用自定义模块和/或函数扩展 torch.nn。在这篇文章中,我们感兴趣的是 PyTorch 对集成定制 C++ 代码的支持。此功能很重要,因为某些操作在 C++ 中比在 Python 中更有效和/或更容易地实现。使用指定的 PyTorch 实用程序(例如 CppExtension),可以轻松地将这些操作作为 PyTorch 的“扩展”包含在内,而无需拉取和重新编译整个 PyTorch 代码库。由于我们对这篇文章的兴趣是加速基于 CPU 的数据预处理管道,因此我们只需使用 C++ 扩展即可,不需要 CUDA 代码。

玩具示例

在我们之前的文章中,我们定义了一个数据输入管道,首先解码 533x800 JPEG 图像,然后提取随机的 256x256 裁剪,经过一些额外的转换后,将其输入训练循环。我们使用 PyTorch Profiler 和 TensorBoard 来测量与从文件加载图像相关的时间,并承认解码的浪费。为了完整起见,我们复制以下代码:

import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
input_img_size = [533800]
img_size = 256

class FakeDataset(VisionDataset):
    def __init__(self, transform):
        super().__init__(root=None, transform=transform)
        size = 10000
        self.img_files = [f'{i}.jpg' for i in range(size)]
        self.targets = np.random.randint(low=0,high=num_classes,
                                         size=(size),dtype=np.uint8).tolist()

    def __getitem__(self, index):
        img_file, target = self.img_files[index], self.targets[index]
        img = Image.open(img_file)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.img_files)


transform = T.Compose(
    [T.PILToTensor(),
     T.RandomCrop(img_size),
     RandomMask(),
     ConvertColor(),
     Scale()])

据推测,如果我们能够仅解码我们感兴趣的作物,我们的管道会运行得更快。不幸的是,截至撰写本文时,PyTorch 不包含支持此功能的函数。然而,使用自定义操作创建工具,我们可以定义并实现我们自己的函数!

自定义 JPEG 图像解码和裁剪函数

libjpeg-turbo 库是一个 JPEG 图像编解码器,与 libjpeg 相比,它包含许多增强和优化。特别是,libjpeg-turbo 包含许多函数,使我们能够仅解码图像中的预定义裁剪,例如 jpeg_skip_scanlines 和 jpeg_crop_scanline。如果您在 conda 环境中运行,可以使用以下命令进行安装:

conda install -c conda-forge libjpeg-turbo

请注意,libjpeg-turbo 已预安装在我们将在下面的实验中使用的官方 AWS PyTorch 2.0 深度学习 Docker 映像中。 在下面的代码块中,我们修改了torchvision 0.15的decode_jpeg函数,以从输入的JPEG编码图像中解码并返回所请求的裁剪。

torch::Tensor decode_and_crop_jpeg(const torch::Tensor& data,
                                   unsigned int crop_y,
                                   unsigned int crop_x,
                                   unsigned int crop_height,
                                   unsigned int crop_width)
 
{
  struct jpeg_decompress_struct cinfo;
  struct torch_jpeg_error_mgr jerr;

  auto datap = data.data_ptr<uint8_t>();
  // Setup decompression structure
  cinfo.err = jpeg_std_error(&jerr.pub);
  jerr.pub.error_exit = torch_jpeg_error_exit;
  /* Establish the setjmp return context for my_error_exit to use. */
  setjmp(jerr.setjmp_buffer);
  jpeg_create_decompress(&cinfo);
  torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());

  // read info from header.
  jpeg_read_header(&cinfo, TRUE);

  int channels = cinfo.num_components;

  jpeg_start_decompress(&cinfo);

  int stride = crop_width * channels;
  auto tensor =
     torch::empty({int64_t(crop_height), int64_t(crop_width), channels},
                  torch::kU8);
  auto ptr = tensor.data_ptr<uint8_t>();

  unsigned int update_width = crop_width;
  jpeg_crop_scanline(&cinfo, &crop_x, &update_width);
  jpeg_skip_scanlines(&cinfo, crop_y);

  const int offset = (cinfo.output_width - crop_width) * channels;
  uint8_t* temp = nullptr;
  if(offset > 0) temp = new uint8_t[cinfo.output_width * channels];

  while (cinfo.output_scanline < crop_y + crop_height) {
    /* jpeg_read_scanlines expects an array of pointers to scanlines.
     * Here the array is only one element long, but you could ask for
     * more than one scanline at a time if that's more convenient.
     */

    if(offset>0){
      jpeg_read_scanlines(&cinfo, &temp, 1);
      memcpy(ptr, temp + offset, stride);
    }
    else
      jpeg_read_scanlines(&cinfo, &ptr, 1);
    ptr += stride;
  }
  if(offset > 0){
    delete[] temp;
    temp = nullptr;
  }
  if (cinfo.output_scanline < cinfo.output_height) {
    // Skip the rest of scanlines, required by jpeg_destroy_decompress.
    jpeg_skip_scanlines(&cinfo,
                        cinfo.output_height - crop_y - crop_height);
  }
  jpeg_finish_decompress(&cinfo);
  jpeg_destroy_decompress(&cinfo);
  return tensor.permute({201});
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("decode_and_crop_jpeg",&decode_and_crop_jpeg,"decode_and_crop_jpeg");
}

在下一节中,我们将按照 PyTorch 教程中的步骤将其转换为可在预处理管道中使用的 PyTorch 运算符。

部署 PyTorch 扩展

如 PyTorch 教程中所述,部署自定义运算符有不同的方法。您的部署设计中可能需要考虑许多因素。以下是我们认为重要的一些示例:

  1. 及时编译:为了确保我们的 C++ 扩展是针对我们训练时使用的同一版本的 PyTorch 进行编译的,我们对部署脚本进行了编程,以便在训练环境中进行训练之前编译代码。
  2. 多进程支持:部署脚本必须支持从多个进程(例如,多个 DataLoader 工作线程)加载我们的 C++ 扩展的可能性。
  3. 托管培训支持:由于我们经常在托管培训环境(例如 Amazon SageMaker)中进行培训,因此我们要求部署脚本支持此选项。 (有关定制托管培训环境主题的更多信息,请参阅此处。)

在下面的代码块中,我们定义了一个简单的 setup.py 脚本,用于编译和安装我们的自定义函数,如此处所述。

from setuptools import setup
from torch.utils import cpp_extension

setup(name='decode_and_crop_jpeg',
      ext_modules=[cpp_extension.CppExtension('decode_and_crop_jpeg'
                                              ['decode_and_crop_jpeg.cpp'], 
                                              libraries=['jpeg'])],
      cmdclass={'build_ext': cpp_extension.BuildExtension})

我们将 C++ 文件和 setup.py 脚本放在名为 custom_op 的文件夹中,并定义一个 「init」.py 以确保安装脚本由单个进程运行一次:

import os
import sys
import subprocess
import shlex
import filelock

p_dir = os.path.dirname(__file__)

with filelock.FileLock(os.path.join(pkg_dir, f".lock")):
  try:
    from custom_op.decode_and_crop_jpeg import decode_and_crop_jpeg
  except ImportError:
    install_cmd = f"{sys.executable} setup.py build_ext --inplace"
    subprocess.run(shlex.split(install_cmd), capture_output=True, cwd=p_dir)
    from custom_op.decode_and_crop_jpeg import decode_and_crop_jpeg

最后,我们修改数据输入管道以使用新创建的自定义函数:

from torchvision.datasets.vision import VisionDataset
input_img_size = [533800]
class FakeDataset(VisionDataset):
    def __init__(self, transform):
        super().__init__(root=None, transform=transform)
        size = 10000
        self.img_files = [f'{i}.jpg' for i in range(size)]
        self.targets = np.random.randint(low=0,high=num_classes,
                                        size=(size),dtype=np.uint8).tolist()

    def __getitem__(self, index):
        img_file, target = self.img_files[index], self.targets[index]
        with torch.profiler.record_function('decode_and_crop_jpeg'):
            import random
            from custom_op.decode_and_crop_jpeg import decode_and_crop_jpeg
            with open(img_file, 'rb'as f:
                x = torch.frombuffer(f.read(), dtype=torch.uint8)
            h_offset = random.randint(0, input_img_size[0] - img_size)
            w_offset = random.randint(0, input_img_size[1] - img_size)
            img = decode_and_crop_jpeg(x, h_offset, w_offset, 
                                       img_size, img_size)

        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.img_files)

transform = T.Compose(
    [RandomMask(),
     ConvertColor(),
     Scale()])

结果

经过我们描述的优化后,我们的步长时间从 0.72 秒降至 0.48 秒,性能提升了 50%!当然,我们优化的影响与原始 JPEG 图像的大小和我们选择的裁剪大小直接相关。

总结

数据预处理管道中的瓶颈很常见,可能会导致 GPU 饥饿并减慢训练速度。考虑到潜在的成本影响,您必须拥有各种工具和技术来分析和解决这些问题。在这篇文章中,我们回顾了通过创建自定义 C++ PyTorch 扩展来优化数据输入管道的选项,展示了其易用性,并展示了其潜在影响。当然,这种优化机制的潜在收益会根据项目和性能瓶颈的细节而有很大差异。

Reference

[1]

Source: https://towardsdatascience.com/how-to-optimize-your-dl-data-input-pipeline-with-a-custom-pytorch-operator-7f8ea2da5206

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/112404.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

驱动day2:LED灯实现三盏灯的亮灭

head.h #ifndef __HEAD_H__ #define __HEAD_H__ #define PHY_PE_MODER 0x50006000 #define PHY_PF_MODER 0x50007000 #define PHY_PE_ODR 0x50006014 #define PHY_PF_ODR 0x50007014 #define PHY_RCC 0x50000A28#endif 应用程序 #include <stdio.h> #include <sys/…

Linux性能优化--补充

14.1. 性能工具的位置 本书描述的性能工具来源于Internet上许多不同的位置。幸运的是&#xff0c;大多数主要发行版都把它们放在一起&#xff0c;包含在了其发行版的当前版本中。表A-1描述了全部工具&#xff0c;提供了指向其原始源位置的地址&#xff0c;并注明它们是否包含在…

YOLOv7改进实战 | 更换轻量化主干网络Backbone(一)之Ghostnet

前言 轻量化网络设计是一种针对移动设备等资源受限环境的深度学习模型设计方法。下面是一些常见的轻量化网络设计方法: 网络剪枝:移除神经网络中冗余的连接和参数,以达到模型压缩和加速的目的。分组卷积:将卷积操作分解为若干个较小的卷积操作,并将它们分别作用于输入的不…

官方认证:研发效能(DevOps)工程师职业技术认证

培养端到端的研发效能人才 为贯彻落实《关于深化人才发展体制机制改革的意见》&#xff0c;推动实施人才强国战略&#xff0c;促进专业技术人员提升职业素养、补充新知识新技能&#xff0c;实现人力资源深度开发&#xff0c;推动经济社会全面发展&#xff0c;根据《中华人民共…

Apache Doris (四十五): Doris数据更新与删除 - Sequence 列

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹哥教你大数据个人主页-哔哩哔哩视频 目录 1. 基本原理

WGCNA分析教程五 | [更新版]

一边学习&#xff0c;一边总结&#xff0c;一边分享&#xff01; 往期WGCNA分析教程 WGCNA分析 | 全流程分析代码 | 代码一 WGCNA分析 | 全流程分析代码 | 代码二 WGCNA分析 | 全流程分析代码 | 代码四 关于WGCNA分析教程日常更新 学习无处不在&#xff0c;我们的教程会在…

[环境配置]anaconda3的base环境与python版本对应关系表

anaconda3版本 base环境对应python版本 Anaconda3-2018.12-Windows-x86_64.exe 3.7 Anaconda3-2019.03-Windows-x86_64.exe 3.7 Anaconda3-2019.07-Windows-x86_64.exe 3.7 Anaconda3-2019.10-Windows-x86_64.exe 3.7 Anaconda3-2020.02-Windows-x86_64.exe 3.7 An…

Alpine.js 精简重

建议有 js 基础&#xff0c;先阅读官网文档&#xff0c;如果您会 vue 类似框架&#xff0c;上手会更快 https://alpinejs.dev js 代码中可以使用 Alpine.sore 定义全局数据 Alpine.store(tabs, {current: first,items: [first, second, third], }) x-text 可以运算任何 js 表…

ubuntu16.04下标定Astra相机

ubuntu16.04下标定Astra相机 1.安装相机驱动 rosrun camera_calibration cameracalibrator.py --size 7x5 --square 0.018 image:/camera/rgb/image_raw camera:/camera/rgb 2.下载camere_calibration 3.进行标定 打开终端&#xff0c;输入 roslaunch astra_launch astrap…

从入门到进阶 之 ElasticSearch 配置优化篇

&#x1f339; 以上分享从入门到进阶 之 ElasticSearch 配置优化篇&#xff0c;如有问题请指教写。&#x1f339;&#x1f339; 如你对技术也感兴趣&#xff0c;欢迎交流。&#x1f339;&#x1f339;&#x1f339; 如有需要&#xff0c;请&#x1f44d;点赞&#x1f496;收藏…

浏览器不能访问阿里云ECS

一、浏览器不能访问端口 在阿里云ECS中构建了工程&#xff0c;nigix或者tomcat或者其他&#xff0c;然后在本地浏览器访问ip端口的时候&#xff0c;连接超时&#xff0c;解决办法&#xff1a; 进入阿里云ECS服务 -> 查看公网ip (外部连接需要使用公网) -> 进入ECS实例的…

攻防世界web篇-cookie

看到cookie立马就会想到F12键看cookie的一些信息 我这个实在存储里面看的&#xff0c;是以.php点缀结尾&#xff0c;可以试一下在链接中加上.php 得到的结果是这样 这里&#xff0c;我就只能上csdn搜索一下了&#xff0c;看到别人写的是在get请求中可以看到flag值

Mysql 约束,基本查询,复合查询与函数

文章目录 约束空属性约束默认值约束zerofill主键约束自增长约束唯一键约束外键约束 查询select的执行顺序单表查询排序 updatedelete整张表的拷贝复合语句group by分组查询 函数日期函数字符串函数数学函数其他函数 复合查询合并查询union 约束 空属性约束 两个值&#xff1a…

Windows安装Jenkins

JDK 11 以上 https://github.com/adoptium/temurin11-binaries/releases/download/jdk-11.0.20%2B8/OpenJDK11U-jdk_x64_windows_hotspot_11.0.20_8.msi https://www.jenkins.io/download/ 下载windows安装版本 授权用户administrator logon as services windows(server)安装…

操作系统【OS】中断和异常

异常&#xff08;内中断&#xff09; 中断&#xff08;外中断&#xff09; 基本概念 由CPU执行指令内部产生的事件内中断都是不可屏蔽中断&#xff0c;一旦出现&#xff0c;就要立即处理。 由来自CPU外部的设备发出的中断请求&#xff08;常用于输入输出&#xff09;典型的由…

element-ui 以CDN 方式引入原生js开发的几个别坑 (+vue)

element-ui 以CDN 方式引入原生js开发的几个坑 最近两个月太忙了 忙的没空写文章 两个月赶出来了几个的项目 一个是雪佛兰裸眼3D的一个商品屏幕展示项目 一个是广汽云渲染的一个云看车项目 一个是奥迪中国充电桩的网页开发项目&#xff0c; 奥迪中国做个饭也是目前正在做的 不…

机器人SLAM与自主导航

机器人技术的迅猛发展&#xff0c;促使机器人逐渐走进了人们的生活&#xff0c;服务型室内移动机器人更是获得了广泛的关注。但室内机器人的普及还存在许多亟待解决的问题&#xff0c;定位与导航就是其中的关键问题之一。在这类问题的研究中&#xff0c;需要把握三个重点&#…

Python 打包whl文件Setup参数

setup 函数常用的参数如下 参数说明name包名称version包版本author程序的作者author_email程序的作者的邮箱地址maintainer维护者maintainer_email维护者的邮箱地址url程序的官网地址license程序的授权信息description程序的简单描述long_description程序的详细描述platforms程…

专题:链表常考题目汇总

文章目录 反转类型&#xff1a;206.反转链表完整版二刷记录 25. K个一组反转链表1 &#xff1a;子链表左闭右闭反转版本2 &#xff1a; 子链表左闭右开反转版本&#xff08;推荐&#xff09;⭐反转链表左闭右闭和左闭右开 合并类型&#xff1a;21.合并两个有序链表1: 递归法2: …

10月19日,每日信息差

今天是2023年10月19日&#xff0c;以下是为您准备的17条信息差 第一、中国海洋石油遭南向资金净卖出2.38亿港元 第二、阅文集团侯晓楠&#xff1a;网文已经成为中国文化的一张全球名片。据了解&#xff0c;2022年以来&#xff0c;阅文已经在海外上线了自制的300多部动漫影视作…