模型训练识别手写数字(一)

  一、模型训练数据集

1. 导入所需库

import numpy as np
from sklearn.datasets import fetch_openml

numpy 是用于数值计算的库。

fetch_openml 是用于从 OpenML 下载数据集的函数。

  2. 获取 MNIST 数据集

X, y = fetch_openml('mnist_784', version=1, return_X_y=True)

fetch_openml('mnist_784', version=1, return_X_y=True) 从 OpenML 下载 MNIST 数据集。X 存储图像数据(784 个特征,28x28 像素的扁平化图像),y 存储对应的标签(数字 0 到 9)。

   3. 将像素值二值化

X[X > 0] = 1

这行代码将 X 中所有大于 0 的像素值设置为 1,二值化处理。这样处理后的图像只有两个值:0(黑色)和 1(白色),有助于简化模型的输入。

   4. 保存数据集

np.save("Data/dataset", X)
np.save("Data/class", y)

np.save("Data/dataset", X) 将图像数据保存为 dataset.npy

np.save("Data/class", y) 将标签数据保存为 class.npy

二、模型训练及预测

1. 导入所需库

import matplotlib.pyplot as plt
import numpy as np
from keras import Sequential
from keras import layers
from keras.api.models import load_model
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

matplotlib.pyplot: 用于绘图和数据可视化。

numpy: 用于处理数组和数据加载。

keras: 用于构建和训练深度学习模型。

sklearn: 提供数据划分和预处理的工具。

f1_score: 用于评估模型性能。

 2. 加载数据

X = np.load("Data/dataset.npy", allow_pickle=True)
y = np.load("Data/class.npy", allow_pickle=True)

使用 numpyload 方法加载训练数据 (X) 和标签 (y)。allow_pickle=True 允许加载包含对象的数组。

 3. One-Hot 编码

onehot = OneHotEncoder(sparse_output=False)
y = onehot.fit_transform(y.reshape(-1, 1))

OneHotEncoder: 将标签转换为独热编码格式,方便用于分类任务。每个标签会被转换为一个二进制数组。 

 4. 划分训练集和测试集

x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=14)

使用 train_test_split 将数据集分为训练集和测试集,通常使用 70%-80% 的数据用于训练,其余用于测试。 

 5. 构建模型

model = Sequential()
model.add(layers.Dense(100, activation='relu', input_shape=(x_train.shape[1],)))
model.add(layers.Dense(y.shape[1], activation='softmax'))  

Sequential: 表示模型是线性的,按顺序堆叠各个层。

Dense: 添加全连接层,第一层有 100 个神经元,使用 ReLU 激活函数;第二层为输出层,使用 Softmax 激活函数,适合多类分类任务。

 6. 编译模型

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

使用 Adam 优化器,损失函数为交叉熵(适合多分类),并监控准确率。 

 7. 训练模型

model.fit(x_train, y_train, epochs=100, batch_size=32, verbose=1)

 训练模型 100 个周期,批量大小为 32,verbose=1 表示输出训练过程的信息。

 8. 评估模型

predictions = model.predict(x_test)
predictions_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)print("F-score: {0:.2f}".format(f1_score(y_test_classes, predictions_classes, average='micro')))

使用测试集进行预测,并计算 F-score 作为评估指标。np.argmax 用于获取每个样本预测概率最高的类。 

 9. 保存模型

model.save("my_model.h5")

将训练好的模型保存到文件 my_model.h5 中,以便后续加载和使用。 

 10. 加载模型

loaded_model = load_model("my_model.h5")

 加载之前保存的模型,以便进行预测。

 11. 进行预测

predictions = loaded_model.predict(x_test)

使用加载的模型对测试集进行预测,获取每个样本的预测结果。

 12. 获取预测和真实标签

y_pred_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)

 使用 np.argmax 从预测结果和真实标签中获取每个样本的类别索引。

  13. 可视化预测结果

plt.figure(figsize=(12, 6))for i in range(20):plt.subplot(4, 5, i + 1)plt.imshow(x_test[i].reshape(28, 28), cmap='gray')  # 假设输入是28x28的图像plt.title(f'True: {y_test_classes[i]}\nPred: {y_pred_classes[i]}')plt.axis('off')plt.tight_layout()
plt.show()

创建一个图形窗口,设置大小为 12x6。

使用 subplot 在 4 行 5 列的网格中绘制 20 个图像。

每个子图中显示测试样本的图像、真实标签和预测标签。

imshow 将图像进行灰度显示,axis('off') 隐藏坐标轴。

tight_layout() 调整子图参数,以避免重叠。

show() 显示图形。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/57072.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode算法(双指针)

今天的题目主要都是力扣前100中,关于双指针的题 1.移动零 链接:移动零 示例: 示例 : 输入: nums [0,1,0,3,12] 输出:[1,3,12,0,0] 可以看到保持原有元素的顺序,将所有的0,移动到数组最后方即可。 这…

开源项目工具:LeanTween - 为Unity 3D打造的高效缓动引擎详解(比较麻烦的API版)

1.LeanTween.reset() 一、工具介绍 参考:推荐开源项目:LeanTween - 为Unity 3D打造的高效缓动引擎-CSDN博客 LeanTween是一个专为Unity 3D引擎设计的高效缓动(tweening)库,它提供了简单易用的API,帮助开…

ctfshow(175->178)--SQL注入--联合注入及其过滤

Web175 进入界面: 审计: 查询语句: $sql "select username,password from ctfshow_user5 where username !flag and id ".$_GET[id]." limit 1;";返回逻辑: if(!preg_match(/[\x00-\x7f]/i, json_enc…

数据结构(8.4_3)——堆的插入删除

在堆中插入新元素 在堆中删除元素 总结:

Linux:权限的深度解析(小白必看!!!)

文章目录 前言一、Linux重要的几个热键二、关机三、扩展命令总结四、shell命令以及运行原理感性理解五、Linux权限的概念1. 权限的概念2. 认识人(用户)1)创建人2)人分类3)人切换4)指令提权 3. ll下文件的权…

一些待机电流波形特征

一、待机电流波形 最干净的待机电流波形应该只有paging,不过需要注意2点: 每个paging的间隔,不同网络可能不一样,有可能是320ms, 640ms 待机网络 paging 间隔 1分钟的耗电量 单个耗电量 单个待机电流 单个波形时长 4G 64…

二十三、Python基础语法(包)

包(package):包是一种组织代码的方式,可以将相关的模块组合在一起,以便更好地管理和重用代码,包的目录中有一个特殊代码文件__init__.py,包的命名也要遵循标识符的规则。 一、包的结构 一个 Python 包通常是一个包含…

NLTK无法下载?

以下内容仅为当前认识,可能有不足之处,欢迎讨论! 文章目录 nltk无法下载怎么办?什么是NLTK?为什么要用NLTK?如何下载? nltk无法下载怎么办? 什么是NLTK? NLTK是学习自然…

python项目实战——多协程下载美女图片

协程 文章目录 协程协程的优劣势什么是IO密集型任务特点示例与 CPU 密集型任务的对比处理 I/O 密集型任务的方式总结 创建并使用协程asyncio模块 创建协程函数运行协程函数asyncio.run(main())aiohttp模块调用aiohttp模块步骤 aiofiles————协程异步函数遇到的问题一 await …

代码随想录跟练21天——LeetCode332.重新安排行程, 51. N皇后,37. 解数独

332.重新安排行程 力扣题目链接(opens new window) 给定一个机票的字符串二维数组 [from, to],子数组中的两个成员分别表示飞机出发和降落的机场地点,对该行程进行重新规划排序。所有这些机票都属于一个从 JFK(肯尼迪国际机场)出…

【Python可视化系列】一文教你绘制双Y轴的双折线图(案例+源码)

这是我的第369篇原创文章。 一、引言 在日常工作和学习中,我们会遇到将两个折线画在一张图上的情况,且这两个折线代表了两个特征,具有不同的涵义和量纲表示,这时候我们就需要绘制一个双Y轴折线图,一边代表一个特征&…

Redis 持久化 总结

前言 相关系列 《Redis & 目录》(持续更新)《Redis & 持久化 & 源码》(学习过程/多有漏误/仅作参考/不再更新)《Redis & 持久化 & 总结》(学习总结/最新最准/持续更新)《Redis & …

python进阶集锦

一、迭代器和生成器 区别 关于迭代器和生成器 迭代器与生成器的区别 迭代器(Iterator)和生成器(Generator)是Python中处理序列数据的两种不同概念。迭代器是遵循迭代协议的对象,而生成器是一种特殊类型的迭代器&am…

Vue学习笔记(八)

透传attribute "透传attribute"指的是传递给一个组件,却没有被改组件声明为props或emits的attribute或者v-on事件监听器。最常见的例子就是class、style和id。 当一个组件以单个元素为根作渲染时,透传的attribute会自动被添加到根元素上。 …

4个提取音频办法,轻松实现视频转音频!

在信息爆炸的时代,视频内容以其直观、生动的特点占据了互联网的大半江山。然而,在某些场景下,我们可能更倾向于只听取音频部分,无论是驾驶途中听讲座、跑步时享受音乐视频中的纯音乐的场景,还是为了节省流量和存储空间…

C++ 类与对象入门:基础知识与定义

引言: 本来打算用一篇介绍清楚C中的类与对象,再三考虑后觉得不妥:第一,知识点实在太多;第二,对于从刚学完C并打算过渡到C的朋友来说,学的太深较有难度… 总而言之,我打算用三到四篇文…

一篇文章总结 SQL 基础知识点

1. 官方文档 MySQL:https://dev.mysql.com/doc/refman/8.4/en/ SQL Server:What is SQL Server? - SQL Server | Microsoft Learn Oracle:https://docs.oracle.com/en/database/oracle/oracle-database/23/lnpls/loe.html 2. 术语 SQL S…

电脑程序变化监控怎么设置?实时监控电脑程序变化的五大方法,手把手教会你!

​在现代办公和信息安全领域,实时监控电脑程序变化是一项至关重要的任务。 无论是企业内网安全、员工行为审计,还是个人电脑的隐私保护,了解并设置有效的监控方法都是必不可少的。 本文将详细介绍五种电脑程序变化监控的方法,帮助…

️ Vulnhuntr:利用大型语言模型(LLM)进行零样本漏洞发现的工具

在网络安全领域,漏洞的发现和修复是保护系统安全的关键。今天,我要向大家介绍一款创新的工具——Vulnhuntr,这是一款利用大型语言模型(LLM)进行零样本漏洞发现的工具,能够自动分析代码,检测远程…

SAP-ABAP开发学习-FUNCTION ALV

ALV概览 ALV全称SAP List View,是SAP提供的一个强大的数据报表显示工具。ALV实质上是一个屏幕控件对象,它通过程序传递数据内表的方式来显示数据。 实现方式:调用标准函数;优化接口:用户可以实现对字段的排序、筛选及统计等功能。…