基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,[[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs
#===========================================#训练配置
# 声明网络结构
model = MNIST()
def train(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')# 声明网络结构
model = MNIST()
#===========================================
def run(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 10for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  #模型测试#===========================================
def showImage(im):#img_path = 'example_0.jpg'# 读取原始图像并显示#im = Image.open('example_0.jpg')plt.imshow(im)plt.show()# 将原始图像转为灰度图im = im.convert('L')print('原始图像shape: ', np.array(im).shape)# 使用Image.ANTIALIAS方式采样原始图片im = im.resize((28, 28), Image.LANCZOS)plt.imshow(im)plt.show()print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255return im
#=========================================== 
# 定义预测过程
def test():model = MNIST()params_file_path = 'mnist.pdparams'img_path = 'example_0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 灌入数据model.eval()tensor_img = load_image(img_path)  result = model(paddle.to_tensor(tensor_img))print('result',result)#  预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/17421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

你知道HTTP与HTTPS有什么区别吗?

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 目录 一、什么是HTTP? 二、什么是HTTPS? 三、HTTPS 的工作原理 1、客户端发起 HTTPS 请求 2、服务端的配置 3、…

Android如何用系统签名打包应用

前言 应用使用系统签名可以在用户不需要手动授权的情况下自动获取权限。适合一些定制系统中集成apk的方案商。 步骤 需要在AndroidManifest.xml中添加共享系统进程属性: android:sharedUserId"android.uid.system"如下图所示: 找到系统定制…

windows环境安装elasticsearch+kibana并完成JAVA客户端查询

下载elasticsearch和kibana安装包 原文连接:https://juejin.cn/post/7261262567304298554 elasticsearch官网下载比较慢,有时还打不开,可以通过https://elasticsearch.cn/download/下载,先找到对应的版本,最好使用迅…

LeetCode每日一题——1331.数组序号转换

题目传送门 题目描述 给你一个整数数组 arr ,请你将数组中的每个元素替换为它们排序后的序号。 序号代表了一个元素有多大。序号编号的规则如下: 序号从 1 开始编号。一个元素越大,那么序号越大。如果两个元素相等,那么它们的…

集团MySQL的酒店管理系统

酒店管理系统 概述 基于Spring Spring MVC MyBatis的酒店管理系统,主要实现酒店客房的预定、入住以及结账等功能。使用Maven进行包管理。 用户端主要功能包括: 登录注册、客房预订、客房评论(编写评论和查看评论) 后台管理主要…

Java maven的下载解压配置(保姆级教学)

mamen基本概念 Maven项目对象模型(POM),可以通过一小段描述信息来管理项目的构建,报告和文档的项目管理工具软件。 Maven 除了以程序构建能力为特色之外,还提供高级项目管理工具。由于 Maven 的缺省构建规则有较高的可重用性,所以…

【已解决】windows7添加打印机报错:加载Tcp Mib库时的错误,无法加载标准TCP/IP端口的向导页

windows7 添加打印机的时候,输入完打印机的IP地址后,点击下一步,报错: 加载Tcp Mib库时的错误,无法加载标准TCP/IP端口的向导页 解决办法: 复制以下的代码到新建文本文档.txt中,然后修改文本文…

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 (ML) 中,一些最重要的线性代数概念是奇异值分解 (SVD) 和主成分分析 (PCA)。收集到所有原始数据后,我们如何发现结构?例如,通过过去 6 天…

华为OD机试真题 JavaScript 实现【小朋友排队】【2023 B卷 100分】,附详细解题思路

目录 一、题目描述二、输入描述三、输出描述四、解题思路五、JavaScript算法源码六、效果展示1、输入2、输出 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试&am…

Vue引入

1. vue引入 第一种方法&#xff1a;在线引入 <script src"https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script> 第二种方法&#xff1a;本地引入 2. 语法学习 el用于绑定id&#xff0c;data用于定义数据如下例题 <!DOCTYPE html> <html…

GFS分布式文件系统概述以及集群部署

目录 一、GlusterFS简介 二、GlusterFS特点 1、扩展性和高性能 2、高可用性 3、全局统一命名空间 4、弹性卷管理 5、基于标准协议 三、GlusterFS术语 四、GlusterFS构成 五、后端存储如何定位文件 六、GFS支持的七种卷 1、分布式卷&#xff08;Distribute volume&a…

JAVA- SQL注入案例(黑马程序员)和避免 超级详细

文章目录 sql注入准备1.创建应该新的数据库用于测试&#xff1b;2.修改配置3.启动jar包4.打开网页测试5.测试sql注入 sql注入避免1. java中的登录逻辑代码2.演示sql注入3.原因5.参数化查询-PreparedStatement SQL注入是什么&#xff1f; SQL 注入&#xff08;SQL Injection&…

Kubernetes系列

文章目录 1 详解docker,踏入容器大门1.1 引言1.2 初始docker1.3 docker安装1.4 docker 卸载1.5 docker 核心概念和底层原理1.5.1 核心概念1.5.2 docker底层原理 1.6 细说docker镜像1.6.1 镜像的常用命令 1.7 docker 容器1.8 docker 容器数据卷1.8.1 直接命令添加1.8.2 Dockerfi…

cocosCreator 之 2D物理

版本&#xff1a; v3.4.0 简介 cocosCreator 内置了 2D 物理系统 和 3D 物理系统&#xff0c;开发者可以通过项目 -> 项目设置 -> 功能裁切来配置物理系统相关&#xff1a; 本文仅对2D 物理系统 做下说明和遇到的问题汇总。该物理系统在cocosCreator的功能裁切中&#x…

android 如何分析应用的内存(十三)——perfetto

android 如何分析应用的内存&#xff08;十三&#xff09; 本篇文章是native内存的最后一篇文章——perfetto perfetto简介 从2018年始&#xff0c;android开发者峰会正式推出perfetto工具。从此perfetto成为安卓最重要的工具之一。在2018年以前&#xff0c;android使用syst…

微信小程序tab加列表demo

一、效果 代码复制即可使用&#xff0c;记得把图标替换成个人工程项目图片。 微信小程序开发经常会遇到各种各样的页面组合&#xff0c;本demo为list列表与tab组合&#xff0c;代码如下&#xff1a; 二、json代码 {"usingComponents": {},"navigationStyle&q…

matlab使用教程(6)—线性方程组的求解

进行科学计算时&#xff0c;最重要的一个问题是对联立线性方程组求解。在矩阵表示法中&#xff0c;常见问题采用以下形式&#xff1a;给定两个矩阵 A 和 b&#xff0c;是否存在一个唯一矩阵 x 使 Ax b 或 xA b&#xff1f; 考虑一维示例具有指导意义。例如&#xff0c;方程 …

20.3 HTML 表格

1. table表格 table标签是HTML中用来创建表格的元素. table标签通常包含以下子标签: - th标签: 表示表格的表头单元格(table header), 用于描述列的标题. - tr标签: 表示表格的行(table row). - td标签: 表示表格的单元格(table data), 通常位于tr标签内, 用于放置单元格中的…

奥迪A3:最新款奥迪A3内饰设计及智能科技应用

奥迪A3一直以来都是奥迪的入门级车型&#xff0c;但这并不意味着它在科技和内饰方面会有所退步。最新款奥迪A3的内饰设计和智能科技应用让人们再次惊叹奥迪的创新能力。 内饰设计 奥迪A3最新款的内饰设计引入了奥迪最新的设计元素&#xff0c;比如8.8英寸的中控显示屏&#xf…

干货 ,ChatGPT 4.0插件Review Reader,秒杀一切选品神器

Hi! 大家好&#xff0c;我是专注于AI项目实战的赤辰&#xff0c;今天继续跟大家介绍另外一款GPT4.0插件Review Reader&#xff08;评论阅读器&#xff09;。 做电商领域的小伙伴们&#xff0c;都知道选品分析至关重要&#xff0c;可以说选品决定成败&#xff0c;它直接关系到产…