用matlab搭建一个简单的图像分类网络

文章目录

  • 1、数据集准备
  • 2、网络搭建
  • 3、训练网络
  • 4、测试神经网络
  • 5、进行预测
  • 6、完整代码

1、数据集准备

首先准备一个包含十个数字文件夹的DigitsData,每个数字文件夹里包含1000张对应这个数字的图片,图片的尺寸都是 28×28×1 像素的,如下图所示

在这里插入图片描述

matlab 中imageDatastore 函数会根据文件夹名称自动为图像进行分类标注。该数据集包含 10 个类别。

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels

将数据划分为训练集、验证集和测试集。使用 70% 的图像作为训练数据,15% 作为验证数据,15% 作为测试数据。指定使用 “randomized”(随机化),以便从每个类别中按指定比例随机分配图像到新的数据集中。
splitEachLabel 函数用于将图像数据存储对象划分成三个新的数据存储对象。

% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");
  • splitEachLabel:MATLAB 中的函数,用于根据每个标签(类别)分别划分图像数据集。这样可以确保每个类别在训练集、验证集和测试集中都有代表性。
  • imds:原始的图像数据存储对象,包含所有图像和对应的标签。
  • 0.7:表示将每个类别中 70% 的图像用于训练集
  • 0.15:表示每个类别中 15% 的图像用于验证集
  • 0.15:表示每个类别中 15% 的图像用于测试集
  • "randomized":表示在划分数据集时使用随机抽样,避免按文件顺序导致划分不均衡。
  • [imdsTrain, imdsValidation, imdsTest]:返回三个新的 imageDatastore 对象,分别代表:
    • imdsTrain:训练数据集
    • imdsValidation:验证数据集
    • imdsTest:测试数据集

2、网络搭建

这里,我们需要借用到matlab工具栏里APPS里的Deep Network Designer,如下图所示

在这里插入图片描述

在Deep Network Designer, 我们创建一个空白Designer画布

在这里插入图片描述

然后我们可以拖动相应的层到Designer里,并连接各个层,如下图所示

在这里插入图片描述

这里,我们只需要改一下输入层的InputSize就行,如下图

在这里插入图片描述

然后,我们可以检查这个网络可行不可行,通过Analyze按钮,就会得到这个网络的分析结果,如下图

在这里插入图片描述

没有错误,就可以通过Export按钮输出这个网络到Matlab工作区,这个网络被自动被命名为net_1。
在这里插入图片描述

3、训练网络

指定训练选项。不同选项的选择需要依赖实验分析(即通过反复试验和比较来确定最优配置)。

% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  MaxEpochs = 4, ...  ValidationData = imdsValidation, ... ValidationFrequency = 30, ...  Plots = "training-progress", ...  Metrics = "accuracy", ...  Verbose = false); 

trainingOptions 是 MATLAB 中用于设置神经网络训练参数的函数。

"sgdm" 是一种常用优化算法,适用于多数分类问题。

MaxEpochs=4 设置为 4 是为了快速试验,实际训练中可以设置更大,比如 10、20 甚至更多。

ValidationFrequency=30 表示每 30 次 mini-batch 后在验证集上评估一次性能,值越小越频繁,但也会增加验证的耗时。

Plots="training-progress" 是非常有用的调试和可视化工具,能帮助你观察训练是否收敛。

Verbose=false 适合在图形界面中查看结果时使用;如果希望看到文字日志,可以设置为 true

使用 trainnet 函数训练神经网络。由于目标是分类任务,因此使用交叉熵损失函数(cross-entropy loss)

% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);
  • imdsTrain:训练数据集,是一个图像数据存储对象(imageDatastore),包含用于训练的图像和对应标签。
  • net_1:要训练的神经网络结构(可由 layerGraphdlnetwork 等方式定义的网络)。
  • "crossentropy":指定损失函数为交叉熵损失函数(cross-entropy loss),这是分类任务中最常用的损失函数,特别适用于多类分类问题。
  • options:训练选项,由前面设置的 trainingOptions 定义,包含训练轮数、验证数据、优化器、可视化等信息。

返回值:

  • net:训练完成后的神经网络,包含了优化后的权重和结构,可用于后续的预测或评估。

在这里插入图片描述

4、测试神经网络

使用 testnet 函数对神经网络进行测试。对于单标签分类任务,评估指标为准确率(accuracy),即预测正确的百分比。默认情况下,testnet 函数会在可用时自动使用 GPU。如果希望手动选择执行环境,可以使用 testnet 函数的 ExecutionEnvironment 参数进行设置。

% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");
  • net:已训练好的神经网络模型,是前面通过 trainnet 得到的结果。
  • imdsTest:测试数据集,是一个图像数据存储对象(imageDatastore),用于测试模型的性能。
  • "accuracy":评估指标,这里指定为准确率,即预测正确的样本数量占总样本数量的百分比。

返回值:

  • accuracy:一个介于 0 和 1 之间的小数,表示模型在测试集上的准确率。例如,accuracy = 0.93 表示模型在测试集中有 93% 的预测是正确的。

testnet 函数自动根据你的硬件情况选择在 CPU 还是 GPU 上运行。如果你想手动指定环境,比如使用 CPU,可以这样写:

accuracy = testnet(net, imdsTest, "accuracy", ExecutionEnvironment="cpu");

5、进行预测

使用 minibatchpredict 函数进行预测,并通过 scores2label 函数将预测得分转换为类别标签。默认情况下,如果有可用的 GPU,minibatchpredict 会自动使用 GPU 进行计算。

% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsValidation);% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YValidation = scores2label(scores, classNames);

可视化部分预测结果:

% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9nexttile  % 在下一个网格位置准备绘图img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像imshow(img)  % 显示图像title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

在这里插入图片描述

6、完整代码

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  MaxEpochs = 4, ...  ValidationData = imdsValidation, ... ValidationFrequency = 30, ...  Plots = "training-progress", ...  Metrics = "accuracy", ...  Verbose = false); % 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9nexttile  % 在下一个网格位置准备绘图img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像imshow(img)  % 显示图像title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74972.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 语言语法精讲:从 Java 开发者的视角全面掌握

《Go 语言语法精讲:从 Java 开发者的视角全面掌握》 一、引言1.1 为什么选择 Go?1.2 适合 Java 开发者的原因1.3 本文目标 二、Go 语言环境搭建2.1 安装 Go2.2 推荐 IDE2.3 第一个 Go 程序 三、Go 语言基础语法3.1 变量与常量3.1.1 声明变量3.1.2 常量定…

如何选择优质的安全工具柜:材质、结构与功能的考量

在工业生产和实验室环境中,安全工具柜是必不可少的设备。它不仅承担着工具的存储任务,还直接影响工作环境的安全和效率。那么,如何选择一个优质的安全工具柜呢?关键在于对材质、结构和功能的考量。 01材质:耐用与防腐 …

系统与网络安全------Windows系统安全(11)

资料整理于网络资料、书本资料、AI,仅供个人学习参考。 制作U启动盘 U启动程序 下载制作U启程序 Ventoy是一个制作可启动U盘的开源工具,只需要把ISO等类型的文件拷贝到U盘里面就可以启动了 同时支持x86LegacyBIOS、x86_64UEFI模式。 支持Windows、L…

【5】搭建k8s集群系列(二进制部署)之安装master节点组件(kube-controller-manager)

注&#xff1a;承接专栏上一篇文章 一、创建配置文件 cat > /opt/kubernetes/cfg/kube-controller-manager.conf << EOF KUBE_CONTROLLER_MANAGER_OPTS"--logtostderrfalse \\ --v2 \\ --log-dir/opt/kubernetes/logs \\ --leader-electtrue \\ --kubeconfig/op…

C#里第一个WPF程序

WPF程序对界面进行优化,但是比WINFORMS的程序要复杂很多, 并且界面UI基本上不适合拖放,所以需要比较多的时间来布局界面, 产且需要开发人员编写更多的代码。 即使如此,在面对诱人的界面表现, 随着客户对界面的需求提高,还是需要采用这样的方式来实现。 界面的样式采…

createContext+useContext+useReducer组合管理React复杂状态

createContext、useContext 和 useReducer 的组合是 React 中管理全局状态的一种常见模式。这种模式非常适合在不引入第三方状态管理库&#xff08;如 Redux&#xff09;的情况下&#xff0c;管理复杂的全局状态。 以下是一个经典的例子&#xff0c;展示如何使用 createContex…

记一次常规的网络安全渗透测试

目录&#xff1a; 前言 互联网突破 第一层内网 第二层内网 总结 前言 上个月根据领导安排&#xff0c;需要到本市一家电视台进行网络安全评估测试。通过对内外网进行渗透测试&#xff0c;网络和安全设备的使用和部署情况&#xff0c;以及网络安全规章流程出具安全评估报告。本…

el-table,新增、复制数据后,之前的勾选状态丢失

需要考虑是否为 更新数据的方式不对 如果新增数据的方式是直接替换原数据数组&#xff0c;而不是通过正确的响应式数据更新方式&#xff08;如使用 Vue 的 this.$set 等方法 &#xff09;&#xff0c;也可能导致勾选状态丢失。 因为 Vue 依赖数据的响应式变化来准确更新视图和…

第15届蓝桥杯java-c组省赛真题

目录 一.拼正方形 1.题目 2.思路 3.代码 二.劲舞团 1.题目 2.思路 3.代码 三.数组诗意 1.题目 2.思路 3.代码 四.封闭图形个数 1.题目 2.思路 3.代码 五.吊坠 1.题目 六.商品库存管理 1.题目 2.思路 3.代码 七.挖矿 1.题目 2.思路 3.代码 八.回文字…

玄机-应急响应-入侵排查

靶机排查目标&#xff1a; 1.web目录存在木马&#xff0c;请找到木马的密码提交 查看/var/www/html。 使用find命令查找 find ./ -type f -name "*.php | xargs grep "eval("查看到1.php里面存在无条件一句话木马。 2.服务器疑似存在不死马&#xff0c;请找…

usbip学习记录

USB/IP: USB device sharing over IP make menuconfig配置&#xff1a; Device Drivers -> Staging drivers -> USB/IP support Device Drivers -> Staging drivers -> USB/IP support -> Host driver 如果还有作为客户端的需要&#xff0c;继续做以下配置&a…

爱普生高精度车规晶振助力激光雷达自动驾驶

在自动驾驶技术快速落地的今天&#xff0c;激光雷达作为车辆的“智慧之眼”&#xff0c;其测距精度与可靠性直接决定了自动驾驶系统的安全上限。而在这双“眼睛”的核心&#xff0c;爱普生&#xff08;EPSON&#xff09;的高精度车规晶振以卓越性能成为激光雷达实现毫米级感知的…

28--当路由器开始“宫斗“:设备控制面安全配置全解

当路由器开始"宫斗"&#xff1a;设备控制面安全配置全解 引言&#xff1a;路由器的"大脑保卫战" 如果把网络世界比作一座繁忙的城市&#xff0c;那么路由器就是路口执勤的交通警察。而控制面&#xff08;Control Plane&#xff09;就是警察的大脑&#xf…

58.基于springboot老人心理健康管理系统

目录 1.系统的受众说明 2.相关技术 2.1 B/S结构 2.2 MySQL数据库 3.系统分析 3.1可行性分析 3.1.1时间可行性 3.1.2 经济可行性 3.1.3 操作可行性 3.1.4 技术可行性 3.1.5 法律可行性 3.2系统流程分析 3.3系统功能需求分析 3.4 系统非功能需求分析 4.系统设计 …

去中心化固定利率协议

核心机制与分类 协议类型&#xff1a; 借贷协议&#xff08;如Yield、Notional&#xff09;&#xff1a;通过零息债券模型&#xff08;如fyDai、fCash&#xff09;锁定固定利率。 收益聚合器&#xff08;如Saffron、BarnBridge&#xff09;&#xff1a;通过风险分级或博弈论…

反射率均值与RCS均值的计算方法差异

1. 反射率均值&#xff08;Mean Reflectance&#xff09; 定义&#xff1a; 反射率是物体表面反射的电磁波能量与入射能量的“比例”&#xff0c;通常以百分比或小数表示。 反射率均值是对多个测量点反射率的算术平均&#xff0c;反映目标区域整体的平均反射特性。 特点&a…

[MySQL初阶]MySQL(8)索引机制:下

标题&#xff1a;[MySQL初阶]MySQL&#xff08;8&#xff09;索引机制&#xff1a;下 水墨不写bug 文章目录 四、从问题到底层&#xff0c;从现象到本质1.为什么插入的数据默认排好序2.MySQL的Page&#xff08;1&#xff09;为什么选择用Page&#xff1f;&#xff08;2&#x…

Access:在移动互联网与AI时代焕发新生

Microsoft Access&#xff1a;在移动互联网与AI时代焕发新生 在移动互联网和人工智能&#xff08;AI&#xff09;技术快速发展的今天&#xff0c;许多传统工具被认为已经过时。然而&#xff0c;Microsoft Access&#xff0c;这款曾经风靡一时的数据库&#xff0c;真的已经被淘…

【无人机】无人机PX4飞控系统高级软件架构

目录 1、概述&#xff08;图解&#xff09; 一、数据存储层&#xff08;Storage&#xff09; 二、外部通信层&#xff08;External Connectivity&#xff09; 三、核心通信枢纽&#xff08;Message Bus&#xff09; 四、硬件驱动层&#xff08;Drivers&#xff09; 五、飞…

【项目日记】高并发服务器项目总结

生活总是让我们遍体鳞伤&#xff0c; 但到后来&#xff0c; 那些受伤的地方一定会变成我们最强壮的地方。 -- 《老人与海》-- 高并发服务器项目总结 模块关系图项目工具模块缓冲区模块通用类型模块套接字socket模块信道Channel模块多路转接Poller模块 Reactor模块时间轮Tim…