一、定义
- 定义
- 接口介绍
- 案例
二、实现
-
定义
- torch.compile 是加速 PyTorch 代码的最新方法! torch.compile 通过 JIT 将 PyTorch 代码编译成优化的内核,使 PyTorch 代码运行得更快,大部分过程仅需修改一行代码。
- torch.compile 的一个重要组件就是 TorchDynamo。TorchDynamo 负责将任意 Python 代码即时编译成 FX Graph(计算图),然后可以进一步优化。TorchDynamo 通过在运行时分析 Python 字节码并检测对 PyTorch 操作的调用来提取 FX Graph。
- torch.compile 的另一个重要组件 TorchInductor 会将 FX Graph 进一步编译成优化的内核。TorchDynamo 允许使用不同的后端,所以为了检查 TorchDynamo 输出的 FX Graph,可以创建一个自定义后端来输出 FX Graph 并简单地返回 Graph 未优化的前向内容。
- 允许自定义函数
开始编译的时候需要耗费大量的时间,即第一次请求,时间较长。
5. 详情见: https://pytorch.org/docs/stable/torch.compiler.html
https://pytorch.org/get-started/pytorch-2.0/
-
接口介绍
modoel_compile = torch.compile(model, mode="reduce-overhead")
(默认)default: 适合加速大模型,编译速度快且无需额外存储空间
reduce-overhead:适合加速小模型,需要额外存储空间
max-autotune:编译速度非常耗时,但提供最快的加速
- 案例
import torch
def foo(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
#方式二
@torch.compile
def opt_foo2(x, y):a = torch.sin(x)b = torch.cos(x)return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
方式三
class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.lin = torch.nn.Linear(100, 10)def forward(self, x):return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
训练
import torch
import torchvision.models as modelsmodel = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()
保存:
torch.save(optimized_model.state_dict(), "foo.pt")
# both these lines of code do the same thing
torch.save(model.state_dict(), "foo.pt")
推理:
# API Not Final
exported_model = torch._dynamo.export(model, input)
torch.save(exported_model, "foo.pt")