模型训练识别手写数字(一)

  一、模型训练数据集

1. 导入所需库

import numpy as np
from sklearn.datasets import fetch_openml

numpy 是用于数值计算的库。

fetch_openml 是用于从 OpenML 下载数据集的函数。

  2. 获取 MNIST 数据集

X, y = fetch_openml('mnist_784', version=1, return_X_y=True)

fetch_openml('mnist_784', version=1, return_X_y=True) 从 OpenML 下载 MNIST 数据集。X 存储图像数据(784 个特征,28x28 像素的扁平化图像),y 存储对应的标签(数字 0 到 9)。

   3. 将像素值二值化

X[X > 0] = 1

这行代码将 X 中所有大于 0 的像素值设置为 1,二值化处理。这样处理后的图像只有两个值:0(黑色)和 1(白色),有助于简化模型的输入。

   4. 保存数据集

np.save("Data/dataset", X)
np.save("Data/class", y)

np.save("Data/dataset", X) 将图像数据保存为 dataset.npy

np.save("Data/class", y) 将标签数据保存为 class.npy

二、模型训练及预测

1. 导入所需库

import matplotlib.pyplot as plt
import numpy as np
from keras import Sequential
from keras import layers
from keras.api.models import load_model
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

matplotlib.pyplot: 用于绘图和数据可视化。

numpy: 用于处理数组和数据加载。

keras: 用于构建和训练深度学习模型。

sklearn: 提供数据划分和预处理的工具。

f1_score: 用于评估模型性能。

 2. 加载数据

X = np.load("Data/dataset.npy", allow_pickle=True)
y = np.load("Data/class.npy", allow_pickle=True)

使用 numpyload 方法加载训练数据 (X) 和标签 (y)。allow_pickle=True 允许加载包含对象的数组。

 3. One-Hot 编码

onehot = OneHotEncoder(sparse_output=False)
y = onehot.fit_transform(y.reshape(-1, 1))

OneHotEncoder: 将标签转换为独热编码格式,方便用于分类任务。每个标签会被转换为一个二进制数组。 

 4. 划分训练集和测试集

x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=14)

使用 train_test_split 将数据集分为训练集和测试集,通常使用 70%-80% 的数据用于训练,其余用于测试。 

 5. 构建模型

model = Sequential()
model.add(layers.Dense(100, activation='relu', input_shape=(x_train.shape[1],)))
model.add(layers.Dense(y.shape[1], activation='softmax'))  

Sequential: 表示模型是线性的,按顺序堆叠各个层。

Dense: 添加全连接层,第一层有 100 个神经元,使用 ReLU 激活函数;第二层为输出层,使用 Softmax 激活函数,适合多类分类任务。

 6. 编译模型

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

使用 Adam 优化器,损失函数为交叉熵(适合多分类),并监控准确率。 

 7. 训练模型

model.fit(x_train, y_train, epochs=100, batch_size=32, verbose=1)

 训练模型 100 个周期,批量大小为 32,verbose=1 表示输出训练过程的信息。

 8. 评估模型

predictions = model.predict(x_test)
predictions_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)print("F-score: {0:.2f}".format(f1_score(y_test_classes, predictions_classes, average='micro')))

使用测试集进行预测,并计算 F-score 作为评估指标。np.argmax 用于获取每个样本预测概率最高的类。 

 9. 保存模型

model.save("my_model.h5")

将训练好的模型保存到文件 my_model.h5 中,以便后续加载和使用。 

 10. 加载模型

loaded_model = load_model("my_model.h5")

 加载之前保存的模型,以便进行预测。

 11. 进行预测

predictions = loaded_model.predict(x_test)

使用加载的模型对测试集进行预测,获取每个样本的预测结果。

 12. 获取预测和真实标签

y_pred_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)

 使用 np.argmax 从预测结果和真实标签中获取每个样本的类别索引。

  13. 可视化预测结果

plt.figure(figsize=(12, 6))for i in range(20):plt.subplot(4, 5, i + 1)plt.imshow(x_test[i].reshape(28, 28), cmap='gray')  # 假设输入是28x28的图像plt.title(f'True: {y_test_classes[i]}\nPred: {y_pred_classes[i]}')plt.axis('off')plt.tight_layout()
plt.show()

创建一个图形窗口,设置大小为 12x6。

使用 subplot 在 4 行 5 列的网格中绘制 20 个图像。

每个子图中显示测试样本的图像、真实标签和预测标签。

imshow 将图像进行灰度显示,axis('off') 隐藏坐标轴。

tight_layout() 调整子图参数,以避免重叠。

show() 显示图形。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/57072.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

揭开MySQL并发中的“死锁”之谜:从原理到解决方案的深度解析

目录 1. 环境准备:创建“账户”和“标记”表1.1 创建 dl_account_t 表1.2 创建 dl_mark_t 表 2. 死锁详解2.1 死锁情景一:相反加锁顺序导致的死锁2.2 死锁情景二:唯一索引冲突引发的死锁 3. 事务隔离级别与锁机制4. 预防与解决死锁的方法4.1 …

centos源码升级glibc2.19时遇到的错误

基本安装步骤 wget http://mirrors.ustc.edu.cn/gnu/libc/glibc-2.19.tar.gz tar -zvxf glibc-2.19.tar.gz cd glibc-2.19 mkdir build cd build ../configure --prefix/usr --enable-profile --enable-add-ons --with-headers/usr/include --with-binutils/usr/bin在configu…

在基于go开发的web应用中加入Nginx反向代理

文章目录 学习笔记-Nginx0. Nginx介绍1. Nginx下载安装2. 启动web服务和Nginx配置2.1 启动服务2.2 Nginx配置 3. 测试4. 扩展 学习笔记-Nginx 在查阅资料时发现,很少有人介绍怎么在golang中使用nginx,为此,我们选择写一篇简单的,…

LeetCode算法(双指针)

今天的题目主要都是力扣前100中,关于双指针的题 1.移动零 链接:移动零 示例: 示例 : 输入: nums [0,1,0,3,12] 输出:[1,3,12,0,0] 可以看到保持原有元素的顺序,将所有的0,移动到数组最后方即可。 这…

【论文写作】10.26 讨论

本次讨论主要是对方向的修正,以及一些科研素养的补足 一、关于方向 方案 波分离开了,可以分别喂给不同的网络,最后将不同的结果融合。观察结果,获得直观的估计,猜测可能的优化方案。 对于第二点的解释,在…

开源项目工具:LeanTween - 为Unity 3D打造的高效缓动引擎详解(比较麻烦的API版)

1.LeanTween.reset() 一、工具介绍 参考:推荐开源项目:LeanTween - 为Unity 3D打造的高效缓动引擎-CSDN博客 LeanTween是一个专为Unity 3D引擎设计的高效缓动(tweening)库,它提供了简单易用的API,帮助开…

ctfshow(175->178)--SQL注入--联合注入及其过滤

Web175 进入界面: 审计: 查询语句: $sql "select username,password from ctfshow_user5 where username !flag and id ".$_GET[id]." limit 1;";返回逻辑: if(!preg_match(/[\x00-\x7f]/i, json_enc…

数据结构(8.4_3)——堆的插入删除

在堆中插入新元素 在堆中删除元素 总结:

Linux:权限的深度解析(小白必看!!!)

文章目录 前言一、Linux重要的几个热键二、关机三、扩展命令总结四、shell命令以及运行原理感性理解五、Linux权限的概念1. 权限的概念2. 认识人(用户)1)创建人2)人分类3)人切换4)指令提权 3. ll下文件的权…

《学会提问》

只要他们使说出口的话听起来显得信誓旦旦,你就极有可能会相信他们的说法 我们倾听他们,是为了构建出自己的答案,而不是听了他们的话以后,马上就按他们说的去做,就好像自己是只无助的羔羊,或者是个牵线的木…

一些待机电流波形特征

一、待机电流波形 最干净的待机电流波形应该只有paging,不过需要注意2点: 每个paging的间隔,不同网络可能不一样,有可能是320ms, 640ms 待机网络 paging 间隔 1分钟的耗电量 单个耗电量 单个待机电流 单个波形时长 4G 64…

二十三、Python基础语法(包)

包(package):包是一种组织代码的方式,可以将相关的模块组合在一起,以便更好地管理和重用代码,包的目录中有一个特殊代码文件__init__.py,包的命名也要遵循标识符的规则。 一、包的结构 一个 Python 包通常是一个包含…

NLTK无法下载?

以下内容仅为当前认识,可能有不足之处,欢迎讨论! 文章目录 nltk无法下载怎么办?什么是NLTK?为什么要用NLTK?如何下载? nltk无法下载怎么办? 什么是NLTK? NLTK是学习自然…

python项目实战——多协程下载美女图片

协程 文章目录 协程协程的优劣势什么是IO密集型任务特点示例与 CPU 密集型任务的对比处理 I/O 密集型任务的方式总结 创建并使用协程asyncio模块 创建协程函数运行协程函数asyncio.run(main())aiohttp模块调用aiohttp模块步骤 aiofiles————协程异步函数遇到的问题一 await …

代码随想录跟练21天——LeetCode332.重新安排行程, 51. N皇后,37. 解数独

332.重新安排行程 力扣题目链接(opens new window) 给定一个机票的字符串二维数组 [from, to],子数组中的两个成员分别表示飞机出发和降落的机场地点,对该行程进行重新规划排序。所有这些机票都属于一个从 JFK(肯尼迪国际机场)出…

3、java if流程控制、while循环语句

目录 选择流程控制语句循环流程控制语句控制循环语句顺序结构If语句Switch语句For循环While循环Do-While循环控制跳转语句1. 选择流程控制语句 引入话题 想象一下,你正在过马路,你需要先检查是否有车辆经过。如果没有车辆,你才会过马路。这种先判断条件再执行动作的过程,…

【Python可视化系列】一文教你绘制双Y轴的双折线图(案例+源码)

这是我的第369篇原创文章。 一、引言 在日常工作和学习中,我们会遇到将两个折线画在一张图上的情况,且这两个折线代表了两个特征,具有不同的涵义和量纲表示,这时候我们就需要绘制一个双Y轴折线图,一边代表一个特征&…

Redis 持久化 总结

前言 相关系列 《Redis & 目录》(持续更新)《Redis & 持久化 & 源码》(学习过程/多有漏误/仅作参考/不再更新)《Redis & 持久化 & 总结》(学习总结/最新最准/持续更新)《Redis & …

GraphQL语法入门

目录 一、介绍GraphQL二、GraphQL基本使用方法三、Schema 定义语言 (SDL)3.1 类型定义1)对象类型2)标量类型3)枚举类型4)输入类型5)列表类型6)非空类型7)接口类型8)联合类型 3.2 查询…

python进阶集锦

一、迭代器和生成器 区别 关于迭代器和生成器 迭代器与生成器的区别 迭代器(Iterator)和生成器(Generator)是Python中处理序列数据的两种不同概念。迭代器是遵循迭代协议的对象,而生成器是一种特殊类型的迭代器&am…