GAN简介
让我们先来简单了解一下GAN
GAN的全称是Generative Adversarial Networks,中文称为“生成对抗网络”,是一种在深度学习领域广泛使用的无监督学习方法。
GAN主要由两部分组成:生成器和判别器。生成器的目标是尽可能地生成真实的样本数据,而判别器的目标是尽可能准确地辨别出生成样本与真实样本。这两个组件通过竞争和对抗的方式共同工作,以提升各自的能力。这种网络结构能够处理没有标注数据的问题,并且在图像处理、自然语言处理等多个领域都有广泛应用。
它通过对抗生成来训练,目的是估测数据样本的潜在分布并生成新的数据样本。
GAN结构图
原理
生成器根据噪声,也就是随机值,来生成样本,而判别器判断哪些是真实数据,哪些是生成数据,然后将学习的经验反向传播给生成器,让生成器生成的样本不断向真实样本靠拢。
在训练过程中,生成器努力让生成的数据更加真实,而判别器努力的去判别数据的真假,二者·形成了对抗。最终两个网络形成了动态平衡,生成样本接近真实样本,而判别器也分辨不出来样本的真假,最终对给定图像预测为真的概率基本接近0.5,也就相当于随即猜测类别了。
公式
在公式中,
z代表输入G网络的噪声,
x代表真实图片
G(z)表示G网络生成的图片,
D(*)表示D网络判断图片是否真实的概率
2.GAN的算法流程和公式详解_哔哩哔哩_bilibili
在这个视频里有对这个公式的详解,这里就不详细说了/
我们经过简单了解之后,就要开始搭建GAN网络了,这里我们以手写字体识别数据集为例。
构建GAN网络的步骤
GAN生成对抗网络,步骤:
首先编写生成器和判别器
然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近1为目的优化我们的生成器
生成器的代码(针对手写字体识别)
预备知识
transforms.Normalize
transforms.Normalize()
函数用于对图像数据进行【标准化】处理。在深度学习中,数据标准化是一个常见的预处理步骤,它有助于模型更快地收敛,并提高模型的性能。
作用
数据标准化:如上所述,transforms.Normalize()函数可以对图像数据进行标准化处理,使数据分布符合标准正态分布。这有助于模型更快地收敛,并提高模型的性能。
提高模型泛化能力:通过对数据进行标准化,我们可以减少模型对特定数据集的过拟合,从而提高模型在未见过的数据上的泛化能力。
加速模型训练:标准化的数据可以使模型在训练过程中更快地学习到数据的特征,从而加速模型的训练速度。
参数
- mean:(list)长度与输入的通道数相同,代表每个通道上所有数值的平均值。
- std:(list)长度与输入的通道数相同,代表每个通道上所有数值的标准差。
Datadoder
参数
dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据
LeakyReLu函数
图像及参数
我们可以与ReLu函数对比,看一下区别:
主要区别就是在小于0的部分了
代码
导库
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
数据集处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
traindata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=True, download=True,transform=transform) # 训练集60,000张用于训练
在加载数据集时,我们要将数据进行归一化,在GAN中,我们就需要将数据归一化到(-1,1)之间,这是为什么呢?原因是我们在下面会用到Tanh激活函数,而Tanh函数的范围是在-1到1之间的,见下图:
在我们既然知道了为什么要这样,下面就要学会如何做到了
在ToTensor中,我们是将数据的范围限制在了(0,1)之间,而后面的Normalize是将数据限制在(-1,1)之间,计算公式为(x-均值)/方差
生成器
class Generator(torch.nn.Module):def __init__(self):super(Generator, self).__init__()self.main = torch.nn.Sequential(torch.nn.Linear(100, 256),torch.nn.ReLU(),torch.nn.Linear(256, 512),torch.nn.ReLU(),torch.nn.Linear(512, 28 * 28),torch.nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img
在这里,我们需要知道,生成器的输入和输出是什么,输入时我们的噪音,而输出一张图片。
在后向传播中,我们最后再将图片进行展平。
判别器
class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator, self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28 * 28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256, 1),torch.nn.Sigmoid())def forward(self, x):x = x.view(-1, 28 * 28)x = self.mainf(x)return x
我们同样需要了解判别器的输入和输出,输入是一张(1,28,28)图片,输出为二分类的概率值。
在判别器中,我们如果使用ReLu函数,在小于0的部分就会出现梯度消失的问题,这时候我们就可以用到LeadkyReLu了,它能够优化GAN的训练。
最后的Sigmoid激活函数,将输出压缩到0到1之间,这通常用于二分类问题,但在这里,它用于表示输入是真实数据的概率。
而在后向传播中,我们需要先对图片进行展平。
定义损失函数,优化函数和优化器
# 定义损失函数和优化函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)
# 定义优化器
gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss() # 损失函数
在这里,我们选择使用BCELoss,交叉熵损失函数,这是因为在GAN中,判别器通常被视为一个二分类器,它试图区分输入是真实样本还是由生成器生成的假样本,而BCELoss就是用来做二分类的损失函数,正好对应。
在优化器部分,它们分别对生成器和判别器的参数进行优化。
图像显示
def gen_img_plot(model, testdata):pre = np.squeeze(model(testdata).detach().cpu().numpy())# tensor.detach()# 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(pre[i])plt.show()
因为我们最终要得到要得到的是处理数据输出的数组,所以我们要用squeeze将额外的单维度删除。
detach是单独开辟空间来保存数据,从而保证数据的稳定性。
plt.figure用来生成一个新画布。
使用subplot函数在一个4x4的网格中定位每个子图。i + 1是因为子图的索引是从1开始的,而不是从0开始。
imshow是在子图中显示图像。
最后的show来显示整体的图像。
后向传播与训练模型
dis_loss = [] # 判别器损失值记录
gen_loss = [] # 生成器损失值记录
lun = [] # 轮数
for epoch in range(60):d_epoch_loss = 0g_epoch_loss = 0cout = len(trainload) # 938批次for step, (img, _) in enumerate(trainload):img = img.to(device) # 图像数据# print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size = img.size(0) # 一批次的图片数量64# 随机生成一批次的100维向量样本,或者说100个像素点random_noise = torch.randn(size, 100, device=device)# 判断器的后向传播dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实数据的损失函数值d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 人造的数据的损失函数值d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()# 生成器的后向传播gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_lossdis_loss.append(float(d_epoch_loss))gen_loss.append(float(g_epoch_loss))print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')lun.append(epoch + 1)
使用enumerate
遍历训练数据集trainload
,其中img
是图像数据,但_
表示我们在这里不使用标签(因为GAN是无监督的)。
step()用来更新判别器的模型参数。
在生成器的后向传播部分,
我们先进行梯度清零,然后通过生成器生成假图像,然后进行前向传播。
我们期望判别器对假图像的评分接近1(真实),因此我们将目标标签设置为与fake_output
形状相同的全1张量torch.ones_like(fake_output)
。
在这里,d_loss和g_loss是一张图像中的损失值,而d_epoch_loss和g_epoch_loss是每一轮损失值的累加,用于最后图像的绘制。
生成图像
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
plt.figure()
plt.plot(lun, dis_loss, 'r', label='判别器损失值')
plt.plot(lun, gen_loss, 'b', label='生成器损失值')
plt.xlabel('训练轮数', fontsize=12)
plt.ylabel('损失值', fontsize=12)
plt.title('损失值随着训练轮数得变化情况:', fontsize=18)
plt.legend()
plt.show()
random_noise = torch.randn(16, 100, device=device)
gen_img_plot(gen, random_noise)
随机生成的噪声有16个样本,100个维度
全部代码
import matplotlib.pyplot as plt
import matplotlib
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np# 导入数据集并且进行数据处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)
])
traindata = torchvision.datasets.MNIST(root='./data', train=True, download=True,transform=transform) # 训练集60,000张用于训练
# 利用DataLoader加载数据集
trainload = DataLoader(dataset=traindata, shuffle=True, batch_size=64)# GAN生成对抗网络,步骤:
# 首先编写生成器和判别器
# 然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
# 接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
# 生成器的代码(针对手写字体识别)
class Generator(torch.nn.Module):def __init__(self):super(Generator, self).__init__()self.main = torch.nn.Sequential(torch.nn.Linear(100, 256),torch.nn.ReLU(),torch.nn.Linear(256, 512),torch.nn.ReLU(),torch.nn.Linear(512, 28 * 28),torch.nn.Tanh())def forward(self, x):img = self.main(x)img = img.reshape(-1, 28, 28)return img# 判别器,最后判断0,1,这意味着最后可以是一个神经元或者两个神经元
class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator, self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28 * 28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256, 1),torch.nn.Sigmoid())def forward(self, x):x = x.view(-1, 28 * 28)x = self.mainf(x)return x# 定义损失函数和优化函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discraiminator().to(device)
# 定义优化器
gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
loss_fn = torch.nn.BCELoss() # 损失函数def gen_img_plot(model, testdata):pre = np.squeeze(model(testdata).detach().cpu().numpy())# tensor.detach()# 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(pre[i])plt.show()# 后向传播
dis_loss = [] # 判别器损失值记录
gen_loss = [] # 生成器损失值记录
lun = [] # 轮数
for epoch in range(60):d_epoch_loss = 0g_epoch_loss = 0cout = len(trainload) # 938批次for step, (img, _) in enumerate(trainload):img = img.to(device) # 图像数据# print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size = img.size(0) # 一批次的图片数量64# 随机生成一批次的100维向量样本,或者说100个像素点random_noise = torch.randn(size, 100, device=device)# 判断器的后向传播dis_opt.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实数据的损失函数值d_real_loss.backward()gen_img = gen(random_noise)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 人造的数据的损失函数值d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossdis_opt.step()# 生成器的后向传播gen_opt.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_lossdis_loss.append(float(d_epoch_loss))gen_loss.append(float(g_epoch_loss))print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')lun.append(epoch + 1)matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
plt.figure()
plt.plot(lun, dis_loss, 'r', label='判别器损失值')
plt.plot(lun, gen_loss, 'b', label='生成器损失值')
plt.xlabel('训练轮数', fontsize=12)
plt.ylabel('损失值', fontsize=12)
plt.title('损失值随着训练轮数得变化情况:', fontsize=18)
plt.legend()
plt.show()
random_noise = torch.randn(16, 100, device=device)
gen_img_plot(gen, random_noise)
运行结果
第1轮的生成器损失值:2226.86328125,判别器损失值461.5265808105469
第2轮的生成器损失值:2378.969970703125,判别器损失值459.3701477050781
第3轮的生成器损失值:2422.438232421875,判别器损失值355.0154113769531
第4轮的生成器损失值:3410.994873046875,判别器损失值172.3834686279297
第5轮的生成器损失值:3589.7734375,判别器损失值168.7844696044922
第6轮的生成器损失值:3944.258544921875,判别器损失值125.10688781738281
第7轮的生成器损失值:4293.7861328125,判别器损失值138.3419952392578
第8轮的生成器损失值:4436.89404296875,判别器损失值159.64407348632812
第9轮的生成器损失值:4485.7646484375,判别器损失值177.5517578125
第10轮的生成器损失值:4136.85986328125,判别器损失值210.64602661132812
第11轮的生成器损失值:4072.7958984375,判别器损失值246.29910278320312
第12轮的生成器损失值:4298.8623046875,判别器损失值183.00152587890625
第13轮的生成器损失值:4899.4794921875,判别器损失值171.33628845214844
第14轮的生成器损失值:4851.458984375,判别器损失值161.920654296875
第15轮的生成器损失值:4995.62646484375,判别器损失值155.28732299804688
第16轮的生成器损失值:4987.4140625,判别器损失值142.6618194580078
第17轮的生成器损失值:5511.90673828125,判别器损失值126.41560363769531
第18轮的生成器损失值:5509.65771484375,判别器损失值157.1754913330078
第19轮的生成器损失值:5164.8671875,判别器损失值143.5445556640625
第20轮的生成器损失值:5490.17236328125,判别器损失值156.86929321289062
第21轮的生成器损失值:5189.4921875,判别器损失值177.5731201171875
第22轮的生成器损失值:5293.32080078125,判别器损失值168.159912109375
第23轮的生成器损失值:4971.2646484375,判别器损失值189.78167724609375
第24轮的生成器损失值:4590.87158203125,判别器损失值211.07289123535156
第25轮的生成器损失值:4739.5732421875,判别器损失值214.7382354736328
第26轮的生成器损失值:4700.568359375,判别器损失值218.89926147460938
第27轮的生成器损失值:4146.5048828125,判别器损失值269.0607604980469
第28轮的生成器损失值:3846.898681640625,判别器损失值287.00604248046875
第29轮的生成器损失值:3559.870361328125,判别器损失值317.5647888183594
第30轮的生成器损失值:3378.71240234375,判别器损失值336.30572509765625
第31轮的生成器损失值:4269.37060546875,判别器损失值257.89910888671875
第32轮的生成器损失值:5209.896484375,判别器损失值191.99989318847656
第33轮的生成器损失值:4632.1728515625,判别器损失值261.9479064941406
第34轮的生成器损失值:2979.66015625,判别器损失值363.874267578125
第35轮的生成器损失值:2710.74462890625,判别器损失值405.0263671875
第36轮的生成器损失值:2661.800048828125,判别器损失值421.5466613769531
第37轮的生成器损失值:2625.377197265625,判别器损失值414.751708984375
第38轮的生成器损失值:2809.101318359375,判别器损失值399.09942626953125
第39轮的生成器损失值:3797.715087890625,判别器损失值314.6676025390625
第40轮的生成器损失值:6223.8974609375,判别器损失值151.0428924560547
第41轮的生成器损失值:3305.96533203125,判别器损失值355.9456481933594
第42轮的生成器损失值:2672.400634765625,判别器损失值395.23834228515625
第43轮的生成器损失值:2538.265625,判别器损失值425.629638671875
第44轮的生成器损失值:2496.415283203125,判别器损失值443.06085205078125
第45轮的生成器损失值:2451.716796875,判别器损失值449.18194580078125
第46轮的生成器损失值:2397.526123046875,判别器损失值467.0350341796875
第47轮的生成器损失值:2427.2900390625,判别器损失值459.0263977050781
第48轮的生成器损失值:2440.54736328125,判别器损失值469.6186218261719
第49轮的生成器损失值:2597.76953125,判别器损失值439.3223876953125
第50轮的生成器损失值:2724.003173828125,判别器损失值438.4668273925781
第51轮的生成器损失值:2539.636474609375,判别器损失值459.2343444824219
第52轮的生成器损失值:2288.4130859375,判别器损失值498.2747802734375
第53轮的生成器损失值:2244.51513671875,判别器损失值506.4640197753906
第54轮的生成器损失值:2242.865478515625,判别器损失值502.57275390625
第55轮的生成器损失值:2198.66552734375,判别器损失值506.5917053222656
第56轮的生成器损失值:2217.268310546875,判别器损失值502.77081298828125
第57轮的生成器损失值:2246.22802734375,判别器损失值502.93206787109375
第58轮的生成器损失值:2165.259033203125,判别器损失值516.4965209960938
第59轮的生成器损失值:2146.760009765625,判别器损失值519.462890625
第60轮的生成器损失值:2110.582763671875,判别器损失值528.8636474609375进程已结束,退出代码为 0
我们得生成器损失值是波动的,判别器损失值也是,很难说他们的趋势走向(当然估计和我的训练轮数有关)
这是我们生成器生成的“伪造的图片”,从这里可以看出来已经很不错了。