深度学习神经网络实战:多层感知机,手写数字识别

目的

利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别

设计

简单三层神经网络

  • 输入层
    28*28个神经原,代表每一张手写数字图片的灰度
  • 隐藏层
    100个神经原
  • 输出层
    -10个神经原,分别代表10个数字

代码

// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();// 添加输入层
model.add(tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));// 编译模型
model.compile({optimizer: "sgd",loss: "categoricalCrossentropy",metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;const trainXs = [];
const testXs = [];const trainLabels = [];
const testLabels = [];for (let i = 0; i < trainingSet.length; i++) {trainXs.push(trainingSet[i].input);trainLabels.push(trainingSet[i].output.indexOf(1));
}for (let i = 0; i < testSet.length; i++) {testXs.push(testSet[i].input);testLabels.push(testSet[i].output.indexOf(1));
}// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];// 模型训练
model.fit(trainXsTensor, trainYsOneHot, {batchSize: 64,epochs: 100,validationSplit: 0.2,callbacks: {onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),onEpochEnd: async (epoch, logs) => {console.log(`Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(3)}, Accuracy: ${logs.acc.toFixed(3)}`);//记录loss和acc,绘制曲线图accPlot.push(logs.acc.toFixed(3));lossPlot.push(logs.loss.toFixed(3));await tf.nextFrame(); // 防止阻塞},onBatchEnd: async (batch, logs) => {console.log(`Batch ${batch} completed. Loss: ${logs.loss.toFixed(3)}, Accuracy: ${logs.acc.toFixed(3)}`);await tf.nextFrame(); // 防止阻塞},},}).then((history) => {console.log("Training completed!", history);//绘制模型训练过程中的损失函数和模型精度曲线变化const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);plot.plot([{ x: epochs, y: lossPlot, name: "Loss" },{ x: epochs, y: accPlot, name: "Accuracy" },],{filename: "loss_acc.png",});//模型评估const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);const testYsOneHot = tf.oneHot(testLabels, 10);const result = model.evaluate(testXsTensor, testYsOneHot);const testLoss = result[0].dataSync()[0];const testAccuracy = result[1].dataSync()[0];console.log(`Test loss: ${testLoss.toFixed(3)}`);console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);//保存模型model.save("file://./my-model").then(() => {console.log("Model saved!");});});

package.json

{"name": "neural_network","version": "1.0.0","description": "","type": "module","main": "mlpTest.js","scripts": {"test": "echo \"Error: no test specified\" && exit 1",},"author": "","license": "ISC","dependencies": {"@tensorflow/tfjs": "^4.17.0","@tensorflow/tfjs-node": "^4.17.0","@tensorflow/tfjs-vis": "^1.0.0","mnist": "^1.1.0","nodeplotlib": "^0.7.7"},"devDependencies": {"@babel/core": "^7.0.0","@babel/preset-env": "^7.0.0","babel-loader": "^8.0.0","webpack": "^5.0.0","webpack-cli": "^4.0.0"}
}

模型结果

模型精度

损失函数与模型精度变化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

负载均衡.

简介: 将请求/数据【均匀】分摊到多个操作单元上执行&#xff0c;负载均衡的关键在于【均匀】。 负载均衡的分类: 网络通信分类 四层负载均衡:基于 IP 地址和端口进行请求的转发。七层负载均衡:根据访问用户的 HTTP 请求头、URL 信息将请求转发到特定的主机。 载体维度分类 硬…

前端开发_Vue入门

Vue概念 Vue 是一个用于构建用户界面的渐进式框架 构建用户界面&#xff1a;基于数据渲染出用户看到的页面渐进式&#xff1a;循序渐进框架&#xff1a;一套完整的项目解决方案 创建Vue实例 准备容器 引包&#xff08;开发版本/生产版本&#xff09; <script src"h…

消息中间件篇之Kafka-数据清理机制

一、Kafka文件存储机制 Kafka文件存储结构&#xff1a;一个Topic有多个分区。每一个分区都有多个段&#xff0c;每个段都有三个文件。 为什么要分段&#xff1f;1. 删除无用文件方便&#xff0c;提高磁盘利用率。 2. 查找数据便捷。 二、数据清理机制 1.日志的清理策略方案1 根…

学习面向对象

面向对象 概念 现实生活&#xff1a; 类&#xff1a;抽象的概念&#xff0c;把具有相同特征和操作的事物归为一类 先有实体&#xff0c;再有类的概念 代码世界&#xff1a; 类&#xff1a;抽象的概念&#xff0c;把具有相同属性和方法的对象归为一类 编写顺序&#xff1a;先有…

神经网络系列---池化

文章目录 池化最大池化平均池化 池化 最大池化 最大池化&#xff08;Max Pooling&#xff09;是卷积神经网络中常用的一种池化技术。其操作是&#xff1a;在输入特征图的一个局部窗口内选取最大的值作为该窗口的输出。 数学表达式如下&#xff1a; 考虑一个输入特征图 A A…

[C++][linux]Linux上内存共享内存用法

一&#xff0c;什么是共享内存 共享内存&#xff08;Shared Memory&#xff09;&#xff0c;指两个或多个进程共享一个给定的存储区。进程可以将同一段共享内存连接到它们自己的地址空间中&#xff0c;所有进程都可以访问共享内存中的地址&#xff0c;就好像它们是由用C语言函…

【ELK05】es的java-api操作-Java High Level REST Client常用功能

1.客户端概括 1.1支持多种客户端 ES支持多种语言客户都安,包括ruby js python java go .net等,其中java目前最新版本的客户都安支持2种方式。一种是旧版已经过时的transport client ,一种是java high level rest client,前者是通过tcp协议链接访问es,后者就是java代码实…

系统学习Python——装饰器:类装饰器-[装饰器与管理器函数]

分类目录&#xff1a;《系统学习Python》总目录 抛开这些细节微妙性&#xff0c;Tracer类装饰器示例最终仍然是依赖于__getattr__来拦截对包装的和内嵌实例对象的获取。正如我们在前面见到的&#xff0c;我们真正需要完成的只是把实例创建调用移入一个类的内部&#xff0c;而不…

GEE入门篇|遥感专业术语(实践操作4):光谱分辨率(Spectral Resolution)

目录 光谱分辨率&#xff08;Spectral Resolution&#xff09; 1.MODIS 2.EO-1 光谱分辨率&#xff08;Spectral Resolution&#xff09; 光谱分辨率是指传感器进行测量的光谱带的数量和宽度。 您可以将光谱带的宽度视为每个波段的波长间隔&#xff0c;在多个波段测量辐射亮…

RestTemplate启动问题解决

⭐ 作者简介&#xff1a;码上言 ⭐ 代表教程&#xff1a;Spring Boot vue-element 开发个人博客项目实战教程 ⭐专栏内容&#xff1a;个人博客系统 ⭐我的文档网站&#xff1a;http://xyhwh-nav.cn/ RestTemplate启动问题解决 问题&#xff1a;在SpringCloud架构项目中配…

服务器双线什么意思?有什么使用优势?

对于企业而言服务器至关重要&#xff0c;它几乎链接着企业的业务&#xff0c;也是员工业务沟通的桥梁&#xff0c;为了保持服务器稳定持续的工作&#xff0c;很多企业都很关心服务器双线的问题&#xff0c;相对来说现在大部分企业使用的都是服务器双线&#xff0c;那服务器双线…

Java SpringBoot 整合 MyBatis 小案例

Java SpringBoot 整合 MyBatis 小案例 基础配置&#xff08;注意版本号&#xff0c;容易报错&#xff09; pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http…

TikTok东南亚小店爆单思路,怎么玩?

东南亚地区的跨境电商市场已经成为全球范围内最具吸引力的市场之一&#xff0c;在各个跨境电商平台上&#xff0c;都是转化率最高的站点之一。TikTok作为电商黑马&#xff0c;吸引了一大波跨境电商玩家入驻&#xff0c;其中东南亚小店也成为热门的选择&#xff0c;那么东南亚小…

当Vue项目启动后,通过IP地址方式在相同网络段的其他电脑上无法访问前端页面?

当Vue项目启动后&#xff0c;通过IP地址方式在相同网络段的其他电脑上无法访问前端页面&#xff0c;可能是由以下几个原因造成的&#xff1a; 服务监听地址&#xff1a;默认情况下&#xff0c;许多开发服务器&#xff08;如Vue CLI的vue-cli-service serve&#xff09;只监听lo…

ubuntu22.04安装cuda11.5+cudnn8.8.0

因为pytorch1.11.0与cuda版本的关系 需要用到cuda11.5 否则报错 "addmm_sparse_cuda" not implemented for Half cuda11.5.0及以前的版本不会出现这个问题 因此重新安装&#xff0c;步骤如下&#xff1a; 安装CUDA-11.5.0 wget https://developer.download.nvi…

2023年09月CCF-GESP编程能力等级认证C++编程六级真题解析

本文收录于专栏《C++等级认证CCF-GESP真题解析》,专栏总目录・点这里 一、单选题(共15题,共30分) 第1题 近年来,线上授课变得普遍,很多有助于改善教学效果的设备也逐渐流行,其中包括比较常用的手写板,那么它属于哪类设备?( ) A:输入 B:输出 C:控制 D:记录 答…

ky10-server docker 离线安装包、离线安装

离线安装脚本 # ---------------离线安装docker------------------- rpm -Uvh --force --nodeps *.rpm# 修改docker拉取源为国内 rm -rf /etc/docker mkdir -p /etc/docker touch /etc/docker/daemon.json cat >/etc/docker/daemon.json<<EOF{"registry-mirro…

kubectl 命令行管理K8S(上)

目录 陈述式资源管理方式 介绍 命令 项目的生命周期 创建 kubectl create命令 发布 kubectl expose命令 更新 kubectl set 回滚 kubectl rollout 删除 kubectl delete 应用发布策略 金丝雀发布 陈述式资源管理方式 介绍 1.kubernetes 集群管理集群资源…

深圳市萨科微半导体有限公司一直研究新材料新工艺

深圳市萨科微&#xff08;www.slkoric.com&#xff09;半导体有限公司一直研究新材料新工艺&#xff0c;不断推出新产品&#xff0c;驱动公司不断发展。最近萨科微slkor推出SL40T120FL系列IGBT单管&#xff0c;和CMOS运算放大器SLA333等产品&#xff0c;为新能源汽车、太阳能光…

【lv14 day10内核模块参数传递和依赖】

一、模块传参 module_param(name,type,perm);//将指定的全局变量设置成模块参数 /* name:全局变量名 type&#xff1a; 使用符号 实际类型 传参方式 bool bool insmod xxx.ko 变量名0 或 1 invbool bool insmod xxx.ko 变量名0 或 1 charp char * insmod xxx.ko 变量名“字符串…