CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660455.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

易优CMS采集插件使用教程

本易优CMS采集教程说明如何使用易优CMS采集插件&#xff0c;批量获取互联网上的文章数据&#xff0c;并自动更新到易优cms&#xff08;eyoucms&#xff09;网站&#xff0c;快速丰富网站的内容。 目录 1. 下载并安装易优CMS采集插件 2. 对接网页文章采集工具 3. 采集数据发…

GPT-4级别模型惨遭泄露!引爆AI社区,“欧洲版OpenAI”下场认领

大家好&#xff0c;我是二狗。 这两天&#xff0c;一款性能接近GPT-4的模型惨遭泄露&#xff0c;引发了AI社区的热议。 这背后究竟是怎么回事呢&#xff1f; 起因是1月28日&#xff0c;一位名为“Miqu Dev”的用户在 HuggingFace 上发布了一组文件&#xff0c;这些文件共同组…

Steam爆火游戏幻兽帕鲁自建多人联机专用服务器配置要求

《幻兽帕鲁》这款多人游戏模式的全新开放世界生存制作游戏&#xff0c;在短短上线5天就卖出700万份&#xff0c;同时在线人数最高达到了180万人&#xff0c;创下Steam历史榜单第二名的好成绩&#xff0c;意料之外的爆火也一度导致幻兽帕鲁出现无法创建4人游戏房间、官方服务器连…

C语言-算法-最短路

【模板】Floyd 题目描述 给出一张由 n n n 个点 m m m 条边组成的无向图。 求出所有点对 ( i , j ) (i,j) (i,j) 之间的最短路径。 输入格式 第一行为两个整数 n , m n,m n,m&#xff0c;分别代表点的个数和边的条数。 接下来 m m m 行&#xff0c;每行三个整数 u …

VUE3:组合式API生命周期

1、onMounted 注册一个回调函数&#xff0c;在组件挂载完成后执行。 组件在以下情况下被视为已挂载&#xff1a; – 1. 其所有同步子组件都已经被挂载。 – 2. 其自身的 DOM 树已经创建完成并插入了父容器中。注意仅当根容器在文档中时&#xff0c;才可以保证组件 DOM 树也在文…

已定式,未定式【高数笔记】

【已定式】 将x-->? 的过程代入到lim中&#xff0c;如果得出的结果可以判断出&#xff0c;lim是有极限的&#xff0c;则为已定式 [举例] lim(1/x)&#xff0c;x--> 无穷 &#xff0c;即&#xff0c;1/ 无穷 0 &#xff0c;所以为已定式 【未定式】 将x-->? 的过程代…

docker 搭建 Seafile 集成 onlyoffice

docker-compose一键部署yaml文件 version: "3"services:db:image: mariadb:10.11container_name: seafile-mysqlenvironment:- MYSQL_ROOT_PASSWORDdb_dev # Requested, set the roots password of MySQL service.- MYSQL_LOG_CONSOLEtruevolumes:- /share/ZFS18_D…

Rust - 变量

不管学什么语言好像都得从变量开始&#xff0c;不过只需要懂得大概就可以了。 但在Rust里不先把变量研究明白后面根本无法进行… 变量绑定 变量赋值❌ 变量绑定✔️ Rust中没有“赋值”一说&#xff0c;而是称为绑定。 int a 3; //C中的变量赋值 a 3; //python中的…

智慧工地可视化综合管理云平台 PC+APP

目录 一、智慧工地可视化数据大屏功能一览 1.首页 2.视频监控 3.机械设备 4.环境监测 5.安全管理 6.质量管理 7.劳务分析 8.进度管理 9.报警统计 二、项目人员管理 1.信息管理 2.信息采集 3.证件管理 危大工程管理 一、智慧工地可视化数据大屏功能一览 包括&am…

transformer_多头注意力机制代码笔记

transformer_多头注意力机制代码笔记 以GPT-2中多头注意力机制代码为例 class CausalSelfAttention(nn.Module):"""因果掩码多头自注意力机制A vanilla multi-head masked self-attention layer with a projection at the end.It is possible to use torch.nn…

【C语言】const修饰指针的不同作用

目录 const修饰变量 const修饰指针变量 ①不用const修饰 ②const放在*的左边 ③const放在*的右边 ④*的左右两边都有const 结论 const修饰变量 变量是可以修改的&#xff0c;如果把变量的地址交给⼀个指针变量&#xff0c;通过指针变量的也可以修改这个变量。 但…

电脑文件打不开是什么原因?常见原因有这9点

在日常生活和工作中&#xff0c;我们经常会使用电脑来处理文件。然而&#xff0c;有时候我们会遇到电脑文件打不开的情况&#xff0c;这给我们的工作和生活带来了很大的不便。本文将为大家介绍电脑文件打不开的原因&#xff0c;帮助大家更好地应对这一问题。 原因1、文件格式问…

交易策略开发:如何揣摩投资心理,研究交易策略

文章目录 揣摩其他投资者的心理&#xff0c;首先要知道他们学习了什么投资知识。永远记住策略一定是弱于机制的。每种交易技术是如何做交易的&#xff0c;各位可以对号入座**马丁网格类均线&#xff0c;MACD等指标类价格行为类缠论类对冲套利类基本面类订单流资金流秘籍类 揣摩…

论文解读:DeepBDC小样本图像分类

Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification 摘要 由于每个新任务只给出很少的训练样例&#xff0c;所以few -shot分类是一个具有挑战性的问题。解决这一挑战的有效研究路线之一是专注于学习由查询图像和某些类别的少数支持…

shell脚本自动备份数据库表

今日目标&#xff1a;shell脚本自动备份数据库中的表并记录执行日志和mysql输出日志 编写思路&#xff1a; &#xff08;1&#xff09;shell脚本运行mysql命令 &#xff08;2&#xff09;脚本输出记录到日志中 &#xff08;3&#xff09;定时任务自动执行shell脚本 1、she…

【Tomcat与网络9】提高Tomcat启动速度的八大措施

本文我们来看一下如何对Tomcat进行调优&#xff0c;我们对于Tomcat的调优主要集中在三个方面&#xff1a;提高启动速度、提高系统稳定性和提高并发能力&#xff0c;后两者很多时候是相辅相成的&#xff0c;我们放在一起看。 Tomcat现在一般都嵌入在SpringBoot里&#xff0c;因…

Linux 驱动开发基础知识——总线设备驱动模型(八)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;Vir2021GKBS &#x1f43c;本文由…

动网格-尺寸函数耦合运动(五)

尺寸函数 **尺寸函数(Size Function)**通常和局部体网格重构结合使用&#xff0c;尺寸函数用于控制重构过程中的网格分布。简单地说&#xff0c;尺寸函数的功能就是在运动边界处约束网格&#xff0c;使其维持在一个较小的尺度&#xff0c;在远离运动边界处&#xff0c;逐步将其…

Windows存储空间不足局域网文件共享 Dism备份系统空间不足

问题情景 在日常使用中难免遇到Windows的空间不足的情况&#xff0c;常用办法是清理垃圾释放空间&#xff0c;部分场景例如我们需要使用Dism备份完整系统&#xff0c;所以需要非常大的存储空间不够&#xff0c;如果空间不够什么才是最有效的方案呢&#xff1f; 我们假设身边没有…

如何使用docker部署Swagger Editor并实现无公网ip远程协作编辑文档

文章目录 Swagger Editor本地接口文档公网远程访问1. 部署Swagger Editor2. Linux安装Cpolar3. 配置Swagger Editor公网地址4. 远程访问Swagger Editor5. 固定Swagger Editor公网地址 Swagger Editor本地接口文档公网远程访问 Swagger Editor是一个用于编写OpenAPI规范的开源编…