基于TensorFlow的手写体数字识别训练与测试

需求:

  • 选择一个最简单的细分方向,初步了解AI图像识别的训练、测试过程
  • TensorFlow、PyTorch、c,三种代码方案,先从TensorFlow入手
  • 探讨最基本问题的优化问题

总结:

  • 基于TensorFlow的python代码库自带了mnist 训练数据集、测试数据集。避免了自己去收集图像、标注的问题。
  • 利用chatgpt逐步完善代码,输出图像(字符方式、bmp方式)辅助分析
  • x为0-9的图像、y为对应数字标签0-9,train训练集60000个,test测试集10000个
  • 实际测试结果能达到98%成功识别率,但是剩下的2%错得也很离谱,有优化的空间。
  • 每次训练、测试的结果,存在差别,并不是完全一样的结果,TensorFlow算法中可能存在随机数
  • 测试失败的数字2中,部分与训练集比较类似,直观看起来不应该失败

代码和注释

# 环境: 20241030 win10 vs2022 python3.9.13
# 安装tensorflow: pip install tensorflow
# vs2022时,在 C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python39_64\Scripts 下运行import os
import numpy as np
import PIL.Image as Image# 显示图像
#import matplotlib.pyplot as plt# oneDNN: Intel 推出的一款深度学习性能优化库,可以加速深度学习计算。
# 1启用/0禁用 oneDNN 优化
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'# 第一次import tensorflow耗时较长
import tensorflow as tf
from tensorflow.keras import layers, models# 检查GPU是否可用
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))def display_mnist_image_console(image):# 设置字符映射,空格代表最暗,#代表最亮chars = " .:-=+*#%@"# 归一化图像到0-9的整数范围normalized_image = (image / 255 * (len(chars) - 1)).astype(int)# 使用字符映射显示图像for row in normalized_image:print("".join(chars[pixel] for pixel in row))def save_mnist_image_as_bmp(image, filename="1.bmp"):"""将MNIST图像保存为BMP格式Args:image: MNIST图像数据,形状为(28, 28)filename: 保存的文件名"""# 确保图像数据在0-255范围内image = np.clip(image, 0, 255).astype(np.uint8)# 将图像数据转换为PIL Image对象img = Image.fromarray(image, 'L')  # 'L'表示灰度图像# 保存图像img.save(filename)# 定义MNIST数据集的下载地址
mnist_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz"# 检查本地是否存在MNIST数据集文件
data_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(data_dir, "mnist.npz")if not os.path.exists(data_file):print(f"本地未找到MNIST数据集,正在从 {mnist_url} 下载...")# 使用tensorflow自带的下载函数下载数据集tf.keras.utils.get_file(filename="mnist.npz", origin=mnist_url, extract=True)
else:print(f"本地已存在MNIST数据集,将使用本地文件 {data_file}")# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test_original = x_test.copy()  # 创建 x_test 的备份
# 下载mnist.npz文件,解压后是x_train.npy x_test.npy等4个文件
# .npy 文件是 NumPy(Numerical Python)的一种自描述二进制文件格式。# 输出数据集基本信息
# x图像、y标签,train训练集60000个,test测试集10000
print(f"训练集图像形状:{x_train.shape},数据类型{x_train.dtype};标签形状:{y_train.shape},数据类型{y_train.dtype}")
print(f"测试集图像形状:{x_test.shape},数据类型{x_test.dtype};标签形状:{y_test.shape},数据类型{y_test.dtype}")# 输出更多详细信息
print(f"\n标签{y_train[0]}对应的图像示例:")
#print(x_train[0]) # 这个图像示例是数字528*28的灰度图
display_mnist_image_console(x_train[0])
save_mnist_image_as_bmp(x_train[0])# 输出图像的最小值和最大值
print(f"\n训练集图像像素值的最小值:{np.min(x_train)};最大值:{np.max(x_train)}") # 0 - 255# x图像 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0  # 定义一个简单的CNN模型
model = models.Sequential([ # Sequential: 创建一个顺序模型,即神经网络的层按顺序堆叠。layers.Flatten(input_shape=(28, 28)), # Flatten: 将输入的 28x28 的二维图像展平为一维向量,以便输入到全连接层。layers.Dense(128, activation='relu'), # Dense: 全连接层,神经元之间全连接。128: 输出神经元的数量,即隐藏层的神经元数量。activation='relu': 使用 ReLU 作为激活函数,引入非线性。layers.Dropout(0.2), # Dropout 层,随机丢弃部分神经元,防止过拟合。每次训练时,随机丢弃 20% 的神经元。layers.Dense(10, activation='softmax') # Dense(10, activation='softmax'): 输出层,有 10 个神经元,对应 10 个数字分类。使用 softmax 激活函数,将输出转换为概率分布。
])# 编译模型
model.compile(optimizer='adam', # 使用 Adam 优化器,一种常用的优化算法。loss='sparse_categorical_crossentropy', # 使用稀疏分类交叉熵作为损失函数,适用于多分类问题且标签是整数的情况。metrics=['accuracy']) # 评估指标为准确率。# 训练模型
model.fit(x_train, y_train, epochs=5) # 训练 5 个 epoch,每个 epoch 遍历一遍整个训练集。训练5次。# 评估模型的性能,并输出损失和准确率。
# 损失(loss): 模型在测试集上的平均损失值,反映了模型预测值与真实值之间的差异。损失越小,说明模型预测越准确。
# 准确率(accuracy): 模型在测试集上预测正确的样本比例,直接反映了模型的分类性能。
#model.evaluate(x_test, y_test) # 评估模型的性能,并输出损失和准确率。
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"\n模型评估 - 损失: {loss:.4f}, 准确率: {accuracy:.4f}")# 预测测试集标签
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)# 初始化错误样本计数
wrong_count = 0
total_count = len(x_test)# 遍历测试集,输出识别错误的样本
print("\n识别错误的样本:")
for i in range(len(x_test)):if predicted_labels[i] != y_test[i]:  # 判断是否识别错误wrong_count += 1print(f"\n样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")display_mnist_image_console(x_test_original[i])  # 显示图像# 输出错误样本总数和总样本数
print(f"\n总共 {total_count} 个样本,识别错误 {wrong_count} 个")
# 总共 10000 个样本,识别错误 222 个。  部分识别错误的明显不应该错。

示例图像:
在这里插入图片描述

识别错误的图像举例:
在这里插入图片描述
在这里插入图片描述
输出指定img列表到bmp文件

def save_images_to_bmp(images, labels, filename, max_per_row=50):"""将图像保存到 BMP 文件中,每行最多 max_per_row 张图像。:param images: 图像数组,形状为 (n, 28, 28):param labels: 标签数组,形状为 (n,):param filename: 保存的 BMP 文件名:param max_per_row: 每行最大图像数量"""img_count = len(images)rows = (img_count + max_per_row - 1) // max_per_rowimg_width, img_height = 28, 28# 创建画布canvas_width = max_per_row * img_widthcanvas_height = rows * img_heightcanvas = Image.new("L", (canvas_width, canvas_height), color=255)  # 灰度图# 绘制每张图片for idx, img in enumerate(images):x_offset = (idx % max_per_row) * img_widthy_offset = (idx // max_per_row) * img_height# img_pil = Image.fromarray((img * 255).astype(np.uint8))  # 恢复像素值范围 0-255img_pil = Image.fromarray(img, 'L')canvas.paste(img_pil, (x_offset, y_offset))# 保存到文件canvas.save(filename)print(f"保存 {filename} 成功!")# 遍历训练集,分类存储
print("\n遍历训练集,分类存储:")
train_images = {i: [] for i in range(10)}
for i in range(len(x_train)):label = y_train[i] # 标签train_images[label].append(x_train[i])# 保存训练集图像
for digit in range(10):# 保存正确分类的样本if train_images[digit]:save_images_to_bmp(train_images[digit],[digit] * len(train_images[digit]),f"train_{digit}.bmp")# 预测测试集标签
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)# 初始化错误样本计数
wrong_count = 0
total_count = len(x_test)
# 初始化存储字典
correct_images = {i: [] for i in range(10)}
wrong_images = {i: [] for i in range(10)}# # 遍历测试集,输出识别错误的样本
# print("\n识别错误的样本:")
# for i in range(len(x_test)):
#     if predicted_labels[i] != y_test[i]:  # 判断是否识别错误
#         wrong_count += 1
#         print(f"\n样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")
#         display_mnist_image_console(x_test_original[i])  # 显示图像# 遍历测试集,分类存储识别结果
print("\n识别错误的样本统计汇总:")
for i in range(len(x_test)):label = y_test[i]                # 真实标签predicted = predicted_labels[i]  # 模型预测结果if predicted == label:correct_images[label].append(x_test_original[i])else:wrong_images[label].append(x_test_original[i])wrong_count += 1print(f"样本索引: {i} 模型预测结果: {predicted_labels[i]}, 正确结果: {y_test[i]}")# display_mnist_image_console(x_test_original[i])  # 显示图像# 保存图像
for digit in range(10):# 保存正确分类的样本if correct_images[digit]:save_images_to_bmp(correct_images[digit],[digit] * len(correct_images[digit]),f"test_{digit}.bmp")# 保存错误分类的样本if wrong_images[digit]:save_images_to_bmp(wrong_images[digit],[digit] * len(wrong_images[digit]),f"test_error_{digit}.bmp",max_per_row=10)

以数字2为例,以下分别为训练集图像、测试集通过的图像、测试集失败的图像:
训练集
测试集通过的
测试集失败的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/61430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通信与网络基础

1.网络通信基本概念 通信:人、物通过某种介质和行为进行信息传递与交流 网络通信:终端设备之间通过计算机网络进行通信 两个终端通过网线传递文件 多个终端通过路由器传递文件 终端通过Internet下载文件 2.信息传递过程 图1-1 假定A计算机访问B的web…

[免费]SpringBoot+Vue景区订票(购票)系统【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的SpringBootVue大景区订票(购票)系统,分享下哈。 项目视频演示 【免费】SpringBootVue景区订票(购票)系统 Java毕业设计_哔哩哔哩_bilibili 项目介绍 现代经济快节奏发展以及不断完善升级的信息…

医疗知识图谱的问答系统详解

一、项目介绍 该项目的数据来自垂直类医疗网站寻医问药,使用爬虫脚本data_spider.py,以结构化数据为主,构建了以疾病为中心的医疗知识图谱,实体规模4.4万,实体关系规模30万。schema的设计根据所采集的结构化数据生成&…

【设计模式系列】解释器模式(十七)

一、什么是解释器模式 解释器模式(Interpreter Pattern)是一种行为型设计模式,它的核心思想是分离实现与解释执行。它用于定义语言的文法规则,并解释执行语言中的表达式。这种模式通常是将每个表达式抽象成一个类,并通…

AI表情神同步!LivePortrait安装配置,一键包,使用教程

快手在AI视频这领域还真有点东西,视频生成工具“可灵”让大家玩得不亦乐乎。 现在又开源了一个超好玩的表情同步(表情控制)项目。 一看这图片,就充满了娱乐性。发布没几天就已经有8000Star。 项目****简介 LivePortrait 是一款…

阿里云服务器(centos7.6)部署前后端分离项目

Mysql8安装部署 确定一下系统的glibc版本,可以使用以下命令进行查看,当前系统glibc版本:2.17(重要!!!) 要根据自己服务器的版本去选择对应的mysql,不然后续安装会报错&a…

Java中TimedCache缓存对象的详细使用

一、TimedCache 是什么? TimedCache是一个泛型类,它的主要作用通常是在一定时间范围内对特定键值对进行缓存,并且能够根据设定的时间策略来自动清理过期的缓存项。 TimedCache是一种带有时间控制功能的缓存数据结构。在 Java 中&#xff0c…

11、数组

1、数组概念 数组就是存储多个相同数据类型的数据。 比如:存储26个字母,存储一个班级的学生成绩。 2、数组使用 数组要遵循先定义再使用 2.1、数组定义的格式 存储数据---空间 ---- 数据类型 多少个 --- 数据个数 >> 数据类型 数…

六、文本搜索工具(grep)和正则表达式

一、grep工具的使用 1、概念 grep: 是 linux 系统中的一个强大的文本搜索工具,可以按照 正则表达式 搜索文本,并把匹配到的行打印出来(匹配到的内容标红)。 2、语法 grep [options]…… pattern [file]…… 工作方式…

【python】爬去二手车数据 未完成

技术方案 python selenium 先下载Microsoft Edge WebDriver Microsoft Edge WebDriver 官网 先看一下自己的edge版本 搜索到版本然后下载自己的版本 安装依赖 pip install seleniumimport time from selenium import webdriverdriver webdriver.Edge(executable_pathr&qu…

玩游戏常常出现vc++runtime library error R6025 这是什么意思,该怎么解决?

当玩游戏时常常出现“vc runtime library error R6025”错误,这通常表明微软C开发运行库组件存在问题。以下是对该错误及其解决方法的详细解释: 错误含义 “vc runtime library error R6025”是一个与Visual C运行时库相关的错误,该错误表明…

【深度学习基础】一篇入门模型评估指标(分类篇)

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀深度学习_十二月的猫的博客-CSDN博客 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 1. 前言 2. 模…

深度学习基础02_损失函数BP算法(上)

目录 一、损失函数 1、线性回归损失函数 1.MAE损失 2.MSE损失 3.SmoothL1Loss 2、多分类损失函数--CrossEntropyLoss 3、二分类损失函数--BCELoss 4、总结 二、BP算法 1、前向传播 1.输入层(Input Layer)到隐藏层(Hidden Layer) 2.隐藏层(Hidden Layer)到输出层(Ou…

从技术视角看AI在Facebook全球化中的作用

在全球化日益加深的今天,人工智能(AI)作为一种变革性技术,正在深刻影响全球互联网巨头的发展方向。Facebook作为全球最大的社交媒体平台之一,正通过AI技术突破语言、文化和技术的障碍,推动全球化战略的实现…

41 基于单片机的小车行走加温湿度检测系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于51单片机,采样DHT11温湿度传感器检测温湿度,滑动变阻器连接数码转换器模拟电量采集传感器, 电机采样L298N驱动,各项参数通过LCD1602显示&#x…

Python3 爬虫 Scrapy的使用

安装完成Scrapy以后&#xff0c;可以使用Scrapy自带的命令来创建一个工程模板。 一、创建项目 使用Scrapy创建工程的命令为&#xff1a; scrapy startproject <工程名> 例如&#xff0c;创建一个抓取百度的Scrapy项目&#xff0c;可以将命令写为&#xff1a; scrapy s…

【S500无人机】--地面端下载

之前国庆的时候导师批了无人机&#xff0c;我们几个也一起研究了几次&#xff0c;基本把无人机组装方面弄的差不多了&#xff0c;还差个相机搭载&#xff0c;今天我们讲无人机的调试 硬件配置如下 首先是地面端下载&#xff0c;大家可以选择下载&#xff1a; Mission Planne地…

Android -- 简易音乐播放器

Android – 简易音乐播放器 播放器功能&#xff1a;* 1. 播放模式&#xff1a;单曲、列表循环、列表随机&#xff1b;* 2. 后台播放&#xff08;单例模式&#xff09;&#xff1b;* 3. 多位置同步状态回调&#xff1b;处理模块&#xff1a;* 1. 提取文件信息&#xff1a;音频文…

常用端口与Udp协议

目录 1.再谈端口 1.1 五元组 1.2 端口号范围划分 1.3 两个指令 1.3.1 netstat 1.3.2 pidof 2.UDP协议 2.1 协议整体格式 2.2 udp特点 2.3 udo缓冲区 1.再谈端口 1.1 五元组 端口号表示了一个主机上进行通信的不同的应用程序&#xff1b;在Tcp/IP协议中&#xff0c;用…

计算机毕业设计SpringCloud+大模型微服务高考志愿填报推荐系统 高考大数据 SparkML机器学习 深度学习 人工智能 Python爬虫 知识图谱

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…