02、Tensorflow实现手写数字识别(数字0-9)

02、Tensorflow实现手写数字识别(数字0-9)

01、Tensorflow实现二元手写数字识别(二分类问题)
02、Tensorflow实现手写数字识别(数字0-9)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0与pycharm

1、识别目标

识别手写仅仅是为了区分手写的0到9,所以实际上是一个多分类问题。
STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

from sklearn.metrics import classification_report 这行代码的主要作用是导入classification_report 函数,以便在后续的代码中使用它来评估分类模型的性能。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。


STEP2:屏蔽无用警告并允许中文

# 使用warnings模块来忽略特定类型的警告  
warnings.simplefilter(action='ignore', category=FutureWarning)  
# 配置tensorflow的日志记录级别  
logging.getLogger("tensorflow").setLevel(logging.ERROR)  
# 设置TensorFlow的autograph模块的详细级别  
tf.autograph.set_verbosity(0)  
# 设置numpy的打印选项  
np.set_printoptions(precision=2)  

STEP3:加载数据集并分割测试集

# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")return X, y# load dataset
X, y = load_data()print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

原始的输入的数据集是5000* 400数组,共包含5000个手写数字的数据,其中400为20*20像素的图片,
在这里插入图片描述


STEP4:模型构建与训练

# 构建模型  
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的  
model = Sequential([### START CODE HERE ###  tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维   Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"  Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"  Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"  ### END CODE HERE ###  ], name="my_model"
)  # 定义模型名称为"my_model"  
model.summary()  # 打印模型的概述信息  # 配置模型的训练参数  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001  
)# 训练模型  
history = model.fit(X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据  epochs=100  # 训练100轮  
)

STEP5:结果可视化与打印准确度信息

fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))prediction_p = tf.nn.softmax(prediction)yhat = np.argmax(prediction_p)# 错误结果标红if y_test[random_index, 0] == yhat:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)ax.set_axis_off()else:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')ax.set_axis_off()fig.suptitle("Label, yhat", fontsize=14)
plt.show()# 给出预测的测试集误差
def evaluation(y_test, y_predict):accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']precision=s['precision']recall=s['recall']f1_score=s['f1-score']#kappa=cohen_kappa_score(y_test, y_predict)return accuracy,precision,recall,f1_score #, kappay_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

3、运行结果

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现手写数字识别(数字0-9)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings# 使用warnings模块来忽略特定类型的警告
warnings.simplefilter(action='ignore', category=FutureWarning)
# 配置tensorflow的日志记录级别
logging.getLogger("tensorflow").setLevel(logging.ERROR)
# 设置TensorFlow的autograph模块的详细级别
tf.autograph.set_verbosity(0)
# 设置numpy的打印选项
np.set_printoptions(precision=2)# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")return X, y# load dataset
X, y = load_data()print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 绘图可选
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(5, 5))
# fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
# # fig.tight_layout(pad=0.5)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # reshape the image
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
#     fig.suptitle("Label, image", fontsize=14)
# plt.show()# 构建模型
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的
model = Sequential([### START CODE HERE ###tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"### END CODE HERE ###], name="my_model"
)  # 定义模型名称为"my_model"
model.summary()  # 打印模型的概述信息# 配置模型的训练参数
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001
)# 训练模型
history = model.fit(X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据epochs=100  # 训练100轮
)fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))prediction_p = tf.nn.softmax(prediction)yhat = np.argmax(prediction_p)# Display the label above the imageif y_test[random_index, 0] == yhat:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)ax.set_axis_off()else:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')ax.set_axis_off()fig.suptitle("Label, yhat", fontsize=14)
plt.show()# 给出预测的测试集误差
def evaluation(y_test, y_predict):accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']precision=s['precision']recall=s['recall']f1_score=s['f1-score']#kappa=cohen_kappa_score(y_test, y_predict)return accuracy,precision,recall,f1_score #, kappay_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

zookeeper集群和kafka集群

(一)kafka 1、kafka3.0之前依赖于zookeeper 2、kafka3.0之后不依赖zookeeper,元数据由kafka节点自己管理 (二)zookeeper 1、zookeeper是一个开源的、分布式的架构,提供协调服务(Apache项目&…

【Openstack Train安装】二、NTP安装

网络时间协议:Network Time Protocol(NTP)是用来使计算机时间同步化的一种协议,它可以使计算机对其服务器或时钟源(如石英钟,GPS等等)做同步化,它可以提供高精准度的时间校正(LAN上与…

ACM程序设计课内实验(2) 排序问题

基础知识‘ sort函数 C中的sort函数是库中的一个函数&#xff0c;用于对容器中的元素进行排序。它的原型如下&#xff1a; template <class RandomAccessIterator, class Compare> void sort (RandomAccessIterator first, RandomAccessIterator last, Compare comp);参数…

IC设计简单概述

IC设计行业是一个高科技行业&#xff0c;有着复杂而细致的分工&#xff0c;严格的流程规范、多种不同类型的EDA工具。下面简单概述以下几个方面。 IC设计公司的分类 IC设计公司有多种分类方法。若按有无芯片生产能力来分&#xff0c;可以分为兼具设计与生产能力&#xff08;I…

Shopee引流妙招!Shopee产品标签重要吗?教你有效打标签引爆流量!

对Shopee平台的卖家来说&#xff0c;在新产品上架时除了要注重产品title、介绍以及图文的优化&#xff0c;还有一件事情很重要&#xff0c;那就是——产品打标签。 对于每个跨境电商卖家来讲&#xff0c;对产品打标签都是必不可少的一个运营环节 下面小宇就来告诉大家&#xf…

SpringSecurity的默认登录页的使用

SpringSecurity的默认登录页的使用 01 前期准备 引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><!--mysql驱动--><dependency><grou…

主机的容器化技术介绍

☞ ░ 前往老猿Python博客 ░ https://blog.csdn.net/LaoYuanPython 一、什么是容器 容器是一个标准化的单元&#xff0c;是一种轻量级、可移植的软件打包技术&#xff0c;容器将软件代码及其相关依赖打包&#xff0c;使应用程序可以在任何计算介质运行。例如开发人员在自己的…

python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM&#xff0c;GRU&#xff0c;文本情感分类 数据集格式&#xff1a; 有需要的可以联系我 实现步骤就是&#xff1a; 1.先对句子进行分词并构建词表 2.生成word2id 3.构建模型 4.训练模型 5.测试模型 代码如下&#xff1a; import pandas as pd im…

命名管道:简单案例实现

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;Linux &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 本博客主要内容讲解了什么是命名管道&#xff0c;匿名管道和命名管道的…

深入了解Rabbit加密技术:原理、实现与应用

一、引言 在信息时代&#xff0c;数据安全愈发受到重视&#xff0c;加密技术作为保障信息安全的核心手段&#xff0c;得到了广泛的研究与应用。Rabbit加密技术作为一种新型加密方法&#xff0c;具有较高的安全性和便捷性。本文将对Rabbit加密技术进行深入探讨&#xff0c;分析…

六、初识FreeRTOS之FreeRTOS的任务挂起和恢复函数介绍

本节需要掌握以下内容&#xff1a; 1&#xff0c;任务的挂起与恢复的API函数&#xff08;熟悉&#xff09; 2&#xff0c;任务挂起与恢复实验&#xff08;掌握&#xff09; 3&#xff0c;课堂总结&#xff08;掌握&#xff09; 一、任务的挂起与恢复的API函数&#xff08;熟…

C++ day41 动态规划 整数拆分 不同的二叉搜索树

题目1&#xff1a;343 整数拆分 题目链接&#xff1a;整数拆分 对题目的理解 将正整数n&#xff0c;拆分成k个正整数的和&#xff08;k>2&#xff09;使得这些整数的乘积最大化&#xff0c;返回最大乘积 动规五部曲 1&#xff09;dp数组的含义以及其下标i的含义 dp[i]…

Verilog 入门(四)(门电平模型化)

文章目录 内置基本门多输入门简单示例 内置基本门 Verilog HDL 中提供下列内置基本门&#xff1a; 多输入门 and&#xff0c;nand&#xff0c;or&#xff0c;nor&#xff0c;xor&#xff0c;xnor 多输出门 buf&#xff0c;not 三态门上拉、下拉电阻MOS 开关双向开关 门级逻辑…

OSG编程指南<十七>:OSG光照与材质

1、OSG光照 OSG 全面支持 OpenGL 的光照特性&#xff0c;包括材质属性&#xff08;material property&#xff09;、光照属性&#xff08;light property&#xff09;和光照模型&#xff08;lighting model&#xff09;。与 OpenGL 相似&#xff0c;OSG 中的光源也是不可见的&a…

工博会新闻稿汇总

23届工博会媒体报道汇总 点击文章标题即可进入详情页 9月23日&#xff0c;第23届工博会圆满落幕&#xff01;本届工博会规模之大、能级之高、新展品之多创下历史之最。高校展区在规模、能级和展品上均也创下新高。工博会系列报道深入探讨了高校科技发展的重要性和多方面影响。…

【合集】MQ消息队列——Message Queue消息队列的合集文章 RabbitMQ入门到使用

前言 RabbitMQ作为一款常用的消息中间件&#xff0c;在微服务项目中得到大量应用&#xff0c;其本身是微服务中的重点和难点。本篇博客是Message Queue相关的学习博客文章的合集篇&#xff0c;目前主要是RabbitMQ入门到使用文章&#xff0c;后续会扩展其他MQ。 目录 前言一、R…

自定义链 SNAT / DNAT 实验举例

参考原理图 实验前的环境搭建 1. 准备三台虚拟机&#xff0c;定义为内网&#xff0c;外网以及网卡服务器 2. 给网卡服务器添加网卡 3. 将三台虚拟机的防火墙和安全终端全部关掉 systemctl stop firewalld && setenforce 0 4. 给内网虚拟机和外网虚拟机 yum安装 httpd…

阿里云国际短信业务网络超时排障指南

选取一台或多台线上的应用服务器或选取相同网络环境下的机器&#xff0c;执行以下操作。 获取公网出口IP。 curl ifconfig.me 测试连通性。 &#xff08;推荐&#xff09;执行MTR命令&#xff08;可能需要sudo权限&#xff09;&#xff0c;检测连通性&#xff0c;执行30秒。 m…

Scrapy框架中间件(一篇文章齐全)

1、Scrapy框架初识&#xff08;点击前往查阅&#xff09; 2、Scrapy框架持久化存储&#xff08;点击前往查阅&#xff09; 3、Scrapy框架内置管道&#xff08;点击前往查阅&#xff09; 4、Scrapy框架中间件 Scrapy 是一个开源的、基于Python的爬虫框架&#xff0c;它提供了…

HashMap的实现原理

1.HashMap实现原理 HashMap的数据结构&#xff1a; *底层使用hash表数据结构&#xff0c;即数组链表红黑树 当我们往HashMap中put元素时&#xff0c;利用key的hashCode重新hash计算出当前对象的元素在数组中的下标 存储时&#xff0c;如果出现hash值相同的key&#xff0c;此时…