Matlab使用深度网络设计器为迁移学习准备网络

迁移学习通过对预训练网络进行微调,使深度学习模型能在少量数据下快速适应新任务,类似于“举一反三”,而不需要从头训练。本文使用matlab自带的深度网络设计器,可以便捷地修改预训练网络进行迁移学习,通过对预训练网络的最后一层进行解锁,调高学习率,使新层中的学习速度快于迁移层的学习速度,即可输入新数据,完成迁移学习的训练。

在实际工程应用中,构建并训练一个大规模的卷积神经网络是比较复杂的,需要大量的数据以及高性能的硬件。迁移学习则是“另辟蹊径”,将训练好的典型的网络稍加改进,用少量的数据进行训练并加以应用,将一个预先训练好的网络作为学习新任务的起点。利用迁移学习对网络进行微调,通常比从头开始训练一个随机初始化权重的网络要快得多、容易得多。使用较少数量的训练图像,就能将学习到的特征快速迁移到新任务中。

1.使用的数据

本文使用的数据集为 MathWorks Merch 数据集,这是包含 75 幅 MathWorks 商品图像的小型数据集,这些商品分属五个不同类(瓶盖、魔方、扑克牌、螺丝刀和手电筒)。数据的排列使得图像位于对应于这五个类的子文件夹中。

folderName = "MerchData";
unzip("MerchData.zip",folderName);

创建一个图像数据存储(Datastore),通过图像数据存储可以存储大图像数据集合,包括无法放入内存的数据,并在神经网络的训练过程中高效分批读取图像。对图像文件夹建立子文件夹,使用imageDatastore,可以将子文件夹名称转换为图像标签。

imds = imageDatastore(folderName, ...IncludeSubfolders=true, ...LabelSource="foldernames");

显示一些示例图像:

numImages = numel(imds.Labels);
idx = randperm(numImages,16);
I = imtile(imds,Frames=idx);
figure
imshow(I)

2.提取数据

提取类名称和类数目:

classNames = categories(imds.Labels);
numClasses = numel(classNames)
numClasses = 5

将数据划分为训练、验证和测试数据集。将 70% 的图像用于训练,15% 的图像用于验证,15% 的图像用于测试。splitEachLabel 函数将图像数据存储拆分为两个新数据存储。

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.7,0.15,0.15,"randomized");

3.选择预训练网络

3.1 深度网络设计器

从命令行打开深度网络设计器App:

deepNetworkDesigner

深度网络设计器提供一些精选的预训练图像分类网络,这些网络已学习适用于各种图像的丰富特征表示。如果您的图像与最初用于训练网络的图像相似,迁移学习效果最好。如果您的训练图像是像 ImageNet 数据库中那样的自然图像,则任一预训练网络都合适。了解有哪些预训练深度神经网络,请参阅预训练的深度神经网络。

如果数据与 ImageNet 数据相差很大(例如,如果您有很小的图像、频谱图或非图像数据),训练新网络可能效果更好。选择SqueezeNet网络:

3.2 浏览网络

深度网络设计器在设计器窗格中显示整个网络的缩小视图。

选择图像输入层 'input',可以看到该网络的输入大小为 227×227×3 像素。 

将输入大小保存在变量 inputSize 中:

inputSize = [227 227 3];

4.准备要训练的网络

要使用预训练网络进行迁移学习,更改类的数量以匹配新数据集。首先,找到网络中的最后一个可学习层。对于 SqueezeNet,最后一个可学习层是最后一个卷积层,'conv10'。选择 'conv10' 层。在属性窗格的底部,点击解锁层。在出现的警告对话框中,点击仍要解锁,这将解锁层属性,以便使其适应新任务。

在 R2023b 之前:要使网络适应新数据,替换层而不是解锁层。在新的卷积二维层中,将 FilterSize 设置为 [1 1]。

NumFilters 属性定义用于分类问题的类的数量,将 NumFilters 更改为新数据中的类数量,此示例中为 5。

通过将 WeightLearnRateFactor 和 BiasLearnRateFactor 设置为 10 来更改学习率,使新层中的学习速度快于迁移层的学习速度。

要检查网络是否准备好进行训练,点击分析按钮,深度学习网络分析器报告零错误或警告,说明网络已准备就绪,可以开始进行训练。

要导出网络,请点击导出,该 App 将网络保存在变量 net_1 中。

数据存储中图像的大小可以不同,要自动调整训练图像的大小,请使用增强的图像数据存储。数据增强还有助于防止网络过拟合和记忆训练图像的具体细节。指定要对训练图像额外执行的这些增强操作:沿垂直轴随机翻转训练图像,以及在水平和垂直方向上随机平移训练图像最多 30 个像素。

pixelRange = [-30 30];imageAugmenter = imageDataAugmenter( ...RandXReflection=true, ...RandXTranslation=pixelRange, ...RandYTranslation=pixelRange);augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...DataAugmentation=imageAugmenter);

对于验证集和测试集,只调整图像大小,而不进行其他预处理操作。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

5.指定训练选项

指定训练选项,在选项中进行选择需要经验分析。

  • 使用 Adam 优化器进行训练。

  • 将初始学习率设置为较小的值以减慢迁移的层中的学习速度。

  • 指定少量轮数。一轮训练是对整个训练数据集的一个完整训练周期。对于迁移学习,所需的训练轮数相对较少。

  • 指定验证数据和验证频率,以便每经过一轮训练就计算一次基于验证数据的准确度。

  • 指定mini batch size,即每次迭代中使用多少个图像。为了确保在每轮训练中都使用整个数据集,请设置mini batch size以均分训练样本的数量。

  • 在图中显示训练进度并监控准确度度量。

  • 禁用详尽输出。

options = trainingOptions("adam", ...InitialLearnRate=0.0001, ...MaxEpochs=8, ...ValidationData=imdsValidation, ...ValidationFrequency=5, ...MiniBatchSize=11, ...Plots="training-progress", ...Metrics="accuracy", ...Verbose=false);

6.训练和测试网络

使用 trainnet 函数训练神经网络,对于分类任务,损失函数使用交叉熵损失。

net = trainnet(imdsTrain,net_1,"crossentropy",options);

对测试图像进行分类,要使用多个观测值进行预测,使用 minibatchpredict 函数;要将预测分数转换为标签,使用 scores2label 函数。

YTest = minibatchpredict(net,augimdsTest);
YTest = scores2label(YTest,classNames);

在混淆图中可视化分类准确度:

TTest = imdsTest.Labels;
figure
confusionchart(TTest,YTest);

图片

7.使用新数据进行预测

对一个图像进行分类,从 JPEG 文件中读取一个图像,调整其大小,并将其转换为单精度数据类型。

im = imread("MerchDataTest.jpg");
im = imresize(im,inputSize(1:2));
X = single(im);

对图像进行分类,要使用单个观测值进行预测,使用 predict 函数;要使用 GPU,先将数据转换为 gpuArray。

if canUseGPUX = gpuArray(X);
end
scores = predict(net,X);
[label,score] = scores2label(scores,classNames);

显示具有预测标签和对应分数的图像:

figure
imshow(im)
title(string(label) + " (Score: " + gather(score) + ")")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/60200.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于yolov8、yolov5的电塔缺陷检测识别系统(含UI界面、训练好的模型、Python代码、数据集)

摘要:电塔缺陷检测在电力设备巡检、运行维护和故障预防中起着至关重要的作用,不仅能帮助相关部门实时监测电塔运行状态,还为智能化检测系统提供了可靠的数据支撑。本文介绍了一款基于YOLOv8、YOLOv5等深度学习框架的电塔缺陷检测模型&#xf…

代理商培训新策略:内部知识库的高效运用

在竞争激烈的市场环境中,代理商作为企业与终端消费者之间的关键纽带,其专业能力和服务质量直接影响着企业的市场表现和品牌形象。因此,如何对代理商进行高效、系统的培训,以提升其业务能力和服务水平,成为企业亟需解决…

uniapp 相关的swiper的一些注意事项

先推荐一个一个对标pc端swiper的uniapp版本 zebra-swiper 缺点是自定义分页器不是很好处理 不知道怎么弄 优点:可以进行高度自适应 &#xff08;这个uniapp原生swiper没有 只能动态修改 采用js 或者只有几种固定高度时采用变量修改&#xff09; <swiperref"lifeMiddle…

ARM中ZI-data段和RW-data段

ARM中ZI-data段和RW-data段 1、只定义全局变量&#xff0c;不使用&#xff0c;不占用内存空间2、 定义并初始化全局变量为0 占用ZI-Data区域3、定义并初始化全局变量非0 占用RW-Data区域4、增加的是一个int8的数据为什么&#xff0c;size增加不是15、定义的全局变量为0&#xf…

jmeter--CSV数据文件设置--请求体设置变量

目录 一、示例 1、准备组织列表的TXT文件&#xff0c;如下&#xff1a; 2、添加 CSV数据文件设置 &#xff0c;如下&#xff1a; 3、接口请求体设置变量&#xff0c;如下&#xff1a; 二、CSV数据文件设置 1、CSV Data Set Config 配置选项说明 2、示例 CSV 文件内容 3、…

golang实现TCP服务器与客户端的断线自动重连功能

1.服务端 2.客户端 生成服务端口程序: 生成客户端程序: 测试断线重连: 初始连接成功

ssm148基于Spring MVC框架的在线电影评价系统设计与实现+jsp(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;在线电影评价系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本在线电影评价系…

DAY1 网络编程(TCP客户端服务器)

作业&#xff1a; TCP客户端服务器。 server服务器代码&#xff1a; #include <myhead.h> #define IP "192.168.110.52" #define PORT 8886 #define BACKLOG 20 int main(int argc, const char *argv[]) {int oldfdsocket(AF_INET,SOCK_STREAM,0);//IPV4通信…

基于arduino 用ESP8266获取实时MAX30102 血氧数据动态曲线显示在网页上

基于arduino 用ESP8266获取实时MAX30102 血氧数据动态曲线显示在网页上 原理&#xff1a; ESP8266获取MAX30102 血氧数据&#xff08;R,IR,G的值&#xff09;发送到路由器局域网内&#xff0c;局域网内的手机电脑&#xff0c;访问ESP的ip地址&#xff0c;获取实时的血氧数据动…

vue3:scss引用

原文查看&#xff1a;https://mp.weixin.qq.com/s?__bizMzg3NTAzMzAxNA&mid2247484356&idx2&sn44b127cd394e217b9e3c4eccafdc0aa9&chksmcec6fb1df9b1720b7bd0ca0b321bf8a995fc8cba233deb703512560cbe451cfb1f05cdf129f6&token1776233257&langzh_CN#rd…

SrpingBoot基础

SpringBoot基本框架中重要常用的包讲解: .idea包和.mvn包框架生成不经常用 src包下主要存放前后端代码: main包下的java包存放的是后端java代码主要负责数据处理 resource包下存放的是配置资源和前端页面,其中static中存放的是前端html网页一般存放静 态资源,templates包…

Nacos实现IP动态黑白名单过滤

一些恶意用户&#xff08;可能是黑客、爬虫、DDoS 攻击者&#xff09;可能频繁请求服务器资源&#xff0c;导致资源占用过高。因此我们需要一定的手段实时阻止可疑或恶意的用户&#xff0c;减少攻击风险。 本次练习使用到的是Nacos配合布隆过滤器实现动态IP黑白名单过滤 文章…

vue-next-admin框架配置(vue)

vue-next-admin 先安装依赖 npm i 依赖, npm run dev 运行 1.配置代理 2.把他的逻辑和自己的登录判断逻辑结合(我的放下面&#xff0c;可以参考哦&#xff0c;可以直接使用&#xff0c;到时候修改登录逻辑就好)&#xff0c;别忘了引入ajxio哦 const onSignIn async () &g…

算法定制LiteAIServer视频智能分析平台工业排污检测算法智控环保监管

随着工业化进程的加快&#xff0c;环境污染问题愈加严重&#xff0c;尤其是工业排污对生态环境的影响引发了广泛关注。在此背景下&#xff0c;视频智能分析平台LiteAIServer工业排污检测算法应运而生&#xff0c;作为一种先进的智能化解决方案&#xff0c;它在监测和管理工业排…

mini-lsm通关笔记Week2Day5

项目地址&#xff1a;https://github.com/skyzh/mini-lsm 个人实现地址&#xff1a;https://gitee.com/cnyuyang/mini-lsm Summary 在本章中&#xff0c;您将&#xff1a; 实现manifest文件的编解码。系统重启时从manifest文件中恢复。 要将测试用例复制到启动器代码中并运行…

【WPF】Prism学习(六)

Prism Dependency Injection 1.依赖注入&#xff08;Dependency Injection&#xff09; 1.1. Prism与依赖注入的关系&#xff1a; Prism框架一直围绕依赖注入构建&#xff0c;这有助于构建可维护和可测试的应用程序&#xff0c;并减少或消除对静态和循环引用的依赖。 1.2. P…

学习ASP.NET Core的身份认证(基于Cookie的身份认证1)

B/S架构程序可通过Cookie、Session、JWT、证书等多种方式认证用户身份&#xff0c;虽然之前测试过用户登录代码&#xff0c;也学习过开源项目中的登录认证&#xff0c;但其实还是对身份认证疑惑甚多&#xff0c;就比如登录验证后用户信息如何保存、客户端下次连接时如何获取用户…

使用Cursor和Claude AI打造你的第一个App

大家好&#xff0c;使用Cursor和Claude AI打造应用程序是一个结合智能代码辅助和人工智能对话的创新过程。Cursor是一个编程辅助工具&#xff0c;它通过智能代码补全、聊天式AI对话和代码生成等功能&#xff0c;帮助开发者提高编程效率。Claude AI则是一个强大的人工智能平台&a…

ssm152家庭财务管理系统设计与实现+jsp(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;家庭财务管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本家庭财务管理系…