【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:os.system(f"wget {url} -O {filename}")
except Exception as e:print(f"Error downloading image: {e}")# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")# Read the categories
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")# Load the ONNX model
import onnx
import onnxruntimeonnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/150687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

韦东山linux驱动开发学习【常更】

1.linux目录简单介绍 2.直接运行需要在$path路径下

windows 安裝字體Font

或者直接Copy到C:\Windows\fonts 目錄下

maven 添加 checkstyle 插件约束代码规范

本例示例&#xff0c;是引用 http 链接这种在线 checkstyle.xml 文件的配置方式&#xff0c;如下示例&#xff1a; <properties><maven.checkstyle.plugin.version>3.3.0</maven.checkstyle.plugin.version><!--支持本地绝对路径、本地相对路径、HTTP远程…

用三智者交易策略澳福加减仓轻松盈利,就是这么厉害

就是这么厉害&#xff0c; 用三智者交易策略&#xff0c;澳福通过加减仓就可以在交易市场中轻松盈利。各位投资者都知道三智者交易策略的两个重要的原则。当市场超过外部极限时&#xff0c;在向上分形的高点和向下分形的低点&#xff0c;就会跟随外部方向/分形点。 fpmarkets澳…

如何使用Docker部署Apache+Superset数据平台并远程访问?

大数据可视化BI分析工具Apache Superset实现公网远程访问 文章目录 大数据可视化BI分析工具Apache Superset实现公网远程访问前言1. 使用Docker部署Apache Superset1.1 第一步安装docker 、docker compose1.2 克隆superset代码到本地并使用docker compose启动 2. 安装cpolar内网…

LLM大模型 (chatgpt) 在搜索和推荐上的应用

目录 1 大模型在搜索的应用1.1 召回1.1.1 倒排索引1.1.2 倒排索引存在的问题1.1.3 大模型在搜索召回的应用 (实体倒排索引&#xff09; 1.2 排序1.2.1 大模型在搜索排序应用&#xff08;融入LLM实体排序&#xff09; 2 大模型在推荐的应用2.1 学术界关于大模型在推荐的研究2.2 …

什么是硬分叉?硬分叉的原因是什么?硬分叉的影响是什么?

目录 什么是硬分叉? 硬分叉的原因是什么? 区块大小的改变 共识机制的修改

在vscode中使用Latex:TexLive2023

安装TexLive2023及配置vscode可参考https://zhuanlan.zhihu.com/p/166523064 然后编译模板 .tex文件时&#xff0c;出现以下几个错误&#xff1a; 1. ctexbook找不到字体集 d:/texlive/2023/texmf-dist/tex/latex/ctex/ctexbook.cls:1678: Class ctexbook Error: CTeX fo…

采用Nexus搭建Maven私服

采用Nexus搭建Maven私服 1.采用docker安装 1.创建数据目录挂载的目录&#xff1a; /usr/local/springcloud_1113/nexus3/nexus-data2.查询并拉取镜像docker search nexus3docker pull sonatype/nexus33.查看拉取的镜像docker images4.创建docker容器&#xff1a;可能出现启动…

OpenWrt环境下,由于wget不支持ssl/tls导致执行opkg update失败的解决方法

执行&#xff1a; opkg update 显示&#xff1a; wget: SSL support not available, please install one of the libustream-ssl-* libraries as well as the ca-bundle and ca-certificates packages. 提示opkg依赖的wget不支持ssl/tls。 此时需要下载支持ssl/tls的wget。但是…

为关键信息基础设施安全助力!持安科技加入关保联盟

近日&#xff0c;中关村华安关键信息基础设施安全保护联盟发布了其新一批的会员单位&#xff0c;零信任办公安全代表企业持安科技成功加入&#xff0c;与联盟企业共同为关键信息基础设施提供各类支撑和保障。 中关村华安关键信息基础设施安全保护联盟由北京市科学技术委员会、中…

软件测试面试时问你的项目经验,你知道该怎么说吗?

很简单&#xff0c;我来给你们一个公式 0 自我介绍&#xff0c;名字 学历 荣誉。 1 简述项目背景&#xff0c;你身处这个项目是做什么的。 不要太细&#xff0c;试着引导一下面试官让他提问。这样&#xff0c;请问您对此有什么疑问吗&#xff1f; 2 简述 你在项目中的角色&…

函数的极限和联系以及与数列收敛的联系

函数 极限定义:函数在 x 0 x_0 x0​处逼近 L L L,则给定任意正数 ϵ \epsilon ϵ,都有存在的 δ \delta δ使得 ∣ f ( x ) − L ∣ < ϵ , 0 < ∣ x − x 0 ∣ < δ |f(x)-L|<\epsilon, 0<|x-x_0|<\delta ∣f(x)−L∣<ϵ,0<∣x−x0​∣<δ 连续定…

2021年6月青少年软件编程(Python)等级考试试卷(一级)

青少年软件编程(Python)等级考试试卷(一级) 分数:100.00 题数:37一、单选题(共25题,每题2分,共50分)二、判断题(共10题,每题2分,共20分)三、编程题(共2题,共30分)分数:100.00 题数:37 一、单选题(共25题,每题2分,共50分) 下列程序运行的结果是?( ) …

Dubbo快速实践

文章目录 架构相关概念集群和分布式架构演进 Dubbo概述Dubbo快速入门前置准备配置服务接口配置Provider配置Consumer Dubbo基本使用总结 本文参考https://www.bilibili.com/video/BV1VE411q7dX 架构相关概念 集群和分布式 集群&#xff1a;很多“人”一起 &#xff0c;干一样…

Visual Studio Code 从英文界面切换中文

1、先安装中文的插件&#xff0c;直接安装。 2、点击右下角的 change language restart&#xff0c; 让软件重启即可以完成了。

Kafka学习笔记(二)

目录 第3章 Kafka架构深入3.3 Kafka消费者3.3.1 消费方式3.3.2 分区分配策略3.3.3 offset的维护 3.4 Kafka高效读写数据3.5 Zookeeper在Kafka中的作用3.6 Kafka事务3.6.1 Producer事务3.6.2 Consumer事务&#xff08;精准一次性消费&#xff09; 第4章 Kafka API4.1 Producer A…

异常控制流——(中断、陷阱、故障、终止、进程等操作系统干货)

异常 异常控制流 控制流&#xff1a; 假设从处理机上电运行一直到断电关机的这段时间内&#xff0c;程序计数器的值是下图序列&#xff0c;其中ak表示某一条指令Ik的地址。 **控制转移&#xff1a;**每一次从ak到ak1的过渡 **平滑&#xff1a;**Ik和Ik1在内存中是相邻的&am…

怎样备份电脑文件比较安全

域智盾软件是一款功能强大的电脑监控软件&#xff0c;它不仅具备实时屏幕监控、行为审计等功能&#xff0c;还能够对电脑文件进行备份和管理。下面将介绍域智盾软件如何备份电脑文件&#xff0c;以确保数据安全。 1、开启文档备份功能 部署后台&#xff0c;然后点击文档安全&a…