深度学习--对抗生成网络(GAN, Generative Adversarial Network)

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。

1. 概念

GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。

  • 生成器负责生成伪造的样本数据,它的目标是生成足够真实的数据,使判别器难以区分。
  • 判别器负责区分数据是真实的(来自训练数据集)还是生成的(来自生成器)。

这两个网络通过博弈的方式相互对抗:

  • 生成器尝试欺骗判别器,生成与真实数据无差别的虚假数据;
  • 判别器试图提高辨别能力,正确区分真假数据。

最终的目标是使生成器生成的数据越来越接近于真实数据,直至判别器无法区分两者。

2. 作用

GAN的主要作用是生成新数据,常用于图像生成、数据增强、艺术创作等领域。它的优势在于无需明确的监督信号,仅通过数据分布的隐含特征进行学习和生成。

具体应用包括:

  • 图像生成:例如生成逼真的人脸、风景等图像。
  • 数据增强:扩充小样本数据集,改善模型训练效果。
  • 超分辨率重建:将低分辨率图像生成高分辨率图像。
  • 风格转换:将一种图像风格转换为另一种,例如将照片转化为绘画风格。
  • 生成虚拟数据:例如医学影像、合成声音、文本等。

3. 核心要点

GAN的核心在于生成器和判别器的相互博弈,这种机制使模型能够自我优化,但同时也存在一些关键挑战和要点:

  • 损失函数:GAN的损失函数是基于极小极大博弈的。生成器的目标是最大化判别器的损失,即让判别器判断出错;而判别器的目标是最小化这个损失,使其能够更好地区分真假数据。

    通常使用交叉熵损失(Binary Cross-Entropy)来优化生成器和判别器:

  • 模式崩溃:生成器有时会陷入生成某些特定模式的数据(称为模式崩溃),即生成器输出的多样性不足,难以生成多样的真实数据。为了解决这一问题,改进的GAN模型(如WGAN)引入了不同的损失函数和训练策略。

  • 平衡训练:生成器和判别器的训练需要保持平衡,过强的判别器会导致生成器无法学习,而过强的生成器又会让判别器失效。训练GAN时,需要小心调节它们的训练速率。

  • 网络架构:生成器和判别器的网络结构设计非常重要,通常使用深度卷积神经网络(DCNN)进行构建,尤其在图像生成任务中,DCGAN(Deep Convolutional GAN)表现优异。

4. 实现过程

GAN的实现过程包括以下几个步骤:

  1. 数据准备:选择训练数据集,例如图像或其他类型的数据集,通常需要大量真实样本。

  2. 生成噪声:生成器的输入是随机噪声,一般从高维的均匀分布或正态分布中采样。

  3. 构建生成器网络:生成器将噪声数据映射为真实数据的空间,通过深度神经网络进行逐层生成,最终输出一个逼真的样本。

  4. 构建判别器网络:判别器是一个二分类网络,输入为真实数据或生成器生成的数据,输出为其判断的概率值(0-1之间,表示真假)。

  5. 训练:采用交替训练方式,先固定生成器,训练判别器;再固定判别器,训练生成器。这个过程不断循环,生成器和判别器相互竞争,直至生成器的生成能力足以欺骗判别器。

  6. 模型评估:训练过程中,使用对抗损失或其他指标来评估生成器和判别器的效果。视觉上,生成的图像逐渐从粗糙变得逼真。

5.GAN的代码实现

下面是一个简单的GAN实现,用于生成与MNIST数据集类似的手写数字图像。

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist

# 设置随机种子,便于复现
np.random.seed(1000)
tf.random.set_seed(1000)

# 超参数设置
latent_dim = 100  # 生成器输入的噪声维度
batch_size = 128
epochs = 10000
save_interval = 1000

# 1. 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = (x_train - 127.5) / 127.5  # 将图像归一化到[-1, 1]
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 重塑为28x28x1的图像

# 2. 创建生成器模型
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

# 3. 创建判别器模型
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(Dense(1, activation='sigmoid'))  # 输出0或1,判断真伪
    return model

# 4. 编译生成器和判别器
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# 5. 创建并编译GAN模型
discriminator.trainable = False  # 固定判别器,训练时只训练生成器
gan_input = tf.keras.Input(shape=(latent_dim,))
generated_image = generator(gan_input)
validity = discriminator(generated_image)

gan = tf.keras.Model(gan_input, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# 6. 训练GAN
def train(epochs, batch_size=128, save_interval=100):
    half_batch = int(batch_size / 2)

    for epoch in range(epochs):
        # 训练判别器
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        real_images = x_train[idx]

        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        generated_images = generator.predict(noise)

        real_labels = np.ones((half_batch, 1))
        fake_labels = np.zeros((half_batch, 1))

        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        valid_labels = np.ones((batch_size, 1))

        g_loss = gan.train_on_batch(noise, valid_labels)

        # 每隔save_interval保存并展示一次结果
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")
            save_images(epoch)

# 7. 生成并保存图像
def save_images(epoch):
    noise = np.random.normal(0, 1, (25, latent_dim))
    gen_images = generator.predict(noise)
    gen_images = 0.5 * gen_images + 0.5  # 缩放回[0, 1]区间

    fig, axs = plt.subplots(5, 5)
    cnt = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

# 开始训练
train(epochs=epochs, batch_size=batch_size, save_interval=save_interval)

6. 适用场景

GAN适用于许多生成任务,特别是那些需要从数据中提取复杂模式的任务:

  • 图像生成与修复:GAN可用于生成逼真的图像,修复图像中的缺失部分。
  • 数据增强:在数据稀缺的场景下,GAN可以生成类似于训练数据的样本,帮助改进模型的泛化能力。
  • 超分辨率图像重建:通过生成细节清晰的高分辨率图像,应用于图像处理、视频质量提升等场景。
  • 风格迁移:通过GAN实现不同风格的图像、视频转换,例如将照片转为艺术风格画。
  • 医学影像生成:GAN可以生成医学图像,例如CT扫描、MRI数据等,辅助疾病检测与诊断。
  • 文本到图像生成:通过输入文本描述,GAN可以生成与描述相匹配的图像,应用于自动图像生成等场景。

总结

对抗生成网络(GAN)是近年来在生成式模型领域的重要突破,通过生成器与判别器的对抗博弈,GAN能够生成高度逼真的数据。其应用范围广泛,涵盖了图像生成、数据增强、超分辨率重建、风格迁移等多个领域。然而,GAN的训练过程具有挑战性,特别是在平衡两者的对抗关系上仍然存在技术难题。随着技术的不断发展,GAN在生成数据、创造内容等方面的应用前景将更加广阔。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/878875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如果电脑一直提示微软账号登录……

前言 今天小白接了个电脑故障问题:电脑提示微软账号登录,然后经过各种操作…… 电脑重启之后就变成了这样: 按理说,登录了微软账号之后,Windows系统要进入到桌面就必须有一个输入密码验证的过程,但这个界…

Harmony OS DevEco Studio 如何导入第三方库(以lottie为例)?-- HarmonyOS自学2

在做鸿蒙开发时,离不开第三方库的引入 一.有哪些支持的Harmony OS的 第三方库? 第三方库下载地址: 1 tpc_resource: 三方组件资源汇总 2 OpenHarmony三方库中心仓 二. 如何加入到DevEco Studio工程 以 lottie为例 OpenHarmony-TPC/lot…

nginx 新建一个 PC web 站点

注意:进行实例之前必须完成nginx的源码编译。(阅读往期文章完成步骤) 1.编辑nginx的配置文件,修改内容 [rootlocalhost ~]# vim /usr/local/nginx/conf/nginx.conf 2.创建新目录/usr/local/nginx/conf.d/,编辑新文件…

基于飞腾平台的Hive的安装配置

【写在前面】 飞腾开发者平台是基于飞腾自身强大的技术基础和开放能力,聚合行业内优秀资源而打造的。该平台覆盖了操作系统、算法、数据库、安全、平台工具、虚拟化、存储、网络、固件等多个前沿技术领域,包含了应用使能套件、软件仓库、软件支持、软件适…

通信工程学习:什么是PCM脉冲编码调制、DPCM差分脉冲编码调制、ADPCM自适应差分脉冲编码调制

PCM脉冲编码调制、DPCM差分脉冲编码调制、ADPCM自适应差分脉冲编码调制 PCM、DPCM、ADPCM是音频编码技术中的三种重要方式,它们在音频信号的数字化、压缩和传输中起着关键作用。以下是对这三种技术的详细解释: 一、PCM(Pulse Code Modulatio…

2024数学建模国赛选题建议+团队助攻资料(已更新完毕)

目录 一、题目特点和选题建议 二、模型选择 1、评价模型 2、预测模型 3、分类模型 4、优化模型 5、统计分析模型 三、white学长团队助攻资料 1、助攻代码 2、成品论文PDF版 3、成品论文word版 9月5日晚18:00就要公布题目了,根据历年竞赛题目…

[C#学习笔记]注释

官方文档&#xff1a;Documentation comments - C# language specification | Microsoft Learn 一、常用标记总结 1.1 将文本设置为代码风格的字体&#xff1a;<c> 1.2 源代码或程序输出:<code> 1.3 异常指示:<exception> 1.4 段落 <para> 1.5 换行&…

Reflection 70B:震撼AI行业的开源模型

随着人工智能&#xff08;AI&#xff09;技术的快速发展&#xff0c;开源与闭源模型的竞争变得越来越激烈。近日&#xff0c;Reflection 70B模型的发布在AI行业引发了巨大的震动。这款拥有70亿参数的开源模型不仅在多项基准测试中取得了优异成绩&#xff0c;还在很多情况下超越…

滑动窗口系列(同向双指针)/9.7

新的解题思路 一、三数之和的多种可能 给定一个整数数组 arr &#xff0c;以及一个整数 target 作为目标值&#xff0c;返回满足 i < j < k 且 arr[i] arr[j] arr[k] target 的元组 i, j, k 的数量。 由于结果会非常大&#xff0c;请返回 109 7 的模。 输入&…

【阿里云】个人认证与公司认证

个人认证和企业认证的区别 更新时间&#xff1a;2024-05-20 09:32:52 本文档主要介绍个人认证账号和企业认证账号的区别。 账号实名认证分为个人实名认证和企业实名认证。 个人账号认证&#xff0c;请选择认证类型为 个人&#xff0c;支持个人支付宝授权认证和个人扫脸认证。…

使用cage工具包生成验证码

目录 1. 导入依赖2. 控制类3. 测试 1. 导入依赖 <!-- 验证码工具 --><dependency><groupId>com.github.cage</groupId><artifactId>cage</artifactId><version>1.0</version></dependency>2. 控制类 RestControl…

探索 RAD:5 个最佳实践案例解析

天下武功&#xff0c;唯快不破&#xff01;应用开发&#xff0c;唯速称王&#xff01; 在当今快速发展的科技环境中&#xff0c;企业面临的挑战不断升级。传统的应用开发方法往往因其复杂的流程和较长的开发周期而无法满足快速变化的市场需求。在这种背景下&#xff0c;快速应…

Mybatis【分页插件,缓存,一级缓存,二级缓存,常见缓存面试题】

文章目录 MyBatis缓存分页延迟加载和立即加载什么是立即加载&#xff1f;什么是延迟加载&#xff1f;延迟加载/懒加载的配置 缓存什么是缓存&#xff1f;缓存的术语什么是MyBatis 缓存&#xff1f;缓存的适用性缓存的分类一级缓存引入案例一级缓存的配置一级缓存的工作流程一级…

【JavaSE基础】Java 基础知识

Java 转义字符 Java 常用的转义字符 在控制台&#xff0c;输入 tab 键&#xff0c;可以实现命令补全 转义字符含义作用\t制表符一个制表位&#xff0c;实现对齐的功能\n &#xff1a;换行符\n换行符一个换行符\r回车符一个回车键 System.out.println(“韩顺平教育\r 北京”);&…

java实现,PDF转换为TIF

目录 ■JDK版本 ■java代码・实现效果 ■POM引用 ■之前TIF相关的问题&#xff08;两张TIF合并&#xff09; ■对于成果物TIF&#xff0c;需要考虑的点 ■问题 ■问题1&#xff1a;无法生成TIF&#xff0c;已解决 ■问题2&#xff1a;生成的TIF过大&#xff0c;已解决 …

MySQL之DQL-分组函数

1、分组函数 1. 分组函数语法 分组函数也叫聚合函数。是对表中一组记录进行操作&#xff0c;每组只返回一个结果。我们只讲如下5个常用的分组函数&#xff1a; 分组函数 含义 MAX 求最大值 MIN 求最小值 SUM 求和 AVG 求平均值 COUNT 求个数 分组函数的语法如下…

Java中的强引用、软引用、弱引用和虚引用于JVM的垃圾回收机制

参考资料 https://juejin.cn/post/7123853933801373733 在 Java 中&#xff0c;引用类型分为四种&#xff1a;强引用&#xff08;Strong Reference&#xff09;、软引用&#xff08;Soft Reference&#xff09;、弱引用&#xff08;Weak Reference&#xff09;和虚引用&#xf…

水晶连连看 - 无限版软件操作说明书

水晶连连看 – 无限版游戏软件使用说明书 文章目录 水晶连连看 – 无限版游戏软件使用说明书1 引言1.1 编写目的1.2 项目名称1.3 项目背景1.4 项目开发环境 2 概述2.1 目标2.2 功能2.3 性能 3 运行环境3.1 硬件3.2 软件 4 使用说明4.1 游戏开始界面4.2 游戏设定4.2.1 游戏帮助4…

Android 15 正式发布到 AOSP ,来了解下新特性和适配需求

其实在年初的时候就整理过《2024 &#xff0c;Android 15 预览版来了》 和《提前窥探 Android 15 的新功能与适配》的相关内容&#xff0c;而随着时间进度推进&#xff0c;近日谷歌也正式发布了 Android 15 的正式版&#xff0c;虽然没什么「大亮点」&#xff0c;但是作为开发者…

11.2.软件系统分析与设计-数据库分析与设计

数据库分析与设计 数据库分析与设计的步骤 ER图和关系模型