CNN对 MNIST 数据库中的图像进行分类

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图
from keras.datasets import mnist# 使用 Keras 导入MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

 将前六个训练图像可视化

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])ax.imshow(X_train[i], cmap='gray')ax.set_title(str(y_train[i]))

查看图像的更多细节 

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else 'black')fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

 预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# 调整比例,使数值在 0 - 1 范围内 [0,255] --> [0,1]
X_train = X_train.astype('float32')/255
X_test = X_test.astype('float32')/255 print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

 对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import to_categoricalnum_classes = 10 
# 打印前十个(整数值)训练标签
print('Integer-valued labels:')
print(y_train[:10])# 对标签进行一次性编码
# 将类别向量转换为二进制类别矩阵
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)# 打印前十个(单次)训练标签
print('One-hot labels:')
print(y_train[:10])

 重塑数据以适应我们的 CNN(和 input_shape)

# 输入图像尺寸为 28x28 像素的图像。
img_rows, img_cols = 28, 28X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)print('input_shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

您必须传递以下参数:

  • filters - 滤波器的数量。
  • kernel_size - 指定(正方形)卷积窗口高度和宽度的数值。

还有一些额外的、可选的参数需要调整:

  • strides - 卷积的步长。如果不指定任何参数,strides 将设为 1。
  • padding - "有效 "或 "相同 "之一。如果不做任何指定,padding 将设置为 "有效"。
  • activation - 通常为 "relu"。如果不指定任何内容,则不会应用激活。我们强烈建议你为网络中的每个卷积层添加 ReLU 激活函数。

 需要注意的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除网络中的最后一层外,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最终层应是具有 softmax 激活函数的密集层。最终层的节点数应等于数据集中的类总数。
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout# 创建模型对象
model = Sequential()# CONV_1: 添加 CONV 层,采用 RELU 激活,深度 = 32 内核
model.add(Conv2D(32, kernel_size=(3, 3), padding='same',activation='relu',input_shape=(28,28,1)))
# POOL_1: 对图像进行下采样,选择最佳特征
model.add(MaxPooling2D(pool_size=(2, 2)))# CONV_2: 在这里,我们将深度增加到 64
model.add(Conv2D(64, (3, 3),padding='same', activation='relu'))
# POOL_2: more downsampling
model.add(MaxPooling2D(pool_size=(2, 2)))# 由于维度过多,我们只需要一个分类输出
model.add(Flatten())# FC_1: 完全连接,获取所有相关数据
model.add(Dense(64, activation='relu'))# FC_2: 输出软最大值,将矩阵压制成 10 个类别的输出概率
model.add(Dense(10, activation='softmax'))model.summary()

需要注意的事项:
  • 网络以两个卷积层的序列开始,然后是最大池化层。
  • 最后一层为数据集中的每个对象类别设置了一个条目,并具有软最大激活函数,因此可以返回概率。
  • Conv2D 深度从输入层的 1 增加到 32 到 64。
  • 我们还想减少高度和宽度--这就是 maxpooling 的作用所在。请注意,在池化层之后,图像尺寸从 28 减小到 14。
  • 可以看到,每个输出形状都用 None 代替了批量大小。这是为了便于在运行时更改批次大小。
  • 最后,我们会添加一个或多个全连接层来确定图像中包含的对象。例如,如果在上一个最大池化层中发现了车轮,那么这个 FC 层将转换该信息,以更高的概率预测图像中出现了一辆汽车。如果图像中有眼睛、腿和尾巴,那么这可能意味着图像中有一只狗。

编译模型

# rmsprop 和自适应学习率 (adaDelta) 是梯度下降的流行形式,仅次于 adam 和 adagrad
# 因为我们有多个类别 (10)# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

训练模型

from keras.callbacks import ModelCheckpoint   # 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,validation_data=(X_test, y_test), callbacks=[checkpointer], verbose=2, shuffle=True)

 在验证集上加载分类准确率最高的模型

# 加载能获得最佳验证精度的权重
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率 

# 评估测试的准确性
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]# 打印测试精度
print('Test accuracy: %.4f%%' % accuracy)

评估模型 

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

 

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

注意事项:

MLP 和 CNN 通常不会产生可比较的结果。MNIST 数据集非常特别,因为它非常干净,而且经过了完美的预处理。例如,所有图像大小相同,并以 28x28 像素网格为中心。如果数字稍有偏斜或不居中,这项任务就会难得多。对于真实世界中杂乱无章的图像数据,CNN 将真正超越 MLP。

为了直观地了解为什么会出现这种情况,要将图像输入 MLP,首先必须将图像转换为矢量。然后,MLP 会将图像视为没有特殊结构的简单数字向量。它不知道这些数字原本是按空间网格排列的。

相比之下,CNN 的设计目的完全相同,即处理多维数据中的模式。与 MLP 不同的是,CNN 知道,相距较近的图像像素比相距较远的像素关系密切。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/187680.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第三十六天| 435 无重叠区间 763 划分字母区间 56 合并区间

目录 435 无重叠区间 763 划分字母区间 56 合并区间 435 无重叠区间 将intervals数组按照左端点进行升序排序。 设置变量len标志此时新加入端点后所有区间的位置&#xff0c;将其赋初值为第一对区间的右端点&#xff0c;因为该点是一定可达的。设置变量res来存储需要移除空间…

上海亚商投顾:沪指震荡反弹 消费、医药股走势活跃

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 一&#xff0e;市场情绪 指数今日窄幅震荡&#xff0c;黄白二线分化明显&#xff0c;权重股力挺指数&#xff0c;题材小票走…

INA219电流感应芯片_程序代码

详细跳转借鉴链接INA219例程此处进行总结 简单介绍一下 INA219&#xff1a; 1、 输入脚电压可以从 0V~26V,INA219 采用 3.3V/5V 供电. 2、 能够检测电流&#xff0c;电压和功率&#xff0c;INA219 内置基准器和乘法器使之能够直接以 A 为单位 读出电流值。 3、 16 位可编程地…

《数字图像处理-OpenCV/Python》连载(50)非线性灰度变换

《数字图像处理-OpenCV/Python》连载&#xff08;50&#xff09;非线性灰度变换 本书京东优惠购书链接&#xff1a;https://item.jd.com/14098452.html 本书CSDN独家连载专栏&#xff1a;https://blog.csdn.net/youcans/category_12418787.html 第 7 章 图像的灰度变换 灰度变…

Unity | 渡鸦避难所-0 | 创建 URP 项目并导入商店资源

0 前言 知识点零零碎碎&#xff0c;没有目标&#xff0c;所以&#xff0c;一起做游戏吧 各位老师如果有什么指点、批评、漫骂、想法、建议、疑惑等&#xff0c;欢迎留言&#xff0c;一起学习 1 创建 3D&#xff08;URP&#xff09;项目 在 Unity Hub 中点击新项目&#xff…

【ESP32】手势识别实现笔记:红外温度阵列 | 双三次插值 | 神经网络 | TensorFlow | ESP-DL

目录 一、开发环境搭建与新建工程模板1.1、开发环境搭建与卸载1.2、新建工程目录1.3、自定义组件 二、驱动移植与应用开发2.1、I2C驱动移植与AMG8833应用开发2.2、SPI驱动移植与LCD应用开发2.3、绘制温度云图2.4、启用PSRAM&#xff08;可选&#xff09;2.5、画面动静和距离检测…

SSM框架详解:结构创建与注解应用

文章目录 1. 引言2. SSM框架项目结构创建2.1 目录结构2.2 说明 3. 注解的应用3.1 Controller3.2 Service3.3 Repository3.4 Autowired3.5 RequestMapping3.6 Select、Insert等 4. 结语 &#x1f388;个人主页&#xff1a;程序员 小侯 &#x1f390;CSDN新晋作者 &#x1f389;欢…

专业级音频处理 Logic Pro X 中文 for Mac

Logic Pro X是一款专业音频制作和音乐创作软件。它是Mac电脑上最受欢迎和广泛使用的音频工作站&#xff08;DAW&#xff09;。Logic Pro X提供了丰富的功能和工具&#xff0c;适用于音乐制作、录音、编辑、混音和音频处理等方面。以下是Logic Pro X软件的一些主要特点和功能&am…

怎么取消苹果订阅自动续费?分享3个可行方法!

在日常生活中&#xff0c;我们经常会使用到各种应用程序或服务&#xff0c;其中很多都提供了订阅自动续费的功能。然而&#xff0c;有时候用户可能会忘记取消订阅&#xff0c;从而导致不必要的扣费&#xff0c;给用户带来麻烦和困扰。 那么&#xff0c;对于使用苹果手机的小伙…

【JUC】十八、happens-before先行发生原则

文章目录 1、先行发生原则happens-before2、happens-before总原则3、8条happens-before规则4、案例 1、先行发生原则happens-before 在Java中&#xff0c;Happends-Before本质上是规定了一种可见性&#xff0c; A Happends-Before B&#xff0c;则A发生过的事情对B来说是可见的…

Discuz论坛自动采集发布软件

随着网络时代的不断发展&#xff0c;Discuz论坛作为一个具有广泛用户基础的开源论坛系统&#xff0c;其采集全网文章的技术也日益受到关注。在这篇文章中&#xff0c;我们将专心分享通过输入关键词实现Discuz论坛的全网文章采集&#xff0c;同时探讨采集过程中伪原创的发布方法…

.net-去重的几种情况

文章目录 前言1. int 类型的list 去重2. string类型的 list 去重3. T泛型 List去重4. 使用HashSet List去重5. 创建静态扩展方法 总结 前言 .net 去重的几种情况 1. int 类型的list 去重 // List<int> List<int> myList new List<int>(){ 100 , 200 ,100…

Selenium定位元素的方法css和xpath的区别!

selenium是一种自动化测试工具&#xff0c;它可以通过不同的定位方式来识别网页上的元素&#xff0c;如id、name、class、tag、link text、partial link text、css和xpath。 css和xpath是两种常用的定位方式&#xff0c;它们都可以通过元素的属性或者层级关系来定位元素&#…

Win10任务栏卡死?三个技巧,让你轻松应对!

windows 10作为广受欢迎的操作系统&#xff0c;为用户提供了强大的功能和友好的用户界面。然而&#xff0c;有时用户可能会面临任务栏卡死的问题&#xff0c;这不仅影响使用体验&#xff0c;还可能导致一系列其他问题。本文将深入介绍win10任务栏卡死的原因&#xff0c;并提供三…

【Linux】-信号-(信号的产生,保存,处理,以及os是怎么读取硬件的输入,硬件异常和coredump,定时器的原理简单的用户态和内核态的详细介绍)

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树&#x1f388; &#x1f389;作者宣言&#xff1a;认真写好每一篇博客&#x1f4a4; &#x1f38a;作者gitee:gitee✨ &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 动态规划算法&#x1f384; 如 果 你 …

外贸B2B网站独立站建站(零基础全流程)

1.第一步是要先去买个域名&#xff1a; 一般做外贸的购买.com 后缀的国际域名就好&#xff0c;域名可以在阿里云&#xff0c;腾讯云等大的平台上购买&#xff0c;方法很简单&#xff08;但是在确定购买新的域名最好要分析下这个域名有没有被黑过&#xff0c;要不然后期对这个网…

【OpenGL】Clion配置

OpenGL简介 OpenGL&#xff08;Open Graphics Library&#xff09;是指定义了一个跨编程语言、跨平台的编程接口规格的专业的图形程序接口。它用于三维图像&#xff08;二维的亦可&#xff09;&#xff0c;是一个功能强大&#xff0c;调用方便的底层图形库。OpenGL是行业领域中…

JVS低代码按钮组件触发逻辑,打破传统功能界限

在现代应用开发中&#xff0c;按钮组件的功能不仅仅局限于触发页面上的简单动作&#xff0c;它更可以成为连接前后端数据交互的桥梁。当按钮被点击时&#xff0c;其背后可能隐藏着复杂的逻辑远程调用过程&#xff0c;这些过程旨在从远程服务器获取数据&#xff0c;并将这些数据…

微信怎么设置自动回复

微信作为一款广受欢迎的社交媒体平台&#xff0c;其聊天功能是非常重要的。许多用户都希望能够快速、自动地回复消息 首先&#xff0c;点击设置&#xff0c;选择机器人下面的自动通过好友 点击新增规则&#xff0c;设置你自动通过好友的时间段&#xff0c;自动通过好友的微信工…

LeetCode Hot100 3.无重复字符的最长子串

题目&#xff1a; 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 代码&#xff1a; class Solution {public int lengthOfLongestSubstring(String s) {char[] arr s.toCharArray(); // 转换成 char[] 加快效率&#xff08;忽略带来的空间…