pytorch forward_【Pytorch部署】TorchScript

v2-3bc7b6600232fe459c94ca4e517dde98_1440w.jpg?source=172ae18b

TorchScript是什么?

TorchScript - PyTorch master documentation​pytorch.org

TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。

我们提供了一些工具来增量地将模型从纯Python程序转换为能够独立于Python运行的TorchScript程序,例如在独立的c++程序中。这使得使用熟悉的Python工具在PyTorch中训练模型,然后通过TorchScript将模型导出到生产环境中成为可能,在这种环境中,Python程序可能由于性能和多线程的原因不适用。

编写TorchScript代码

torch.jit.script(obj)

脚本化一个函数或者nn.Module对象,将会检查它的源代码, 将其作为TorchScript代码使用TorchScrit编译器编译它,返回一个ScriptModule或ScriptFunction。 TorchScript语言自身是Python语言的一个子类, 因此它并非具有所有的Python语言特性。 torch.jit.script能够被作为函数或装饰器使用。参数obj可以是class, function, nn.Module。

具体地,脚本化一个函数torch.jit.script 装饰器将会通过编译函数被装饰函数体来构造一个ScriptFunction对象。例如:

import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rprint(type(foo))  # torch.jit.ScriptFuncion# See the compiled graph as Python code
print(foo.code)

脚本化一个nn.Module:默认地编译其forward方法,并递归地编译其子模块以及被forward调用的函数。如果一个模块只使用TorchScript中支持的特性,则不需要更改原始模块代码。编译器将构建ScriptModule,其中包含原始模块的属性、参数和方法的副本。例如:

import torchclass MyModule(torch.nn.Module):def __init__(self, N, M):super(MyModule, self).__init__()# This parameter will be copied to the new ScriptModuleself.weight = torch.nn.Parameter(torch.rand(N, M))# When this submodule is used, it will be compiledself.linear = torch.nn.Linear(N, M)def forward(self, input):output = self.weight.mv(input)# This calls the `forward` method of the `nn.Linear` module, which will# cause the `self.linear` submodule to be compiled to a `ScriptModule` hereoutput = self.linear(output)return outputscripted_module = torch.jit.script(MyModule(2, 3))

编译一个不在forward中的方法以及递归地编译其内的所有方法,可在此方法上使用装饰器torch.jit.export为了忽视某些方法也可以使用装饰器为了忽视某些方法也可以使用装饰器torch.jit.ignoretorch.jit.unused

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()@torch.jit.exportdef some_entry_point(self, input):return input + 10@torch.jit.ignoredef python_only_fn(self, input):# This function won't be compiled, so any# Python APIs can be usedimport pdbpdb.set_trace()def forward(self, input):if self.training:self.python_only_fn(input)return input * 99scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2))) 

torch.jit.trace(func,example_inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)

跟踪一个函数并返回一个可执行的或ScriptFunction对象,将使用即时编译(JIT)进行优化。跟踪非常适合那些只操作单张量或张量的列表、字典和元组的代码。使用torch.jit.tracetorch.jit.trace_module ,你能将一个模型或python函数转为TorchScript中的ScriptModuleScriptFunction。根据你提供的输入样例,它将会运行 该函数并记录所有张量上执行的操作。

Tracing 仅仅正确地记录那些不是数据依赖的函数和nn.Module(例如没有对数据的条件判断) 并且它们也没有任何未跟踪的外部依赖(例如执行输入输出或访问全局变量). Tracing 只记录在给定张量上运行给定函数时所执行的操作。 因此,返回的ScriptModule将始终在任何输入上运行相同的跟踪图。当你的模块需要根据输入和/或模块状态运行不同的操作集时,这就产生了一些重要的影响。例如:

  • Tracing不会记录任何类似if语句或循环的控制流。当这个控制流在您的模块中是常量时,这是没有问题的,并且它通常内联了控制流决策。但有时控制流实际上是模型本身的一部分。例如,一个递归网络是一个输入序列长度(可能是动态的)的循环。
  • 在返回的ScriptModule中,无论ScriptModule处于哪种模式,在train和eval模式中具有不同行为的操作都将始终表现为处于跟踪时所处的模式。

在这种情况下,Trace是不合适的,Script是更好的选择。如果你跟踪这样的模型,您可能会在后续的模型调用中得到不正确的结果。当执行可能导致产生错误跟踪的操作时,跟踪程序将尝试发出警告。

tracing a function:

import torchdef foo(x, y):return 2 * x + y# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

tracing a existing module

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)

跟踪一个模块并返回一个可执行的ScriptModule,该脚本模块将使用即时编译进行优化。当一个模块被传递到torch.jit.trace,只运行和跟踪forward方法。使用trace_module,您可以为要跟踪的示例输入指定一个方法名字典(参见下面的example_input参数)。

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)def weighted_kernel_sum(self, weight):return weight * self.conv.weightn = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)

class torch.jit.ScriptModule

ScriptModule 封装一个c++接口中的torch::jit::Module类, 有下列属性及方法:

  • code 返回forward方法的内部图的打印表示(具有有效的Python语法)
  • graph返回forward方法的内部图的字符串表示形式
  • inlined_graph返回forward方法的内部图的字符串表示形式。此图将被预处理为内联所有函数和方法调用。
  • save(f,_extra_files=ExtraFilesMap{})

class torch.jit.ScriptFunction 与上者类似

torch.jit.save(m,f,_extra_files=ExtraFilesMap{})

保存此模块的脱机版本,以便在单独的进程中使用。所保存的模块序列化此模块的所有方法、子模块、参数和属性。它可以使用torch::jit::load(文件名)加载到c++ API中,也可以使用torch.jit.load加载到Python API中。为了能够保存模块,它必须不调用任何本机Python函数。这意味着所有子模块也必须是ScriptModule的子类。所有模块,不管它们的设备是什么,总是在加载过程中加载到CPU上。这与torch.load()的语义不同,将来可能会改变。

import torch
import ioclass MyModule(torch.nn.Module):def forward(self, x):return x + 10m = torch.jit.script(MyModule())# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)# Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

torch.jit.load(f,map_location=None,_extra_files=ExtraFilesMap{})

加载先前用torch.jit.save保存的ScriptModule或ScriptFunction所有之前保存的模块,无论它们的设备是什么,都首先加载到CPU上,然后移动到它们保存的设备上。如果失败(例如,因为运行时系统没有特定的设备),就会引发异常。

import torch
import iotorch.jit.load('scriptmodule.pt')# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:buffer = io.BytesIO(f.read())# Load all tensors to the original device
torch.jit.load(buffer)# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')# Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])

torch.jit.ignore(drop=False, **kwargs)

这个装饰器向编译器表明,一个函数或方法应该被忽略,并保留为Python函数。这允许您在模型中保留尚未与TorchScript兼容的代码。如果从TorchScript调用,被忽略的函数将把调用分派给Python解释器。函数被忽略的模型不能导出。使用drop=True参数时可以,但会抛出异常。最好使用torch.jit.unused

import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignoredef debugger(self, x):import pdbpdb.set_trace()def forward(self, x):x += 10# The compiler would normally try to compile `debugger`,# but since it is `@ignore`d, it will be left as a call# to Pythonself.debugger(x)return xm = torch.jit.script(MyModule())# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

使用torch.jit.ignore(drop=True), 这一方法已被torch.jit.unused替代。

import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignore(drop=True)def training_method(self, x):import pdbpdb.set_trace()def forward(self, x):if self.training:self.training_method(x)return xm = torch.jit.script(MyModule())# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")

torch.jit.unused(fn)

这个装饰器向编译器表明,应该忽略一个函数或方法,并用引发异常来替换它。这允许您在模型中保留与TorchScript不兼容的代码,同时仍然导出模型。

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self, use_memory_efficent):super(MyModule, self).__init__()self.use_memory_efficent = use_memory_efficent@torch.jit.unuseddef memory_efficient(self, x):import pdbpdb.set_trace()return x + 10def forward(self, x):# Use not-yet-scriptable memory efficient modeif self.use_memory_efficient:return self.memory_efficient(x)else:return x + 10m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))

混合Tracing和Scripting

在许多情况下,跟踪或脚本是将模型转换为TorchScript的一种更简单的方法。可以编写跟踪和脚本来满足模型某一部分的特定需求。

脚本函数可以调用跟踪函数。当您需要围绕一个简单的前馈模型使用控制流时,这一点特别有用。例如,序列到序列模型的波束搜索通常用脚本编写,但可以调用使用跟踪生成的编码器模块。

例如在脚本中调用跟踪函数

import torchdef foo(x, y):return 2 * x + ytraced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))@torch.jit.script
def bar(x):return traced_foo(x, x)

跟踪函数也可以调用脚本函数。当模型的一小部分需要一些控制流时,这是很有用的,即使大部分模型只是一个前馈网络。由跟踪函数调用的脚本函数中的控制流被正确保存。

例如在跟踪函数中调用脚本函数

import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

这个组合也适用于nn.Module。

import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super(MyScriptModule, self).__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(),torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/289101.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

兼容ie8 rgba()用法

今天遇到了一个问题,要在一个页面中设置一个半透明的白色div。这个貌似不是难题,只需要给这个div设置如下的属性即可: background: rgba(255,255,255,.1); 但是要兼容到ie8。这个就有点蛋疼了。因为ie8不支持rgba()函数。下面我们总结一下rgb…

python中的标识符能不能使用关键字_Python中的标识符不能使用关键字

Python中的标识符不能使用关键字答:√智慧职教: 检查客室座椅外观良好,确认?无破损答:坐垫 靠背关于投标报价时综合单价的确定,下列做法中正确的是()答:以项目特征描述为依据确定综合单价城市总体规划调查时&#xff…

C# WPF实战项目升级了

概述之前用Caliburn.Micro搭建的WPF实战项目,CM框架选用了 3.0.3,实际上CM框架目前最新版已经到4.0。173了,所有很有必须升级一下项目了. 本来打算把平台框架也直接升级到.NET 6 的,但是项目里面很多库不支持最新的平台版本&#…

ArcGIS Engine开发模板及C#代码

目 录 1. 模板 2. 代码 1. 模板 以下为AE开发软件自带的模板及代码,开发工具为VS 2012+ArcGIS Engine 10.2。 2. 代码 using System; using System.Drawing; using System.Collections; using System.ComponentModel; using System.Windows.Forms; using System.Data; us…

高级iOS面试题

非标准答案 2 1: 类方法是可以直接通过类名直接调用,无需进行实例化对象。类方法是以开头2. 实例方法,需要显示实例化对象,为对象分配堆栈空间,并通过对象实例调用实例方法3. RUNTIME 是在程序运行过程动态对实例对象进行操作&…

dotTrace 6.1帮你理解SQL查询如何影响应用性能

dotTrace是JetBrains公司旗下的一款.NET应用程序性能瓶颈检测工具。该工具是ReSharper旗舰版的一部分,也可以单独安装。近日,dotTrace 6.1发布,主要增加了人们期待已久的SQL查询性能分析,开发人员可以通过它获得特定查询的执行时间…

mysql申请审核系统_Mysql审核工具archery

Mysql审核工具archery系统:Centos6.8ip:192.168.122.150安装Python和virtualenv编译安装[rootwww ~]# yum install wget gcc make zlib-devel openssl openssl-devel[rootwww src]# wget "https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz"[…

iOS——Core Animation 知识摘抄(二)

阴影 主要是shadowOpacity 、shadowColor、shadowOffset和shadowRadius四个属性 shadowPath属性 我们已经知道图层阴影并不总是方的,而是从图层内容的形状继承而来。这看上去不错,但是实时计算阴影也是一个非常消耗资源的,尤其是图层有多个子…

Blazor University (11)组件 — 替换子组件的属性

原文链接:https://blazor-university.com/components/replacing-attributes-on-child-components/替换子组件的属性源代码[1]到目前为止,我们已经了解了如何创建代码生成的属性[2],以及如何捕获意外参数[3]。除了这两种技术之外,B…

HTTPS实现原理

HTTPS实现原理 HTTPS(全称:Hypertext Transfer Protocol over Secure Socket Layer),是以安全为目标的HTTP通道,简单讲是HTTP的安全版。即HTTP下加入SSL层,HTTPS的安全基础是SSL。其所用的端口号是443。…

C#使用ServiceController控制windows服务

C#中,使用ServiceController类控制windows服务,使用之前要先添加引用:System.ServiceProcess,然后在命名空间中引用:using System.ServiceProcess。下面举例获取本机的所有已安装的Windows服务和应用,然后查找某一应用活服务是否已经安装。 代码: using System; using S…

电信aep平台是什么意思_江苏天鼎证券:股票平台跳水是什么意思?股票为什么会跳水?...

相信很多新手在刚玩股票的时候会遇到很多的专业的基础知识不能理解,比如什么是跳水?为什么会跳水呢?接下来就为大家详细来说股票的跳水以及为何会跳水。一、股票平台跳水是什么意思?股票跳水通常指股价在较短的时间内,出现从高位下降到低位的现象。出…

mysql mgr简介_MySQL Group Replication(MGR)使用简介与注意事项

MySQL Group Replication(MGR)是MySQL官方在5.7.17版本引进的一个数据库高可用与高扩展的解决方案,以插件形式提供。MGR基于分布式paxos协议,实现组复制,保证数据一致性。内置故障检测和自动选主功能,只要不是集群中的大多数节点都…

python beautifulsoup4 table tr_python BeautifulSoup解析表

牧羊人nacy这是通用的工作示例(表数据)标记。它返回带有内部列的行的列表。第一行仅接受一个(表头/数据)。def tableDataText(table): rows [] trs table.find_all(tr) headerow [td.get_text(stripTrue) for td in trs[0].find_all(th)] # header row i…

clob字段怎么导出_Oracle 11g及12c+版本下为啥有些表不能exp导出?

【引言】今天有同事问了一个问题,在Oracle 11g下,为啥exp方式导出一个用户的数据表,在imp后却发现有些表并没有迁移过来。经查阅官方文档,发现和Oracle11g及12c 版本相对于10g,有一个新特性deferred_segment_creation(…

C# 读写二进制文件

读写二进制文件的一种选择是直接使用流类型;在这种情况下,最好使用字节数组执行读写操作。另一个选择是使用为这个场景定义的读取器和写入器:BinaryReader和BinaryWriter。使用它们的方式类似于使用 StreamReader 和 StreamWriter&#xff0c…

推荐系统(1)--splitting approaches for context-aware recommendation

开篇语: 大一的时候。在实验室老师和师兄的带领下。我開始接触推荐系统。时光匆匆,转眼已是大三,因为大三课甚是少。于是便有了时间将自己所学的东西做下总结。第一篇博客。献给过去三年里带我飞的老师和师兄们,感谢你们的无私帮助…

python 百度云文字识别 proxy_python使用百度文字识别功能方法详解

介绍python使用百度智能去的文字识别功能,可以识别截图中的文,登陆路验证码等等。, 登陆百度智能云,选择产品服务。选择“人工智能”---文字识别。点击创建应用。 如图下面有关于“文字识别”的各类信息,如通用文字识别…

Android性能优化典范(转)

本文转自:http://hukai.me/android-performance-patterns/ 2015新年伊始,Google发布了关于Android性能优化典范的专题,一共16个短视频,每个3-5分钟,帮助开发者创建更快更优秀的Android App。课程专题不仅仅介绍了Andr…

Xamarin效果第二十一篇之GIS中可扩展浮动操作按钮

在前面文章中简单玩了玩GIS的基本操作、Mark相关、AR、测距和加载三维白模,今天再次对操作栏又一次修改了,直接放到了右下角可伸缩效果;啥也不说了都在效果里:添加支持圆角 ContentView:Xamarin.Forms.PancakeView再来Xamarin 社区工具包:Xamarin.CommunityToolkit再来看看最终…