TorchScript是什么?
TorchScript - PyTorch master documentationpytorch.orgTorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。
我们提供了一些工具来增量地将模型从纯Python程序转换为能够独立于Python运行的TorchScript程序,例如在独立的c++程序中。这使得使用熟悉的Python工具在PyTorch中训练模型,然后通过TorchScript将模型导出到生产环境中成为可能,在这种环境中,Python程序可能由于性能和多线程的原因不适用。
编写TorchScript代码
torch.jit.script(obj)
脚本化一个函数或者nn.Module对象,将会检查它的源代码, 将其作为TorchScript代码使用TorchScrit编译器编译它,返回一个ScriptModule或ScriptFunction。 TorchScript语言自身是Python语言的一个子类, 因此它并非具有所有的Python语言特性。 torch.jit.script能够被作为函数或装饰器使用。参数obj可以是class, function, nn.Module。
具体地,脚本化一个函数: torch.jit.script
装饰器将会通过编译函数被装饰函数体来构造一个ScriptFunction对象。例如:
import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rprint(type(foo)) # torch.jit.ScriptFuncion# See the compiled graph as Python code
print(foo.code)
脚本化一个nn.Module:默认地编译其forward方法,并递归地编译其子模块以及被forward调用的函数。如果一个模块只使用TorchScript中支持的特性,则不需要更改原始模块代码。编译器将构建ScriptModule,其中包含原始模块的属性、参数和方法的副本。例如:
import torchclass MyModule(torch.nn.Module):def __init__(self, N, M):super(MyModule, self).__init__()# This parameter will be copied to the new ScriptModuleself.weight = torch.nn.Parameter(torch.rand(N, M))# When this submodule is used, it will be compiledself.linear = torch.nn.Linear(N, M)def forward(self, input):output = self.weight.mv(input)# This calls the `forward` method of the `nn.Linear` module, which will# cause the `self.linear` submodule to be compiled to a `ScriptModule` hereoutput = self.linear(output)return outputscripted_module = torch.jit.script(MyModule(2, 3))
编译一个不在forward中的方法以及递归地编译其内的所有方法,可在此方法上使用装饰器torch.jit.export
为了忽视某些方法也可以使用装饰器为了忽视某些方法也可以使用装饰器torch.jit.ignore
和torch.jit.unused
import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()@torch.jit.exportdef some_entry_point(self, input):return input + 10@torch.jit.ignoredef python_only_fn(self, input):# This function won't be compiled, so any# Python APIs can be usedimport pdbpdb.set_trace()def forward(self, input):if self.training:self.python_only_fn(input)return input * 99scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
torch.jit.trace(func,example_inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)
跟踪一个函数并返回一个可执行的或ScriptFunction对象,将使用即时编译(JIT)进行优化。跟踪非常适合那些只操作单张量或张量的列表、字典和元组的代码。使用torch.jit.trace
和torch.jit.trace_module
,你能将一个模型或python函数转为TorchScript中的ScriptModule
或ScriptFunction
。根据你提供的输入样例,它将会运行 该函数并记录所有张量上执行的操作。
Tracing 仅仅正确地记录那些不是数据依赖的函数和nn.Module(例如没有对数据的条件判断) 并且它们也没有任何未跟踪的外部依赖(例如执行输入输出或访问全局变量). Tracing 只记录在给定张量上运行给定函数时所执行的操作。 因此,返回的ScriptModule将始终在任何输入上运行相同的跟踪图。当你的模块需要根据输入和/或模块状态运行不同的操作集时,这就产生了一些重要的影响。例如:
- Tracing不会记录任何类似if语句或循环的控制流。当这个控制流在您的模块中是常量时,这是没有问题的,并且它通常内联了控制流决策。但有时控制流实际上是模型本身的一部分。例如,一个递归网络是一个输入序列长度(可能是动态的)的循环。
- 在返回的ScriptModule中,无论ScriptModule处于哪种模式,在train和eval模式中具有不同行为的操作都将始终表现为处于跟踪时所处的模式。
在这种情况下,Trace是不合适的,Script是更好的选择。如果你跟踪这样的模型,您可能会在后续的模型调用中得到不正确的结果。当执行可能导致产生错误跟踪的操作时,跟踪程序将尝试发出警告。
tracing a function:
import torchdef foo(x, y):return 2 * x + y# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment
tracing a existing module
import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)
跟踪一个模块并返回一个可执行的ScriptModule,该脚本模块将使用即时编译进行优化。当一个模块被传递到torch.jit.trace
,只运行和跟踪forward方法。使用trace_module,您可以为要跟踪的示例输入指定一个方法名字典(参见下面的example_input参数)。
import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Conv2d(1, 1, 3)def forward(self, x):return self.conv(x)def weighted_kernel_sum(self, weight):return weight * self.conv.weightn = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)
class torch.jit.ScriptModule
ScriptModule 封装一个c++接口中的torch::jit::Module
类, 有下列属性及方法:
code
返回forward方法的内部图的打印表示(具有有效的Python语法)graph
返回forward方法的内部图的字符串表示形式inlined_graph
返回forward方法的内部图的字符串表示形式。此图将被预处理为内联所有函数和方法调用。save(f,_extra_files=ExtraFilesMap{})
class torch.jit.ScriptFunction 与上者类似
torch.jit.save(m,f,_extra_files=ExtraFilesMap{})
保存此模块的脱机版本,以便在单独的进程中使用。所保存的模块序列化此模块的所有方法、子模块、参数和属性。它可以使用torch::jit::load(文件名)加载到c++ API中,也可以使用torch.jit.load加载到Python API中。为了能够保存模块,它必须不调用任何本机Python函数。这意味着所有子模块也必须是ScriptModule的子类。所有模块,不管它们的设备是什么,总是在加载过程中加载到CPU上。这与torch.load()的语义不同,将来可能会改变。
import torch
import ioclass MyModule(torch.nn.Module):def forward(self, x):return x + 10m = torch.jit.script(MyModule())# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)# Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
torch.jit.load(f,map_location=None,_extra_files=ExtraFilesMap{})
加载先前用torch.jit.save保存的ScriptModule或ScriptFunction所有之前保存的模块,无论它们的设备是什么,都首先加载到CPU上,然后移动到它们保存的设备上。如果失败(例如,因为运行时系统没有特定的设备),就会引发异常。
import torch
import iotorch.jit.load('scriptmodule.pt')# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:buffer = io.BytesIO(f.read())# Load all tensors to the original device
torch.jit.load(buffer)# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')# Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
torch.jit.ignore(drop=False, **kwargs)
这个装饰器向编译器表明,一个函数或方法应该被忽略,并保留为Python函数。这允许您在模型中保留尚未与TorchScript兼容的代码。如果从TorchScript调用,被忽略的函数将把调用分派给Python解释器。函数被忽略的模型不能导出。使用drop=True参数时可以,但会抛出异常。最好使用torch.jit.unused
import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignoredef debugger(self, x):import pdbpdb.set_trace()def forward(self, x):x += 10# The compiler would normally try to compile `debugger`,# but since it is `@ignore`d, it will be left as a call# to Pythonself.debugger(x)return xm = torch.jit.script(MyModule())# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
使用torch.jit.ignore(drop=True), 这一方法已被torch.jit.unused替代。
import torch
import torch.nn as nnclass MyModule(nn.Module):@torch.jit.ignore(drop=True)def training_method(self, x):import pdbpdb.set_trace()def forward(self, x):if self.training:self.training_method(x)return xm = torch.jit.script(MyModule())# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")
torch.jit.unused(fn)
这个装饰器向编译器表明,应该忽略一个函数或方法,并用引发异常来替换它。这允许您在模型中保留与TorchScript不兼容的代码,同时仍然导出模型。
import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self, use_memory_efficent):super(MyModule, self).__init__()self.use_memory_efficent = use_memory_efficent@torch.jit.unuseddef memory_efficient(self, x):import pdbpdb.set_trace()return x + 10def forward(self, x):# Use not-yet-scriptable memory efficient modeif self.use_memory_efficient:return self.memory_efficient(x)else:return x + 10m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))
混合Tracing和Scripting
在许多情况下,跟踪或脚本是将模型转换为TorchScript的一种更简单的方法。可以编写跟踪和脚本来满足模型某一部分的特定需求。
脚本函数可以调用跟踪函数。当您需要围绕一个简单的前馈模型使用控制流时,这一点特别有用。例如,序列到序列模型的波束搜索通常用脚本编写,但可以调用使用跟踪生成的编码器模块。
例如在脚本中调用跟踪函数
import torchdef foo(x, y):return 2 * x + ytraced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))@torch.jit.script
def bar(x):return traced_foo(x, x)
跟踪函数也可以调用脚本函数。当模型的一小部分需要一些控制流时,这是很有用的,即使大部分模型只是一个前馈网络。由跟踪函数调用的脚本函数中的控制流被正确保存。
例如在跟踪函数中调用脚本函数
import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
这个组合也适用于nn.Module。
import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super(MyScriptModule, self).__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(),torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())