【PyTorch】图像多分类项目部署

【PyTorch】图像多分类项目

【PyTorch】图像多分类项目部署

如果需要在独立于训练脚本的新脚本中部署模型,这种情况模型和权重在内存中不存在,因此需要构造一个模型类的对象,然后将存储的权重加载到模型中。

加载模型参数,验证模型的性能,并在测试数据集上部署模型

from torch import nn
from torchvision import models# 定义一个resnet18模型,不使用预训练参数
model_resnet18 = models.resnet18(pretrained=False)
# 获取模型的全连接层的输入特征数
num_ftrs = model_resnet18.fc.in_features
# 定义分类的类别数
num_classes=10
# 将全连接层的输出特征数改为分类的类别数
model_resnet18.fc = nn.Linear(num_ftrs, num_classes)import torch 
path2weights="./models/resnet18_pretrained.pt"
# 加载预训练的ResNet18模型权重
model_resnet18.load_state_dict(torch.load(path2weights))
# 将ResNet-18模型设置为评估模式
model_resnet18.eval();
# 检查CUDA是否可用
if torch.cuda.is_available():# 如果可用,将设备设置为CUDAdevice = torch.device("cuda")# 将模型移动到CUDA设备上model_resnet18=model_resnet18.to(device)def deploy_model(model,dataset,device, num_classes=10,sanity_check=False):# 获取数据集的长度len_data=len(dataset)# 初始化输出张量y_out=torch.zeros(len_data,num_classes)# 初始化真实标签张量y_gt=np.zeros((len_data),dtype="uint8")# 将模型移动到指定设备model=model.to(device)# 初始化时间列表elapsed_times=[]with torch.no_grad():for i in range(len_data):# 获取数据集中的一个样本x,y=dataset[i]# 将真实标签存入张量y_gt[i]=y# 记录开始时间start=time.time()    # 将输入数据传入模型进行预测yy=model(x.unsqueeze(0).to(device))# 将预测结果存入张量y_out[i]=torch.softmax(yy,dim=1)# 计算预测时间elapsed=time.time()-start# 将预测时间存入列表elapsed_times.append(elapsed)# 如果进行完整性检查,则跳出循环if sanity_check is True:break# 计算平均预测时间inference_time=np.mean(elapsed_times)*1000# 打印平均预测时间print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回预测结果和真实标签return y_out.numpy(),y_gt
from torchvision import datasets
import torchvision.transforms as transforms# 数据转换
data_transformer = transforms.Compose([transforms.ToTensor()])path2data="./data"# 加载数据
test0_ds=datasets.STL10(path2data, split='test', download=True,transform=data_transformer)
print(test0_ds.data.shape)

from sklearn.model_selection import StratifiedShuffleSplit# 创建StratifiedShuffleSplit对象,设置分割次数为1,测试集大小为0.2,随机种子为0
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)# 获取test0_ds的索引
indices=list(range(len(test0_ds)))# 获取test0_ds的标签
y_test0=[y for _,y in test0_ds]# 对索引和标签进行分割
for test_index, val_index in sss.split(indices, y_test0):# 打印测试集和验证集的索引print("test:", test_index, "val:", val_index)# 打印测试集和验证集的大小print(len(val_index),len(test_index))

from torch.utils.data import Subset# 从test0_ds中选取val_index索引的子集,赋值给val_ds
val_ds=Subset(test0_ds,val_index)
# 从test0_ds中选取test_index索引的子集,赋值给test_ds
test_ds=Subset(test0_ds,test_index)
# 定义均值
mean=[0.4467106, 0.43980986, 0.40664646]
# 定义标准差
std=[0.22414584,0.22148906,0.22389975]
# 定义一个名为test0_transformer的变量,用于将一系列的图像变换操作组合在一起
test0_transformer = transforms.Compose([# 将图像转换为Tensor类型transforms.ToTensor(),# 对图像进行归一化操作,使用mean和std作为均值和标准差transforms.Normalize(mean, std),])   
# 将test0_transformer赋值给test0_ds的transform属性
test0_ds.transform=test0_transformer
import time
import numpy as np# 调用deploy_model函数,传入model_resnet18,val_ds,device和sanity_check参数,返回y_out和y_gt
y_out,y_gt=deploy_model(model_resnet18,val_ds,device=device,sanity_check=False)
# 打印y_out和y_gt的形状
print(y_out.shape,y_gt.shape)

from sklearn.metrics import accuracy_score# 将y_out中的最大值索引赋值给y_pred
y_pred = np.argmax(y_out,axis=1)
# 打印y_pred和y_gt的形状
print(y_pred.shape,y_gt.shape)# 计算并打印y_pred和y_gt的准确率
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署模型,得到预测结果和真实标签
y_out,y_gt=deploy_model(model_resnet18,test_ds,device=device)# 取出预测结果中概率最大的类别
y_pred = np.argmax(y_out,axis=1)# 计算准确率
acc=accuracy_score(y_pred,y_gt)# 打印准确率
print(acc)

from torchvision import utils
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
np.random.seed(1)# 定义一个函数,用于显示图像
def imshow(inp, title=None):# 定义图像的均值和标准差mean=[0.4467106, 0.43980986, 0.40664646]std=[0.22414584,0.22148906,0.22389975]# 将图像从tensor转换为numpy数组,并转置inp = inp.numpy().transpose((1, 2, 0))# 将均值和标准差转换为numpy数组mean = np.array(mean)std = np.array(std)# 将图像的像素值进行归一化inp = std * inp + mean# 将像素值限制在0和1之间inp = np.clip(inp, 0, 1)# 显示图像plt.imshow(inp)# 如果有标题,则显示标题if title is not None:plt.title(title)# 暂停0.001秒plt.pause(0.001) # 定义网格大小
grid_size=16
# 随机生成4个索引
rnd_inds=np.random.randint(1,len(test_ds),grid_size)
# 打印随机生成的索引
print("image indices:",rnd_inds)# 根据索引获取对应的图像和标签
x_grid_test=[test_ds[i][0] for i in rnd_inds]
y_grid_test=[(y_pred[i],y_gt[i]) for i in rnd_inds]# 将图像转换为网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
# 打印网格的形状
print(x_grid_test.shape)# 设置图像的大小
plt.rcParams['figure.figsize'] = (10, 10)
# 显示网格
imshow(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/50016.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sftp和ftp 区别、工作原理

Sftp的工作原理: SFTP的工作原理基于SSH协议,通过加密连接和安全认证来保障文件传输的安全性。 SFTP(SSH File Transfer Protocol)是一个确保数据在传输过程中安全的协议,它通过为传输的数据提供加密保护和对用户…

Docker重启策略和缩小镜像体积

目录 Docker重启策略 命令格式 命令选项 Docker缩小镜像体积 Docker重启策略 命令格式 docker run --restartno|always|on-failure|unless-stopped .... 命令选项 no:不管容器是正常退出还是异常退出,都不会重启容器。默认策略。always&#xf…

Java面试锦集 之 一、Java基础(1)

一、Java基础(1) 1.final 关键字的作用? 修饰变量: 一旦被赋值,就不能再被修改,保证了变量值的稳定性。 例: final int NUMBER 10; //之后就不能再改变 NUMBER 的值了。修饰方法:…

图解 HDFS 架构 |读写过程

HDFS HDFS 全称 Hadoop Distributed File System,是一个分布式文件系统。HDFS(Hadoop Distributed File System)是 Apache Hadoop 生态系统的一部分,它是一个分布式文件系统,用于存储和处理大规模数据集。HDFS 专门设…

【学习笔记】子集DP

背景 有一类问题和子集有关。 给你一个集合 S S S&#xff0c;令 T T T 为 S S S 的超集&#xff0c;也就是 S S S 所有子集的集合&#xff0c;求 T T T 中所有元素的和。 暴力1 先预处理子集的元素和 A i A_i Ai​&#xff0c;再枚举子集。 for(int s0; s<(1<…

源代码防泄密如何做?企业如何有效选择源代码防泄密产品?

源代码防泄密怎么选&#xff1f;如何高效做源代码防泄密工作&#xff1f; 源代码开发环境复杂&#xff0c;涉及的开发软件和文件类型众多且变化多端&#xff0c;那么究竟有哪些源代码防泄密软件能够适应各种开发软件而不影响原有的工作效率呢&#xff1f; 对于研发人员来说&a…

SpringBoot3 使用虚拟线程

目录 环境 说明 application.properties pom.xml Controller Java 验证 环境 Java Java: graalvm-jdk-21Springboot3.3.1 说明 springboot3.x 打开虚拟线程非常简单&#xff0c;只需添加一行配置信息即可。 spring.threads.virtual.enabled true application.properti…

探索 Framer Motion 高级动画技巧:提升前端设计水平

在现代的网页和应用设计中&#xff0c;动画不仅仅是视觉的点缀&#xff0c;更是用户体验的重要组成部分。它能够使界面更具吸引力&#xff0c;提升交互的流畅性&#xff0c;甚至在不经意间传达品牌的个性和态度。然而&#xff0c;要创造出令人惊叹的动效并不容易——直到有了 F…

经验——OLED的使用

型号&#xff1a;HS96L01W 4S03 分辨率&#xff1a;120*64 通讯方式&#xff1a;4线SPI 模式00 MCU&#xff1a;MSPM0G3507&#xff08;只影响SPI的配置&#xff09; 原本照着型号搜到了嘉立创的使用文档&#xff0c;但是实际上并不能正常使用&#xff0c;后来寻到了一篇博客…

PyTorch可以用来干嘛?

PyTorch 是一个广泛使用的开源机器学习库&#xff0c;由 Facebook AI Research&#xff08;FAIR&#xff09;开发。它主要用于计算机视觉和自然语言处理等深度学习领域&#xff0c;但也可以应用于许多其他类型的机器学习任务。PyTorch 提供了丰富的功能和灵活的设计&#xff0c…

MFC与QT中禁用Esc、Alt+F4、关闭图标

在业务中&#xff0c;我们需要按指定的方式才能关闭当前对话框。如下图需输入密码点击确认后&#xff0c;界面才能关闭。 方法1&#xff1a;通过禁用界面的按钮以及键盘上对应关闭对话框的按键。 1.灰度化关闭按钮 在对话框初始化部分添加将关闭按钮禁用 //MFC CMenu *pSysMe…

主要的国产信创数据库有哪些

数据库生态分类 当前数据库生态可以大致分类三类: 一、传统商业数据库&#xff0c;以 Oracle 为代表&#xff0c;其在 40 余年时间里所创造的数据库帝国已拥有了极其完善的生态; 二、开源数据库&#xff0c;以 MYSQL、PostgreSQL为代表&#xff0c;遍布全球的社区组织形成了强…

大文件分片上传(前端TS实现)

大文件分片上传 内容 一般情况下&#xff0c;前端上传文件就是new FormData,然后把文件 append 进去&#xff0c;然后post发送给后端就完事了&#xff0c;但是文件越大&#xff0c;上传的文件也就越长&#xff0c;如果在上传过程中&#xff0c;突然网络故障&#xff0c;又或者…

AHK是让任何软件都支持 Shift + 鼠标滚轮 实现界面水平滚动

目录 基本介绍 详细特点 图解安装 下载失败&#xff1f;缓慢&#xff1f; 创建并运行脚本代码&#x1f603; 新建空 xxx.ahk文件 vscode/记事本等编辑工具打开 复制并粘贴简易脚本 运行 其他问题 问题一&#xff1a;弹出无法执行此脚本 关闭脚本 基本介绍 AutoHot…

【MetaGPT系列】【MetaGPT完全实践宝典——如何定义单一行为多行为Agent】

目录 前言一、智能体1-1、Agent概述1-2、Agent与ChatGPT的区别 二、多智能体框架MetaGPT2-1、安装&配置2-2、使用已有的Agent&#xff08;ProductManager&#xff09;2-3、拥有单一行为的Agent&#xff08;SimpleCoder&#xff09;2-3-1、定义写代码行为2-3-2、角色定义2-3…

B站音视频分开 大小问题

音频是33331 kb&#xff0c;视频是374661 kb 合并之后却是2561363 kb 这可能是B站音频和视频分开的原因吧

grub之loongarch架构调试

一 什么是grub GNU GRUB 是一个多重操作系统启动管理器。GNU GRUB是由GRUB&#xff08;GRandUnified Bootloader&#xff09;派生而来。 GRUB最初由Erich Stefan Boleyn 设计和应用&#xff1b; 主流发行版 Fedora、Redhat、Centos、Kylin 等基于RPM包的系统&#xff0c;在最新…

04 ES6中对象的简写

在 ES6 中&#xff0c;对象字面量的书写方式进行了一些简化&#xff0c;使得对象的创建更加简洁。以下是 ES6 中对象简写的几种形式&#xff1a; 属性值缩写&#xff1a; 当对象的属性名和属性值的变量名相同时&#xff0c;可以省略属性值&#xff0c;只写属性名。 // ES5 cons…

如何在Linux上安装配置RabbitMQ消息队列

RabbitMQ是一种开源的消息中间件&#xff0c;基于AMQP协议实现。它可以在分布式系统中传递消息&#xff0c;并提供了可靠的消息传递机制。RabbitMQ使用一种称为"消息队列"的方式来管理消息的发送和接收。它的主要特性包括&#xff1a; 可靠性&#xff1a;RabbitMQ使用…

Zabbix监控案例

文章目录 一、监控linux TCP连接状态TCP端口的十一种连接状态自定义监控项监控示例二、监控模板监控tcp连接监控nginx 一、监控linux TCP连接状态 TCP&#xff0c;全称Transfer Control Protocol&#xff0c;中文名为传输控制协议&#xff0c;它工作在OSI的传输层&#xff0c;…