用模型预测测试数据


Hi, I’m Shendi




2、用模型预测测试数据



在之前已经训练好了一个模型,可以通过 model.save("path") 来保存模型到硬盘,下次直接使用。


这个模型使用的 mnist 数据集训练,这个数据集包含6万训练样本和一万测试样本,28*28像素,是一个手写数字数据集,相当于在学习编程语言的hello,world

接下来就开始使用训练好的模型


使用测试数据测试

最开始我尝试直接用画图工具绘制一个数组,让其识别。但识别出来的压根不对,也不清楚什么原因,所以从最开始的弄起。

既然训练的模型评估的准确度达到90%多,那么使用测试数据就没有问题了吧,我将测试数据的图片保存依然识别不对。于是直接使用加载的测试数据


最开始,当然是加载数据集

# 加载 mnist 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

通过 tf.keras.models.load_model 加载保存的模型,我存储在 my_model 文件夹中

model = tf.keras.models.load_model('my_model')

# 选择一些测试集样本
selected_samples = x_test[:5]
true_labels = y_test[:5]

其中x_test是样本数据,y_test样本的正确标签

通过 predict 进行预测,在之前训练的模型有十个输出层,0-9,预测获得的结果就是这个样本对应输出层的可信度,最终结果选择可信度最高的那个

predictions = model.predict(selected_samples)
print(predictions)
# 选取可信度最高的打印
print(tf.argmax(predictions[0]).numpy())

因为我使用vscode,所以没办法直接show,只能保存到本地文件夹查看结果,于是使用以下代码

# 保存图像和预测结果到文件
for i in range(len(selected_samples)):plt.imshow(selected_samples[i], cmap='gray')  # 显示灰度图像plt.title(f"Predicted: {tf.argmax(predictions[i]).numpy()}, True: {true_labels[i]}")plt.axis('off')plt.savefig(f"predicted_image_{i}.png")  # 保存图像plt.close()

这个结果是准确的,效果如下

在这里插入图片描述


其中上面的predicted是预测结果,true是正确结果



问题

就如上面所说,我将数据集的测试数据的某张图片保存到本地,然后加载,用模型预测加载的图片,是不准确的。

我的代码

import tensorflow as tf
import matplotlib.pyplot as pltfrom PIL import Image
import numpy as np# 加载 mnist 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()def initModel():model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10)])predictions = model(x_train[:1]).numpy()tf.nn.softmax(predictions).numpy()loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)loss_fn(y_train[:1], predictions).numpy()model.compile(optimizer='adam',loss=loss_fn,metrics=['accuracy'])model.fit(x_train, y_train, epochs=5)r = model.evaluate(x_test,  y_test, verbose=2)print(r)model.save("my_model");# probability_model = tf.keras.Sequential([#   model,#   tf.keras.layers.Softmax()# ])# probability_model(x_test[:5])
# initModel();def test():model = tf.keras.models.load_model('my_model')# 选择一些测试集样本selected_samples = x_test[:5]true_labels = y_test[:5]# 使用模型对样本进行预测predictions = model.predict(selected_samples)print(predictions)print(tf.argmax(predictions[0]).numpy())# 保存图像和预测结果到文件for i in range(len(selected_samples)):plt.imshow(selected_samples[i], cmap='gray')  # 显示灰度图像plt.title(f"Predicted: {tf.argmax(predictions[i]).numpy()}, True: {true_labels[i]}")plt.axis('off')plt.savefig(f"predicted_image_{i}.png")  # 保存图像plt.close()
# test();def test2():model = tf.keras.models.load_model('my_model')# 准备图像img_path = 'test_img.png'  # 替换为你的图像文件路径image = Image.open(img_path)image = image.convert('L')  # 转换为灰度图像image = image.resize((28, 28))  # 调整图像大小image = np.array(image)  # 转换为 numpy 数组# 归一化处理(如果在训练模型时有进行归一化)image = image.astype('float32') / 255plt.imshow(image, cmap='gray')  # 显示灰度图像plt.axis('off')plt.savefig(f"my.png")  # 保存图像plt.close()# 对图像进行预测prediction = model.predict(np.expand_dims(image, axis=0))print(prediction)print(np.argmax(prediction, axis=1))
test2()def saveImg(index):img = Image.fromarray(x_test[index])img.save('test_img.png')
saveImg(0)

我直接使用画图工具绘制数字,加载这个图片,预测,也是不准确的。对于这个,已经花了大把的时间搜索,但资料都特别少,于是准备跳过了,毕竟刚开始,一切都是未知。不应在一些非目标的事情浪费大把时间。




END

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/237023.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux下的进程组与会话的区别

进程组(Process Group)和会话(Session)是Unix/Linux操作系统中的两个概念,它们之间有一些关键区别: 定义和范围:一个进程组是一组相关进程的集合,它们具有相同的进程组ID&#xff08…

【运维面试100问】(十一)淡淡I/O过程

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》:python零基础入门学习 《python运维脚本》: python运维脚本实践 《shell》:shell学习 《terraform》持续更新中:terraform_Aws学习零基础入门到最佳实战 《k8…

华为云之ECS云产品快速入门

华为云之ECS云产品快速入门 一、ECS云服务器介绍二、本次实践目标三、创建虚拟私有云VPC1.虚拟私有云VPC介绍2.进入虚拟私有云VPC管理页面3.创建虚拟私有云4.查看创建的VPC 四、创建弹性云服务器ECS——Linux1.进入ECS购买界面2.创建弹性云服务器(Linux)——基础配置步骤3.创建…

DC-磁盘管理

2023年全国网络系统管理赛项真题 模块B-Windows解析 题目 在DC2上安装及配置软RAID 5。在安装好的DC2虚拟机中添加三块10G虚拟磁盘。组成RAID 5,磁盘分区命名为卷标H盘:Raid5。手动测试破坏一块磁盘,做RAID磁盘修复,确认RAID 5配置完毕。配置步骤 关闭虚拟机,添加3块10G磁…

【AI提示词艺术】第10期 ---希望、价值与魔法的交织:小女孩与梦幻背景的数字艺术之旅

金色猎犬被视为一种贵族犬种 金色猎犬是一种流行的犬种,通常被视为一种象征,代表着忠诚、勇气和敏锐的嗅觉。这种犬种有着悠久的历史,可以追溯到中世纪时期。 金色猎犬的外观特征是短而厚的毛发,金色的颜色,以及敏锐…

Python---TCP 网络应用程序开发流程

1. TCP 网络应用程序开发流程的介绍 TCP 网络应用程序开发分为: TCP 客户端程序开发TCP 服务端程序开发 说明: 客户端程序是指运行在用户设备上的程序 服务端程序是指运行在服务器设备上的程序,专门为客户端提供数据服务。 2. TCP 客户端程序开发流程的介绍 步…

在GitHub找开源项目

在 GitHub 的搜索框里: 使用搜索关键词可以在 GitHub 上快速的找你需要的开源项目: 限制搜索范围 通过 in 关键词 (大小写不敏感) 限制搜索范围: 公式搜索范围in:name xxx项目名包含xxxin:description xxx项目描述包含xxxin:readme xxx项目…

C/C++ 使用 MySQL API 进行数据库操作

C/C 使用 MySQL API 进行数据库操作 一、前言 随着信息时代的到来,数据库的应用日益广泛,MySQL 作为开源的关系型数据库管理系统,被广大开发者所喜爱。在 C/C 程序中,我们可以通过 MySQL 提供的 API 接口来连接数据库&#xff0…

100GPTS计划-AI学术AcademicRefiner

地址 https://chat.openai.com/g/g-LcMl7q6rk-academic-refiner https://poe.com/AcademicRefiner 测试 减少相似性 增加独特性 修改http://t.csdnimg.cn/jyHwo这篇文章微调 专注于人工智能、科技、金融和医学领域的学术论文改写,秉承严格的专业和学术标准。 …

Windows 如何在局域网中建立NTP服务器实现时间同步(设置一台设备作为主机,其他设备作为从机来同步时间)

首先简单了解一下什么是NTP 网络时间协议(NTP)是一种用于同步计算机网络上各设备时间的协议。NTP时间同步在许多项目和应用中都是关键的,特别是那些对时间同步精度有要求的场景。比如需要使用NTP时间同步的情况有:金融交易系统、科学研究实验、工业自动…

Unity 创建/删除/启用/禁用组件的惯用方法

1、创建组件&#xff1a; Unity 创建组件可以通过编辑器中的"Add Component"创建&#xff0c;或者代码动态创建&#xff1a;GameObject.AddComponent<T>()&#xff0c;如&#xff1a; ameObject.AddComponent<Rigidbody>(); 2、删除组件&#xff1a; …

WPF组合控件TreeView+DataGrid之TreeView封装

&#xff08;关注博主后&#xff0c;在“粉丝专栏”&#xff0c;可免费阅读此文&#xff09; wpf的功能非常强大&#xff0c;很多控件都是原生的&#xff0c;但是要使用TreeViewDataGrid的组合&#xff0c;就需要我们自己去封装实现。 我们需要的效果如图所示&#x…

医院影像科PACS系统源码,医学影像系统,支持MPR、CPR、MIP、SSD、VR、VE三维图像处理

PACS系统是医院影像科室中应用的一种系统&#xff0c;主要用于获取、传输、存档和处理医学影像。它通过各种接口&#xff0c;如模拟、DICOM和网络&#xff0c;以数字化的方式将各种医学影像&#xff0c;如核磁共振、CT扫描、超声波等保存起来&#xff0c;并在需要时能够快速调取…

Python之json模块和pickle模块详解

json模块和pickle模块的用法 在python中&#xff0c;可以使用pickle和json两个模块对数据进行序列化操作。 其中&#xff1a; json可以用于字符串或者字典等与python数据类型之间的序列化与反序列化操作。 pickle可以用于python特有类型与python数据类型之间的序列化与反序…

回归预测 | MATLAB实现GWO-DHKELM基于灰狼算法优化深度混合核极限学习机的数据回归预测 (多指标,多图)

回归预测 | MATLAB实现GWO-DHKELM基于灰狼算法优化深度混合核极限学习机的数据回归预测 &#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现GWO-DHKELM基于灰狼算法优化深度混合核极限学习机的数据回归预测 &#xff08;多指标&#xff0c;多图&#…

Redis基础篇-002 初识Redis

1、认识NoSQL 1.1 概念 NoSQL是一个非关系型数据库。 常见的NoSQL有&#xff1a;Redis、MongoDB 1.2 NoSQL与SQL的区别 类别SQLNoSQL数据结构结构化非结构化数据关联关联非关联查询方式SQL非SQL事务特性ACIDBASE存储方式磁盘内存扩展性垂直水平使用场景1&#xff09;数据结…

Docker安装(CentOS)+简单使用

Docker安装(CentOS) 一键卸载旧的 sudo yum remove docker* 一行代码(自动安装) 使用官方安装脚本 curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun 启动 docker并查看状态 运行镜像 hello-world docker run hello-world 简单使用 使用 docker run …

docker部署个人网站项目记录(前后端分离)

背景 项目是前后端分离&#xff0c;前端有三部分&#xff0c;分别是 个人网站&#xff08;blog&#xff09;网站后台管理系统&#xff08;admin&#xff09;数据大屏&#xff08;datascreen&#xff09; 后端是基于nodejs写的后台服务 后台接口服务&#xff08;todo-nodejs…

Github项目推荐:在线rename

项目地址 GitHub - JasonGrass/rename: 在线文件批量重命名 项目简介 一个开源的在线重命名文件工具。利用了新的浏览器API获取文件句柄&#xff0c;在不上传文件的情况下对文件进行重命名。可以作为前端文件操作api学习范例。 项目截图

《每天一分钟学习C语言·五》

1、 给一个字符数组输入字符串 char arr[10]; gets[arr]; //gets函数接收回车符&#xff0c;如果直接按回车&#xff0c;gets函数会把回车符转变成空字符作为结束&#xff0c;即arr[0]’\0’;2、 文件结尾标志ctrlz表示返回NULL 自己定义的头文件里面一般有宏定义和声明&#…