Pytorch入门实战 P2-CIFAR10彩色图片识别

目录

一、前期准备

1、数据集CIFAR10

2、判断自己的设备,是否可以使用GPU运行。

3、下载数据集,划分好训练集和测试集

4、加载训练集、测试集

5、取一个批次查看下

6、数据可视化

二、搭建简单的CNN网络模型

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、模型训练结果可视化

五、模型训练结果:


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这周的实战内容,主要使用的数据集是CIFAR10数据集。用来验证彩色图片的识别。

一、前期准备

1、数据集CIFAR10

我们使用的数据集的文档地址:Datasets — Torchvision 0.17 documentation

简单介绍下CIFAR10数据集:

CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。

50000张训练图像10000张测试图像

2、判断自己的设备,是否可以使用GPU运行。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

3、下载数据集,划分好训练集和测试集

import torchvision.datasets# 下载训练集
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
# 下载测试集
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)

4、加载训练集、测试集

# 使用dataloader加载数据集,并设置好batch_size
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,shuffle=True,batch_size=batch_size)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size)

5、取一个批次查看下

# 取一个批次,查看下数据
imgs,labels = next(iter(train_dl))
print(imgs.shape)   #  数据的shape为:[batch_size,channel,height,weight]  
'''对于CIFAR10,这里的shape是 [32,3,32,32],即 因为取得是train_dl的数据,batch_size为32;channel为3是因为,是彩色图片RGB的3通道,如果是黑白图片,则channel为1;剩下的32x32是高度和宽度;
'''

6、数据可视化

即:展示下取到的数据。

# 数据可视化
plt.figure(figsize=(20,5))
for i, imgs in enumerate(imgs[:20]):npimg = imgs.numpy().transpose((1,2,0))   #.numpy()用于将Tensor转换为一个Numpy数组。transpose是Numpy数组的一个方法,用于重新排列数组的维度。plt.subplot(2, 10, i+1)plt.imshow(npimg, cmap=plt.cm.binary)plt.axis('off')
plt.show()

运行结果展示: 

二、搭建简单的CNN网络模型

 CNN(卷积神经网络),需要注意其结构、层与层之间的连接关系以及各层的功能。

①卷积层:负责提取特征。(通常使用局部连接权值共享方式,这有助于减少网络的参数数量和计算复杂度。)

②池化层:负责降低数据的空间尺寸和计算复杂度。

③全连接层:负责将提取的特征映射到输出类别。

# 构建简单的CNN网络
num_classes = 10
class Model(nn.Module):def __init__(self):super().__init__()# 特征提取self.conv1 = nn.Conv2d(3, 64, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(64, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.conv3 = nn.Conv2d(64, 128, kernel_size=3)self.pool3 = nn.MaxPool2d(2)# 分类网络self.fc1 = nn.Linear(512, 256)self.fc2 = nn.Linear(256, num_classes)# 前向传播def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.pool3(F.relu(self.conv3(x)))x = torch.flatten(x, start_dim=1)  # 线性层+激活函数  是构建复杂模型的基础x = F.relu(self.fc1(x))x = self.fc2(x)return x# 打印并加载模型
model = Model().to(device)
print(model)

三、训练模型

1、设置超参数

# 1、设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2   #学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)   # 定义一个随机梯度下降优化器,即SGD优化器。# model.parameters() 返回模型中所有可训练的参数(通常是权重和偏置)

2、编写训练函数

# 2、编写训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 数据集的大小,一共60000张图片num_batches = len(dataloader)  # 批次数目 1875 (60000/32 = 1875)train_loss, train_acc = 0, 0   # 初始化训练的损失和正确率for X,y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,y为真实值,计算二者差值,即为损失。# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3、编写测试函数

# 3、编写测试函数
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 数据集的大小,共10000张num_batches = len(dataloader)  # 批次数目 ,313( 10000/32 = 321.5 ,向上取整)test_loss, test_acc = 0, 0  # 初始化测试的损失和精确# 不进行训练时,停止梯度下降,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4、正式训练

# 4、正式训练
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []'''model.train()和model.eval() 是深度学习中常见的两个方法,它们用于设置模型的训练模式和评估模式。①当你调用model.train()时,你正在告诉模型你即将进入训练阶段。通常意味着模型中的某些层(如Dropout层和BatchNormalization层)会改变它们的行为以适应训练过程。Dropout层:在训练模式下,Dropout层会随机将一部分神经元的输出设置为0,有助于防止过拟合。BatchNormalization层:在训练模式下,BatchNoralization层会使用当前批次的数据来更新其运行均值和方差,并应用这些统计量来标准化输入。②当你调用model.eval()时,你正在告诉模型你即将进入评估或推断阶段。在这种模式下,模型的某些层会改变它们的行为,以确保在评估时模型给出一致的结果。
'''
for epoch in range(epochs):model.train()  # 进入训练阶段epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}'print(template.format(epoch+1, epoch_train_acc*100,epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Finish')

四、模型训练结果可视化

# 四、结果可视化
warnings.filterwarnings('ignore')   # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100    # 分辨率epochs_range = range(epochs)  # 生成从0到epoches-1的整数序列plt.figure(figsize=(12,3))  # figsize=(12,3)  包含两个元素的元组,分别代表图形的宽度和高度,单位是英寸。plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')# 在远程服务器上面跑代码,想要保存下,plt.show()的结果,打下下面的注释
# plt.savfig('想要保存的服务器的地址+图片的名称.png/jpg自行定义即可')  
# eg:plt.savefig('/data/jupyter/deepinglearning/resultImg.jpg')plt.show()
print("画图结束。。。")

五、模型训练结果:

这周和上周的代码类似,但是,比起刚开始的时候,好多代码都清晰了很多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/744960.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【CSP试题回顾】201709-3-JSON查询

CSP-201709-3-JSON查询 解题思路 1. 初始化数据结构 map<string, string> strContent: 存储字符串类型属性的内容。键是属性名&#xff08;可能包含通过点.连接的多级属性名&#xff09;&#xff0c;值是属性的字符串值。vector<string> keyVec: 存储当前正在处…

【DAY11 软考中级备考笔记】数据结构 排序操作系统

数据结构 排序&&操作系统 3月14日 – 天气&#xff1a;晴 今天天气非常热&#xff0c;已经到20度了&#xff0c;春天已经来了。 1. 堆排序 堆排序的思想是首先建立一个堆&#xff0c;然后弹出堆顶元素&#xff0c;剩下的元素再形成一个堆&#xff0c;然后继续弹出元素&…

为什么要用scrapy爬虫库?而不是纯python进行爬虫?

为什么要用scrapy爬虫库&#xff1f;而不是纯python进行爬虫&#xff1f; Scrapy的优点Scrapy节省的工作使用纯Python编写爬虫的不足 Scrapy是一个使用Python编写的开源和协作的web爬虫框架&#xff0c;它被设计用于爬取网页数据并从中提取结构化数据。Scrapy的强大之处在于其广…

工科硕士研究生毕业论文撰写总结

工科硕士研究生毕业论文撰写总结 最近一段看了十几篇研究生毕业论文&#xff0c;发现不少问题。结合最近几年当评委及审论文的经验来总结下工科硕士研究生毕业论文撰写毕业论文问题与经验。 一&#xff0e;科技论文的总要求 论文是写给同行看的&#xff0c;注意读者对象。&a…

页面侧边栏顶部固定和底部固定方法

顶部固定用于侧边栏低于屏幕高度----左侧边栏 底部固定用于侧边栏高于屏幕高度----右侧边栏 vue页面方法 页面布局 页面样式&#xff0c;因为内容比较多&#xff0c; 只展示主要代码 * {margin: 0;padding: 0;text-align: center; } .head {width: 100%;height: 88px;back…

在notion里面实现四象限清单

四象限清单是一种时间管理工具&#xff0c;旨在帮助人们根据任务的重要性和紧急性来优先排序他们的工作。这个概念最早由德怀特艾森豪威尔提出&#xff0c;后来又被史蒂芬柯维在他的著作《高效能人士的七个习惯》中进一步普及。四象限清单将任务分为四个类别&#xff1a; 第一…

VueX详解

Vuex 主要应用于Vue.js中管理数据状态的一个库通过创建一个集中的数据存储&#xff0c;供程序中所有组件访问 使用场景 涉及到非父子关系的组件&#xff0c;例如兄弟关系、祖孙关系&#xff0c;甚至更远的关系组件之间的联系中大型单页应用&#xff0c;考虑如何更好地在组件外部…

洛谷 P5018 对称二叉树

题目背景 NOIP2018 普及组 T4 题目描述 一棵有点权的有根树如果满足以下条件&#xff0c;则被轩轩称为对称二叉树&#xff1a; 二叉树&#xff1b;将这棵树所有节点的左右子树交换&#xff0c;新树和原树对应位置的结构相同且点权相等。 下图中节点内的数字为权值&#xf…

window server2012 卸载iis后,远程连接黑屏

原因分析&#xff1a; 因为自己在卸载IIS的时候&#xff0c;不小心卸载了.net framework&#xff0c;系统没有了图形界面&#xff08;由完整模式Full变为了核心模式core&#xff09;&#xff0c;需要重新恢复.net framework4.5。 解决方法分析&#xff1a; 需要将核心模式co…

基于Vue移动端电影票务服务APP设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 相关技术 3 1.1 Vue框架 3 1.2 数据库MongoDB 3 1.3 Axios请求 3 1.4 H5、CSS3和JavaScript 4 1.5 本章小结 4 2 系统分析 5 2.1 功能需求 5 2.2 用例分析 5 2.3 用户功能 6 2.4本章小结 6 3 基于Vue电影票务服务APP设计 7 3.1 页面设计 …

YOLOv9改进策略:注意力机制 |通道注意力和空间注意力CBAM | GAM超越CBAM,不计成本提高精度

&#x1f4a1;&#x1f4a1;&#x1f4a1;本文改进内容&#xff1a;通道注意力和空间注意力CBAM&#xff0c;全新注意力GAM&#xff1a;超越CBAM&#xff0c;不计成本提高精度 改进结构图如下&#xff1a; YOLOv9魔术师专栏 ☁️☁️☁️☁️☁️☁️☁️☁️☁️☁️☁️☁️…

Kotlin/Java中String的equals和==

Kotlin/Java中String的equals和 在Java中&#xff0c;如果定义一个常量String和new出一个String对象&#xff0c;是不同的&#xff1a; String s1 "zhang" String s2 new String("zhang") 因为在Java看来&#xff0c;s1只是一个常量&#xff0c;会放在…

Let’s Move Sui , 一起来学习吧

Let’s Move Sui是一个全新的交互式学习平台&#xff0c;通过SuiFrens的帮助教您如何在Sui上构建。设计供新手和经验丰富的开发者使用&#xff0c;Let’s Move Sui提供了一次非凡的Sui开发之旅&#xff0c;利用了Move在Sui上的独特之处&#xff0c;从基于对象的数据模型的基础知…

飞桨图像分割套件PaddleSeg初探

飞桨图像分割套件PaddleSeg初探 PaddleSeg是基于飞桨PaddlePaddle的端到端图像分割套件&#xff0c;内置45模型算法及140预训练模型&#xff0c;支持配置化驱动和API调用开发方式&#xff0c;打通数据标注、模型开发、训练、压缩、部署的全流程&#xff0c;提供语义分割、交互式…

项目性能优化—性能优化的指标、目标

项目性能优化—性能优化的指标、目标 性能优化的终极目标是什么 性能优化的目标实际上是为了更好的用户体验&#xff1a; 一般我们认为用户体验是下面的公式&#xff1a; 用户体验 产品设计&#xff08;非技术&#xff09; 系统性能 ≈ 系统性能 快 那什么样的体验叫快呢…

交换机/路由器的存储介质-华三

交换机/路由器的存储介质-华三 本文主要介绍网络设备的存储介质组成。 ROM(read-only memory&#xff0c;只读存储器) 用于存储 BootROM程序。BootROM程序是一个微缩的引导程序&#xff0c;主要任务是查找应用程序文件并引导到操作系统&#xff0c;在应用程序文件或配置文件出…

Learn OpenGL 10 Assimp+网格+模型

Assimp 一个非常流行的模型导入库是Assimp&#xff0c;它是Open Asset Import Library&#xff08;开放的资产导入库&#xff09;的缩写。Assimp能够导入很多种不同的模型文件格式&#xff08;并也能够导出部分的格式&#xff09;&#xff0c;它会将所有的模型数据加载至Assim…

WebAssembly探索篇(四)emcc和cmake编译opencv复杂案例

文章目录 开发环境工程目录CMakeLists.txtmain.cpp web端index.html效果图 遇到的问题JS与C传值Uncaught TypeError: Module._malloc is not a functioncanvas像素RGBA四通道 经验&&教训参考 最近因为项目原因&#xff0c;研究了一下WebAssembly。2015年上线与JS、HTML…

C语言——详解字符函数和字符串函数(一)

Hi,铁子们好呀&#xff01;今天博主来给大家更一篇C语言的字符函数和字符串函数~ 具体讲的内容如下&#xff1a; 文章目录 &#x1f386;1.字符分类函数&#x1f4af;&#x1f4af;⏩1.1 什么是字符分类函数的&#xff1f;&#x1f4af;&#x1f4af;⏩1.2 字符函数的类型有哪…

基于Python的中医药知识问答系统设计与实现

[简介] 这篇文章主要介绍了基于Python的中医药知识问答系统的设计与实现。该系统利用Python编程语言&#xff0c;结合中医药领域的知识和技术&#xff0c;实现了一个功能强大的问答系统。文章首先介绍了中医药知识的特点和传统问答系统的局限性&#xff0c;然后提出了设计思路…