【PyTorch】图像多分类项目部署

【PyTorch】图像多分类项目

【PyTorch】图像多分类项目部署

如果需要在独立于训练脚本的新脚本中部署模型,这种情况模型和权重在内存中不存在,因此需要构造一个模型类的对象,然后将存储的权重加载到模型中。

加载模型参数,验证模型的性能,并在测试数据集上部署模型

from torch import nn
from torchvision import models# 定义一个resnet18模型,不使用预训练参数
model_resnet18 = models.resnet18(pretrained=False)
# 获取模型的全连接层的输入特征数
num_ftrs = model_resnet18.fc.in_features
# 定义分类的类别数
num_classes=10
# 将全连接层的输出特征数改为分类的类别数
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加载预训练的ResNet18模型权重
model_resnet18.load_state_dict(torch.load(path2weights))
# 将ResNet-18模型设置为评估模式
model_resnet18.eval();
# 检查CUDA是否可用
if torch.cuda.is_available():# 如果可用,将设备设置为CUDAdevice = torch.device("cuda")# 将模型移动到CUDA设备上model_resnet18=model_resnet18.to(device)def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):# 获取数据集的长度len_data=len(dataset)# 初始化输出张量y_out=torch.zeros(len_data,num_classes)# 初始化真实标签张量y_gt=np.zeros((len_data),dtype="uint8")# 将模型移动到指定设备model=model.to(device)# 初始化时间列表elapsed_times=[]with torch.no_grad():for i in range(len_data):# 获取数据集中的一个样本x,y=dataset[i]# 将真实标签存入张量y_gt[i]=y# 记录开始时间start=time.time()    # 将输入数据传入模型进行预测yy=model(x.unsqueeze(0).to(device))# 将预测结果存入张量y_out[i]=torch.softmax(yy,dim=1)# 计算预测时间elapsed=time.time()-start# 将预测时间存入列表elapsed_times.append(elapsed)# 如果进行完整性检查,则跳出循环if sanity_check is True:break# 计算平均预测时间inference_time=np.mean(elapsed_times)*1000# 打印平均预测时间print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回预测结果和真实标签return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms# 数据转换
data_transformer = transforms.Compose([transforms.ToTensor()])path2data="./data"# 加载数据
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit# 创建StratifiedShuffleSplit对象,设置分割次数为1,测试集大小为0.2,随机种子为0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)# 获取test0_ds的索引
indices=list(range(len(test0_ds)))# 获取test0_ds的标签
y_test0=[y for _,y in test0_ds]# 对索引和标签进行分割
for test_index, val_index in sss.split(indices, y_test0):# 打印测试集和验证集的索引print("test:", test_index, "val:", val_index)# 打印测试集和验证集的大小print(len(val_index),len(test_index))

from torch.utils.data import Subset# 从test0_ds中选取val_index索引的子集,赋值给val_ds
val_ds=Subset(test0_ds,val_index)
# 从test0_ds中选取test_index索引的子集,赋值给test_ds
test_ds=Subset(test0_ds,test_index)
# 定义均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定义标准差
std=[0.22414584,0.22148906,0.22389975]
# 定义一个名为test0_transformer的变量,用于将一系列的图像变换操作组合在一起
test0_transformer = transforms.Compose([# 将图像转换为Tensor类型transforms.ToTensor(),# 对图像进行归一化操作,使用mean和std作为均值和标准差transforms.Normalize(mean, std),])   
# 将test0_transformer赋值给test0_ds的transform属性
test0_ds.transform=test0_transformer
import time
import numpy as np# 调用deploy_model函数,传入model_resnet18,val_ds,device和sanity_check参数,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形状
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score# 将y_out中的最大值索引赋值给y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形状
print(y_pred.shape,y_gt.shape)# 计算并打印y_pred和y_gt的准确率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署模型,得到预测结果和真实标签
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)# 取出预测结果中概率最大的类别
y_pred = np.argmax(y_out,axis=1)# 计算准确率
acc=accuracy_score(y_pred,y_gt)# 打印准确率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)# 定义一个函数,用于显示图像
def imshow(inp, title=None):# 定义图像的均值和标准差mean=[0.4467106, 0.43980986, 0.40664646]std=[0.22414584,0.22148906,0.22389975]# 将图像从tensor转换为numpy数组,并转置inp = inp.numpy().transpose((1, 2, 0))# 将均值和标准差转换为numpy数组mean = np.array(mean)std = np.array(std)# 将图像的像素值进行归一化inp = std * inp + mean# 将像素值限制在0和1之间inp = np.clip(inp, 0, 1)# 显示图像plt.imshow(inp)# 如果有标题,则显示标题if title is not None:plt.title(title)# 暂停0.001秒plt.pause(0.001) # 定义网格大小
grid_size=16
# 随机生成4个索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印随机生成的索引
print("image indices:",rnd_inds)# 根据索引获取对应的图像和标签
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]# 将图像转换为网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印网格的形状
print(x_grid_test.shape)# 设置图像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 显示网格
imshow(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50016.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图解 HDFS 架构 |读写过程

HDFS HDFS 全称 Hadoop Distributed File System,是一个分布式文件系统。HDFS(Hadoop Distributed File System)是 Apache Hadoop 生态系统的一部分,它是一个分布式文件系统,用于存储和处理大规模数据集。HDFS 专门设…

源代码防泄密如何做?企业如何有效选择源代码防泄密产品?

源代码防泄密怎么选?如何高效做源代码防泄密工作? 源代码开发环境复杂,涉及的开发软件和文件类型众多且变化多端,那么究竟有哪些源代码防泄密软件能够适应各种开发软件而不影响原有的工作效率呢? 对于研发人员来说&a…

探索 Framer Motion 高级动画技巧:提升前端设计水平

在现代的网页和应用设计中,动画不仅仅是视觉的点缀,更是用户体验的重要组成部分。它能够使界面更具吸引力,提升交互的流畅性,甚至在不经意间传达品牌的个性和态度。然而,要创造出令人惊叹的动效并不容易——直到有了 F…

经验——OLED的使用

型号:HS96L01W 4S03 分辨率:120*64 通讯方式:4线SPI 模式00 MCU:MSPM0G3507(只影响SPI的配置) 原本照着型号搜到了嘉立创的使用文档,但是实际上并不能正常使用,后来寻到了一篇博客…

MFC与QT中禁用Esc、Alt+F4、关闭图标

在业务中,我们需要按指定的方式才能关闭当前对话框。如下图需输入密码点击确认后,界面才能关闭。 方法1:通过禁用界面的按钮以及键盘上对应关闭对话框的按键。 1.灰度化关闭按钮 在对话框初始化部分添加将关闭按钮禁用 //MFC CMenu *pSysMe…

主要的国产信创数据库有哪些

数据库生态分类 当前数据库生态可以大致分类三类: 一、传统商业数据库,以 Oracle 为代表,其在 40 余年时间里所创造的数据库帝国已拥有了极其完善的生态; 二、开源数据库,以 MYSQL、PostgreSQL为代表,遍布全球的社区组织形成了强…

大文件分片上传(前端TS实现)

大文件分片上传 内容 一般情况下,前端上传文件就是new FormData,然后把文件 append 进去,然后post发送给后端就完事了,但是文件越大,上传的文件也就越长,如果在上传过程中,突然网络故障,又或者…

AHK是让任何软件都支持 Shift + 鼠标滚轮 实现界面水平滚动

目录 基本介绍 详细特点 图解安装 下载失败?缓慢? 创建并运行脚本代码😃 新建空 xxx.ahk文件 vscode/记事本等编辑工具打开 复制并粘贴简易脚本 运行 其他问题 问题一:弹出无法执行此脚本 关闭脚本 基本介绍 AutoHot…

【MetaGPT系列】【MetaGPT完全实践宝典——如何定义单一行为多行为Agent】

目录 前言一、智能体1-1、Agent概述1-2、Agent与ChatGPT的区别 二、多智能体框架MetaGPT2-1、安装&配置2-2、使用已有的Agent(ProductManager)2-3、拥有单一行为的Agent(SimpleCoder)2-3-1、定义写代码行为2-3-2、角色定义2-3…

B站音视频分开 大小问题

音频是33331 kb,视频是374661 kb 合并之后却是2561363 kb 这可能是B站音频和视频分开的原因吧

Zabbix监控案例

文章目录 一、监控linux TCP连接状态TCP端口的十一种连接状态自定义监控项监控示例二、监控模板监控tcp连接监控nginx 一、监控linux TCP连接状态 TCP,全称Transfer Control Protocol,中文名为传输控制协议,它工作在OSI的传输层,…

3.Fabric系统架构、网络拓扑图、交易流程

Hyperledger Fabric系统架构 Fabric网络拓扑图 Fabric交易流程 多通道

【数字范围按位与】python刷题记录

run到位运算。 顿悟&#xff1a; 只看第一个二进制位&#xff0c;只存在0,1两种情况&#xff0c;所以如果left<right&#xff0c;区间中必然存在left1,那么最低位&一下一定等于0了&#xff0c;然后不停的右移&#xff0c;一直移到两个相等为止&#xff0c;就这么简单 …

Qt自定义下拉列表-可为选项设置标题、可禁用选项

在Qt中,ComboBox&#xff08;组合框&#xff09;是一种常用的用户界面控件,它提供了一个下拉列表,允许用户从预定义的选项中选择一个。在项目开发中&#xff0c;如果简单的QComboBox无法满足需求&#xff0c;可以通过自定义QComboBox来实现更复杂的功能。本文介绍一个自定义的下…

二级医院LIS系统源码,医学检验系统,支持DB2,Oracle,MS SQLServer等主流数据库

系统概述&#xff1a; LIS系统即实验室信息管理系统。LIS系统能实现临床检验信息化&#xff0c;检验科信息管理自动化。其主要功能是将检验科的实验仪器传出的检验数据经数据分析后&#xff0c;自动生成打印报告&#xff0c;通过网络存储在数据库中&#xff0c;使医生能够通过医…

7.消息应答

消费者完成一个任务可能需要一段时间&#xff0c;如果其中一个消费者处理一个长时间的任务并且只完成了部分突然就挂掉了&#xff0c;会发生什么情况&#xff1f; RabbitMQ一旦向消费者传递了一条消息&#xff0c;便立即将该消息标记为删除。这种情况下&#xff0c;突然有个消…

代码随想录算法训练营day6 | 242.有效的字母异位词、349. 两个数组的交集、202. 快乐数、1.两数之和

文章目录 哈希表键值 哈希函数哈希冲突拉链法线性探测法 常见的三种哈希结构集合映射C实现std::unordered_setstd::map 小结242.有效的字母异位词思路复习 349. 两个数组的交集使用数组实现哈希表的情况思路使用set实现哈希表的情况 202. 快乐数思路 1.两数之和思路 总结 今天是…

OpenCV 遍历Mat,像素操作,使用TrackBar 调整图像的亮度和对比度 C++实现

文章目录 1.使用C遍历Mat,完成颜色反转1.1 常规遍历方式1.2 迭代器遍历方式1.3指针访问方式遍历&#xff08;最快&#xff09;1.4不同遍历方式的时间对比 2.图像像素操作&#xff0c;提高图像的亮度3.TrackBar 进度条操作3.1使用TrackBar 调整图像的亮度3.2使用TrackBar 调整图…

学术研讨 | 区块链网络体系结构研讨会顺利召开

添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; 近日&#xff0c;国家区块链技术创新中心组织了“区块链网络体系结构研讨会”&#xff0c;会议面向跨域交互多、计算规模大、数据管理复杂、性能与扩展性要求高等特征的区块链网络的体系结构展开交流研讨&…

docker相关内容学习

一、docker的四部分 二、镜像相关命令 三、容器相关命令