PyTorch JIT和TorchScript,一个API提升推理性能50%

PyTorch支持两种模式:eager模式和script模式。eager模式主要用于模型的编写、训练和调试,script模式主要是针对部署的,其包含PytorchJIT和TorchScript(一种在 PyTorch 中执行高效的序列化代码格式)。

script模式使用torch.jit.tracetorch.jit.script创建一个PyTorch eager module的中间表示(intermediate representation, IR),IR 经过内部优化,并在运行时使用 PyTorch JIT 编译。PyTorch JIT 编译器使用运行时信息来优化 IR。该 IR 与 Python 运行时是解耦的。

PyTorch JIT(Just-In-Time Compilation)是 PyTorch 中的即时编译器。

  1. 它允许你将模型转化为 TorchScript 格式,从而提高模型的性能和部署效率。
  2. JIT 允许你在动态图和静态图之间无缝切换。你可以在 Python 中以动态图的方式构建和调试模型,然后将模型编译为 TorchScript 以进行优化和部署。
  3. JIT 允许你在不同的深度学习框架之间进行模型转换,例如将 PyTorch 模型转换为 ONNX 格式,从而可以在其他框架中运行。

TorchScript 是 PyTorch 提供的一种将模型序列化以便在其他环境中运行的机制。它将 PyTorch 模型编译成一种中间表示形式,可以在没有 Python 解释器的环境中运行。这使得模型可以在 C++ 等其他语言中运行,也可以在嵌入式设备等资源受限的环境中实现高效的推理。

以下是 TorchScript 的一些重要特性和用途:

  1. 静态图表示形式:TorchScript 是一种静态图表示形式,它在模型构建阶段对计算图进行编译和优化,而不是在运行时动态构建。这可以提高模型的执行效率。
  2. 模型导出:TorchScript 允许将 PyTorch 模型导出到一个独立的文件中,然后可以在没有 Python 环境的设备上运行。
  3. 跨平台部署:TorchScript 允许在不同的深度学习框架之间进行模型转换,例如将 PyTorch 模型转换为 ONNX 格式,从而可以在其他框架中运行。
  4. 模型优化和量化:通过 TorchScript,你可以使用各种技术(如量化)对模型进行优化,从而减小模型的内存占用和计算资源消耗。
  5. 融合和集成:TorchScript 可以帮助你将多个模型整合到一个整体流程中,从而提高系统的整体性能。
  6. 嵌入式设备:对于资源受限的嵌入式设备,TorchScript 可以帮助你优化模型以适应这些环境。

使用 TorchScript 可以将 PyTorch 模型变得更容易在生产环境中部署和集成。然而,它也可能需要你对模型进行一些修改以使其可以成功编译为 TorchScript。

总的来说,TorchScript 是一个强大的工具,特别是对于需要在不同环境中部署 PyTorch 模型的情况。通过将模型导出为 TorchScript,你可以实现更广泛的模型应用和部署。

一段话总结,为什么要用以及什么时候要用script模式呢?

  1. 可以脱离python GIL以及python runtime的限制来运行模型,比如通过LibTorch通过C++来运行模型。这样方便了模型部署,例如可以在IoT等平台上运行。例如这个tutorial,使用C++来运行pytorch的model。
  2. PyTorch JIT是用于pytorch的优化的JIT编译器,它使用运行时信息来优化 TorchScript modules,可以自动进行层融合、量化、稀疏化等优化。因此,相比pytorch model,TorchScript的性能会更高。

Script mode通过torch.jit.trace或者torch.jit.script来调用。这两个函数都是将python代码转换为TorchScript的两种不同的方法。torch.jit.trace将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个PyTorch模型,torch.jit.trace会跟踪此input在model中的计算过程,然后将其转换为Torch脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。torch.jit.script直接将Python函数(或者一个Python模块)通过python语法规则和编译转换为Torch脚本。torch.jit.script更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于RNN或者一些具有可变序列长度的模型,使用torch.jit.script会更为方便。

在通常情况下,更应该倾向于使用torch.jit.trace而不是torch.jit.script

在上一篇blog中,我们非常非常详细介绍了torch.jit.tracetorch.jit.script的区别以及使用建议。强烈建议先阅读上一篇blog,再来阅读此篇内容。

本篇中,我们重点看一下TorchScript model与eager model的性能区别。

JIT Trace

torch.jit.trace使用eager model和一个dummy input作为输入,tracer会根据提供的model和input记录数据在模型中的流动过程,然后将整个模型转换为TorchScript module。看一个具体的例子:

我们使用BERT(Bidirectional Encoder Representations from Transformers)作为例子。

from transformers import BertTokenizer, BertModel
import numpy as np
import torch
from time import perf_counterdef timer(f,*args):   start = perf_counter()f(*args)return (1000 * (perf_counter() - start))# 加载bert model
native_model = BertModel.from_pretrained("bert-base-uncased")
# huggingface的API中,使用torchscript=True参数可以直接加载TorchScript model
script_model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)script_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', torchscript=True)# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = script_tokenizer.tokenize(text)# Masking one of the input tokens
masked_index = 8tokenized_text[masked_index] = '[MASK]'indexed_tokens = script_tokenizer.convert_tokens_to_ids(tokenized_text)segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

然后分别在CPU和GPU上测试eager mode的pytorch推理速度。

# 在CPU上测试eager model推理性能
native_model.eval()
np.mean([timer(native_model,tokens_tensor,segments_tensors) for _ in range(100)])# 在GPU上测试eager model推理性能
native_model = native_model.cuda()
native_model.eval()
tokens_tensor_gpu = tokens_tensor.cuda()
segments_tensors_gpu = segments_tensors.cuda()
np.mean([timer(native_model,tokens_tensor_gpu,segments_tensors_gpu) for _ in range(100)])

再分别在CPU和GPU上测试script mode的TorchScript模型的推理速度

# 在CPU上测试TorchScript性能
traced_model = torch.jit.trace(script_model, [tokens_tensor, segments_tensors])
# 因模型的trace时,已经包含了.eval()的行为,因此不必再去显式调用model.eval()
np.mean([timer(traced_model,tokens_tensor,segments_tensors) for _ in range(100)])# 在GPU上测试TorchScript的性能

最终运行结果如表

CPU latency (ms)GPU latency (ms)
PyTorch171.2730.42
TorchScript165.2413.50

我使用的硬件规格是google colab,cpu是Intel(R) Xeon(R) CPU @ 2.00GHz,GPU是Tesla T4

从结果来看,在CPU上,TorchScript比pytorch eager快了3.5%,在GPU上,TorchScript比pytorch快了55.6%

然后我们再用ResNet做一个测试。

import torchvision
import torch
from time import perf_counter
import numpy as npdef timer(f,*args):   start = perf_counter()f(*args)return (1000 * (perf_counter() - start))# Pytorch cpu versionmodel_ft = torchvision.models.resnet18(pretrained=True)
model_ft.eval()
x_ft = torch.rand(1,3, 224,224)
print(f'pytorch cpu: {np.mean([timer(model_ft,x_ft) for _ in range(10)])}')# Pytorch gpu versionmodel_ft_gpu = torchvision.models.resnet18(pretrained=True).cuda()
x_ft_gpu = x_ft.cuda()
model_ft_gpu.eval()
print(f'pytorch gpu: {np.mean([timer(model_ft_gpu,x_ft_gpu) for _ in range(10)])}')# TorchScript cpu versionscript_cell = torch.jit.script(model_ft, (x_ft))
print(f'torchscript cpu: {np.mean([timer(script_cell,x_ft) for _ in range(10)])}')# TorchScript gpu versionscript_cell_gpu = torch.jit.script(model_ft_gpu, (x_ft_gpu))
print(f'torchscript gpu: {np.mean([timer(script_cell_gpu,x_ft.cuda()) for _ in range(100)])}')
CPU latency (ms)GPU latency (ms)
PyTorch77.472.99
TorchScript74.241.64

TorchScript相比PyTorch eager model,CPU性能提升4.2%,GPU性能提升45%。与Bert的结论一致。

总结

  1. 本文重点说明了Pytorch的eager模式和script模式,重点是script模式的TorchScript和Pytorch JIT
  2. 上一篇文章重点说明了eager模式的model转为script模式的TorchScript的两个api,torch.jit.tracetorch.jit.script的区别,这是这一篇文章的基础,建议先阅读上一篇文章
  3. 使用Bert和ResNet两个网络进行了Pytorch eager model和TorchScript的CPU和GPU性能测试。结论在两个网络上一致,使用TorchScript在CPU上,相比PyTorch eager mode,会有4%左右的性能提升,在GPU上,会有50%左右的性能提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114995.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Openssl数据安全传输平台008:业务数据分析+工厂方法

文章目录 UML图1.1 客户端1.2 服务器端 UML图 1.1 客户端 // 准备要发送的数据 struct RequestMsg {//1 密钥协商 //2 密钥校验; // 3 密钥注销int cmdType; // 报文类型string clientId; // 客户端编号string serverId; // 服务器端编号string sign;string data; };1.2 服务器…

Unity之ShaderGraph如何实现UV抖动

前言 今天我们通过噪波图来实现一个UV抖动的效果。 如下图所示: 关键节点 Time:提供对着色器中各种时间参数的访问 UV:提供对网格顶点或片段的UV坐标的访问。可以使用通道下拉参数选择输出值的坐标通道。 SimpleNoise:根据…

GEE图表——利用chirps降水数据进行某个区域累计降水量的图表绘制

简介 以下是在GEE云平台利用chirps降水数据进行某个区域累计降水量的图表绘制的具体步骤: 1. 打开GEE云平台的网站(https://code.earthengine.google.com)并登录账户。 2. 在左上角的搜索栏中输入“Chirps”,点击回车以搜索Chirps降水数据集。 3. 点击搜索结果中的Chir…

Windows Server 2019 搭建FTP站点

目录 1.添加IIS及FTP服务角色 2.创建FTP账户(用户名和密码)和组 3.设置共享文件夹的权限 4.添加及设置FTP站点 5.配置FTP防火墙支持 6.配置安全组策略 7.客户端测试 踩过的坑说明: 1.添加IIS及FTP服务角色 a.选择【开始】→【服务器…

电流监测芯片SGM8199A2应用电路设计

SGM8199是一系列具有电压输出功能的双向电流监测芯片,用于监测共模电压范围内分流电阻上的压降,而不受电源电压的影响。该器件具有-0.1V至26V的宽共模电压范围输入。低偏移使得在监测电流时允许分流器上的满量程最大压降为10mV。SGM8199系列提供三种固定…

关于vant 的tabbar功能

1、想要实现tabbar页面A,其他的页面B(非tabbar页面)。 从A页面进入B页面,底部的active选中效果应该被取消掉,但是还是选中A。 按照官网的说法有两个方法 一、根据path路径 二、自定义的model 但是!但是…

贪吃蛇项目实践

游戏背景: 贪吃蛇是久负盛名的游戏,它也和俄罗斯⽅块,扫雷等游戏位列经典游戏的⾏列。 实现基本的功能: 贪吃蛇地图绘制 蛇吃⻝物的功能 (上、下、左、右⽅向键控制蛇的动作) 蛇撞墙死亡 蛇撞⾃⾝死亡 计…

rust学习——栈、堆、所有权

文章目录 栈、堆、所有权栈(Stack)与堆(Heap)栈堆性能区别所有权与堆栈 所有权原则变量作用域所有权与函数返回值与作用域 栈、堆、所有权 栈(Stack)与堆(Heap) 栈和堆是编程语言最核心的数据结构,但是在很多语言中,你并不需要深入了解栈与堆。 但对于…

ReentrantLock与synchronized区别之比较(面试)

背景: 我们Java开发中需要保证数据线程安全时有多重选择,直接使用线程安全的集合类,或者某些变量我们通过ReentrantLock来保证安全,或者使用synchronized关键字,那两者有何区别? 备注: Reent…

Linux编程——多任务间通信和同步

在前面的文章中(Linux编程基础——多线程),简单对Linux中的多线程进行了介绍,包括pthread、信号量与互斥锁,本文将对Linux编程中的多任务间通信与同步技术进行相对完整的补充。 在Linux中有两种多任务实现手段&#xf…

ubuntu20.04安装MySQL8、MySQL服务管理、mysql8卸载

ubuntu20.04安装MySQL8 #更新源 sudo apt-get update #安装 sudo apt-get install mysql-serverMySQL服务管理 # 查看服务状态 sudo service mysql status # 启动服务 sudo service mysql start # 停止服务 sudo service mysql stop # 重启服务 sudo service mysql restart登…

中间件安全-CVE复现WeblogicJenkinsGlassFish漏洞复现

目录 服务攻防-中间件安全&CVE复现&Weblogic&Jenkins&GlassFish漏洞复现中间件-Weblogic安全问题漏洞复现CVE_2017_3506漏洞复现 中间件-JBoos安全问题漏洞复现CVE-2017-12149漏洞复现CVE-2017-7504漏洞复现 中间件-Jenkins安全问题漏洞复现CVE-2017-1000353漏…

idea设置字体大小快捷键 Ctrl+鼠标上下滑 字体快捷键缩放设置

双击 按住ctrl鼠标滑轮上划放大就好了 这个双击设置为,Ctrl鼠标下滑 字体缩小就好了

03-垃圾收集策略与算法

垃圾收集策略与算法 程序计数器、虚拟机栈、本地方法栈随线程而生,也随线程而灭;栈帧随着方法的开始而入栈,随着方法的结束而出栈。这几个区域的内存分配和回收都具有确定性,在这几个区域内不需要过多考虑回收的问题,因…

手把手创建属于自己的ASP.NET Croe Web API项目

第一步:创建项目的时候选择ASP.NET Croe Web API 点击下一步,然后配置: 下一步:

Adobe Photoshop 基本操作

PS快捷键 图层 选择图层 Ctrl T:可以对图层的大小和位置进行调整 填充图层 MAC: AltBackspace (前景) or CtrlBackspace (背景) WINDOWS: AltDelete (前景) or CtrlDelete (背景) 快速将图层填充为前景色或背景色 平面化图层(盖印图层&#xff09…

性能测试LoadRunner02

本篇主要讲:通过Controller设计简单的测试场景,可以简单的分析性能测试报告。 Controller 设计场景 Controller打开方式 1)通过VUG打开 2)之间双击Controller 不演示了,双击打开,选择Manual Scenario自…

《视觉 SLAM 十四讲》V2 第 9 讲 后端优化1 【扩展卡尔曼滤波器 EKF BA+非线性优化(Ceres、g2o)】

文章目录 第9讲 后端19.1.2 线性系统和 KF9.1.4 扩展卡尔曼滤波器 EKF 不足 9.2 BA 与 图优化9.2.1 投影模型和 BA 代价函数9.2.2 BA 的求解9.2.3 稀疏性 和 边缘化9.2.4 鲁棒核函数 9.3 实践: Ceres BA 【Code】本讲 CMakeLists.txt 9.4 实践:g2o 求解 …

100 # mongoose 的使用

mongoose elegant mongodb object modeling for node.js https://mongoosejs.com/ 安装 mongoose npm i mongoose基本示例 const mongoose require("mongoose");// 1、连接 mongodb let conn mongoose.createConnection("mongodb://kaimo313:kaimo313loc…

如何从小白成长为AI工程师笔记

📚入门机器学习基础 对于本科生来说,需要打好数学基础,包括高数、概率论和线性代数。 对于已经上研究生或工作想转行的人来说,可以直接开始学习机器学习算法,重要的是理解算法的原理和推导过程。如果有时间和需要&am…