python手写数字识别(PaddlePaddle框架、MNIST数据集)

python手写数字识别(PaddlePaddle框架、MNIST数据集)

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalizetransform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
# 使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')# 用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成卷积神经网络的构建
class CNN(paddle.nn.Layer):def __init__(self):super().__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x# 开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练
train_loader = paddle.io.DataLoader(train_dataset, batch_size=128, shuffle=True)
# 加载训练集 batch_size 设为 128
def train(model):model.train()epochs = 3optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 用Adam作为优化函数print("Training:")for epoch in range(epochs):for batch_id, data in enumerate(train_loader()):x_data = data[0]y_data = data[1]predicts = model(x_data)loss = F.cross_entropy(predicts, y_data)# 计算损失acc = paddle.metric.accuracy(predicts, y_data)loss.backward()if batch_id % 300 == 0:print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))optim.step()optim.clear_grad()
model = CNN()
train(model)# 训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=128)
# 加载测试数据集
def test(model):model.eval()print("Testing:")for batch_id, data in enumerate(test_loader()):x_data = data[0]y_data = data[1]predicts = model(x_data)# 获取预测结果loss = F.cross_entropy(predicts, y_data)acc = paddle.metric.accuracy(predicts, y_data)if batch_id % 50 == 0:print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/13155.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Java基础揉碎]多线程基础

多线程基础 什么是程序, 进程 什么是线程 什么是单线程,多线程 并发, 并行的概念 单核cpu来回切换, 造成貌似同时执行多个任务, 就是并发; 在我们的电脑中可能同时存在并发和并行; 怎么查看自己电脑的cpu有几核 1.资源监视器查看 2.此电脑图标右键管理- 设备管理器- 处理器…

k8s 二进制安装 详细安装步骤

目录 一 实验环境 二 操作系统初始化配置(所有机器) 1,关闭防火墙 2,关闭selinux 3,关闭swap 4, 根据规划设置主机名 5, 做域名映射 6,调整内核参数 7, 时间同步 三 部署 dock…

uniapp vu3 scroll-view 滚动到指定位置

设置 scroll-view <scroll-view :scroll-y"true" :scroll-with-animation"true" :scroll-top"scrollTop" :style"height:${height}px"><view v-for"item in 10" :id"box${item}">box {{item}}</v…

原生IP介绍

原生IP&#xff0c;顾名思义&#xff0c;即初始真实IP地址&#xff0c;是指从互联网服务提供商获得的IP地址&#xff0c;IP地址在互联网与用户之间直接建立联系&#xff0c;不需要经过代理服务器代理转发。 原生IP具备以下特点。 1.直接性 原生IP可以直接连接互联网&#xff…

337_C++_内存对齐操作,内存分配、或其他需要数据对齐的场合中是很常见的操作

size_t ImagesCache::_alignSize(size_t srcSz, size_t alnSz) {if (0 == alnSz) {printf("[ImagesCache] Incorrect input parameters\n");return srcSz;

代码随想录算法训练营第五十四天

第二题我看了很久还是没太明白&#xff0c;我发现理解动规有一点点吃力了啊&#xff0c;努努力。 392.判断子序列 总感觉在不等于的时候&#xff0c;应该是dp[i][j] dp[i-1][j-2]; 这里其实按他那个图会更好理解一点。 class Solution { public:bool isSubsequence(string s, …

Gone框架介绍19 -如何进行单元测试?

gone是可以高效开发Web服务的Golang依赖注入框架 github地址&#xff1a;https://github.com/gone-io/gone 文档地址&#xff1a;https://goner.fun/zh/ 请帮忙在github上点个 ⭐️吧&#xff0c;这对我很重要 &#xff1b;万分感谢&#xff01;&#xff01; 文章目录 单元测试…

CentOs安装

安装 开发工具 &#xff1a;GCC、 JDK、mysql 如果出现蓝屏&#xff0c;要在BIOS开启虚拟化支持&#xff0c;或者移除打印机。

Google:站长移除无效网址

当您的网址不需要呈现在Google站长中时&#xff0c;您可以在站长工具中移除网址 操作步骤&#xff1a;登录Google站长&#xff0c;绑定网站完成后&#xff0c;点击左侧删除 >> 输入网址 如果遇到一些网址&#xff0c;可以找寻网址间的规律&#xff0c;比如说&#xff0…

2024生日快乐祝福HTML源码

源码介绍 2024生日快乐祝福HTML源码&#xff0c;源码由HTMLCSSJS组成&#xff0c;记事本打开源码文件可以进行内容文字之类的修改&#xff0c;双击html文件可以本地运行效果&#xff0c;也可以上传到服务器里面&#xff0c; 源码截图 源码下载 2024生日快乐祝福HTML源码

Shell脚本 <<EOF ... EOF语法(Here Document)(特殊的输入重定向方式)(定界符)

文章目录 Here Document语法Here Document 的基本语法使用场景 关于定界符定界符不是变量定界符在 Here Document 中只是一个字符串&#xff0c;主要功能是标记输入文本的开始和结束&#xff0c;使用时应遵循最佳实践格式要求例子和说明如何使用定界符定界符可重复使用&#xf…

Spring数据访问全攻略:从JdbcTemplate到声明式事务

上文讲到 —— 航向数据之海&#xff1a;Spring的JPA与Hibernate秘籍 本文目录 四. JdbcTemplate的使用定义JdbcTemplate及其在Spring中的作用展示如何使用JdbcTemplate简化数据库操作1. 配置JdbcTemplate2. 使用JdbcTemplate查询数据3. 打印查询结果 五. Spring的事务管理介绍…

桥接模式

桥接模式&#xff1a;在这种模式下&#xff0c;虚拟机就像是局域网中一台独立的主机&#xff0c;能够访问网内任何一台机器。在桥接模式下&#xff0c;必须为虚拟系统手动配置IP地址、子网掩码&#xff0c;并且这些配置需要与宿主机器处于同一网段&#xff0c;以便虚拟系统和宿…

leetcode-42. 接雨水(双指针,前缀)

42. 接雨水 /*** param {number[]} height* return {number}*/ var trap function (height) {let len height.length;let pre_max new Array(len).fill(0);let suf_max new Array(len).fill(0);pre_max[0] height[0];suf_max[len - 1] height[len - 1];for (let i 1; i…

queue使用

C的queue是一种先进先出&#xff08;FIFO&#xff09;的数据结构&#xff0c;可以用来存储一系列元素。它属于STL&#xff08;Standard Template Library&#xff09;的一部分&#xff0c;以queue模板类的形式提供。 要使用queue&#xff0c;需要包含头文件&#xff0c;并使用…

Linux shell编程学习笔记49:strings命令

0 前言 在使用Linux的过程中&#xff0c;有时我们需要在obj文件或二进制文件中查找可打印的字符串&#xff0c;那么可以strings命令。 1. strings命令 的功能、格式和选项说明 我们可以使用命令 strings --help 来查看strings命令的帮助信息。 pupleEndurer bash ~ $ strin…

在k8s中搭建elasticsearch高可用集群,并对数据进行持久化存储

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《洞察之眼&#xff1a;ELK监控与可视化》&#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、Elasticsearch简介 2、k8s简介 二、环境准备 …

Git项目管理——提交项目和版本回退(二)

个人名片&#xff1a; &#x1f393;作者简介&#xff1a;嵌入式领域优质创作者&#x1f310;个人主页&#xff1a;妄北y &#x1f4de;个人QQ&#xff1a;2061314755 &#x1f48c;个人邮箱&#xff1a;[mailto:2061314755qq.com] &#x1f4f1;个人微信&#xff1a;Vir2025WB…

android绘制多个黑竖线条

本文实例为大家分享了android绘制多个黑竖线条展示的具体代码&#xff0c;供大家参考&#xff0c;具体内容如下 1.写一个LinearLayout的布局&#xff0c;将宽度写成5dp将高度写成match_parent. 2.在写一个类继承LinearLayout&#xff0c;用LayoutInflater实现子布局的在这个L…

train_gpt2_fp32.cu - main

llm.c/test_gpt2_fp32.cu at master karpathy/llm.c (github.com) 源码 // ---------------------------------------------------------------------------- // main training loop int main(int argc, char *argv[]) {// read in the (optional) command line argumentsco…