MATLAB环境下生成对抗网络系列(11种)

为了构建有效的图像深度学习模型,数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括:几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增强、对抗训练、生成对抗网络和风格迁移等内容。增强的数据代表一个分布覆盖性更广、可靠性更高的数据点集,使用增强数据能够有效增加训练样本的多样性,最小化训练集和验证集以及测试集之间的距离。使用数据增强后的数据集训练模型,可以达到提升模型稳定性、泛化能力的效果。

使用生成对抗网络GAN提取原数据集特征,对抗生成新的目标域图像,已成为众多学者在数据增强技术研究中的优选方法。相比于传统的图像数据增强方法,通过基于GAN的生成式建模技术进行数据增强的思路来源于博弈论中的二人零和博弈,由网络中包含的生成器和判别器利用对抗学习的方法来指导网络训练,在两个网络对抗过程中估计原始数据样本的分布并生成与之相似的新数据。

近期的研究在原始生成对抗网络框架的基础上又提出了多种不同的改进方案,通过设计不同的神经网络架构和损失函数等手段不断提升生成对抗网络的性能。生成对抗网络应用在图像数据增强任务上的思想主要是其通过生成新的训练数据来扩充模型的训练数据,通过样本空间的扩充实现图像分类等任务效果的提升。但目前基于GAN的图像数据增强技术普遍存在模型收敛不稳定、生成图像质量低等问题,如何正确引入高频信息,提升图像数据质量是破解这一系列问题的关键。

MATLAB环境配置如下:

  • MATLAB 2021b
  • Deep Learning Toolbox
  • Parallel Computing Toolbox (optional for GPU usage)

目录如下

  • Generative Adversarial Network (GAN) [paper]
  • Least Squares Generative Adversarial Network (LSGAN) [paper]
  • Deep Convolutional Generative Adversarial Network (DCGAN) [paper]
  • Conditional Generative Adversarial Network (CGAN)[paper]
  • Auxiliary Classifier Generative Adversarial Network (ACGAN) [paper]
  • InfoGAN [paper]
  • Adversarial AutoEncoder (AAE)[paper]
  • Pix2Pix[paper]
  • Wasserstein Generative Adversarial Network (WGAN) [paper]
  • Semi-Supervised Generative Adversarial Network (SGAN) [paper]
  • CycleGAN [paper]
  • DiscoGAN [paper]

部分代码如下:

首先,导入相关的mnist手写数字图

load('mnistAll.mat')

然后对训练、测试图像进行预处理

trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;%测试标签

preprocess为归一化函数,如下

function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end

然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等

settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

下面进行编码器初始化,代码还是很容易看懂的

paramsEn.FCW1 = dlarray(initializeGaussian([512,...prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));

解码器初始化

paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));

判别器初始化

paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];

开始训练

dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~outtic; shuffleid = randperm(size(trainX,2));trainXshuffle = trainX(:,shuffleid);fprintf('Epoch %d\n',epoch) for i=1:numIterationsglobal_iter = global_iter+1;idx = (i-1)*settings.batch_size+1:i*settings.batch_size;XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');[GradEn,GradDe,GradDis] = ...dlfeval(@modelGradients,XBatch,...paramsEn,paramsDe,paramsDis,settings);% 更新判别器网络参数[paramsDis,avgG.Dis,avgGS.Dis] = ...adamupdate(paramsDis, GradDis, ...avgG.Dis, avgGS.Dis, global_iter, ...settings.lrD, settings.beta1, settings.beta2);% 更新编码器网络参数[paramsEn,avgG.En,avgGS.En] = ...adamupdate(paramsEn, GradEn, ...avgG.En, avgGS.En, global_iter, ...settings.lrG, settings.beta1, settings.beta2);% 更新解码器网络参数[paramsDe,avgG.De,avgGS.De] = ...adamupdate(paramsDe, GradDe, ...avgG.De, avgGS.De, global_iter, ...settings.lrG, settings.beta1, settings.beta2);if i==1 || rem(i,20)==0progressplot(paramsDe,settings);if i==1 h = gcf;% 捕获图像frame = getframe(h); im = frame2im(frame); [imind,cm] = rgb2ind(im,256); % 写入 GIF 文件if epoch == 0imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf); else imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append'); end endendendelapsedTime = toc;disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")epoch = epoch+1;if epoch == settings.maxepochsout = true;end    
end

下面是完整的辅助函数

模型的梯度计算函数

function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...dly(settings.latent_dim+1:2*settings.latent_dim)*...randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end

提取数据函数

function x = gatext(x)
x = gather(extractdata(x));
end

GPU深度学习数组wrapper函数

function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end

权重初始化函数

function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end

dropout函数

function dly = dropout(dlx,p)
if nargin < 2p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end

编码器函数

function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end

解码器函数

function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end

判别器函数

function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/681562.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四、案例 - Oracle数据迁移至MySQL

Oracle数据迁移至MySQL 一、生成测试数据表和数据1.在Oracle创建数据表和数据2.在MySQL创建数据表 二、生成模板文件1.模板文件内容2.模板文件参数详解2.1 全局设置2.2 数据读取&#xff08;Reader&#xff09;2.3 数据写入&#xff08;Writer&#xff09;2.4 性能设置 三、案例…

每日一题(最大连续1的个数,完全数计算)

485. 最大连续 1 的个数 - 力扣&#xff08;LeetCode&#xff09; #include <stdio.h> int findMaxConsecutiveOnes(int* nums, int numsSize) { if (numsSize 0) return 0; // 如果数组为空&#xff0c;返回0 int maxCount 0; // 最大连续1的个数 int currentCo…

使用耳机壳UV树脂制作私模定制耳塞的大小和形状对音质有影响吗?

使用耳机壳UV树脂制作私模定制耳塞的大小和形状对音质有影响。私模定制耳塞是根据用户的耳型定制的&#xff0c;因此其大小和形状与用户的耳朵形状相匹配&#xff0c;能够减少漏音和外部噪音的干扰&#xff0c;提供更好的音质体验。 具体来说&#xff0c;私模定制耳塞的大小和形…

分享96个jQuery特效,总有一款适合您

分享96个jQuery特效&#xff0c;总有一款适合您 96个jQuery特效下载链接&#xff1a;https://pan.baidu.com/s/1Pibj41ibHKTmdW7FHfRLjg?pwd8888 提取码&#xff1a;8888 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理…

React+Antd实现表格自动向上滚动

1、效果 2、环境 1、react18 2、antd 4 3、代码实现 原理&#xff1a;创建一个定时器&#xff0c;修改表格ant-table-body的scrollTop属性实现滚动&#xff0c;监听表层的元素div的鼠标移入和移出实现实现鼠标进入元素滚动暂停&#xff0c;移出元素的时候表格滚动继续。 一…

【Godot4自学手册】第十三节初建创建敌人

从本节起&#xff0c;将要学习创建第一人。 一、创建敌人动画 1.导入素材。 在Sprites文件夹下新建Enemy文件夹&#xff0c;并将需要的敌人素材导入到文件夹。文档结构如下&#xff1a; 2.创建Enemy场景。 新建场景&#xff0c;根节点设置为CharacterBody2D&#xff0c;命…

最新wordpress外贸主题

日用百货wordpress外贸主题 蓝色大气的wordpress外贸主题&#xff0c;适合做日用百货的外贸公司搭建跨境电商网站使用。 https://www.jianzhanpress.com/?p5248 添加剂wordpress外贸建站主题 橙色wordpress外贸建站主题&#xff0c;适合做食品添加剂或化工添加剂的外贸公司…

使用MICE进行缺失值的填充处理

在我们进行机器学习时&#xff0c;处理缺失数据是非常重要的&#xff0c;因为缺失数据可能会导致分析结果不准确&#xff0c;严重时甚至可能产生偏差。处理缺失数据是保证数据分析准确性和可靠性的重要步骤&#xff0c;有助于确保分析结果的可信度和可解释性。 在本文中&#…

家政小程序系统源码开发:引领智能生活新篇章

随着科技的飞速发展&#xff0c;小程序作为一种便捷的应用形态&#xff0c;已经深入到我们生活的方方面面。尤其在家庭服务领域&#xff0c;家政小程序的出现为人们带来了前所未有的便利。它不仅简化了家政服务的流程&#xff0c;提升了服务质量&#xff0c;还为家政服务行业注…

工程问题与学术研究的融合 —— 校企合作

一、工程问题与学术研究的常规融合方法 工程问题与学术研究的融合通常体现在“产学研结合”的模式中&#xff0c;具体策略如下&#xff1a; 1. 需求导向&#xff1a;从实际工程问题出发&#xff0c;明确科研目标。在解决工程问题的过程中&#xff0c;识别出需要进一步研究的基…

Vue.js2+Cesium1.103.0 十五、计算方位角

Vue.js2Cesium1.103.0 十五、计算方位角 Demo <template><divid"cesium-container"style"width: 100%; height: 100%;"/> </template><script> /* eslint-disable no-undef */ /* eslint-disable new-cap */ /* eslint-disable n…

代码随想录算法训练营第三十天 | 重新安排行程、N皇后、解数独

目录 重新安排行程N皇后解数独总结 LeetCode 332.重新安排行程 LeetCode 51. N皇后 LeetCode 37. 解数独 重新安排行程 给定一个机票的字符串二维数组 [from, to]&#xff0c;子数组中的两个成员分别表示飞机出发和降落的机场地点&#xff0c;对该行程进行重新规划排序。所有…

今日早报 每日精选15条新闻简报 每天一分钟 知晓天下事 2月14日,星期三

每天一分钟&#xff0c;知晓天下事&#xff01; 2024年2月14日 星期三 农历正月初五 1、 第十四届全国冬季运动会将于17日开幕&#xff0c;部分赛事今天起陆续开赛。 2、 2024年购房政策将进一步宽松&#xff0c;专家称今年买房性价比更高。 3、 春节档票房突破45亿元&#…

docker 3.1 镜像

docker 3.1 镜像命令 拉取镜像 docker pull debian #从 Docker Hub 拉取名为 debian 的镜像docker pull hello-world #从 Docker Hub 拉入名为 hello-world 的镜像‍ 运行镜像/容器 docker run hello-world ‍ 查看本地所有的镜像 docker images​​ 容器生成镜像…

【数据结构】链表OJ面试题3《判断是否有环》(题库+解析)

1.前言 前五题在这http://t.csdnimg.cn/UeggB 后三题在这http://t.csdnimg.cn/gbohQ 记录每天的刷题&#xff0c;继续坚持&#xff01; 2.OJ题目训练 9. 给定一个链表&#xff0c;判断链表中是否有环。 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成…

如何使用六图一表七种武器

六图一表七种武器用于质量管理&#xff1a; 描述当遇到问题时应该用那张图来解决&#xff1a; 一、如果题目说出了质量问题需要找原因&#xff1f; 解&#xff1a;用因果图&#xff0c;因果图也称石川图或鱼骨图 二、如果要判断过程是否稳定受控&#xff1f; 解&#xff1a…

4核8G服务器支持多少人同时在线访问?

腾讯云4核8G服务器支持多少人在线访问&#xff1f;支持25人同时访问。实际上程序效率不同支持人数在线人数不同&#xff0c;公网带宽也是影响4核8G服务器并发数的一大因素&#xff0c;假设公网带宽太小&#xff0c;流量直接卡在入口&#xff0c;4核8G配置的CPU内存也会造成计算…

【51单片机】LCD1602(江科大)

1.LCD1602介绍 LCD1602(Liquid Crystal Display)液晶显示屏是一种字符型液晶显示模块,可以显示ASCII码的标准字符和其它的一些内置特殊字符,还可以有8个自定义字符 显示容量:162个字符,每个字符为5*7点阵 2.引脚及应用电路 3.内部结构框图 屏幕: 字模库:类似于数码管的数…

一起玩儿Proteus仿真(C51)——05. 红绿灯仿真(一)

摘要&#xff1a;本文介绍如何仿真红绿灯 今天来做一个红绿灯仿真的程序&#xff0c;这个程序主要包括一下这些功能&#xff1a; 模拟的路口为十字交叉路口&#xff0c;假设东西和南北方向都是双向行驶&#xff0c;因此需要设置4组红绿灯和4个倒计时显示屏。倒计时时间最长为9…

【教程】C++语言基础学习笔记(七)——Array数组

写在前面&#xff1a; 如果文章对你有帮助&#xff0c;记得点赞关注加收藏一波&#xff0c;利于以后需要的时候复习&#xff0c;多谢支持&#xff01; 【C语言基础学习】系列文章 第一章 《项目与程序结构》 第二章 《数据类型》 第三章 《运算符》 第四章 《流程控制》 第五章…