AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

AIGC实战——条件生成对抗网络

    • 0. 前言
    • 1. CGAN架构
    • 2. 模型训练
    • 3. CGAN 分析
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像。在本节中,我们将构建一个能够控制输出的 GAN,即条件生成对抗网络 (Conditional Generative Adversarial Net, GAN)。该模型最早由 MirzaOsindero2014 年提出,是对 GAN 架构的简单改进。

1. CGAN架构

在节中,我们将使用面部数据集中的头发颜色属性来设置 CGAN 的条件。也就是说,我们将能够明确指定是否要生成带有金发的图像。头发颜色标签作为 CelebA 数据集的一部分已在数据集中提供,CGAN 的架构如下图所示。

CGAN 架构

标准 GANCGAN 之间的关键区别在于:在 CGAN 中,我们需要向生成器和判别器传递与标签相关的额外信息。在生成器中,标签信息转化为独热编码 (one-hot) 向量后附加在潜空间样本之后。在判别器中,通过重复独热编码向量填充得到与输入图像相同形状的通道,将标签信息添加为 RGB 图像的额外通道。
CGAN 之所以能够生成指定类型的图像,是因为其判别器可以获得关于图像内容的额外信息,因此生成器必须确保其输出与提供的标签一致,以继续欺骗判别器。如果生成器生成了与图像标签不一致的图像,即使图像非常逼真,判别器会将它们判定为伪造图像,因为图像和标签并不匹配。
在本节所构建的 CGAN 中,因为有两个类别(金发和非金发),独热编码标签的长度是 2。但是,我们也可以根据需要拥有使用多个标签。例如,在 Fashion-MNIST 数据集上训练 CGAN 时,为了输出 10 种不同类型的 Fashion-MNIST 图像,可以通过将长度为 10 的独热编码标签向量并入生成器的输入,并将 10 个额外的独热编码标签通道并入判别器的输入。
综上,我们需要对标准 GAN 架构所进行的修改是,将标签信息与生成器和判别器的现有输入连接起来:

# 图像通道和标签通道分别传递给判别器,并进行连接
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
label_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CLASSES))
x = layers.Concatenate(axis=-1)([critic_input, label_input])
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)critic = models.Model([critic_input, label_input], critic_output)
print(critic.summary())
# 潜向量和标签类别分别传递给生成器,并在调整形状之前进行连接
generator_input = layers.Input(shape=(Z_DIM,))
label_input = layers.Input(shape=(CLASSES,))
x = layers.Concatenate(axis=-1)([generator_input, label_input])
x = layers.Reshape((1, 1, Z_DIM + CLASSES))(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model([generator_input, label_input], generator_output)
print(generator.summary())

2. 模型训练

调整 CGANtrain_step 方法,以令生成器和判别器适应新的输入格式:

    def train_step(self, data):# 从数据集中提取图像和标签real_images, one_hot_labels = data# 将独热编码向量扩展为具有与输入图像相同空间尺寸 (64×64) 的独热编码图像image_one_hot_labels = one_hot_labels[:, None, None, :]image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=1)image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=2)batch_size = tf.shape(real_images)[0]for i in range(self.critic_steps):random_latent_vectors = tf.random.normal( shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:# 生成器接受包含两个输入的列表——随机潜向量和独热编码的标签向量fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)# 判别器接受包含两个输入的列表——真实/生成图像和独热编码的标签通道fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)real_predictions = self.critic([real_images, image_one_hot_labels], training=True)c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)c_gp = self.gradient_penalty(batch_size, real_images, fake_images, image_one_hot_labels)# 梯度惩罚函数还需要通过独热编码的标签通道传递(由于其流经判别器)c_loss = c_wass_loss + c_gp * self.gp_weightc_gradient = tape.gradient(c_loss, self.critic.trainable_variables)self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:# 生成器训练过程的修改与判别器训练步骤的修改相同fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)g_loss = -tf.reduce_mean(fake_predictions)gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))self.c_loss_metric.update_state(c_loss)self.c_wass_loss_metric.update_state(c_wass_loss)self.c_gp_metric.update_state(c_gp)self.g_loss_metric.update_state(g_loss)return {m.name: m.result() for m in self.metrics}

3. CGAN 分析

我们可以通过将特定的独热编码标签传递到生成器的输入中来控制 CGAN 的输出。例如,要生成一张非金发的人脸图像,我们传入向量 [1, 0];要生成一张金发的人脸图像,我们传入向量 [0, 1]
CGAN 的输出如下图所示。可以看到,在保持随机潜向量不变的情况下,只改变条件标签向量,显然 CGAN 已经学会使用标签向量来控制图像的头发颜色属性,且图像的其余部分几乎没有改变。这证明了 GAN 能够以这种方式组织潜空间中的点,使得各个特征可以相互解耦。

生成结果

如果数据集中有标签可用,即使不一定需要将生成的输出与标签相关联,将它们作为 GAN 的输入通常也可以提高生成图像的质量,我们可以把标签看作是像素输入的信息扩展。

小结

在本节中,构建了一个条件生成对抗网络 (Conditional Generative Adversarial Net, CGAN),通过将标签作为输入传递给判别器和生成器,能够生成可控类别的图像,这是由于标签为网络提供了额外的信息,以便使生成的输出与给定的标签相关联。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/230306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用JavaScript 将数据网格绑定到 GraphQL 服务

前言 作为一名前端开发人员,GraphQL对于我们来说是令人难以置信的好用。它可以用来简化数据访问,这让我们的工作变得更加容易。 什么是 GraphQL?它是一个抽象层,位于任意数量的数据源之上,并为您提供一个简单的 API …

日期计算 C语言xdoj68

问题描述 给定一个年份y和一个整数d,问这一年的第d天是几月几日? 注意闰年的2月有29天,且满足下面条件之一的是闰年: 1) 年份是4的整数倍,而且不是100的整数倍; 2) 年份是…

激光雷达LIDAR

1. 历史 公元前440年,古希腊哲学家、预言家、科学家、江湖术士恩培多克勒提出月亮是由反射发光;提出光有速度。1638年,伽利略提着一盏灯站在山头上,默默的把灯盖了起来… 远处的另一个山头上,他的助手在看到灯灭的一瞬…

传统FC存储向NoF发展进化

全闪存时代背景下,传统的FC(Fibre Channel,网状通道)存储网络已经无法满足全闪存数据中心的要求,NVMe(Non-Volatile Memory express,非易失性内存主机控制器接口规范)存储协议的出现…

现网问题处理策略

收到问题时需要确认的问题 哪个局点 识别局点重要性使用哪些业务开启的特性数据量大小 哪个版本 找到我们产品的版本以及上游组件/底座版本 谁找来的 一线同事。一线直接面对客户,压力会比较大,需要严肃对待。下游业务同事。压力相对会较小。 最近…

12.18拓扑排序,DAG,模板,课程安排

拓扑排序 有向无环图一定是拓扑序列,有向有环图一定不是拓扑序列。 无向图没有拓扑序列。 首先我们先来解释一下什么是有向无环图: 有向就是我们两个结点之间的边是有方向的,无环的意思就是整个序列中没有几个结点通过边形成一个圆环。 下图就是一个…

【web安全】万能密码总结

前言 菜某的总结,欢迎提意见补充~ 万能密码的原理 万能密码实际上也算是sql注入的一种。 登录界面是一个与数据库交互的位置,很容易产生sql注入的位置。 我们登录时输入的数据会带入数据库查询进行比对,当用户名与用户的密码对的上的话&…

个人版 AI 辅助系统的尝试

在 CSDN 的时候,我就一直想要有自己的 AI 工作环境。我们组只有一台高配的办公服务器,用于训练模型,分析数据。通常来说这台机器都很忙。如果想要 做一些研究工作或试验,资源就有点紧张了。而我自己的工作机,虽然是一台…

基于Vue的汽车服务商城系统设计与实现论文

摘 要 本课题是根据用户的需要以及网络的优势建立的一个基于Vue的汽车服务商城系统,来更好的为用户提供服务。 本基于Vue的汽车服务商城系统应用Java技术,MYSQL数据库存储数据,基于SSMVue框架开发。在网站的整个开发过程中,首先对…

clipboard.js实现复制和粘贴

// 复制文本到剪贴板 function copyToClipboard(text) {navigator.clipboard.writeText(text).then(() > {console.log(Text copied to clipboard);}).catch((error) > {console.error(Failed to copy text:, error);}); }// 从剪贴板粘贴文本 function pasteFromClipboa…

linux网络管理_网络接口名称规则

11.1 网络接口名称规则 11.1.1 简介 目标:认识网卡》》找到网卡文件》》学会修改文件》》多台服务器互通 网络接口名称 ​ 传统上,Linux中的网络接口被枚举为eth0 (ethernet0)、eth1、eth2等,然而使用这些网络设备名可能遇到不确定性,且不…

面试算法56:二叉搜索树中两个节点的值之和

题目 给定一棵二叉搜索树和一个值k,请判断该二叉搜索树中是否存在值之和等于k的两个节点。假设二叉搜索树中节点的值均唯一。例如,在如图8.12所示的二叉搜索树中,存在值之和等于12的两个节点(节点5和节点7)&#xff0…

WebSocket网络协议

一、简介 WebSocket 是一种在客户端和服务器之间建立双向通信信道的网络协议。它在客户端和服务器之间建立一个持久的、全双工的连接,允许数据在两个方向上实时传输,而不需要像HTTP一样进行多次请求和响应。 WebSocket 的主要优势是减少了服务器和客户…

Redis发布与订阅

什么是发布与订阅 答: redis发布订阅是一种消息通信通信模式,由发送者(pub)发送消息,订阅者(sub)接收消息。 如下图client2、4、5就是订阅着,订阅了channel1的消息。 当channel1要发送消息时,这几个订阅者都会实时收到消息。 发布订阅的方式…

C++ STL泛型算法

泛型算法 <algorithm>定义了大约 80 个标准算法。 它们操作由一对迭代器定义的&#xff08;输入&#xff09;序列或单一迭代器定义的&#xff08;输出&#xff09;序列。 当对两个序列进行拷贝、比较操作时&#xff0c;第一个序列由一对迭代器[b,e)表示&#xff0c;但第…

移动零算法(leetcode第283题)

题目描述&#xff1a; 给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元素的相对顺序。请注意 &#xff0c;必须在不复制数组的情况下原地对数组进行操作。示例 1:输入: nums [0,1,0,3,12] 输出: [1,3,12,0,0] 示例 2:输入: n…

用uniapp写一个点击左侧可以滑动的menu

完成后的图片&#xff08;点击左侧右边或滑动&#xff0c;滑动右边左侧的选中也会变化&#xff09;&#xff1a; 数据js &#xff08;classifyData&#xff09;&#xff1a; export default[{"name": "女装","foods": [{"name": &q…

消息幂等:如何保证消息不被重复消费?

应用的幂等是在分布式系统设计时必须要考虑的一个方面&#xff0c;如果对幂等没有额外的考虑&#xff0c;那么在消息失败重新投递&#xff0c;或者远程服务重试时&#xff0c;可能会出现许多诡异的问题。本文一起来看一下&#xff0c;在消息队列应用中&#xff0c;如何处理因为…

命名之美:探索Java的标识符与命名规范

目录 ​编辑 前言 一、Java关键字&#xff1a; class&#xff1a; public、private、protected&#xff1a; static&#xff1a; final&#xff1a; void&#xff1a; int、double、char、boolean&#xff1a; if、else、switch&#xff1a; for、while、do&#xf…

01到底应该怎么理解“平均负载”

1、如何了解系统的负载情况&#xff1f; 每次发现系统变慢时&#xff0c; 我们通常做的第⼀件事&#xff0c; 就是执⾏top或者uptime命令&#xff0c; 来了解系统的负载情况。 ⽐如像下⾯这样&#xff0c; 我在命令⾏⾥输⼊了uptime命令&#xff0c; 系统也随即给出了结果。 …