使用 PyTorch 实现 ZFNet 进行 MNIST 图像分类

         在本篇博客中,我们将通过两个主要部分来演示如何使用 PyTorch 实现 ZFNet,并在 MNIST 数据集上进行训练和测试。ZFNet(ZFNet)是基于卷积神经网络(CNN)的图像分类模型,广泛用于图像识别任务。

环境准备

        在开始之前,请确保你的环境已经安装了以下依赖:

pip install torch torchvision matplotlib tqdm

一、训练部分:训练 ZFNet 模型

首先,我们需要准备训练数据、定义 ZFNet 模型,并进行模型训练。

1. 数据加载与预处理

MNIST 数据集由 28x28 的手写数字图像组成。我们将通过 torchvision.datasets 来加载数据,并进行必要的预处理。

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from zfnet import ZFNet  # 假设 ZFNet 定义在 zfnet.py 文件中
from tqdm import tqdm  # 导入 tqdm
from torch.cuda.amp import autocast, GradScaler  # 导入混合精度训练def prepare_data(batch_size=128, num_workers=2, data_dir='D:/workspace/data'):"""准备 MNIST 数据集并返回数据加载器:param batch_size: 批处理大小:param num_workers: 数据加载的工作线程数:param data_dir: 数据存储的目录:return: 训练数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 正则化])trainset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)return trainloader

2. 初始化模型与优化器

在这里,我们将初始化模型和优化器。我们选择 Adam 优化器,并且为提高计算效率,我们采用混合精度训练。

def initialize_device():"""初始化计算设备(GPU 或 CPU):return: 计算设备"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")return devicedef initialize_model(device):"""初始化模型并移动到指定设备:param device: 计算设备:return: 初始化好的模型"""model = ZFNet().to(device)  # 假设 ZFNet 是自定义模型return modeldef initialize_optimizer(model, lr=0.001):"""初始化优化器:param model: 需要优化的模型:param lr: 学习率:return: 优化器"""optimizer = optim.Adam(model.parameters(), lr=lr)return optimizer

3. 训练模型

使用训练数据进行训练,并且每训练一个 epoch 就更新一次进度条,同时使用混合精度训练来提高效率。

def train_model(model, trainloader, criterion, optimizer, num_epochs=5, device='cuda'):"""训练模型:param model: 训练的模型:param trainloader: 数据加载器:param criterion: 损失函数:param optimizer: 优化器:param num_epochs: 训练的轮数:param device: 计算设备"""scaler = GradScaler()  # 用于自动缩放梯度for epoch in range(num_epochs):model.train()running_loss = 0.0# 使用 tqdm 包裹 DataLoader 来显示进度条with tqdm(trainloader, unit="batch", desc=f"Epoch {epoch + 1}/{num_epochs}") as tepoch:for inputs, labels in tepoch:# 直接将数据和标签移动到 GPUinputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)optimizer.zero_grad()# 混合精度前向和反向传播with autocast():  # 自动混合精度outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播与优化scaler.scale(loss).backward()  # 使用混合精度反向传播scaler.step(optimizer)  # 更新参数scaler.update()  # 更新缩放因子running_loss += loss.item()# 更新进度条显示tepoch.set_postfix(loss=running_loss / (tepoch.n + 1))# 打印每个 epoch 的平均损失print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}")# 保存模型torch.save(model.state_dict(), 'zfnet_model.pth')print("Model saved as zfnet_model.pth")

4. 主函数

在主函数中,我们会初始化设备、模型、损失函数,并启动训练过程。

if __name__ == '__main__':"""主函数:组织所有步骤的执行"""# 数据加载trainloader = prepare_data()# 设备选择device = initialize_device()# 模型初始化model = initialize_model(device)# 损失函数criterion = torch.nn.CrossEntropyLoss()# 优化器初始化optimizer = initialize_optimizer(model)# 启动训练train_model(model, trainloader, criterion, optimizer, num_epochs=5, device=device)

二、测试部分:评估 ZFNet 模型

训练完成后,我们将加载训练好的模型,并在测试集上评估其性能。

1. 加载和预处理数据
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from zfnet import ZFNet  # 假设 ZFNet 定义在 zfnet.py 文件中def load_and_preprocess_data(batch_size=1000):"""加载并预处理 MNIST 数据集:param batch_size: 数据加载的批次大小:return: 测试数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载 MNIST 测试集testset = datasets.MNIST(root='D:/workspace/data', train=False, download=True, transform=transform)# 数据加载器testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)return testloader

2. 加载训练好的模型
def load_and_preprocess_data(batch_size=1000):"""加载并预处理 MNIST 数据集:param batch_size: 数据加载的批次大小:return: 测试数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载 MNIST 测试集testset = datasets.MNIST(root='D:/workspace/data', train=False, download=True, transform=transform)# 数据加载器testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)return testloaderdef load_trained_model(model_path='zfnet_model.pth'):"""加载训练好的模型:param model_path: 模型文件路径:return: 加载的模型"""model = ZFNet()model.load_state_dict(torch.load(model_path))model.eval()  # 设置为评估模式return model

3. 评估模型
def evaluate_model(model, testloader):"""评估模型在测试集上的表现:param model: 训练好的模型:param testloader: 测试数据加载器:return: 模型准确率"""correct = 0total = 0with torch.no_grad():for inputs, labels in testloader:outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalreturn accuracy

4. 可视化预测结果
def visualize_predictions(model, testloader, num_images=6):"""可视化模型对多张测试图片的预测结果:param model: 训练好的模型:param testloader: 测试数据加载器:param num_images: 显示图像的数量"""model.eval()data_iter = iter(testloader)images, labels = next(data_iter)outputs = model(images)_, predicted = torch.max(outputs, 1)# 绘制结果fig, axes = plt.subplots(2, 3, figsize=(10, 7))axes = axes.ravel()for i in range(num_images):ax = axes[i]img = images[i].numpy().transpose(1, 2, 0)  # 将 Tensor 转换为 NumPy 数组并转置为 HWC 格式ax.imshow(img.squeeze(), cmap='gray')  # squeeze 去除单通道维度ax.set_title(f"Pred: {predicted[i].item()} | Actual: {labels[i].item()}")ax.axis('off')plt.tight_layout()plt.show()

5. 主函数

在测试阶段,我们加载模型并在测试数据集上评估它。

def main():"""主函数,组织数据加载、模型加载、评估和可视化步骤"""# 加载并预处理数据testloader = load_and_preprocess_data()# 加载训练好的模型model = load_trained_model()# 评估模型accuracy = evaluate_model(model, testloader)print(f"Accuracy: {accuracy * 100:.2f}%")# 可视化预测结果visualize_predictions(model, testloader, num_images=6)if __name__ == '__main__':main()


结语

通过本文的介绍,我们实现了一个基于 ZFNet 模型的图像分类任务,使用 PyTorch 对 MNIST 数据集进行训练与测试,并展示了如何进行混合精度训练以提高效率。在未来,你可以根据不同的任务修改模型结构、优化器或者训练策略,进一步提升性能。


完整项目ZFNet-PyTorch: 使用 PyTorch 实现 ZFNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/zfnet-py-torch


  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/886757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机网络实验】之静态路由配置

【计算机网络实验】之静态路由配置 实验题目实验目的实验任务实验设备实验环境实验步骤路由器配置设置静态路由测试路由器之间的连通性配置主机PC的IP测试 实验题目 静态路由协议的配置 实验目的 熟悉路由器工作原理和机制;巩固静态路由理论;设计简单…

driver.js实现页面操作指引

概述 在访问某些网站的时候,第一次进去你会发现有个操作指引,本文引用driver.js,教你在你的页面也加入这般高大上的操作指引。 实现效果 实现 driver.js简介 driver.js是一个功能强大且高度可定制的基于原生JavaScript开发的新用户引导库…

无人机航测技术算法概述!

一、核心技术 传感器技术: GPS/GLONASS:无人机通过卫星定位系统实现高精度的飞行控制和数据采集。 高清相机:用于拍摄地面图像,通过后续图像处理生成三维模型。 激光雷达(LiDAR):通过激光扫…

Docker 基础命令介绍和常见报错解决

介绍一些 docker 可能用到的基础命令,并解决三个常见报错: 权限被拒绝(Permission Denied)无法连接到 Docker 仓库(Timeout Exceeded)磁盘空间不足(No Space Left on Device) 命令以…

Java RPC框架的接口预热原理及无损实现

🚀 博主介绍:大家好,我是无休居士!一枚任职于一线Top3互联网大厂的Java开发工程师! 🚀 🌟 在这里,你将找到通往Java技术大门的钥匙。作为一个爱敲代码技术人,我不仅热衷…

java的强,软,弱,虚引用介绍以及应用

写在前面 本文看下Java的强,软,弱,虚引用相关内容。 1:各种引用介绍 顶层类是java.lang.ref.Reference,注意是一个抽象类,而不是接口,其中比较重要的引用队列ReferenceQueue就在该类中定义,子…

已有docker增加端口号,不用重新创建Docker

已有docker增加端口号,不用重新创建Docker 1. 整体描述2. 具体实现2.1 查看容器id2.2 停止docker服务2.3 修改docker配置文件2.4 重启docker服务 3. 总结 1. 整体描述 docker目前使用的非常多,但是每次更新都需要重新创建docker,也不太方便&…

jmeter常用配置元件介绍总结之断言

系列文章目录 1.windows、linux安装jmeter及设置中文显示 2.jmeter常用配置元件介绍总结之安装插件 3.jmeter常用配置元件介绍总结之线程组 4.jmeter常用配置元件介绍总结之函数助手 5.jmeter常用配置元件介绍总结之取样器 6.jmeter常用配置元件介绍总结之jsr223执行pytho…

OpenLayers教程12_WebGL自定义着色器:实现高级渲染效果

在 OpenLayers 中使用 WebGL 自定义着色器实现高级渲染效果 目录 一、引言二、WebGL 自定义着色器的优势三、示例应用:实现动态渲染效果 1. 项目结构2. 主要代码实现3. 运行与效果 四、代码讲解与扩展 1. 动态圆的半径和填充颜色2. 动态透明度与边框效果 五、总结…

Axure二级菜单下拉交互实例

1.使用boxlabe进行基础布局 2.设置鼠标悬浮和选中状态 3.转换为动态面板 选中所有二级菜单,进行按钮组转换 选中所有二级菜单,进行动态面板转换 4.给用户管理增加显示/隐藏事件 1)选择toggle代表上拉和下拉切换加载 2)勾选Bring to Front,并选择Push/Pull Widgets代表收缩时…

SpringSecurity+OAuth2权限管理

Spring Security 零 介绍 功能: 身份认证(authentication) 授权(authorization) 防御常见攻击(protection against common attacks) 身份认证: 身份认证是验证谁正在访问系统资…

为什么芯麦的 GC4931P 可以替代A4931/Allegro 的深度对比介绍

在电机驱动芯片领域,芯麦 GC4931P 和 A4931 都是备受关注的产品。它们在多种应用场景中发挥着关键作用,今天我们就来详细对比一下这两款芯片。 一、性能参数对比 (一)电流输出能力 A4931 具有一定的电流输出能力,但芯…

ThreadLocal原理及其内存泄漏

ThreadLocal通过为每个线程创建一个共享变量的副本来保证各个线程之间变量的访问和修改互不影响。 ThreadLocal存放的值是线程内共享的,线程间互斥的,主要用于线程内共享数据,避免通过参数传递。 ThreadLocal有四个方法: initialV…

工业大数据分析与应用:开启智能制造新时代

在全球工业4.0浪潮的推动下,工业大数据分析已经成为推动智能制造、提升生产效率和优化资源配置的重要工具。通过收集、存储、处理和分析海量工业数据,企业能够获得深刻的业务洞察,做出更明智的决策,并实现生产流程的全面优化。本文…

web安全测试渗透案例知识点总结(上)——小白入狱

目录 一、Web安全渗透测试概念详解1. Web安全与渗透测试2. Web安全的主要攻击面与漏洞类型3. 渗透测试的基本流程 二、知识点详细总结1. 常见Web漏洞分析2. 渗透测试常用工具及其功能 三、具体案例教程案例1:SQL注入漏洞利用教程案例2:跨站脚本&#xff…

每天五分钟机器学习:支持向量机算法数学基础之核函数

本文重点 从现在开始,我们将开启支持向量机算法的学习,不过在学习支持向量机算法之前,我们先来学习一些支持向量机所依赖的数学知识,这会帮助我们更加深刻的理解支持向量机算法,本文我们先来学习核函数。 定义 核函数(Kernel Function)是一种在支持向量机(SVM)、高…

【小程序】dialog组件

这个比较简单 我就直接上代码了 只需要传入title即可&#xff0c; 内容部分设置slot 代码 dialog.ttml <view class"dialog-wrapper" hidden"{{!visible}}"><view class"mask" /><view class"dialog"><view …

【MySQL】ubantu 系统 MySQL的安装与免密码登录的配置

&#x1f351;个人主页&#xff1a;Jupiter. &#x1f680; 所属专栏&#xff1a;MySQL初阶探索&#xff1a;构建数据库基础 欢迎大家点赞收藏评论&#x1f60a; 目录 &#x1f4da;mysql的安装&#x1f4d5;MySQL的登录&#x1f30f;MySQL配置免密码登录 &#x1f4da;mysql的…

Dubbo源码解析-服务注册(五)

一、服务注册 当确定好了最终的服务配置后&#xff0c;Dubbo就会根据这些配置信息生成对应的服务URL&#xff0c;比如&#xff1a; dubbo://192.168.65.221:20880/org.apache.dubbo.springboot.demo.DemoService? applicationdubbo-springboot-demo-provider&timeout300…

计算机网络-理论部分(二):应用层

网络应用体系结构 Client-Server客户-服务器体系结构&#xff1a;如Web&#xff0c;FTP&#xff0c;Telnet等Peer-Peer&#xff1a;点对点P2P结构&#xff0c;如BitTorrent 应用层协议定义了&#xff1a; 交换的报文类型&#xff0c;请求or响应报文类型的语法字段的含义如何…