【Week-P2】CNN彩色图片分类-CIFAR10数据集

文章目录

  • 一、环境配置
  • 二、准备数据
  • 三、搭建网络结构
  • 四、开始训练
  • 五、查看训练结果
  • 六、总结
    • 3.1 ⭐ `torch.nn.Conv2d()`详解
    • 3.2 ⭐ `torch.nn.Linear()`详解
    • 3.3 ⭐`torch.nn.MaxPool2d()`详解
    • 3.4 ⭐ 关于卷积层、池化层的计算
    • 4.2.1 `optimizer.zero_grad()`说明
    • 4.2.2 `loss.backward()`说明
    • 4.2.3 `optimizer.step()`说明
    • 4.4.1 `model.train()`说明
    • 4.4.2 `model.eval()`说明

本文采用CIFAR10数据集,通过简单CNN来实现彩色图片识别。

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、环境配置

# 1. 设置环境
import sys
from datetime import datetimeimport torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvisionprint("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("Pytorch version: " + torch.__version__)
print("Python version: " + sys.version)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

在这里插入图片描述

二、准备数据

导入数据的方式和【Week P1】中的方法是一致的,都是通过dataset下载数据集、通过dataloader加载数据集。

'''
2. 导入数据使用dataset下载CIFAR10数据集,并划分好训练集与测试集使用dataloader加载数据,并设置好基本的batch_size
'''
print("---------------------2.1 下载CIFAR10数据集,并划分训练集和测试集------------------")
train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensordownload=True)test_ds  = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensordownload=True)print("---------------------2.2 设置batch_size------------------")
batch_size = 32train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)test_dl  = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)print("---------------------2.2.1 取一个批次查看数据格式------------------")
# 取一个批次查看数据格式
# 数据的shape为:[batch_size, channel, height, weight]
# 其中batch_size为自己设定,channel,height和weight分别是图片的通道数,高度和宽度。
imgs, labels = next(iter(train_dl))
imgs.shapeprint("---------------------2.3 数据可视化------------------")
import numpy as np# 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5)) 
for i, imgs in enumerate(imgs[:20]):# 维度缩减npimg = imgs.numpy().transpose((1, 2, 0))# 将整个figure分成2行10列,绘制第i+1个子图。plt.subplot(2, 10, i+1)plt.imshow(npimg, cmap=plt.cm.binary)plt.axis('off')#plt.show()  如果你使用的是Pycharm编译器,请加上这行代码

等待漫长的4h35min后:
在这里插入图片描述

三、搭建网络结构

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类。

用到的运算主要有:卷积、池化

网络结构:
在这里插入图片描述

以下几点涉及到的内容,统一在文末说明:
3.1 ⭐ torch.nn.Conv2d()详解
3.2 ⭐ torch.nn.Linear()详解
3.3 ⭐torch.nn.MaxPool2d()详解
3.4 ⭐ 关于卷积层、池化层的计算

print("---------------------3.1 定义简单CNN网络,要点:卷积和池化运算------------------")
import torch.nn.functional as Fnum_classes = 10  # 图片的类别数class Model(nn.Module):def __init__(self):super().__init__()# 特征提取网络self.conv1 = nn.Conv2d(3, 64, kernel_size=3)   # 第一层卷积,卷积核大小为3*3self.pool1 = nn.MaxPool2d(kernel_size=2)       # 设置池化层,池化核大小为2*2self.conv2 = nn.Conv2d(64, 64, kernel_size=3)  # 第二层卷积,卷积核大小为3*3   self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) # 第二层卷积,卷积核大小为3*3   self.pool3 = nn.MaxPool2d(kernel_size=2) # 分类网络self.fc1 = nn.Linear(512, 256)          self.fc2 = nn.Linear(256, num_classes)# 前向传播def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))     x = self.pool2(F.relu(self.conv2(x)))x = self.pool3(F.relu(self.conv3(x)))x = torch.flatten(x, start_dim=1)x = F.relu(self.fc1(x))x = self.fc2(x)return xprint("---------------------3.2 加载和打印网络结构------------------")
from torchinfo import summary
# 将模型转移到GPU中(我们模型运行均在GPU中进行)
model = Model().to(device)summary(model)

在这里插入图片描述

四、开始训练

4.2 编写训练函数中,用到的函数有:

  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()

在文末说明每个函数的使用方法

4.3 编写测试函数中:

  • 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

4.4 正式训练中,使用的训练方法包括:

  • model.train():作用是启用 Batch Normalization 和 Dropout
  • model.eval():作用是不启用 Batch Normalization 和 Dropout
# 4. 训练模型
print("---------------------4.1 设置超参数------------------")
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)print("---------------------4.2 编写训练函数-----------------")
# 训练循环
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片num_batches = len(dataloader)   # 批次数目,1875(60000/32)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_lossprint("---------------------4.3 编写测试函数-----------------")
def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_lossprint("---------------------4.4 正式训练-----------------")
epochs     = 10
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

在这里插入图片描述

五、查看训练结果

print("---------------------5. 查看训练结果-----------------")
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述
可以看到,训练10个epoch后的效果是非常差的,训练准确率和测试准确率都不到60%。

六、总结

3.1 ⭐ torch.nn.Conv2d()详解

函数原型:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

关键参数说明:

  • in_channels ( int ):输入图像中的通道数
  • out_channels ( int ) : 卷积产生的通道数
  • kernel_size ( int or tuple ) :卷积核的大小
  • stride ( int or tuple , optional ) :卷积的步长。默认值:1
  • padding ( int , tuple或str , optional ) : 添加到输入的所有四个边的填充。默认值:0
  • dilation (int or tuple, optional):膨胀操作,控制kernel点(卷积核点)的间距,默认值:1。
  • padding_mode (字符串,可选) : ‘zeros’, ‘reflect’, ‘replicate’或’circular’. 默认:‘zeros’
  • 关于dilation参数图解:
    在这里插入图片描述

3.2 ⭐ torch.nn.Linear()详解

函数原型:

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

关键参数说明:

  • in_features:每个输入样本的大小
  • out_features:每个输出样本的大小

3.3 ⭐torch.nn.MaxPool2d()详解

函数原型:

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

关键参数说明:

  • kernel_size:最大的窗口大小
  • stride:窗口的步幅,默认值为kernel_size(核的大小)
  • padding:填充值,默认为0
  • dilation:控制窗口中元素步长的参数

3.4 ⭐ 关于卷积层、池化层的计算

下面的网络数据shape变化过程为:

3, 32, 32(输入数据)→ 64, 30, 30(经过卷积层1)→ 64, 15, 15(经过池化层1)→ 64, 13, 13(经过卷积层2)→ 64, 6, 6(经过池化层2)→ 128, 4, 4(经过卷积层3) → 128, 2, 2(经过池化层3)→ 512 -> 256→ num_classes(10)

计算过程如下:
(1)卷积输出shape公式:
在这里插入图片描述
输入数据为:[3, 32, 32],即图片矩阵大小为32*32,卷积核大小为3,填充步长为默认值0,步长为默认值1,代入计算得到输出的大小为:30*30,输出通道不变,所以输入数据[3, 32, 32]经过Conv1层后得到的shape为·[64, 30, 30]·。

(2)池化输出公式:
在这里插入图片描述
输入的数据格式(从Conv1得到)是:[64, 30, 30] [C*Hin*Win],已知:Hin=30,padding=0,dilation=1,kernel_size=2,stride=2(即kernel_size),代入上述池化公式,可得Hout=15
同理,Wout=15,C保持不变,故而output.shape [64, 15, 15]
在这里插入图片描述

4.2.1 optimizer.zero_grad()说明

  • optimizer.zero_grad()函数会遍历模型的所有参数,通过内置方法截断反向传播的梯度流,再将每个参数的梯度值设为0,即上一次的梯度记录被清空。

4.2.2 loss.backward()说明

  • PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

  • 具体来说,torch.tensorautograd包的基础类,如果设置tensorrequires_gradsTrue,就会开始跟踪在这个tensor上的所有运算,如果做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

  • 更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个wrequires_gradsTrue,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w.grad属性中。

  • 如果没有进行tensor.backward()的话,梯度值将会是None因此loss.backward()要写在optimizer.step()之前

4.2.3 optimizer.step()说明

  • step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。

  • 注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。

4.4.1 model.train()说明

  • model.train()的作用是:启用 Batch NormalizationDropout

  • 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()

  • model.train()是保证BN层能够用到每一批数据的均值和方差。

  • 对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

4.4.2 model.eval()说明

  • model.eval()的作用是:不启用 Batch NormalizationDropout

  • 如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()

  • model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。

  • 对于Dropout,model.eval()是将所有网络连接都利用起来,即不进行随机舍弃神经元。

  • 训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/237202.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis Plus使用遇到的问题

如果想使用Mapper的xxxById()方法,实体类的主键上面必须加上TableId注解,如果不加,会报错 2023-12-21 22:48:33.526 WARN 11212 --- [ main] c.b.m.core.injector.DefaultSqlInjector : class com.example.mybatisplusdemo.dom…

ubuntu18.04 64 位安装笔记——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务2:离线数据处理

进入VirtuakBox官网,网址链接:Oracle VM VirtualBoxhttps://www.virtualbox.org/ 网页连接:Ubuntu Virtual Machine Images for VirtualBox and VMwarehttps://www.osboxes.org/ubuntu/ 将下发的ds_db01.sql数据库文件放置mysql中 12、编写S…

无约束优化问题求解笔记(2):最速下降法

目录 3. 最速下降法3.1 最速下降法的基本思想3.2 基于精确搜索的最速下降法3.3 基于精确搜索的最速下降法的程序实现3.4 基于精确搜索的最速下降法的缺点 Reference 3. 最速下降法 3.1 最速下降法的基本思想 最速下降法是典型的线搜索方法. 设 f f f 是 R n \mathbb{R}^n R…

Easyexcel读取单/多sheet页

Easyexcel读取单/多sheet页 此文档会说明单个和多个的sheet页的读取方法,包括本人在使用过程中的踩坑点。 依赖不会的自行百度导入,话不多说,直接上干货。以下示例基于2.x,新版本基本类似 1、创建实体 实体是用来接收对应列的数据…

【QT】QGraphicsView和QGraphicsItem坐标转换

坐标转换 QGraphicsItem和QGraphicsView之间的坐标转换需要通过QGraphicsScene进行转换 QGraphicsView::mapToScene() - 视图 -> 场景QGraphicsView::mapFromScene() - 场景 -> 视图QGraphicsItem::mapToScene() - 图元 -> 场景QGraphicsItem::mapFromScene() - 场景 …

C++ Qt开发:StringListModel字符串列表映射组件

Qt 是一个跨平台C图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍QStringListModel字符串映射组件的常用方法及…

线程(四)

线程(一) ~ 线程(四)章节导图 导图https://naotu.baidu.com/file/07f437ff6bc3fa7939e171b00f133e17 线程安全 什么是线程安全? 业务中多线程同时访问一个对象或方法时我们不需要做额外的处理(像单线程编程一样)程序可以正常运行并能获取…

JS模块化规范之ES6及UMD

JS模块化规范之ES6及总结 前言ES6模块化概念基本使用ES6实现 UMD(Universal Module Definition)总结 前言 ESM在模块之间的依赖关系是高度确定的,与运行状态无关,编译工具只需要对ESM模块做静态分析,就可以从代码字面中推断出哪些模块值未曾被…

RocketMQ系统性学习-RocketMQ原理分析之Broker接收消息的处理流程

🌈🌈🌈🌈🌈🌈🌈🌈 【11来了】文章导读地址:点击查看文章导读! 🍁🍁🍁🍁🍁🍁&#x1f3…

【git学习笔记 01】打标签

文章目录 一、声明二、对标签的基本认知什么是标签?为什么要打标签?如何生成类似github中readme的图标 三、标签相关命令四、示例操作 一、声明 本帖持续更新中如有纰漏,望批评指正!参考视频链接,非常感谢原作者&…

5 分钟内搭建一个免费问答机器人:Milvus + LangChain

搭建一个好用、便宜又准确的问答机器人需要多长时间? 答案是 5 分钟。只需借助开源的 RAG 技术栈、LangChain 以及好用的向量数据库 Milvus。必须要强调的是,该问答机器人的成本很低,因为我们在召回、评估和开发迭代的过程中不需要调用大语言…

Backtrader 文档学习-Data Feeds(下)

Backtrader 文档学习-Data Feeds(下) 1. Data Resampling 当数据仅在单个时间范围内可用,需要在不同的时间范围内进行分析时,就需要进行一些重采样。 “重采样”实际上应该称为“上采样”,因为它是从一个源时间区间到…

C++的泛型编程—模板

目录 一.什么是泛型编程? ​编辑 ​编辑 二.函数模板 函数模板的实例化 当不同类型形参传参时的处理 使用多个模板参数 三.模板参数的匹配原则 四.类模板 1.定义对象时要显式实例化 2.类模板不支持声明与定义分离 3.非类型模板参数 4.模板的特化 函数模板…

MySQL的安装及如何连接到Navicat和IntelliJ IDEA

MySQL的安装及如何连接到Navicat和IntelliJ IDEA 文章目录 MySQL的安装及如何连接到Navicat和IntelliJ IDEA1 MySQL安装1.1 下载1.2 安装(解压)1.3 配置1.3.1 添加环境变量1.3.2 新建配置文件1.3.3 初始化MySQL1.3.4 注册MySQL服务1.3.5 启动MySQL服务1.3.6 修改默认账户密码 1…

Windows中安装nvm进行Node版本控制

1.nvm介绍 nvm英文全程也叫node.js version management,是一个node.js的版本管理工具。nvm和npm都是node.js版本管理工具,但是为了解决node各种不同之间版本存在不兼容的问题,因此可以通过nvm安装和切换不同版本的node。 2.nvm下载 可在点…

6个免费设计资源站,设计师们赶紧收藏!

本期给大家分享5个免费的设计资源站,设计师必备的设计设计神奇,绝对能帮助你在工作中事半功倍,赶紧收藏吧~ 1、菜鸟图库 https://www.sucai999.com/?vNTYwNDUx 菜鸟图库是我推荐过很多次的网站,主要是站内素材多,像…

PHPStorm一站式配置

phpstorm安装好之后,先别急着编码。工欲善其事,必先利其器,配置好下面这些之后让编码事半功倍。 主题 Appearance & Behavior -> Appearance -> Theme 选中 [Light with Light Header] 亮色较为护眼 关闭更新 Appearance & …

C#学习笔记 - C#基础知识 - C#从入门到放弃 - C# 方法

C# 入门基础知识 - 方法 第8节 方法8.1 C# 函数/方法简介8.2 方法的声明及调用8.2.1 参数列表方法的声明及调用8.2.2 参数数组方法的声明及调用8.2.3、引用参数与值参数 8.3 静态方法和实例方法8.3.1 静态、实例方法的区别8.2.3 静态、实例方法的声明及其调用 8.4 虚方法8.4.1 …

Linux学习(3)——基本命令-文件

1、cat:查看文件内容--上下合并文件 注意:cat只能查看普通的文本文件 如果文件内容过多会显示不全 选项效果-n显示行号包括空行-b跳过空白行编号;注意,在一行打了空格不算空白行,enter键直接跳过这一行才算-s将所有连续…

【JAVA】CyclicBarrier源码解析以及示例

文章目录 前言CyclicBarrier源码解析以及示例主要成员变量核心方法 应用场景任务分解与合并应用示例 并行计算应用示例 游戏开发应用示例输出结果 数据加载应用示例 并发工具的协同应用示例 CyclicBarrier和CountDownLatch的区别循环性:计数器的变化:用途…