神经网络字符分类

按照题目要求修改了多层感知机

题目将图片的每个点作为输入,其中大小为28*28,中间有两个大小为100的隐藏层,激活函数是relu,然后输出大小是10,激活函数是softmax

优化器是Adam,结合了AdaGrad和RMSProp算法的优点,为每个参数计算自适应的学习率。

损失函数是交叉熵损失的函数,通常用于分类问题,交叉熵损失函数衡量的是实际输出(probability distribution)与期望输出(true labels)的相似程度,在多分类问题中特别有用。

准确率(Accuracy)指标衡量的是模型预测正确的样本数与总样本数之间的比例。

epochs:训练的轮数5

batch_size:每次训练时使用的样本数量64

---------------------------------------------------------------------------------------------------------------------------------

本实践使用多层感知器训练(DNN)模型,用于预测手写数字图片。

本次实验主要考查以下内容 (1)尝试调整隐藏层单元数量、激活函数、隐藏层数量对于模型性能的影响 激活函数参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html#activation-functional 或paddle.nn.functional (2)调整不同的训练的迭代轮次(epoch)、学习率、优化器并学会观察训练阶段与测试阶段loss变化,并依据此调整模型 优化器、学习率可参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html (2)补全测试数据集上计算accuracy的过程,可以采用model下的evaluate,也可以利用predict之后的result结果进行计算 模型训练与评估相关API调用举例 https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Model_cn.html

首先导入必要的包

numpy---------->python第三方库,用于进行科学计算

PIL------------> Python Image Library,python第三方图像处理库

matplotlib----->python的绘图库 pyplot:matplotlib的绘图框架

os------------->提供了丰富的方法来处理文件和目录

#导入需要的包
import numpy as np
import paddle as paddle
import paddle.nn as nn
import paddle.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import paddle
from paddle.io import Dataset
import os
print("本教程基于Paddle的版本号为:"+paddle.__version__)
! python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple

Step1:准备数据。

(1)数据集介绍

MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

(2)transform函数是定义了一个归一化标准化的标准

(3)train_dataset和test_dataset

paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集

transform=transform参数则为归一化标准

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
#print(np.array(test_dataset).shape)
print('加载完成')
#让我们一起看看数据集中的图片是什么样子的
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(25,22;155x154)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))
#让我们再来看看数据样子是什么样的吧
print(train_data0)

Step2.网络配置

以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。

# 定义多层感知器  
#动态图定义多层感知器
class mnist(paddle.nn.Layer):def __init__(self):super(mnist,self).__init__()#输入通道784,输出通道100self.conv1=nn.Linear(in_features=784,out_features=100)#输入通道100,输出通道100self.conv2=nn.Linear(in_features=100,out_features=100)#输入通道100,输出通道10self.conv3=nn.Linear(in_features=100,out_features=10)def forward(self, input_):x = paddle.reshape(input_, [input_.shape[0], -1])# print(x.shape)[64, 784]y=F.relu(self.conv1(x))y=F.relu(self.conv2(y))y=F.softmax(self.conv3(y))return y

 


from paddle.metric import Accuracy# 用Model封装模型
model = paddle.Model(mnist())   # 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())

Step3.模型训练及评估

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=5,batch_size=64,save_dir='multilayer_perceptron',verbose=1)#模型预测
result = model.predict(test_dataset, batch_size=1)#请补全模型性能验证代码,可使用model下的evaluate函数或者利用上面的预测出来的结果model.evaluate(test_dataset,verbose=1)
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]ress=model.predict_batch(test_data0)test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))print('test_data0 预测的数值为:' ,end='')
print(ress)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/26830.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习python实践——关于ward聚类分层算法的一些个人心得

最近在利用python跟着参考书进行机器学习相关实践,相关案例用到了ward算法,但是我理论部分用的是周志华老师的《西瓜书》,书上没有写关于ward的相关介绍,所以自己网上查了一堆资料,都很难说清楚ward算法,幸…

AIGC绘画设计:Midjourney V6 来袭,该版本有哪些新功能?

Midjourney V6 支持更自然的语言输入,可以处理更自然地对话式(以前的版本是以关键字为中心的)提示,对复杂提示有了更好的解释能力。大幅增加了每个 /image 的内存,可以处理更长、更详细的提示(从40 直接提升…

Android framework的Zygote源码分析

文章目录 Android framework的Zygote源码分析linux的fork Android framework的Zygote源码分析 init.rc 在Android系统中,zygote是一个native进程,是Android系统上所有应用进程的父进程,我们系统上app的进程都是由这个zygote分裂出来的。zyg…

12、云服务器上搭建环境

云服务器上搭建环境 12.1 选择一款远程连接工具(mobax) 有很多,比如mobax、xshll等等,我这里选择mobax,下载个免费版的即可 安装完成后,双击打开: 第一步,创建远程连接的用户,用户默认为root,密码为远程服务器的密码 第二步,输入远程公网IP,选择刚刚创建的用…

[C][数据结构][排序][下][快速排序][归并排序]详细讲解

文章目录 1.快速排序1.基本思想2.hoare版本3.挖坑法4.前后指针版本5.非递归版本改写 2.归并排序 1.快速排序 1.基本思想 任取待排序元素序列的某元素作为基准值,按照该排序码将待排序集合分割成两子序列,左子序列中所有元素均小于基准值,右…

目标检测中的anchor机制

目录 一、目标检测中的anchor机制 1.什么是anchor boxes? 二、什么是Anchor? ​编辑三、为什么需要anchor boxes? 四、anchor boxes是怎么生成的? 五、高宽比(aspect ratio)的确定 六、尺度(scale)的…

工业高温烤箱:现代工业的重要设备

工业高温烤箱,作为现代工业生产中不可或缺的关键设备,以其独特的高温烘烤能力,为各种工业产品的加工与制造提供了强有力的支持。斯博欣将对工业高温烤箱的原理、特点、应用领域及未来发展进行简要介绍。 一、工业高温烤箱的特点 1、高温性能优…

怎么修改Visual Studio Code中现在github账号

git config --global user.name “你的用户名” git config --global user.email “你的邮箱” git config --global --list git push -u origin your_branch_name git remote add origin

FastAPI 作为H5中流式输出的后端

FastAPI 作为H5中流式输出的后端 最近大家都在玩LLM,我也凑了热闹,简单实现了一个本地LLM应用,分享给大家,百分百可以用哦~^ - ^ 先介绍下我使用的三种工具: Ollama:一个免费的开源框架&…

centos7 xtrabackup mysql 基本测试(4)---虚拟机环境 mysql 修改datadir(有问题)

centos7 xtrabackup mysql 基本测试(4)—虚拟机环境 mysql 修改datadir 参考 centos更改mysql数据库目录 https://blog.csdn.net/sinat_33151213/article/details/125079593 https://blog.csdn.net/jx_ZhangZhaoxuan/article/details/129139499 创建目…

锌,能否成为下一个“铜”?

光大期货认为,今年以来,市场关注锌能否接棒铜价牛市。铜需求增长空间大,而锌消费结构传统,缺乏新亮点。虽然在供应的扰动上锌强于铜,但因需求乏善可陈,金融属性弱势,锌很难接棒铜,引…

数据质量守护者:数据治理视角下的智能数据提取策略

一、引言 在信息化和数字化高速发展的今天,数据已成为企业决策、运营和创新的核心要素。然而,随着数据量的快速增长和来源的多样化,数据质量问题逐渐凸显,成为制约企业数据价值发挥的关键因素。数据治理作为确保数据质量、提升数…

KEIL5.39 5.40 fromelf 不能生成HEX bug

使用AC6 编译,只要勾选了生成HEX。 结果报如下错误 暂时没有好的解决办法 1.替换法 2.在编译完后用命令生成HEX

蚓链研究院告诉你:蚓链数字化营销如何帮助力助你打造品牌!

在打造产品品牌的过程中,数字化营销会带来哪些利弊影响?如何消除或减少弊端?蚓链来和你一起分析、解决。 利处: 1.高度精准的目标定位:凭借大数据和先进算法,能精确锁定潜在客户,使营销资源得到…

数栈xAI:轻量化、专业化、模块化,四大功能革新 SQL 开发体验

在这个数据如潮的时代,SQL 已远远超越了简单的查询语言范畴,它已成为数据分析和决策制定的基石,成为撬动企业智慧决策的关键杠杆。SQL 的编写和执行效率直接关系到数据处理的速度和分析结果的深度,对企业洞察市场动态、优化业务流…

针对k8s集群已经加入集群的服务器进行驱逐

例如k8s 已经有很多服务器,现在由于服务器资源过剩,需要剥离一些服务器出来 查找节点名称: kubectl get nodes设置为不可调度: kubectl cordon k8s-node13恢复可调度 kubectl uncordon k8s-node13在驱逐之前先把需要剥离驱逐的节…

File及典型案例

File File对象表示一个路径,可以是文件的路径,也可以是文件夹的路径 这个路径可以是存在的,也允许不存在 常见的构造方法 图来自黑马程序员网课 package com.lazyGirl.filedemo;import java.io.File;public class Demo1 {public static vo…

立式护眼台灯十大品牌哪个好?立式护眼台灯十大品牌排行

立式护眼台灯十大品牌哪个好?根据国际市场的研究数据表明,我国在日常生活中对电子产品的依赖度极高,每天看电子产品的时间超过8小时,出现眼睛酸痛、干涩、视觉疲劳的人群也不再少数,而给眼睛带来伤害的除了电子产品中所含的蓝光之…

Vue3-滑动到最右验证功能

1、思路 1、在登录页面需要启动向右滑块验证 2、效果图 3、文章地址:滑动验证码的实现-vue-simple-verify 2、成分分析 1、由三块构成,分别是底部条、拖动条、拖动移动部分 2、底部条:整体容器,包括背景、边框和文字&#xf…

端午假期新房销售较去年下降16%,6月核心城市有望继续好转

内容提要 国常会强调政策措施落地见效,继续研究新去库存、稳市场政策。多城市二手房市场活跃,新房成交回暖缓慢。端午假期新房销售下降,核心城市市场有望好转。 文章正文 6月7日,国常会强调“着力推动已出台政策措施落地见效&am…