【深度学习】pytorch快速得到mobilenet_v2 pth 和onnx

在linux执行这个程序:

import torch
import torch.onnx
from torchvision import transforms, models
from PIL import Image
import os# Load MobileNetV2 model
model = models.mobilenet_v2(pretrained=True)
model.eval()# Download an example image from the PyTorch website
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try:os.system(f"wget {url} -O {filename}")
except Exception as e:print(f"Error downloading image: {e}")# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])input_image = Image.open(filename)
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)# Download ImageNet labels using wget
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")# Read the categories
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())# Save the PyTorch model
torch.save(model.state_dict(), "mobilenet_v2.pth")# Convert the PyTorch model to ONNX with specified input and output names
dummy_input = torch.randn(1, 3, 224, 224)
onnx_path = "mobilenet_v2.onnx"
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names)print(f"PyTorch model saved to 'mobilenet_v2.pth'")
print(f"ONNX model saved to '{onnx_path}'")# Load the ONNX model
import onnx
import onnxruntimeonnx_model = onnx.load(onnx_path)
onnx_session = onnxruntime.InferenceSession(onnx_path)# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX with the correct input name
onnx_output = onnx_session.run(['output'], {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("\nTop categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

得到:

在这里插入图片描述
用本地pth推理:

import torch
from torchvision import transforms, models
from PIL import Image# Load MobileNetV2 model
model = models.mobilenet_v2()
model.load_state_dict(torch.load("mobilenet_v2.pth", map_location=torch.device('cpu')))
model.eval()# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Perform inference on CPU
with torch.no_grad():output = model(input_tensor)# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
# print(output[0])# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())

用onnx推理:

import torch
import onnxruntime
from torchvision import transforms
from PIL import Image# Load the ONNX model
onnx_path = "mobilenet_v2.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_path)# Preprocess the input image
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Load the example image
input_image = Image.open("dog.jpg")
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension# Convert input tensor to ONNX-compatible format
input_tensor_onnx = input_tensor.numpy()# Perform inference on ONNX
onnx_output = onnx_session.run(None, {'input': input_tensor_onnx})
onnx_probabilities = torch.nn.functional.softmax(torch.tensor(onnx_output[0]), dim=1)# Load ImageNet labels
categories = []
with open("imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]# Show top categories per image for ONNX
onnx_top5_prob, onnx_top5_catid = torch.topk(onnx_probabilities, 5)
print("Top categories for ONNX:")
for i in range(onnx_top5_prob.size(1)):print(categories[onnx_top5_catid[0][i]], onnx_top5_prob[0][i].item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/150687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

韦东山linux驱动开发学习【常更】

1.linux目录简单介绍 2.直接运行需要在$path路径下

windows 安裝字體Font

或者直接Copy到C:\Windows\fonts 目錄下

用三智者交易策略澳福加减仓轻松盈利,就是这么厉害

就是这么厉害, 用三智者交易策略,澳福通过加减仓就可以在交易市场中轻松盈利。各位投资者都知道三智者交易策略的两个重要的原则。当市场超过外部极限时,在向上分形的高点和向下分形的低点,就会跟随外部方向/分形点。 fpmarkets澳…

如何使用Docker部署Apache+Superset数据平台并远程访问?

大数据可视化BI分析工具Apache Superset实现公网远程访问 文章目录 大数据可视化BI分析工具Apache Superset实现公网远程访问前言1. 使用Docker部署Apache Superset1.1 第一步安装docker 、docker compose1.2 克隆superset代码到本地并使用docker compose启动 2. 安装cpolar内网…

LLM大模型 (chatgpt) 在搜索和推荐上的应用

目录 1 大模型在搜索的应用1.1 召回1.1.1 倒排索引1.1.2 倒排索引存在的问题1.1.3 大模型在搜索召回的应用 (实体倒排索引) 1.2 排序1.2.1 大模型在搜索排序应用(融入LLM实体排序) 2 大模型在推荐的应用2.1 学术界关于大模型在推荐的研究2.2 …

采用Nexus搭建Maven私服

采用Nexus搭建Maven私服 1.采用docker安装 1.创建数据目录挂载的目录: /usr/local/springcloud_1113/nexus3/nexus-data2.查询并拉取镜像docker search nexus3docker pull sonatype/nexus33.查看拉取的镜像docker images4.创建docker容器:可能出现启动…

为关键信息基础设施安全助力!持安科技加入关保联盟

近日,中关村华安关键信息基础设施安全保护联盟发布了其新一批的会员单位,零信任办公安全代表企业持安科技成功加入,与联盟企业共同为关键信息基础设施提供各类支撑和保障。 中关村华安关键信息基础设施安全保护联盟由北京市科学技术委员会、中…

软件测试面试时问你的项目经验,你知道该怎么说吗?

很简单,我来给你们一个公式 0 自我介绍,名字 学历 荣誉。 1 简述项目背景,你身处这个项目是做什么的。 不要太细,试着引导一下面试官让他提问。这样,请问您对此有什么疑问吗? 2 简述 你在项目中的角色&…

Dubbo快速实践

文章目录 架构相关概念集群和分布式架构演进 Dubbo概述Dubbo快速入门前置准备配置服务接口配置Provider配置Consumer Dubbo基本使用总结 本文参考https://www.bilibili.com/video/BV1VE411q7dX 架构相关概念 集群和分布式 集群:很多“人”一起 ,干一样…

Visual Studio Code 从英文界面切换中文

1、先安装中文的插件,直接安装。 2、点击右下角的 change language restart, 让软件重启即可以完成了。

Kafka学习笔记(二)

目录 第3章 Kafka架构深入3.3 Kafka消费者3.3.1 消费方式3.3.2 分区分配策略3.3.3 offset的维护 3.4 Kafka高效读写数据3.5 Zookeeper在Kafka中的作用3.6 Kafka事务3.6.1 Producer事务3.6.2 Consumer事务(精准一次性消费) 第4章 Kafka API4.1 Producer A…

异常控制流——(中断、陷阱、故障、终止、进程等操作系统干货)

异常 异常控制流 控制流: 假设从处理机上电运行一直到断电关机的这段时间内,程序计数器的值是下图序列,其中ak表示某一条指令Ik的地址。 **控制转移:**每一次从ak到ak1的过渡 **平滑:**Ik和Ik1在内存中是相邻的&am…

怎样备份电脑文件比较安全

域智盾软件是一款功能强大的电脑监控软件,它不仅具备实时屏幕监控、行为审计等功能,还能够对电脑文件进行备份和管理。下面将介绍域智盾软件如何备份电脑文件,以确保数据安全。 1、开启文档备份功能 部署后台,然后点击文档安全&a…

李宏毅2023机器学习作业HW05解析和代码分享

ML2023Spring - HW5 相关信息: 课程主页 课程视频 Sample code HW05 视频 HW05 PDF 个人完整代码分享: GitHub | Gitee | GitCode 运行日志记录: wandb P.S. HW05/06 是在 Judgeboi 上提交的,完全遵循 hint 就可以达到预期效果。 因为无法在 Judgeboi 上…

Vue3 源码解读系列(八)——生命周期

生命周期 正常的生命周期 // 注册钩子函数 const onBeforeMount createHook(bm/* BEFORE_MOUNT */) const onMounted createHook(m/* MOUNTED */) const onBeforeUpdate createHook(bu/* BEFORE_UPDATE */) const onUpdated createHook(u/* UPDATED */) const onBeforeUnm…

商城免费搭建之java商城 java电子商务Spring Cloud+Spring Boot+mybatis+MQ+VR全景+b2b2c

1. 涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务) 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis 3. 前端框架…

java桌面程序

目标之一是把打印导出的功能最终用java实现一套,首先选定javafx,因为idea默认创建工程就带的javafx,没找到swing。 创建工程,这里要选1.8,高版本jdk默认不带fx 实现主界面的代码 package sample;import javafx.app…

将word中的表格无变形的弄进excel中

在上篇文章中记录了将excel表拷贝到word中来: 记录将excel表无变形的弄进word里面来-CSDN博客 本篇记录:将word中的表格无变形的弄进excel中。 1.按F12,“另存为...”,保存类型:“单个文件页面”,保存。…

numpy报错:AttributeError: module ‘numpy‘ has no attribute ‘float‘

报错:AttributeError: module numpy has no attribute float numpy官网:NumPy 报错原因:从numpy1.24起删除了numpy.bool、numpy.int、numpy.float、numpy.complex、numpy.object、numpy.str、numpy.long、numpy.unicode类型的支持。 解决办法…