python与深度学习(六):CNN和手写数字识别二

目录

  • 1. 说明
  • 2. 手写数字识别的CNN模型测试
    • 2.1 导入相关库
    • 2.2 加载数据和模型
    • 2.3 设置保存图片的路径
    • 2.4 加载图片
    • 2.5 图片预处理
    • 2.6 对图片进行预测
    • 2.7 显示图片
  • 3. 完整代码和显示结果
  • 4. 多张图片进行测试的完整代码以及结果

1. 说明

本篇文章是对上篇文章训练的模型进行测试。首先是将训练好的模型进行重新加载,然后采用opencv对图片进行加载,最后将加载好的图片输送给模型并且显示结果。

2. 手写数字识别的CNN模型测试

2.1 导入相关库

在这里导入需要的第三方库如cv2,如果没有,则需要自行下载。

from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np

2.2 加载数据和模型

把MNIST数据集进行加载,并且把训练好的模型也加载进来。

# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载cnn_mnist.h5文件,重新生成模型对象, 等价于之前训练好的cnn_model
recons_model = keras.models.load_model('cnn_mnist.h5')

2.3 设置保存图片的路径

将数据集的某个数据以图片的形式进行保存,便于测试的可视化。
在这里设置图片存储的位置。

# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)

在书写完上述代码后,需要在代码的当前路径下新建一个imgs的文件夹用于存储图片,如下。
在这里插入图片描述

执行完上述代码后就会在imgs的文件中可以发现多了一张图片,如下(下面测试了很多次)。
在这里插入图片描述

2.4 加载图片

采用cv2对图片进行加载,下面最后一行代码取一个通道的原因是用opencv库也就是cv2读取图片的时候,图片是三通道的,而训练的模型是单通道的,因此取单通道。

# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
print(test_img.shape)

2.5 图片预处理

对图片进行预处理,即进行归一化处理和改变形状处理,这是为了便于将图片输入给训练好的模型进行预测。

# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)

2.6 对图片进行预测

将图片输入给训练好我的模型并且进行预测。
预测的结果是10个概率值,所以需要进行处理, np.argmax()是得到概率值最大值的序号,也就是预测的数字。

# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别/手写体数字:', class_id)
class_id = str(class_id)

2.7 显示图片

对预测的图片进行显示,把预测的数字显示在图片上。
下面6行代码分别是创建窗口,设定窗口大小,显示数字,显示图片,停留图片,清除内存。

# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

3. 完整代码和显示结果

以下是完整的代码和图片显示结果。

from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载cnn_mnist.h5文件,重新生成模型对象, 等价于之前训练好的cnn_model
recons_model = keras.models.load_model('cnn_mnist.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
print(test_img.shape)
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别/手写体数字:', class_id)
class_id = str(class_id)
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
(28, 28)
1/1 [==============================] - 0s 210ms/step
test.png的预测概率: [[2.3381226e-05 1.1173951e-09 2.5884110e-09 2.3000638e-10 1.5515226e-073.6373976e-07 9.9997604e-01 5.8317045e-13 1.0071908e-07 1.6725430e-09]]
test.png的预测概率: 0.99997604
test.png的所属类别/手写体数字: 6

在这里插入图片描述

4. 多张图片进行测试的完整代码以及结果

为了测试更多的图片,引入循环进行多次测试,效果更好。

from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载cnn_mnist.h5文件,重新生成模型对象, 等价于之前训练好的cnn_model
recons_model = keras.models.load_model('cnn_mnist.h5')prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):path1 = input("input the test picture path:")# 创建图片保存路径test_file_path = os.path.join(sys.path[0], 'imgs', path1)# 存储测试数据的任意一个num = int(input("input the test picture num:"))Image.fromarray(x_test[num]).save(test_file_path)# 加载本地test.png图像image = cv2.imread(test_file_path)# 复制图片test_img = image.copy()# 将图片大小转换成(28,28)test_img = cv2.resize(test_img, (28, 28))# 取单通道值test_img = test_img[:, :, 0]# 预处理: 归一化 + reshapenew_test_img = (test_img/255.0).reshape(1, 28, 28, 1)# 预测y_pre_pro = recons_model.predict(new_test_img, verbose=1)# 哪一类数字class_id = np.argmax(y_pre_pro, axis=1)[0]print('test.png的预测概率:', y_pre_pro)print('test.png的预测概率:', y_pre_pro[0, class_id])print('test.png的所属类别/手写体数字:', class_id)class_id = str(class_id)# # 显示cv2.namedWindow('img', 0)cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)cv2.imshow('img', image)cv2.waitKey()cv2.destroyAllWindows()

下面的test picture num指的是数据集中该数据的序号(0-59999),并不是值实际的数字。

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :2
input the test picture path:1.jpg
input the test picture num:1
1/1 [==============================] - 0s 156ms/step
test.png的预测概率: [[4.3549915e-07 4.7153802e-07 9.9998319e-01 5.7891691e-07 2.7986115e-085.3348625e-08 7.1938064e-09 1.4849566e-05 3.6678301e-07 2.2624316e-09]]
test.png的预测概率: 0.9999832
test.png的所属类别/手写体数字: 2

在这里插入图片描述

input the test picture path:2.jpg
input the test picture num:2
1/1 [==============================] - 0s 26ms/step
test.png的预测概率: [[1.4249144e-10 9.9994874e-01 6.1170212e-08 2.7543174e-09 1.9512597e-065.1548787e-09 1.5619334e-07 3.3457465e-07 4.5184272e-05 3.6284032e-06]]
test.png的预测概率: 0.99994874
test.png的所属类别/手写体数字: 1

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/9395.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云国际版云服务器防火墙设置

阿里云国际版云服务器防火墙设置 入侵防御页面为您实时展示云防火墙拦截流量的源IP、目的IP、阻断应用、阻断来源和阻断事件详情等信息。本文介绍了入侵防御页面展示的信息和相关操作,下面和012一起来了解阿里云国际版云服务器防火墙设置: 前提条件 您需…

工具推荐:Linux Busybox

文章首发地址 BusyBox是一个开源的、轻量级的、可嵌入式的、多个Unix工具的集合。BusyBox提供了各种Unix工具的实现,包括文件处理工具、网络工具、shell工具、系统管理工具、进程管理工具等等。它被设计为一个小巧、高效、可靠、易于维护的工具,适用于嵌…

Java在线OJ项目(二)、数据库与题目的增删改查【后端如何操作数据和数据库】

Java在线OJ项目(二)、数据库与题目的增删改查【后端如何操作数据和数据库】 (二)、数据库与题目的增删改查【后端如何操作数据和数据库】1. 设计题目的数据库 格式2. 存储题目类3. 数据库连接代码(common所有模块都可以…

微服务——服务异步通讯RabbitMQ

前置文章 消息队列——RabbitMQ基本概念容器化部署和简单工作模式程序_北岭山脚鼠鼠的博客-CSDN博客 消息队列——rabbitmq的不同工作模式_北岭山脚鼠鼠的博客-CSDN博客 消息队列——spring和springboot整合rabbitmq_北岭山脚鼠鼠的博客-CSDN博客 目录 Work queues 工作队列…

设计模式 - 工厂模式

一、 简单工厂(Simple Factory Pattern) 1、概念 一个工厂对象决定创建出哪一种产品类的实力,但不属于GOF23种设计模式。 简单工厂适用于工厂类负责创建的对象较少的场景,且客户端只需要传入工厂类的参数,对于如何创…

Andrew算法求凸包模板

前置知识 向量的叉乘: 设 a ⃗ ( x a , y a , z a ) , b ⃗ ( x b , y b , z b ) \vec a(x_a,y_a,z_a), \vec b(x_b, y_b,z_b) a (xa​,ya​,za​),b (xb​,yb​,zb​), 令 a ⃗ \vec a a 和 b ⃗ \vec b b 的叉乘为 c ⃗ \vec c c , 有: c ⃗ ∣ i j k x a y a z a x b y…

【深度学习】GPT-3

2020年5月,OpenAI在长达72页的论文《https://arxiv.org/pdf/2005.14165Language Models are Few-Shot Learners》中发布了GPT-3,共有1750亿参数量,需要700G的硬盘存储,(GPT-2有15亿个参数),它比GPT-2有了极大的改进。根…

钉钉返回:访问ip不在白名单之中,请参考FAQ

新版钉钉 在开发管理-服务器出口IP-配置返回错误信息返回给你的requestIp

k8s部署新版elasticsearch+kibana并配置快照备份

版本:es 7.17.6 kibana 7.17.6 k8s:1.19.16 一、介绍 Elasticsearch和Kibana是一对强大的开源工具,通常一起使用以构建实时数据分析和可视化解决方案。 Elasticsearch: Elasticsearch是一个分布式、高性能的实时搜索和分析引擎。它构建在开源搜索引擎库Lucene之上…

【C++】开源:Redis数据库配置与使用

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍Redis数据库配置与使用。 无专精则不能成,无涉猎则不能通。。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下&#xff0c…

边缘计算对现代交通的重要作用

边缘计算之所以重要,是在于即使在5G真正商用之时,可以实现超大带宽(eMBB)的应用场景,但庞大数据量的涌现也就意味着需要在云和端传输过程中找到一个承接点,对数据进行预处理再选择是否上云。 边缘计算应用演…

【Python入门【推导式创建序列、字典推导式、集合推导式】(九)

👏作者简介:大家好,我是爱敲代码的小王,CSDN博客博主,Python小白 📕系列专栏:python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 📧如果文章知识点有错误…

流媒体协议

1 RTP报⽂格式 V:RTP协议的版本号,占2位,当前协议版本号为2。 P:填充标志,占1位,如果P1,则在该报⽂的尾部填充⼀个或多个额外的⼋位组,它们不是有效载荷 的⼀部分。 X:扩…

SkyWalking链路追踪-技术文档首页

SkyWalking 文档中文版(社区提供) (skyapm.github.io)https://skyapm.github.io/document-cn-translation-of-skywalking/ SkyWalking-基本概念 SkyWalking链路追踪是一个用于分布式系统的性能监控工具,它帮助开发人员了解系统中各组件之间…

工程安全监测无线振弦采集仪在建筑物的应用分析

工程安全监测无线振弦采集仪在建筑物的应用分析 工程安全监测无线振弦采集仪是一种在建筑物中应用的重要设备。它通过无线采集建筑物内部的振动信息,对建筑物的安全性进行监测和评估,为建筑物的施工和使用提供了可靠的技术支持。本文将详细介绍工程安全…

GBDT算法

GBDT 是 Gradient Boosting Decison Tree,是集成学习下boosting家族的一个算法。GBDT 可以用于分类和回归任务,但基学习器都是 CART 回归树,因为它使用的是负梯度拟合的方法做的,分类任务是通过采用损失函数来做的,类似…

[Spark] 大纲

1、Spark任务提交流程 2、SparkSQL执行流程 2.1 RBO,基于规则的优化 2.2 CBO,基于成本的优化 3、Spark性能调优 3.1 固定资源申请和动态资源分配 3.2 数据倾斜常见解决方法 3.3 小文件优化 4、Spark 3.0 4.1 动态分区裁剪(Dynamic Partition Pr…

ElasticSearch基础篇-安装与基本操作

ElasticSearch基础篇 安装 官网 下载地址 下载完成后对文件进行解压,项目结构如下 进入bin目录点击elasticsearch.bat启动服务 9300 端口为 Elasticsearch 集群间组件的通信端口, 9200 端口为浏览器访问的 http协议 RESTful 端口 打开浏览器&#…

力扣热门100题之矩阵置0【中等】

题目描述 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1: 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 示例 2&#xff…

C++ - list介绍 和 list的模拟实现

list介绍 list 是一个支持在常数范围内,任意位置进行插入删除的序列式容器,且这个容器可以前后双向迭代。我们可以把 list 理解为 双向循环链表的结构。 于其他结构的容器相比,list在 任意位置进行插入和函数的效率要高很多;而li…