谷歌开源的 GAN 库--TFGAN

本文大约 8000 字,阅读大约需要 12 分钟

第一次翻译,限于英语水平,可能不少地方翻译不准确,请见谅!

最近谷歌开源了一个基于 TensorFlow 的库–TFGAN,方便开发者快速上手 GAN 的训练,其 Github 地址如下:

https://github.com/tensorflow/models/tree/master/research/gan

原文网址:Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)


如果你玩过波斯王子,那你应该知道你需要保护自己不被”影子“所杀掉,但这也是一个矛盾:如果你杀死“影子”,那游戏就结束了;但你不做任何事情,那么游戏也会输掉。

尽管生成对抗网络(GAN)有不少优点,但它也面临着相似的区分问题。大部分支持 GAN 的深度学习专业也是非常谨慎的支持它,并指出它确实存在稳定性的问题。

GAN 的这个问题也可以称做整体收敛性问题。尽管判别器 D 和 生成器 D 相互竞争博弈,但同时也相互依赖对方来达到有效的训练。如果其中一方训练得很差,那整个系统也会很差(这也是之前提到的梯度消失或者模式奔溃问题)。并且你也需要确保他们不会训练太过度,造成另一方无法训练了。因此,波斯王子是一个很有趣的概念。

首先,神经网络的提出就是为了模仿人类的大脑(尽管是人为的)。它们也已经在物体识别和自然语言处理方面取得成功。但是,想要在思考和行为上与人类一致,这还有非常大的差距。

那么是什么让 GANs 成为机器学习领域一个热门话题呢?因为它不仅只是一个相对新的结构,它更加是一个比之前其他模型都能更加准确的对真实数据建模,可以说是深度学习的一个革命性的变化。

最后,它是一个同时训练两个独立的网络的新模型,这两个网络分别是判别器和生成器。这样一个非监督神经网络却能比其他传统网络得到更好性能的结果。

但目前事实是我们对 GANs 的研究还只是非常浅层,仍然有着很多挑战需要解决。GANs 目前也存在不少问题,比如无法区分在某个位置应该有多少特定的物体,不能应用到 3D 物体,以及也不能理解真实世界的整体结构。当然现在有大量研究正在研究如何解决上述问题,新的模型也取得更好的性能。

而最近谷歌为了让 GANs 更容易实现,设计开发并开源了一个基于 TensorFlow 的轻量级库–TFGAN。

根据谷歌的介绍,TFGAN 提供了一个基础结构来减少训练一个 GAN 模型的难度,同时提供非常好测试的损失函数和评估标准,以及给出容易上手的例子,这些例子强调了 TFGAN 的灵活性和易于表现的优点。

此外,还提供了一个教程,包含一个高级的 API 可以快速使用自己的数据集训练一个模型。

来源: research.googleblog.com

上图是展示了对抗损失在图像压缩方面的效果。最上方第一行图片是来自 ImageNet 数据集的图片,也是原始输入图片,中间第二行展示了采用传统损失函数训练得到的图像压缩神经网络的压缩和解压缩效果,最底下一行则是结合传统损失函数和对抗损失函数训练的网络的结果,可以看到尽管基于对抗损失的图片并不像原始图片,但是它比第二行的网络得到更加清晰和细节更好的图片。

TFGAN 既提供了几行代码就可以实现的简答函数来调用大部分 GAN 的使用例子,也是建立在包含复杂 GAN 设计的模式化方式。这就是说,我们可以采用自己需要的模块,比如损失函数、评估策略、特征以及训练等等,这些都是独立的模块。TFGAN 这样的设计方式其实就满足了不同使用者的需求,对于入门新手可以快速训练一个模型来看看效果,对于需要修改其中任何一个模块的使用者也能修改对应模块,而不会牵一发而动全身。

最重要的是,谷歌也保证了这个代码是经过测试的,不需要担心一般的 GAN 库造成的数字或者统计失误。

开始使用

首先添加以下代码来导入 tensorflow 和 声明一个 TFGAN 的实例:

import tensorflow as tf
tfgan = tf.contrib.gan

为何使用 TFGAN

  • 采用良好测试并且很灵活的调用接口实现快速训练生成器和判别器网络,此外,还可以混合 TFGAN、原生 TensorFlow以及其他自定义框架代码;
  • 使用实现好的GAN 的损失函数和惩罚策略 (比如 Wasserstein loss、梯度惩罚等)
  • 训练阶段对 GAN 进行监控和可视化操作,以及评估生成结果
  • 使用实现好的技巧来稳定和提高性能
  • 基于常规的 GAN 训练例子来开发
  • 采用GANEstimator接口里快速训练一个 GAN 模型
  • TFGAN 的结构改进也会自动提升你的 TFGAN 项目的性能
  • TFGAN 会不断添加最新研究的算法成果

TFGAN 的部件有哪些呢?

TFGAN 是由多个设计为独立的部件组成的,分别是:

  • core:提供了一个主要的训练 GAN 模型的结构。训练过程分为四个阶段,每个阶段都可以采用自定义代码或者 调用 TFGAN 库接口来完成;
  • features:包含许多常见的 GAN 运算和正则化技术,比如实例正则化(instance normalization)
  • losses:包含常见的 GAN 的损失函数和惩罚机制,比如 Wasserstein loss、梯度惩罚、相互信息惩罚等
  • evaulation:使用一个预训练好的 Inception 网络来利用Inception Score或者Frechet Distance评估标准来评估非条件生成模型。当然也支持利用自己训练的分类器或者其他方法对有条件生成模型的评估
  • examples and tutorial:使用 TFGAN 训练 GAN 模型的例子和教程。包含了使用非条件和条件式的 GANs 模型,比如 InfoGANs 等。

训练一个 GAN 模型

典型的 GAN 模型训练步骤如下:

  1. 为你的网络指定输入,比如随机噪声,或者是输入图片(一般是应用在图片转换的应用,比如 pix2pixGAN 模型)
  2. 采用GANModel接口定义生成器和判别器网络
  3. 采用GANLoss指定使用的损失函数
  4. 采用GANTrainOps设置训练运算操作,即优化器
  5. 开始训练

当然,GAN 的设置有多种形式。比如,你可以在非条件下训练生成器生成图片,或者可以给定一些条件,比如类别标签等输入到生成器中来训练。无论是哪种设置,TFGAN 都有相应的实现。下面将结合代码例子来进一步介绍。

实例

非条件 MNIST 图片生成

第一个例子是训练一个生成器来生成手写数字图片,即 MNIST 数据集。生成器的输入是从多变量均匀分布采样得到的随机噪声,目标输出是 MNIST 的数字图片。具体查看论文“Generative Adversarial Networks”。代码如下:

# 配置输入
# 真实数据来自 MNIST 数据集
images = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的输入,从多变量均匀分布采样得到的随机噪声
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 调用 tfgan.gan_model() 函数定义生成器和判别器网络模型
gan_model = tfgan.gan_model(generator_fn=mnist.unconditional_generator,  discriminator_fn=mnist.unconditional_discriminator,  real_data=images,generator_inputs=noise)# 调用 tfgan.gan_loss() 定义损失函数
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)# 调用 tfgan.gan_train_ops() 指定生成器和判别器的优化器
train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))# tfgan.gan_train() 开始训练,并指定训练迭代次数 num_steps
tfgan.gan_train(train_ops,hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],logdir=FLAGS.train_log_dir)
条件式 MNIST 图片生成

第二个例子同样还是生成 MNIST 图片,但是这次输入到生成器的不仅仅是随机噪声,还会给类别标签,这种 GAN 模型也被称作条件 GAN,其目的也是为了让 GAN 训练不会太过自由。具体可以看论文“Conditional Generative Adversarial Nets”。

代码方面,仅仅需要修改输入和建立生成器与判别器模型部分,如下所示:

# 配置输入
# 真实数据来自 MNIST 数据集,这里增加了类别标签--one_hot_labels
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的输入,从多变量均匀分布采样得到的随机噪声
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 调用 tfgan.gan_model() 函数定义生成器和判别器网络模型
gan_model = tfgan.gan_model(generator_fn=mnist.conditional_generator,  discriminator_fn=mnist.conditional_discriminator,  real_data=images,generator_inputs=(noise, one_hot_labels)) # 生成器的输入增加了类别标签# 剩余的代码保持一致
...
对抗损失

第三个例子结合了 L1 pixel loss 和对抗损失来学习自动编码图片。瓶颈层可以用来传输图片的压缩表示。如果仅仅使用 pixel-wise loss,网络只回倾向于生成模糊的图片,但 GAN 可以用来让这个图片重建过程更加逼真。具体可以看论文“Full Resolution Image Compression with Recurrent Neural Networks”来了解如何用 GAN 来实现图像压缩,以及论文“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”了解如何用 GANs 来增强生成的图片质量。

代码如下:

# 配置输入
images = image_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.gan_model(generator_fn=nets.autoencoder,  # 自定义的 autoencoderdiscriminator_fn=nets.discriminator,  # 自定义的 discriminatorreal_data=images,generator_inputs=images)# 建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 结合两个 loss
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代码保持一致
...
图像转换

第四个例子是图像转换,它是将一个领域的图片转变成另一个领域的同样大小的图片。比如将语义分割图变成街景图,或者是灰度图变成彩色图。具体细节看论文“Image-to-Image Translation with Conditional Adversarial Networks”。

代码如下:

# 配置输入,注意增加了 target_image
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.gan_model(generator_fn=nets.generator,  discriminator_fn=nets.discriminator,  real_data=target_image,generator_inputs=input_image)#  建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.least_squares_generator_loss,discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 结合两个 loss
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代码保持一致
...
InfoGAN

最后一个例子是采用 InfoGAN 模型来生成 MNIST 图片,但是可以不需要任何标签来控制生成的数字类型。具体细节可以看论文“InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets”。

代码如下:

# 配置输入
images = mnist_data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判别器网络
gan_model = tfgan.infogan_model(generator_fn=mnist.infogan_generator,  discriminator_fn=mnist.infogran_discriminator,  real_data=images,unstructured_generator_inputs=unstructured_inputs,  # 自定义输入structured_generator_inputs=structured_inputs)  # 自定义# 配置 GAN loss 以及相互信息惩罚
gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0,mutual_information_penalty_weight=1.0)# 剩下代码保持一致
...
自定义模型的创建

最后同样是非条件 GAN 生成 MNIST 图片,但利用GANModel函数来配置更多参数从而更加精确控制模型的创建。

代码如下:

# 配置输入
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 手动定义生成器和判别器模型
with tf.variable_scope('Generator') as gen_scope:generated_images = generator_fn(noise)
with tf.variable_scope('Discriminator') as dis_scope:discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)# 依赖于你需要使用的 TFGAN 特征,你并不需要指定 `GANModel`函数的每个参数,不过
# 最少也需要指定判别器的输出和变量
gan_model = tfgan.GANModel(generator_inputs,generated_data,generator_variables,gen_scope,generator_fn,real_data,discriminator_real_outputs,discriminator_gen_outputs,discriminator_variables,dis_scope,discriminator_fn)# 剩下代码和第一个例子一样
...

最后,再次给出 TFGAN 的 Github 地址如下:

https://github.com/tensorflow/models/tree/master/research/gan


如果有翻译不当的地方或者有任何建议和看法,欢迎留言交流;也欢迎关注我的微信公众号–机器学习与计算机视觉或者扫描下方的二维码,和我分享你的建议和看法,指正文章中可能存在的错误,大家一起交流,学习和进步!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408849.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux快速php,Linux 下的这些高效指令,是你快速学习的神器

linux是一套免费使用和自由传播的类unix操作系统,是一个基于posix和unix的多用户、多任务、支持多线程和多cpu的操作系统。它能运行主要的unix工具软件、应用程序和网络协议。它支持32位和64位硬件。linux继承了unix以网络为核心的设计思想,是一个性能稳…

TensorFlow 加载多个模型的方法

采用 TensorFlow 的时候,有时候我们需要加载的不止是一个模型,那么如何加载多个模型呢? 原文:https://bretahajek.com/2017/04/importing-multiple-tensorflow-models-graphs/ 关于 TensorFlow 可以有很多东西可以说。但这次我只…

[资源分享] TensorFlow 官方中文版教程来了

最近,TensorFlow 提供了中文版的教程(Tutorials)和指南(Guide)。其中,教程是介绍了一些基本的机器学习模型,包括分类、回归等,也包括一些深度学习方面的模型,包括常用的卷…

深度学习4线性回归,逻辑回归

y是连续的则是一个回归问题,y是离散的则是一个分类问题,这边就开始考虑y是离散的情况。 对于这样的问题很多,比如判断一个人是否生病,或者判断一个邮件是否是垃圾邮件。 回归时连续型的,一般不用在上述的分类问题中&am…

linux系统shell知识点,Linux 系统中shell知识点说明和常用的帮助命令简单介绍 | IT工程师的生活足迹...

linux 系统内核和各种驱动程序覆盖在下层的硬件系统之上;对上提供各种系统调用接口API,供shell和各种程序应用程序调用。总体结构图如下:操作系统的层次架构一般我们理解shell指的是BASH,即linux系统默认的字符界面使用的shell版本。另外还有…

必读的AI和深度学习博客

技术的提高是需要日积月累的努力,除了看书看视频外,一个很有效的提高方法当然就是阅读大牛的博客文章了,所谓听君一席话,胜读十年书,虽然读大牛的文章没有这么夸张,但也可以让你解决技术上的一些难题&#…

[教程]一份简单易懂的 TensorFlow 教程

上周分享了一份 TensorFlow 官方的中文版教程,这次分享的是在 Github 上的一份简单易懂的教程,项目地址是: https://github.com/open-source-for-science/TensorFlow-Course#why-use-tensorflow 如下图所示,已经有超过7000的 St…

linux 2.6 hash表作用,高性能分布式哈希表FastDHT介绍及安装配置

FastDHT介绍FastDHT 是一个高性能的分布式哈希系统 (DHT) ,使用 Berkeley DB 做数据存储,使用 libevent 做网络IO处理,提供 Java 版的客户端接口包。适合用来存储用户在线、会话等小数据量信息。FastDHT存储Key Value Pair支持两种存储方式:缓…

[GAN学习系列3]采用深度学习和 TensorFlow 实现图片修复(上)

在之前的两篇 GAN 系列文章–[GAN学习系列1]初识GAN以及[GAN学习系列2] GAN的起源中简单介绍了 GAN 的基本思想和原理,这次就介绍利用 GAN 来做一个图片修复的应用,主要采用的也是 GAN 在网络结构上的升级版–DCGAN,最初始的 GAN 采用的还是神…

用虚拟机把ubuntu安装到TF卡上

最近在学习Linux,考虑到将来可能不会带着自己的笔记本到处跑,而我又希望能随身带着个Ubuntu系统 ,总不能在别人的电脑上装个Linux系统吧。刚好最近入手了一张 Sandisk 16G class 10 的TF卡,加上一个PNY的手机宝贝读卡器&#xff0…

C 语言调用CPU指令,CPU 1214C中 TSEND_C指令 最多可以使用几次-工业支持中心-西门子中国...

8次硬件版本 V3.0 支持的协议和最大的连接资源:3个连接用于操作面板1个连接用于编程设备(PG)与 CPU 的通信8个连接用于Open IE ( TCP, ISO on TCP, UDP) 的编程通信,使用T-block 指令来实现3个连接用于S7 通信的服务器端连接,可以实现与S7-20…

[GAN学习系列3]采用深度学习和 TensorFlow 实现图片修复(中)

上一篇文章–[GAN学习系列3]采用深度学习和 TensorFlow 实现图片修复(上)中,我们先介绍了对于图像修复的背景,需要利用什么信息来对缺失的区域进行修复,以及将图像当做概率分布采样的样本来看待,通过这个思路来开始进行…

[资源分享] 推荐两本电子书

又到了一周一次的资源和教程推荐。这周会推荐两本电子书,希望大家不只是收藏不阅读系列哦!1. 《模式识别与机器学习》(PRML)免费开放下载第一本推荐的书籍就是 AI 领域里面一直都非常有名的书籍--《模式识别与机器学习》&#xff…