全连接神经网络案例——手写数字识别

文章目录

  • 1.我们导入需要的工具包
  • 2.数据加载
  • 3.数据处理
  • 4.模型构建
  • 5.模型编译
  • 6.模型训练
  • 7.模型测试
  • 8.模型保存

在这里插入图片描述
使⽤⼿写数字的MNIST数据集如上图所示,该数据集包含60,000个⽤于训练的样本和10,000个⽤于测试的样本,图像是固定⼤⼩(28x28像素),其值为0到255。

整个案例的实现流程是:

  • 数据加载
  • 数据处理
  • 模型构建
  • 模型训练
  • 模型测试
  • 模型保存

1.我们导入需要的工具包

# 1.导入所需的工具包
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
# 构建模型
from tensorflow.keras.models import Sequential
# 相关的网络层
from tensorflow.keras.layers import Dense, Dropout, Activation, BatchNormalization
# 导入辅助工具包
from tensorflow.keras import utils
# 正则化
from tensorflow.keras import regularizers
# 数据集
from tensorflow.keras.datasets import mnist

2.数据加载

首先我们加载手写数字图像

# 2.数据加载
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

在这里插入图片描述
数据展示:
在这里插入图片描述

3.数据处理

神经⽹络中的每个训练样本是⼀个向量,因此需要对输⼊进⾏重塑,使每个28x28的图像成为⼀个的784维向量。另外,将输⼊数据进⾏归⼀化处理,从0-255调整到0-1。
在这里插入图片描述
另外对于⽬标值我们也需要进⾏处理,将其转换为热编码的形式(本):
在这里插入图片描述

# 3.数据处理
# 数据维度的调整
x_train = x_train.reshape(60000, 784)
x_test = x_test(10000, 784)
# 数据类型调整
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# 归一化
x_train = x_train/255
x_test = x_test/255
# 将目标值转换成热编码的形式
y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

4.模型构建

在这⾥我们构建只有3层全连接的⽹络来进⾏处理:
在这里插入图片描述
构建模型如下所示:

# 4.模型构建
# 使用序列模型进行构建
model = Sequential()
# 全连接层,2个隐藏层,一个输出层
# 第一个隐藏层,512个神经元,先BN再激活,随机失活
model.add(Dense(512, input_shape=(784,)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))
# 第二个隐藏层,512个神经元,先BN再激活,随机失活
model.add(Dense(512, kernel_regularizer=regularizers.l2(0.01)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))
# 输出层
model.add(Dense(10, activation='softmax'))
# 查看模型架构
model.summary()

在这里插入图片描述

5.模型编译

设置模型训练使⽤的损失函数交叉熵损失和优化⽅法adam,损失函数⽤来衡量预测值与真实值之间的差异,优化器⽤来使⽤损失函数达到最优:

# 5.模型编译
# 损失函数(交叉熵损失),优化器,评价指标
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(),metrics=tf.keras.metrics.Accuracy())

6.模型训练

# 6.模型训练
# 使用fit,指定训练集,epochs,batch_size,val,verbose
history = model.fit(x_train, y_train, epochs=4, batch_size=128, validation_data=(x_test, y_test), verbose=1)

在这里插入图片描述
我们将损失绘制成曲线:

# 绘制损失函数
plt.figure()
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='val')
plt.legend()
plt.grid()
plt.show()

在这里插入图片描述

我们再来绘制一下准确率变化曲线:

# 绘制准确率变化曲线
plt.figure()
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='val')
plt.legend()
plt.grid()
plt.show()

在这里插入图片描述
其中,history会保存loss的变化和在compile中指定的评价指标的结果。

7.模型测试

# 7.模型测试
model.evaluate(x_test, y_test, verbose=1)

8.模型保存

# 8.模型保存
model.save("my_model.h5")

之后要是想要想在这个模型时,只需要使用model = tf.keras.models.load_model('my_model.h5')即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/58543.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2.ARM_ARM是什么

CPU工作原理 CPU与内存中的内容: 内存中存放了指令,每一个指令存放的地址不一样,所需的内存空间也不一样。 运算器能够进行算数运算和逻辑运算,这些运算在CPU中都是以运算电路的形式存在,一个运算功能对应一种运算电…

MetaGeneMark:宏转录组转录本基因预测

GeneMark™ download 下载 gunzip gm_key_64.gz tar -xvzf MetaGeneMark_linux_64.tar.gz #查看安装 (完整路径)/gmhmmp #解压文件里面这个比较重要 MetaGeneMark_linux_64/mgm/MetaGeneMark_v1.mod #复制gm_key文件到主路径 mv gm_key_64 .gm_key cp .gm_key /home/zhongpei…

腾讯轻量云服务器docker拉取不到镜像的问题:拉取超时

前言 也是尝试了各种解决方案之后,无果, 后来发现每个服务器提供商都有自己的镜像加速,且只给自家服务器使用,我用的腾讯云 教程 安装docker 直接上链接:云服务器 搭建 Docker-实践教程-文档中心-腾讯云 配置加速镜…

各家AI性格不同,怎样取长补短

你发现了么,每家的AI性格也有区别呢,有些AI比较啰嗦,有些AI回答简洁明了,有些AI条理清晰喜欢列1、2、3。 我们在利用AI的时候,也要学会取长补短,参考各家AI的回答,择优录用。 例如&#xff0c…

Django安装

在终端创建django项目 1.查看自己的python版本 输入对应自己本机python的版本,列如我的是3.11.8 先再全局安装django依赖包 2.在控制窗口输入安装命令: pip3.11 install django 看到Successflully 说明我们就安装成功了 python的Scripts文件用于存…

Socket 和 WebSocket 的应用

Socket(套接字)是计算机网络中的一个抽象层,它允许应用程序通过网络进行通信。套接字用于跨网络的不同主机上的应用程序之间的数据交换。在互联网中,套接字通常基于 TCP(传输控制协议)或 UDP(用…

Materials Studio 2023安装教程(仅作分享参考)

目录 一、软件下载 二、软件介绍 2.1 软件特点 2.2 功能模块 2.3 应用领域 三、安装步骤 一、软件下载 软件名称:Materials Studio 2023 软件语言:英文 软件大小:2.03G 系统要求:Windows10或更高, 64位操作系…

Spark SQL大数据分析快速上手-DataFrame应用体验

【图书介绍】《Spark SQL大数据分析快速上手》-CSDN博客 《Spark SQL大数据分析快速上手》【摘要 书评 试读】- 京东图书 大数据与数据分析_夏天又到了的博客-CSDN博客 本节主要介绍如何使用DataFrame进行编程。 4.1.1 SparkSession 在旧版本中,Spark SQL提供…

SSM中maven

一:maven的分模块开发 maven分模块就是在多人操作一个项目时将maven模块导入依赖,注意仓库里面没有资源坐标,需要使用install操作下载。 二:maven的依赖管理 pom文件中直接写的依赖叫做直接依赖,直接依赖中用到的依…

25中海油笔试测评春招秋招校招暑期实习社招笔试入职测评行测题型微测网题型分享

中海油笔试一般采用线上机考的形式。考试时间为 120 分钟,满分 100 分。笔试内容主要包括思想素质测评和通用能力测评两个科目。以下是具体介绍: 1. 思想素质测评: ✅价值观:考察考生对工作、职业、企业等方面的价值观念和态度&…

【笔记】变压器-热损耗-频响曲线推导 - 04 额定功率处损耗特性

0.最大的问题 - 散热 对变压器这类功率器件,最大的问题是散热的效率。因为传统的电路基板热导率并不高,几乎和良性导热材料有近乎两个数量级的导热差异,所以,会采用特殊的导热技术,把热量尽可能快地传导到散热片。 传…

定高虚拟列表:让大数据渲染变得轻松

定高虚拟列表 基本认识 在数据如潮水般涌来的今天,如何高效地展示和管理这些数据成为了开发者们面临的一大挑战,传统的列表渲染方式在处理大量数据时,往往会导致页面卡顿、滚动不流畅等问题,严重影响用户体验(在页面…

我的博客网站为什么又回归Blazor了

引言 在博客网站的开发征程中,站长可谓是一路披荆斩棘。从最初的构思到实践,先后涉足了多种开发技术,包括 MVC、Razor Pages、Vue、Go、Blazor 等。在这漫长的过程中,网站版本更迭近 10 次,每一个版本都凝聚着站长的心…

Uniapp安装Pinia并持久化(Vue3)

安装pinia 在uni-app的Vue3版本中,Pinia已被内置,无需额外安装即可直接使用(Vue2版本则内置了Vuex)。 HBuilder X项目:直接使用,无需安装。CLI项目:需手动安装,执行yarn add pinia…

<网络> 协议

目录 文章目录 一、认识协议 1. 协议概念 2. 结构化数据传输 3. 序列化和反序列化 二、网络计算器 1. 封装socket类 2. 协议定制 request类的序列化和反序列化 response类的序列化和反序列化 报头的添加与去除 Json序列化工具 Jsoncpp 的主要特点: Jsoncpp 的使用方法: 3. Ser…

群控系统服务端开发模式-应用开发-文件上传功能开发

一、文件上传路由 在根目录下route文件夹中app.php文件中,添加文件上传功能路由,代码如下: Route::post(upload/file,common.Upload/file);// 上传文件接口 二、功能代码开发 在根目录下app文件夹下common文件夹中创建上传控制器并命名为Up…

pycharm小游戏贪吃蛇及pygame模块学习()

由于代码量大,会逐渐发布 一.pycharm学习 在PyCharm中使用Pygame插入音乐和图片时,有以下这些注意事项: 插入音乐: - 文件格式支持:Pygame常用的音乐格式如MP3、OGG等,但MP3可能需额外安装库&#xf…

检索增强和知识冲突学习笔记

检索增强生成任务(Retrieval-Augmented Generation, RAG)是一种自然语言处理技术,它结合了信息检索和生成模型,用于生成高质量的文本输出。具体来说,RAG 模型在生成文本时,会先通过检索模块从外部知识库或文…

从0开始深度学习(25)——多输入多输出通道

之前我们都只研究了一个通道的情况(二值图、灰度图),但实际情况中很多是彩色图像,即有标准的RGB三通道图片,本节将更深入地研究具有多输入和多输出通道的卷积核。 1 多输入通道 当输入包含多个通道时,需要…

网管平台(进阶篇):如何正确的管理网络设备?

网络设备作为构建计算机网络的重要基石,扮演着数据传输、连接和管理的关键角色。从交换机、路由器到防火墙、网关,各类网络设备共同协作,形成了高效、稳定的网络系统。本文将详细介绍网络设备的种类,并探讨如何正确管理这些设备&a…