【PyTorch】图像二分类项目-部署

【PyTorch】图像二分类项目

【PyTorch】图像二分类项目-部署

在独立于训练脚本的新脚本中部署用于推理的模型,需要构造一个模型类的对象,并将权重加载到模型中。操作流程为:定义模型--加载权重--在验证和测试数据集上部署模型。

import torch.nn as nn
import numpy as np
# 设置随机种子
np.random.seed(0)
import torch.nn as nn
import torch.nn.functional as F# 定义一个函数,用于计算卷积层的输出形状
def findConv2dOutShape(H_in,W_in,conv,pool=2):# 获取卷积核的大小kernel_size=conv.kernel_size# 获取卷积的步长stride=conv.stride# 获取卷积的填充padding=conv.padding# 获取卷积的扩张dilation=conv.dilation# 计算卷积后的高度H_out=np.floor((H_in+2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)/stride[0]+1)# 计算卷积后的宽度W_out=np.floor((W_in+2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)/stride[1]+1)# 如果pool不为空if pool:# 将H_out除以poolH_out/=poolW_out/=pool# 返回H_out和W_out的整数形式return int(H_out),int(W_out)class Net(nn.Module):def __init__(self, params):super(Net, self).__init__()# 获取输入形状C_in,H_in,W_in=params["input_shape"]# 获取初始滤波器数量init_f=params["initial_filters"] # 获取第一个全连接层神经元数量num_fc1=params["num_fc1"]  # 获取类别数量num_classes=params["num_classes"] # 获取模型的dropout率,是0到1间的浮点数# Dropout是一种正则化技术,随机关闭部分神经元(输出设为0),防止过拟合,提高泛化能力self.dropout_rate=params["dropout_rate"] # 定义第一个卷积层self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3)# 计算第一个卷积层的输出形状h,w=findConv2dOutShape(H_in,W_in,self.conv1)self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3)h,w=findConv2dOutShape(h,w,self.conv2)self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3)h,w=findConv2dOutShape(h,w,self.conv3)self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3)h,w=findConv2dOutShape(h,w,self.conv4)# 计算全连接层的输入形状self.num_flatten=h*w*8*init_f# 定义第一个全连接层self.fc1 = nn.Linear(self.num_flatten, num_fc1)self.fc2 = nn.Linear(num_fc1, num_classes)# 定义前向传播函数,接收输入xdef forward(self, x):# 第一个卷积层x = F.relu(self.conv1(x))# 第一个池化层x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv3(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv4(x))x = F.max_pool2d(x, 2, 2)# 将卷积层的输出展平x = x.view(-1, self.num_flatten)# 第一个全连接层x = F.relu(self.fc1(x))# Dropout层x=F.dropout(x, self.dropout_rate)# 第二个全连接层x = self.fc2(x)# 返回输入x应用对数软最大变换后的输出# log-softmax对数软最大值函数,常用于计算交叉熵损失函数(cross-entropy loss),因为交叉熵损失函数需要计算概率的对数。# dim参数指定了在哪个维度上应用log-softmax。例如,如果dim=1,则对每一行应用log-softmax。return F.log_softmax(x, dim=1)# 定义模型参数
params_model={# 输入形状"input_shape": (3,96,96),# 初始过滤器数量"initial_filters": 8, # 全连接层1的神经元数量"num_fc1": 100,# Dropout率"dropout_rate": 0.25,# 类别数量"num_classes": 2,
}# 创建一个CNN模型,参数为params_model
cnn_model = Net(params_model)import torch
# 权重文件路径
path2weights="./models/weights.pt"# 加载权重文件
cnn_model.load_state_dict(torch.load(path2weights))# 进入评估模式
cnn_model.eval()

# 移动模型至cuda设备
if torch.cuda.is_available():device = torch.device("cuda")cnn_model=cnn_model.to(device) 
import time # 定义一个函数,用于部署模型
def deploy_model(model,dataset,device, num_classes=2,sanity_check=False):# num_classes:类别数,默认为2# sanity_check:是否进行完整性检查,默认为Falsepass# 获取数据集长度len_data=len(dataset)# 初始化输出张量y_out=torch.zeros(len_data,num_classes)# 初始化真实标签张量y_gt=np.zeros((len_data),dtype="uint8")# 将模型移动到指定设备model=model.to(device)# 存储每次推理的时间elapsed_times=[]# 在不计算梯度的情况下执行以下代码with torch.no_grad():for i in range(len_data):# 获取数据集中的一个样本x,y=dataset[i]# 将真实标签存储到张量中y_gt[i]=y# 记录开始时间start=time.time()    # 进行推理y_out[i]=model(x.unsqueeze(0).to(device))# 计算推理时间elapsed=time.time()-start# 将推理时间存储到列表中elapsed_times.append(elapsed)# 如果进行完整性检查,则只进行一次推理if sanity_check is True:break# 计算平均推理时间inference_time=np.mean(elapsed_times)*1000# 打印平均推理时间print("average inference time per image on %s: %.2f ms " %(device,inference_time))# 返回推理结果和真实标签return y_out.numpy(),y_gtimport torch
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
import torchvision.transforms as transforms
import os# 设置随机种子,使得每次运行代码时生成的随机数相同
torch.manual_seed(0)class histoCancerDataset(Dataset):def __init__(self, data_dir, transform,data_type="train"):      # 获取数据目录path2data=os.path.join(data_dir,data_type)# 获取数据目录下的所有文件名self.filenames = os.listdir(path2data)# 获取数据目录下的所有文件的完整路径self.full_filenames = [os.path.join(path2data, f) for f in self.filenames]# 获取标签文件名csv_filename=data_type+"_labels.csv"# 获取标签文件的完整路径path2csvLabels=os.path.join(data_dir,csv_filename)# 读取标签文件labels_df=pd.read_csv(path2csvLabels)# 将标签文件的索引设置为文件名labels_df.set_index("id", inplace=True)# 获取每个文件的标签self.labels = [labels_df.loc[filename[:-4]].values[0] for filename in self.filenames]# 获取数据转换函数self.transform = transformdef __len__(self):# 返回数据集的长度return len(self.full_filenames)def __getitem__(self, idx):# 根据索引获取图像image = Image.open(self.full_filenames[idx])  # 对图像进行转换image = self.transform(image)# 返回图像和标签return image, self.labels[idx]import torchvision.transforms as transforms
# 创建一个数据转换器,将数据转换为张量
data_transformer = transforms.Compose([transforms.ToTensor()])data_dir = "./data/"
# 传入数据目录、数据转换器和数据集类型
histo_dataset = histoCancerDataset(data_dir, data_transformer, "train")
# 打印数据集的长度
print(len(histo_dataset))

from torch.utils.data import random_split# 获取数据集的长度
len_histo=len(histo_dataset)
# 训练集取数据集的80%
len_train=int(0.8*len_histo)
# 验证集取数据集的20%
len_val=len_histo-len_train# 将数据集随机分割为训练集和验证集
train_ds,val_ds=random_split(histo_dataset,[len_train,len_val])# 打印训练集和验证集的长度
print("train dataset length:", len(train_ds))
print("validation dataset length:", len(val_ds))

 

# 部署模型 
y_out,y_gt=deploy_model(cnn_model,val_ds,device=device,sanity_check=False)
# 打印输出和真实值的形状
print(y_out.shape,y_gt.shape)

使用预测输出计算模型在验证数据集上的精度

from sklearn.metrics import accuracy_score# 获取预测
y_pred = np.argmax(y_out,axis=1)
print(y_pred.shape,y_gt.shape)# 计算精度 
acc=accuracy_score(y_pred,y_gt)
print("accuracy: %.2f" %acc)

 

# 部署在CPU上
device_cpu = torch.device("cpu")
y_out,y_gt=deploy_model(cnn_model,val_ds,device=device_cpu,sanity_check=False)
print(y_out.shape,y_gt.shape)

复制data文件夹中的sample_submission.csv文件并命名为test_labels.csv

path2csv="./data/test_labels.csv"
# 读取csv文件,并存储到DataFrame中
labels_df=pd.read_csv(path2csv)
# 显示DataFrame的前几行
labels_df.head()

data_dir = "./data/"
# 创建测试数据集
histo_test = histoCancerDataset(data_dir, data_transformer,data_type="test")
# 打印测试数据集的长度
print(len(histo_test))

 

# 用测试数据集部署
y_test_out,_=deploy_model(cnn_model,histo_test, device, sanity_check=False)# 使用np.argmax函数对y_test_out进行操作,得到y_test_pred
y_test_pred=np.argmax(y_test_out,axis=1)# 打印y_test_pred的形状
print(y_test_pred.shape)

from torchvision import utilsimport numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.random.seed(0)# 定义一个函数,用于显示图像和标签
def show(img,y,color=True):# 将图像转换为numpy数组npimg = img.numpy()# 将图像的维度从(C,H,W)转换为(H,W,C)npimg_tr=np.transpose(npimg, (1,2,0))# 如果color为False,则将图像转换为灰度图像if color==False:npimg_tr=npimg_tr[:,:,0]plt.imshow(npimg_tr,interpolation='nearest',cmap="gray")else:# 否则,直接显示图像plt.imshow(npimg_tr,interpolation='nearest')# 显示图像的标签plt.title("label: "+str(y))# 定义一个网格大小
grid_size=4
# 随机选择grid_size个图像的索引
rnd_inds=np.random.randint(0,len(histo_test),grid_size)
print("image indices:",rnd_inds)# 从histo_test中获取grid_size个图像
x_grid_test=[histo_test[i][0] for i in range(grid_size)]
# 从y_test_pred中获取grid_size个标签
y_grid_test=[y_test_pred[i] for i in range(grid_size)]# 将grid_size个图像组合成一个网格
x_grid_test=utils.make_grid(x_grid_test, nrow=4, padding=2)
print(x_grid_test.shape)# 设置图像的大小
plt.rcParams['figure.figsize'] = (10.0, 5)
# 显示图像和标签
show(x_grid_test,y_grid_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47312.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于PSO算法优化PID参数的一些问题

目录 前言 Q1:惯性权重ω如何设置比较好?学习因子C1和C2如何设置? Q2:迭代速度边界设定一定能够遍历(/覆盖)整个PID参数二维空间范围吗?还是说需要与迭代次数相关?迭代次数越高&a…

MATLAB图像处理分析基础(一)

一、引言 MATLAB软件得到许多数字图像处理学生、老师和科研工作者的喜爱,成为数字图像处理领域不可或缺的工具之一,其与其他软件相比有以下诸多显著优点。首先,MATLAB 拥有强大的内置函数库,涵盖了图像读取、显示、处理及分析的全…

【学习笔记】无人机系统(UAS)的连接、识别和跟踪(九)-无人机区域地面探测与避让(DAA)

引言 3GPP TS 23.256 技术规范,主要定义了3GPP系统对无人机(UAV)的连接性、身份识别、跟踪及A2X(Aircraft-to-Everything)服务的支持。 3GPP TS 23.256 技术规范: 【免费】3GPPTS23.256技术报告-无人机系…

ESP8266模块(2)

实例1 查看附近的WiFi 步骤1:进入AT指令模式 使用USB转串口适配器将ESP8266模块连接到电脑。打开串口终端软件,并设置正确的串口和波特率(通常为115200)。输入以下命令并按回车确认: AT如果模块响应OK,…

【计算机网络】0 课程主要内容(自顶向下方法,中科大郑烇、杨坚)(待)

1 教学目标 掌握计算机网络 基本概念 工作原理 常用技术 为将来学习、应用和研究计算机网络打下坚实基础 2 课程主要内容 1 计算机网络和互联网2 应用层3 传输层4 网络层:数据平面5 网络层:控制平面6 数据链路层和局域网7 网络安全8 无线和移动网络9 多…

构建gitlab远端服务器(check->build->test->deploy)

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 TODO:写完再整理 文章目录 系列文章目录前言构建gitlab远端服务器一、步骤一:搭建gitlab的运行服务器【运维】1. 第一步:硬件服务器准备工作(1)选择合适的硬件和操作系统linux(2)安装必…

Learning vtkjs之WarpScalar

过滤器 WarpScalar 介绍 先看一个官方的一句话介绍: vtkWarpScalar - deform geometry with scalar data vtkWarpScalar - 使用标量数据变形几何体 详细介绍 vtkWarpScalar is a filter that modifies point coordinates by moving points along point normals by…

spss数据分析是什么 怎么下载spss

什么是SPSS SPSS是社会统计科学软件包的简称, 其官方全称为IBM SPSS Statistics。SPSS软件包最初由SPSS Inc.于1968年推出,于2009年被IBM收购,主要运用于各领域数据的管理和统计分析。作为世界社会科学数据分析的标准,SPSS操作操作…

C++合作开发项目:美术馆1.0

快乐星空MakerZINCFFO 合作入口&#xff1a;CM工作室 效果图&#xff1a; 代码&#xff1a; &#xff08;还有几个音乐&#xff01;&#xff09; main.cpp #include <bits/stdc.h> #include <windows.h> #include <conio.h> #include <time.h> #in…

《数据结构》--顺序表

C语言语法基础到数据结构与算法&#xff0c;前面已经掌握并具备了扎实的C语言基础&#xff0c;为什么要学习数据结构课程&#xff1f;--我们学完本章就可以实践一个&#xff1a;通讯录项目 简单了解过后&#xff0c;通讯录具备增加、删除、修改、查找联系人等操作。要想实现通…

Python学习笔记—100页Opencv详细讲解教程

目录 1 创建和显示窗口... - 4 - 2 加载显示图片... - 6 - 3 保存图片... - 7 - 4 视频采集... - 8 - 5视频录制... - 11 - 6 控制鼠标... - 12 - 7 TrackBar 控件... - 14 - 8.RGB和BGR颜色空间... - 16 - 9.HSV和HSL和YUV.. - 17 - 10 颜色空间的转化... - 18 - …

数据结构——栈的实现(java实现)与相应的oj题

文章目录 一 栈栈的概念:栈的实现&#xff1a;栈的数组实现默认构造方法压栈获取栈元素的个数出栈获取栈顶元素判断当前栈是否为空 java提供的Stack类Stack实现的接口&#xff1a; LinkedList也可以当Stack使用虚拟机栈&#xff0c;栈帧&#xff0c;栈的三个概念 二 栈的一些算…

JetBrains IDE 使用git进行多人合作开发教程

以下DEMO可以用于多人共同开发维护一个项目时&#xff0c;使用Git远程仓库的实践方案 分支管理 dev&#xff1a;开发分支test&#xff1a;测试分支prod&#xff1a;生成分支 个人开发也最起码有一个masterdev&#xff0c;作为主分支和当前开发分支。master永远是稳定版本&am…

花几千上万学习Java,真没必要!(十九)

1、StringBuilder&#xff1a; 测试代码1&#xff1a; package stringbuilder.com; import java.util.ArrayList; import java.util.List; public class StringBuilderExample { public static void main(String[] args) { // 初始化StringBuilder StringBuilder sb n…

腾讯会议产品策划的成长之路:从万字文档到功能落地的实战经验

腾讯会议产品策划的成长之路&#xff1a;从万字文档到功能落地的实战经验 在腾讯会议的产品团队中&#xff0c;有这样一位产品策划&#xff0c;他以其出色的逻辑思维、全局观念以及扎实的执行力&#xff0c;在团队中发挥着举足轻重的作用。他就是林陪同&#xff0c;一个自称“会…

JAVA进阶学习12

文章目录 一、File类1.1 File对象的构造1.2 File对象的常见方法判断功能的方法获取功能的方法绝对路径和相对路径创建删除功能的方法 1.3 File的常用遍历方法1.4 File获取并遍历的其他方法1.5 用法举例二、IO流2.1 IO的分类2.2 字节流的方法概述2.2.1 FileOutputStream2.2.2 Fi…

UE4-字体导入

一.字体导入 方法一&#xff1a; 然后通过导入将自己想要的字体导入到项目中&#xff0c;也可以直接将我们放在桌面的字体直接拖入到我们的内容浏览器中。 但是要注意想要发售游戏的话不可以这样导入微软的字体&#xff0c;因为Windows自带基本都有版权&#xff0c;所以最…

明星应援系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;线上应援管理&#xff0c;线下应援管理&#xff0c;应援物品管理&#xff0c;购买订单管理&#xff0c;集资应援管理&#xff0c;集资订单管理&#xff0c;市集订单管理&#xff0…

【CSS in Depth 2 精译_020】3.3 元素的高度

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第一章 层叠、优先级与继承&#xff08;已完结&#xff09; 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位&#xff08;已完结&#xff09; 2.1 相对…

在 CI/CD 中怎么使用 Docker 部署前端项目?

本项目代码已开源&#xff0c;具体见&#xff1a; 前端工程&#xff1a;vue3-ts-blog-frontend 后端工程&#xff1a;express-blog-backend 数据库初始化脚本&#xff1a;关注公众号程序员白彬&#xff0c;回复关键字“博客数据库脚本”&#xff0c;即可获取。 前言 在上一篇文…