1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)

1、卷积神经网络的手写数字旋转角度预测原理及流程

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:

原理:

  1. 数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。

  2. 模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。

  3. 训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。

流程:

  1. 加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。

  2. 构建CNN模型:使用MATLAB深度学习工具箱中的函数(如convolution2dLayermaxPooling2dLayerfullyConnectedLayerclassificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。

  3. 定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。

  4. 训练模型:使用训练数据集对CNN模型进行训练,通过调用trainNetwork函数并传入训练数据和训练选项来完成训练过程。

  5. 评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。

  6. 预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。

这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。

2、卷积神经网络的手写数字旋转角度预测案例说明

1)解决问题

卷积神经网络来预测手写数字的旋转角度

2)技术方案

回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。

3、加载数据

1)数据说明

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

2)加载数据代码

说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

3)显示训练集代码

numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);

 视图效果

06f6163700a5438b802835e39b5c0504.png

4)数据集划分代码

说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。

[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);

4、检查数据归一化

1)归一化说明

训练神经网络时,确保数据在网络的所有阶段均归一化

对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.

数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离

归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1

2)绘制响应的分布代码

说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。

figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")

视图效果 

b293b09f497e445ea96b6b4fd31045df.png

5、定义神经网络架构

1)神经网络架构说明

对于图像输入,指定一个图像输入层。

指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。

在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。

在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。

2)神经网络架构代码

numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];

6、指定训练选项

1)指定训练选项说明

使用Experiment Manager。

将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。

通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。

在图中显示训练进度并监控均方根误差。

2)指定训练选项代码

miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);

7、训练神经网络

1)训练神经网络说明

使用 trainnet 函数训练神经网络。

对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

2)训练神经网络代码

net = trainnet(XTrain,anglesTrain,layers,"mse",options);

视图效果

 eaa6e5e1078e4fd8bd4d9c5fe0cf06f0.png

8、测试网络

1)测试网络说明

基于测试数据评估准确度来测试网络性能。

使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

2)测试网络代码

YTest = minibatchpredict(net,XTest);

3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异 

predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))

 4)散点图中可视化预测。绘制预测值对真实值的图。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")

视图效果 

a35e3dd73632423a92451bfbd0ae7b66.png

9、使用新数据进行预测

1)测试说明

使用 predict 函数并使用神经网络对第一个测试图像进行预测

2)测试代码

X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

10、总结

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:

总结要点:

  1. 数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。

  2. 模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。

  3. 训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。

实现流程:

  1. 数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。

  2. CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。

  3. 训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。

  5. 预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。

通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。

11、源代码

代码

%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%%  定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

工程文件

https://download.csdn.net/download/XU157303764/89494539

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

铁威马教程丨如何收集NAS的日志

适用版本: 适用于TOS 5.0.xxx、TOS5.1.xxx版本。 适用机型: TNAS型号(除F2-210、F4-210) 故障现象: 当TNAS宕机导致网页不可访问且PC无法搜索到该设备时,重启后TOS网页的系统报告缺失相关日志,不利于异常…

try-catch-finally使用注意事项

1.catch块和finally块可以省略其中一个。 2.finally块在try和catch的return之前执行。(return时会暂存,执行finally后再return) 如果finally中有return,那就直接return了。 /*** 省略finally 语句块*/public static void omitFina…

第一百六十八节 Java IO教程 - Java Zip文件

Java IO教程 - Java Zip文件 Java对ZIP文件格式有直接支持。通常,我们将使用java.util.zip包中的以下四个类来处理ZIP文件格式: ZipEntryZipInputStreamZipOutputStreamZipFile ZipEntry对象表示ZIP文件格式的归档文件中的条目。 zip条目可以是压缩的…

大部分公司都是草台班子,甚至更水

我第一份实习是在一家咨询公司,我以为我们能够给我们的客户提供极具商业价值的战略指导,但其实开始干活了之后,发现我们就是PPT和调研报告的搬运工。后来我去了一家互联网大厂,我以为我的身边全都是逻辑超强的技术和产品大佬&…

20240715题目

589. N 叉树的前序遍历 给定一个 n 叉树的根节点 root &#xff0c;返回 其节点值的 前序遍历 。 n 叉树 在输入中按层序遍历进行序列化表示&#xff0c;每组子节点由空值 null 分隔&#xff1b; /* // Definition for a Node. class Node { public:int val;vector<Node*…

先“精益”,还是先“信息化”

在当今这个快速变化、竞争激烈的制造业时代&#xff0c;工厂的高效运营与持续创新成为了企业脱颖而出的关键。而工厂精益管理与信息化系统&#xff0c;正是这两把推动企业转型升级的利器&#xff0c;它们相辅相成。 精益管理&#xff0c;作为一种以消除浪费、提升价值为核心的…

pgsql(guass)可获取到对应的表名称、字段名称、注释、字段类型

pgsql可获取到对应的表名称、字段名称、注释、字段类型(GUASS的也是适用) SELECT c.relname as 表名,a.attname as 字段名,format_type(a.atttypid,a.atttypmod) as 类型,a.attnotnull as 非空, col_description(a.attrelid,a.attnum) as 注释 FROM pg_class as c,pg_attri…

[ACM独立出版] 2024年虚拟现实、图像和信号处理国际学术会议(VRISP 2024,8月2日-4)

2024年虚拟现实、图像和信号处理国际学术会议&#xff08;VRISP 2024&#xff09;将于2024年8月2-4日在中国厦门召开。 VRISP 2024将围绕“虚拟现实、图像和信号处理”的最新研究领域&#xff0c;为来自国内外高等院校、科学研究所、企事业单位的专家、教授、学者、工程师等提供…

llama-cpp-python

文章目录 一、关于 llama-cpp-python二、安装安装配置支持的后端Windows 笔记MacOS笔记升级和重新安装 三、高级API1、简单示例2、从 Hugging Face Hub 中提取模型3、聊天完成4、JSON和JSON模式JSON模式JSON Schema 模式 5、函数调用6、多模态模型7、Speculative Decoding8、Em…

【Leetcode】二十一、前缀树 + 词典中最长的单词

文章目录 1、背景2、前缀树Trie3、leetcode208&#xff1a;实现Trie4、leetcode720&#xff1a;词典中最长的单词 1、背景 如上&#xff0c;以浏览器搜索时的自动匹配为例&#xff1a; 如果把所有搜索关键字放一个数组里&#xff0c;则&#xff1a;插入、搜索一个词条时&#x…

SEO:6个避免被搜索引擎惩罚的策略-华媒舍

在当今数字时代&#xff0c;搜索引擎成为了绝大多数人获取信息和产品的首选工具。为了在搜索结果中获得良好的排名&#xff0c;许多网站采用了各种优化策略。有些策略可能会适得其反&#xff0c;引发搜索引擎的惩罚。以下是彭博社发稿推广的6个避免被搜索引擎惩罚的策略。 1. 内…

一文带你看懂SAP-HANA的基本架构与原理

注&#xff1a;本篇主要对SAP HANA做了总结与论述&#xff0c;如有错误欢迎读者提出并补充 创作不易,希望大家一键三连支持!!!♥♥♥ 创作不易,希望大家一键三连支持!!!♥♥♥ 创作不易,希望大家一键三连支持!!!♥♥♥ 目录 一. 背景引入1.1 硬件与数据库系统1.2 行业现状 …

AES Android IOS H5 加密方案

前景&#xff1a; 1、本项目原有功能RSA客户端对敏感信息进行加密 2、本次漏洞说是服务端返回值有敏感信息&#xff0c;需要密文返回 方案&#xff1a; 本次方案不算完美&#xff0c;还是有被劫持篡改的风险&#xff0c;但基本https证书认证加持&#xff0c;风险相对较小 …

对Spring的理解,项目中都用什么?怎么用的?对IOC、和AOP的理解及实现原理

对Spring的理解 Spring是一个开源的Java平台&#xff0c;它提供了全面的基础设施支持&#xff0c;以便你可以更容易地开发Java应用程序。Spring解决了企业应用程序开发中的很多复杂性&#xff0c;提供了以下核心功能&#xff1a; - 依赖注入&#xff08;IoC&#xff09;&…

Camera XTS 处理笔记

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: 常用测试步骤&#xff08;下面均以CTS为例&#xff09; 打开终端&#xff0c;进入 cts 包 tools目录下执行 ./cts-tradefed 进入cts测试 :~/XTS/CTS/14…

C++知识点总结(45):序列动态规划

序列动态规划 一、意义二、例题1. 最长上升子序列2. 合唱队形&#xff08;加强版&#xff09;3. 公共子序列4. 编辑距离 一、意义 动态规划&#xff08;dynamic programming&#xff09;&#xff0c;将一个目标大问题“大事化小&#xff0c;小事化了”&#xff0c;分成很多的子…

永磁同步电机高性能控制算法(14)—— 有源阻尼电流环

1.前言 在之前的之后中已经发过一篇复矢量电流环和我们平时用的比较多的前馈补偿的电流环的对比&#xff0c;感觉复矢量电流环的效果还是挺明显的。 https://zhuanlan.zhihu.com/p/682880365https://zhuanlan.zhihu.com/p/682880365 当时在看文献的时候&#xff0c;复矢量电…

AI算法17-贝叶斯岭回归算法Bayesian Ridge Regression | BRR

贝叶斯岭回归算法简介 贝叶斯岭回归&#xff08;Bayesian Ridge Regression&#xff09;是一种回归分析方法&#xff0c;它结合了岭回归&#xff08;Ridge Regression&#xff09;的正则化特性和贝叶斯统计的推断能力。这种方法在处理具有大量特征的数据集时特别有用&#xff…

13、Shell自动化运维编程基础

弋.目录 RHCE板块一、为什么学习和使用Shell编程二、Shell是什么1、shell起源2、查看当前系统支持的shell3、查看当前系统默认shell4、Shell 概念 三、Shell 程序设计语言1、Shell 也是一种脚本语言2、用途 四、如何学好shell1、熟练掌握shell编程基础知识2、建议 五、Shell脚本…

英伟达股票1拆10后,现在再买入是否为时已晚?

英伟达股票1拆10后&#xff0c;现在再买入是否为时已晚&#xff1f; 英伟达的股价在过去18个月里已经上涨了近800% 人工智能领域无疑是当下最受投资者关注的焦点之一&#xff0c;而这一领域的佼佼者--英伟达&#xff0c;也被一些华尔街投资机构和看好半导体、数据中心行业的专业…