深度学习9:简单理解生成对抗网络原理

目录

生成算法

生成对抗网络(GAN)

“生成”部分

“对抗性”部分

GAN如何运作?

培训GAN的技巧?

GAN代码示例

如何改善GAN?

结论


生成算法

您可以将生成算法分组到三个桶中的一个:

  1. 鉴于标签,他们预测相关的功能(朴素贝叶斯)
  2. 给定隐藏的表示,他们预测相关的特征(变分自动编码器,生成对抗网络)
  3. 鉴于一些功能,他们预测其余的(修复,插补)

我们将探索生成对抗网络的一些基础知识!GAN具有令人难以置信的潜力,因为他们可以学习模仿任何数据分布。也就是说,GAN可以学习在任何领域创造类似于我们自己的世界:图像,音乐,语音。

示例GAN架构

生成对抗网络(GAN)

“生成”部分

  • 叫做发电机
  • 给定某个标签,尝试预测功能
  • EX:鉴于电子邮件被标记为垃圾邮件,预测(生成)电子邮件的文本。
  • 生成模型学习各个类的分布。

“对抗性”部分

  • 称为判别者
  • 鉴于这些功能,尝试预测标签
  • EX:根据电子邮件的文本,预测(区分)垃圾邮件或非垃圾邮件。
  • 判别模型学习了类之间的界限。

GAN如何运作?

一个称为Generator的神经网络生成新的数据实例,而另一个神经网络Discriminator则评估它们的真实性。

您可以将GAN视为伪造者(发电机)和警察(Discriminator)之间的猫捉老鼠游戏。伪造者正在学习制造假钱,警察正在学习如何检测假钱。他们都在学习和提高。伪造者不断学习创造更好的假货,并且警察在检测它们时不断变得更好。最终的结果是,伪造者(发电机)现在接受了培训,可以创造出超现实的金钱!

让我们用MNIST手写数字数据集探索一个具体的例子:

MNIST手写数字数据集

我们将让Generator创建新的图像,如MNIST数据集中的图像,它取自现实世界。当从真实的MNIST数据集中显示实例时,Discriminator的目标是将它们识别为真实的。

同时,Generator正在创建传递给Discriminator的新图像。它是这样做的,希望它们也将被认为是真实的,即使它们是假的。Generator的目标是生成可通过的手写数字,以便在不被捕获的情况下进行说谎。Discriminator的目标是将来自Generator的图像分类为假的。

MNIST手写数字+ GAN架构

GAN步骤:

  1. 生成器接收随机数并返回图像。
  2. 将生成的图像与从实际数据集中获取的图像流一起馈送到鉴别器中。
  3. 鉴别器接收真实和假图像并返回概率,0到1之间的数字,1表示真实性的预测,0表示假

两个反馈循环:

  1. 鉴别器处于反馈循环中,具有图像的基本事实(它们是真实的还是假的),我们知道。
  2. 发生器与Discriminator处于反馈循环中(Discriminator将其标记为真实或伪造,无论事实如何)。

培训GAN的技巧?

在开始训练发生器之前预先识别鉴别器将建立更清晰的梯度。

训练Discriminator时,保持Generator值不变。训练发生器时,保持Discriminator值不变。这使网络能够更好地了解它必须学习的梯度。

GAN被制定为两个网络之间的游戏,重要:保持它们的平衡。如果发电机或鉴别器太好,GAN可能很难学习。

GAN需要很长时间才能训练。在单个GPU上,GAN可能需要数小时,在单个CPU上,GAN可能需要数天。

GAN代码示例

class GAN():def __init__(self):self.img_rows = 28 self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)optimizer = Adam(0.0002, 0.5)# Build and compile the discriminatorself.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])# Build and compile the generatorself.generator = self.build_generator()self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)# The generator takes noise as input and generated imgsz = Input(shape=(100,))img = self.generator(z)# For the combined model we will only train the generatorself.discriminator.trainable = False# The valid takes generated images as input and determines validityvalid = self.discriminator(img)# The combined model  (stacked generator and discriminator) takes# noise as input => generates images => determines validity self.combined = Model(z, valid)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):noise_shape = (100,)model = Sequential()model.add(Dense(256, input_shape=noise_shape))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))model.summary()noise = Input(shape=noise_shape)img = model(noise)return Model(noise, img)def build_discriminator(self):img_shape = (self.img_rows, self.img_cols, self.channels)model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))model.summary()img = Input(shape=img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, save_interval=50):# Load the dataset(X_train, _), (_, _) = mnist.load_data()# Rescale -1 to 1X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)half_batch = int(batch_size / 2)for epoch in range(epochs):# ---------------------#  Train Discriminator# ---------------------# Select a random half batch of imagesidx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]noise = np.random.normal(0, 1, (half_batch, 100))# Generate a half batch of new imagesgen_imgs = self.generator.predict(noise)# Train the discriminatord_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  Train Generator# ---------------------noise = np.random.normal(0, 1, (batch_size, 100))# The generator wants the discriminator to label the generated samples# as valid (ones)valid_y = np.array([1] * batch_size)# Train the generatorg_loss = self.combined.train_on_batch(noise, valid_y)# Plot the progressprint ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))# If at save interval => save generated image samplesif epoch % save_interval == 0:self.save_imgs(epoch)def save_imgs(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, 100))gen_imgs = self.generator.predict(noise)# Rescale images 0 - 1gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("gan/images/mnist_%d.png" % epoch)plt.close()if __name__ == '__main__':gan = GAN()gan.train(epochs=30000, batch_size=32, save_interval=200)

如何改善GAN?

GAN刚刚在2014年发明 – 它们非常新!GAN是一个很有前途的生成模型家族,因为与其他方法不同,它们可以生成非常干净和清晰的图像,并学习包含有关基础数据的有价值信息的权重。但是,如上所述,可能难以使Discriminator和Generator网络保持平衡。有很多正在进行的工作使GAN培训更加稳定。

除了生成漂亮的图片之外,还开发了一种利用GAN进行半监督学习的方法,该方法涉及鉴别器产生指示输入标签的附加输出。这种方法可以使用极少数标记示例在数据集上实现最前沿结果。例如,在MNIST上,通过完全连接的神经网络,每个类只有10个标记示例,实现了99.1%的准确度 – 这一结果非常接近使用所有60,000个标记示例的完全监督方法的最佳已知结果。这是非常有希望的,因为在实践中获得标记的示例可能非常昂贵。

结论

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54809.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

火山引擎 DataLeap:从短视频 APP 实践看如何统一数据指标口径

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 短视频正在成为越来越多人发现世界的窗口,其背后的创作者生态建设是各大短视频 APP 不可忽视的重要组成部分。 为了激励更多优质内容生产,某…

iOS 分别对一张图的局部进行磨砂,拼接起来不能贴合

效果图 需求,由于视图层级的原因,需要对图片分开进行磨砂, 然后组合在一起 如图,上下两部分,上下两个UIImageVIew大小相同,都是和图片同样的大小,只是上面的UIimageVIew 只展示上半部份 &#…

初识【类和对象】

目录 1.面向过程和面向对象初步认识 2.类的引入 3.类的定义 4.类的访问限定符及封装 5.类的作用域 6.类的实例化 7.类的对象大小的计算 8.类成员函数的this指针 1.面向过程和面向对象初步认识 C语言是面向过程的,关注的是过程,分析出求解问题的…

stm32之9.中断优先级配置

主函数main.c #include <stm32f4xx.h> #include "led.h" #include "key.h"#define PAin(n) (*(volatile uint32_t *)(0x42000000 (GPIOA_BASE0x10-0x40000000)*32 (n)*4)) #define PEin(n) (*(volatile uint32_t *)(0x42000000 (GP…

pytorch里面的nn.AdaptiveAvgPool2d

今天遇到nn.AdaptiveAvgPool2d((None, 1)) AdaptiveAvgPool2d函数详细解释&#xff1a; 2D自适应平均池化&#xff08;2D adaptive average pooling&#xff09;是一种对输入信号进行二维平均池化的操作&#xff0c;输入信号由多个输入平面&#xff08;input planes&#xff0…

为什么学嵌入式还要学单片机和人工智能?

从企业用人需求的角度来看&#xff0c;许多企业在招聘嵌入式工程师时都希望其具备一定的技能要求。其中&#xff0c;熟悉STM32单片机开发、熟悉嵌入式Linux开发以及熟悉实时操作系统开发&#xff0c;如FreeRTOS等&#xff0c;是常见的要求。掌握这些技术点的课程将为学生提供更…

成都国际车展来袭:有颜值有实力 大运新能源两款车型备受关注

今年成都国际车展现场最大看点是什么&#xff1f;诸多实力车企汇聚一堂各显神通&#xff0c;形式各样的新能源车型更是吸晴无数&#xff0c;成为消费者的购车首选。老品牌、新势力的大运新能源独具匠心设计特色展台&#xff0c;旗下两款车型悦虎和远志M1重磅登场。两款车型既有…

【设备树笔记整理6】中断系统中的设备树

1 中断概念的引入与处理流程 1.1 中断处理框图 1.2 中断程序的使用 主函数() while(1) {do_routine_task(); }中断处理函数() {handle_interrupt_task(); }如何调用中断处理函数&#xff1f; 1.3 ARM对异常(中断)的处理过程 &#xff08;1&#xff09;初始化 ① 设置中断…

Apollo自动驾驶:引领未来的智能出行

自动驾驶技术正日益成为当今科技领域的焦点&#xff0c;它代表着未来出行的一大趋势&#xff0c;而Baidu公司推出的Apollo自动驾驶平台则在这一领域中展现出强大的领导地位。本文将深入探讨Apollo自动驾驶技术的关键特点、挑战以及它对未来智能出行的影响。 Apollo自动驾驶平台…

wx:for的使用和事件传参,解构赋值的应用

在页面的.js文件中创建了一个对象&#xff0c; 并且在页面的view中调用了两种不同的方法将对象中的元素显示出来&#xff01; 第2种代码要加强理解&#xff01;&#xff01;&#xff01; 小程序中的wx:if wx:elif wx:else 其实好像c语言中的 if-elif-else 在页面的.j…

一个专业级 AI 聊天浏览器,开源了!

公众号关注 “GitHubDaily” 设为 “星标”&#xff0c;每天带你逛 GitHub&#xff01; 在 AI 模型大爆炸的今天&#xff0c;我们每天都能在技术圈见证无数个大语言模型诞生&#xff0c;其诞生速度之快&#xff0c;着实让人看得目不暇接。 对于热衷于体验、调试、评测各种大模型…

Pytorch-day08-模型进阶训练技巧

PyTorch 模型进阶训练技巧 自定义损失函数 如 cross_entropy L2正则化动态调整学习率 如每十次 *0.1 典型案例&#xff1a;loss上下震荡 1、自定义损失函数 1、PyTorch已经提供了很多常用的损失函数&#xff0c;但是有些非通用的损失函数并未提供&#xff0c;比如&#xf…

如何优化因为高亮造成的大文本(大字段)检索缓慢问题

首先还是说一下背景&#xff0c;工作中用到了 elasticsearch 的检索以及高亮展示&#xff0c;但是索引中的content字段是读取的大文本内容&#xff0c;所以后果就是索引的单个字段很大&#xff0c;造成单独检索请求的时候速度还可以&#xff0c;但是加入高亮之后检索请求的耗时…

开始MySQL之路——MySQL约束概述详解

MySQL约束 create table [if not exists] 表名(字段名1 类型[(宽度)] [约束条件] [comment 字段说明],字段名2 类型[(宽度)] [约束条件] [comment 字段说明],字段名3 类型[(宽度)] [约束条件] [comment 字段说明] )[表的一些设置]; 概念 约束英文&#xff1a;constraint 约束实…

GeoHash之存储篇

前言&#xff1a; 在上一篇文章GeoHash——滴滴打车如何找出方圆一千米内的乘客主要介绍了GeoHash的应用是如何的&#xff0c;本篇文章我想要带大家探索一下使用什么样的数据结构去存储这些Base32编码的经纬度能够节省内存并且提高查询的效率。 前缀树、跳表介绍&#xff1a; …

数据结构队列的实现

本章介绍数据结构队列的内容&#xff0c;我们会从队列的定义以及使用和OJ题来了解队列&#xff0c;话不多说&#xff0c;我们来实现吧 队列 1。队列的概念及结构 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;…

centos7搭建apache作为文件站后,其他人无法访问解决办法

在公司内网的一个虚拟机上搭建了httpsd服务&#xff0c;准备作为内部小伙伴们的文件站&#xff0c;但是搭建好之后发现别的小伙伴是无法访问我机器的。 于是寻找一下原因&#xff0c;排查步骤如下&#xff1a; 1.netstat -lnp 和 ps aux 先看下端口和 服务情况 发现均正常 2.…

设计模式-工厂设计模式

核心思想 在简单工厂模式的基础上进一步的抽象化具备更多的可扩展和复用性&#xff0c;增强代码的可读性使添加产品不需要修改原来的代码&#xff0c;满足开闭原则 优缺点 优点 符合单一职责&#xff0c;每个工厂只负责生产对应的产品符合开闭原则&#xff0c;添加产品只需添…

探讨uniapp的路由与页面生命周期问题

1 首先我们引入页面路由 2 页面生命周期函数 onLoad() {console.log(页面加载)},onShow() {console.log(页面显示)},onReady(){console.log(页面初次显示)},onHide() {console.log(页面隐藏)},onUnload() {console.log(页面卸载)},onBackPress(){console.log(页面返回)}3 页面…

代码随想录算法训练营之JAVA|第三十九天|474. 一和零

今天是第39天刷leetcode&#xff0c;立个flag&#xff0c;打卡60天。 算法挑战链接 474. 一和零https://leetcode.cn/problems/ones-and-zeroes/ 第一想法 题目理解&#xff1a;找到符合条件的子集&#xff0c;这又是一个组合的问题。 看到这个题目的时候&#xff0c;我好像…