基于PyTorch搭建你的生成对抗性网络

cd0651e77ef4cd23df4cfc0475b1bb07.jpeg

前言

你听说过GANs吗?还是你才刚刚开始学?GANs是2014年由蒙特利尔大学的学生 Ian Goodfellow 博士首次提出的。GANs最常见的例子是生成图像。有一个网站包含了不存在的人的面孔,便是一个常见的GANs应用示例。也是我们将要在本文中进行分享的。

生成对抗网络由两个神经网络组成,生成器和判别器相互竞争。我将在后面详细解释每个步骤。希望在本文结束时,你将能够从零开始训练和建立自己的生财之道对抗性网络。所以闲话少说,让我们开始吧。

目录

步骤0: 导入数据集

步骤1: 加载及预处理图像

步骤2: 定义判别器算法

步骤3: 定义生成器算法

步骤4: 编写训练算法

步骤5: 训练模型

步骤6: 测试模型

步骤0: 导入数据集

第一步是下载并将数据加载到内存中。我们将使用 CelebFaces Attributes Dataset (CelebA)来训练你的对抗性网络。主要分以下三个步骤:

1. 下载数据集:

https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip;

2. 解压缩数据集;

3. Clone 如下 GitHub地址:

https://github.com/Ahmad-shaikh575/Face-Generation-using-GANS

这样做之后,你可以在 colab 环境中打开它,或者你可以使用你自己的 pc 来训练模型。

导入必要的库

#import the neccessary libraries
import pickle as pkl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets
from torchvision import transforms
import torch
import torch.optim as optim

步骤1: 加载及预处理图像

在这一步中,我们将预处理在前一节中下载的图像数据。

将采取以下步骤:

  1. 调整图片大小

  2. 转换成张量

  3. 加载到 PyTorch 数据集中

  4. 加载到 PyTorch DataLoader 中

# Define hyperparameters
batch_size = 32
img_size = 32
data_dir='processed_celeba_small/'# Apply the transformations
transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor()])
# Load the dataset
imagenet_data = datasets.ImageFolder(data_dir,transform= transform)# Load the image data into dataloader
celeba_train_loader = torch.utils.data.DataLoader(imagenet_data,batch_size,shuffle=True)

图像的大小应该足够小,这将有助于更快地训练模型。Tensors 基本上是 NumPy 数组,我们只是将图像转换为在 PyTorch 中所必需的 NumPy 数组。

然后我们加载这个转换成的 PyTorch 数据集。在那之后,我们将把我们的数据分成小批量。这个数据加载器将在每次迭代时向我们的模型训练过程提供图像数据。

随着数据的加载完成。现在,我们可以预处理图像。

图像的预处理

我们将在训练过程中使用 tanh 激活函数。该生成器的输出范围在 -1到1之间。我们还需要对这个范围内的图像进行缩放。代码如下所示:

def scale(img, feature_range=(-1, 1)):'''Scales the input image into given feature_range'''min,max = feature_rangeimg = img * (max-min) + minreturn img

这个函数将对所有输入图像缩放,我们将在后面的训练中使用这个函数。

现在我们已经完成了无聊的预处理步骤。

接下来是令人兴奋的部分,现在我们需要为我们的生成器和判别器神经网络编写代码。

步骤2: 定义判别器算法

97127e11857e4e2d08f27fe2e2848c52.png

判别器是一个可以区分真假图像的神经网络。真实的图像和由生成器生成的图像都将提供给它。

我们将首先定义一个辅助函数,这个辅助函数在创建卷积网络层时非常方便。

# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)#Appending the layerlayers.append(conv_layer)#Applying the batch normalization if it's given trueif batch_norm:layers.append(nn.BatchNorm2d(out_channels))# returning the sequential containerreturn nn.Sequential(*layers)

这个辅助函数接收创建任何卷积层所需的参数,并返回一个序列化的容器。现在我们将使用这个辅助函数来创建我们自己的判别器网络。

class Discriminator(nn.Module):def __init__(self, conv_dim):super(Discriminator, self).__init__()self.conv_dim = conv_dim#32 x 32self.cv1 = conv(3, self.conv_dim, 4, batch_norm=False)#16 x 16self.cv2 = conv(self.conv_dim, self.conv_dim*2, 4, batch_norm=True)#4 x 4self.cv3 = conv(self.conv_dim*2, self.conv_dim*4, 4, batch_norm=True)#2 x 2self.cv4 = conv(self.conv_dim*4, self.conv_dim*8, 4, batch_norm=True)#Fully connected Layerself.fc1 = nn.Linear(self.conv_dim*8*2*2,1)def forward(self, x):# After passing through each layer# Applying leaky relu activation functionx = F.leaky_relu(self.cv1(x),0.2)x = F.leaky_relu(self.cv2(x),0.2)x = F.leaky_relu(self.cv3(x),0.2)x = F.leaky_relu(self.cv4(x),0.2)# To pass throught he fully connected layer# We need to flatten the image firstx = x.view(-1,self.conv_dim*8*2*2)# Now passing through fully-connected layerx = self.fc1(x)return x

步骤3: 定义生成器算法

d30da4cfe3a34a4e28baacfc833e75f8.png

正如你们从图中看到的,我们给网络一个高斯矢量或者噪声矢量,它输出 s 中的值。图上的“ z”表示噪声,右边的 G (z)表示生成的样本。

与判别器一样,我们首先创建一个辅助函数来构建生成器网络,如下所示:

def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []convt_layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)# Appending the above conv layerlayers.append(convt_layer)if batch_norm:# Applying the batch normalization if Truelayers.append(nn.BatchNorm2d(out_channels))# Returning the sequential containerreturn nn.Sequential(*layers)

现在,是时候构建生成器网络了! !

class Generator(nn.Module):def __init__(self, z_size, conv_dim):super(Generator, self).__init__()self.z_size = z_sizeself.conv_dim = conv_dim#fully-connected-layerself.fc = nn.Linear(z_size, self.conv_dim*8*2*2)#2x2self.dcv1 = deconv(self.conv_dim*8, self.conv_dim*4, 4, batch_norm=True)#4x4self.dcv2 = deconv(self.conv_dim*4, self.conv_dim*2, 4, batch_norm=True)#8x8self.dcv3 = deconv(self.conv_dim*2, self.conv_dim, 4, batch_norm=True)#16x16self.dcv4 = deconv(self.conv_dim, 3, 4, batch_norm=False)#32 x 32def forward(self, x):# Passing through fully connected layerx = self.fc(x)# Changing the dimensionx = x.view(-1,self.conv_dim*8,2,2)# Passing through deconv layers# Applying the ReLu activation functionx = F.relu(self.dcv1(x))x= F.relu(self.dcv2(x))x= F.relu(self.dcv3(x))x= F.tanh(self.dcv4(x))#returning the modified imagereturn x

为了使模型更快地收敛,我们将初始化线性和卷积层的权重。根据相关研究论文中的描述:所有的权重都是从0中心的正态分布初始化的,标准差为0.02

我们将为此目的定义一个功能如下:

def weights_init_normal(m):classname = m.__class__.__name__# For the linear layersif 'Linear' in classname:torch.nn.init.normal_(m.weight,0.0,0.02)m.bias.data.fill_(0.01)# For the convolutional layersif 'Conv' in classname or 'BatchNorm2d' in classname:torch.nn.init.normal_(m.weight,0.0,0.02)

现在我们将超参数和两个网络初始化如下:

# Defining the model hyperparamameters
d_conv_dim = 32
g_conv_dim = 32
z_size = 100   #Size of noise vectorD = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)
# Applying the weight initialization
D.apply(weights_init_normal)
G.apply(weights_init_normal)print(D)
print()
print(G)

输出结果大致如下:

b0dbbadcdafb3a9f8bf28da76d19bd8f.png

判别器损失:

根据 DCGAN Research Paper 论文中描述:

        判别器总损失 = 真图像损失 + 假图像损失,即:d_loss = d_real_loss + d_fake_loss。

       不过,我们希望鉴别器输出1表示真正的图像和0表示假图像,所以我们需要设置的损失来反映这一点。

我们将定义双损失函数。一个是真正的损失,另一个是假的损失,如下:

def real_loss(D_out,smooth=False):batch_size = D_out.size(0)if smooth:labels = torch.ones(batch_size)*0.9else:labels = torch.ones(batch_size)labels = labels.to(device)criterion = nn.BCEWithLogitsLoss()loss = criterion(D_out.squeeze(), labels)return lossdef fake_loss(D_out):batch_size = D_out.size(0)labels = torch.zeros(batch_size)labels = labels.to(device)criterion = nn.BCEWithLogitsLoss()loss = criterion(D_out.squeeze(), labels)return loss

生成器损失:

根据 DCGAN Research Paper 论文中描述:

        生成器的目标是让判别器认为它生成的图像是真实的。

现在,是时候为我们的网络设置优化器了:

lr = 0.0005
beta1 = 0.3
beta2 = 0.999 # default value
# Optimizers
d_optimizer = optim.Adam(D.parameters(), lr, betas=(beta1, beta2))
g_optimizer = optim.Adam(G.parameters(), lr, betas=(beta1, beta2))

我将为我们的训练使用 Adam 优化器。因为它目前被认为是对GANs最有效的。根据上述介绍论文中的研究成果,确定了超参数的取值范围。他们已经尝试了它,这些被证明是最好的!超参数设置如下:

步骤4: 编写训练算法

我们必须为我们的两个神经网络编写训练算法。首先,我们需要初始化噪声向量,并在整个训练过程中保持一致。

# Initializing arrays to store losses and samples
samples = []
losses = []# We need to initilialize fixed data for sampling
# This would help us to evaluate model's performance
sample_size=16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()

对于判别器:

我们首先将真实的图像输入判别器网络,然后计算它的实际损失。然后生成伪造图像并输入判别器网络以计算虚假损失。

在计算了真实和虚假损失之后,我们对其进行求和,并采取优化步骤进行训练。

# setting optimizer parameters to zero
# to remove previous training data residue
d_optimizer.zero_grad()# move real images to gpu memory
real_images = real_images.to(device)# Pass through discriminator network
dreal = D(real_images)# Calculate the real loss
dreal_loss = real_loss(dreal)# For fake images# Generating the fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()# move z to the GPU memory
z = z.to(device)# Generating fake images by passing it to generator
fake_images = G(z)# Passing fake images from the disc network        
dfake = D(fake_images)
# Calculating the fake loss
dfake_loss = fake_loss(dfake)#Adding both lossess
d_loss = dreal_loss + dfake_loss
# Taking the backpropogation step
d_loss.backward()
d_optimizer.step()

对于生成器:

对于生成器网络的训练,我们也会这样做。刚才在通过判别器网络输入假图像之后,我们将计算它的真实损失。然后优化我们的生成器网络。

## Training the generator for adversarial loss
#setting gradients to zero
g_optimizer.zero_grad()# Generate fake images
z = np.random.uniform(-1, 1, size=(batch_size, z_size))
z = torch.from_numpy(z).float()
# moving to GPU's memory
z = z.to(device)# Generating Fake images
fake_images = G(z)# Calculating the generator loss on fake images
# Just flipping the labels for our real loss function
D_fake = D(fake_images)
g_loss = real_loss(D_fake, True)# Taking the backpropogation step
g_loss.backward()
g_optimizer.step()

步骤5: 训练模型

现在我们将开始100个epoch的训练: D

经过训练,损失的图表看起来大概是这样的:

297583a673349a29e8b1e16942b38a73.png

我们可以看到,判别器 Loss 是相当平滑的,甚至在100个epoch之后收敛到某个特定值。而生成器的Loss则飙升。

我们可以从下面步骤6中的结果看出,60个时代之后生成的图像是扭曲的。由此可以得出结论,60个epoch是一个最佳的训练节点。

步骤6: 测试模型

10个epoch之后:

96d12f9703a12ac0e77a0c2f87c3f146.png

20个epoch之后:

dc196f60f97d770c6d0c79c8f4d85ae7.png

30个epoch之后:

33365e822eda268d5e048d296cbea7be.png

40个epoch之后:

0198316b876a53571d5550f9f97556d5.png

50个epoch之后:

bdcf6ad6c3a1977c55cdd7f6fd9a4816.png

60个epoch之后:

4152c319441182ae8907a6b354fb6a44.png

70个epoch之后:

b48254e7675b521a66d059da034db13b.png

80个epoch之后:

e567311618525ee9d2e67e0292fe067d.png

90个epoch之后:

c933353f759c001a2d1cd8d9ec81f3ba.png

100个epoch之后:

0f83fe27e164d062d3afad75e7960e75.png

总结

我们可以看到,训练一个生成对抗性网络并不意味着它一定会产生好的图像。

从结果中我们可以看出,训练40-60个 epoch 的生成器生成的图像相对比其他更好。

您可以尝试更改优化器、学习速率和其他超参数,以使其生成更好的图像!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/147736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot中使用MongoDB完成数据存储

我们在开发中用到的数据存储工具有许多种,我们常见的数据存储工具包括: 关系性数据库:使用表格来存储数据,支持事务和索引。(如:MySQL,Oracle,SQL Server等)。NoSQL数据…

Redis篇---第七篇

系列文章目录 文章目录 系列文章目录前言一、是否使用过 Redis Cluster 集群,集群的原理是什么?二、 Redis Cluster 集群方案什么情况下会导致整个集群不可用?三、Redis 集群架构模式有哪几种?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分…

OPPO Watch纯手机开启远程ADB调试

Wear OS手表中,我们可以直接在开发者设置中打开WiFi调试。但是这在OPPO等魔改Android系统中不再奏效。 需要什么?? 手表一台手机一个OTG转接头一个手表充电器一个 演示设备 手机: OPPO Find X手表: OPPO Watch 1代 …

Android——模块级build.gradle配置——applicationId和namespace

官方地址: 配置应用模块-applicationId和namespace了解 build.gradle 中的实用设置。https://developer.android.google.cn/studio/build/configure-app-module?hlzh-cn 产生那些异常场景: Android:Namespace not specified. Please spec…

Halcon (3):窗体常用语法使用

文章目录 文章专栏视频资源前言halcon图像使用加载图片示例绘制常用图像批量批注绘制 文章专栏 Halcon开发 视频资源 机器视觉之C#联合Halcon 前言 在使用halcon的算子之前,我们要先学会如何在图片上面进行标注。因为我们不仅要导出处理的结果,还要导出…

论文阅读:YOLOV: Making Still Image Object Detectors Great at Video Object Detection

发表时间:2023年3月5日 论文地址:https://arxiv.org/abs/2208.09686 项目地址:https://github.com/YuHengsss/YOLOV 视频物体检测(VID)具有挑战性,因为物体外观的高度变化以及一些帧的不同恶化。有利的信息…

Windows10电脑没有微软商店的解决方法

在Windows10电脑中用户可以打开微软商店,下载自己需要的应用程序。但是,有用户反映自己Windows10电脑上没有微软商店,但是不清楚具体的解决方法,接下来小编给大家详细介绍关于解决Windows10电脑内微软商店不见了的方法&#xff0c…

计算机的发展

硬件的发展 第一台电子数字计算机:ENIAC(1946),作者:冯诺依曼,逻辑元件:电子管 bug:小虫子,会影响打点 Intel: 机器字长:计算机一次整数运算所能…

springcloudalibaba-3

一、Nacos Config入门 1. 搭建nacos环境【使用现有的nacos环境即可】 使用之前的即可 2. 在微服务中引入nacos的依赖 <!-- nacos配置依赖 --><dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-…

深入Rust:探索所有权和借用机制

大家好&#xff01;我是lincyang。 今天&#xff0c;我们将一起深入探索Rust语言中的一个核心概念&#xff1a;所有权和借用机制。 这些特性是Rust区别于其他语言的重要特点&#xff0c;它们在内存管理和并发编程中扮演着关键角色。 一、Rust所有权机制 1. 什么是所有权&#x…

Qt HTTP 摘要认证(海康球机摄像机ISAPI开发)

接到一个需求是开发下海康的球机,控制云台,给到我的是一个开发手册,当然了是海康的私有协议 ISAPI开发手册https://download.csdn.net/download/qq_37059136/88547425关于开发这块读文档就可以理解了,海康使用的是摘要认证,当然了海康已经给出使用范例 通过libcurl就可以直接连…

C/C++---------------LeetCode第1189. “气球” 的最大数量

气球的最大数量 题目及要求统计法在main内使用 题目及要求 给你一个字符串 text&#xff0c;你需要使用 text 中的字母来拼凑尽可能多的单词 “balloon”&#xff08;气球&#xff09;。 字符串 text 中的每个字母最多只能被使用一次。请你返回最多可以拼凑出多少个单词 “ba…

Openssl X509 v3 AuthorityKeyIdentifier实验与逻辑分析

Openssl是X509的事实标准&#xff0c;目前主流OS或个别安全性要求较高的设计场景&#xff0c;对X509的证书链验证已经不在停留在只从数字签名校验了&#xff0c;也就是仅仅从公钥验签的角度&#xff0c;在这些场景中&#xff0c;往往还会校验AuthorityKeyIdentifier和SubjectKe…

Egress-TLS-Origination

目录 文章目录 目录本节实战1、出口网关TLS发起2、通过 egress 网关发起双向 TLS 连接关于我最后 本节实战 实战名称&#x1f6a9; 实战&#xff1a;Egress TLS Origination-2023.11.19(failed)&#x1f6a9; 实战&#xff1a;通过 egress 网关发起双向 TLS 连接-2023.11.19(测…

STM32_SPI总线驱动OLED详细原理讲解

目录 这里写目录标题 第13章 Cortex-M4-SPI总线13.1 SPI总线概述13.1.1 SPI总线介绍13.1.2 SPI总线接口与物理拓扑结构13.1.3 SPI总线通信原理13.1.4 SPI总线数据格式 13.2 IO口模拟SPI操作OLED13.2.1 常见的显示设备13.2.2 OLED显示屏概述13.2.3 OLED特征13.2.4 显示原理13.2.…

【图数据库实战】HugeGraph架构

一、概述 作为一款通用的图数据库产品&#xff0c;HugeGraph需具备图数据的基本功能&#xff0c;如下图所示。HugeGraph包括三个层次的功能&#xff0c;分别是存储层、计算层和用户接口层。 HugeGraph支持OLTP和OLAP两种图计算类型&#xff0c;其中OLTP实现了Apache TinkerPop3…

webAPP基础学习

###视觉基础 part-I ####1.面试中常见的像素问题 >什么是像素? *1.什么是px? px-虚拟像素,css像素的单位 px是一个相对单位,相对于设备像素而言 >相对性 a.相对于同一个设备,css像素的可变的 css像素物理像素>会受到缩放的影响 css像素缩放倍数*单个物理像…

QEMU显示虚拟化的几种选项

QEMU可以通过通过命令行"-vga type"选择为客户机模拟的VGA卡的类别,可选择的类型有多个: -vga typeSelect type of VGA card to emulate. Valid values for type arecirrusCirrus Logic GD5446 Video card. All Windows versions starting from Windows 95 should …

MySQL集群高可用架构之MMM

目录 一、MMM概述 1.1 MMM 简介 1.2 MMM高可用架构 1.3 MMM工作原理 1.4 工作流程图 二、MMM高可用双主双从架构部署 1、架构&#xff1a; 2、搭建 MySQL 多主多从模式 3、安装配置 MySQL-MMM 4、故障测试 一、MMM概述 1.1 MMM 简介 MMM&#xff08;Master-Master re…