白话生成对抗网络 GAN,50 行代码玩转 GAN 模型!【附源码】

今天,带大家一起来了解一下如今非常火热的深度学习模型:生成对抗网络(Generate Adversarial Network,GAN)。GAN 非常有趣,我就以最直白的语言来讲解它,最后实现一个简单的 GAN 程序来帮助大家加深理解。

什么是 GAN?

好了,GAN 如此强大,那它到底是一个什么样的模型结构呢?我们之前学习过的机器学习或者神经网络模型主要能做两件事:预测和分类,这也是我们所熟知的。那么是否可以让机器模型自动来生成一张图片、一段语音?而且可以通过调整不同模型输入向量来获得特定的图片和声音。例如,可以调整输入参数,获得一张红头发、蓝眼睛的人脸,可以调整输入参数,得到女性的声音片段,等等。也就是说,这样的机器模型能够根据需求,自动生成我们想要的东西。因此,GAN 应运而生!

GAN,即生成对抗网络,主要包含两个模块:生成器(Generative Model)和判别器(Discriminative Model)。生成模型和判别模型之间互相博弈、学习产生相当好的输出。以图片为例,生成器的主要任务是学习真实图片集,从而使得自己生成的图片更接近于真实图片,以“骗过”判别器。而判别器的主要任务是找出出生成器生成的图片,区分其与真实图片的不同,进行真假判别。在整个迭代过程中,生成器不断努力让生成的图片越来越像真的,而判别器不断努力识别出图片的真假。这类似生成器与判别器之间的博弈,随着反复迭代,最终二者达到了平衡:生成器生成的图片非常接近于真实图片,而判别器已经很难识别出真假图片的不同了。其表现是对于真假图片,判别器的概率输出都接近 0.5。
对 GAN 的概念还是有点不清楚?没关系,举个生动的例子来说明。
最近,红色石头想学习绘画,是因为看到梵大师的画作,也想画出类似的作品。梵大师的画作像这样:
说画就画,红色石头找来一个研究梵大师作品很多年的王教授来指导我。王教授经验丰富,眼光犀利,市面上模仿梵大师的画作都难逃他的法眼。王教授跟我说了一句话:什么时候你的画这幅画能骗过我,你就算是成功了。

红色石头很激动,立马给王教授画了这幅画:
王教授轻轻扫了一眼,满脸黑线,气的直哆嗦,“0 分!这也叫画?差得太多了!” 听了王教授的话,红色石头自我反省,确实画的不咋地,连眼睛、鼻子都没有。于是,又 重新画了一幅:
王教授一看,不到 2 秒钟,就丢下四个字:1 分!重画!红色石头一想,还是不行,画得太差了,就回去好好研究梵大师的画作风格,不断改进,重新创作,直到有一天,红色石头拿着新的画作给王教授看:
王教授看了一看,说有点像了。我得仔细看看。最后,还是跟我说,不行不行,细节太差!继续重新画吧。唉,王教授越来越严格了!红色石头叹了口气回去继续研究,最后将自我很满意的一幅画交给了王教授鉴赏:
这下,王教授戴着眼镜,仔细品析,许久之后,王教授拍着我的肩膀说,画得很好,我已经识别不了真假了。哈哈,得到了王教授的夸奖和肯定,心里美滋滋,终于可以创作出梵大师样的绘画作品了。下一步考虑转行去。
好了,例子说完了(接受大家对我绘画天赋的吐槽)。这个例子,其实就是一个 GAN 训练的过程。红色石头就是生成器,目的就是要输出一幅画能够骗过王教授,让王教授真假难辨!王教授就是判别器,目的就是要识别出红色石头的画作,判断其为假的!整个过程就是“生成 — 对抗”的博弈过程,最终,红色石头(生成器)输出一幅“以假乱真”的画作,连王教授(判别器)都难以区分了。
这就是 GAN,懂了吧。

GAN 模型基本结构

在认识 GAN 模型之前,我们先来看一看 Yann LeCun 对未来深度学习重大突破技术点的个人看法:

The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This is an idea that was originally proposed by Ian Goodfellow when he was a student with Yoshua Bengio at the University of Montreal (he since moved to Google Brain and recently to OpenAI).
This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.

Yann LeCun 认为 GAN 很可能会给深度学习模型带来新的重大突破,是20年来机器学习领域最酷的想法。这几年 GAN 发展势头非常强劲。下面这张图是近几年 ICASSP 会议上所有提交的论文中包含关键词 “generative”、“adversarial” 和 “reinforcement” 的论文数量统计。
数据表明,2018 年,包含关键词 “generative” 和 “adversarial” 的论文数量发生井喷式增长。不难预见, 未来几年关于 GAN 的论文会更多。下面来介绍一下 GAN 的基本结构,我们已经知道了 GAN 由生成器和判别器组成,各用 G 和 D 表示。以生成图片应用为例,其模型结构如下所示:
GAN 基本模型由 输入 Vector、G 网络、D 网络组成。其中,G 和 D 一般都是由神经网络组成。G 的输出是一幅图片,只不过是以全连接形式。G 的输出是 D 的输入,D 的输入还包含真实样本集。这样, D 对真实样本尽量输出 score 高一些,对 G 产生的样本尽量输出 score 低一些。每次循环迭代,G 网络不断优化网络参数,使 D 无法区分真假;而 D 网络也在不断优化网络参数,提高辨识度,让真假样本的 score 有差距。
最终,经过多次训练迭代,GAN 模型建立:
最终的 GAN 模型中,G 生成的样本以假乱真,D 输出的 score 接近 0.5,即表示真假样本难以区分,训练成功。
这里,重点要讲解一下输入 vector。输入向量是用来做什么的呢?其实,输入 vector 中的每一维度都可以代表输出图片的某个特征。比如说,输入 vector 的第一个维度数值大小可以调节生成图片的头发颜色,数值大一些是红色,数值小一些是黑色;输入 vector 的第二个维度数值大小可以调节生成图片的肤色;输入 vector 的第三个维度数值大小可以调节生成图片的表情情绪,等等。
GAN 的强大之处也正是在于此,通过调节输入 vector,就可以生成具有不同特征的图片。而这些生成的图片不是真实样本集里有的,而是即合理而又没有见过的图片。是不是很有意思呢?下面这张图反映的是不同的 vector 生成不同的图片。
说完了 GAN 的模型之后,我们再来简单看一下 GAN 的算法原理。既然有两个模块:G 和 D,每个模块都有相应的网络参数。
先来看 D 模块,它的目标是让真实样本 score 越大越好,让 G 产生的样本 score 越小越好。那么可以得到 D 的损失函数为:
其中,x 是真实样本,G(z) 是 G 生成样本。我们希望 D(x) 越大越好,D(G(z)) 越小越好,也就是希望 -D(x) 越小越好,-log(1-D(G(z))) 越小越好。从损失函数的角度来说,能够得到上式。
再来看 G 模块,它的目标就是希望其生成的模型能够在 D 中得到越高的分数越好。那么可以得到 G 的损失函数为:
知道了损失函数之后,接下来就可以使用各种优化算法来训练模型了。

动手写个 GAN 模型

接下来,我将使用 PyTorch 实现一个简单的 GAN 模型。仍然以绘画创作为例,假设我们要创造如下“名画”(以正弦图形为例):

生成该“艺术画作”的代码如下:

def artist_works():    # painting from the famous artist (real target)    r = 0.02 * np.random.randn(1, ART_COMPONENTS)    paintings = np.sin(PAINT_POINTS * np.pi) + r    paintings = torch.from_numpy(paintings).float()    return paintings

然后,分别定义 G 网络和 D 网络模型:

G = nn.Sequential(                  # Generator    nn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)    nn.ReLU(),    nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas)D = nn.Sequential(                  # Discriminator    nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G    nn.ReLU(),    nn.Linear(128, 1),    nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)

我们设置 Adam 算法进行优化:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

最后,构建 GAN 迭代训练过程:

plt.ion()    # something about continuous plottingD_loss_history = []
G_loss_history = []
for step in range(10000):    artist_paintings = artist_works()          # real painting from artist    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas     G_paintings = G(G_ideas)                   # fake painting from G (random ideas)        prob_artist0 = D(artist_paintings)         # D try to increase this prob    prob_artist1 = D(G_paintings)              # D try to reduce this prob        D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))    G_loss = torch.mean(torch.log(1. - prob_artist1))        D_loss_history.append(D_loss)    G_loss_history.append(G_loss)  opt_D.zero_grad()    D_loss.backward(retain_graph=True)    # reusing computational graph    opt_D.step()        opt_G.zero_grad()    G_loss.backward()    opt_G.step()        if step % 50 == 0:  # plotting        plt.cla()        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)        plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')        plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})     plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})        plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)plt.ioff()
plt.show()

我采用了动态绘图的方式,便于时刻观察 GAN 模型训练情况。
迭代次数为 1 时:
迭代次数为 200 时:
迭代次数为 1000 时:
迭代次数为 10000 时:
完美!经过 10000 次迭代训练之后,生成的曲线已经与标准曲线非常接近了。D 的 score 也如预期接近 0.5。
完整代码有 .py 和 .ipynb 两种版本,需要的请点击「阅读原文」获取。

一个值得关注的 AI 技术的公众号。作者红色石头是专注于人工智能的 CSDN 博客专家和知乎专栏作者。本公众号主要涉及人工智能领域 Python、ML 、CV、NLP 等前沿知识、干货笔记和优质资源!我们致力于为您提供切实可行的 AI 学习路线。

个人网站 :www.redstonewill.com


AI有道优质文章精选

  • 干货 | 126 篇 AI 原创文章精选(ML、DL、资源、教程) 

  • 【通俗易懂】机器学习中 L1 和 L2 正则化的直观解释 

  • 简单的交叉熵损失函数,你真的懂了吗? 

  • 划重点!十分钟掌握牛顿法凸优化 

  • 简单的梯度下降算法,你真的懂了吗?

夕小瑶的卖萌屋

_

关注&星标小夕,带你解锁AI秘籍

订阅号主页下方「撩一下」有惊喜哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480493.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java架构师进阶之独孤九剑(一)-算法思想与经典算法

“ 这是整个架构师连载系列,分为9大步骤,我们现在还在第一个步骤:程序设计和开发->数据结构与算法。 我们今天讲解重点讲解算法。 算法思想 1 贪心思想 顾名思义,贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并…

数据结构--链表--单链表中环的检测,环的入口,环的长度的计算

就如数字6一样的单链表结构,如何检测是否有6下部的○呢,并且求交叉点位置 思路 使用快慢指针(一个一次走2步,一个走1步),若快慢指针第一次相遇,则有环 慢指针路程 sabs absab 快指针路程 2sa…

ACL 2010-2020研究趋势总结

一只小狐狸带你解锁 炼丹术&NLP 秘籍作者:哈工大SCIR 车万翔教授导读2020年5月23日,有幸受邀在中国中文信息学会青年工作委员会主办的AIS(ACL-IJCAI-SIGIR)2020顶会论文预讲会上介绍了ACL会议近年来的研究趋势,特整…

架构师进阶之独孤九剑:设计模式详解

我们继续架构师进阶之独孤九剑进阶,目前我们仍然在第一阶段:程序设计和开发环节。 “ 设计模式不仅仅只是一种规范,更多的是一种设计思路和经验总结,目的只有一个:提高你高质量编码的能力。以下主要分为三个环节&…

知识表示发展史:从一阶谓词逻辑到知识图谱再到事理图谱

研究证实,人类从一出生即开始累积庞大且复杂的数据库,包括各种文字、数字、符码、味道、食物、线条、颜色、公式、声音等,大脑惊人的储存能力使我们累积了海量的资料,这些资料构成了人类的认知知识基础。实验表明,将数…

领域应用 | 基于知识图谱的警用安保机器人大数据分析技术研究

本文转载自公众号:警察技术杂志。 郝久月 樊志英 汪宁 王欣 摘 要:构建大数据支撑下的智能应用是公安信息化发展的趋势,警用安保机器人大数据分析平台的核心功能包括机器人智能人机交互和前…

数据挖掘学习指南!!

入门数据挖掘,必须理论结合实践。本文梳理了数据挖掘知识体系,帮助大家了解和提升在实际场景中的数据分析、特征工程、建模调参和模型融合等技能。完整项目实践(共100多页)后台回复 数据挖掘电子版 获取数据分析探索性数据分析&am…

数据结构--栈--顺序栈/链式栈(附: 字符括号合法配对检测)

栈结构:先进后出,后进先出,像叠盘子一样,先叠的后用。 代码github地址 https://github.com/hitskyer/course/tree/master/dataAlgorithm/chenmingming/stack 1.顺序栈(数组存储,需给定数组大小&#xff0c…

银行计考试-计算机考点2-计算机系统组成与基本工作原理

版权声明&#xff1a;本文为博主原创文章&#xff0c;未经博主允许不得转载。 https://blog.csdn.net/sinat_33363493/article/details/53647129 </div><link rel"stylesheet" href"https://csdnimg.cn/release/pho…

我们的实践: 400万全行业动态事理图谱Demo

历史经验知识在未来预测的应用 华尔街的独角兽Kensho&#xff0c;是智能金融Fintech的一个不得不提的成功案例&#xff0c;这个由高盛领投的6280万美元投资&#xff0c;总融资高达7280万美元的公司自推出后便名声大噪。Warren是kensho是一个代表产品&#xff0c;用户能够以通俗…

蚂蚁花呗团队面试题:LinkedHashMap+SpringCloud+线程锁+分布式

一面 自我介绍 map怎么实现hashcode和equals,为什么重写equals必须重写hashcode 使用过concurrent包下的哪些类&#xff0c;使用场景等等。 concurrentHashMap怎么实现&#xff1f;concurrenthashmap在1.8和1.7里面有什么区别 CountDownLatch、LinkedHashMap、AQS实现原理 …

肖仰华 | SIGIR 2018、WWW2018 知识图谱研究综述

本文转载自公众号&#xff1a;知识工场。全国知识图谱与语义计算大会&#xff08;CCKS: China Conference on Knowledge Graph and Semantic Computing&#xff09;由中国中文信息学会语言与知识计算专委会定期举办的全国年度学术会议。CCKS源于国内两个主要的相关会议&#xf…

数据结构--栈--共享顺序栈

共享顺序栈&#xff1a;内部也是一个数组 将两个栈放在数组的两端&#xff0c;一个从数组首端开始压栈&#xff0c;一个从数组尾部开始压栈&#xff0c;等到两边栈顶在中间相遇时&#xff0c;栈满。 共享顺序栈在某些情况下可以节省空间。 头文件 sharingStack.h //共享顺序…

一个励志PM小哥哥的Java转型之路

先给大家看张我朋友圈截图&#xff1a; 这哥们本科学英语的&#xff0c;毕业后做了产品经理&#xff0c;去年 9 月份开始学 Java&#xff0c;6 个月的时间&#xff0c;拿到了快手的 Offer。如果你对 Java 也有兴趣&#xff0c;不妨听完这个故事。你是不是也和他当时的处境…

最全蚂蚁金服高级Java面试题目(3面)

一面&#xff1a; JVM数据存储模型&#xff0c;新生代、年老代的构造&#xff1f; java GC算法&#xff0c;什么时候会触发minor gc&#xff0c;什么时候会触发full gc&#xff1f; GC 可达性分析中哪些算是GC ROOT&#xff1f; 你熟悉的JVM调优参数&#xff0c;使用过哪些调…

运用事理图谱搞事情:新闻预警、事件监测、文本可视化、出行规划与历时事件流生成

目前&#xff0c;事理图谱在描述领域事件时空信息上具有独特性&#xff0c;这种逻辑图结构能够以一种直观的方式向我们展现出一个领域知识的链路信息。从学术的角度上来说&#xff0c;事理图谱与事件抽取、事件关系抽取、脚本学习、事件链生成、篇章句间关系识别、图谱图结构运…

CCKS 2018 | 最佳论文:南京大学提出 DSKG,将多层 RNN 用于知识图谱补全

本文转载自公众号&#xff1a;机器之心。 选自CCKS 2018作者&#xff1a;Lingbing Guo、Qingheng Zhang、Weiyi Ge、Wei Hu、Yuzhong Qu机器之心编译参与&#xff1a;Panda、刘晓坤2018 年 8 月 14-17 日&#xff0c;主题为「知识计算与语言理解」的 2018 全国知识图谱…

计算机软件系统

计算机软件系统按其功能可分为系统软件和应用软件两大类。1、系统软件系统软件是指管理、控制、和维护计算机及其外部设备&#xff0c;提供用户与计算机之间操作界面等方面的软件&#xff0c;它并不专门针对具体的应用问题。代表性的系统软件有&#xff1a;操作系统、数据库管理…

数据结构--栈--浏览器前进后退应用

浏览器前进后退&#xff1a; 当你依次浏览a&#xff0c;b&#xff0c;c,然后回到b&#xff0c;再浏览d&#xff0c;就只能查看a&#xff0c;b&#xff0c;d&#xff0c;了。 原理&#xff1a; 利用两个栈A,B 浏览新网页的时候&#xff0c;压入栈A&#xff0c;清空栈B前进&…

关于BERT,面试官们都怎么问

1.BERT 的基本原理是什么&#xff1f;BERT 来自 Google 的论文Pre-training of Deep Bidirectional Transformers for Language Understanding&#xff0c;BERT 是“Bidirectional Encoder Representations from Transformers”的首字母缩写&#xff0c;整体是一个自编码语言模…