机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 构建神经网络模型
model = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_images, train_labels, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)# 保存模型
model.save('mnist_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像# 调整图片大小为 28x28 像素
image = image.resize((28, 28))# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)# 使用模型进行预测
predictions = loaded_model.predict(input_image)# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/675734.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

跳过mysql5.7密码并重置密码 shell脚本

脚本 目前只是验证了5.7 版本是可以的,8.多的还需要验证 以下是一个简单的Shell脚本,用于跳过MySQL密码设置并重置密码: #!/bin/bash yum install psmisc -y# 停止MySQL服务 sudo service mysqld stop# 跳过密码验证 sudo mysqld --skip-g…

【Linux】进程学习(二):进程状态

目录 1.进程状态1.1 阻塞1.2 挂起 2. 进程状态2.1 运行状态-R进一步理解运行状态 2.2 睡眠状态-S2.3 休眠状态-D2.4 暂停状态-T2.5 僵尸状态-Z僵尸进程的危害 2.6 死亡状态-X2.7 孤儿进程 1.进程状态 1.1 阻塞 阻塞:进程因为等待某种条件就绪,而导致的…

Elasticsearch: 非结构化的数据搜索

很多大数据组件在快速原型时期都是Java实现,后来因为GC不可控、内存或者向量化等等各种各样的问题换到了C,比如zookeeper->nuraft(https://www.yuque.com/treblez/qksu6c/hu1fuu71hgwanq8o?singleDoc# 《olap/clickhouse keeper 一致性协调服务》)&a…

掌握Vue,开启你的前端开发之路!

介绍:Vue.js是一个构建数据驱动的Web应用的渐进式框架,它以简洁和轻量级著称。 首先,Vue.js的核心在于其视图层,它允许开发者通过简单的模板语法将数据渲染进DOM(文档对象模型)。以下是Vue.js的几个重要特点…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之StepperItem组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之StepperItem组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、StepperItem组件 用作Stepper组件的页面子组件。 子组件 无。 接口 St…

Project2007下载安装教程,保姆级教程,附安装包和工具

前言 Project是一款项目管理软件,不仅可以快速、准确地创建项目计划,而且可以帮助项目经理实现项目进度、成本的控制、分析和预测,使项目工期大大缩短,资源得到有效利用,提高经济效益。软件设计目的在于协助专案经理发…

如何修复Mac的“ kernel_task” CPU使用率过高的Bug?

当计算机开始缓慢运行时,这从来都不是一件有趣的事情,但是当您弄不清它为何如此缓慢时,甚至会变得更糟。如果您已经关闭了所有程序,并且Mac上的所有内容仍然感觉像是在糖蜜中移动,这可能是令人讨厌的kernel_task导致高…

Vue3快速上手(一)使用vite创建项目

一、准备 在此之前,你的电脑,需要安装node.js,我这边v18.19.0 wangdymb 2024code % node -v v18.19.0二、创建 执行npm create vuelatest命令即可使用vite创建vue3项目 有的同学可能卡主不动,可能是npm的registry设置的问题 先看下&#x…

C语言----内存函数

内存函数主要用于动态分配和管理内存,它直接从指针的方位上进行操作,可以实现字节单位的操作。 其包含的头文件都是:string.h memcpy copy block of memory的缩写----拷贝内存块 格式: void *memcpy(void *dest, const void …

数据库切片大对决:ShardingSphere与Mycat技术解析

欢迎来到我的博客,代码的世界里,每一行都是一个故事 数据库切片大对决:ShardingSphere与Mycat技术解析 前言ShardingSphere与Mycat简介工作原理对比功能特性对比 前言 在数据库的舞台上,有两位颇受欢迎的明星,它们分别…

基于OpenCV灰度图像转GCode的螺旋扫描实现

基于OpenCV灰度图像转GCode的螺旋扫描实现 引言激光雕刻简介OpenCV简介实现步骤 1.导入必要的库2. 读取灰度图像3. 图像预处理4. 生成GCode5. 保存生成的GCode6. 灰度图像螺旋扫描代码示例 总结 系列文章 ⭐深入理解G0和G1指令:C中的实现与激光雕刻应用⭐基于二值…

攻防世界 CTF Web方向 引导模式-难度1 —— 1-10题 wp精讲

目录 view_source robots backup cookie disabled_button get_post weak_auth simple_php Training-WWW-Robots view_source 题目描述: X老师让小宁同学查看一个网页的源代码,但小宁同学发现鼠标右键好像不管用了。 不能按右键,按F12 robots …

springboot微信小程序uniapp学习计划与日程管理系统

基于springboot学习计划与日程管理系统,确定学习计划小程序的目标,明确用户需求,学习计划小程序的主要功能是帮助用户制定学习计划,并跟踪学习进度。页面设计主要包括主页、计划学习页、个人中心页等,然后用户可以利用…

Elasticsearch(四)

是这样的前面的几篇笔记,感觉对我没有形成知识体系,感觉乱糟糟的,只是大概的了解了一些基础知识,仅此而已,而且对于这技术栈的学习也是为了在后面的java开发使用,但是这里的API学的感觉有点乱!然…

【零基础入门TypeScript】Union

目录 语法:Union文字 示例:Union类型变量 示例:Union 类型和函数参数 Union类型和数组 示例:Union类型和数组 TypeScript 1.4 使程序能够组合一种或两种类型。Union类型是表达可以是多种类型之一的值的强大方法。使用管道符号…

高仿原神官网UI 纯html源码

高仿原神官网UI源码介绍 如果您希望打造一个与原神官方网站相似的外观和用户体验,但又不想使用复杂的框架或模板,那么我们的高仿原神官网UI源码将是一个完美的选择。它采用纯HTML5构建,无需任何额外的CSS或JavaScript库支持,即可…

代码随想录算法训练营第四十六天(动态规划篇)|01背包(滚动数组方法)

01背包(滚动数组方法) 学习资料:代码随想录 (programmercarl.com) 题目链接(和上次一样):题目页面 (kamacoder.com) 思路 使用一维滚动数组代替二维数组。二维数组的解法记录在:代码随想录算…

C#,十进制展开数(Decimal Expansion Number)的算法与源代码

1 十进制展开数 十进制展开数(Decimal Expansion Number)的计算公式: DEN n^3 - n - 1 The decimal expansion of a number is its representation in base -10 (i.e., in the decimal system). In this system, each "decimal place…

Zabbix 配置实时开通的LDAP认证-基于AD

介绍 本教程适用于6.4-7.0版本的Zabbix,域控(AD)使用Windows Server 2022搭建,域控等级为 2016。 域控域名为 songxwn.com 最终实现AD用户统一认证,统一改密,Zabbix用户自动添加。(6.4之前不…

使用npm包js-web-screen-shot做网页截图,可以对截图加文字,箭头等等,类似于微信截图

<template><div class"m-feedback-wrap" :style"{ top: ${feedbackHeight}px }"><div class"m-feedback-icon-wrap"><el-tooltipclass"item"effect"dark"content"内容"placement"left-…