孪生神经网络MATLAB实战[含源码]

​一、算法原理

      孪生神经网络( Siamese neural network)是一种深度学习网络,它使用两个或多个具有相同架构、共享相同参数和权重的相同子网。孪生网络通常用于寻找两个可比较事物之间的关系的任务。孪生网络的一些常见应用包括面部识别、签名验证或释义识别。孪生神经网络是基于两个人工神经网络建立的耦合架构,孪生神经网络以两个样本为输入,输出其嵌入高维空间的表征,以比较两个样本的相似程度,狭义的孪生神经网络由两个结构相同,且权重共享的神经网络拼接而成,网络框架如下图所示。

        广义的孪生神经网络(又称pseudo-siamese network,伪孪生神经网络),可由两个任意的神经网络拼接而成,可由卷积神经网络、循环神经网络等组成,网络框架如下图所示。

    简单来说,孪生神经网络就是衡量两个输入的相似程度。孪生神经网络有两个输入,将两个输入输入到两个神经网络,这两个神经网络分别将输入映射到新的空间,形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。孪生神经网络在这些任务中表现良好,因为它们的共享权重意味着在训练过程中需要学习的参数更少,并且它们可以用相对较少的训练数据产生良好的结果。

二、代码实战

%% matlab学习之家
%% 孪生神经网络
clc
clear
%% 读取训练集
dataFolderTrain = "D:\S\孪生神经网络\images_background"; %% 更换路径
imdsTrain = imageDatastore(dataFolderTrain, ...IncludeSubfolders=true, ...LabelSource="none");
​
files = imdsTrain.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"-");
imdsTrain.Labels = categorical(labels);
​
%% 显示图片
idx = randperm(numel(imdsTrain.Files),8);
​
for i = 1:numel(idx)subplot(4,2,i)imshow(readimage(imdsTrain,idx(i)))title(imdsTrain.Labels(idx(i)),Interpreter="none");
end
​
batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getTwinBatch(imdsTrain,batchSize);
​
for i = 1:batchSizeif pairLabel(i) == 1s = "similar";elses = "dissimilar";endsubplot(2,5,i)imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);title(s)
end
%% 定义孪生神经网络架构
layers = [imageInputLayer([105 105 1],Normalization="none")convolution2dLayer(10,64,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")reluLayermaxPooling2dLayer(2,Stride=2)convolution2dLayer(7,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")reluLayermaxPooling2dLayer(2,Stride=2)convolution2dLayer(4,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")reluLayermaxPooling2dLayer(2,Stride=2)convolution2dLayer(5,256,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")reluLayerfullyConnectedLayer(4096,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")];
​
net = dlnetwork(layers);
​
fcWeights = dlarray(0.01*randn(1,4096));
fcBias = dlarray(0.01*randn(1,1));
​
fcParams = struct(..."FcWeights",fcWeights,..."FcBias",fcBias);
%% 定义损失函数
numIterations = 10000;
miniBatchSize = 180;
learningRate = 6e-5;
gradDecay = 0.9;
gradDecaySq = 0.99;
executionEnvironment = "auto";
​
if canUseGPUgpu = gpuDevice;disp(gpu.Name + " GPU detected and available for training.")
end
​
trailingAvgSubnet = [];
trailingAvgSqSubnet = [];
trailingAvgParams = [];
trailingAvgSqParams = [];
​
monitor = trainingProgressMonitor(Metrics="Loss",XLabel="Iteration",Info="ExecutionEnvironment");
if canUseGPUupdateInfo(monitor,ExecutionEnvironment=gpu.Name + " GPU")
elseupdateInfo(monitor,ExecutionEnvironment="CPU")
end
​
start = tic;
iteration = 0;
%% 使用自定义训练循环训练模型。循环遍历训练数据并在每次迭代时更新网络参数。
while iteration < numIterations && ~monitor.Stop
​iteration = iteration + 1;
​[X1,X2,pairLabels] = getTwinBatch(imdsTrain,miniBatchSize);
​X1 = dlarray(X1,"SSCB");X2 = dlarray(X2,"SSCB");
​if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"X1 = gpuArray(X1);X2 = gpuArray(X2);end
​% 使用dlfeval和modelLoss评估模型损失和梯度[loss,gradientsSubnet,gradientsParams] = dlfeval(@modelLoss,net,fcParams,X1,X2,pairLabels);
​% 更新孪生神经网络参数[net,trailingAvgSubnet,trailingAvgSqSubnet] = adamupdate(net,gradientsSubnet, ...trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
​% 更新全连接层参数.[fcParams,trailingAvgParams,trailingAvgSqParams] = adamupdate(fcParams,gradientsParams, ...trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq);
​recordMetrics(monitor,iteration,Loss=loss);monitor.Progress = 100 * iteration/numIterations;
​
end
%% 读取测试集
dataFolderTest = "D:\S\孪生神经网络\images_evaluation" %% 更换路径
imdsTest = imageDatastore(dataFolderTest, ...IncludeSubfolders=true, ...LabelSource="none");
​
files = imdsTest.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"_");
imdsTest.Labels = categorical(labels);
​
numClasses = numel(unique(imdsTest.Labels));
accuracy = zeros(1,5);
accuracyBatchSize = 150;
​
for i = 1:5[X1,X2,pairLabelsAcc] = getTwinBatch(imdsTest,accuracyBatchSize);
​X1 = dlarray(X1,"SSCB");X2 = dlarray(X2,"SSCB");
​if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"X1 = gpuArray(X1);X2 = gpuArray(X2);end
​Y = predictTwin(net,fcParams,X1,X2);
​Y = gather(extractdata(Y));Y = round(Y);
​
​accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize;
end
​
averageAccuracy = mean(accuracy)*100;
​
testBatchSize = 10;
​
[XTest1,XTest2,pairLabelsTest] = getTwinBatch(imdsTest,testBatchSize);
​
XTest1 = dlarray(XTest1,"SSCB");
XTest2 = dlarray(XTest2,"SSCB");
​
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"XTest1 = gpuArray(XTest1);XTest2 = gpuArray(XTest2);
end
YScore = predictTwin(net,fcParams,XTest1,XTest2);
YScore = gather(extractdata(YScore));
YPred = round(YScore);
XTest1 = extractdata(XTest1);
XTest2 = extractdata(XTest2);
f = figure;
tiledlayout(2,5);
f.Position(3) = 2*f.Position(3);
​
predLabels = categorical(YPred,[0 1],["dissimilar" "similar"]);
targetLabels = categorical(pairLabelsTest,[0 1],["dissimilar","similar"]);
%% 用预测的标签和预测的分数绘制图像
for i = 1:numel(pairLabelsTest)nexttileimshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]);
​title( ..."Target: " + string(targetLabels(i)) + newline + ..."Predicted: " + string(predLabels(i)) + newline + ..."Score: " + YScore(i))
end

仿真结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606102.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

node.js+mysql旅游景点分享网站-计算机毕业设计源码03796

摘 要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。旅游景点分享网站设计&#xff0c;主要的模块包括查看后台首页、轮播图&#xff08;轮播图管理&#xff09;、网站公告管理&#xff08;网站公告…

AQS 抽象队列同步器

AQS AQS &#xff08;抽象队列同步器&#xff09;&#xff1a; AbstractQueuedSynchronizer 是什么 来自jdk1.5&#xff0c;是用来实现锁或者其他同步器组件的公共基础部分的抽象实现&#xff0c;是重量级基础框架以及JUC的基石&#xff0c;主要用于解决锁分配给谁的问题整体…

Linux第17步_安装SSH服务

secure shell protocol简称SSH。 目的&#xff1a;在进行数据传输之前&#xff0c;SSH先对联级数据包通过加密技术进行加密处理&#xff0c;然后再进行数据传输&#xff0c;确保数据传输安全。 1、在安装前&#xff0c;要检查虚拟机可以上网&#xff0c;否则可能会导致安装失…

电商带货品牌直播间SOP运营执行步骤

【干货资料持续更新&#xff0c;以防走丢】 电商带货品牌直播间SOP运营执行步骤 部分资料预览 资料部分是网络整理&#xff0c;仅供学习参考。 直播运营模板合集&#xff08;完整资料包含以下内容&#xff09; 目录 直播业务商业框架.png直播工作流程SOP梳理.xlsx 2023年抖…

HubSpot的内容管理系统(CMS)好用吗?

HubSpot的内容管理系统&#xff08;CMS&#xff09;通常被认为是功能强大且用户友好的工具&#xff0c;尤其适用于数字营销和在线业务。以下是一些HubSpot CMS的优势和功能&#xff1a; 用户友好的编辑界面&#xff1a; HubSpot CMS提供直观的编辑界面&#xff0c;具有拖放式编…

数字藏品如何赋能线下实体?以 BOOMSHAKE 潮流夜店为例

此篇为报告内容精华版&#xff0c;更多详细精彩内容请点击 完整版 在数字化浪潮的推动下&#xff0c;品牌和企业正在迎来一场前所未有的变革。传统市场营销策略逐渐让位于新兴技术&#xff0c;特别是非同质化代币&#xff08;NFT&#xff09;的应用。这些技术不仅改变了品牌资…

scala 安装和创建项目

Scala&#xff0c;一种可随您扩展的编程语言&#xff1a;从小型脚本到大型多平台应用程序。Scala不是Java的扩展&#xff0c;但它完全可以与Java互操作。在编译时&#xff0c;Scala文件将转换为Java字节码并在JVM&#xff08;Java虚拟机&#xff09;上运行。Scala被设计成面向对…

【JAVA】Iterator 怎么使用?有什么特点

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; JAVA ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 Iterator 接口的主要方法&#xff1a; 例子 特点&#xff1a; 结语 我的其他博客 前言 在编程的世界里&#xff0c;迭代…

DTM分布式事务

DTM分布式事务 从内网看到了关于事务在业务中的讨论&#xff0c;评论区大佬有提及DTM开源项目[https://dtm.pub/]&#xff0c;开学开学 基础理论 一、Why DTM ​ 项目产生于实际生产中的问题&#xff0c;涉及订单支付的服务会将所有业务相关逻辑放到一个大的本地事务&#xff…

卷积神经网络|迁移学习-猫狗分类完整代码实现

还记得这篇文章吗&#xff1f;迁移学习|代码实现 在这篇文章中&#xff0c;我们知道了在构建模型时&#xff0c;可以借助一些非常有名的模型&#xff0c;这些模型在ImageNet数据集上早已经得到了检验。 同时torchvision模块也提供了预训练好的模型。我们只需稍作修改&#xf…

qtday1(2024/1/8)

#include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QMainWindow(parent) {//设置界面固定大小this->resize(1728,972);this->setFixedSize(1728,972);this->setWindowIcon(QIcon("C:\\Users\\78507\\Desktop\\pic\\qq1.png"));this->…

高级RAG(五):TruLens 评估-扩大和加速LLM应用程序评估

之前我们介绍了&#xff0c;RAGAs评估&#xff0c;今天我们再来介绍另外一款RAG的评估工具:TruLens , trulens是TruEra公司的一款开源软件工具&#xff0c;它可帮助您使用反馈功函数客观地评估基于 LLM 的应用程序的质量和有效性。反馈函数有助于以编程方式评估输入、输出和中间…

vue3 内置组件

文章目录 前言一、过渡效果相关的组件1、Transition2、TransitionGroup 二、状态缓存组件&#xff08;KeepAlive&#xff09;三、传送组件&#xff08;Teleport &#xff09;四、异步依赖处理组件&#xff08;Suspense&#xff09; 前言 在vue3中 其提供了5个内置组件 Transiti…

antv/x6_2.0学习使用(四、边)

一、添加边 节点和边都有共同的基类 Cell&#xff0c;除了从 Cell 继承属性外&#xff0c;还支持以下选项。 属性名类型默认值描述sourceTerminalData-源节点或起始点targetTerminalData-目标节点或目标点verticesPoint.PointLike[]-路径点routerRouterData-路由connectorCon…

猫咪吃哪种猫粮好?主食冻干猫粮哪种性价比高

由于猫咪是肉食动物&#xff0c;对蛋白质的需求很高&#xff0c;如果摄入的蛋白质不足&#xff0c;就会影响猫咪的成长。而冻干猫粮本身因为制作工艺的原因&#xff0c;能保留原有的营养成分和营养元素&#xff0c;所以冻干猫粮蛋白含量比较高&#xff0c;营养又高&#xff0c;…

第二十七周:文献阅读笔记

第二十七周&#xff1a;文献阅读笔记 摘要AbstractDenseNet 网络1. 文献摘要2. 引言3. ResNets4. Dense Block5. Pooling layers6. Implementation Details7. Experiments8. Feature Reuse9. 代码实现 总结 摘要 DenseNet&#xff08;密集连接网络&#xff09;是一种深度学习神…

工智能基础知识总结--词嵌入之FastText

什么是FastText FastText是Facebook于2016年开源的一个词向量计算和文本分类工具,它提出了子词嵌入的方法,试图在词嵌入向量中引入构词信息。一般情况下,使用fastText进行文本分类的同时也会产生词的embedding,即embedding是fastText分类的产物。 FastText流程 FastText的架…

计算机组成原理简答题

目录 1、指令和数据在计算机内部以几进制存储&#xff0c;又是如何区分的呢&#xff1f; 2、计算机内部为什么要使用二进制&#xff1f; 3、简单描述计算机系统的层次结构 4、DRAM为什么要进行刷新&#xff0c;如何刷新的&#xff1f; 5、简述不同操作码的指令格式&#xf…

FileStream文件管理

文件管理 FileStream&#xff1a;是一个用于读写文件的一个类。它提供了基于流的方式操作文件&#xff0c;可以进行读取、写入、查找和关闭等操作。 第一个参数&#xff1a;path&#xff08;路径&#xff09; 相对路径&#xff1a;相对于当前项目的bin目录下的Debug和Realse来…

[嵌入式AI从0开始到入土]10_yolov5在昇腾上应用

[嵌入式AI从0开始到入土]嵌入式AI系列教程 注&#xff1a;等我摸完鱼再把链接补上 可以关注我的B站号工具人呵呵的个人空间&#xff0c;后期会考虑出视频教程&#xff0c;务必催更&#xff0c;以防我变身鸽王。 第一章 昇腾Altas 200 DK上手 第二章 下载昇腾案例并运行 第三章…