程序员学长 | 快速学习一个算法,GAN

本文来源公众号“程序员学长”,仅用于学术分享,侵权删,干货满满。

原文链接:快速学习一个算法,GAN

GAN 如何工作?

GAN 由两个部分组成:生成器(Generator)和判别器(Discriminator)。这两个部分通过一种对抗的过程来相互改进和优化。

c03159492fff41b799068bb5cf88bd59.png

生成器(Generator)

生成器的任务是接收一个随机噪声向量,并将其转换为看起来尽可能真实的数据。它通常使用一个深度神经网络来实现,从随机噪声中生成类似于训练数据的样本。

生成器的目标是生成的样本能够骗过判别器,使判别器认为生成的数据是真实的。

判别器(Discriminator)

判别器是一个二分类模型,它的任务是区分真实数据和生成器生成的假数据。

判别器接收一个数据样本,并输出一个概率值,表示该样本是真实数据的概率。

判别器的目标是尽可能准确地将真实数据与生成数据区分开来。

对抗过程

GAN 的训练过程可以被看作是生成器和判别器之间的博弈。具体来说,生成器试图生成逼真的数据以欺骗判别器,而判别器则试图更好地识别生成的假数据

这个过程可以描述为一个最小最大化的优化问题。

4c8b68b8b9184d9092ffda2498ba9173.png

鉴别器 D 想要最大化目标函数,使得 D(x) 接近于 1,D(G(z)) 接近于 0。这意味着鉴别器应该将训练集中的所有图像识别为真实 (1),将所有生成的图像识别为假 (0)。

生成器 (G) 想要最小化目标函数,使得 D(G(z)) 为 1。这意味着生成器试图生成被鉴别器网络分类为 1 的图像。

训练步骤

  1. 初始化生成器和判别器的参数

  2. 判别器训练

    • 从真实数据集中采样一个 mini-batch 的真实数据。

    • 从生成器的噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新判别器的参数,使其能够更好地区分真实数据和生成数据。

      f6487b7759c24cf9aa3570065e6f957b.png

  3. 生成器训练

    861dc76c582746c281d4d47c3e1156c0.png
    • 从噪声分布中采样一个 mini-batch 的噪声,并生成假数据。

    • 更新生成器的参数,使其生成的数据更能够欺骗判别器。

  4. 重复上述过程,直到生成器生成的数据足够逼真,判别器无法准确区分真实数据和生成数据

GAN 在图像生成、图像修复、超分辨率、风格迁移、数据增强等领域有广泛应用。例如,通过 GAN,可以生成高分辨率的图像,将低分辨率图像转换为高分辨率图像,或者将某种风格的图像转换为另一种风格的图像。

案例分享

下面是使用 GAN 来生成图像的案例。这里我们以手写数字识别数据集为例进行说明。

1.读取数据集

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as npnum_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100
batch_size = 32
(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

2.定义生成器

生成器 (D) 在 GAN 中扮演着至关重要的角色,因为它负责生成能够欺骗鉴别器的真实图像。

它是 GAN 中图像形成的主要组件。

在本文中,我们为生成器使用了一种特定的架构,该架构包含一个完全连接 (FC) 层并采用 Leaky ReLU 激活。然而,值得注意的是,生成器的最后一层使用 TanH 激活而不是 LeakyReLU。

def build_generator():gen_model = Sequential()gen_model.add(Dense(256, input_dim=z_size))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(512))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(1024))gen_model.add(LeakyReLU(alpha=0.2))gen_model.add(BatchNormalization(momentum=0.8))gen_model.add(Dense(np.prod(input_shape), activation='tanh'))gen_model.add(Reshape(input_shape))gen_noise = Input(shape=(z_size,))gen_img = gen_model(gen_noise)return Model(gen_noise, gen_img)

3.定义鉴别器

在生成对抗网络 (GAN) 中,鉴别器 (D) 通过评估真实性和可能性来执行区分真实图像和生成图像的关键任务。

此组件可以看作是一个二元分类问题。

为了解决此任务,我们可以采用一个简化的网络架构,该架构由全连接层 (FC)、Leaky ReLU 激活和 Dropout 层组成。值得一提的是,鉴别器的最后一层包括 FC 层,后跟 Sigmoid 激活。Sigmoid 激活函数产生所需的分类概率。

def build_discriminator():disc_model = Sequential()disc_model.add(Flatten(input_shape=input_shape))disc_model.add(Dense(512))disc_model.add(LeakyReLU(alpha=0.2))disc_model.add(Dense(256))disc_model.add(LeakyReLU(alpha=0.2))disc_model.add(Dense(1, activation='sigmoid'))disc_img = Input(shape=input_shape)validity = disc_model(disc_img)return Model(disc_img, validity)

4.计算损失函数

我们可以使用二元交叉熵损失来实现生成器和鉴别器。

# discriminator
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',optimizer='sgd',metrics=['accuracy'])z = Input(shape=(z_size,))# generator
img = generator(z)disc.trainable = Falsevalidity = disc(img)# combined model
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')

5.优化损失

def intialize_model():disc= build_discriminator()disc.compile(loss='binary_crossentropy',optimizer='sgd',metrics=['accuracy'])generator = build_generator()z = Input(shape=(z_size,))img = generator(z)disc.trainable = Falsevalidity = disc(img)combined = Model(z, validity)combined.compile(loss='binary_crossentropy', optimizer='sgd')return disc, generator, combined

6. 模型训练

def train(epochs, batch_size=128, sample_interval=50):# load images(train_ims, _), (_, _) = mnist.load_data()# preprocesstrain_ims = train_ims / 127.5 - 1.train_ims = np.expand_dims(train_ims, axis=3)valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))# training loopfor epoch in range(epochs):batch_index = np.random.randint(0, train_ims.shape[0], batch_size)imgs = train_ims[batch_index]# create noisenoise = np.random.normal(0, 1, (batch_size, z_size))# predict using a Generatorgen_imgs = gen.predict(noise)# calculate loss functionsreal_disc_loss = disc.train_on_batch(imgs, valid)fake_disc_loss = disc.train_on_batch(gen_imgs, fake)disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)noise = np.random.normal(0, 1, (batch_size, z_size))g_loss = full_model.train_on_batch(noise, valid)# save outputs every few epochsif epoch % sample_interval == 0:one_batch(epoch)

7.生成手写数字

使用 MNIST 数据集,我们可以创建一个实用函数,使生成器为一组图像生成预测。

此函数生成随机声音,将其提供给生成器,运行它以显示生成的图像并将其保存在特殊文件夹中。建议定期运行此实用函数,例如每 200 个周期运行一次,以监控网络进度。

def one_batch(epoch):r, c = 5, 5noise_model = np.random.normal(0, 1, (r * c, z_size))gen_images = gen.predict(noise_model)# Rescale images 0 - 1gen_images = gen_images*(0.5) + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()

在我们的实验中,我们使用 32 的批次大小对 GAN 进行了大约 10,000 个时期的训练。

为了跟踪训练的进度,我们每 200 个时期保存一次生成的图像,并将它们存储在名为 “images” 的指定文件夹中。

disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)

现在,我们来检查一下不同阶段的 GAN 模拟结果。

初始化、5000 个 epoch 以及 10000 个 epoch 的最终结果。

最初,我们以随机噪声作为生成器的输入。

2e954b9c4c274f3e98d6ff8c9ba2be87.png

经过 5000 个时期的训练后,我们可以观察到生成的图形开始类似于 MNIST 数据集。

1a0dc57d0de14beebed11deee9879f68.png

经过 10,000 个时期的训练后,我们获得以下输出。

7c9ae539548c463db3d7b6a41fe4db86.png

可以看到,这些生成的图像与手写数字数据已经非常相似了。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45073.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从0开始基于transformer进行股价预测(pytorch版本)

目录 数据阶段两个问题开始利用我们的代码进行切分 backbone网络训练效果 感觉还行,没有调参数。源码比较长,如果需要我后续会发(因为太长了!!) 数据阶段 !!!注意&#…

还不懂 OOM ?详解内存溢出与内存泄漏区别!

内存溢出与内存泄漏 1. 内存溢出(Out Of Memory,OOM) 概念: 内存溢出是指程序在运行过程中,尝试申请的内存超过了系统所能提供的最大内存限制,并且垃圾收集器也无法提供更多的内存,导致程序无…

# Redis 入门到精通(一)数据类型(3)

Redis 入门到精通(一)数据类型(3) 一、redis 数据类型–set 类型介绍与基本操作 1、set 类型 新的存储需求: 存储大量的数据,在查询方面提供更高的效率。需要的存储结构: 能够保存大量的数据,高效的内部…

【爬虫】解析爬取的数据

目录 一、正则表达式1、常用元字符2、量词3、Re模块4、爬取豆瓣电影 二、Xpath1、Xpath解析Ⅰ、节点选择Ⅱ、路径表达式Ⅲ、常用函数 2、爬取豆瓣电影 解析数据,除了前面的BeautifulSoup库,还有正则表达式和Xpath两种方法。 一、正则表达式 正则表达式…

C++|智能指针

目录 引入 一、智能指针的使用及原理 1.1RAII 1.2智能指针原理 1.3智能指针发展 1.3.1std::auto_ptr 1.3.2std::unique_ptr 1.3.3std::shared_ptr 二、循环引用问题及解决方法 2.1循环引用 2.2解决方法 三、删除器 四、C11和boost中智能指针的关系 引入 回顾上…

谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码

文章目录 一,使用逆向工程步骤梳理1,修改逆向工程的application.yml配置2,修改逆向工程的generator.properties配置3,以Debug模式启动逆向工程4,使用逆向工程生成代码5,整合生成的代码到对应的模块中 二&am…

VMware Workstation 虚拟机网络配置为与主机使用同一网络

要将 VMware Workstation 虚拟机网络配置为与主机使用同一网络,我们需要将虚拟机的网络适配器设置为桥接模式。具体步骤如下: 配置 VMware Workstation 虚拟机网络为桥接模式 打开 VMware Workstation: 启动 VMware Workstation。 选择虚拟机…

实验场:在几分钟内使用 Bedrock Anthropic Models 和 Elasticsearch 进行 RAG 实验

作者:来自 Elastic Joe McElroy, Aditya Tripathi 我们最近发布了 Elasticsearch Playground,这是一个新的低代码界面,开发人员可以通过 A/B 测试 LLM、调整提示(prompt)和分块数据来迭代和构建生产 RAG 应用程序。今天…

Web3学习路线图,从入门到精通

前面我们聊了Web3的知识图谱,内容是相当的翔实,要从哪里入手可以快速的入门Web3,本篇就带你看看Web3的学习路线图,一步一步深入学习Web3。 这张图展示了Web3学习路线图,涵盖了区块链基础知识、开发方向、应用开发等内…

接上一回C++:补继承漏洞+多态原理(带图详解)

引子:接上一回我们讲了继承的分类与六大默认函数,其实继承中的菱形继承是有一个大坑的,我们也要进入多态的学习了。 注意:我学会了,但是讲述上可能有一些不足,希望大家多多包涵 继承复习: 1&…

windows环境下基于3DSlicer 源代码编译搭建工程开发环境详细操作过程和中间关键错误解决方法说明

说明: 该文档适用于  首次/重新 搭建3D-Slicer工程环境  Clean up(非增量) 编译生成 1. 3D-slicer 软件介绍 (1)3D Slicer为处理MRI\CT等图像数据软件,可以实行基于MRI图像数据的目标分割、标记测量、坐标变换及三维重建等功能,其源于3D slicer 4.13.0-2022-01-19开…

OS Copilot测评

1.按照第一步管理重置密码时报错了,搞不懂为啥?本来应该跳转到给的那个实例的,我的没跳过去 2.下一步重置密码的很丝滑没问题 3安全组新增入库22没问题 很方便清晰 4.AccessKey 还能进行预警提示 5.远程连接,网速还是很快,一点没卡,下载很棒 6.替换的时候我没有替换<>括…

【JavaEE】网络编程——UDP

&#x1f921;&#x1f921;&#x1f921;个人主页&#x1f921;&#x1f921;&#x1f921; &#x1f921;&#x1f921;&#x1f921;JavaEE专栏&#x1f921;&#x1f921;&#x1f921; 文章目录 1.数据报套接字(UDP)1.1特点1.2编码1.2.1DatagramSocket1.2.2DatagramPacket…

Spring Cloud Alibaba AI 介绍及使用

一、Spring Cloud Alibaba AI 介绍 Spring AI 是 Spring 官方社区项目&#xff0c;旨在简化 Java AI 应用程序开发&#xff0c;让 Java 开发者像使用 Spring 开发普通应用一样开发 AI 应用。而 Spring Cloud Alibaba AI 是阿里以 Spring AI 为基础&#xff0c;并在此基础上提供…

dive deeper into tensor:从底层开始学习tensor

inspired by karpathy/micrograd: A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API (github.com)and Taking PyTorch for Granted | wh (nrehiew.github.io). 这属于karpathy的karpathy/nn-zero-to-hero: Neural Networks…

阐述 C 语言中的参数传递机制

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&#xff0c;看过的人都说好。 文章目…

多表查询sql

概述&#xff1a;项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互关联,所以各个表结构之间也存在着各种联系&#xff0c;分为三种&#xff1a; 一对多多对多一对一 一、多表关系 一对多 案例&#xff1a;部门与…

【PowerShell】-1-快速熟悉并使用PowerShell

目录 PowerShell是什么&#xff1f;和CMD的区别&#xff1f; PowerShell的演变 自动化IT管理任务 一些名词 详尽的PowerShell开始之路 1.打开PowerShell&#xff1a; 2.基本命令&#xff1a; &#xff08;1&#xff09;Get-Process &#xff08;2&#xff09;变量赋值…

【核心笔记】Java入门到起飞,小白都能看懂的Java教程 (五)——数组

一 数组的定义和初始化 定义数组 数据类型[] 数组名&#xff1b;例 int[] arr; 数据类型 数组名[]&#xff1b;例 int arr[]; 数组初始化 数据类型[] 数组名 new 数据类型[] {值}&#xff1b;例 int[] arr new int[] {1,2,3}; &#xff08;简化形式&#xff09;数据类型[] 数…

超赞!只需粘贴复制超赞,视频快速转换成文章

大家好&#xff01;我是闷声轻创&#xff01;是否还在为撰写高质量的文章而熬夜奋战&#xff1f;今天&#xff0c;我要给你们带来一个超级棒的消息——视频变文章的神奇工具&#xff0c;让你的创作之路从此不再艰辛&#xff01; 视频素材的宝藏——油管&#xff08;YTB&#xf…