Pytorch-day10-模型部署推理-checkpoint

模型部署&推理

  • 模型部署
  • 模型推理

我们会将PyTorch训练好的模型转换为ONNX 格式,然后使用ONNX Runtime运行它进行推理

1、ONNX

ONNX( Open Neural Network Exchange) 是 Facebook (现Meta) 和微软在2017年共同发布的,用于标准描述计算图的一种格式。ONNX通过定义一组与环境和平台无关的标准格式,使AI模型可以在不同框架和环境下交互使用,ONNX可以看作深度学习框架和部署端的桥梁,就像编译器的中间语言一样

由于各框架兼容性不一,我们通常只用 ONNX 表示更容易部署的静态图。硬件和软件厂商只需要基于ONNX标准优化模型性能,让所有兼容ONNX标准的框架受益

ONNX主要关注在模型预测方面,使用不同框架训练的模型,转化为ONNX格式后,可以很容易的部署在兼容ONNX的运行环境中

  • ONNX官网:https://onnx.ai/
  • ONNX GitHub:https://github.com/onnx/onnx

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4hoUBZ88-1692614464568)(attachment:image-2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PlCTmLyk-1692614464569)(attachment:image.png)]

2、ONNX Runtime

  • ONNX Runtime官网:https://www.onnxruntime.ai/
  • ONNX Runtime GitHub:https://github.com/microsoft/onnxruntime

ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,它直接对接ONNX,可以直接读取.onnx文件并实现推理,不需要再把 .onnx 格式的文件转换成其他格式的文件

PyTorch借助ONNX Runtime也完成了部署的最后一公里,构建了 PyTorch --> ONNX --> ONNX Runtime 部署流水线

安装onnx

pip install onnx

安装onnx runtime

pip install onnxruntime # 使用CPU进行推理

pip install onnxruntime-gpu # 使用GPU进行推理

注意:ONNX和ONNX Runtime之间的适配关系。我们可以访问ONNX Runtime的Github进行查看

网址:https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NVBVlhGG-1692614464569)(attachment:image.png)]

ONNX Runtime和CUDA之间的适配关系

网址:https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6x0xvNMn-1692614464569)(attachment:image-2.png)]

ONNX Runtime、TensorRT和CUDA的匹配关系:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G7NPCXmY-1692614464569)(attachment:image-3.png)]

3、模型转换为ONNX格式

  • 用torch.onnx.export()把模型转换成 ONNX 格式的函数
  • 模型导成onnx格式前,我们必须调用model.eval()或者model.train(False)以确保我们的模型处在推理模式下
import torch.onnx 
# 转换的onnx格式的名称,文件后缀需为.onnx
onnx_file_name = "resnet50.onnx"
# 我们需要转换的模型,将torch_model设置为自己的模型
model = torchvision.models.resnet50(pretrained=True)
# 加载权重,将model.pth转换为自己的模型权重
model = model.load_state_dict(torch.load("resnet50.pt"))
# 导出模型前,必须调用model.eval()或者model.train(False)
model.eval()
# dummy_input就是一个输入的实例,仅提供输入shape、type等信息 
batch_size = 1 # 随机的取值,当设置dynamic_axes后影响不大
dummy_input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) 
# 这组输入对应的模型输出
output = model(dummy_input)
# 导出模型
torch.onnx.export(model,        # 模型的名称dummy_input,   # 一组实例化输入onnx_file_name,   # 文件保存路径/名称export_params=True,        #  如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.opset_version=10,          # ONNX 算子集的版本,当前已更新到15do_constant_folding=True,  # 是否执行常量折叠优化input_names = ['conv1'],   # 输入模型的张量的名称output_names = ['fc'], # 输出模型的张量的名称# dynamic_axes将batch_size的维度指定为动态,# 后续进行推理的数据可以与导出的dummy_input的batch_size不同dynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})

注:
算子版本对照文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md

ONNX模型的检验

我们需要检测下我们的模型文件是否可用,我们将通过onnx.checker.check_model()进行检验

import onnx
# 我们可以使用异常处理的方法进行检验
try:# 当我们的模型不可用时,将会报出异常onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用时,将不会报出异常,并会输出“The model is valid!”print("The model is valid!")

ONNX模型可视化

使用netron做可视化。下载地址:https://netron.app/

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iEgN86DI-1692614464569)(attachment:image.png)]

模型的输入&输出信息:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qzyKV8ba-1692614464570)(attachment:image-2.png)]

使用ONNX Runtime进行推理


import onnxruntime
# 需要进行推理的onnx模型文件名称
onnx_file_name = "xxxxxx.onnx"# onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])# 构建字典的输入数据,字典的key需要与我们构建onnx模型时的input_names相同
# 输入的input_img 也需要改变为ndarray格式
# ort_inputs = {'conv_1': input_img} 
#建议使用下面这种方法,因为避免了手动输入key
ort_inputs = {ort_session.get_inputs()[0].name:input_img}# run是进行模型的推理,第一个参数为输出张量名的列表,一般情况可以设置为None
# 第二个参数为构建的输入值的字典
# 由于返回的结果被列表嵌套,因此我们需要进行[0]的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]

注意:

  • PyTorch模型的输入为tensor,而ONNX的输入为array,因此我们需要对张量进行变换或者直接将数据读取为array格式
  • 输入的array的shape应该和我们导出模型的dummy_input的shape相同,如果图片大小不一样,我们应该先进行resize操作
  • run的结果是一个列表,我们需要进行索引操作才能获得array格式的结果
  • 在构建输入的字典时,我们需要注意字典的key应与导出ONNX格式设置的input_name相同

完整代码

1. 安装&下载

#!pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
# Download ImageNet labels
#!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

2、定义模型

import torch
import io
import time
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
import onnxruntime
import torchvision
import numpy as np
from torch import nn
import torch.nn.init as init
onnx_file = 'resnet50.onnx'
save_dir = './resnet50.pt'

# 下载预训练模型
Resnet50 = torchvision.models.resnet50(pretrained=True)# 保存 模型权重
torch.save(Resnet50.state_dict(), save_dir)print(Resnet50)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=1000, bias=True)
)

3. 模型导出为ONNX格式


batch_size = 1    # just a random number
# 先加载模型结构
loaded_model = torchvision.models.resnet50()   
# 在加载模型权重
loaded_model.load_state_dict(torch.load(save_dir))
#单卡GPU
# loaded_model.cuda()# 将模型设置为推理模式
loaded_model.eval()
# Input to the model
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
torch_out = loaded_model(x)
torch_out
tensor([[-5.8050e-01,  7.5065e-02,  1.9404e-01, -9.1107e-01,  9.9716e-01,-1.2941e+00, -1.3402e-01, -6.4496e-01,  6.0434e-01, -1.6355e+00,-1.5187e-01,  1.0285e+00, -9.0719e-02, -2.6877e-01, -1.2656e+00,-7.9748e-01, -1.3802e+00, -9.6179e-01,  5.3512e-01,  8.3388e-02,-6.2868e-01,  1.5385e-01, -2.5405e-01,  4.3549e-01, -3.2834e-02,-8.9873e-01, -1.7059e+00, -8.5661e-01, -1.4386e+00, -2.0589e+00,-2.3464e+00, -3.6227e-01, -3.5712e+00, -1.6644e+00, -3.0064e-01,-1.8671e+00,  7.5745e-01, -2.3606e+00,  1.2460e-01,  2.7504e-01,-2.1071e-01, -2.6051e+00,  4.9932e-02, -3.0857e-01, -1.5757e-02,5.6365e-02,  1.0149e-01, -2.4776e+00,  1.7863e+00, -2.1650e+00,1.8615e+00, -2.8109e+00, -2.0084e+00, -5.4413e-01,  8.8444e-01,-8.8331e-01,  7.3980e-02, -2.0061e+00,  5.5653e-01,  7.1335e-01,4.6456e-01,  1.0112e+00,  4.2683e-01, -1.8685e-01, -1.1910e+00,1.6901e-01, -7.3501e-01, -2.4989e-01, -2.7711e-01,  1.8286e+00,-1.1317e+00,  1.9985e+00,  4.0941e-01,  2.7733e-01, -5.1216e-02,3.1703e-01, -2.1450e-01,  1.5035e+00,  1.2469e+00,  3.6729e+00,-1.2205e+00, -2.9484e-01, -3.2170e-01, -2.1006e+00, -1.2326e-01,3.9842e-01, -3.5075e-01,  1.5957e-01, -4.8100e-01,  1.2830e+00,-1.1557e+00,  2.9266e-01,  6.7955e-01,  1.2951e+00, -1.7461e-01,-3.4974e+00,  9.8954e-01, -1.1453e+00, -1.5246e+00,  7.6012e-01,-2.7971e-01, -1.0384e-01, -1.3282e+00,  3.7075e-01, -1.0879e+00,-2.2167e+00, -1.6805e+00,  1.5793e-01, -1.2778e+00, -3.4896e-01,6.2826e-01,  1.7638e+00, -8.2627e-01,  6.5328e-01,  5.1948e-01,-1.5375e+00, -2.7378e+00, -6.8703e-02, -1.5729e+00, -2.1919e+00,-1.0581e+00, -2.9345e+00, -3.2737e+00, -2.5095e+00, -2.5462e+00,-3.4298e+00,  1.0801e+00, -4.6679e-02, -7.1422e-01, -1.1388e+00,-2.2512e+00, -9.3222e-01,  2.7792e-01, -2.4730e-01, -1.3677e+00,-1.1018e+00, -2.3430e+00,  1.1828e+00,  1.5632e+00, -2.6486e+00,-2.2285e+00, -8.2680e-01, -1.9754e+00, -1.5034e+00, -2.1048e+00,1.0566e+00, -6.0091e-01, -2.2394e+00, -1.0461e+00, -1.4851e+00,9.9063e-02,  4.5648e-01, -3.0590e+00, -5.1038e-02, -2.2756e+00,-1.5584e+00, -2.6344e+00, -1.3177e+00, -2.4749e+00,  1.3347e-01,-1.8447e+00, -1.9380e+00, -1.1397e+00, -9.6618e-01, -4.7473e-01,-8.1531e-01, -2.0591e+00, -2.2707e+00, -2.1579e+00, -8.4820e-01,-1.8621e+00, -1.0359e+00, -1.7589e+00, -5.1326e-01, -1.9336e+00,-2.4361e+00, -3.0598e+00, -1.5690e+00,  7.9418e-01, -2.0329e+00,-1.4686e+00, -1.3989e+00, -1.2050e+00, -4.6212e-01, -2.1246e+00,3.9028e-02, -1.3888e+00, -8.1794e-01, -3.2460e+00, -2.9345e-01,-1.5963e+00, -1.4708e+00, -1.7513e+00, -1.0326e+00, -2.5880e+00,-3.5845e-02, -1.8802e+00, -2.0279e+00, -2.2119e+00, -5.6981e-01,-1.4423e+00, -5.3841e-01, -2.4736e-01,  1.4031e-01, -1.1382e+00,-1.3424e+00, -1.5412e-01, -1.5119e+00, -8.1195e-01, -2.3688e+00,-3.1494e+00, -1.2997e+00, -2.0867e+00, -1.5811e+00, -1.1873e+00,-1.4610e+00,  4.6883e-01, -1.3841e+00, -2.3627e+00, -5.0272e-01,-2.2311e+00,  2.8236e-01, -1.4063e+00, -6.1543e-01,  2.2254e-01,-1.8209e+00, -2.2796e+00, -1.4799e+00, -9.3366e-01, -4.5269e-01,-1.5885e+00, -3.5685e-01, -7.9922e-01, -1.7434e+00, -1.3543e+00,-5.9424e-01, -7.4004e-02, -4.8574e-01, -9.4252e-01, -1.1784e+00,-1.0762e+00, -7.0929e-01, -2.3507e+00, -1.5668e+00, -2.8629e+00,-9.7854e-01, -7.7075e-01, -2.1660e+00, -2.3006e-01, -6.7149e-01,-8.6158e-01, -1.7104e-02, -1.9825e+00, -7.7517e-01, -3.8014e-01,-2.1186e+00, -9.2220e-01, -9.2850e-01, -1.2418e+00,  9.7522e-02,-3.6667e-03, -2.1291e+00, -2.8809e+00, -1.3699e+00, -1.5959e+00,-6.5653e-01, -1.2664e+00, -2.8341e-01, -1.5526e+00, -7.1795e-01,-4.8103e-01, -1.6648e+00, -8.2810e-01, -1.6934e+00, -1.3563e+00,-1.6123e+00, -1.1855e+00, -1.2475e+00, -1.3781e+00, -9.8912e-01,-1.3062e-03,  1.2144e+00,  2.8563e+00,  1.7405e+00,  3.0779e-01,8.2037e-01, -4.7336e-01, -2.7651e+00,  4.0167e-01,  2.1637e-01,-5.0109e-01, -1.0902e+00, -2.6263e-01,  5.9031e-01, -5.2879e-01,1.0321e+00,  1.2048e+00,  1.6882e-01,  4.2126e-02, -3.8657e-01,-1.3633e+00,  2.0077e+00, -9.9282e-01, -1.6829e-01, -1.5846e+00,-2.1892e+00, -6.6651e-01,  9.6200e-01,  1.1047e+00, -3.3428e-01,2.7981e+00,  7.2582e-01,  3.4494e-01,  8.2232e-01,  1.7219e+00,1.0106e+00, -2.3200e-01,  4.9711e-02,  1.6123e+00,  8.3826e-01,-1.4559e+00, -2.4328e+00, -2.8555e+00, -2.6156e+00, -1.9900e+00,-2.4778e+00, -1.9356e+00, -1.5563e+00, -2.5033e+00, -3.5848e+00,-2.4205e-01, -5.5758e-01,  2.3322e-01, -1.1810e+00, -8.3212e-01,-4.8195e-02, -4.9411e-01, -3.0698e-03, -1.6134e+00, -1.5790e+00,-5.8626e-01, -1.8875e+00, -1.5670e+00, -2.0681e+00, -1.7590e+00,-3.9325e-01, -2.0172e+00, -1.3237e+00, -1.7693e-01, -8.5266e-01,-2.0535e+00, -2.7916e+00, -1.7173e+00,  5.3713e-02, -1.9363e-01,-3.1787e-01,  7.0567e-01,  5.3067e-01,  1.0458e+00,  1.2243e+00,-3.9257e-01, -3.9865e-01,  3.8122e-01,  3.4527e-01, -1.6836e+00,6.8797e-01,  1.2213e+00,  1.0733e+00,  1.1278e+00,  6.7682e-01,1.2179e+00, -8.0824e-01,  2.7535e-03, -8.5098e-01, -9.4244e-02,-3.7395e-01, -5.9386e-01, -8.1263e-02, -5.8865e-01, -8.3479e-01,-7.2452e-01, -1.6460e-01,  7.2182e-01,  1.2066e+00, -1.8087e+00,-4.4841e-01, -3.2795e-01, -3.0482e-01, -3.3302e-01, -2.4936e+00,-5.7049e-01, -2.0744e-02, -7.5551e-01, -2.4757e+00, -1.7799e+00,-1.1292e+00, -1.0917e+00,  6.8229e-01,  8.7337e-01,  3.1813e+00,-1.5752e+00,  1.0542e-01,  2.5594e+00, -1.0048e+00, -2.2436e+00,4.9551e-01, -2.0745e+00, -9.9214e-01, -2.5501e+00,  2.7392e+00,6.4982e-01,  3.5795e+00,  2.0882e+00,  1.0579e+00,  2.3663e+00,-1.1029e+00, -6.6217e-01, -4.8396e-01,  3.6624e+00,  2.3802e+00,8.2251e-01,  2.5061e+00, -1.8793e+00,  1.6354e+00,  1.9349e+00,7.7006e-01,  2.4251e-01,  1.7568e+00, -9.3206e-01,  1.2631e+00,1.0240e+00, -3.5013e-01,  7.5377e-03,  5.0503e-01, -9.5431e-01,1.5458e+00, -2.5770e+00,  5.7188e-01,  9.7471e-01, -3.1393e-01,1.0891e+00,  2.3057e+00, -7.5324e-01,  3.2789e+00, -8.1716e-01,-1.9879e+00,  5.5330e+00,  6.3507e-01, -1.1635e+00, -1.1235e+00,-3.4298e+00,  7.5610e-01, -3.1293e-02, -9.6185e-01, -8.1488e-02,1.1240e+00, -6.9891e-02,  2.5587e+00,  2.2736e+00,  1.7838e-01,-6.9245e-01,  2.4419e+00,  2.0427e+00,  1.1029e+00,  4.1609e+00,3.5126e+00, -1.8192e+00, -3.3070e+00,  7.6861e-01,  1.2807e+00,2.1298e-01, -8.7622e-01, -2.1935e+00,  1.0431e+00,  1.9949e+00,-3.2491e-01, -3.1093e+00, -1.0409e+00,  1.2334e+00, -1.7676e-01,3.0567e+00,  2.6081e+00,  2.7356e-01,  6.0596e-02, -1.3262e+00,-3.5291e-01, -4.7318e-01,  2.1949e+00,  5.3661e+00,  4.2932e+00,8.3733e+00,  4.1425e-01,  2.4924e-01, -1.3689e+00,  7.1289e-02,-9.8287e-01, -1.2412e+00,  1.3910e+00,  1.9533e+00,  3.3525e+00,1.7242e+00,  1.7637e+00,  1.0108e+00,  1.2255e+00,  1.7504e+00,5.4399e-01,  2.2958e+00,  1.9387e+00,  2.4723e+00, -1.1986e+00,-1.5123e+00, -1.9842e+00,  1.8934e+00,  1.3407e+00,  4.6350e-01,2.6674e+00,  1.0492e+00,  1.0988e+00, -1.4208e-02,  3.9129e-01,-4.7343e-01, -1.7139e+00, -7.8037e-01,  1.3938e+00,  2.4655e+00,-9.8006e-01, -5.5273e-01,  1.1947e+00,  1.5285e+00,  2.2214e-01,2.2346e+00,  1.3524e+00, -3.2841e-01,  2.1160e+00,  4.4156e+00,-2.7112e+00, -9.0547e-01, -1.4378e+00,  1.5687e+00,  3.1633e+00,-2.9853e-01,  1.2451e+00,  2.5149e+00,  1.0312e+00, -6.9518e-01,1.1537e+00,  9.6612e-01, -3.5077e+00, -7.9979e-02,  4.3770e+00,-6.3443e-01, -5.2904e-01,  1.5411e+00,  1.2678e+00, -1.2136e+00,-2.1303e+00,  5.5227e+00,  3.5111e-01,  1.5474e+00,  2.1807e+00,1.4828e+00, -1.4299e+00,  1.9229e+00,  2.4931e+00, -2.5156e+00,-1.7203e+00, -4.2708e-01,  1.6891e+00,  1.5878e+00, -3.3333e+00,2.1083e+00, -1.7954e-02,  3.9262e-01, -1.8340e+00,  7.8696e-01,-2.9308e+00, -2.3592e+00,  1.0347e+00,  8.9930e-01,  1.2392e+00,5.4734e-01,  6.6852e-01, -2.6781e+00,  2.2405e-01, -9.0210e-01,1.0648e+00, -2.3832e+00,  1.7305e+00,  1.6958e+00,  1.0681e+00,8.2608e-01,  2.5071e+00, -2.3054e-01,  3.9594e-01, -1.4630e-01,-2.1682e+00,  3.0358e+00,  1.5096e+00,  7.6303e-01,  4.4392e+00,3.2750e+00,  2.6279e+00,  4.3440e-01, -3.9379e+00,  1.0872e+00,1.7172e+00,  2.8548e+00, -1.0287e+00,  4.9895e+00, -2.0666e+00,4.8006e+00,  2.0120e+00, -1.5181e+00,  8.6181e-01, -3.4666e-01,2.2120e+00,  3.0910e+00,  5.9223e-01,  2.2166e+00,  3.9417e+00,3.5241e+00, -5.3305e-01,  3.5832e+00,  2.5654e+00, -1.5450e+00,-2.6835e+00,  3.1550e+00, -2.6302e+00,  2.3621e-01,  2.1758e+00,1.2487e+00, -1.0268e-01,  3.6262e+00,  3.6049e+00, -2.3248e+00,2.3213e-01,  3.2931e+00, -1.0058e+00,  4.5938e-01, -4.2993e-01,1.3951e+00, -2.8811e-01, -5.2850e-01,  1.0776e+00,  4.6138e+00,-7.1348e-01,  5.8099e-01,  4.4438e-01, -6.0801e-01,  7.0509e-01,3.5084e+00,  3.0626e+00,  7.0831e-01,  1.5073e+00, -2.1074e+00,3.2849e+00, -2.7267e+00,  2.9387e-01,  5.1394e-01,  1.4031e-01,-1.0694e+00, -2.5526e+00,  1.6833e+00, -1.3013e+00,  3.0083e+00,-1.9390e+00,  4.4978e-01, -1.5059e-01, -2.4490e+00,  1.6431e+00,-4.6816e-01, -1.6293e+00, -7.9092e-01,  1.1116e+00,  2.1265e+00,-3.0442e+00,  9.5523e-02,  2.8034e+00,  1.3312e+00,  3.4422e+00,4.4743e-01,  1.7062e+00,  1.8941e-01,  1.2406e+00, -9.8100e-01,-9.7636e-01, -3.9718e-01, -5.6298e-01,  2.1325e+00,  1.4298e+00,-4.6180e+00, -5.8675e-01,  1.7124e+00, -7.3919e-02, -2.9715e+00,2.9501e+00,  1.4472e+00, -1.3756e+00, -1.0018e+00, -1.1162e-01,1.2214e+00, -5.2164e-01, -8.7681e-01,  6.0252e-01,  2.7381e-01,-2.9817e+00, -1.3999e+00,  1.8137e+00, -3.4810e-02,  1.2475e+00,-5.1820e-01,  3.4469e+00,  2.8484e+00,  5.9049e-01,  2.2143e+00,-1.9403e-01,  1.5231e+00, -4.1188e+00,  5.6471e-01, -1.4212e+00,1.1938e+00,  2.8821e+00,  2.4709e+00, -1.6792e+00, -4.7604e-01,1.7501e+00, -2.2566e+00,  7.4556e-01,  2.5034e+00, -3.6194e-01,-1.1058e+00,  2.2076e+00, -6.0705e-03,  2.5470e+00, -1.9637e+00,2.7231e+00,  2.4390e+00,  1.1190e+00, -9.0371e-01, -4.4400e-01,8.6673e-01,  2.8887e+00, -6.5289e-01,  1.6986e+00,  6.0122e-01,-1.1510e+00,  1.9672e+00,  3.6989e+00,  1.3653e-01,  9.0087e-01,1.8489e+00, -2.7983e+00,  1.5802e+00,  2.6502e+00,  1.1414e+00,-5.3817e-01,  1.1085e+00, -2.1715e+00, -7.2016e-01,  1.5999e+00,4.9543e+00,  1.9814e+00, -1.1679e+00,  2.8527e+00,  2.1758e+00,7.5756e-01, -1.0221e+00,  1.2118e+00, -2.4591e-01,  1.4493e+00,3.4529e-01,  1.6389e+00,  4.0479e+00,  1.2619e+00,  4.2199e-01,-1.2010e+00,  2.7446e+00,  3.2914e+00,  1.6454e+00, -4.8627e-01,-3.6592e-01,  1.1508e+00,  4.4760e+00,  3.3516e+00,  2.9289e+00,1.6571e+00, -6.9271e-02,  1.5371e+00, -1.6635e-01,  2.8581e+00,1.0374e+00,  1.1429e+00,  2.1297e+00,  1.0264e+00,  4.7174e+00,-8.5201e-01,  1.7106e+00,  7.4727e-01,  6.5346e-01,  1.6801e+00,-3.7609e-01, -1.5926e+00, -2.6283e+00, -1.6866e+00,  5.5250e-02,-6.2809e-02,  5.9573e-01, -7.4590e-01,  5.3049e-01, -1.5091e+00,-8.0366e-01,  3.3241e+00,  2.3141e+00,  1.1193e+00, -1.6830e+00,3.3035e+00,  2.9134e-01, -2.9930e+00,  2.4471e+00,  9.8725e-01,-2.7953e+00, -1.7308e+00, -9.4977e-01,  1.6247e-01,  2.5793e+00,2.9449e-01,  2.1876e+00,  1.3091e-01,  6.2929e+00, -5.5488e-01,1.2929e+00, -9.5095e-03, -1.1349e+00, -1.0178e-01,  2.3317e+00,-4.3678e-01,  2.3839e+00,  2.6191e+00, -2.0215e+00,  1.5188e+00,3.1490e+00,  3.1997e+00, -2.2047e-01, -1.2029e-01,  2.7171e+00,3.1623e+00,  7.7251e-01, -1.8028e+00, -7.3017e-01,  1.5781e+00,7.6143e-01,  4.7296e+00,  1.7691e+00,  1.4732e+00,  2.0614e+00,2.2509e+00, -4.4578e+00,  1.1764e+00,  2.2630e+00,  5.7318e-01,4.3310e-01,  1.6570e+00, -1.4352e+00, -1.2535e+00, -4.0429e+00,-5.1775e-01, -1.5580e+00, -1.8145e+00,  2.4469e+00,  1.9574e+00,-2.0032e-01, -2.0393e+00,  3.3668e+00, -5.2449e-01, -4.5653e+00,4.8361e-01,  4.8011e-01,  8.3248e-01, -1.4842e-01,  2.5230e+00,-3.1912e-01,  1.1091e+00,  1.9290e+00,  6.5501e-01,  7.5642e-01,1.3678e+00,  1.6187e+00, -2.2867e+00, -1.3338e+00,  7.0305e-01,-2.6969e+00, -3.4848e-01,  3.5779e+00,  2.5296e+00,  1.2646e+00,-8.2202e-01,  1.5727e+00,  2.0048e+00,  1.9939e+00,  3.6664e-01,-3.7189e-01,  6.5360e-02,  2.5970e+00,  1.9509e+00,  7.9060e+00,4.1564e+00,  1.9750e+00,  1.3692e+00,  7.0074e-01,  1.3194e+00,1.5737e+00,  3.1158e+00,  2.8220e-01, -1.1930e+00, -2.9132e+00,3.6715e-01,  2.0554e+00, -4.5951e-01,  1.4659e+00,  1.6097e-01,3.5082e-01,  1.9813e+00,  2.3234e+00, -1.6767e+00, -1.9703e+00,-4.2028e-01, -2.6262e+00, -1.3928e+00, -7.6662e-01,  4.5116e-01,2.6828e-01, -2.8156e-01,  7.0492e-02, -2.3663e+00, -5.0179e-01,-1.6241e-01, -2.5555e+00, -9.8973e-02, -2.2130e+00, -2.3067e+00,-1.8250e+00, -1.8571e+00, -2.4779e+00, -2.7528e+00, -2.9528e+00,-9.4892e-01, -2.8599e+00, -6.0309e-01, -1.4899e-01, -9.7413e-01,9.2476e-01,  1.2974e+00, -8.6647e-01, -1.4522e-01,  1.5039e+00,1.5240e-01, -1.9550e+00, -1.3404e+00,  5.6667e-01, -1.2009e+00,-9.4940e-01,  1.0278e+00, -2.9112e+00, -6.9027e-01, -8.4326e-01,-1.5937e+00,  1.6618e+00,  3.1860e+00,  3.0757e+00,  4.0690e-01,-1.1017e+00,  3.6284e+00, -6.9720e-01, -1.3498e+00,  1.4283e-01,-4.1820e-01, -1.6470e+00,  4.1369e-01,  1.7120e-01, -1.7615e+00,7.3642e-01,  1.7452e+00,  4.3359e-01, -2.8788e-01, -6.6571e-02,-1.4325e-02, -2.2441e+00,  1.2690e+00, -7.3996e-01, -1.1551e+00,-1.4367e+00, -1.5546e+00, -2.9878e+00, -3.5215e+00, -4.2169e+00,-3.7416e+00, -2.0244e+00, -2.6461e+00, -1.1108e+00,  1.1864e+00]],grad_fn=<AddmmBackward0>)
torch_out.size()
torch.Size([1, 1000])

# 导出模型
torch.onnx.export(loaded_model,               # model being runx,             # model input (or a tuple for multiple inputs)onnx_file,   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,   # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['conv1'],   # the model's input namesoutput_names = ['fc'], # the model's output names# variable length axesdynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

4、检验ONNX模型

# 我们可以使用异常处理的方法进行检验
try:# 当我们的模型不可用时,将会报出异常onnx.checker.check_model(onnx_file)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用时,将不会报出异常,并会输出“The model is valid!”print("The model is valid!")
The model is valid!

5. 使用ONNX Runtime进行推理

import onnxruntime
import numpy as nport_session = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])# 将张量转化为ndarray格式
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 构建输入的字典和计算输出结果
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# 比较使用PyTorch和ONNX Runtime得出的精度
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!

6. 进行实际预测并可视化

# 推理数据
from PIL import Image
from torchvision.transforms import transforms# 生成推理图片
image = Image.open('./images/cat.jpg')# 将图像调整为指定大小
image = image.resize((224, 224))# 将图像转换为 RGB 模式
image = image.convert('RGB')image.save('./images/cat_224.jpg')
categories = []
# Read the categories
with open("./imagenet/imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]def get_class_name(probabilities):# Show top categories per imagetop5_prob, top5_catid = torch.topk(probabilities, 5)for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())
#预处理
def pre_image(image_file):input_image = Image.open(image_file)preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])input_tensor = preprocess(input_image)inputs = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model# input_arr = inputs.cpu().detach().numpy()return inputs 
#inference with model# 先加载模型结构
resnet50 = torchvision.models.resnet50()   
# 在加载模型权重
resnet50.load_state_dict(torch.load(save_dir))resnet50.eval()  
#推理
input_batch = pre_image('./images/cat_224.jpg')# move the input and model to GPU for speed if available
print("GPU Availability: ", torch.cuda.is_available())
if torch.cuda.is_available():input_batch = input_batch.to('cuda')resnet50.to('cuda')with torch.no_grad():output = resnet50(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
get_class_name(probabilities)
GPU Availability:  False
Persian cat 0.6668420433998108
lynx 0.023987364023923874
bow tie 0.016234245151281357
hair slide 0.013150070793926716
Japanese spaniel 0.012279157526791096
input_batch.size()
torch.Size([1, 3, 224, 224])
#benchmark 性能
latency = []
for i in range(10):with torch.no_grad():start = time.time()output = resnet50(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)# for catid in range(top5_catid.size(0)):#     print(categories[catid])latency.append(time.time() - start)print("{} model inference CPU time:cost {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 model inference CPU time:cost 149.59 ms
1 model inference CPU time:cost 130.74 ms
2 model inference CPU time:cost 133.76 ms
3 model inference CPU time:cost 130.64 ms
4 model inference CPU time:cost 131.72 ms
5 model inference CPU time:cost 130.88 ms
6 model inference CPU time:cost 136.31 ms
7 model inference CPU time:cost 139.95 ms
8 model inference CPU time:cost 141.90 ms
9 model inference CPU time:cost 140.96 ms
# Inference with ONNX Runtime
import onnxruntime
from onnx import numpy_helper
import time
onnx_file = 'resnet50.onnx'
session_fp32 = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])def softmax(x):"""Compute softmax values for each sets of scores in x."""e_x = np.exp(x - np.max(x))return e_x / e_x.sum()latency = []
def run_sample(session, categories, inputs):start = time.time()input_arr = inputsort_outputs = session.run([], {'conv1':input_arr})[0]output = ort_outputs.flatten()output = softmax(output) # this is optionaltop5_catid = np.argsort(-output)[:5]# for catid in top5_catid:#     print(categories[catid])latency.append(time.time() - start)return ort_outputs

input_tensor = pre_image('./images/cat_224.jpg')
input_arr = input_tensor.cpu().detach().numpy()
for i in range(10):ort_output = run_sample(session_fp32, categories, input_arr)print("{} ONNX Runtime CPU Inference time = {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 ONNX Runtime CPU Inference time = 67.66 ms
1 ONNX Runtime CPU Inference time = 56.30 ms
2 ONNX Runtime CPU Inference time = 53.90 ms
3 ONNX Runtime CPU Inference time = 58.18 ms
4 ONNX Runtime CPU Inference time = 64.53 ms
5 ONNX Runtime CPU Inference time = 62.79 ms
6 ONNX Runtime CPU Inference time = 61.75 ms
7 ONNX Runtime CPU Inference time = 60.51 ms
8 ONNX Runtime CPU Inference time = 59.35 ms
9 ONNX Runtime CPU Inference time = 57.57 ms

4、扩展知识

  • 模型量化
  • 模型剪裁
  • 工程优化
  • 算子优化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/47055.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

商业智能BI是什么都不明白,如何实现数字化?

2021年下半年中国商业智能软件市场规模为4.8亿美元&#xff0c;2021年度市场规模达到7.8亿美元&#xff0c;同比增长34.9%&#xff0c;呈现飞速增长的趋势。数字化时代&#xff0c;商业智能BI对于企业的落地应用有着巨大价值&#xff0c;逐渐成为了现代企业信息化、数字化转型中…

【Leetcode Sheet】Weekly Practice 3

Leetcode Test 833 字符串中的查找与替换(8.15) 你会得到一个字符串 s (索引从 0 开始)&#xff0c;你必须对它执行 k 个替换操作。替换操作以三个长度均为 k 的并行数组给出&#xff1a;indices, sources, targets。 要完成第 i 个替换操作: 检查 子字符串 sources[i] 是否…

怎么借助ChatGPT处理数据结构的问题

目录 使用ChatGPT进行数据格式化转换 代码示例 ChatGPT格式化数据提示语 代码示例 批量格式化数据提示语 代码示例 ChatGPT生成的格式化批处理代码 使用ChatGPT合并不同数据源的数据 合并数据提示语 自动合并数据提示语 ChatGPT生成的自动合并代码 结论 数据合并是…

在Windows下安装PIP+Phantomjs+Selenium

最近准备深入学习Python相关的爬虫知识了&#xff0c;如果说在使用Python爬取相对正规的网页使用"urllib2 BeautifulSoup 正则表达式"就能搞定的话&#xff1b;那么动态生成的信息页面&#xff0c;如Ajax、JavaScript等就需要通过"Phantomjs CasperJS Selen…

【从零开始的rust web开发之路 二】axum中间件和共享状态使用

系列文章目录 第一章 axum学习使用 第二章 axum中间件使用 文章目录 系列文章目录前言一、中间件是什么二、中间件使用常用中间件使用中间件使用TraceLayer中间件实现请求日志打印自定义中间件 共享状态 前言 上篇文件讲了路由和参数相应相关的。axum还有个关键的地方是中间件…

clickhouse-备份恢复

一、简介 备份恢复是数据库常用的手段&#xff0c;可能大多数公司很少会对大数据所使用的数据进行备份&#xff0c;这里还是了解下比较好&#xff0c;下面做了一些简单的介绍&#xff0c;详细情况可以通过官网来查看&#xff0c;经过测试发现Disk中增量备份并不好用&#xff0…

电工-学习电工有哪些好处

学习电工有哪些好处&#xff1f;在哪学习电工&#xff1f; 学习电工有哪些好处&#xff1f;在哪学习电工&#xff1f;学习电工可以做什么&#xff1f;优势有哪些&#xff1f; 学习电工可以做什么&#xff1f;学习电工有哪些好处&#xff1f; 就业去向&#xff1a;可在企业单位…

“深入探索JVM内部机制:理解Java虚拟机的运行原理“

标题&#xff1a;深入探索JVM内部机制&#xff1a;理解Java虚拟机的运行原理 摘要&#xff1a;本篇博客将深入探索Java虚拟机&#xff08;JVM&#xff09;的内部机制&#xff0c;帮助读者理解JVM的运行原理。我们将介绍JVM的组成结构&#xff0c;包括类加载器、运行时数据区域…

基于微信小程序的垃圾分类系统设计与实现(2.0 版本,附前后端代码)

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 1 简介 视频演示地址&#xff1a; 基于微信小程序的智能垃圾分类回收系统&#xff0c;可作为毕业设计 小…

PyCharm PyQt5 开发环境搭建

环境 python&#xff1a;3.6.x PyCharm&#xff1a;PyCharm 2019.3.5 (Community Edition) 安装PyQT5 pip install PyQt5 -i https://pypi.douban.com/simplepip install PyQt5-tools -i https://pypi.douban.com/simple配置PyCharm PyQtUIC Program &#xff1a;D:\Pytho…

Android kotlin 跳转手机热点开关页面和判断热点是否打开

Android kotlin 跳转手机热点开关页面和判断热点是否打开 判断热点是否打开跳转手机热点开关页面顺带介绍一些其他常用的设置页面跳转 其他热点的一些相关知识Local-only hotspot 参考 判断热点是否打开 网上方法比较多&#xff0c;我这边使用了通过WifiManager 拿反射的getWi…

从C语言到C++_33(C++11_上)initializer_list+右值引用+完美转发+移动构造/赋值

目录 1. 列表初始化initializer_list 2. 前面提到的一些知识点 2.1 小语法 2.2 STL中的一些变化 3. 右值和右值引用 3.1 右值和右值引用概念 3.2 右值引用类型的左值属性 3.3 左值引用与右值引用比较 3.4 右值引用的使用场景 3.4.1 左值引用的功能和短板 3.4.2 移动…

使用ApplicationRunner简化Spring Boot应用程序的初始化和启动

ApplicationRunner这个接口&#xff0c;我们一起来了解这个组件&#xff0c;并简单使用它吧。&#x1f92d; 引言 在开发Spring Boot应用程序时&#xff0c;应用程序的初始化和启动是一个重要的环节。ApplicationRunner是Spring Boot提供的一个有用的接口&#xff0c;可以帮助…

【Spring Boot】SpringBoot和数据库交互: 使用Spring Data JPA

文章目录 1. 数据库和Java应用程序1.1 为什么需要数据库交互1.2 传统的数据库交互方法 2. 什么是JPA2.1 JPA的定义2.2 JPA的优势 3. Spring Data JPA介绍3.1 Spring Data JPA的特性3.2 如何简化数据库操作 4. 在SpringBoot中集成Spring Data JPA4.1 添加依赖4.2 配置数据源 5. …

【javaweb】学习日记Day3 - Ajax 前后端分离开发 入门

目录 一、Ajax 1、简介 2、Axios &#xff08;没懂 暂留&#xff09; &#xff08;1&#xff09;请求方式别名 &#xff08;2&#xff09;发送get请求 &#xff08;3&#xff09;发送post请求 &#xff08;4&#xff09;案例 二、前端工程化 1、Vue项目-目录结构 2、…

第10步---MySQL的日志操作

第10步---MySQL的日志操作 错误日志 慢日志 1.查看错误日志 -- 查看日志信息 show VARIABLES like log_error%;2.查看binlog 高版本是默认开启的&#xff0c;低的是默认是不开启的 binlog日志文件是与事务相关 -- 查看binlog日志的格式 show variables like binlog_format;-- …

vue-router在vue2/3区别

构建选项区别 vue2-router const router-new VueRouter({mode:history,base:_name,})vue-next-router import { createRouter,createWebHistory} from vue-next-router const routercreateRouter({history:createHistory(/) })在上述代码中我们发现,vue2中的构建选项mode和ba…

【Python机器学习】实验15 将Lenet5应用于Cifar10数据集(PyTorch实现)

文章目录 CIFAR10数据集介绍1. 数据的下载2.修改模型与前面的参数设置保持一致3. 新建模型4. 从数据集中分批量读取数据5. 定义损失函数6. 定义优化器7. 开始训练8.测试模型 9. 手写体图片的可视化10. 多幅图片的可视化 思考题11. 读取测试集的图片预测值&#xff08;神经网络的…

k8s集群证书过期后,如何更新k8s证书

对于版本 1.21.5&#xff0c;这是我的解决方案&#xff1a; 步骤1&#xff1a; ssh 到主节点&#xff0c;然后在步骤 2 中检查证书。 步骤2&#xff1a; 运行这个命令&#xff1a;kubeadm certs check-expiration rootkube-master-1:~# kubeadm certs check-expiration [c…

qt中窗口的布局

qt中窗口的布局 常用的窗口布局方式使用拖拽控件的方式调用窗口布局使用Widget控件完成窗口布局布局中嵌套布局demo&#xff08;制作登录页面&#xff09; 如果不使用窗口布局&#xff0c;会带来的后果&#xff1a; 控件可能显示不出来不能按照期望的大小显示不能跟随窗口进行…