昇思25天学习打卡营第22天|GAN图像生成

今天是参加昇思25天学习打卡营的第22天,今天打卡的课程是“GAN图像生成”,这里做一个简单的分享。

1.简介

今天来学习“GAN图像生成”,这是一个基础的生成式模型。

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

2.模型架构

  • 模型原理

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥) 需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)𝐷数值应该趋近于 00 。因此 𝐷(𝑥) 也可以被认为是传统的二分类器。

用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧) 的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 𝑝𝑑𝑎𝑡𝑎(𝑥) 的数据分布,我们希望找到 θ 使得 𝑝𝐺(𝑥;𝜃) 和𝑝𝑑𝑎𝑡𝑎(𝑥) 尽可能的接近,其中𝜃 代表网络参数。

𝐷(𝐺(𝑧))表示生成器 𝐺𝐺生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数 log⁡𝐷(𝑥);而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数log⁡(1−𝐷(𝐺(𝑧)))。因此GAN的损失函数为:
在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是𝑝𝐺(𝑥;𝜃)=𝑝𝑑𝑎𝑡𝑎(𝑥),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
  • 核心代码

生成器代码:

from mindspore import nn
import mindspore.ops as opsimg_size = 28  # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器代码:

# 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')
  • 损失函数和优化器
lr = 0.0002  # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

3.小结

今天学了GAN用于图像生成的基本理论和编码方法。GAN模型由生成器(Generative Model)和判别器(Discriminative Model)构成两个相互对抗的模型。生成器负责生成图像进行,判别器用于判定图像真假,通过对抗的模式使得真假判定的结果接近1:1,进而完成训练。这样训练好的生成器即可用于图形生成。这里面着重要掌握对抗网络损失函数的意义,这是的对抗网络能够输出最正确结果的要点。

以上是第22天的学习内容,附上今日打卡记录:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/46926.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Bug:时间字段显示有问题

Bug:时间字段显示有问题 文章目录 Bug:时间字段显示有问题1、问题2、解决方法一:添加注解3、解决方法二:消息转换器自定义对象映射器配置消息转换器 1、问题 ​ 在后端传输时间给前端的时候,发现前端的时间显示有问题…

代码trick 类型判断

文章目录 判断 null 和 undefined 判断 null 和 undefined vue 源码里的技巧,即 value null 用的是双等号。在双等号中,null 和 undefined 是相等的,也就是说 value 是 null 或 undefined 都会为 true if( value null ){// ... }

[Spring] Spring Web MVC案例实战

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

AV1技术学习:Translational Motion Compensation

编码块根据运动矢量在参考帧中找到相应的预测块,如下图所示,当前块的左上角的位置为(x0, y0),在参考帧中找到同样位置(x0, y0)的块,根据运动矢量移动到目标参考块(左上角位置为:(x1, y1))。 AV1…

前端a-tree遇到的问题

在使用a-tree时候,给虚拟滚动的高度,然后展开a-tree滑动一段距离 比如这样 随后你切换页面,在返回这个页面的时候 就会出现这样的bug 解决方法: onBeforeRouteLeave((to, from, next) > {// 可以在路由参数变化时执行的逻辑ke…

白山云荣获信通院“算网安全行业应用优秀案例”奖

日前,在由中国通信标准化协会算网融合产业及标准推进委员会与信通院共同组织召开的“2024年算网融合产业发展大会”上,白山云凭借创新的SD-WAN算网融合方案,荣获“算网安全行业应用优秀案例”奖。 算网融合是多元异构、海量泛在的算力设施&am…

path模块和HTTP协议

一。path模块常用API ./相对路径,/绝对路径 二,HTTP协议 1.请求报文 1.请求行 URL的组成 2.请求头 3.请求体 可以是空:GET请求 可以是字符串,还可以是json:POST请求 2.响应报文 1.响应行 HTTP / 1.1 200 OK H…

VsCode 与远程服务器 ssh免密登录

首先配置信息 加入下列信息 Host qb-zn HostName 8.1xxx.2xx.3xx User root ForwardAgent yes Port 22 IdentityFile ~/.ssh/id_rsa 找到自己的公钥,不带pub是私钥,打死都不能给别人。复制公钥 拿到公钥后,来到远程服务器 vim ~/.ss…

Leetcode—3011. 判断一个数组是否可以变为有序【中等】(__builtin_popcount()、ranges::is_sorted())

2024每日刷题&#xff08;144&#xff09; Leetcode—3011. 判断一个数组是否可以变为有序 O(n)复杂度实现代码 class Solution { public:bool canSortArray(vector<int>& nums) {// 二进制数位下1数目相同的元素就不进行组内排序// 只进行分组// 当前组的值若小于…

人工智能算法工程师(中级)课程12-PyTorch神经网络之LSTM和GRU网络与代码详解1

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程12-PyTorch神经网络之LSTM和GRU网络与代码详解。在深度学习领域,循环神经网络(RNN)因其处理序列数据的能力而备受关注。然而,传统的RNN存在梯度消失和梯度爆炸的问题,这使得它在长序列任务中的表现不尽…

MySQL--C_C++语言连接访问

Connector/C的使用 首先需要在mysql官网下载C接口库 解压指令 tar -zxvf 压缩包名 下载并解压好后 但是还有比这更优的做法。 这样子手动安装不仅麻烦&#xff0c;还可能存在兼容性的问题。 其实在我们使用yum安装mysql时&#xff0c;大概率会自动帮我们把其他的环境都安装…

[JS]认识feach

介绍 feach是浏览器内置的api, 用于发送网络请求 请求方式对比 AJAX: 基于XMLHttpRequest接收请求, 使用繁琐Axios: 基于Promise的请求客户端, 在浏览器和node中均可使用, 使用简单fetch: 浏览器内置的api, 基于Promise, 功能简单 基础语法 <body><button>发请求…

OracleLinux6.9升级UEK内核

方法一: [root@localhost ~]# uname -r 4.1.12-61.1.28.el6uek.x86_64 [root@localhost ~]# rpm -qa | grep kernel-uek kernel-uek-firmware-4.1.12-61.1.28.el6uek.noarch kernel-uek-4.1.12-61.1.28.el6uek.x86_64 [root@localhost ~]# yum list kernel-uek Loaded plug…

金豺狼优化算法(GWO)及其Python和MATLAB实现

金豺狼优化算法&#xff08;GWO&#xff09;是一种启发式优化算法&#xff0c;灵感来源于灰狼群体的社会行为和层级结构。GWO算法由Mirjalili等人于2014年提出&#xff0c;通过模拟灰狼群体的捕猎行为&#xff0c;寻找最优解。相比于其他优化算法&#xff0c;GWO算法具有较好的…

探索Gradle自动化测试:一站式测试框架配置指南

探索Gradle自动化测试&#xff1a;一站式测试框架配置指南 在当今快速迭代的软件开发周期中&#xff0c;自动化测试是确保代码质量和快速反馈的关键。Gradle&#xff0c;作为一个强大的构建工具&#xff0c;提供了丰富的插件和配置选项来支持自动化测试。本文将深入探讨如何在…

【Datawhale AI夏令营】电力需求预测挑战赛 Task01

整个学习活动&#xff0c;将带你从 跑通最简的Baseline&#xff0c;到了解竞赛通用流程、深入各个竞赛环节&#xff0c;精读Baseline与进阶实践 文章目录 一、赛题背景二、赛题任务三、实践步骤学习规划分析思路常见时序场景 task01codecode 解读 一、赛题背景 随着全球经济的…

如何在linux中给vim编辑器添加插件

在Linux系统中给Vim编辑器添加插件通常通过插件管理器来完成&#xff0c;以下是一般的步骤&#xff1a; 1.使用插件管理器安装插件 安装插件管理器&#xff08;如果尚未安装&#xff09;&#xff1a; 常见的插件管理器包括 Vundle、vim-plug 和 Pathogen 等。你可以根据个人喜…

TF和TF-IDF区别和联系

TF&#xff08;Term Frequency&#xff09;和TF-IDF&#xff08;Term Frequency-Inverse Document Frequency&#xff09;都是用于文本挖掘和信息检索的统计方法&#xff0c;用于评估一个词在文档或文档集合中的重要性。 一.TF&#xff08;Term Frequency&#xff09; 1.定义…

CSA笔记1-基础知识和目录管理命令

[litonglocalhost ~]$ 是终端提示符&#xff0c;类似于Windows下的cmd的命令行 litong 当前系统登录的用户名 分隔符 localhost 当前机器名称&#xff0c;本地主机 ~ 当前用户的家目录 $ 表示当前用户为普通用户若为#则表示当前用户为超级管理员 su root 切换root权限…

昇思25天学习打卡营第12天|munger85

基于MindSpore通过GPT实现情感分类 这个实现情感分类意思就是通过一些电影的数据最后知道他对于这个电影的评价&#xff0c;最后知道他对于这个电影的评价到底是好还是不好&#xff0c;零就是不好&#xff0c;一就是好。首先我们肯定是按安装这些依赖包了为了今天这个模型我们…