基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,[[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs
#===========================================#训练配置
# 声明网络结构
model = MNIST()
def train(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')# 声明网络结构
model = MNIST()
#===========================================
def run(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 10for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  #模型测试#===========================================
def showImage(im):#img_path = 'example_0.jpg'# 读取原始图像并显示#im = Image.open('example_0.jpg')plt.imshow(im)plt.show()# 将原始图像转为灰度图im = im.convert('L')print('原始图像shape: ', np.array(im).shape)# 使用Image.ANTIALIAS方式采样原始图片im = im.resize((28, 28), Image.LANCZOS)plt.imshow(im)plt.show()print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255return im
#=========================================== 
# 定义预测过程
def test():model = MNIST()params_file_path = 'mnist.pdparams'img_path = 'example_0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 灌入数据model.eval()tensor_img = load_image(img_path)  result = model(paddle.to_tensor(tensor_img))print('result',result)#  预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/17421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu设置主机ip

ubuntu 设置ip sudo dhclient -r enp67s0 # 是你的网卡,可以通过ifconfig 查,比如enp0 sudo ifconfig enp67s0 192.168.1.114 netmask 255.255.255.0 Ubuntu显示有线网已连接但无法上网,已经确认网口、交换机(路由器&#xff…

你知道HTTP与HTTPS有什么区别吗?

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 目录 一、什么是HTTP? 二、什么是HTTPS? 三、HTTPS 的工作原理 1、客户端发起 HTTPS 请求 2、服务端的配置 3、…

BT#蓝牙 - Link Policy Settings

对于Classic Bluetooth的Connection,有一个Link_Policy_Settings,是HCI configuration parameters中的一个。 Link_Policy_Settings 参数决定了本地链路管理器(Link Manager)在收到来自远程链路管理器的请求时的行为,还用来决定改变角色(rol…

UE4/5C++多线程插件制作(十九、异步资源读取封装,细节修改)

目录 MTPResourceLoadManage MTPThreadInterface MTPManage.h MTPManage.cpp RTPAgendy RTPAgendy.h RTPAgendy.cpp

Android如何用系统签名打包应用

前言 应用使用系统签名可以在用户不需要手动授权的情况下自动获取权限。适合一些定制系统中集成apk的方案商。 步骤 需要在AndroidManifest.xml中添加共享系统进程属性: android:sharedUserId"android.uid.system"如下图所示: 找到系统定制…

windows环境安装elasticsearch+kibana并完成JAVA客户端查询

下载elasticsearch和kibana安装包 原文连接:https://juejin.cn/post/7261262567304298554 elasticsearch官网下载比较慢,有时还打不开,可以通过https://elasticsearch.cn/download/下载,先找到对应的版本,最好使用迅…

LeetCode每日一题——1331.数组序号转换

题目传送门 题目描述 给你一个整数数组 arr ,请你将数组中的每个元素替换为它们排序后的序号。 序号代表了一个元素有多大。序号编号的规则如下: 序号从 1 开始编号。一个元素越大,那么序号越大。如果两个元素相等,那么它们的…

集团MySQL的酒店管理系统

酒店管理系统 概述 基于Spring Spring MVC MyBatis的酒店管理系统,主要实现酒店客房的预定、入住以及结账等功能。使用Maven进行包管理。 用户端主要功能包括: 登录注册、客房预订、客房评论(编写评论和查看评论) 后台管理主要…

Java maven的下载解压配置(保姆级教学)

mamen基本概念 Maven项目对象模型(POM),可以通过一小段描述信息来管理项目的构建,报告和文档的项目管理工具软件。 Maven 除了以程序构建能力为特色之外,还提供高级项目管理工具。由于 Maven 的缺省构建规则有较高的可重用性,所以…

【已解决】windows7添加打印机报错:加载Tcp Mib库时的错误,无法加载标准TCP/IP端口的向导页

windows7 添加打印机的时候,输入完打印机的IP地址后,点击下一步,报错: 加载Tcp Mib库时的错误,无法加载标准TCP/IP端口的向导页 解决办法: 复制以下的代码到新建文本文档.txt中,然后修改文本文…

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 (ML) 中,一些最重要的线性代数概念是奇异值分解 (SVD) 和主成分分析 (PCA)。收集到所有原始数据后,我们如何发现结构?例如,通过过去 6 天…

华为OD机试真题 JavaScript 实现【小朋友排队】【2023 B卷 100分】,附详细解题思路

目录 一、题目描述二、输入描述三、输出描述四、解题思路五、JavaScript算法源码六、效果展示1、输入2、输出 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试&am…

114、Spring AOP是如何实现的?它和AspectJ有什么区别?

Spring AOP是如何实现的?它和AspectJ有什么区别? 一、AOP的理解1、spring aop:动态代理实现2、spring aop 和 AspectJ的区别3、小图一、AOP的理解 其实,AOP只是一种编程思想,表示面向切面编程,如果想实现这种思想,可以使用动态代理啊,第三方的框架 AspectJ啊等等。 1…

【git合并分支自定义提交消息】

开发分支 dev主分支 master 需求 dev分支开发完后合并到master分支自定义提交信息 通过 git merge dev --squash --no-commit此命令会拉取dev分支代码到当前分支,并不会自动提交,可以自己修改提交信息

搭建帮助中心到底要重点关注哪些元素呢?

搭建帮助中心的目标是给用户提供全面的问题解决方案,所以我们在搭建帮助中心的时候就要多去注意“用户”“问题”“解决方案”“使用方法”这些元素。今天looklook就从这些重点展开,帮助大家深入了解一下帮助中心。 帮助中心的用户 在帮助中心中&#x…

Vue引入

1. vue引入 第一种方法&#xff1a;在线引入 <script src"https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script> 第二种方法&#xff1a;本地引入 2. 语法学习 el用于绑定id&#xff0c;data用于定义数据如下例题 <!DOCTYPE html> <html…

GFS分布式文件系统概述以及集群部署

目录 一、GlusterFS简介 二、GlusterFS特点 1、扩展性和高性能 2、高可用性 3、全局统一命名空间 4、弹性卷管理 5、基于标准协议 三、GlusterFS术语 四、GlusterFS构成 五、后端存储如何定位文件 六、GFS支持的七种卷 1、分布式卷&#xff08;Distribute volume&a…

JAVA- SQL注入案例(黑马程序员)和避免 超级详细

文章目录 sql注入准备1.创建应该新的数据库用于测试&#xff1b;2.修改配置3.启动jar包4.打开网页测试5.测试sql注入 sql注入避免1. java中的登录逻辑代码2.演示sql注入3.原因5.参数化查询-PreparedStatement SQL注入是什么&#xff1f; SQL 注入&#xff08;SQL Injection&…

Kubernetes系列

文章目录 1 详解docker,踏入容器大门1.1 引言1.2 初始docker1.3 docker安装1.4 docker 卸载1.5 docker 核心概念和底层原理1.5.1 核心概念1.5.2 docker底层原理 1.6 细说docker镜像1.6.1 镜像的常用命令 1.7 docker 容器1.8 docker 容器数据卷1.8.1 直接命令添加1.8.2 Dockerfi…

RVM问题记录 - Error running ‘__rvm_make -j10‘

文章目录 前言开发环境问题描述问题分析解决方案最后 前言 公司新到一台电脑需要配置开发环境&#xff0c;在用RVM安装Ruby时遇到了一个奇怪的问题。 开发环境 RVM: 1.29.12OpenSSL: 3.1.1 问题描述 执行命令安装Ruby 3.0版本&#xff1a; rvm install ruby-3.0.0在编译阶…