Gan论文阅读笔记

GAN论文阅读笔记

2014年老论文了,主要记录一些重要的东西。论文链接如下:

Generative Adversarial Nets (neurips.cc)

文章目录

  • GAN论文阅读笔记
    • 出发点
    • 创新点
    • 设计
    • 训练代码
    • 网络结构代码
    • 测试代码

出发点

Deep generative models have had less of an impact, due to the difficulty of approximating many intractable probabilistic computations that arise in maximum likelihood estimation and related strategies, and due to difficulty of leveraging the benefits of piecewise linear units in the generative context.

​ 当时的生成模型效果不佳在于近似许多棘手的概率计算十分困难,如最大似然估计等。除此之外,把利用分段线性单元运用到生成场景中也有困难。于是作者提出新的生成模型:GAN。

​ 我的理解是,当时的生成模型都是去学习模型生成数据的分布,比如确定方差,确定均值之类的参数,然而这种方法十分难以学习,而且计算量大而复杂,作者考虑到这一点,对生成模型采用端到端的学习策略,不去学习生成数据的分布,而是直接学习模型,只要这个模型的生成结果能够逼近Ground-Truth,那么就可以直接用这个模型代替分布去生成数据。这是典型的黑箱思想。

创新点

adiscriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

创新点1:提出对抗学习策略:提出两个model之间相互对抗,相互抑制的策略。一个model名为生成器Generator,一个model名为判别器Discriminator,生成器尽可能生成接近真实的数据,判别器尽可能识别出生成器数据是Fake。

In this article, we explore the special case when the generative model generates samples by passing random noise through a multilayer perceptron, and the discriminative model is also a multilayer perceptron.

创新点2:当两个model都使用神经网络时,可以运用反向传播和Dropout等算法进行学习,这样就可以避免使用马尔科夫链。

设计

To learn the generator’s distribution pgover data x, we define a prior on input noise variables pz(z), then represent a mapping to data space as G(z; θg), where G is a differentiable function represented by a multilayer perceptron with parameters θg. We also define a second multilayer perceptron D(x; θd) that outputs a single scalar. D(x) represents the probability that x came from the data rather than pg.

1.输入:为了让生成器G生成的数据分布pg与真实数据分布x接近,策略是给G输入一个噪音变量z,然后学习参数θg,这个θg是G网络权重。因此,G可以被写作:G(z;θg)。
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \underset{G}{min}\underset{D}{max}V(D, G) =\mathbb{E}_{x \sim p_{data}(x)}\left[ logD(x)\right] + \mathbb{E}_{z \sim p_z(z)}\left[log(1 - D(G(z)))\right] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
2.对抗性损失函数:从代码可知,对抗性损失是两个BCELoss的和,V尽可能使D(x)更大,在此基础上尽可能使G(z)更小。这是有先后顺序的,在后面会做说明。

在代码中可知,先人为生成两个标签,第一个标签是用torch.ones生成的全为1的矩阵,形状为(batch,1)。其中batch是输入噪声的batch,第二维度只是一个数字——1,这个标签用于判别器D的BCELoss中,代入BCELoss即可得到上面对抗性损失中左侧的期望。第二个标签是用torch.zeors生成的全为0的矩阵,形状同理为(batch,1),运用于生成器G的BCELoss中,代入即可得到对抗性损失的右侧期望。

we alternate between k steps of optimizing D and one step of optimizing G.

This results in D being maintained near its optimal solution, so long as G changes slowly enough.

3.D与G的训练有先后顺序:判别器D先于生成器G训练,而且要求先对D训练k步,再为G训练1步,这就保证G的训练比D足够慢。

如果生成器G足够强大,那么判别器无法再监测生成器,也就没有对抗的必要了。相反,如果判别器D太过于强大,那么生成器也训练地十分缓慢。

在这里插入图片描述

4.算法图如上。

训练代码

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from Model import generator
from Model import discriminatorimport osif not os.path.exists('gan_train.py'):  # 报错中间结果os.mkdir('gan_train.py')def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 96
num_epoch = 200
z_dimension = 100# 数据预处理
img_transform = transforms.Compose([transforms.ToTensor(),  # 图像数据转换成了张量,并且归一化到了[0,1]。transforms.Normalize([0.5], [0.5])  # 这一句的实际结果是将[0,1]的张量归一化到[-1, 1]上。前面的(0.5)均值, 后面(0.5)标准差,
])
# MNIST数据集
mnist = datasets.MNIST(root='./data', train=True, transform=img_transform, download=True)
# 数据集加载器
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)D = discriminator()  # 创建生成器
G = generator()  # 创建判别器
if torch.cuda.is_available():  # 放入GPUD = D.cuda()G = G.cuda()criterion = nn.BCELoss()  # BCELoss 因为可以当成是一个分类任务,如果后面不加Sigmod就用BCEWithLogitsLoss
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)  # 优化器
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)  # 优化器# 开始训练
for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):  # img[96,1,28,28]G.train()num_img = img.size(0)  # num_img=batchsize# =================train discriminatorimg = img.view(num_img, -1)  # 把图片拉平,为了输入判别器 [96,784]real_img = img.cuda()  # 装进cuda,真实图片real_label = torch.ones(num_img).reshape(num_img, 1).cuda()  # 希望判别器对real_img输出为1 [96,1]fake_label = torch.zeros(num_img).reshape(num_img, 1).cuda()  # 希望判别器对fake_img输出为0  [96,1]# 先训练鉴别器# 计算真实图片的lossreal_out = D(real_img)  # 将真实图片输入鉴别器 [96,1]d_loss_real = criterion(real_out, real_label)  # 希望real_out越接近1越好 [1]real_scores = real_out  # 后面print用的# 计算生成图片的lossz = torch.randn(num_img, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,1]#   这个z维度和生成器第一个Linear第一个参数一致# 避免计算G的梯度fake_img = G(z).detach()  # 生成伪造图片 [96,748]fake_out = D(fake_img)  # 给判别器判断生成的好不好 [96,1]d_loss_fake = criterion(fake_out, fake_label)  # 希望判别器给fake_out越接近0越好 [1]fake_scores = fake_out  # 后面print用的d_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 计算生成图片的lossz = torch.randn(num_img, z_dimension).cuda()  # 生成随机噪声 [96,100]fake_img = G(z)  # 生成器伪造图像 [96,784]output = D(fake_img)  # 将伪造图像给判别器判断真伪 [96,1]g_loss = criterion(output, real_label)  # 生成器希望判别器给的值越接近1越好 [1]# 更新生成器g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print(f'Epoch [{epoch}/{num_epoch}], d_loss: {d_loss.cpu().detach():.6f}, g_loss: {g_loss.cpu().detach():.6f}',f'D real: {real_scores.cpu().detach().mean():.6f}, D fake: {fake_scores.cpu().detach().mean():.6f}')if epoch == 0:  # 保存图片real_images = to_img(real_img.detach().cpu())save_image(real_images, './img_gan/real_images.png')fake_images = to_img(fake_img.detach().cpu())save_image(fake_images, f'./img_gan/fake_images-{epoch + 1}.png')G.eval()with torch.no_grad():new_z = torch.randn(batch_size, 100).cuda()test_img = G(new_z)print(test_img.shape)test_img = to_img(test_img.detach().cpu())test_path = f'./test_result/the_{epoch}.png'save_image(test_img, test_path)# 保存模型
torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

网络结构代码

import torch
from torch import nn# 判别器 判别图片是不是来自MNIST数据集
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),  # 784=28*28nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid()#   sigmoid输出这个生成器是或不是原图片,是二分类)def forward(self, x):x = self.dis(x)return x# 生成器 生成伪造的MNIST数据集
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),  # 输入为100维的随机噪声nn.ReLU(),nn.Linear(256, 256),nn.ReLU(),nn.Linear(256, 784),#   生成器输出的特征维和正常图片一样,这是一个可参考的点nn.Tanh())def forward(self, x):x = self.gen(x)return xclass FinetuneModel(nn.Module):def __init__(self, weights):super(FinetuneModel, self).__init__()self.G = generator()base_weights = torch.load(weights)model_parameters = dict(self.G.named_parameters())#   不是对model进行named_parameters,而是对model里面的具体网络进行named_parameters取出参数,否则取出的是model冗余的参数去测试pretrained_weights = {k: v for k, v in base_weights.items() if k in model_parameters}new_state_dict = {k: pretrained_weights[k] for k in model_parameters.keys()}self.G.load_state_dict(new_state_dict)def forward(self, input):output = self.G(input)return output

测试代码

import os
import sys
import numpy as np
import torch
import argparse
import torch.utils.data
from PIL import Image
from Model import FinetuneModel
from Model import generator
from torchvision.utils import save_imageparser = argparse.ArgumentParser("GAN")
parser.add_argument('--save_path', type=str, default='./test_result')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=2)
parser.add_argument('--model', type=str, default='generator.pth')args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)def to_img(x):  # 将结果的-0.5~0.5变为0~1保存图片out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outdef main():if not torch.cuda.is_available():print("no gpu device available")sys.exit(1)model = FinetuneModel(args.model)model = model.to(device=args.gpu)model.eval()z_dimension = 100with torch.no_grad():for i in range(100):z = torch.randn(96, z_dimension).cuda()  # 创建一个100维度的随机噪声作为生成器的输入 [96,100]output = model(z)print(output.shape)u_name = f'the_{i}.png'print(f'processing {u_name}')u_path = save_path + '/' + u_nameoutput = to_img(output.cpu().detach())save_image(output, u_path)if __name__ == '__main__':main()

本文毕

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/208299.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件压力测试的重要性与用途

在当今数字化的时代,软件已经成为几乎所有行业不可或缺的一部分。随着软件应用规模的增加和用户数量的上升,软件的性能变得尤为关键。为了确保软件在面对高并发和大负载时仍然能够保持稳定性和可靠性,软件压力测试变得至关重要。下面是软件压…

提醒事项日历同步怎么设置?可实时同步日历的提醒事项工具

随着生活节奏的加快,我们每天都需要处理许多琐碎的事务。为了不忘记重要的事情,很多人选择使用提醒事项工具来帮助自己。然而,市场上的提醒事项工具五花八门,有些并不具备日历月视图功能,也无法与手机日历同步&#xf…

Linux学习笔记7-IIC的应用和AP3216C

接下来进入其他两种串行通信方式:SPI和I2C的学习,因为以后的项目中会用到这些通信方式,而且正点原子的开发板里面也有用I2C和SPI通信的传感器来做实例,分别是一个距离传感器和六轴陀螺仪,这样就可以很好的通过实例来学…

GRE与顺丰圆通快递盒子

1. DNS污染 随想: 在输入一串网址后,会发生如下变化如果你在系统中配置了 Hosts 文件,那么电脑会先查询 Hosts 文件如果 Hosts 里面没有这个别名,就通过域名服务器查询域名服务器回应了,那么你的电脑就可以根据域名服…

【LeetCode:1466. 重新规划路线 | DFS + 图 + 树】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

Vue 子路由页面发消息给主路由页面 ,实现主页面显示子页面的信息

需求 子页面进入后,能在主页面显示子页的相关信息,比如说主页面的菜单激活的是哪个子页面的菜单项 如上图,当刷新浏览器页面时,让菜单的激活项仍保持在【最近浏览】。 实现方式: 在子页面的create事件中增加&#xff…

Vue:绘制图例

本文记录使用Vue框架绘制图例的代码片段。 可以嵌入到cesium视图中,也可以直接绘制到自己的原生系统中。 一、绘制图例Vue组件 <div v-for="(color, index) in colors" :key="index" class="legend-item"><div class="color-…

深度学习还可以从如下方面进行创新!!

文章目录 一、我认为可以从如下5个方向进行创新总结 一、我认为可以从如下5个方向进行创新 新的模型结构&#xff1a;尽管现在的深度学习模型已经非常强大&#xff0c;但是还有很多未被探索的模型结构。探索新的模型结构可以带来更好的性能和更低的计算成本。 新的优化算法&a…

一个简单的postman设置断言,为何会难住一个工作5年的测试?

postman设置断言 作为一款接口测试工 具&#xff0c;postman需要对发送请求后返回的结果是否正确做验证&#xff0c;在postman中通过 tests页签做请求的验证&#xff0c;也称为断言。 postman设置断言的流程 1、在tests页签截取要对比的实际响应信息&#xff08;响应头、响应…

眼花缭乱的ADN/ADX/DSP/DMP/SSP和他们的关系链

做过互联网广告尤其是程序化广告的同学都遇到过以下这些名词&#xff0c;或许正被他们折磨的焦头烂额&#xff0c;这篇文章&#xff0c;我们就来说说这些概念的含义及他们之间的关系链。 ADN&#xff1a;AD Network——广告网络或广告联盟。连接广告主和媒体的中间商。 ADX&…

stm32串口编程实例-实现数据的收发功能

大家好&#xff0c;今天给大家介绍stm32串口编程实例&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 串口是USART(通用同步/异步收发器)的俗称。 实际上&#xff0c;串行总线并不…

2023年8月8日 Go生态洞察:Go 1.21 版本发布探索

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

中小企业都在用哪些开源项目管理工具?分享15款

推荐15个优秀的开源项目管理工具&#xff0c;比如&#xff1a;ProjectLibre、OpenProject、ERPNext、Redmine、禅道、Tuleap、Restyaboard等。 项目经理面临各种复杂任务&#xff0c;包括追踪任务的进度、评估交付风险和管理整体工作量。为了顺利达成目标&#xff0c;一款靠谱的…

学习IO的第四天

作业 : 使用两个子进程完成两个文件的拷贝&#xff0c;子进程1拷贝前一半内容&#xff0c;子进程2拷贝后一般内容&#xff0c;父进程用于回收两个子进程的资源 #include <head.h>int main(int argc, const char *argv[]) {int rd -1;if((rdopen("./01_test.c&quo…

零基础如何入门HarmonyOS开发?

HarmonyOS鸿蒙应用开发是当前非常热门的一个领域&#xff0c;许多人都想入门学习这个技术。但是&#xff0c;对于零基础的人来说&#xff0c;如何入门确实是一个问题。下面&#xff0c;我将从以下几个方面来介绍如何零基础入门HarmonyOS鸿蒙应用开发学习。 一、了解HarmonyOS鸿…

[JSMSA_CTF] 2023年12月练习题 pwn

一开始没给附件&#xff0c;还以为是3个盲pwn结果&#xff0c;pwn了一晚上没出来&#xff0c;今天看已经有附件了。 pwn1 在init_0里使用mallopt(1,0) 设置global_max_fast0 任何块释放都会进入unsort在free函数里没有清理指针&#xff0c;有UAF将v6:0x100清0&#xff0c;便于…

甘草书店:#10 2023年11月24日 星期五 「麦田创业分享2—世界奇奇怪怪,请保持可可爱爱」

今日继续分享麦田创业经验。 如果你问我&#xff0c;创业过程中是否想过放弃。那么答案是&#xff0c;有那么一次。 那时想要放弃的原因并不是辛苦没有回报&#xff0c;或是资金短缺&#xff0c;而是没能理解“异见者”。 其实事情非常简单&#xff0c;现在反观那时的自己&a…

实例解析关于兔鲜登录tab栏切换案例详细讲解!

文章目录 文章目录 效果图展示 整体制作的一个思路 代码展示 技术细节 小结 效果图展示 点击账户登录显示登录的模块&#xff0c;点击二维码登录显示二维码的模块 整体制作的一个思路 点击哪个模块哪个显示&#xff0c;另外一个模块让它隐藏即可&#xff01; 代码展示 <!…

好莱坞明星识别

一、前期工作 1. 设置GPU from tensorflow import keras from tensorflow.keras import layers,models import os, PIL, pathlib import matplotlib.pyplot as plt import tensorflow as tfgpus tf.config.list_physical_devices("GPU")if gpus:gpu0 …

动态规划——完全背包问题(公式推导,组合、排列)

本文章是对于完全背包 一些题型(如题目所示&#xff0c;组合、排列和最小值类型)的总结和理解&#xff0c;依次记录一下&#xff0c;方便回顾与复习。 本文章是基于个人所总结 实现的&#xff0c;但在其中遇到了一些疑惑与困难&#xff0c;所以总结一篇与完全背包相关的问题。 …