01、Tensorflow实现二元手写数字识别

01、Tensorflow实现二元手写数字识别(二分类问题)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0

1、识别目标

识别手写仅仅是为了区分手写的0和1,所以实际上是一个二分类问题。

2、Tensorflow算法实现

STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_score

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。

from sklearn.metrics import accuracy_score:这是从scikit-learn库中引入accuracy_score函数。这个函数用于计算分类准确率,常用于评估分类模型的性能。


STEP2:屏蔽无用警告并允许中文

logging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

logging.getLogger(“tensorflow”).setLevel(logging.ERROR):这行代码用于设置 TensorFlow 的日志级别为 ERROR。这意味着只有当 TensorFlow 中发生错误时,才会在日志中输出相关信息。较低级别的日志信息(如 WARNING、INFO、DEBUG)将被忽略。

tf.autograph.set_verbosity(0):这行代码用于设置 TensorFlow 的自动图形(Autograph)日志的冗长级别为 0。这意味着在将 Python 代码转换为 TensorFlow 图形代码时,将不会输出任何日志信息。这有助于减少日志噪音,使日志更加干净。

warnings.simplefilter(action=‘ignore’,category=FutureWarning):这行代码用于忽略所有 FutureWarning 类型的警告。在 Python中,当使用某些即将过时或未来版本中可能发生变化的特性时,通常会发出 FutureWarning。通过设置action=‘ignore’,代码将不会输出这类警告,使控制台输出更加干净。

plt.rcParams[‘font.sans-serif’]=[‘SimHei’]:这行代码用于设置 matplotlib 中的默认无衬线字体为 SimHei。SimHei 是一种常用于显示中文的字体,这样设置后,matplotlib 将在绘图时使用 SimHei 字体来显示中文,从而避免中文乱码问题。

plt.rcParams[‘axes.unicode_minus’]=False:这行代码用于解决 matplotlib
中负号显示异常的问题。默认情况下,matplotlib 可能无法正确显示负号,将其设置为 False 可以使用 ASCII字符作为负号,从而正常显示。


STEP3:导入并划分数据集

划分10%作为测试:

X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

STEP4:模型构建与训练

# 构建模型,三层模型进行分类,第一层输入100个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)

STEP5:结果可视化与打印准确度信息
原始的输入的数据集是400 * 1000的数组,共包含1000个手写数字的数据,其中400为20*20像素的图片,因此对每个400的数组进行reshape((20, 20))可以得到原始的图片进而绘图。

# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

3、运行结果

按照最初的划分,数据集包含1000个数据,划分10%为测试集,也就是100个数据。结果可视化随机选择其中的64个数据绘图,每个图像的上方标明了其真实标签和预测的结果,这个是一个非常简单的示例,准确度还是非常高的。
在这里插入图片描述

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现二元手写数字识别(二分类问题)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import warnings
import logging
from sklearn.metrics import accuracy_scorelogging.getLogger("tensorflow").setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
warnings.simplefilter(action='ignore', category=FutureWarning)
# 支持中文显示
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_data/X.npy")y = np.load("Handwritten_Digit_Recognition_data/y.npy")X = X[0:1000]y = y[0:1000]return X, y# 加载数据集,查看数据集大小,可以看到有1000个数据集,每个输入是20*20=400大小的图片
X, y = load_data()
print('The shape of X is: ' + str(X.shape))
print('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 下面画图,随便从原数据取出几个画图,可以注释
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(8, 8))
# fig.tight_layout(pad=0.1)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # 将1*400的数据转换为20*20的图像格式
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
# plt.show()# 构建模型,三层模型进行分类,第一层输入25个神经元...
model = Sequential([tf.keras.Input(shape=(400,)),    #specify input size### START CODE HERE ###Dense(100, activation='sigmoid'),Dense(10, activation='sigmoid'),Dense(1, activation='sigmoid')### END CODE HERE ###], name = "my_model"
)
# 打印三层模型的参数
model.summary()
# 模型设定,学习率0.001,因为是分类,使用BinaryCrossentropy损失函数
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(0.001),
)
# 开始训练,训练循环20
model.fit(X_train,y_train,epochs=20
)# 绘制测试集的预测结果,绘制64个
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
fig.tight_layout(pad=0.1, rect=[0, 0.03, 1, 0.92])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))if prediction >= 0.5:yhat = 1else:yhat = 0# Display the label above the imageax.set_title(f"{y_test[random_index, 0]},{yhat}")ax.set_axis_off()
fig.suptitle("真实标签, 预测的标签", fontsize=16)
plt.show()# 给出预测的测试集误差
y_pred=model.predict(X_test)
print("测试数据集准确率为:", accuracy_score(y_test, np.round(y_pred)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/167814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据丢失预防措施包括什么

数据丢失预防措施是保护企业或个人重要数据的重要手段。以下是一些有效的预防措施: 可以通过域之盾软件来实现数据防丢失,具体的功能包括: https://www.yuzhidun.cn/https://www.yuzhidun.cn/ 1、备份数据 定期备份所有重要数据&#xff0…

unittest指南——不拼花哨,只拼实用

📢专注于分享软件测试干货内容,欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢交流讨论:欢迎加入我们一起学习!📢资源分享:耗时200小时精选的「软件测试」资…

代码随想录算法训练营第五十三天|1143.最长公共子序列 1035.不相交的线 53. 最大子序和

文档讲解:代码随想录 视频讲解:代码随想录B站账号 状态:看了视频题解和文章解析后做出来了 1143.最长公共子序列 class Solution:def longestCommonSubsequence(self, text1: str, text2: str) -> int:dp [[0] * (len(text2) 1) for _ i…

基于法医调查算法优化概率神经网络PNN的分类预测 - 附代码

基于法医调查算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于法医调查算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于法医调查优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

windows11上安装WSL

Windows电脑上要配置linux(这里指ubuntu)开发环境,主要有三种方式: 1)在windows上装个虚拟机(比如vmware)。缺点是vmware加载ubuntu后系统会变慢很多,而且需要通过samba来实现window…

计算机组成原理。3-408

1.动态存储和静态存储 2.双端口RAM 注意:cpu通过地址线和数据线读写数据时,不能同时写,但可以同时读,也不能一边读一边写。 3.多体并行存储器 分为高位存储和低位存储 小结 4.磁盘存储器的组成 5.磁盘的性能指标 磁盘读写寻道…

Vue中Slot的使用指南

目录 前言 什么是slot? 单个slot的使用 具名slot的使用 作用域插槽 总结 前言 在Vue中,slot是一种非常强大和灵活的功能,它允许你在组件模板中预留出一个或多个"插槽",然后在使用这个组件的时候动态地填充内容。这…

TSINGSEE青犀智能分析网关道路积水识别AI算法方案

在各处的街道、路口等区域,及时发现道路积水问题,可以大大减少城市管理部门压力,及时处理,减少交通事故与人员摔倒事故。通过道路积水AI算法,能有效提高城市管理部门效率,优化城市管理方式。 那么&#xff…

【Web】PhpBypassTrick相关例题wp

目录 ①[NSSCTF 2022 Spring Recruit]babyphp ②[鹤城杯 2021]Middle magic ③[WUSTCTF 2020]朴实无华 ④[SWPUCTF 2022 新生赛]funny_php 明天中期考,先整理些小知识点冷静一下 ①[NSSCTF 2022 Spring Recruit]babyphp payload: a[]1&b1[]1&b2[]2&…

NLP的使用

参考: Apache openNLP 简介 - 链滴 (ld246.com) opennlp 模型下载地址:Index of /apache/opennlp/models/ud-models-1.0/ (tencent.com) OpenNLP是一个流行的开源自然语言处理工具包,它提供了一系列的NLP模型和算法。然而,Open…

【模拟开关CH440R】2022-1-20

资料模拟开关CH440芯片手册 - 百度文库 ch440R回来了,导通usb设备没问题,降压不影响。但是我发现个严重的问题,我的电路是直接通过4067控制ch440r接地,低电平,使能三个线路连一起的,邮箱的图您看看&#xf…

N-134基于java实现捕鱼达人游戏

开发工具eclipse,jdk1.8 文档截图&#xff1a; package com.qd.fish;import java.awt.Graphics; import java.io.File; import java.util.ArrayList; import java.util.List;import javax.imageio.ImageIO;public class Fishes {//定义一个集合来管理鱼List<Fish> fish…

五种多目标优化算法(NSDBO、NSGA3、MOGWO、NSWOA、MOPSO)求解微电网多目标优化调度(MATLAB代码)

一、多目标优化算法简介 &#xff08;1&#xff09;非支配排序的蜣螂优化算法NSDBO 多目标应用&#xff1a;基于非支配排序的蜣螂优化算法NSDBO求解微电网多目标优化调度&#xff08;MATLAB&#xff09;-CSDN博客 &#xff08;2&#xff09;NSGA3 NSGA-III求解微电网多目标…

应用场景丨社区燃气管网监测系统建设

燃气作为现代社会的重要能源&#xff0c;燃气被广泛应用于居民生活、工业生产、商业服务等领域。然而&#xff0c;燃气泄漏事故时有发生&#xff0c;不仅给人们的生命财产安全带来严重威胁&#xff0c;也给燃气行业的发展带来不良影响。因此&#xff0c;对于燃气管道的监测和管…

给虚拟机配置静态id地址

1.令人头大的原因 当连接虚拟机的时候 地址不一会就改变&#xff0c;每次都要重新输入 2.配置虚拟机静态id地址 打开命令窗口执行 : vim /etc/sysconfig/network-scripts/ifcfg-ens33 按下面操作修改 查看自己子网掩码 3.重启网络 命令行输入 systemctl restart netwo…

【C语言】函数(四):函数递归与迭代,二者有什么区别

目录 前言递归定义递归的两个必要条件接受一个整型值&#xff08;无符号&#xff09;&#xff0c;按照顺序打印它的每一位使用函数不允许创建临时变量&#xff0c;求字符串“abcd”的长度求n的阶乘求第n个斐波那契数 迭代总结递归与迭代的主要区别用法不同结构不同时间开销不同…

毛利率创历史新高,三季度的小米拿出“新王牌”?

近日&#xff0c;小米正式发布了今年三季度的财报。财报数据显示&#xff0c;小米第三季度经调整净利润为59.9亿元人民币&#xff0c;同比增长182.9%&#xff0c;远超市场预期的48亿元。这其中&#xff0c;手机业务作为小米的基本盘一直是市场的关注焦点。今年三季度&#xff0…

vue3的基本使用(超详细)

一、初识vue3 1.vue3简介 2020年9月18日&#xff0c;vue3发布3.0版本&#xff0c;代号大海贼时代来临&#xff0c;One Piece特点&#xff1a; 无需构建步骤&#xff0c;渐进式增强静态的 HTML在任何页面中作为 Web Components 嵌入单页应用 (SPA)全栈 / 服务端渲染 (SSR)Jams…

搜索引擎语法

演示自定的Google hacking语法&#xff0c;解释含意以及在渗透过程中的作用 Google hacking site&#xff1a;限制搜索范围为某一网站&#xff0c;例如&#xff1a;site:baidu.com &#xff0c;可以搜索baidu.com 的一些子域名。 inurl&#xff1a;限制关键字出现在网址的某…

【Spring】MyBatis的操作数据库

目录 一&#xff0c;准备工作 1.1 创建工程 1.2 准备数据 1.3 数据库连接字符串 1.4 创建持久层接口UserInfoMapper 1.5 单元测试 二&#xff0c;注解的基础操作 2.1 打印日志 2.2 参数传递 2.3 增&#xff08;Insert&#xff09; 2.4 删&#xff08;Delete&#x…