【深度学习PyTorch简介】7.Load and run model predictions 加载和运行模型预测

Load and run model predictions 加载和运行模型预测

Load the model 加载模型

在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。

%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor

为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。

class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU(),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits

加载模型权重时,我们需要首先实例化模型类,因为该类定义了网络的结构。接下来,我们使用 load_state_dict() 方法加载参数。

model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True)(5): ReLU())
)

**注意:**请务必在推理之前调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。否则,您将看到不一致的推理结果。

Model Inference 模型推理

优化模型以在各种平台和编程语言上运行是很困难的。在所有不同的框架和硬件组合中最大限度地提高性能非常耗时。Open Neural Network Exchange (ONNX) 开放神经网络交换运行时为您提供了一种解决方案,可在任何硬件、云或边缘设备上进行一次训练并加速推理。

ONNX 是许多供应商支持的通用格式,用于共享神经网络和其他机器学习模型。您可以使用 ONNX 格式在其他编程语言(Java, JavaScript, C# 和 ML.NET)和框架上对模型进行推理。

Exporting the model to ONNX 将模型导出到 ONNX

PyTorch 还具有本机 ONNX 导出支持。然而,考虑到 PyTorch 执行图的动态特性,导出过程必须遍历执行图以生成持久的 ONNX 模型。因此,应将适当大小的测试变量传递到导出例程中(在我们的例子中,我们将创建正确大小的虚拟零张量。您可以从训练数据集的shape函数中获取大小:tensor.shape):

input_image = torch.zeros((1,28,28))
onnx_model = 'data/model.onnx'
onnx.export(model, input_image, onnx_model)

我们将使用测试数据集作为示例数据,从 ONNX 模型进行推理以进行预测。

test_data = datasets.FashionMNIST(root = "data",train = False,download = True,transform = ToTensor()
)classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]x, y = test_data[0][0], test_data[0][1]

我们使用 onnxruntime.InferenceSession 创建推理会话。要推断 ONNX 模型,请调用 run 并传入您想要返回的输出列表(如果您需要所有输出,请保留为空)和输入值的映射。结果是输出列表。

session = onnxruntime.InferenceSession(onnx_model, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].nameresult = session.run([output_name], {input_name:x.numpy()})
predicted, actual = classes[result[0][0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: {actual}')
Predicted: "Ankle boot", Actual: Ankle boot

ONNX 模型使您能够在不同平台上以不同编程语言运行推理。

知识检查

什么是 PyTorch 模型 state_dict?

它是模型的内部状态字典,用于存储已学习的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/642799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS高级技巧导读

1,精灵图 1.1 为什么需要精灵图? 目的:为了有效地减少服务器接收和发送请求的次数,提高页面的加载速度 核心原理:将网页中的一些小背景图像整合到一张大图中,这样服务器只需要一次请求就可以了 1.2 精灵…

【WPF.NET开发】克隆打印机

本文内容 大多数企业有时会购买多台同一型号的打印机。 通常,这些打印机都安装了几乎相同的配置设置。 安装每台打印机既费时又容易出错。 使用 Microsoft .NET Framework 公开的 System.Printing.IndexedProperties 命名空间和 InstallPrintQueue 类可以立即安装从…

化妆-护肤品选购

粉底液 在下颚线涂抹 选择贴近肤色的 在选购时不能立即选购,而是涂抹后逛个街吃个饭,再看持久程度和服帖程度,所有粉底液都会脱妆 品牌 韩系品牌 VDL 韩系品牌偏光泽感(因为韩国人皮肤偏好) 所以一般会带有”亮泽“ “光感” 中国品牌 …

查看并解析当前jdk的垃圾收集器

概述:复习的时候,学看一下。 命令: -XX:PrintCommandLineFlags 打开idea,配置jvm 把上面命令输入jvm options中即可。 举例代码 这个代码的解析,我上篇文章有写,这个跟本文没有任何关系: …

C++--enum--枚举

C/C枚举类型: 不限定作用域的枚举类型 关键字:enum 声明枚举类型,然后可以用枚举类型来定义变量(如同结构体): enum Color{white,black,yellow}; {注意分号} Color color_type; color_type 变量的值只限于枚举类型Color中的值 枚…

深度学习|RCNNFast-RCNN

1.RCNN 2014年提出R-CNN网络,该网络不再使用暴力穷举的方法,而是使用候选区域方法(region proposal method)创建目标检测的区域来完成目标检测的任务,R-CNN是以深度神经网络为基础的目标检测的模型 ,以R-C…

Hikvision综合安防管理平台files;.css接口存在任意文件读取漏洞 附POC软件

免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文章仅供学习用途使用。 1. Hikvisi…

MySQL 常规操作指南

1. 连接MySQL服务器 (1)通过命令行连接 mysql -u username -p在提示下输入对应用户的密码,即可进入MySQL命令行界面。 (2)指定数据库连接 mysql -u username -p -D database_name这里会直接连接到名为database_nam…

设计模式——1_6 代理(Proxy)

诗有可解不可解,若镜花水月勿泥其迹可也 —— 谢榛 文章目录 定义图纸一个例子:图片搜索器图片加载搜索器直接在Image添加组合他们 各种各样的代理远程代理:镜中月,水中花保护代理:对象也该有隐私引用代理:…

Jupyter Notebook安装使用教程

Jupyter Notebook 是一个基于网页的交互式计算环境,允许你创建和共享包含代码、文本说明、图表和可视化结果的文档。它支持多种编程语言,包括 Python、R、Julia 等。其应用场景非常广泛,特别适用于数据科学、机器学习和教育领域。它可以用于数…

vue element MessageBox.prompt this.$prompt组件禁止显示右上角关闭按钮,取消按钮,及点击遮罩层关闭

vue element MessageBox.prompt this.$prompt组件禁止或取消显示右上角关闭按钮,取消按钮,及点击遮罩层关闭 实现效果: 实现代码 MessageBox.prompt(请先完成手机号绑定, 系统提示, {confirmButtonText: 提 交,showClose: false,closeOnClic…

linux之安装配置VM+CentOS7+换源

文章目录 一、centos07安装二、CentOS 07网络配置2.1解决CentOS 07网络名不出现问题此博主的论文可以解决2.2配置(命令: 【ip a】也可查看ip地址) 三、使用链接工具链接CentOS进行命令控制四、换软件源 一、centos07安装 1、在vmvare中新建虚拟机 2、下…

Linux:动静态库的概念与制作使用

文章目录 动静态库基础认知动静态库基本概念静态库的制作库的概念包的概念 静态库的使用第三方库小结 动态库的制作动态库的使用动态库如何找到内容?小结 本篇要谈论的内容是关于动静态库的问题,具体的逻辑框架是建立在库的制作,库的使用&…

JavaFX增删改查其他控件

小技巧 增删改查思路 --查 底层select * from 表 where sname like %% --1.拿文本框的关键字 --2.调模糊查询的方法 myShow("") --删 底层 delete from tb_stu where sid? --1.想方设法拿学号 --1.先拿到用户所选中的学生对象 Student --2.调用方法传对象.getSi…

mysql INSERT数据覆盖现有元素(若存在)

INSERT...ON DUPLICATE KEY UPDATE的使用 如果指定了ON DUPLICATE KEY UPDATE,并且插入行后会导致在一个UNIQUE索引或PRIMARY KEY中出现重复值,则会更新ON DUPLICATE KEY UPDATE关键字后面的字段值。 例如,如果列a被定义为UNIQUE&#xff0…

不要为了学习而学习

经常有朋友问我: 老师,从您这里学了很多方法,也一直想要改变自己,但总是没办法坚持下去,怎么办? 这个问题,我也很无奈啊。毕竟我也没办法飞到你身边,手把手把每一步都教给你。&…

Eyes Wide Shut? Exploring the Visual Shortcomings of Multimodal LLMs

大开眼界?探索多模态模型种视觉编码器的缺陷。 论文中指出,上面这些VQA问题,人类可以瞬间给出正确的答案,但是多模态给出的结果却是错误的。是哪个环节出了问题呢?视觉编码器的问题?大语言模型出现了幻觉&…

七八分钟快速用k8s部署springboot前后端分离项目

前置依赖 k8s集群,如果没有安装,请先安装 kubectl ,客户端部署需要依赖 应用镜像构建 应用镜像构建不用自己去执行,相关镜像已经推送到docker hub 仓库,如果要了解过程和细节,可以看一下,否…

基于springboot+vue的足球青训俱乐部管理系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 研究背景…

C++逆向分析--虚函数(多态的前置)

先理解一件事,在intel汇编层面来说,直接调用和间接调用的区别。 直接调用语法: call 地址 硬编码为 :e8 间接调用语法: call [ ...] 硬编码为: FF 那么在C语法中,实现多态的前提是父类需要实现多态的成员…