GAN的原理分析与实例

为了便于理解,可以先玩一玩这个网站:GAN Lab: Play with Generative Adversarial Networks in Your Browser!

GAN的本质:枯叶蝶和鸟。生成器的目标:让枯叶蝶进化,变得像枯叶,不被鸟准确识别。判别器的目标:准确判别是枯叶还是鸟

伪代码: 

案例:

原始数据:

案例结果: 

 案例完整代码:

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = './data/train_data'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 50
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),#转置卷积层:输入特征映射的尺寸会放大,通道数可能会减小,普通卷积层:输入特征映射的尺寸会缩小,但通道数可能会增加nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)#real_image = real_image.cuda()if (i + 1) % d_every == 0:  #d_every = 1,每一个batch训练一次discriminatoroptimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i)) #将Figure划分为8行8列的子图网格,并将当前的子图添加到第i个位置。# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0)) #permute()函数可以对维度进行重排,Matplotlib期望的图像格式是(H, W, C),即高度、宽度、通道i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.savefig("images/%dcgan.png" % num)if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布# if torch.cuda.is_available() == True:#     print('Cuda is available!')#     Generator.cuda()#     Discriminator.cuda()#     criterion.cuda()#     true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()#     fix_noises, noises = fix_noises.cuda(), noises.cuda()#plot_epoch = [1,5,10,50,100,200,500,800,1000,1500,2000,2500,3000]plot_epoch = [1,5,10,50,100,200,500,800,1000,1200,1500]for i in range(1500):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

http://t.csdnimg.cn/FTSriicon-default.png?t=N7T8http://t.csdnimg.cn/FTSri

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222322.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java中的链表

文章目录 前言一、链表的概念及结构二、单向不带头非循坏链表的实现2.1打印链表2.2求链表的长度2.3头插法2.4尾插法2.5任意位置插入2.6查找是否包含某个元素的节点2.7删除第一次出现这个元素的节点2.8删除包含这个元素的所以节点2.9清空链表单向链表的测试 三、双向不带头非循坏…

【Python】人工智能-机器学习——不调库手撕深度网络分类问题

1. 作业内容描述 1.1 背景 数据集大小150该数据有4个属性&#xff0c;分别如下 Sepal.Length&#xff1a;花萼长度(cm)Sepal.Width&#xff1a;花萼宽度单位(cm)Petal.Length&#xff1a;花瓣长度(cm)Petal.Width&#xff1a;花瓣宽度(cm)category&#xff1a;类别&#xff0…

【STM32】STM32学习笔记-GPIO输入(07)

00. 目录 文章目录 00. 目录01. 按键简介02. 传感器模块简介03. 光敏电阻传感器04. 按键电路图05. C语言数据类型06. C语言宏定义07. C语言typedef08. C语言结构体09. C语言枚举10. 附录 01. 按键简介 按键&#xff1a;常见的输入设备&#xff0c;按下导通&#xff0c;松手断开…

TCP/IP详解——ARP 协议

文章目录 一、ARP 协议1. ARP 数据包格式2. ARP 工作过程3. ARP 缓存4. ARP 请求5. ARP 响应6. ARP 代理7. ARP 探测IP冲突8. ARP 协议抓包分析9. ARP 断网攻击10. 总结 一、ARP 协议 ARP&#xff08;Address Resolution Protocol&#xff09;协议工作在网络层和数据链路层之间…

CCF编程能力等级认证GESP—C++2级—20230923

CCF编程能力等级认证GESP—C2级—20230923 单选题&#xff08;每题 2 分&#xff0c;共 30 分&#xff09;判断题&#xff08;每题 2 分&#xff0c;共 20 分&#xff09;编程题 (每题 25 分&#xff0c;共 50 分)⼩杨的 X 字矩阵数字⿊洞 答案及解析单选题判断题编程题1编程题…

使用Docker本地安装部署Draw.io绘图工具并实现远程访问协作办公

前言 提到流程图&#xff0c;大家第一时间可能会想到Visio&#xff0c;不可否认&#xff0c;VIsio确实是功能强大&#xff0c;但是软件为收费&#xff0c;并且因为其功能强大&#xff0c;导致安装需要很多的系统内存&#xff0c;并且是不可跨平台使用。所以&#xff0c;今天给…

电脑开机出现:CLIENT MAD ADDR (网卡启动系统)的解决办法

文章目录 前言步骤1、确定情况2、对症下药——关闭网卡启动 补充1、关于BIOS2、关于PXE 前言 最近给旧电脑重装系统安了下开发环境和常用软件啥的&#xff0c;之前还好好启动的电脑&#xff0c;开机突然需要额外加载一个页面&#xff0c;虽然最后正常启动了不影响使用&#xf…

Qt5 CMake环境配置

Qt5 CMake环境配置 设置Qt路径 有两种方法 Qt5_DIR&#xff0c;使用这个变量&#xff0c;必须把路径设置到Qt5Config.cmake所在文件夹&#xff0c;也就是安装目录下的lib/cmake/Qt5CMAKE_PREFIX_PATH&#xff0c;只需要设置到安装目录就可以了&#xff0c;这个目录就是bin、…

OpenStack和Docker结合?为何现在流行?

为何现在流行OpenStack和Docker结合&#xff1f; 结合的好处 1、资源管理与调度灵活&#xff1a; OpenStack提供了完善的虚拟机管理能力&#xff0c;而Kubernetes&#xff08;使用Docker作为容器运行环境&#xff09;在容器调度方面非常高效。将两者结合&#xff0c;可以实现…

RNN介绍及Pytorch源码解析

介绍一下RNN模型的结构以及源码&#xff0c;用作自己复习的材料。 RNN模型所对应的源码在&#xff1a;\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。 RNN的模型图如下&#xff1a; 源码注释中写道&#xff0c;RNN的数学公式&#xff1a; 表示在时刻的隐藏状态…

ES6学习(三):Set和Map容器的使用

Set容器 set的结构类似于数组,但是成员是唯一且不会重复的。 创建的时候需要使用new Set([])的方法 创建Set格式数据 let set1 new Set([])console.log(set1, set1)let set2 new Set([1, 2, 3, 4, 5])console.log(set2, set2) 对比看看Set中唯一 let set3 new Set([1, 1,…

多架构容器镜像构建实战

最近在一个国产化项目中遇到了这样一个场景&#xff0c;在同一个 Kubernetes 集群中的节点是混合架构的&#xff0c;也就是说&#xff0c;其中某些节点的 CPU 架构是 x86 的&#xff0c;而另一些节点是 ARM 的。为了让我们的镜像在这样的环境下运行&#xff0c;一种最简单的做法…

Rust语言基础语法使用

1.安装开发工具: RustRover JetBrains: Essential tools for software developers and teams 下载: RustRover: Rust IDE by JetBrains 下载成功后安装并启动RustRover 安装中文语言包插件 重启RustRover生效

vue3引入echarts正确姿势

echarts文档地址&#xff1a; echarts官网地址 echarts配置手册 echarts 模板地址 1、安装 &#xff08;1&#xff09;安装echarts包 npm install echarts --save 或者 cnpm install echarts --save&#xff08;2&#xff09;安装vue echarts工具包 npm install echart…

持续集成交付CICD:Jenkins使用基于SaltStack的CD流水线下载Nexus制品

目录 一、理论 1.salt常用命令 二、实验 1.SaltStack环境检查 2.Jenkins使用基于SaltStack的CD流水线下载Nexus制品 二、问题 1.salt未找到命令 2.salt简单测试报错 3. wget输出日志过长 一、理论 1.salt常用命令 &#xff08;1&#xff09;salt 命令 该 命令执行s…

蓝牙物联网智慧物业解决方案

蓝牙物联网智慧物业解决方案是一种利用蓝牙技术来提高物业管理和服务效率的解决方案。它通过将蓝牙技术与其他智能设备、应用程序和云服务相结合&#xff0c;为物业管理和服务提供更便捷、高效和智能化的支持。 蓝牙物联网智慧物业解决方案包括&#xff1a; 1、设备管理&#…

JupyterNotebook安装依赖 使用conda环境

怎么使用JupyterNotebook安装依赖&#xff1f;&#xff08;使用conda环境&#xff1f;&#xff09; 预装的jupyter中有conda&#xff0c;可以进入终端执行相关命令 预装的Jupyter&#xff0c;目标用户是轻度的Jupyter用户&#xff0c;如果想使用conda的多环境等高级功能&#x…

Unity中的ShaderToy

文章目录 前言一、ShaderToy网站二、ShaderToy基本框架1、我们可以在ShaderToy网站中&#xff0c;这样看用到的GLSL文档2、void mainImage 是我们的程序入口&#xff0c;类似于片断着色器3、fragColor作为输出变量&#xff0c;为屏幕每一像素的颜色&#xff0c;alpha一般赋值为…

微信小程序单图上传和多图上传

图片上传主要用到 1、wx.chooseImage(Object object) 从本地相册选择图片或使用相机拍照。 参数 Object object 属性类型默认值必填说明countnumber9否最多可以选择的图片张数sizeTypeArray.<string>[original, compressed]否所选的图片的尺寸sourceTypeArray.<s…

从开源项目中学习如何自定义 Spring Boot Starter 小组件

前言 今天参考的开源组件Graceful Response——Spring Boot接口优雅响应处理器。 具体用法可以参考github以及官方文档。 基本使用 引入Graceful Response组件 项目中直接引入如下maven依赖&#xff0c;即可使用其相关功能。 <dependency><groupId>com.feiniaoji…