paddle实现手写数字模型(一)

  1. 参考文档:paddle官网文档
  2. 环境:Python 3.12.2 ,pip 24.0 ,paddlepaddle 2.6.0
    python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 调试代码如下:
    LeNet.py
import paddle
import paddle.nn.functional as Fclass LeNet(paddle.nn.Layer):def __init__(self):super().__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

train.py


import paddle
from paddle.vision.transforms import Compose,Normalize,ToTensor
import paddle.vision.transforms as T  import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracyfrom LeNet import LeNet
from PIL import Imageprint(paddle.__version__)
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
print('下载和加载训练数据...')
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test',transform=transform)
print('load finished')train_data0,train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0,cmap=plt.cm.binary)
#plt.show()
print('train_data0 label is: '+str(train_label_0))model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
print('配置模型...')
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
# 训练模型
print('训练模型...')
model.fit(train_dataset,epochs=2,batch_size=64,verbose=1)
# 保存模型  
model.save('./model/mnist_model')  # 默认保存模型结构和参数 #预测模型
print('预测模型...')
model.evaluate(test_dataset, batch_size=64, verbose=1)

predicted.py


import paddleimport numpy as npfrom LeNet import LeNet
from PIL import Image# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')#plt.imshow(im,cmap='gray')# print(np.array(im))im = im.resize((28, 28), Image.Resampling.LANCZOS)im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255 return im# 加载训练好的模型参数
model = LeNet()
model.load_dict(paddle.load('./model/mnist_model.pdparams'))# 设置模型为评估模式
model.eval()# 准备一个MNIST样例图像
example_image = load_image("d:/8.png")# 转换为Tensor并进行推理
with paddle.no_grad():example_tensor = paddle.to_tensor(example_image)prediction = model(example_tensor)print(prediction)# 获取预测类别
predicted_class = np.argmax(prediction.numpy(), axis=1)[0]
print(f"Predicted class: {predicted_class}")

说明:先通过执行train.py训练数据集,将模型保存在model文件夹中,
然后运行predicted.py加载训练出来的数据集,推理出d:/8.png图片的结果。
结果图片如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/812338.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

013_NaN_in_Matlab中的非数与调试方法

Matlab中的非数与调试方法 是什么? Matlab编程(计算器使用)中经常有个错误给你,这句话里可能包含一个关键词NaN。大部分学生都有过被 NaN 支配的痛苦记忆。 NaN 是 Not a Number 的缩写,表示不是一个数字。在 Matla…

【SVN】clean up报错:Cleanup failed to process the following paths 解决方法

报错来源:代码更新有一个文件既不能接受自己的也不能接受别人的,只能取消,再提交提醒clean up,随后报标题错误。 解决方法:参考https://www.cnblogs.com/pinpin/p/11395438.html 注:如果clean up的时候有…

基于ssm项目校园快递平台系统

采用技术 基于SpringBoot框架实现的web的智慧社区系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringMVCMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 管理员功能 订单管理 快递管理 公告管理 公告类…

基于YOLOv5s的电动车入梯识别系统(数据集+权重+登录界面+GUI界面+mysql)

本人训练的yolov5s模型,准确率在98.6%左右,可准确完成电梯内检测电动车任务,并搭配了GUI检测界面,支持权重选择、图片检测、视频检测、摄像头检测、识别结果拍照和在线标注数据集等功能。 共有4000张左右图片,全部为电…

机器人瓶胚检测工作站(H3U脉冲轴控制)

1、变量定义 2、程序监控1 2、 程序监控2 3、程序监控3 机器人输送料和机构的动作安全尤为重要,下面我们讨论下安全联锁控制逻辑

基于STM32F103单片机的时间同步项目

一、前言 本项目为前一个时间同步项目的更迭版本,由于之前的G031开发板没有外部晶振,从机守时能力几乎没有,5秒以上不同步从机时间就开始飞了。在考虑成本选型后,选择了带有外部有缘晶振的STM32F103C8T6最小单片机,来作…

解决mac本git安装后找不到命令的问题

不熟悉mac配置,折腾了半天,记录一下。 1.问题描述2.解决方法 1.问题描述 从https://sourceforge.net/projects/git-osx-installer/files/下载的git安装包: 安装时提示: 这里的解决办法是按住control键再打开文件安装。 安装完…

react antd 实现修改密码(原密码,新密码,再次输入新密码,新密码增加正则复杂度校验)

先看样子 组件代码: import React, { useState, useEffect } from react import { Row, Col, Modal, Spin, Input, Button, message, Form } from antd import { LockOutlined, EyeTwoTone, EyeInvisibleOutlined } from ant-design/icons import * as Serve from …

pyside6的QSpinBox自定义特性初步研究(二)

当前的需求是,蓝色背景的画面,需要一个相对应色系的QSpinBox部件。已有的部件风格是这样的,需要新的部件与之般配。 首先新建一个QDoubleSpinBox,并定义其背景色和边框: QDoubleSpinBox { color: white; border:1px…

基于无线物联网的智能配电监控系统设计应用

摘要:阐述基于电力物联网的智能配电监控系统的特点,探讨物联网结构及其关键技术,电力物联网下的智能配电监控系统设计,包括整体结构设计、硬件和软件系统设计。 安科瑞薛瑶瑶18701709087 关键词:电力物联网&#xff…

【好用】推荐10套后端管理系统前端模板

后台管理系统前端模板是开发者在构建后台管理系统时使用的一种工具,它提供了预先设计好的界面和组件,以帮助开发者快速搭建出功能完善、用户体验良好的管理系统。以下是V哥整理的10款流行的后台管理系统前端模板,它们基于不同的技术栈和设计理…

zookeeper分布式应用程序协调服务

一、zookeeper基本介绍 1.1 zookeeper的概念 Zookeeper是一个开源的分布式的,为分布式框架提供协调服务的Apache项目。 是Hadoop和Hbase的重要组件。它是一个为分布式应用提供一致性服务的软件,提供的功能包括:配置维护、域名服务、…

[Python图像识别] 五十二.水书图像识别 (2)基于机器学习的濒危水书古文字识别研究

该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门、OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子、图像增强技术、图像分割等,后期结合深度学习研究图像识别、图像分类应用。目前我进入第二阶段Python图像识别,该部分主要以目标检测、图像…

CVPR 2024 | Retrieval-Augmented Open-Vocabulary Object Detection

CVPR 2024 - Retrieval-Augmented Open-Vocabulary Object Detection 论文:https://arxiv.org/abs/2404.05687代码:https://github.com/mlvlab/RALF原始文档:https://github.com/lartpang/blog/issues/13 本文提出了一种新的开放词汇目标检…

去除pycharm运行pytest的默认参数--no-header --no-summary -q

进入pycharm设置(Settings),找到高级设置(Advanced Settings)—>Python–>Pytest:不添加"–no-header --no-summary -q"(Pytest:do not add “–no-header --no-summary -q”)

2024年妈妈杯数学建模C题思路分析-物流网络分拣中心货量预测及人员排班

# 1 赛题 C 题 物流网络分拣中心货量预测及人员排班 电商物流网络在订单履约中由多个环节组成,图 ’ 是一个简化的物流 网络示意图。其中,分拣中心作为网络的中间环节,需要将包裹按照不同 流向进行分拣并发往下一个场地,最终使包裹…

Android中基于DWARF的stack unwind实现原理

一、简介 在软件开发中,unwind stack(栈回溯 或 调用栈展开)是调试和异常处理中至关重要的一环,通过理解其实现原理,可以更好地理解程序的执行流程,更有效地进行调试和错误排查。 本文主要介绍 AArch64 架构下的两种最典型的栈回溯…

RabbitMQ的介绍

为什么使用 MQ? 流量削峰和缓冲 如果订单系统最多能处理一万次订单,这个处理能力在足够应付正常时段的下单,但是在高峰期,可能会有两万次下单操作,订单系统只能处理一万次下单操作,剩下的一万次被阻塞。我们…

.NET JWT入坑

前言 JWT (JSON Web Token) 是一种安全传输信息的开放标准,由Header、Payload和Signature三部分组成。它主要用于身份验证、信息交换和授权。JWT可验证用户身份,确保访问权限,实现单点登录,并在客户端和服务器之间安全地交换信息…

SQLite 在Android安装与定制方案(十七)

返回:SQLite—系列文章目录 上一篇:SQLite超详细的编译时选项(十六) 下一篇:SQLite Android 绑定(十八) 安装 有三种方法可以将 SQLite Android 绑定添加到应用程序: 1、通过…