一、定义
- 定义
- script、eager、onnx 模式对比
- 案例
- 生成的模型可以被c++调用
- 接口解读
二、实现
- 定义
- 可以在高性能环境libtorch(C ++)中直接加载,实现模型推理,而无需Pytorch训练框架依赖
- 无需代码,直接加载模型,实现推理。
- 主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。
教程网址:https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
接口网址:https://pytorch.org/docs/stable/jit.html
- script、eager、onnx 模式对比
- 案例
模型转为脚本:
import torch
from torch import nn
from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)
# 修改模型
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity() # type: ignore
model.fc = nn.Linear(model.fc.in_features, 10)
model.eval()script_module = torch.jit.trace(model, example_inputs=torch.randn([1, 3, 224, 224]))
torch.jit.save(script_module, "quant_model.pth")
脚本文件加载、推理
import torch
#推理时加载模型
quantized_recover_model = torch.jit.load("quant_model.pth")
with torch.no_grad(): # 设置禁止计算梯度inputs = torch.randn([1, 3, 224, 224])outputs = quantized_recover_model(inputs) # 前向传播print(outputs)
4.生成的模型可以被c++调用
c++代码// 加载生成的torchscript模型
auto module = torch::jit::load('jit_model.pth');
// 根据任务需求读取数据
std::vector<torch::jit::IValue> inputs = ...;
// 计算推理结果
auto output = module.forward(inputs).toTensor();
- 接口解读