使用PyTorch训练VGG11模型:Fashion-MNIST图像分类实战

本文将通过代码实战,详细讲解如何使用 PyTorch 和 VGG11 模型在 Fashion-MNIST 数据集上进行图像分类任务。代码包含数据预处理、模型定义、训练与评估全流程,并附上训练结果的可视化图表。所有代码可直接复现,适合深度学习初学者和进阶开发者参考。


1. 环境准备

确保已安装以下库:

pip install torch torchvision d2l
2. 代码实现
2.1 导入依赖库
from d2l import torch as d2l
from torchvision import models, transforms
import torch
2.2 数据预处理

由于VGG11默认接受RGB三通道输入,需将Fashion-MNIST的灰度图转换为3通道:

# 定义数据预处理流程
transform = transforms.Compose([transforms.Resize(224),                # 调整图像尺寸为224x224transforms.Grayscale(num_output_channels=3),  # 单通道转三通道transforms.ToTensor()                   # 转为Tensor格式
])
2.3 加载数据集
# 加载Fashion-MNIST数据集并应用预处理
batch_size = 64 * 3  # 增大批大小以利用GPU并行计算
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)# 替换原始数据集的数据增强方法
train_data.dataset.transform = transform
test_data.dataset.transform = transform
2.4 定义模型

使用PyTorch内置的VGG11模型(从头训练,不使用预训练权重):

# 初始化VGG11模型(输入通道为3,输出类别为10)
net = models.vgg11(pretrained=False, num_classes=10)
2.5 模型训练

调用D2L库的封装函数进行训练(支持GPU加速):

# 设置超参数并启动训练
num_epochs = 10
lr = 0.01
device = d2l.try_gpu()  # 自动检测GPU# 开始训练
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
3. 训练结果分析

下图为训练过程中的损失和准确率变化曲线:

关键指标
EpochTrain LossTrain AccTest AccSpeed (examples/sec)
10.8570.2%78.5%112.3
30.31288.6%88.1%117.7
50.3287.6%84.3%118.5
100.2191.8%85.7%119.0
  • 训练损失(Train Loss):随着训练轮次增加,损失快速下降并趋于稳定。例如,第3轮时损失降至 0.312,表明模型快速收敛。

  • 训练准确率(Train Acc):第3轮时达到 88.6%,说明模型对训练数据的学习效果显著。

  • 测试准确率(Test Acc):第3轮测试准确率 88.1%,与训练准确率接近,表明模型泛化能力优秀,未出现明显过拟合。

  • 训练速度:在 cuda:0 设备上达到 117.7 examples/sec,充分利用GPU加速,适合大规模数据训练。

4. 完整代码 
from d2l import torch as d2l
from torchvision import models, transforms
import torch# 数据预处理
transform = transforms.Compose([transforms.Resize(224),transforms.Grayscale(num_output_channels=3),transforms.ToTensor()
])# 加载数据集
batch_size = 64 * 3
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)
train_data.dataset.transform = transform
test_data.dataset.transform = transform# 定义模型
net = models.vgg11(pretrained=False, num_classes=10)# 训练配置
num_epochs = 10
lr = 0.01
device = d2l.try_gpu()# 启动训练
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
5. 常见问题
Q1:为什么将灰度图转为三通道?

VGG系列模型设计时默认接受RGB输入(3通道)。尽管Fashion-MNIST为单通道,需通过复制通道数适配模型。

Q2:如何进一步提升准确率?
  • 增加训练轮次(如 num_epochs=20)。

  • 使用更复杂模型(如VGG16、ResNet)。

  • 添加数据增强(随机旋转、亮度调整)。

Q3:训练时显存不足怎么办?
  • 减小 batch_size(如设为64)。

  • 启用混合精度训练(添加 torch.cuda.amp)。


6. 总结

本文使用PyTorch实现了VGG11模型在Fashion-MNIST数据集上的分类任务,最终测试准确率达 85.7%,并在第3轮即达到 88.1% 的测试准确率,训练速度高达 117.7 examples/sec,展现了优秀的性能与效率。通过代码解析与结果分析,读者可快速掌握从数据预处理到模型训练的完整流程,并根据实际需求调整模型或超参数进一步优化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/77223.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

汽车BMS技术分享及其HIL测试方案

一、BMS技术简介 在全球碳中和目标的战略驱动下,新能源汽车产业正以指数级速度重塑交通出行格局。动力电池作为电动汽车的"心脏",其性能与安全性不仅直接决定了车辆的续航里程、使用寿命等关键指标,更深刻影响着消费者对电动汽车的…

打造船岸“5G+AI”智能慧眼 智驱力赋能客船数智管理

项目介绍 船舶在航行、作业过程中有着严格的规范要求,但在实际航行与作业中往往会因为人为的疏忽,发生事故,导致人员重大伤亡和财产损失; 为推动安全治理模式向事前预防转型,实现不安全状态和行为智能预警&#xff0c…

C#二叉树

C#二叉树 二叉树是一种常见的数据结构,它是由节点组成的一种树形结构,其中每个节点最多有两个子节点。二叉树的一个节点通常包含三部分:存储数据的变量、指向左子节点的指针和指向右子节点的指针。二叉树可以用于多种算法和操作,…

WinForm真入门(11)——ComboBox控件详解

WinForm中 ComboBox 控件详解‌ ComboBox 是 WinForms 中一个集文本框与下拉列表于一体的控件,支持用户从预定义选项中选择或直接输入内容。以下从核心属性、事件、使用场景到高级技巧的全面解析: 一、ComboBox 核心属性‌ 属性说明示例‌Items‌下拉…

超详细解读:数据库MVCC机制

之前文章:Mysql锁_exclusivelock for update写锁-CSDN博客 中有提到通过MVCC来实现快照读,从而解决幻读问题,这里详细介绍下MVCC。 一、前言 表1:实例表t idk1122 表2:事务A、B、C的执行流程 事务A事务B事务Cstart …

【SpringCloud】从入门到精通【上】

今天主播我把黑马新版微服务课程MQ高级之前的内容都看完了,虽然在看视频的时候也记了笔记,但是看完之后还是忘得差不多了,所以打算写一篇博客再温习一下内容。 课程坐标:黑马程序员SpringCloud微服务开发与实战 微服务 认识单体架构 单体架…

力扣hot100_回溯(2)_python版本

一、39. 组合总和(中等) 代码: class Solution:def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:ans []path []def dfs(i: int, left: int) -> None:if left 0:# 找到一个合法组合ans.append(pa…

AI平台如何实现推理?数算岛是一个开源的AI平台(主要用于管理和调度分布式AI训练和推理任务。)

数算岛是一个开源的AI平台,主要用于管理和调度分布式AI训练和推理任务。它基于Kubernetes构建,支持多种深度学习框架(如TensorFlow、PyTorch等)。以下是数算岛实现模型推理的核心原理、架构及具体实现步骤: 一、数算岛…

cesium项目之cesiumlab地形数据加载

之前的文章我们有提到,使用cesiumlab加载地形出现了一些错误,没有解决,今天作者终于找到了解决方法,下面描述一下具体步骤,首先在地理数据云下载dem数据,在cesiumlab中使用地形切片,得到terrain…

[Vue]App.vue讲解

页面中可以看见的内容不再在index.html中进行编辑,而是在App.vue中进行编辑。 组件化开发 在传统的html开发中,一个页面的资源往往都写在同一个html文件中。这种模式在开发小规模、样式简单的项目时会相当便捷,但当项目规模越来越大&#xf…

sql-labs靶场 less-1

文章目录 sqli-labs靶场less 1 联合注入 sqli-labs靶场 每道题都从以下模板讲解,并且每个步骤都有图片,清晰明了,便于复盘。 sql注入的基本步骤 注入点注入类型 字符型:判断闭合方式 (‘、"、’、“”&#xf…

蓝桥杯-小明的彩灯(差分)

问题描述: 差分数组 1. 什么是差分数组? 差分数组 c 是原数组 a 的“差值表示”,其定义如下: c[0] a[0]c[i] a[i] - a[i-1] (i ≥ 1) 差分数组记录了相邻元素的差值。例如,原数组 a [1, …

精品可编辑PPT | 基于湖仓一体构建数据中台架构大数据湖数据仓库一体化中台解决方案

本文介绍了基于湖仓一体构建数据中台架构的技术创新与实践。它详细阐述了数据湖、数据仓库和数据中台的概念,分析了三者的区别与协作关系,指出数据湖可存储大规模结构化和非结构化数据,数据仓库用于高效存储和快速查询以支持决策,…

最近api.themoviedb.org无法连接的问题解决

修改NAS的host需要用到SSH终端连接工具,比如常见的Putty,XShell,或者FinalShell等都可以,我个人还是习惯Putty。 1.输入命令“ sudo -i ”回车,提示输入密码,密码就是我们NAS的登录密码,输入的…

0.机器学习基础

0.人工智能概述: (1)必备三要素: 数据算法计算力 CPU、GPU、TPUGPU和CPU对比: GPU主要适合计算密集型任务;CPU主要适合I/O密集型任务; 【笔试问题】什么类型程序适合在GPU上运行&#xff1…

多类型医疗自助终端智能化升级路径(代码版.下)

医疗人机交互层技术实施方案 一、多模态交互体系 1. 医疗语音识别引擎 # 基于Wav2Vec2的医疗ASR系统 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torchaudioclass MedicalASR:def __init__(self):self.processor = Wav2Vec2Processor.from_pretrai…

前端基础:React项目打包部署服务器教程

问题背景 我做了一个React框架的前端的Node项目,是一个单页面应用。 页面路由用的是,然后使用了React.lazy在路由层级对每一个不同页面进行了懒加载,只有打开那个页面才会加载对应资源。 然后现在我用了Webpack5对项目进行了打包&#xff…

【深度学习:理论篇】--Pytorch基础入门

目录 1.Pytorch--安装 2.Pytorch--张量 3.Pytorch--定义 4.Pytorch--运算 4.1.Tensor数据类型 4.2.Tensor创建 4.3.Tensor运算 4.4.Tensor--Numpy转换 4.5.Tensor--CUDA(GPU) 5.Pytorch--自动微分 (autograd) 5.1.back…

使用 Spring Boot 快速构建企业微信 JS-SDK 权限签名后端服务

使用 Spring Boot 快速构建企业微信 JS-SDK 权限签名后端服务 本篇文章将介绍如何使用 Spring Boot 快速构建一个用于支持企业微信 JS-SDK 权限校验的后端接口,并提供一个简单的 HTML 页面进行功能测试。适用于需要在企业微信网页端使用扫一扫、定位、录音等接口的…

工程师 - FTDI SPI converter

中国网站:FTDIChip- 首页 UMFT4222EV-D UMFT4222EV-D - FTDI 可以下载Datasheet。 UMFT4222EVUSB2.0 to QuadSPI/I2C Bridge Development Module Future Technology Devices International Ltd. The UMFT4222EV is a development module which uses FTDI’s FT4222H…