python与深度学习(五):CNN和手写数字识别

目录

  • 1. 说明
  • 2. 卷积运算
  • 3. 填充
  • 4. 池化
  • 5. 卷积神经网络实战-手写数字识别的CNN模型
    • 5.1 导入相关库
    • 5.2 加载数据
    • 5.3 数据预处理
    • 5.4 数据处理
    • 5.5 构建网络模型
    • 5.6 模型编译
    • 5.7 模型训练、保存和评价
    • 5.8 模型测试
    • 5.9 模型训练结果的可视化
  • 6. 手写数字识别的CNN模型可视化结果图
  • 7. 完整代码

1. 说明

从这篇文章开始介绍卷积神经网络CNN,CNN比ANN更适用于图片分类问题。

2. 卷积运算

卷积运算的过程如下图所示,灰度图像(可以认为是一个二维矩阵),通过将卷积核从最左上角贴合,然后将对应位置的数进行相乘再相加就得到了特征图的第一个数。
在这里插入图片描述
然后通过移动卷积核,再进行对应位置相乘再相加可以得到特征图的第二个结果,这里步长是1。
卷积核的运动方向是先往右运动,到了末尾,然后再向下移动并且回到开头,然后再继续向右运动,经过如此往复运动可以得到下图结果。
在这里插入图片描述
针对于单通道输入和多核现象,就是用每一个卷积核和单通道图像进行运算,最后得到两个特征图,如下图所示。
在这里插入图片描述
针对于多通道情况,如下图所示。
在这里插入图片描述
多通道的每个通道和卷积核的对应通道进行卷积,然后把所有通道的运算结果进行相加,得到特征图的结果,如下图。
在这里插入图片描述
针对于多通道多核现象,就是将每个卷积核和多通道进行运算得到多个特征图,如下图所示。
在这里插入图片描述
以上均是步长为1的情况,即卷积核每次移动一个像素。如果步长为2,即卷积核每次移动两个像素。

3. 填充

通过上述介绍,我们发现经过多层卷积之后特征图会变得越来越小,会造成信息的缺失,为了降低影响可以进行填充操作,即在特征图的外层加一圈0,然后和卷积核进行运算,如下图所示。
在这里插入图片描述
这时候卷积核的参数padding为same,不进行填充时候为valid,默认为valid。

4. 池化

有时候特征图很大,网络的参数太多,为了降低网络的参数量,采用池化操作。池化可分为最大池化和平均池化。
最大池化,就是对特征图在每个池化区域内找出最大值作为输出结果,如下图。
在这里插入图片描述
平均池化,就是对特征图在每个池化区域内找进行平均运算得到的结果作为输出结果,如下图。
在这里插入图片描述

5. 卷积神经网络实战-手写数字识别的CNN模型

5.1 导入相关库

以下第三方库是python专门用于深度学习的库

from keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow import keras
from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPool2D
from keras.models import Sequential
from keras.callbacks import EarlyStopping
import tensorflow as tf
from keras import optimizers, losses

5.2 加载数据

把MNIST数据集进行加载

"1.加载数据"
"""
x_train是mnist训练集图片,大小的28*28的,y_train是对应的标签是数字
x_test是mnist测试集图片,大小的28*28的,y_test是对应的标签是数字
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()  # 加载mnist数据集
print('mnist_data:', x_train.shape, y_train.shape, x_test.shape, y_test.shape)  # 打印训练数据和测试数据的形状

5.3 数据预处理

(1) 将输入的图片进行归一化,从0-255变换到0-1;
(2) 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),便于输入给神经网络;
(3) 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数, 计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算 独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0。

"2.数据预处理"def preprocess(x, y):  # 数据预处理函数x = tf.cast(x, dtype=tf.float32) / 255.  # 将输入的图片进行归一化,从0-255变换到0-1x = tf.reshape(x, [28, 28, 1])"""# 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),相当于将图片拉直,便于输入给神经网络"""y = tf.cast(y, dtype=tf.int32)  # 将输入图片的标签转换为int32类型y = tf.one_hot(y, depth=10)"""# 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数,计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0"""return x, y

5.4 数据处理

数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象

batchsz = 128  # 每次输入给神经网络的图片数
"""
数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象
"""
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 构建训练集对象
db = db.map(preprocess).shuffle(60000).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理
ds_val = tf.data.Dataset.from_tensor_slices((x_test, y_test))  # 构建测试集对象
ds_val = ds_val.map(preprocess).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理

5.5 构建网络模型

构建了两层卷积层,两层池化层,然后是展平层(将二维特征图拉直输入给全连接层),然后是三层全连接层。

"3.构建网络模型"
model = Sequential([Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(10,activation='softmax')])model.build(input_shape=(None, 28, 28, 1))  # 模型的输入大小
model.summary()  # 打印网络结构

5.6 模型编译

模型的优化器是Adam,学习率是0.01,
损失函数是losses.CategoricalCrossentropy,
性能指标是正确率accuracy

"4.模型编译"
model.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=False),metrics=['accuracy'])
"""
模型的优化器是Adam,学习率是0.01,
损失函数是losses.CategoricalCrossentropy,
性能指标是正确率accuracy
"""

5.7 模型训练、保存和评价

模型训练的次数是5,每1次循环进行测试;
以.h5文件格式保存模型;
得到测试集的正确率。

"5.模型训练"
history = model.fit(db, epochs=5, validation_data=ds_val, validation_freq=1)
"""
模型训练的次数是5,每1次循环进行测试
"""
"6.模型保存"
model.save('cnn_mnist.h5')  # 以.h5文件格式保存模型"7.模型评价"
model.evaluate(ds_val)  # 得到测试集的正确率

5.8 模型测试

对模型进行测试

"8.模型测试"
sample = next(iter(ds_val))  # 取一个batchsz的测试集数据
x = sample[0]  # 测试集数据
y = sample[1]  # 测试集的标签
pred = model.predict(x)  # 将一个batchsz的测试集数据输入神经网络的结果
pred = tf.argmax(pred, axis=1)  # 每个预测的结果的概率最大值的下标,也就是预测的数字
y = tf.argmax(y, axis=1)  # 每个标签的最大值对应的下标,也就是标签对应的数字
print(pred)  # 打印预测结果
print(y)  # 打印标签数字

5.9 模型训练结果的可视化

对模型的训练结果进行可视化

"9.模型训练时的可视化"
# 显示训练集和验证集的acc和loss曲线
acc = history.history['accuracy']  # 获取模型训练中的accuracy
val_acc = history.history['val_accuracy']  # 获取模型训练中的val_accuracy
loss = history.history['loss']  # 获取模型训练中的loss
val_loss = history.history['val_loss']  # 获取模型训练中的val_loss
# 绘值acc曲线
plt.figure(1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
# 绘制loss曲线
plt.figure(2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()  # 将结果显示出来

6. 手写数字识别的CNN模型可视化结果图

Epoch 1/5
469/469 [==============================] - 12s 22ms/step - loss: 0.1733 - accuracy: 0.9465 - val_loss: 0.0840 - val_accuracy: 0.9763
Epoch 2/5
469/469 [==============================] - 11s 21ms/step - loss: 0.0704 - accuracy: 0.9793 - val_loss: 0.0581 - val_accuracy: 0.9819
Epoch 3/5
469/469 [==============================] - 11s 22ms/step - loss: 0.0566 - accuracy: 0.9833 - val_loss: 0.0576 - val_accuracy: 0.9844
Epoch 4/5
469/469 [==============================] - 11s 22ms/step - loss: 0.0573 - accuracy: 0.9833 - val_loss: 0.0766 - val_accuracy: 0.9784
Epoch 5/5
469/469 [==============================] - 11s 22ms/step - loss: 0.0556 - accuracy: 0.9844 - val_loss: 0.0537 - val_accuracy: 0.9830

在这里插入图片描述
在这里插入图片描述
从以上结果可知,模型的准确率达到了98%。

7. 完整代码

from keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow import keras
from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPool2D
from keras.models import Sequential
from keras.callbacks import EarlyStopping
import tensorflow as tf
from keras import optimizers, losses
"1.加载数据"
"""
x_train是mnist训练集图片,大小的28*28的,y_train是对应的标签是数字
x_test是mnist测试集图片,大小的28*28的,y_test是对应的标签是数字
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()  # 加载mnist数据集
print('mnist_data:', x_train.shape, y_train.shape, x_test.shape, y_test.shape)  # 打印训练数据和测试数据的形状"2.数据预处理"def preprocess(x, y):  # 数据预处理函数x = tf.cast(x, dtype=tf.float32) / 255.  # 将输入的图片进行归一化,从0-255变换到0-1x = tf.reshape(x, [28, 28, 1])"""# 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),相当于将图片拉直,便于输入给神经网络"""y = tf.cast(y, dtype=tf.int32)  # 将输入图片的标签转换为int32类型y = tf.one_hot(y, depth=10)"""# 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数,计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0"""return x, ybatchsz = 128  # 每次输入给神经网络的图片数
"""
数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象
"""
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 构建训练集对象
db = db.map(preprocess).shuffle(60000).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理
ds_val = tf.data.Dataset.from_tensor_slices((x_test, y_test))  # 构建测试集对象
ds_val = ds_val.map(preprocess).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理"3.构建网络模型"
model = Sequential([Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),MaxPool2D(pool_size=(2, 2), strides=2),Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),MaxPool2D(pool_size=(2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(10,activation='softmax')])model.build(input_shape=(None, 28, 28, 1))  # 模型的输入大小
model.summary()  # 打印网络结构"4.模型编译"
model.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=False),metrics=['accuracy'])
"""
模型的优化器是Adam,学习率是0.01,
损失函数是losses.CategoricalCrossentropy,
性能指标是正确率accuracy
""""5.模型训练"
history = model.fit(db, epochs=5, validation_data=ds_val, validation_freq=1)
"""
模型训练的次数是5,每1次循环进行测试
"""
"6.模型保存"
model.save('cnn_mnist.h5')  # 以.h5文件格式保存模型"7.模型评价"
model.evaluate(ds_val)  # 得到测试集的正确率"8.模型测试"
sample = next(iter(ds_val))  # 取一个batchsz的测试集数据
x = sample[0]  # 测试集数据
y = sample[1]  # 测试集的标签
pred = model.predict(x)  # 将一个batchsz的测试集数据输入神经网络的结果
pred = tf.argmax(pred, axis=1)  # 每个预测的结果的概率最大值的下标,也就是预测的数字
y = tf.argmax(y, axis=1)  # 每个标签的最大值对应的下标,也就是标签对应的数字
print(pred)  # 打印预测结果
print(y)  # 打印标签数字"9.模型训练时的可视化"
# 显示训练集和验证集的acc和loss曲线
acc = history.history['accuracy']  # 获取模型训练中的accuracy
val_acc = history.history['val_accuracy']  # 获取模型训练中的val_accuracy
loss = history.history['loss']  # 获取模型训练中的loss
val_loss = history.history['val_loss']  # 获取模型训练中的val_loss
# 绘值acc曲线
plt.figure(1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
# 绘制loss曲线
plt.figure(2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()  # 将结果显示出来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/10060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode面向运气之Javascript—第2600题-K件物品的最大和-94.68%

LeetCode第2600题-K件物品的最大和 题目要求 袋子中装有一些物品,每个物品上都标记着数字 1 、0 或 -1 。 四个非负整数 numOnes 、numZeros 、numNegOnes 和 k 。 袋子最初包含: numOnes 件标记为 1 的物品。numZeroes 件标记为 0 的物品。numNegOn…

12 扩展Spring MVC

✔ 12.1 实现页面跳转功能 页面跳转功能:访问localhost:8081/jiang会自动跳转到另一个页面。 首先,在config包下创建一个名为MyMvcConfig的配置类: 类上加入Configuration注解,类实现WebMvcConfiger接口,实现里面的视…

Tomcat中的缓存配置

Tomcat中的缓存配置通常是通过Web应用程序的context.xml文件或Tomcat的server.xml文件进行设置。下面提供一个简单的案例来说明如何在Tomcat中配置缓存。 假设您的Web应用程序名为"myapp",我们将在context.xml中添加缓存配置。 打开Tomcat安装目录&…

【学习心得】sublime text 4 自定义编译系统

一、问题描述 在电脑中有多个版本的Python解释器,而sublime默认选择最新版本的解释器,如何指定自己想要的解释器呢? 二、自定义编译系统 1、选择新建编译系统(如图) 2、重写两个键值对(只修改中文部分其…

升级你的数据采集引擎 使用多线程与代理池提升HTTP代理爬虫性能

在信息爆炸的时代,海量数据的采集和分析成为了企业发展和决策的关键。本文将分享如何通过多线程和代理池的应用,助您升级数据采集引擎,提高数据获取效率和稳定性。 HTTP代理爬虫作为数据采集的重要工具,其性能直接影响着数据采集…

【Hive 01】简介、安装部署、高级函数使用

1 Hive简介 1.1 Hive系统架构 Hive是建立在 Hadoop上的数据仓库基础构架,它提供了一系列的工具,可以进行数据提取、转化、加载( ETL )Hive定义了简单的类SQL查询语言,称为HQL,它允许熟悉SQL的用户直接查询…

待学习列表

列表 梦是人生番外篇 语雀 语言 区块链 TS 微信小程序 前端 Java Python C

【《机器学习和深度学习:原理、算法、实战(使用Python和TensorFlow)》——以机器学习理论为基础并包含其在工业界的实践的一本书】

机器学习和深度学习已经成为从业人员在人工智能时代必备的技术,被广泛应用于图像识别、自然语言理解、推荐系统、语音识别等多个领域,并取得了丰硕的成果。目前,很多高校的人工智能、软件工程、计算机应用等专业均已开设了机器学习和深度学习…

AVFoundation - 视频过渡

文章目录 一、简要说明二、使用一、简要说明 相关类 AVMutableVideoCompositionAVMutableVideoCompositionInstruction 视频操作指令 AVMutableVideoCompositionLayerInstruction二、使用 - (void)testCom5 {// Compositionを生成AVMutableComposition *mutableComposition =…

OBS 迁移--华为云

一、创建迁移i任务 1. 登录管理控制台。 2. 单击管理控制台左上角的 在下拉框中选择区域。 3. 单击“ 服务列表 ”,选择“ 迁移 > 对象存储迁移服务 OMS ”,进入“ 对象存储迁移服务 ”页面。 4. 单击页面右上角“ 创建迁移任务 ”。 5. 仔细阅读…

el-upload上传图片和视频,支持预览和删除

话不多说&#xff0c; 直接上代码&#xff1a; 视图层&#xff1a; <div class"contentDetail"><div class"contentItem"><div style"margin-top:5px;" class"label csAttachment">客服上传图片:</div><el…

iOS--KVO和KVC

KVC 简单介绍 KVC的全称是KeyValueCoding&#xff0c;俗称“键值编码”&#xff0c;可以通过一个key来访问某个属性&#xff1b; KVC提供了一种间接访问其属性方法或成员变量的机制&#xff0c;可以通过字符串来访问对应的属性方法或成员变量&#xff1b; 它是一个非正式的…

大数据课程C4——ZooKeeper结构运行机制

文章作者邮箱&#xff1a;yugongshiyesina.cn 地址&#xff1a;广东惠州 ▲ 本章节目的 ⚪ 了解Zookeeper的特点和节点信息&#xff1b; ⚪ 掌握Zookeeper的完全分布式安装 ⚪ 掌握Zookeeper的选举机制、ZAB协议、AVRO&#xff1b; 一、Zookeeper-简介 1. 特点…

js中reduce方法的常用应用场景

reduce() 方法可以用来迭代数组并且执行一个回调函数&#xff0c;将数组中的元素聚合成一个单独的值。它可以被用于一系列的操作&#xff0c;如累加求和&#xff0c;计算平均值和查找最大值或最小值等。以下是reduce() 方法的几个应用场景和相应的示例&#xff1a; 求和或求积…

vue build 打包遇到bug解决记录

文章目录 vue-cli-service servevue打包修改dist文件夹名字vue build require is not defined 和 exports is not defind 错误 vue-cli-service serve 通常vue是不能直接使用vue-cli-service命令在终端运行的&#xff0c;所以才会在package.json中配置了scripts&#xff1a; …

Gorm 单表操作 插入数据

创建表AutoMigrate&#xff0c;并且插入数据db.Create package mainimport ("fmt""gorm.io/driver/mysql""gorm.io/gorm" )type Student struct {ID uint gorm:"size:3"Name string gorm:"size:3"Age int …

JDBC的的使用

首先导入jar包。 https://downloads.mysql.com/archives/c-j/ package com.test.sql;import java.sql.*;public class StudySql {public static void init() throws SQLException {Statement stmt null;Connection conn null;ResultSet res null;PreparedStatement pstm…

微信小程序生成二维码(weapp-qrcode)可添加logo

插件 npm 地址&#xff1a;https://www.npmjs.com/package/weapp-qrcode 插件 GitHub 地址&#xff1a;https://github.com/yingye/weapp-qrcode/tree/master 一、引入 1、根据 GitHub 指引将 weapp-qrcode 放到本地 uitl 文件夹下&#xff1b; 2、创建 canvas <canvas c…

自动化运维工具—Ansible概述

Ansible是什么&#xff1f; Ansible是一个基于Python开发的配置管理和应用部署工具&#xff0c;现在也在自动化管理领域大放异彩。它融合了众多老牌运维工具的优点&#xff0c;Pubbet和Saltstack能实现的功能&#xff0c;Ansible基本上都可以实现。 Ansible能批量配置、部署、…

CTFshow web入门 爆破

web21 bp 攻击模块的迭代器 输入账号密码 抓包 发现下面存在一个base64编码 我输入的是 123 123 发现就是base加密 账号 密码 那我们怎么通过intruder模块 自动变为 base64呢 然后去payload------>custom iterator(自定义迭代器) 位置1导入admin 因为是 账号:密码 这…