pytorch forward_【Pytorch部署】TorchScript

v2-3bc7b6600232fe459c94ca4e517dde98_1440w.jpg?source=172ae18b

TorchScript是什么?

TorchScript - PyTorch master documentation​pytorch.org

TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。

我们提供了一些工具来增量地将模型从纯Python程序转换为能够独立于Python运行的TorchScript程序,例如在独立的c++程序中。这使得使用熟悉的Python工具在PyTorch中训练模型,然后通过TorchScript将模型导出到生产环境中成为可能,在这种环境中,Python程序可能由于性能和多线程的原因不适用。

编写TorchScript代码

torch.jit.script(obj)

脚本化一个函数或者nn.Module对象,将会检查它的源代码, 将其作为TorchScript代码使用TorchScrit编译器编译它,返回一个ScriptModule或ScriptFunction。 TorchScript语言自身是Python语言的一个子类, 因此它并非具有所有的Python语言特性。 torch.jit.script能够被作为函数或装饰器使用。参数obj可以是class, function, nn.Module。

具体地,脚本化一个函数torch.jit.script 装饰器将会通过编译函数被装饰函数体来构造一个ScriptFunction对象。例如:

import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rprint(type(foo))  # torch.jit.ScriptFuncion# See the compiled graph as Python code
print(foo.code)

脚本化一个nn.Module:默认地编译其forward方法,并递归地编译其子模块以及被forward调用的函数。如果一个模块只使用TorchScript中支持的特性,则不需要更改原始模块代码。编译器将构建ScriptModule,其中包含原始模块的属性、参数和方法的副本。例如:

import torchclass MyModule(torch.nn.Module):def __init__(self, N, M):super(MyModule, self).__init__()# This parameter will be copied to the new ScriptModuleself.weight = torch.nn.Parameter(torch.rand(N, M))# When this submodule is used, it will be compiledself.linear = torch.nn.Linear(N, M)def forward(self, input):output = self.weight.mv(input)# This calls the `forward` method of the `nn.Linear` module, which will# cause the `self.linear` submodule to be compiled to a `ScriptModule` hereoutput = self.linear(output)return outputscripted_module = torch.jit.script(MyModule(2, 3))

编译一个不在forward中的方法以及递归地编译其内的所有方法,可在此方法上使用装饰器torch.jit.export为了忽视某些方法也可以使用装饰器为了忽视某些方法也可以使用装饰器torch.jit.ignoretorch.jit.unused

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()@torch.jit.exportdef some_entry_point(self, input):return input + 10@torch.jit.ignoredef python_only_fn(self, input):# This function won't be compiled, so any# Python APIs can be usedimport pdbpdb.set_trace()def forward(self, input):if self.training:self.python_only_fn(input)return input * 99scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2))) 

torch.jit.trace(func,example_inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)

跟踪一个函数并返回一个可执行的或ScriptFunction对象,将使用即时编译(JIT)进行优化。跟踪非常适合那些只操作单张量或张量的列表、字典和元组的代码。使用torch.jit.tracetorch.jit.trace_module ,你能将一个模型或python函数转为TorchScript中的ScriptModuleScriptFunction。根据你提供的输入样例,它将会运行 该函数并记录所有张量上执行的操作。

Tracing 仅仅正确地记录那些不是数据依赖的函数和nn.Module(例如没有对数据的条件判断) 并且它们也没有任何未跟踪的外部依赖(例如执行输入输出或访问全局变量). Tracing 只记录在给定张量上运行给定函数时所执行的操作。 因此,返回的ScriptModule将始终在任何输入上运行相同的跟踪图。当你的模块需要根据输入和/或模块状态运行不同的操作集时,这就产生了一些重要的影响。例如:

  • Tracing不会记录任何类似if语句或循环的控制流。当这个控制流在您的模块中是常量时,这是没有问题的,并且它通常内联了控制流决策。但有时控制流实际上是模型本身的一部分。例如,一个递归网络是一个输入序列长度(可能是动态的)的循环。
  • 在返回的ScriptModule中,无论ScriptModule处于哪种模式,在train和eval模式中具有不同行为的操作都将始终表现为处于跟踪时所处的模式。

在这种情况下,Trace是不合适的,Script是更好的选择。如果你跟踪这样的模型,您可能会在后续的模型调用中得到不正确的结果。当执行可能导致产生错误跟踪的操作时,跟踪程序将尝试发出警告。

tracing a function:

import torchdef foo(x, y):return 2 * x + y# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

tracing a existing module

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)

跟踪一个模块并返回一个可执行的ScriptModule,该脚本模块将使用即时编译进行优化。当一个模块被传递到torch.jit.trace,只运行和跟踪forward方法。使用trace_module,您可以为要跟踪的示例输入指定一个方法名字典(参见下面的example_input参数)。

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)def weighted_kernel_sum(self, weight):return weight * self.conv.weightn = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

class torch.jit.ScriptModule

ScriptModule 封装一个c++接口中的torch::jit::Module类, 有下列属性及方法:

  • code 返回forward方法的内部图的打印表示(具有有效的Python语法)
  • graph返回forward方法的内部图的字符串表示形式
  • inlined_graph返回forward方法的内部图的字符串表示形式。此图将被预处理为内联所有函数和方法调用。
  • save(f,_extra_files=ExtraFilesMap{})

class torch.jit.ScriptFunction 与上者类似

torch.jit.save(m,f,_extra_files=ExtraFilesMap{})

保存此模块的脱机版本,以便在单独的进程中使用。所保存的模块序列化此模块的所有方法、子模块、参数和属性。它可以使用torch::jit::load(文件名)加载到c++ API中,也可以使用torch.jit.load加载到Python API中。为了能够保存模块,它必须不调用任何本机Python函数。这意味着所有子模块也必须是ScriptModule的子类。所有模块,不管它们的设备是什么,总是在加载过程中加载到CPU上。这与torch.load()的语义不同,将来可能会改变。

import torch
import ioclass MyModule(torch.nn.Module):def forward(self, x):return x + 10m = torch.jit.script(MyModule())# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)# Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

torch.jit.load(f,map_location=None,_extra_files=ExtraFilesMap{})

加载先前用torch.jit.save保存的ScriptModule或ScriptFunction所有之前保存的模块,无论它们的设备是什么,都首先加载到CPU上,然后移动到它们保存的设备上。如果失败(例如,因为运行时系统没有特定的设备),就会引发异常。

import torch
import iotorch.jit.load('scriptmodule.pt')# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:buffer = io.BytesIO(f.read())# Load all tensors to the original device
torch.jit.load(buffer)# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')# Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])

torch.jit.ignore(drop=False, **kwargs)

这个装饰器向编译器表明,一个函数或方法应该被忽略,并保留为Python函数。这允许您在模型中保留尚未与TorchScript兼容的代码。如果从TorchScript调用,被忽略的函数将把调用分派给Python解释器。函数被忽略的模型不能导出。使用drop=True参数时可以,但会抛出异常。最好使用torch.jit.unused

import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignoredef debugger(self, x):import pdbpdb.set_trace()def forward(self, x):x += 10# The compiler would normally try to compile `debugger`,# but since it is `@ignore`d, it will be left as a call# to Pythonself.debugger(x)return xm = torch.jit.script(MyModule())# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

使用torch.jit.ignore(drop=True), 这一方法已被torch.jit.unused替代。

import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignore(drop=True)def training_method(self, x):import pdbpdb.set_trace()def forward(self, x):if self.training:self.training_method(x)return xm = torch.jit.script(MyModule())# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")

torch.jit.unused(fn)

这个装饰器向编译器表明,应该忽略一个函数或方法,并用引发异常来替换它。这允许您在模型中保留与TorchScript不兼容的代码,同时仍然导出模型。

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self, use_memory_efficent):super(MyModule, self).__init__()self.use_memory_efficent = use_memory_efficent@torch.jit.unuseddef memory_efficient(self, x):import pdbpdb.set_trace()return x + 10def forward(self, x):# Use not-yet-scriptable memory efficient modeif self.use_memory_efficient:return self.memory_efficient(x)else:return x + 10m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))

混合Tracing和Scripting

在许多情况下,跟踪或脚本是将模型转换为TorchScript的一种更简单的方法。可以编写跟踪和脚本来满足模型某一部分的特定需求。

脚本函数可以调用跟踪函数。当您需要围绕一个简单的前馈模型使用控制流时,这一点特别有用。例如,序列到序列模型的波束搜索通常用脚本编写,但可以调用使用跟踪生成的编码器模块。

例如在脚本中调用跟踪函数

import torchdef foo(x, y):return 2 * x + ytraced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))@torch.jit.script
def bar(x):return traced_foo(x, x)

跟踪函数也可以调用脚本函数。当模型的一小部分需要一些控制流时,这是很有用的,即使大部分模型只是一个前馈网络。由跟踪函数调用的脚本函数中的控制流被正确保存。

例如在跟踪函数中调用脚本函数

import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

这个组合也适用于nn.Module。

import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super(MyScriptModule, self).__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(),torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/289101.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

兼容ie8 rgba()用法

今天遇到了一个问题,要在一个页面中设置一个半透明的白色div。这个貌似不是难题,只需要给这个div设置如下的属性即可: background: rgba(255,255,255,.1); 但是要兼容到ie8。这个就有点蛋疼了。因为ie8不支持rgba()函数。下面我们总结一下rgb…

hdu水仙花

水仙花数Time Limit : 2000/1000ms (Java/Other) Memory Limit : 65536/32768K (Java/Other) Total Submission(s) : 11 Accepted Submission(s) : 6 Problem Description 春天是鲜花的季节,水仙花就是其中最迷人的代表,数学上有个水仙花数&#xff…

python中的标识符能不能使用关键字_Python中的标识符不能使用关键字

Python中的标识符不能使用关键字答:√智慧职教: 检查客室座椅外观良好,确认?无破损答:坐垫 靠背关于投标报价时综合单价的确定,下列做法中正确的是()答:以项目特征描述为依据确定综合单价城市总体规划调查时&#xff…

C# WPF实战项目升级了

概述之前用Caliburn.Micro搭建的WPF实战项目,CM框架选用了 3.0.3,实际上CM框架目前最新版已经到4.0。173了,所有很有必须升级一下项目了. 本来打算把平台框架也直接升级到.NET 6 的,但是项目里面很多库不支持最新的平台版本&#…

Android之通过ContentResolver获取手机图片和视频的路径和生成缩略图和缩略图路径

1 问题 获取手机所有图片和视频的路径和生成图片和视频的缩略图和缩略图路径 生成缩略图我们用的系统函数 public static Bitmap getThumbnail(ContentResolver cr, long origId, int kind, Options options) {throw new RuntimeException("Stub!");} 调用如下 M…

ArcGIS Engine开发模板及C#代码

目 录 1. 模板 2. 代码 1. 模板 以下为AE开发软件自带的模板及代码,开发工具为VS 2012+ArcGIS Engine 10.2。 2. 代码 using System; using System.Drawing; using System.Collections; using System.ComponentModel; using System.Windows.Forms; using System.Data; us…

为何解析浏览器地址参数会为null_request 包中出现 DNS 解析超时的探究

事情的起因是这样的,公司使用自建 dns 服务器,但是有一个致命缺陷,不支持 ipv6 格式的地址解析,而 node 的 DNS 解析默认是同时请求 v4 和 v6 的地址的,这样会导致偶尔在解析 v6 地址的时候出现超时。本文链接地址 htt…

高级iOS面试题

非标准答案 2 1: 类方法是可以直接通过类名直接调用,无需进行实例化对象。类方法是以开头2. 实例方法,需要显示实例化对象,为对象分配堆栈空间,并通过对象实例调用实例方法3. RUNTIME 是在程序运行过程动态对实例对象进行操作&…

dotTrace 6.1帮你理解SQL查询如何影响应用性能

dotTrace是JetBrains公司旗下的一款.NET应用程序性能瓶颈检测工具。该工具是ReSharper旗舰版的一部分,也可以单独安装。近日,dotTrace 6.1发布,主要增加了人们期待已久的SQL查询性能分析,开发人员可以通过它获得特定查询的执行时间…

React Native之函数作为参数传递给另外一个函数去调用

1 用法 我们一般喜欢把js里面的函数作为参数传递给另外一个函数,然后再调用这个函数,有点像C语言里面的函数指针 2 代码测试 写了一个函数,2个参数分别是函数,然后更具数据决定调用哪个函数 /*** Sample React Native App* https://github.com/facebook/react-native** form…

STL—list

前面我们分析了vector&#xff0c;这篇介绍STL中另一个重要的容器list list的设计 list由三部分构成&#xff1a;list节点、list迭代器、list本身 list节点 list是一个双向链表&#xff0c;所以其list节点中有前后两个指针。如下&#xff1a; // list节点 template <typenam…

C#语法糖 Null 条件运算符 【?.】

例子比如说:我们有一个UserInformation类public class UserInformation{ public string Name { get; set; }public List<string> Address { get; set; }}有下面一段代码,我们获取张三的第一个地址static void Main(string[] args){UserInformation user new UserInforma…

用单片机测量流体流速的_流量测量的主要方法

电磁流量计由于流量检测的复杂性和多样性&#xff0c;流量检测的方法非常多&#xff0c;常用于工业生产中的有10多种。流量测量与仪表可以分为测量瞬时流量和总流量两类。生产过程中流量大多作为监控参数&#xff0c;测量的是瞬时流量&#xff0c;但在物料平衡和能源计量的贸易…

C#帮助控件HelpProvider的使用

using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms;namespace 帮助控件的使用 {public partial class Form1 : Form{public Form1(…

mysql申请审核系统_Mysql审核工具archery

Mysql审核工具archery系统&#xff1a;Centos6.8ip:192.168.122.150安装Python和virtualenv编译安装[rootwww ~]# yum install wget gcc make zlib-devel openssl openssl-devel[rootwww src]# wget "https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz"[…

iOS——Core Animation 知识摘抄(二)

阴影 主要是shadowOpacity 、shadowColor、shadowOffset和shadowRadius四个属性 shadowPath属性 我们已经知道图层阴影并不总是方的&#xff0c;而是从图层内容的形状继承而来。这看上去不错&#xff0c;但是实时计算阴影也是一个非常消耗资源的&#xff0c;尤其是图层有多个子…

Blazor University (11)组件 — 替换子组件的属性

原文链接&#xff1a;https://blazor-university.com/components/replacing-attributes-on-child-components/替换子组件的属性源代码[1]到目前为止&#xff0c;我们已经了解了如何创建代码生成的属性[2]&#xff0c;以及如何捕获意外参数[3]。除了这两种技术之外&#xff0c;B…

HTTPS实现原理

HTTPS实现原理 HTTPS&#xff08;全称&#xff1a;Hypertext Transfer Protocol over Secure Socket Layer&#xff09;&#xff0c;是以安全为目标的HTTP通道&#xff0c;简单讲是HTTP的安全版。即HTTP下加入SSL层&#xff0c;HTTPS的安全基础是SSL。其所用的端口号是443。…

Android之在ubuntu上过滤多条关键字日志

1 问题 比如我们在查问题的时候,需要过滤多个关键字,我平时的做法是一个终端执行下面的命令,然后几个关键字就几个终端,切换来切换去不方便看日志 adb logcat | grep **** 2 改进办法 今天看到同事用了grep -E,我们可以通过-E这个参数过滤多个关键字,比如 adb logcat | gre…

C#使用ServiceController控制windows服务

C#中,使用ServiceController类控制windows服务,使用之前要先添加引用:System.ServiceProcess,然后在命名空间中引用:using System.ServiceProcess。下面举例获取本机的所有已安装的Windows服务和应用,然后查找某一应用活服务是否已经安装。 代码: using System; using S…