GAN对抗生成网络(二)——算法及Python实现

1 算法步骤

上一篇提到的GAN的最优化问题是G^{*}=\arg\min\limits_{G}\max\limits_{D}V(G,D),本文记录如何求解这一问题。

首先为了表示方便,记\max\limits_{D}V(G,D)=L(G),这里让V(G,D)最大的D=D^{*}可视作常量。

第一步,给定初始的G_{0},使用梯度上升找到 D_0^{*},最大化L(G_0)。关于梯度下降,可以参考笔者另一篇文章《BP神经网络原理-CSDN博客》误差反向传播的部分。

第二步,使用梯度下降法,找到G最佳的参数\theta_{G}.其中\eta为学习率。

\theta_{G}\leftarrow \theta_{G}-\eta\frac{\partial{L(G)}}{\theta_{G}}得到G_{1}

 之后这两步交替进行。

这里的L(G)是有max运算的,可以被微分吗?答案是可以的

引用李宏毅老师的例子,f(x)是有max运算的,相当于分段函数,在求微分的时候,根据当前x落在哪个区域决定微分的形式如何。

2 算法与JS散度的关系

上述算法第一步训练D时本质是增大JS散度,第二步训练G时看起来是减小JS散度,但实际上不完全等同。

如下图所示,左侧表示算法第一步根据G_{0}得到了最优的D_0^{*}。当进行到算法第二步,需要根据D_0^{*}找到一个更小的JS散度,如右图所示,G选择了G_{1}从而使得V(G_1, D)<V(G_0, D)。虽然此时JS散度更小,但是由于G_{0}更换成了G_{1}D_0^{*}将更新参数变成D_1^{*},此时JS散度更大了。只能说不能让G更新得太多,否则不能达到减小JS散度的目标。回到文物造假的例子,造假者收到鉴宝者的反馈后,应该微调技艺,而不是彻底更换技艺,否则只能从头来过。

从快速收敛的角度来说,G应该不能更新太过,但是如果太小也忽略了G更好的形式,可能陷入局部最优。

3 实际训练过程

实际训练时,我们是无法计算出真实数据或生成数据实际的期望的,只能通过抽样近似得到期望。因此实际的做法如下:

3.1 第一步,初始化

初始化生成器G和判别器D

3.2 第二步,固定G,训练D

从分布函数(如高斯分布)中随机抽样出m个样本\left \{ z^1,z^2,...z^m \right \}输入给G,输出m个样本\left \{ \tilde{x}^1,\tilde{x}^2,..\tilde{x}^m \right \}G本质上概率分布转化器——将高斯分布的噪声转变成样本的分布。从真实数据中随机抽样出m个样本\left \{ x^1,x^2,...x^m \right \},将二者输入给D。训练D的参数使其接收x_{1}时打出0分,接收x_{2}时打出1分,即最大化\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

建模成分类或回归问题均可。

使用梯度上升法,\theta_d\leftarrow \theta_d+\eta\nabla\tilde{V}(\theta_d)

实际中需要更新多次,使得V值最大。这一步实际上只找到了一个\max\limits_{D}V(G,D)的下限(lower bound),原因是:(1)训练次数不会非常大,没法训练到收敛;(2)即使能收敛,也可能只是一个局部最优解;(3)推导时假设了D可以是任意的函数,即针对不同的x都给出最高的值,但实际中这个假设不成立。

3.3 第三步,固定D,训练G

从分布函数(如高斯分布)中随机抽样出另外m个样本\left \{ z^1,z^2,...z^m \right \}

更新G的参数\theta_G使得下式最小:

\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}

其中第一项与G无关,因此只需要看第二项。

根据上文的讨论,这里一般只训练一次,避免G改变过多,无法收敛。

实践中是将GD合在一起作为一个大的神经网络,前几层是G,后几层是D,中间有一个隐含层是G的输出,就是GAN希望得到的输出。第二步和第三步可分别固定神经网络中的某几层参数不动,训练其它层参数。

4 Python实现

关于GAN的代码,参考了https://github.com/junqiangchen/GAN。项目可以产生数字图片和人脸图片,其中人脸图片的生成使用了GAN的变种——WGAN,之后会专门讨论,本文讨论最原始的GAN模型。

4.1 使用新版tensorflow需要修改的地方

原始的代码直接运行是不通的,需要做一些调整;原始代码采用的是旧版Tensorflow(V1),如果安装了新版TensorFlow(V2)也需要做调整;有些包如果安装的新版同样不支持部分API,需要替换。具体如下表所示

问题调整方法备注
部分文件路径不对

调整路径,例如

from GAN.face_model import WGAN_GPModel

调整为

from GAN.genface.face_model import WGAN_GPModel
其他几处不再赘述
imresize报错

 例如

from scipy.misc import imresize

调整为

from skimage.transform import resize
最新版本scipy不支持此函数,将
imresize(test_image, (init_width * scale_factor, init_height * scale_factor))

替换为

resize(
Image.fromarray(test_image).resize(init_width * scale_factor, init_height * scale_factor))
imsave
报错

例如

scipy.misc.imsave(path, merge_img)

调整为

import cv2cv2.imwrite(path, merge_img * 255)

最新版本scipy不支持此函数,替换成cv2。个人认为最后应该乘255,因为原始数据是0~1的数据,直接存会存成几乎黑白的图片,需要还原
使用新版tensorflow的问题
import tensorflow as tf

替换为

import tensorflow.compat.v1 as tftf.compat.v1.disable_eager_execution()
新版tensorflow提供了向下兼容的compat.v1的使用方式,统一替换即可。同时要取消eager_execution模式,新版默认是“即时计算”模式,如果兼容旧版则应取消该模式。

4.2 GAN的代码解析

代码位置:gan/GAN/genmnist/mnist_model.py, class名为GANModel

4.2.1 Generator

定义在_GAN_generator函数中,总结为以下要点:

(1)含有五层网络,除最后一层,其他层在进入下一层之前都用batch_normalization归一化+relu激活函数

g4 = tf.contrib.layers.batch_norm(g4, epsilon=1e-5, is_training=self.phase, scope='bn4')
g4 = tf.nn.relu(g4)

(2)每一层都定义w和b,使用truncated_normal,即截断异常值的正态分布

tf.truncated_normal_initializer

(3)第1~2层使用全连接层,即使用w乘输入,并加上b偏置

tf.matmul(g1, g_w2) + g_b2

(4)第3~4层使用反卷积运算。是卷积运算的逆过程,关于反卷积的介绍笔者正在整理

tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, strides, strides, 1], padding='SAME')

(5)第5层使用卷积运算,并使用tanh激活函数

g5 = convolution_2d(g4, g_w5)

4.2.2 Discriminator

与Generator类似,简述如下:

(1)共4层,其中1、2层使用卷积,3、4层使用全连接

(2)卷积后使用平均池化

d1 = average_pool_2x2(d1)

(3)最后一层使用sigmoid将输出控制在0~1之间

out = tf.nn.sigmoid(out_logit)

4.2.3 损失函数

Generator的损失函数为

-tf.reduce_mean(tf.log(self.D_fake))

对应前文提到的\frac{1}{m}\sum_{i=1}^m{log{(1-D(G(z^i)))}}。注意这里是用-\frac{1}{m}\sum_{i=1}^m{log{D(G(z^i))}},方向是一样的,之后笔者会讨论他们的区别。

Discriminator的损失函数为

-tf.reduce_mean(tf.log(self.D_real) + tf.log(1 - self.D_fake))

对应前文提到的\tilde{V}=\frac{1}{m}\sum_{i=1}^m{log{D(x^i)}}+\frac{1}{m}\sum_{i=1}^m{log{(1-D(\tilde{x}^i))}}

4.2.5 训练

定义D和G的训练函数,使得各自损失函数最小化。

trainD_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.d_loss, var_list=D_vars)
trainG_op = tf.train.AdamOptimizer(learning_rate, beta1).minimize(self.g_loss, var_list=G_vars)

先让D预训练30次,然后D和G交替训练。为什么先让D预训练30次?笔者认为D本质上就是个图片分类器,可以不依赖于G,比较好训练,预训练可以加快收敛速度。

训练时使用feed“喂数据”

feed_dict={self.X: batch_xs, self.Z: z_batch, self.phase: 1}

其中self.X表示真实的图片,self.Z表示噪声,self.phase表示batchnorm训练阶段还是测试阶段。

4.2.6 预测

生成随机噪声Z之后,喂给G,即可生成图片

outimage = self.Gen.eval(feed_dict={self.Z: z_batch, self.phase: 1}, session=sess)

不过笔者对这里的phase有些疑问,是否应该设置为0?恕笔者对Tensorflow不熟,代码解析有些走马观花,没有深究细节以及为什么这么写,等功力提高再回过头来优化。

至此,原始GAN的算法以及Python实现已介绍完毕,下一篇笔者将拓展讨论一些细节并介绍GAN的变种。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66272.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[读书日志]从零开始学习Chisel 第二篇:Scala的变量与函数(敏捷硬件开发语言Chisel与数字系统设计)

第一篇https://blog.csdn.net/m0_74021449/article/details/144887921 2.2 Scala的变量及函数 2.2.1变量定义与基本类型 变量声明 变量首次定义必须使用关键字var或者val&#xff0c;二者的区别是val修饰的变量禁止被重新赋值&#xff0c;它是一个只读的变量。首次定义变量时…

Spring Boot - 日志功能深度解析与实践指南

文章目录 概述1. Spring Boot 日志功能概述2. 默认日志框架&#xff1a;LogbackLogback 的核心组件Logback 的配置文件 3. 日志级别及其配置配置日志级别3.1 配置文件3.2 环境变量3.3 命令行参数 4. 日志格式自定义自定义日志格式 5. 日志文件输出6. 日志归档与清理7. 自定义日…

NVIDIA DLI课程《NVIDIA NIM入门》——学习笔记

先看老师给的资料&#xff1a; NVIDIA NIM是 NVIDIA AI Enterprise 的一部分&#xff0c;是一套易于使用的预构建容器工具&#xff0c;目的是帮助企业客户在云、数据中心和工作站上安全、可靠地部署高性能的 AI 模型推理。这些预构建的容器支持从开源社区模型到 NVIDIA AI 基础…

【HF设计模式】05-单例模式

声明&#xff1a;仅为个人学习总结&#xff0c;还请批判性查看&#xff0c;如有不同观点&#xff0c;欢迎交流。 摘要 《Head First设计模式》第5章笔记&#xff1a;结合示例应用和代码&#xff0c;介绍单例模式&#xff0c;包括遇到的问题、采用的解决方案、以及达到的效果。…

【FlutterDart】页面切换 PageView PageController(9 /100)

上效果&#xff1a; 有些不能理解官方例子里的动画为什么没有效果&#xff0c;有可能是我写法不对 后续如果有动画效果修复了&#xff0c;再更新这篇&#xff0c;没有动画效果&#xff0c;总觉得感受的丝滑效果差了很多 上代码&#xff1a; import package:flutter/material.…

Electron快速入门——跨平台桌面端应用开发框架

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

Android NDK开发实战之环境搭建篇(so库,Gemini ai)

文章流程 音视频安卓开发首先涉及到ffmpeg编译打包动态库&#xff0c;先了解动态库之间的cpu架构差异性。然后再搭建可运行的Android 环境。 So库适配 ⽇常开发我们经常会使⽤到第三库&#xff0c;涉及到底层的语⾳&#xff0c;视频等都需要添加so库。⽽so库的体积⼀般来说 ⾮…

【Java回顾】Day2 正则表达式----异常处理

参考资料&#xff1a;菜鸟教程 https://www.runoob.com/java/java-exceptions.html 正则表达式 有一部分没看完 介绍 字符串的模式搜索、编辑或处理文本java.util.regex包&#xff0c;包含了pattern和mathcer类&#xff0c;用于处理正则表达式的匹配操作。 捕获组 把多个字符…

Unity性能优化总结

目录 前言 移动端常见性能优化指标​编辑 包体大小优化 FPS CPU占用率 GPU占用率 内存 发热和耗电量 流量优化 前言 终于有时间了&#xff0c;我将在最近两个项目中进行优化的一些经验进行归纳总结以飨读者。因为我习惯用思维导图&#xff0c;所以归纳的内容主要以图来…

北京航空航天大学惊现技术商业“宫斗剧”!背后隐藏的内幕遭曝光!

北京航空航天大学&#xff08;以下称北航&#xff09;与源亿&#xff08;北京&#xff09;网络科技有限公司&#xff08;以下称源亿&#xff09;的派驻的员工恶意串通&#xff0c;指定北京蚂蚁非标科技有限公司&#xff08;以下称蚂蚁公司&#xff09;挖走源亿公司在现场派驻的…

transfomer深度学习实战水果识别

本文采用RT-DETR作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。RT-DETR以其高效的实时检测能力&#xff0c;在多个目标检测任务中展现出卓越性能。本研究针对水果数据集进行训练和优化&#xff0c;该数据集包含丰富的水果图像样本&#…

TensorFlow深度学习实战(3)——深度学习中常用激活函数详解

TensorFlow深度学习实战&#xff08;3&#xff09;——深度学习中常用激活函数详解 0. 前言1. 引入激活函数1.1 感知器1.2 多层感知器1.3 训练感知器存在的问题 2. 激活函数3. 常见激活函数3.1 sigmoid3.2 tanh3.3 ReLU3.4 ELU和Leaky ReLU 小结系列链接 0. 前言 使用激活函数…

数据结构C语言描述9(图文结合)--二叉树和特殊书的概念,二叉树“最傻瓜式创建”与前中后序的“递归”与“非递归遍历”

前言 这个专栏将会用纯C实现常用的数据结构和简单的算法&#xff1b;有C基础即可跟着学习&#xff0c;代码均可运行&#xff1b;准备考研的也可跟着写&#xff0c;个人感觉&#xff0c;如果时间充裕&#xff0c;手写一遍比看书、刷题管用很多&#xff0c;这也是本人采用纯C语言…

Leetcode打卡:设计一个ATM机器

执行结果&#xff1a;通过 题目 2241 设计一个ATM机器 一个 ATM 机器&#xff0c;存有 5 种面值的钞票&#xff1a;20 &#xff0c;50 &#xff0c;100 &#xff0c;200 和 500 美元。初始时&#xff0c;ATM 机是空的。用户可以用它存或者取任意数目的钱。 取款时&#xff0c…

vscode如何离线安装插件

在没有网络的时候,如果要安装插件,就会麻烦一些,需要通过离线安装的方式进行。下面记录如何在vscode离线安装插件。 一、下载离线插件 在一台能联网的电脑中,下载好离线插件,拷贝到无法联网的电脑上。等待安装。 vscode插件商店地址:https://marketplace.visualstudio.co…

用Tkinter制作一个用于合并PDF文件的小程序

需要安装PyPDF2库&#xff0c;具体原代码如下&#xff1a; # -*- coding: utf-8 -*- """ Created on Sun Dec 29 14:44:20 2024author: YBK """import PyPDF2 import os import tkinter as tk import windndpdf_files [] def dragged_files(f…

CDP集成Hudi实战-Hive

[〇]关于本文 本文测试一下使用Hive和Hudi的集成 软件版本Hudi1.0.0Hadoop Version3.1.1.7.3.1.0-197Hive Version3.1.3000.7.3.1.0-197Spark Version3.4.1.7.3.1.0-197CDP7.3.1 [一]部署Jar包 1-部署hudi-hive-sync-bundle-1.0.0.jar文件 [rootcdp73-1 ~]# for i in $(se…

公司资产网站

本文结尾处获取源码。 本文结尾处获取源码。 本文结尾处获取源码。 一、相关技术 后端&#xff1a;Java、JavaWeb / Springboot。前端&#xff1a;Vue、HTML / CSS / Javascript 等。数据库&#xff1a;MySQL 二、相关软件&#xff08;列出的软件其一均可运行&#xff09; I…

CSS——2.书写格式一

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title></title></head><body><!--css书写中&#xff1a;--><!--1.css 由属性名:属性值构成--><!--style"color: red;font-size: 20px;&quo…

基于FPGA的辩论赛系统设计-8名选手-正反两方-支持单选手评分-正反两方评分总和

基于FPGA的辩论赛系统设计 功能描述一、系统概述二、仿真波形视频 功能描述 1.答辩倒计时功能&#xff0c;当正反任意一方开始答辩后&#xff0c;倒计时30S。在倒计时最后10S后&#xff0c;LED灯开始闪烁。 2.答辩评分和计分功能&#xff0c;当答辩方结束答辩后&#xff0c;评…