Load and run model predictions 加载和运行模型预测
Load the model 加载模型
在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU(),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits
加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict()
方法加载参数。
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True)(5): ReLU())
)
**注意:**请务必在推理之前调用
model.eval()
方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。
Model Inference 模型推理
优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。
ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。
Exporting the model to ONNX 将模型导出到 ONNX
PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape
函数中获取大小:tensor.shape
):
input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)
我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。
test_data = datasets.FashionMNIST(root = "data",train = False,download = True,transform = ToTensor()
)classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]x, y = test_data[0][0], test_data[0][1]
我们使用 onnxruntime.InferenceSession
创建推理会话。要推断 ONNX 模型,请调用 run
并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。
session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].nameresult = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
Predicted: "Ankle boot", Actual: Ankle boot
ONNX 模型使您能够在不同平台上以不同编程语言运行推理。
知识检查
什么是 PyTorch 模型 state_dict?
它是模型的内部状态字典,用于存储已学习的参数。