CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660455.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

易优CMS采集插件使用教程

本易优CMS采集教程说明如何使用易优CMS采集插件&#xff0c;批量获取互联网上的文章数据&#xff0c;并自动更新到易优cms&#xff08;eyoucms&#xff09;网站&#xff0c;快速丰富网站的内容。 目录 1. 下载并安装易优CMS采集插件 2. 对接网页文章采集工具 3. 采集数据发…

GPT-4级别模型惨遭泄露!引爆AI社区,“欧洲版OpenAI”下场认领

大家好&#xff0c;我是二狗。 这两天&#xff0c;一款性能接近GPT-4的模型惨遭泄露&#xff0c;引发了AI社区的热议。 这背后究竟是怎么回事呢&#xff1f; 起因是1月28日&#xff0c;一位名为“Miqu Dev”的用户在 HuggingFace 上发布了一组文件&#xff0c;这些文件共同组…

智慧工地可视化综合管理云平台 PC+APP

目录 一、智慧工地可视化数据大屏功能一览 1.首页 2.视频监控 3.机械设备 4.环境监测 5.安全管理 6.质量管理 7.劳务分析 8.进度管理 9.报警统计 二、项目人员管理 1.信息管理 2.信息采集 3.证件管理 危大工程管理 一、智慧工地可视化数据大屏功能一览 包括&am…

【C语言】const修饰指针的不同作用

目录 const修饰变量 const修饰指针变量 ①不用const修饰 ②const放在*的左边 ③const放在*的右边 ④*的左右两边都有const 结论 const修饰变量 变量是可以修改的&#xff0c;如果把变量的地址交给⼀个指针变量&#xff0c;通过指针变量的也可以修改这个变量。 但…

电脑文件打不开是什么原因?常见原因有这9点

在日常生活和工作中&#xff0c;我们经常会使用电脑来处理文件。然而&#xff0c;有时候我们会遇到电脑文件打不开的情况&#xff0c;这给我们的工作和生活带来了很大的不便。本文将为大家介绍电脑文件打不开的原因&#xff0c;帮助大家更好地应对这一问题。 原因1、文件格式问…

论文解读:DeepBDC小样本图像分类

Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification 摘要 由于每个新任务只给出很少的训练样例&#xff0c;所以few -shot分类是一个具有挑战性的问题。解决这一挑战的有效研究路线之一是专注于学习由查询图像和某些类别的少数支持…

shell脚本自动备份数据库表

今日目标&#xff1a;shell脚本自动备份数据库中的表并记录执行日志和mysql输出日志 编写思路&#xff1a; &#xff08;1&#xff09;shell脚本运行mysql命令 &#xff08;2&#xff09;脚本输出记录到日志中 &#xff08;3&#xff09;定时任务自动执行shell脚本 1、she…

【Tomcat与网络9】提高Tomcat启动速度的八大措施

本文我们来看一下如何对Tomcat进行调优&#xff0c;我们对于Tomcat的调优主要集中在三个方面&#xff1a;提高启动速度、提高系统稳定性和提高并发能力&#xff0c;后两者很多时候是相辅相成的&#xff0c;我们放在一起看。 Tomcat现在一般都嵌入在SpringBoot里&#xff0c;因…

Linux 驱动开发基础知识——总线设备驱动模型(八)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;Vir2021GKBS &#x1f43c;本文由…

动网格-尺寸函数耦合运动(五)

尺寸函数 **尺寸函数(Size Function)**通常和局部体网格重构结合使用&#xff0c;尺寸函数用于控制重构过程中的网格分布。简单地说&#xff0c;尺寸函数的功能就是在运动边界处约束网格&#xff0c;使其维持在一个较小的尺度&#xff0c;在远离运动边界处&#xff0c;逐步将其…

Windows存储空间不足局域网文件共享 Dism备份系统空间不足

问题情景 在日常使用中难免遇到Windows的空间不足的情况&#xff0c;常用办法是清理垃圾释放空间&#xff0c;部分场景例如我们需要使用Dism备份完整系统&#xff0c;所以需要非常大的存储空间不够&#xff0c;如果空间不够什么才是最有效的方案呢&#xff1f; 我们假设身边没有…

如何使用docker部署Swagger Editor并实现无公网ip远程协作编辑文档

文章目录 Swagger Editor本地接口文档公网远程访问1. 部署Swagger Editor2. Linux安装Cpolar3. 配置Swagger Editor公网地址4. 远程访问Swagger Editor5. 固定Swagger Editor公网地址 Swagger Editor本地接口文档公网远程访问 Swagger Editor是一个用于编写OpenAPI规范的开源编…

【方案】TSINGSEE青犀智能分析网关V4+EasyCVR智慧服务区一体化监控平台

随着年关将近&#xff0c;春运大潮已然开启&#xff0c;届时又伴随着大雨暴雪天气&#xff0c;高速路况的新闻层出不穷。由于长期驾车且高速拥堵严重&#xff0c;不少人就聚集在服务区休息&#xff0c;导致服务区流量爆满&#xff0c;空前的拥堵极易导致服务区瘫痪。如何利用智…

计算机毕业设计 | springboot 多功能商城 购物网站(附源码)

1&#xff0c; 概述 国家大力推进信息化建设的大背景下&#xff0c;城市网络基础设施和信息化应用水平得到了极大的提高和提高。特别是在经济发达的沿海地区&#xff0c;商业和服务业也比较发达&#xff0c;公众接受新事物的能力和消费水平也比较高。开展商贸流通产业的信息化…

OpenHarmony—编辑器使用技巧

DevEco Studio支持使用多种语言进行应用/服务的开发&#xff0c;包括ArkTS、JS和C/C。在编写应用/服务阶段&#xff0c;可以通过掌握代码编写的各种常用技巧&#xff0c;来提升编码效率。 代码高亮 支持对代码关键字、运算符、字符串、类、标识符、注释等进行高亮显示&#x…

少儿编程教育市场分析:行业规模有望在2025年达到约500亿元

少儿编程教育是通过编程游戏启蒙、可视化图形编程等课程&#xff0c;培养学生的计算思维和创新解难能力的课程。与成人的编程不同&#xff0c;少儿编程教育并非高等教育那样学习如何写代码、编制应用程序&#xff0c;而是通过编程游戏启蒙、可视化图形编程等课程&#xff0c;培…

C语言——标准输入函数(scanf、getchar和gets)

目录 1. 标准输入输出头文件2. scanf2.1 scanf2.1.1 函数申明2.1.2 基本用法2.1.3 返回值2.1.4 占位符2.1.5 赋值忽略符 3. getchar3.1 函数申明3.2 基本用法 4. gets4.1 函数申明4.2 基本用法 1. 标准输入输出头文件 #include <stdio.h>在使用标准输入输出函数的时候都…

摄影分享|基于Springboot的摄影分享网站设计与实现(源码+数据库+文档)

摄影分享网站目录 目录 基于Springboot的摄影分享网站设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、图片素材管理 3、视频素材管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐…

企业网络基础架构监控工具

IT 基础架构已成为提供基本业务服务的基石&#xff0c;无论是内部管理操作还是为客户托管的应用程序服务&#xff0c;监控 IT 基础设施至关重要&#xff0c;并且已经建立起来&#xff0c;SMB IT 基础架构需要简单的网络监控工具来监控性能和报告问题。通常&#xff0c;几个 IT …

UE5 虚幻游戏报错常用解决方法(幻兽帕鲁UE5报错)

在体验使用虚幻引擎5、4&#xff08;UE5/UE4&#xff09;开发的游戏如《幻兽帕鲁》时&#xff0c;玩家可能会遇到各种报错情况&#xff0c;例如黑屏、闪退、C运行时错误等。本博客将汇集一系列有效解决方案&#xff0c;通过调整虚幻引擎内置命令行参数以及优化系统环境&#xf…