生成对抗网络DCGAN学习实践

在AI内容生成领域,有三种常见的AI模型技术:GAN、VAE、Diffusion。其中,Diffusion是较新的技术,相关资料较为稀缺。VAE通常更多用于压缩任务,而GAN由于其问世较早,相关的开源项目和科普文章也更加全面,适合入门学习。

博主从入门和学习角度用Tensorflow跑通了DCGAN,本文对其进行记录以及分享。

1.简介

GAN(Generative Adversarial Network)是一种用于生成模型的机器学习框架。其原理基于两个主要组件:生成器(Generator)和判别器(Discriminator),二者通过对抗学习的方式相互竞争和提升。

从2014年左右发展至今,GAN目前有很多分支:

  • GAN 朴素GAN,最原始版本
  • DCGAN 卷积神经网络GAN
  • CGAN 条件GAN,训练时传入额外条件,例如通过不同的mask区域生成不同内容,可控制的生成
  • SeqGAN 使用GAN生成某些风格的句子,但不能进行对答
  • Cycle GAN 可实现图像风格迁移,其实现略复杂
  • 省略

2.原理介绍

先来看图

梯度
判别
G
LeakyReLU
tanh
InputNoise
FullConnectLayer123
OutputImage
D
LeakyReLU
Sigmoid
InputImage
FullConnectLayer12
OutputOneValue

生成器(Generator)和判别器(Discriminator)是GAN的两个主要模型,生成器在上图中用缩写G表示,判别器用缩写D表示。
生成器G输入[N]的一维噪声,即InputNoise。输出[W * H * RGB](大致类似)的张量
判别器D输入一张图像,输出[1]的张量,即一个浮点数,通过0-1的值得到图像是真还是假

判别器需要尽可能的认出造假图片,生成器需要尽可能的骗过判别器,两者会在这2个目标上不断的通过反向传播进行学习,从而达到生成器和判别器的纳什均衡,最终输出质量很高的生成图像。

2.2 重点1

在训练中,判别器返回一个0-1区间的浮点数(如[0]=0.63,[0]=0.21)作为判断结果,值越高也越认为是真实图片。由于判别器也是一个神经网络模型,因此可以将输出层的梯度一直传递回输入层,然后将输入层的梯度作为生成器的梯度继续反向传播,从而完成一次训练。

然而,很多文章并没有提到这一点。如果没有接触过这种多模型梯度传递训练方法,可能会认为使用一个数学方法或者计算机视觉方法来构建判别器也可以让整个模型正常运行。但事实上,这种方法是不可行的(通常情况下)。

2.3 重点2

使用更多的层可以增强模型的推理能力。例如,在训练过程中,如果模型生成出眉毛 A 的特征,则有鼻子 B、C 和 D 相关的备选项;而如果生成出眉毛 E 的特征,则有鼻子 F 和 G 相关的备选项。

这也是为什么生成器需要使用三个隐层的原因(博主的观点)。通过增加隐层的数量,模型可以捕捉到更多的特征和抽象概念,从而提高生成器的表现能力和推理能力。更深层次的网络结构能够帮助模型学习更复杂的模式和关联,使其在生成结果时更加准确和多样化。

上图生成器部分的激活函数用的是LeakyReLU,实际上就单隐层神经网络来说,ReLU要比Sigmoid能多解决很多类型问题,Sigmoid更适合分类问题,遇到一些奇怪的问题不容易收敛,而LeakyReLU激活函数即和ReLU逻辑一样也可以返回负数信息,这是博主觉得采用这个激活函数的原因。
而至于tanH和Sigmoid的比较,它们在某种程度上相似。一般来说,网上普遍认为tanH比Sigmoid更好,主要原因是它具有较窄的数值边界范围。

2.4 重点3

对于2套样本比较损失这类问题,一般使用二分类交叉熵,这不同于分类问题。
而二分类交叉熵又是在只有2种结果(r和1-r),的情况下对公式进行的简化:
https://blog.csdn.net/grayrail/article/details/131619144

2.5 模式崩溃

训练时还会出现一种情况,即生成器始终卡在一个生成结果上,比如生成0-9数字,结果训练几轮后始终在生成数字3。
这种情况称为模式崩溃,一般增加训练样本数量并调节参数,没有比较好的办法。

3.实践准备

python库下载使用国内镜像源:
https://zhuanlan.zhihu.com/p/477179822

使用方式:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyspider

github库下载耽误时间,可以缓存到gitee:
在这里插入图片描述

而gitee也有自己缓存好的镜像库,可以先去这里查:
https://gitcode.net/mirrors

python库查找:
https://pypi.org/

在pip中查找python库:
先 pip install pip-search 再使用命令 pip_search 搜索

4.实践

全连接神经网络版本的朴素GAN效果相对较差,而DCGAN(Deep Convolutional GAN)是卷积神经网络版本的GAN,下面以DCGAN为例使用Tensorflow进行实现:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers# 定义生成器模型
def build_generator():model = tf.keras.Sequential()model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Reshape((7, 7, 256)))assert model.output_shape == (None, 7, 7, 256)  # 注意:batch size 没有限制model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))assert model.output_shape == (None, 7, 7, 128)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))assert model.output_shape == (None, 14, 14, 64)model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))assert model.output_shape == (None, 28, 28, 1)return model# 定义判别器模型
def build_discriminator():model = tf.keras.Sequential()model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',input_shape=[28, 28, 1]))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))model.add(layers.LeakyReLU())model.add(layers.Dropout(0.3))model.add(layers.Flatten())model.add(layers.Dense(1))return model# 定义生成器和判别器
generator = build_generator()
discriminator = build_discriminator()# 定义损失函数
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)# 定义生成器和判别器的优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)def generator_loss(fake_output):return loss_fn(tf.ones_like(fake_output), fake_output)def discriminator_loss(real_output, fake_output):real_loss = loss_fn(tf.ones_like(real_output), real_output)fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)total_loss = real_loss + fake_lossreturn total_loss# 定义训练循环
@tf.function  #这个是tensorflow的装饰器,标记后可提升性能,不加此标记也可
def train_step(images):# 生成噪声向量noise = tf.random.normal([BATCH_SIZE, 100])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:# 使用生成器生成假图片generated_images = generator(noise, training=True)# 使用判别器判断真假图片real_output = discriminator(images, training=True)fake_output = discriminator(generated_images, training=True)# 计算损失函数gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)# 计算梯度并更新生成器和判别器的参数gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))def generate_and_save_images(model, epoch, test_input):predictions = model(test_input, training=False)print("predictions.shape:", predictions.shape)num_images = predictions.shape[0]rows = int(num_images ** 0.5) # 计算行数cols = num_images // rows # 计算列数fig = plt.figure(figsize=(8, 8))for i in range(num_images):plt.subplot(rows, cols, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))#plt.show()# 加载MNIST数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()# 标准化数据
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5# 批量大小与训练次数
BATCH_SIZE = 256
EPOCHS = 50# 数据集切分为批次并进行训练
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(BATCH_SIZE)for epoch in range(EPOCHS):for i,image_batch in enumerate(dataset):print("sub i",i)train_step(image_batch)print("------------------------------------------------------epoch:", epoch)# 每个 epoch 结束后生成并保存一组图像if (epoch + 1) % 5 == 0:seed = tf.random.normal([BATCH_SIZE, 100])generate_and_save_images(generator, epoch + 1, seed)

跑一阵子MNIST数据集后,结果如下:
在这里插入图片描述


参考:

论文精读: https://www.bilibili.com/video/BV1rb4y187vD

同济子豪兄精读版本: https://www.bilibili.com/video/BV1oi4y1m7np

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】Gradient Descent

Gradient Descent for Linear Regression 1、梯度下降2、梯度下降算法的实现(1) 计算梯度(2) 梯度下降(3) 梯度下降的cost与迭代次数(4) 预测 3、绘图4、学习率 首先导入所需的库: import math, copy import numpy as np import matplotlib.pyplot as plt plt.styl…

Devops系统中jira平台迁移

需求:把aws中的devops系统迁移到华为云中,其中主要是jira系统中的数据迁移,主要方法为在华为云中建立一套 与aws相同的devops平台,再把数据库和文件系统中的数据迁移,最后进行测试。 主要涉及到的服务集群CCE、数据库mysql、弹性文件服务SFS、数据复制DRS、弹性负载均衡ELB。 迁…

问道管理:补仓什么意思?怎么补仓可以降低成本?

补仓这个术语我们在理财出资中经常听到,例如基金补仓,股票补仓。那么,补仓什么意思?怎样补仓能够降低成本?问道管理为我们预备了相关内容,以供参阅。 补仓什么意思? 股票补仓是指出资者在某一只…

Debian 12.1 “书虫 “发布,包含 89 个错误修复和 26 个安全更新

导读Debian 项目今天宣布,作为最新 Debian GNU/Linux 12 “书虫 “操作系统系列的首个 ISO 更新,Debian 12.1 正式发布并全面上市。 Debian 12.1 是在 Debian GNU/Linux 12 “书虫 “发布六周后推出的,目的是为那些希望在新硬件上部署操作系统…

Vivado进行自定义IP封装

一. 简介 本篇文章将介绍如何使用Vivado来对上篇文章(FPGA驱动SPI屏幕)中的代码进行一个IP封装,Vivado自带的IP核应该都使用过,非常方便。 这里将其封装成IP核的目的主要是为了后续项目的调用,否则当我新建一个项目的时候,我需要将…

VirtualBox Ubuntu无法安装增强功能以及无法复制粘贴踩坑记录

在VirtualBox安装增强功能想要和主机双向复制粘贴,中间查了很多资料,终于是弄好了。记录一下过程,可能对后来人也有帮助,我把我参考的几篇主要的博客都贴上来了,如果觉得我哪里讲得不清楚的,可以去对应的博…

Shell脚本学习-Shell函数

函数的作用就是将程序里多次被调用的相同代码组合起来(函数体),并为其取一个名字,即函数名。其他所有想重复调用这部分代码的地方都只需要调用这个名字就可以了。当需要修改这部分代码时候,只需要修改函数体内的这部分…

【简单认识GFS分布式文件系统】

文章目录 一.GlusterFS 概述1.GlusterFS简介2.特点3.GlusterFS 术语4.模块化堆栈式架构5.GlusterFS 的工作流程6.GlusterFS的卷类型1、**分布式卷(Distribute volume)**2、条带卷(Stripe volume)3、复制卷(Replica vol…

Web后端基本设计思想

JavaWeb应用的后端一般基于MVC和三层架构思想实现。 MVC是一种设计模式,用于开发用户界面和交互式应用程序。M即Model,业务模型,负责处理应用程序的业务逻辑和数据;V即View,视图,负责给用户展示界面和数据&…

快速创建vue3+vite+ts项目

安装nodejs 创建项目 npm init vitelatest 默认之后回车 选择项目名字my-vue-project 选择vue框架 选择ts 运行项目 cd my-vue-project npm install --registryhttps://registry.npm.taobao.org npm run dev

Vue2 第十二节 Vue组件化编程(一)

1.模块与组件,模块化与组件化概念 2. 非单文件组件 3. 组件编写注意事项 4. 组件的嵌套 一. 模块与组件,模块化与组件化 传统方式编写存在的问题 (1)依赖关系混乱,不好维护 (2)代码的复用…

炒股杠杆途乐证券;股票买入卖出时间规则?

股票买入卖出时刻规则是指出资者在股票商场上进行生意交易时需求遵循的一系列时刻规定。正确的买入和卖出时刻能够协助出资者最大化出资回报,一起降低风险。但是,在股票商场上,生意时刻的挑选是一个复杂的问题,需求从多个角度剖析…

vSphere ESXI 7.0 网络规划

ESXi 网络 业务网络、Vmotion(漂移)、管理网络、存储网络 ESXi 管理网络 vCenter Server 管理网络 vCenter Server SSO域名 Single Sign-on域名:在没有指定的情况下,默认填写 vsphere.local VMware vSphere整体解决方案和网络…

汽车行业案例 | 联合汽车电子全新质量问题管理平台上线,燕千云助力汽车电子领军者实现数字化质量管理

据权威调研机构显示,2022年中国智能电动汽车的销量已占新能源汽车的52%以上。到2025年,在新能源汽车50%的汽车出行市场渗透率的基础上,智能电动汽车的销量将超1220万辆,占新能源汽车的80.1%。在技术进步和产业变革快速推进的背景下…

git常用指令

git add命令 作用:移动文件:工作区-->暂存区 git add .:把所有文件都放到暂存区 git commit命令 作用:移动文件:暂存区-->本地仓库 git status命令 作用:查看修改状态 git log命令 作用&#xf…

嵌入式软件开发有没有捷径

嵌入式软件开发有没有什么捷径?不定期会收到类似的问题,我只想说:嵌入式软件开发没有捷径 说实话,有这种想法的人,我其实想劝你放弃。对于绝大多数普通人,一步一个脚印就是捷径。 当然,这个问题…

VLAN介绍

目录 VLAN的特点: VLAN的好处: VLAN的实现原理 VLAN标签 VLAN的划分方式 接口划分VLAN--接口类型 Access接口 Trunk接口 Hybrid接口 实现VLAN之间通信 使用路由器物理接口 使用子接口 使用三层交换机的VLANIF接口 配置 VLANIF的转发流程 三层交换机参与下的三层…

IDEA偶尔编译的时候不识别lombok

偶尔IDEA启动项目的时候会识别不到lombok,识别不到get()跟set()方法 方案 在settings添加下面代码 -Djps.track.ap.dependenciesfalse

dialog => :before-close的属性应用

在element-ui里面关闭弹窗的时候before-close会触发。 也就是点击X的时候回触发before-close这个属性, 代码实例: <el-dialogtitle"新增用户":visible.sync"dialogVisible"width"50%":before-close"handleClose"> handleClose…

linux学习笔记(2)----汇编LED灯实验

MX6ULL 的 IO IO的复用功能 这里的只使用了低五位&#xff0c;用来配置io口&#xff0c;其中bit0~bit3(MUX_MODE)就是设置 GPIO1_IO00 的复用功能的&#xff0c;GPIO1_IO00 一共可以复用为 9种功能 IO&#xff0c;分别对应 ALT0~ALT8。每种对应了不同的功能 io的属性配置 HY…