编写高效的PyTorch代码技巧(下)

点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”

总第 133 篇文章,本文大约 3000 字,阅读大约需要 15 分钟

原文:https://github.com/vahidk/EffectivePyTorch

作者:vahidk

前言

这是一份 PyTorch 教程和最佳实践笔记,目录如下所示:

  1. PyTorch 基础

  2. 将模型封装为模块

  3. 广播机制的优缺点

  4. 使用好重载的运算符

  5. 采用 TorchScript 优化运行时间

  6. 构建高效的自定义数据加载类

  7. PyTorch 的数值稳定性

上篇文章的链接如下:

编写高效的PyTorch代码技巧(上)

这次介绍后面3点,写出高效的代码以及保证做数值计算时候的稳定性。


5. 采用 TorchScript 优化运行时间

PyTorch 优化了维度很大的张量的运算操作。在 PyTorch 中对小张量进行太多的运算操作是非常低效的。所以有可能的话,将计算操作都重写为批次(batch)的形式,可以减少消耗和提高性能。而如果没办法自己手动实现批次的运算操作,那么可以采用 TorchScript 来提升代码的性能。

TorchScript 是一个 Python 函数的子集,但经过了 PyTorch 的验证,PyTorch 可以通过其 just in time(jtt) 编译器来自动优化 TorchScript 代码,提高性能。

下面给出一个具体的例子。在机器学习应用中非常常见的操作就是 batch gather ,也就是 output[i] = input[i, index[i]]。其代码实现如下所示:

import torch
def batch_gather(tensor, indices):output = []for i in range(tensor.size(0)):output += [tensor[i][indices[i]]]return torch.stack(output)

通过 torch.jit.script 装饰器来使用 TorchScript 的代码:

@torch.jit.script
def batch_gather_jit(tensor, indices):output = []for i in range(tensor.size(0)):output += [tensor[i][indices[i]]]return torch.stack(output)

这个做法可以提高 10% 的运算速度。

但更好的做法还是手动实现批次的运算操作,下面是一个向量化实现的代码例子,提高了 100 倍的速度:

def batch_gather_vec(tensor, indices):shape = list(tensor.shape)flat_first = torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])offset = torch.reshape(torch.arange(shape[0]).cuda() * shape[1],[shape[0]] + [1] * (len(indices.shape) - 1))output = flat_first[indices + offset]return output

6. 构建高效的自定义数据加载类

上一节介绍了如何写出更加高效的 PyTorch 的代码,但为了让你的代码运行更快,将数据更加高效加载到内存中也是非常重要的。幸运的是 PyTorch 提供了一个很容易加载数据的工具,即 DataLoader 。一个 DataLoader 会采用多个 workers 来同时将数据从 Dataset 类中加载,并且可以选择使用 Sampler 类来对采样数据和组成 batch 形式的数据。

如果你可以随时访问你的数据,那么使用 DataLoader 会非常简单:只需要继承 Dataset 类别并实现 __getitem__ (读取每个数据)和 __len__(返回数据集的样本数量)这两个方法。下面给出一个代码例子,如何从给定的文件夹中加载图片数据:

import glob
import os
import random
import cv2
import torchclass ImageDirectoryDataset(torch.utils.data.Dataset):def __init__(path, pattern):self.paths = list(glob.glob(os.path.join(path, pattern)))def __len__(self):return len(self.paths)def __item__(self):path = random.choice(paths)return cv2.imread(path, 1)

比如想将文件夹内所有的 jpeg 图片都加载,代码实现如下所示:

dataloader = torch.utils.data.DataLoader(ImageDirectoryDataset("/data/imagenet/*.jpg"), num_workers=8)
for data in dataloader:# do something with data

这里采用了 8 个 workers 来并行的从硬盘中读取数据。这个数量可以根据实际使用机器来进行调试,得到一个最佳的数量。

当你的数据都很大或者你的硬盘读写速度很快,采用DataLoader进行随机读取数据是可行的。但也可能存在一种情况,就是使用的是一个很慢的连接速度的网络文件系统,请求单个文件的速度都非常的慢,而这可能就是整个训练过程中的瓶颈。

一个更好的做法就是将数据保存为一个可以连续读取的连续文件格式。例如,当你有非常大量的图片数据,可以采用 tar 命令将其压缩为一个文件,然后用 python 来从这个压缩文件中连续的读取图片。要实现这个操作,需要用到 PyTorch 的 IterableDataset。创建一个 IterableDataset 类,只需要实现 __iter__ 方法即可。

下面给出代码实现的例子:

import tarfile
import torchdef tar_image_iterator(path):tar = tarfile.open(self.path, "r")for tar_info in tar:file = tar.extractfile(tar_info)content = file.read()yield cv2.imdecode(content, 1)file.close()tar.members = []tar.close()class TarImageDataset(torch.utils.data.IterableDataset):def __init__(self, path):super().__init__()self.path = pathdef __iter__(self):yield from tar_image_iterator(self.path)

不过这个方法有一个问题,当使用 DataLoader 以及多个 workers 读取这个数据集的时候,会得到很多重复的数据:

dataloader = torch.utils.data.DataLoader(TarImageDataset("/data/imagenet.tar"), num_workers=8)
for data in dataloader:# data contains duplicated items

这个问题主要是因为每个 worker 都会创建一个单独的数据集的实例,并且都是从数据集的起始位置开始读取数据。一种避免这个问题的办法就是不是压缩为一个tar 文件,而是将数据划分成 num_workers 个单独的 tar 文件,然后每个 worker 分别加载一个,代码实现如下所示:

class TarImageDataset(torch.utils.data.IterableDataset):def __init__(self, paths):super().__init__()self.paths = pathsdef __iter__(self):worker_info = torch.utils.data.get_worker_info()# For simplicity we assume num_workers is equal to number of tar filesif worker_info is None or worker_info.num_workers != len(self.paths):raise ValueError("Number of workers doesn't match number of files.")yield from tar_image_iterator(self.paths[worker_info.worker_id])

所以使用例子如下所示:

dataloader = torch.utils.data.DataLoader(TarImageDataset(["/data/imagenet_part1.tar", "/data/imagenet_part2.tar"]), num_workers=2)
for data in dataloader:# do something with data

这是一种简单的避免重复数据的问题。而 tfrecord 则用了比较复杂的办法来共享数据,具体可以查看:

https://github.com/vahidk/tfrecord


7. PyTorch 的数值稳定性

当使用任意一个数值计算库,比如 NumPy 或者 PyTorch ,都需要知道一点,编写数学上正确的代码不一定会得到正确的结果,你需要确保这个计算是稳定的。

首先以一个简单的例子开始。从数学上来说,对任意的非零 x ,都可以知道式子 是成立的。但看看具体实现的时候,是不是总是正确的:

import numpy as npx = np.float32(1)y = np.float32(1e-50)  # y would be stored as zero
z = x * y / yprint(z)  # prints nan

代码的运行结果是打印 nan ,原因是 y 的数值对于 float32 类型来说非常的小,这导致它的实际数值是 0 而不是 1e-50。

另一种极端情况就是 y 非常的大:

y = np.float32(1e39)  # y would be stored as inf
z = x * y / yprint(z)  # prints nan

输出结果依然是 nan ,因为 y 太大而被存储为 inf 的情况,对于 float32 类型来说,其范围是 1.4013e-45 ~ 3.40282e+38,当超过这个范围,就会被置为 0 或者 inf。

下面是如何查看一种数据类型的数值范围:

print(np.nextafter(np.float32(0), np.float32(1)))  # prints 1.4013e-45
print(np.finfo(np.float32).max)  # print 3.40282e+38

为了让计算变得稳定,需要避免过大或者过小的数值。这看起来很容易,但这类问题是很难进行调试,特别是在 PyTorch 中进行梯度下降的时候。这不仅因为需要确保在前向传播过程中的所有数值都在使用的数据类型的取值范围内,还要保证在反向传播中也做到这一点。

下面给出一个代码例子,计算一个输出向量的 softmax,一种不好的代码实现如下所示:

import torchdef unstable_softmax(logits):exp = torch.exp(logits)return exp / torch.sum(exp)print(unstable_softmax(torch.tensor([1000., 0.])).numpy())  # prints [ nan, 0.]

这里计算 logits 的指数数值可能会得到超出 float32 类型的取值范围,即过大或过小的数值,这里最大的 logits 数值是 ln(3.40282e+38) = 88.7,超过这个数值都会导致 nan

那么应该如何避免这种情况,做法很简单。因为有 ,也就是我们可以对 logits 减去一个常量,但结果保持不变,所以我们选择logits 的最大值作为这个常数,这种做法,指数函数的取值范围就会限制为 [-inf, 0] ,然后最终的结果就是 [0.0, 1.0] 的范围,代码实现如下所示:

import torchdef softmax(logits):exp = torch.exp(logits - torch.reduce_max(logits))return exp / torch.sum(exp)print(softmax(torch.tensor([1000., 0.])).numpy())  # prints [ 1., 0.]

接下来是一个更复杂点的例子。

假设现在有一个分类问题。我们采用 softmax 函数对输出值 logits 计算概率。接着定义采用预测值和标签的交叉熵作为损失函数。对于一个类别分布的交叉熵可以简单定义为 :

所以有一个不好的实现交叉熵的代码实现为:

def unstable_softmax_cross_entropy(labels, logits):logits = torch.log(softmax(logits))return -torch.sum(labels * logits)labels = torch.tensor([0.5, 0.5])
logits = torch.tensor([1000., 0.])xe = unstable_softmax_cross_entropy(labels, logits)print(xe.numpy())  # prints inf

在上述代码实现中,当 softmax 结果趋向于 0,其 log 输出会趋向于无穷,这就导致计算结果的不稳定性。所以可以对其进行重写,将 softmax 维度拓展并做一些归一化的操作:

def softmax_cross_entropy(labels, logits, dim=-1):scaled_logits = logits - torch.max(logits)normalized_logits = scaled_logits - torch.logsumexp(scaled_logits, dim)return -torch.sum(labels * normalized_logits)labels = torch.tensor([0.5, 0.5])
logits = torch.tensor([1000., 0.])xe = softmax_cross_entropy(labels, logits)print(xe.numpy())  # prints 500.0

可以验证计算的梯度也是正确的:

logits.requires_grad_(True)
xe = softmax_cross_entropy(labels, logits)
g = torch.autograd.grad(xe, logits)[0]
print(g.numpy())  # prints [0.5, -0.5]

这里需要再次提醒,进行梯度下降操作的时候需要额外的小心谨慎,需要确保每个网络层的函数和梯度的范围都在合法的范围内,指数函数和对数函数在不正确使用的时候都可能导致很大的问题,它们都能将非常小的数值转换为非常大的数值,或者从很大变为很小的数值。


精选AI文章

1. 10个实用的机器学习建议

2. 深度学习算法简要综述(上)

3. 深度学习算法简要综述(上)

4. 常见的数据增强项目和论文介绍

5. 实战|手把手教你训练一个基于Keras的多标签图像分类器

精选python文章

1.  python数据模型

2. python版代码整洁之道

3. 快速入门 Jupyter notebook

4. Jupyter 进阶教程

5. 10个高效的pandas技巧

精选教程资源文章

1. [资源分享] TensorFlow 官方中文版教程来了

2. [资源]推荐一些Python书籍和教程,入门和进阶的都有!

3. [Github项目推荐] 推荐三个助你更好利用Github的工具

4. Github上的各大高校资料以及国外公开课视频

5. GitHub上有哪些比较好的计算机视觉/机器视觉的项目?

欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!

 

如果觉得不错,在看、转发就是对小编的一个支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

统计(1 - 2)

统计学基础定义 Statistics的前部分为“state”,政府,原由是统计是300年前被首次应用在政府部门统计人口出生和死亡信息的;如今的统计学早已被应用在各个专业领域; 统计学是用以收集数据、分析数据和数据推论的一组概念、原则和方…

2020年计算机视觉学习指南

点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”总第 134 篇文章,本文大约 3000 字,阅读大约需要 10 分钟原文:https://towardsdatascience.com/guide-to-learn-computer-vision-in-2020-36f19d92c934作…

是选择Keras还是PyTorch开始你的深度学习之旅呢?

点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”总第 135 篇文章,本文大约 7000 字,阅读大约需要 20 分钟原文:https://medium.com/karan_jakhar/keras-vs-pytorch-dilemma-dc434e5b5ae0作者&#xff1…

关于myeclipse打开jsp巨慢解决方案

作为企业级开发最流行的工具,用Myeclipse开发java web程序无疑是最合适的,java web前端采用jsp来显示,myeclipse默认打开jsp的视图有卡顿的现象,那么如何更改jsp默认的打开方式,让我们可以进行更快速的jsp开发呢? 简单…

event

听取了网友:kenwang的意见,我的Blog在记流水账啊,现在才发现我发表的都是代码,一个感想也没有,以后要慢慢改正。明天要培训公司的框架,后天要搬家,这个周未没有得休息。

60分钟快速入门PyTorch

点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”总第 136 篇文章,本文大约 26000 字,阅读大约需要 60 分钟PyTorch 是由 Facebook 开发,基于 Torch 开发,从并不常用的 Lua 语言转为 Python …

react学习(38)----react是什么

什么是组件? 官方定义:将一些简短、独立的代码片段组合成复杂的 UI 界面,这些代码片段被称作“组件”。 解读:我们可以理解为能够组成一个UI界面的每一个独立的代码片段,例如表单的代码集合,轮播图的代码集…

大端与小端

/*************************************大端与小端:与大端存储格式相反,在小端存储格式中,低地址中存放的是字数据的低字节,高地址存放的是字数据的高字节**************************************//*联合体union的存放顺序是所有…

react学习(39)----react中的Hello World

ReactDOM.render(<h1>Hello, world!</h1>,document.getElementById(root) ); 它将在页面上展示一个 “Hello, world!” 的标题。

[libGDX游戏开发教程]使用libGDX进行游戏开发(12)-Action动画

前文章节列表&#xff1a;使用libGDX进行游戏开发(11)-高级编程技巧 使用libGDX进行游戏开发(10)-音乐音效不求人&#xff0c;程序员也可以DIY 使用libGDX进行游戏开发(9)-场景过渡使用libGDX进行游戏开发(8)-没有美工的程序员&#xff0c;能够依赖的还有粒子系统 使用libGDX进…

一年了

到温州出差一年了,一个项目做了一年啊,郁闷啊很想回家,回武汉

react学习(40)----react中的jsx简介

const name Josh Perez;const element <h1>Hello, {name}</h1>; ReactDOM.render(element,document.getElementById(root) ); jsx语法是个表达式 可以直接声明变量

将DataSet中的操作更新到Access数据库

代码如下&#xff1a;<%import Namespace Namespacesystem.data%><%import Namespace Namespacesystem.data.oledb%><script languagevb runatserver>Sub page_load()sub page_load() dim strConnection as string dim strSQL as string dim ob…

react学习(41)----react中的jsx简介

JSX 特定属性你可以通过使用引号&#xff0c;来将属性值指定为字符串字面量&#xff1a;const element <div tabIndex"0"></div>;也可以使用大括号&#xff0c;来在属性值中插入一个 JavaScript 表达式&#xff1a;const element <img src{user.ava…

@synthesize obj=_obj的意义详解 @property和@synthesize

本文转载至&#xff1a;http://blog.csdn.net/showhilllee/article/details/8971159我们在进行iOS开发时&#xff0c;经常会在类的声明部分看见类似于synthesize window_window; 的语句&#xff0c;那么&#xff0c;这个window是什么&#xff0c;_ window又是什么&#xff0c;两…

我喜欢的一首歌--《幸福的瞬间》

看了《薰衣草》就开始喜欢这首歌了&#xff0c;看的时候还会为了电视里男女主角痴情的爱情故事落泪&#xff0c;可见我还不成熟。今天正当我和下班人群一起在572上被挤得快变形的时候&#xff0c;车厢里放起了这首歌&#xff0c;我差点以为是我的手机响了。&#xff08;呵呵&am…

react学习(42)----react中的jsx表达对象

JSX 表示对象 Babel 会把 JSX 转译成一个名为 React.createElement() 函数调用。 以下两种示例代码完全等效&#xff1a; const element (<h1 className"greeting">Hello, world!</h1> ); const element React.createElement(h1,{className: greet…

react学习(43)----react中将一个元素渲染为 DOM

假设你的 HTML 文件某处有一个 <div>&#xff1a; <div id"root"></div> 我们将其称为“根” DOM 节点&#xff0c;因为该节点内的所有内容都将由 React DOM 管理。 仅使用 React 构建的应用通常只有单一的根 DOM 节点。如果你在将 React 集成进…

win7 IIS7.5配置伪静态

第一部: 从如下地址中下载URLRewriter组件组件&#xff1a;官方下载地址&#xff1a;http://download.microsoft.com/download/0/4/6/0463611e-a3f9-490d-a08c-877a83b797cf/MSDNURLRewriting.msi第二部&#xff1a;在网站项目中添加URLRewriter程序集的引用。第三部&#xff1…