程序员学长 | 快速学习一个算法,GAN

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学习一个算法,GAN

GAN 如何工作?

GAN 由两个部分组成:生成器(Generator)和判别器(Discriminator)。这两个部分通过一种对抗的过程来相互改进和优化。

c03159492fff41b799068bb5cf88bd59.png

生成器(Generator)

生成器的任务是接收一个随机噪声向量,并将其转换为看起来尽可能真实的数据。它通常使用一个深度神经网络来实现,从随机噪声中生成类似于训练数据的样本。

生成器的目标是生成的样本能够骗过判别器,使判别器认为生成的数据是真实的。

判别器(Discriminator)

判别器是一个二分类模型,它的任务是区分真实数据和生成器生成的假数据。

判别器接收一个数据样本,并输出一个概率值,表示该样本是真实数据的概率。

判别器的目标是尽可能准确地将真实数据与生成数据区分开来。

对抗过程

GAN 的训练过程可以被看作是生成器和判别器之间的博弈。具体来说,生成器试图生成逼真的数据以欺骗判别器,而判别器则试图更好地识别生成的假数据

这个过程可以描述为一个最小最大化的优化问题。

4c8b68b8b9184d9092ffda2498ba9173.png

鉴别器 D 想要最大化目标函数,使得 D(x) 接近于 1,D(G(z)) 接近于 0。这意味着鉴别器应该将训练集中的所有图像识别为真实 (1),将所有生成的图像识别为假 (0)。

生成器 (G) 想要最小化目标函数,使得 D(G(z)) 为 1。这意味着生成器试图生成被鉴别器网络分类为 1 的图像。

训练步骤

  1. 初始化生成器和判别器的参数

  2. 判别器训练

    • 从真实数据集中采样一个 mini-batch 的真实数据。

    • 从生成器的噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新判别器的参数,使其能够更好地区分真实数据和生成数据。

      f6487b7759c24cf9aa3570065e6f957b.png

  3. 生成器训练

    861dc76c582746c281d4d47c3e1156c0.png
    • 从噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新生成器的参数,使其生成的数据更能够欺骗判别器。

  4. 重复上述过程,直到生成器生成的数据足够逼真,判别器无法准确区分真实数据和生成数据

GAN 在图像生成、图像修复、超分辨率、风格迁移、数据增强等领域有广泛应用。例如,通过 GAN,可以生成高分辨率的图像,将低分辨率图像转换为高分辨率图像,或者将某种风格的图像转换为另一种风格的图像。

案例分享

下面是使用 GAN 来生成图像的案例。这里我们以手写数字识别数据集为例进行说明。

1.读取数据集

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as npnum_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100
batch_size = 32
(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

2.定义生成器

生成器 (D) 在 GAN 中扮演着至关重要的角色,因为它负责生成能够欺骗鉴别器的真实图像。

它是 GAN 中图像形成的主要组件。

在本文中,我们为生成器使用了一种特定的架构,该架构包含一个完全连接 (FC) 层并采用 Leaky ReLU 激活。然而,值得注意的是,生成器的最后一层使用 TanH 激活而不是 LeakyReLU。

def build_generator():gen_model = Sequential()gen_model.add(Dense(256, input_dim=z_size))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(512))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(1024))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(np.prod(input_shape), activation='tanh'))gen_model.add(Reshape(input_shape))gen_noise = Input(shape=(z_size,))gen_img = gen_model(gen_noise)return Model(gen_noise, gen_img)

3.定义鉴别器

在生成对抗网络 (GAN) 中,鉴别器 (D) 通过评估真实性和可能性来执行区分真实图像和生成图像的关键任务。

此组件可以看作是一个二元分类问题。

为了解决此任务,我们可以采用一个简化的网络架构,该架构由全连接层 (FC)、Leaky ReLU 激活和 Dropout 层组成。值得一提的是,鉴别器的最后一层包括 FC 层,后跟 Sigmoid 激活。Sigmoid 激活函数产生所需的分类概率。

def build_discriminator():disc_model = Sequential()disc_model.add(Flatten(input_shape=input_shape))disc_model.add(Dense(512))disc_model.add(LeakyReLU(alpha=0.2))disc_model.add(Dense(256))disc_model.add(LeakyReLU(alpha=0.2))disc_model.add(Dense(1, activation='sigmoid'))disc_img = Input(shape=input_shape)validity = disc_model(disc_img)return Model(disc_img, validity)

4.计算损失函数

我们可以使用二元交叉熵损失来实现生成器和鉴别器。

# discriminator
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',optimizer='sgd',metrics=['accuracy'])z = Input(shape=(z_size,))# generator
img = generator(z)disc.trainable = Falsevalidity = disc(img)# combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')

5.优化损失

def intialize_model():disc= build_discriminator()disc.compile(loss='binary_crossentropy',optimizer='sgd',metrics=['accuracy'])generator = build_generator()z = Input(shape=(z_size,))img = generator(z)disc.trainable = Falsevalidity = disc(img)combined = Model(z, validity)combined.compile(loss='binary_crossentropy', optimizer='sgd')return disc, generator, combined

6. 模型训练

def train(epochs, batch_size=128, sample_interval=50):# load images(train_ims, _), (_, _) = mnist.load_data()# preprocesstrain_ims = train_ims / 127.5 - 1.train_ims = np.expand_dims(train_ims, axis=3)valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# training loopfor epoch in range(epochs):batch_index = np.random.randint(0, train_ims.shape[0], batch_size)imgs = train_ims[batch_index]# create noisenoise = np.random.normal(0, 1, (batch_size, z_size))# predict using a Generatorgen_imgs = gen.predict(noise)# calculate loss functionsreal_disc_loss = disc.train_on_batch(imgs, valid)fake_disc_loss = disc.train_on_batch(gen_imgs, fake)disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)noise = np.random.normal(0, 1, (batch_size, z_size))g_loss = full_model.train_on_batch(noise, valid)# save outputs every few epochsif epoch % sample_interval == 0:one_batch(epoch)

7.生成手写数字

使用 MNIST 数据集,我们可以创建一个实用函数,使生成器为一组图像生成预测。

此函数生成随机声音,将其提供给生成器,运行它以显示生成的图像并将其保存在特殊文件夹中。建议定期运行此实用函数,例如每 200 个周期运行一次,以监控网络进度。

def one_batch(epoch):r, c = 5, 5noise_model = np.random.normal(0, 1, (r * c, z_size))gen_images = gen.predict(noise_model)# Rescale images 0 - 1gen_images = gen_images*(0.5) + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()

在我们的实验中,我们使用 32 的批次大小对 GAN 进行了大约 10,000 个时期的训练。

为了跟踪训练的进度,我们每 200 个时期保存一次生成的图像,并将它们存储在名为 “images” 的指定文件夹中。

disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)

现在,我们来检查一下不同阶段的 GAN 模拟结果。

初始化、5000 个 epoch 以及 10000 个 epoch 的最终结果。

最初,我们以随机噪声作为生成器的输入。

2e954b9c4c274f3e98d6ff8c9ba2be87.png

经过 5000 个时期的训练后,我们可以观察到生成的图形开始类似于 MNIST 数据集。

1a0dc57d0de14beebed11deee9879f68.png

经过 10,000 个时期的训练后,我们获得以下输出。

7c9ae539548c463db3d7b6a41fe4db86.png

可以看到,这些生成的图像与手写数字数据已经非常相似了。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45073.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

day11:01文件处理

一、文件与文件模式介绍 1、什么是文件 文件是操作系统提供给用户/应用程序操作硬盘的一种虚拟的概念/接口 用户/应用程序(open()) 操作系统(文件) 计算机硬件(硬盘)2、为何要用文件 ①用户/应用程序可以通过文件将数据永久保存…

使用Apache Beam进行统一批处理与流处理

Apache Beam是一个开源的统一编程模型,用于定义和执行数据处理流水线,支持批处理和流处理。Beam旨在提供一个简单、可扩展且灵活的框架,适用于各种数据处理任务。本文将详细介绍如何使用Apache Beam进行批处理和流处理,并通过Java…

从0开始基于transformer进行股价预测(pytorch版本)

目录 数据阶段两个问题开始利用我们的代码进行切分 backbone网络训练效果 感觉还行,没有调参数。源码比较长,如果需要我后续会发(因为太长了!!) 数据阶段 !!!注意&#…

多个uilabel添加同一个UITapGestureRecognizer对象,只有最后那个生效么

如果多个 UILabel 添加同一个 UITapGestureRecognizer 对象,确实只有最后一个 UILabel 会响应手势。这是因为一个手势识别器只能被添加到一个视图上,多次添加实际上是重新指定该识别器的视图目标。 要实现多个 UILabel 响应相同的手势,可以为…

还不懂 OOM ?详解内存溢出与内存泄漏区别!

内存溢出与内存泄漏 1. 内存溢出(Out Of Memory,OOM) 概念: 内存溢出是指程序在运行过程中,尝试申请的内存超过了系统所能提供的最大内存限制,并且垃圾收集器也无法提供更多的内存,导致程序无…

# Redis 入门到精通(一)数据类型(3)

Redis 入门到精通(一)数据类型(3) 一、redis 数据类型–set 类型介绍与基本操作 1、set 类型 新的存储需求: 存储大量的数据,在查询方面提供更高的效率。需要的存储结构: 能够保存大量的数据,高效的内部…

【爬虫】解析爬取的数据

目录 一、正则表达式1、常用元字符2、量词3、Re模块4、爬取豆瓣电影 二、Xpath1、Xpath解析Ⅰ、节点选择Ⅱ、路径表达式Ⅲ、常用函数 2、爬取豆瓣电影 解析数据,除了前面的BeautifulSoup库,还有正则表达式和Xpath两种方法。 一、正则表达式 正则表达式…

C++|智能指针

目录 引入 一、智能指针的使用及原理 1.1RAII 1.2智能指针原理 1.3智能指针发展 1.3.1std::auto_ptr 1.3.2std::unique_ptr 1.3.3std::shared_ptr 二、循环引用问题及解决方法 2.1循环引用 2.2解决方法 三、删除器 四、C11和boost中智能指针的关系 引入 回顾上…

谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码

文章目录 一,使用逆向工程步骤梳理1,修改逆向工程的application.yml配置2,修改逆向工程的generator.properties配置3,以Debug模式启动逆向工程4,使用逆向工程生成代码5,整合生成的代码到对应的模块中 二&am…

VPS拨号服务器:独享的高效与安全

在当今互联网高速发展的时代,虚拟私人服务器(VPS)已成为许多企业和个人用户托管网站、应用程序的首选。特别是带有拨号功能的VPS服务器,以其独特的优势受到广泛关注。本文将深入探讨VPS拨号服务器的独享特性,以及它如何…

Vue 使用Audio或AudioContext播放本地音频

使用Audio 第一种 使用标签方式 <audio src"./tests.mp3" ref"audio"></audio><el-button click"audioPlay()">播放Audio</el-button>audioPlay() {this.$refs.audio.currentTime 0;this.$refs.audio.play();// this.$…

c++方法

std::transform方法 std::transform 是 C 标准库算法中的一个非常有用的函数&#xff0c;它定义在头文件 中。这个函数用于将给定范围内的每个元素按照指定的操作进行转换&#xff0c;并将转换结果存储在另一个位置&#xff08;可以是原始范围的另一个容器&#xff0c;或者完全…

HarmonyOS应用开发前景及使用工具

HarmonyOS应用开发001 文章目录 前言前景一、技术特性二、使用工具1.项目目录结构 前言 学习之前&#xff0c;需要有一定的开发基础&#xff08;如&#xff1a;java、c#、c、WEB前端的一些了解)。 HarmonyOS开发使用的ArkTS&#xff0c;ArkTS是在TS的基础之上进行封装的&#…

外科休克病人的护理

一、引言 休克是外科常见的危急重症之一,它是由于机体遭受强烈的致病因素侵袭后,有效循环血量锐减、组织灌注不足所引起的以微循环障碍、细胞代谢紊乱和器官功能受损为特征的综合征。对于外科休克病人的护理,至关重要。 二、休克的分类 外科休克主要分为低血容量性休克(包括…

VMware Workstation 虚拟机网络配置为与主机使用同一网络

要将 VMware Workstation 虚拟机网络配置为与主机使用同一网络&#xff0c;我们需要将虚拟机的网络适配器设置为桥接模式。具体步骤如下&#xff1a; 配置 VMware Workstation 虚拟机网络为桥接模式 打开 VMware Workstation&#xff1a; 启动 VMware Workstation。 选择虚拟机…

博客网站目录网址导航自适应主题php源码

开源免费 博客屋网址导航自适应主题php源码v1.0是一款免费开源的PHP分类导航建站程序&#xff0c;源代码公开且无任何加密代码、安全有保障、无后门隐患。 系统稳定 内核安全稳定、PHPMYSQL/Sqlite架构、跨平台运行;版本自带ico接口集成&#xff0c;添加网站时&#xff0c;可自…

PostGIS2.4服务器编译安装

PostGIS的最新版本已经到3.5&#xff0c;但是还有一些国产数据库内核使用的旧版本的PostgreSQL&#xff0c;支持PostGIS2.4。但PostGIS2.4的版本已经在yum中找不到了&#xff0c;安装只能通过本地编译的方式。这里介绍一下如何在Centos7的系统上&#xff0c;编译部署PostGIS2.4…

实验场:在几分钟内使用 Bedrock Anthropic Models 和 Elasticsearch 进行 RAG 实验

作者&#xff1a;来自 Elastic Joe McElroy, Aditya Tripathi 我们最近发布了 Elasticsearch Playground&#xff0c;这是一个新的低代码界面&#xff0c;开发人员可以通过 A/B 测试 LLM、调整提示&#xff08;prompt&#xff09;和分块数据来迭代和构建生产 RAG 应用程序。今天…

Web3学习路线图,从入门到精通

前面我们聊了Web3的知识图谱&#xff0c;内容是相当的翔实&#xff0c;要从哪里入手可以快速的入门Web3&#xff0c;本篇就带你看看Web3的学习路线图&#xff0c;一步一步深入学习Web3。 这张图展示了Web3学习路线图&#xff0c;涵盖了区块链基础知识、开发方向、应用开发等内…

桥接模式案例

桥接模式&#xff08;Bridge Pattern&#xff09;是一种结构型设计模式&#xff0c;它将抽象部分与实现部分分离&#xff0c;使它们可以独立变化。桥接模式通过创 建一个桥接接口&#xff0c;将抽象部分和实现部分连接起来&#xff0c;从而实现两者的解耦。下面是一个详细的桥接…