1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)

1、卷积神经网络的手写数字旋转角度预测原理及流程

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:

原理:

  1. 数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。

  2. 模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。

  3. 训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。

流程:

  1. 加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。

  2. 构建CNN模型:使用MATLAB深度学习工具箱中的函数(如convolution2dLayermaxPooling2dLayerfullyConnectedLayerclassificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。

  3. 定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。

  4. 训练模型:使用训练数据集对CNN模型进行训练,通过调用trainNetwork函数并传入训练数据和训练选项来完成训练过程。

  5. 评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。

  6. 预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。

这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。

2、卷积神经网络的手写数字旋转角度预测案例说明

1)解决问题

卷积神经网络来预测手写数字的旋转角度

2)技术方案

回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。

3、加载数据

1)数据说明

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

2)加载数据代码

说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

3)显示训练集代码

numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);

 视图效果

06f6163700a5438b802835e39b5c0504.png

4)数据集划分代码

说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。

[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);

4、检查数据归一化

1)归一化说明

训练神经网络时,确保数据在网络的所有阶段均归一化

对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.

数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离

归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1

2)绘制响应的分布代码

说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。

figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")

视图效果 

b293b09f497e445ea96b6b4fd31045df.png

5、定义神经网络架构

1)神经网络架构说明

对于图像输入,指定一个图像输入层。

指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。

在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。

在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。

2)神经网络架构代码

numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];

6、指定训练选项

1)指定训练选项说明

使用Experiment Manager。

将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。

通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。

在图中显示训练进度并监控均方根误差。

2)指定训练选项代码

miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);

7、训练神经网络

1)训练神经网络说明

使用 trainnet 函数训练神经网络。

对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

2)训练神经网络代码

net = trainnet(XTrain,anglesTrain,layers,"mse",options);

视图效果

 eaa6e5e1078e4fd8bd4d9c5fe0cf06f0.png

8、测试网络

1)测试网络说明

基于测试数据评估准确度来测试网络性能。

使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

2)测试网络代码

YTest = minibatchpredict(net,XTest);

3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异 

predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))

 4)散点图中可视化预测。绘制预测值对真实值的图。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")

视图效果 

a35e3dd73632423a92451bfbd0ae7b66.png

9、使用新数据进行预测

1)测试说明

使用 predict 函数并使用神经网络对第一个测试图像进行预测

2)测试代码

X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

10、总结

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:

总结要点:

  1. 数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。

  2. 模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。

  3. 训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。

实现流程:

  1. 数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。

  2. CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。

  3. 训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。

  5. 预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。

通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。

11、源代码

代码

%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%%  定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

工程文件

https://download.csdn.net/download/XU157303764/89494539

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

铁威马教程丨如何收集NAS的日志

适用版本: 适用于TOS 5.0.xxx、TOS5.1.xxx版本。 适用机型: TNAS型号(除F2-210、F4-210) 故障现象: 当TNAS宕机导致网页不可访问且PC无法搜索到该设备时,重启后TOS网页的系统报告缺失相关日志,不利于异常…

第一百六十八节 Java IO教程 - Java Zip文件

Java IO教程 - Java Zip文件 Java对ZIP文件格式有直接支持。通常,我们将使用java.util.zip包中的以下四个类来处理ZIP文件格式: ZipEntryZipInputStreamZipOutputStreamZipFile ZipEntry对象表示ZIP文件格式的归档文件中的条目。 zip条目可以是压缩的…

大部分公司都是草台班子,甚至更水

我第一份实习是在一家咨询公司,我以为我们能够给我们的客户提供极具商业价值的战略指导,但其实开始干活了之后,发现我们就是PPT和调研报告的搬运工。后来我去了一家互联网大厂,我以为我的身边全都是逻辑超强的技术和产品大佬&…

先“精益”,还是先“信息化”

在当今这个快速变化、竞争激烈的制造业时代,工厂的高效运营与持续创新成为了企业脱颖而出的关键。而工厂精益管理与信息化系统,正是这两把推动企业转型升级的利器,它们相辅相成。 精益管理,作为一种以消除浪费、提升价值为核心的…

[ACM独立出版] 2024年虚拟现实、图像和信号处理国际学术会议(VRISP 2024,8月2日-4)

2024年虚拟现实、图像和信号处理国际学术会议(VRISP 2024)将于2024年8月2-4日在中国厦门召开。 VRISP 2024将围绕“虚拟现实、图像和信号处理”的最新研究领域,为来自国内外高等院校、科学研究所、企事业单位的专家、教授、学者、工程师等提供…

【Leetcode】二十一、前缀树 + 词典中最长的单词

文章目录 1、背景2、前缀树Trie3、leetcode208:实现Trie4、leetcode720:词典中最长的单词 1、背景 如上,以浏览器搜索时的自动匹配为例: 如果把所有搜索关键字放一个数组里,则:插入、搜索一个词条时&#x…

SEO:6个避免被搜索引擎惩罚的策略-华媒舍

在当今数字时代,搜索引擎成为了绝大多数人获取信息和产品的首选工具。为了在搜索结果中获得良好的排名,许多网站采用了各种优化策略。有些策略可能会适得其反,引发搜索引擎的惩罚。以下是彭博社发稿推广的6个避免被搜索引擎惩罚的策略。 1. 内…

一文带你看懂SAP-HANA的基本架构与原理

注:本篇主要对SAP HANA做了总结与论述,如有错误欢迎读者提出并补充 创作不易,希望大家一键三连支持!!!♥♥♥ 创作不易,希望大家一键三连支持!!!♥♥♥ 创作不易,希望大家一键三连支持!!!♥♥♥ 目录 一. 背景引入1.1 硬件与数据库系统1.2 行业现状 …

AES Android IOS H5 加密方案

前景: 1、本项目原有功能RSA客户端对敏感信息进行加密 2、本次漏洞说是服务端返回值有敏感信息,需要密文返回 方案: 本次方案不算完美,还是有被劫持篡改的风险,但基本https证书认证加持,风险相对较小 …

Camera XTS 处理笔记

和你一起终身学习,这里是程序员Android 经典好文推荐,通过阅读本文,您将收获以下知识点: 常用测试步骤(下面均以CTS为例) 打开终端,进入 cts 包 tools目录下执行 ./cts-tradefed 进入cts测试 :~/XTS/CTS/14…

永磁同步电机高性能控制算法(14)—— 有源阻尼电流环

1.前言 在之前的之后中已经发过一篇复矢量电流环和我们平时用的比较多的前馈补偿的电流环的对比,感觉复矢量电流环的效果还是挺明显的。 https://zhuanlan.zhihu.com/p/682880365https://zhuanlan.zhihu.com/p/682880365 当时在看文献的时候,复矢量电…

AI算法17-贝叶斯岭回归算法Bayesian Ridge Regression | BRR

贝叶斯岭回归算法简介 贝叶斯岭回归(Bayesian Ridge Regression)是一种回归分析方法,它结合了岭回归(Ridge Regression)的正则化特性和贝叶斯统计的推断能力。这种方法在处理具有大量特征的数据集时特别有用&#xff…

13、Shell自动化运维编程基础

弋.目录 RHCE板块一、为什么学习和使用Shell编程二、Shell是什么1、shell起源2、查看当前系统支持的shell3、查看当前系统默认shell4、Shell 概念 三、Shell 程序设计语言1、Shell 也是一种脚本语言2、用途 四、如何学好shell1、熟练掌握shell编程基础知识2、建议 五、Shell脚本…

英伟达股票1拆10后,现在再买入是否为时已晚?

英伟达股票1拆10后,现在再买入是否为时已晚? 英伟达的股价在过去18个月里已经上涨了近800% 人工智能领域无疑是当下最受投资者关注的焦点之一,而这一领域的佼佼者--英伟达,也被一些华尔街投资机构和看好半导体、数据中心行业的专业…

SoulApp创始人张璐团队以AI驱动社交进化,平台社交玩法大变革

在科技飞速发展的今天,人工智能正逐步渗透到社交媒体的各个环节,赋能全链路社交体验。AI的引入不仅提升了内容推荐的精准度,使用户能够更快速地发现感兴趣的内容,还能通过用户行为预测,帮助平台更好地理解和满足用户需求。此外,AI驱动的虚拟助手和聊天机器人也正在改变用户互动…

NVIDIA RTX 50系显卡接口全变,功耗爆炸超500W

七月伊始,手机圈就开始打的不可开交了。 例如真我 GT6、IQOO Neo 9S、以及蓄势待发的红米 K70 Ultra,都想在这个暑假向莘莘学子发出最诚挚的「邀请函」。 反观电脑圈这边,不能说一潭死水,只能说毫无波澜。 不过该来的还是要来的&…

Redis的使用(四)常见使用场景-缓存使用技巧

1.绪论 redis本质上就是一个缓存框架,所以我们需要研究如何使用redis来缓存数据,并且如何解决缓存中的常见问题,缓存穿透,缓存击穿,缓存雪崩,以及如何来解决缓存一致性问题。 2.缓存的优缺点 2.1 缓存的…

睿考网:造价员和造价工程师是一个意思吗?

在工程建设领域中,经常会有人问:“造价员和造价工程师是一样的吗?”这两者代表的是两种独立的职业身份,职责和资格要求有明显的差异,是两种完全不同的考试。 造价工程师是一种具有专业资质的人员,通过国家统一的执业…

『 Linux 』命名管道

文章目录 命名管道与匿名管道命名管道特点命名管道的理解命名管道实现两个毫无关联的进程间通信 命名管道与匿名管道 命名管道是管道的一种,数据流向为单向故被称为管道; 与匿名管道相同属于一种内存级文件; 区别如下: 名字 匿名管道 没有名字,只存在于内存当中(类似内核缓冲…

【软件测试】编写测试用例篇

前面部分主要是编写测试用例的方法和方向,后面一部分是编写出具体的测试用例 目录 什么是测试用例 1.设计测试用例的万能公式 1.1.从思维出发 1.2.万能公式 1.3.弱网测试 1.4.安装与卸载测试 2.设计测试用例的方法 2.1.基于需求的设计方法 2.2.等价类 2.3…