深度学习神经网络实战:多层感知机,手写数字识别

目的

利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别

设计

简单三层神经网络

  • 输入层
    28*28个神经原,代表每一张手写数字图片的灰度
  • 隐藏层
    100个神经原
  • 输出层
    -10个神经原,分别代表10个数字

代码

// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();// 添加输入层
model.add(tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));// 编译模型
model.compile({optimizer: "sgd",loss: "categoricalCrossentropy",metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;const trainXs = [];
const testXs = [];const trainLabels = [];
const testLabels = [];for (let i = 0; i < trainingSet.length; i++) {trainXs.push(trainingSet[i].input);trainLabels.push(trainingSet[i].output.indexOf(1));
}for (let i = 0; i < testSet.length; i++) {testXs.push(testSet[i].input);testLabels.push(testSet[i].output.indexOf(1));
}// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];// 模型训练
model.fit(trainXsTensor, trainYsOneHot, {batchSize: 64,epochs: 100,validationSplit: 0.2,callbacks: {onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),onEpochEnd: async (epoch, logs) => {console.log(`Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(3)}, Accuracy: ${logs.acc.toFixed(3)}`);//记录loss和acc,绘制曲线图accPlot.push(logs.acc.toFixed(3));lossPlot.push(logs.loss.toFixed(3));await tf.nextFrame(); // 防止阻塞},onBatchEnd: async (batch, logs) => {console.log(`Batch ${batch} completed. Loss: ${logs.loss.toFixed(3)}, Accuracy: ${logs.acc.toFixed(3)}`);await tf.nextFrame(); // 防止阻塞},},}).then((history) => {console.log("Training completed!", history);//绘制模型训练过程中的损失函数和模型精度曲线变化const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);plot.plot([{ x: epochs, y: lossPlot, name: "Loss" },{ x: epochs, y: accPlot, name: "Accuracy" },],{filename: "loss_acc.png",});//模型评估const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);const testYsOneHot = tf.oneHot(testLabels, 10);const result = model.evaluate(testXsTensor, testYsOneHot);const testLoss = result[0].dataSync()[0];const testAccuracy = result[1].dataSync()[0];console.log(`Test loss: ${testLoss.toFixed(3)}`);console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);//保存模型model.save("file://./my-model").then(() => {console.log("Model saved!");});});

package.json

{"name": "neural_network","version": "1.0.0","description": "","type": "module","main": "mlpTest.js","scripts": {"test": "echo \"Error: no test specified\" && exit 1",},"author": "","license": "ISC","dependencies": {"@tensorflow/tfjs": "^4.17.0","@tensorflow/tfjs-node": "^4.17.0","@tensorflow/tfjs-vis": "^1.0.0","mnist": "^1.1.0","nodeplotlib": "^0.7.7"},"devDependencies": {"@babel/core": "^7.0.0","@babel/preset-env": "^7.0.0","babel-loader": "^8.0.0","webpack": "^5.0.0","webpack-cli": "^4.0.0"}
}

模型结果

模型精度

损失函数与模型精度变化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

负载均衡.

简介: 将请求/数据【均匀】分摊到多个操作单元上执行&#xff0c;负载均衡的关键在于【均匀】。 负载均衡的分类: 网络通信分类 四层负载均衡:基于 IP 地址和端口进行请求的转发。七层负载均衡:根据访问用户的 HTTP 请求头、URL 信息将请求转发到特定的主机。 载体维度分类 硬…

前端开发_Vue入门

Vue概念 Vue 是一个用于构建用户界面的渐进式框架 构建用户界面&#xff1a;基于数据渲染出用户看到的页面渐进式&#xff1a;循序渐进框架&#xff1a;一套完整的项目解决方案 创建Vue实例 准备容器 引包&#xff08;开发版本/生产版本&#xff09; <script src"h…

消息中间件篇之Kafka-数据清理机制

一、Kafka文件存储机制 Kafka文件存储结构&#xff1a;一个Topic有多个分区。每一个分区都有多个段&#xff0c;每个段都有三个文件。 为什么要分段&#xff1f;1. 删除无用文件方便&#xff0c;提高磁盘利用率。 2. 查找数据便捷。 二、数据清理机制 1.日志的清理策略方案1 根…

[C++][linux]Linux上内存共享内存用法

一&#xff0c;什么是共享内存 共享内存&#xff08;Shared Memory&#xff09;&#xff0c;指两个或多个进程共享一个给定的存储区。进程可以将同一段共享内存连接到它们自己的地址空间中&#xff0c;所有进程都可以访问共享内存中的地址&#xff0c;就好像它们是由用C语言函…

GEE入门篇|遥感专业术语(实践操作4):光谱分辨率(Spectral Resolution)

目录 光谱分辨率&#xff08;Spectral Resolution&#xff09; 1.MODIS 2.EO-1 光谱分辨率&#xff08;Spectral Resolution&#xff09; 光谱分辨率是指传感器进行测量的光谱带的数量和宽度。 您可以将光谱带的宽度视为每个波段的波长间隔&#xff0c;在多个波段测量辐射亮…

RestTemplate启动问题解决

⭐ 作者简介&#xff1a;码上言 ⭐ 代表教程&#xff1a;Spring Boot vue-element 开发个人博客项目实战教程 ⭐专栏内容&#xff1a;个人博客系统 ⭐我的文档网站&#xff1a;http://xyhwh-nav.cn/ RestTemplate启动问题解决 问题&#xff1a;在SpringCloud架构项目中配…

Java SpringBoot 整合 MyBatis 小案例

Java SpringBoot 整合 MyBatis 小案例 基础配置&#xff08;注意版本号&#xff0c;容易报错&#xff09; pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http…

TikTok东南亚小店爆单思路,怎么玩?

东南亚地区的跨境电商市场已经成为全球范围内最具吸引力的市场之一&#xff0c;在各个跨境电商平台上&#xff0c;都是转化率最高的站点之一。TikTok作为电商黑马&#xff0c;吸引了一大波跨境电商玩家入驻&#xff0c;其中东南亚小店也成为热门的选择&#xff0c;那么东南亚小…

当Vue项目启动后,通过IP地址方式在相同网络段的其他电脑上无法访问前端页面?

当Vue项目启动后&#xff0c;通过IP地址方式在相同网络段的其他电脑上无法访问前端页面&#xff0c;可能是由以下几个原因造成的&#xff1a; 服务监听地址&#xff1a;默认情况下&#xff0c;许多开发服务器&#xff08;如Vue CLI的vue-cli-service serve&#xff09;只监听lo…

ky10-server docker 离线安装包、离线安装

离线安装脚本 # ---------------离线安装docker------------------- rpm -Uvh --force --nodeps *.rpm# 修改docker拉取源为国内 rm -rf /etc/docker mkdir -p /etc/docker touch /etc/docker/daemon.json cat >/etc/docker/daemon.json<<EOF{"registry-mirro…

kubectl 命令行管理K8S(上)

目录 陈述式资源管理方式 介绍 命令 项目的生命周期 创建 kubectl create命令 发布 kubectl expose命令 更新 kubectl set 回滚 kubectl rollout 删除 kubectl delete 应用发布策略 金丝雀发布 陈述式资源管理方式 介绍 1.kubernetes 集群管理集群资源…

深圳市萨科微半导体有限公司一直研究新材料新工艺

深圳市萨科微&#xff08;www.slkoric.com&#xff09;半导体有限公司一直研究新材料新工艺&#xff0c;不断推出新产品&#xff0c;驱动公司不断发展。最近萨科微slkor推出SL40T120FL系列IGBT单管&#xff0c;和CMOS运算放大器SLA333等产品&#xff0c;为新能源汽车、太阳能光…

【lv14 day10内核模块参数传递和依赖】

一、模块传参 module_param(name,type,perm);//将指定的全局变量设置成模块参数 /* name:全局变量名 type&#xff1a; 使用符号 实际类型 传参方式 bool bool insmod xxx.ko 变量名0 或 1 invbool bool insmod xxx.ko 变量名0 或 1 charp char * insmod xxx.ko 变量名“字符串…

基于Java学生管理系统设计与实现(源码+部署文档)

博主介绍&#xff1a; ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ &#x1f345; 文末获取源码联系 &#x1f345; &#x1f447;&#x1f3fb; 精彩专栏 推荐订阅 &#x1f447;&#x1f3fb; 不然下次找不到 Java项目精品实…

1904_ARM Cortex M系列芯片特性小结

1904_ARM Cortex M系列芯片特性小结 全部学习汇总&#xff1a; g_arm_cores: ARM内核的学习笔记 (gitee.com) ARM Cortex M系列的MCU用过好几款了&#xff0c;也涉及到了不同的内核。不过&#xff0c;关于这些内核的基本的特性还是有些不了解。从ARM的官方网站上找来了一个对比…

【UE 材质】冰冻效果

效果 步骤 1. 打开“Quixel Bridge” 下载冰的纹理 2. 新建一个材质&#xff0c;这里命名为“M_Frozen” 打开“M_Frozen”&#xff0c;添加如下节点&#xff0c;此时我们可以通过控制参数“偏移”来改变边界的偏移 此时预览效果如下 如果增加参数“偏移”的默认值效果如下 3.…

OpenCV 16 - Qt使用opencv视觉库

1 下载好opencv视觉库 不知道怎么下载和编译opencv视觉库的可以直接使用这个 : opencvcv_3.4.2_qt 2 解压opencv包 3 打开opencv的安装目录 4.打开x86/bin 复制里面所有的dll文件&#xff0c;黏贴到C/windows/syswow64里面 5 新建Qt项目 6 修改pro文件:添加对应的头文件和库文件…

[python] 代码转换成流程图 (pycallgraph)

1. centos 7 安装dot 在 CentOS 7 上安装 Graphviz 中的 dot 工具可以通过 yum 命令进行。dot 工具是 Graphviz 提供的一个用于生成图形的命令行工具&#xff0c;通常在安装 Graphviz 的时候会一并安装。 以下是在 CentOS 7 上安装 Graphviz 的步骤&#xff1a; 更新 yum 软…

腾讯文档(excel也一样)设置单元格的自动行高列宽

1. 选中单元格 可选择任意一个或者几个 2. 设置自动 行高和列宽 即可生效

sql-labs第46关 order by盲注

sql-labs第46关 order by盲注 来到了第46关进入关卡发现让我们输入的参数为sort&#xff0c;我们输入?sort1尝试&#xff1a; 输入?sort2,3,发现表格按照顺序进行排列输出&#xff0c;明显是使用了order by相关的函数。 我们将参数变成1进行尝试&#xff0c;就会报错&…