Matlab多输入单输出之倾斜手写数字识别

本文主要介绍使用matlab构建多输入单输出的网络架构,来实现倾斜的手写数字识别,使用concatenationLayer来拼接特征,实现网络输入多个特征。

1.加载训练数据

加载数据:手写数字的图像、真实数字标签和数字顺时针旋转的角度。

load DigitsDataTrain

网络的输入数据类型需要是datastore,使用 arrayDatastore 将三个普通矩阵变为datastore,最后再使用combine合并。

dsX1Train = arrayDatastore(XTrain,IterationDimension=4);
dsX2Train = arrayDatastore(anglesTrain);
dsTTrain = arrayDatastore(labelsTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);

显示20个随机训练图像:

numObservationsTrain = numel(labelsTrain);
idx = randperm(numObservationsTrain,20);figure
tiledlayout("flow");
for i = 1:numel(idx)nexttileimshow(XTrain(:,:,:,idx(i)))title("Angle: " + anglesTrain(idx(i)))
end

图片

2.设计网络架构

设计如下的网络结构:

图片

  • 对于图像输入,指定一个大小与输入数据匹配的图像输入层。

  • 对于特征输入,指定一个大小与输入特征数量匹配的特征输入层。

  • 对于图像输入分支,进行卷积、批归一化和ReLU层块,其中卷积层有16个5×5滤波器。

  • 为了将批归一化层的输出转换为特征向量,需要用一个大小为50的全连接层。

  • 要将第一个全连接层的输出与特征输入连接起来,使用flatten layer将全连接层中的 "SSCB" (空间、空间、通道、批处理)输出展平,使其具有 "CB" 格式。

  • 沿第一维度(channel维度)将平坦层的输出与特征输入连接起来

  • 对于分类输出,包括一个输出大小与类数匹配的全连接层,然后是softmax层。

创建一个空的神经网络:

net = dlnetwork;

创建一个网络主分支,并将其添加到网络中:

[h,w,numChannels,numObservations] = size(XTrain);
numFeatures = 1;
classNames = categories(labelsTrain);
numClasses = numel(classNames);imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;layers = [imageInputLayer(imageInputSize,Normalization="none")convolution2dLayer(filterSize,numFilters)batchNormalizationLayerreluLayerfullyConnectedLayer(50)flattenLayerconcatenationLayer(1,2,Name="cat")fullyConnectedLayer(numClasses)softmaxLayer];net = addLayers(net,layers);

将feature input layer添加到网络中,并将其连接到 concatenation layer的第二个输入:

featInput = featureInputLayer(numFeatures,Name="features");
net = addLayers(net,featInput);
net = connectLayers(net,"features","cat/in2");

在绘图中可视化网络:

figure
plot(net)

3.训练网络

使用SGDM优化器进行训练,训练15个epochs,以0.01的学习率进行训练,在图表中显示训练进度并监控accuracy指标,不显示详细输出。

options = trainingOptions("sgdm", ...MaxEpochs=15, ...InitialLearnRate=0.01, ...Plots="training-progress", ...Metrics="accuracy", ...Verbose=0);

使用 trainnet 函数训练神经网络,对于分类使用交叉熵损失。

net = trainnet(dsTrain,net,"crossentropy",options);

图片

4.测试网络

通过将测试集上的预测与真实标签进行比较来测试网络的分类准确性,加载测试数据:

load DigitsDataTest

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。

scores = minibatchpredict(net,XTest,anglesTest);
YTest = scores2label(scores,classNames);

在混淆图中可视化预测:

figure
confusionchart(labelsTest,YTest)

图片

评估分类准确性:

accuracy = mean(YTest == labelsTest)

accuracy = 0.9878

查看一些预测的图像:

idx = randperm(size(XTest,4),9);
figure
tiledlayout(3,3)
for i = 1:9nexttileI = XTest(:,:,:,idx(i));imshow(I)label = string(YTest(idx(i)));title("Predicted Label: " + label)
end

图片

5.不用角度特征训练和测试网络

% 网络设计
net_without_feature = dlnetwork;
layers = [imageInputLayer(imageInputSize,Normalization="none")convolution2dLayer(filterSize,numFilters)batchNormalizationLayerreluLayerfullyConnectedLayer(numClasses)softmaxLayer];net_without_feature = addLayers(net_without_feature,layers);
% 网络训练
options = trainingOptions("sgdm", ...MaxEpochs=15, ...InitialLearnRate=0.01, ...Plots="training-progress", ...Metrics="accuracy", ...Verbose=0);dsTrain_without_feature = combine(dsX1Train,dsTTrain);net_without_feature = trainnet(dsTrain_without_feature,net_without_feature,"crossentropy",options);

图片

% 在混淆矩阵中可视化预测。
scores = minibatchpredict(net_without_feature,XTest);
YTest = scores2label(scores,classNames);
figure
confusionchart(labelsTest,YTest)

图片

% 评估分类准确性。
accuracy = mean(YTest == labelsTest)

accuracy = 0.9858

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/886970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R | 统一栅格数据的坐标系、分辨率和行列号

各位同学,在做相关性等分析时,经常会遇到各栅格数据间的行列号不统一等问题,下面的代码能直接解决这类麻烦。以某个栅格数据的坐标系、分辨率和行列号为准,统一文件夹内所有栅格并输出到新的文件夹。 代码只需要更改输入输出和ti…

UE5 第一人称射击项目学习(完结)

这个项目几乎完结了。 也算我上手的第一个纯蓝图小项目。 现在只剩下缝缝补补了。 之前把子弹设计为蓝图,这里要引入C的面向对象思想,建立成员函数。 首先双击打开子弹的蓝图 这边就可以构造成员函数 写一个print your name 在这里生成成员函数后&am…

三相正弦交流电的相序:揭秘正相序与反相序的奥秘

在电力系统中,三相正弦交流电的应用无处不在,从家庭用电到大型工业设备,都离不开它的稳定供电。然而,在三相交流电中,有一个概念常常让初学者感到困惑,那就是“相序”。今天,我们就来深入探讨一…

力扣面试题 - 24 插入

题目&#xff1a; 给定两个整型数字 N 与 M&#xff0c;以及表示比特位置的 i 与 j&#xff08;i < j&#xff0c;且从 0 位开始计算&#xff09;。 编写一种方法&#xff0c;使 M 对应的二进制数字插入 N 对应的二进制数字的第 i ~ j 位区域&#xff0c;不足之处用 0 补齐…

华为云容器监控平台

首先搜索CCE,点击云容器引擎CCE 有不同的测试&#xff0c;生产&#xff0c;正式环境 工作负载--直接查询服务名看监控 数据库都是走的一个 Redis的查看

Spring Cloud Stream实现数据流处理

1.什么是Spring Cloud Stream&#xff1f; 我看很多回答都是“为了屏蔽消息队列的差异&#xff0c;使我们在使用消息队列的时候能够用统一的一套API&#xff0c;无需关心具体的消息队列实现”。 这样理解是有些不全面的&#xff0c;Spring Cloud Stream的核心是Stream&#xf…

Notepad++--在开头快速添加行号

原文网址&#xff1a;Notepad--在开头快速添加行号_IT利刃出鞘的博客-CSDN博客 简介 本文介绍Notepad怎样在开头快速添加行号。 需求 原文件 想要的效果 方法 1.添加点号 Alt鼠标左键&#xff0c;从首行选中首列下拉&#xff0c;选中需要添加序号的所有行的首列&#xff…

【ROS2】多传感器融合、实现精准定位:robot_localization

1、简述 robot_localization在SLAM建图、导航中常用于将多个传感器融合(IMU、里程计、深度相机、GPS等),以提高定位精度,为机器人提供了在三维空间中的非线性状态估计 robot_localization包含两个状态估计节点: ekf_localization_node:扩展卡尔曼滤波(EKF),缺点是非…

【智谱开放平台-注册_登录安全分析报告】

前言 由于网站注册入口容易被机器执行自动化程序攻击&#xff0c;存在如下风险&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露&#xff0c;不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 &#xff0c;造成用户无法登陆、注册&#xff0c;大量收到垃圾短信的…

永夜星河主题特效2(星河背景 + 闪烁文字+点击星星 + 文字弹出特效)

目录 图片展示 星河背景 闪烁文字点击星星 文字弹出特效 特效介绍&#xff1a; 使用方式&#xff1a; 图片展示 星河背景 闪烁文字点击星星 文字弹出特效 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8">&l…

SAP Ariba Contracts_Author the Main Agreement

SAP Ariba Contracts_Author the Main Agreement Workspace Documents 从SAP Ariba 项目模板中继承的文档将自动出现在控件的“文档”选项卡中项目&#xff0c;Workspace Documents 管理所有合同相关文件的地方&#xff0c;如主要协议&#xff0c;附录&#xff0c;合同条款等…

C++练习题(5)

//函数打印素数 #include <iostream> using namespace std; int is_prime(int n){ for(int j2;j<n;j){ if(n%j0) return 0; } return 1; } int main(){ int i0; for(int i100;i<200;i){ if(is_prime(i)1) pr…

编程之路,从0开始:动态内存管理

Hello&#xff0c;大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路。 我们今天来学习C语言中的动态内存管理。 目录 1、为什么要有动态内存管理&#xff1f; 2、malloc和free &#xff08;1&#xff09;malloc函数 &…

【题解】AT_joisc2007_mall ショッピングモール (Mall)

原题传送门 温馨提示&#xff1a;岛国题要换行&#xff01; 需要求一个矩阵的和&#xff0c;考虑二维前缀和。 题目中不允许矩阵中有负数&#xff0c;结合求和的最小值&#xff0c;我们把负数赋为最大值不就行了吗。 接下来就是求二维前缀和了。 基于容斥原理&#xff0c;二…

Apifox软件Mock前端数据,帮忙生成API接口文档

Apifox是一款功能强大的接口调试软件&#xff0c;其特色功能丰富&#xff0c;且在前端mock数据生成方面表现出色。以下是对Apifox软件特色功能的详解&#xff0c;以及如何进行前端mock数据生成的步骤&#xff1a; https://apifox.com/help/api-docs/exporting-api https://www…

入门pandas

pandas是本书后续内容的首选库。它含有使数据清洗和分析工作变得更快更简单的数据结构和操作工具。pandas经常和其它工具一同使用&#xff0c;如数值计算工具NumPy和SciPy&#xff0c;分析库statsmodels和scikit-learn&#xff0c;和数据可视化库matplotlib。pandas是基于NumPy…

Apple Vision Pro开发001-开发配置

一、Vision Pro开发硬件和软件要求 硬件要求软件要求 1、Apple Silicon Mac(M系列芯片的Mac电脑) 2、Apple vision pro-真机调试 XCode15.2及以上&#xff0c;调试开发和打包发布Unity开发者账号&&苹果开发者账号 二 、开启无线调试 1、Apple Vision Pro和Mac连接同…

无人机与低空经济:开启新质生产力的新时代

无人机技术作为低空经济的核心技术之一&#xff0c;正以其独特的优势在多个行业中发挥着重要作用&#xff0c;成为推动新质生产力革命的重要力量。无人机的应用范围广泛&#xff0c;从农业植保到物流配送&#xff0c;从城市监测到紧急救援&#xff0c;无人机的身影无处不在&…

故障排除-------K8s挂载集群外NFS异常

故障排除-------K8s挂载集群外NFS异常 1. 故障现象2. 原因梳理2.1 排查思路2.2 确认yaml内容2.3 创建k8s内的nfs测试2.3.1 创建nfs和svc2.3.2 测试创建pvc2.3.3 测试结果 2.4 NFS服务端故障排除2.4.1 网络阻断排除2.4.2 排除服务状态问题2.4.3 排查NFS权限问题 3. 故障排除 1. …

微服务即时通讯系统的实现(客户端)----(2)

目录 1. 将protobuf引入项目当中2. 前后端交互接口定义2.1 核心PB类2.2 HTTP接口定义2.3 websocket接口定义 3. 核心数据结构和PB之间的转换4. 设计数据中心DataCenter类5. 网络通信5.1 定义NetClient类5.2 引入HTTP5.3 引入websocket 6. 小结7. 搭建测试服务器7.1 创建项目7.2…