目录
关于GAN网络
关于生成模型和判别模型
GAN网路的特性和搭建步骤(以手写字体识别数据集为例)
搭建步骤
特性
GAN的目标函数(损失函数)
目标函数原理
torch.nn.BCELoss(实际应用的损失函数)
代码
生成器
判别器
后向传播
完整代码
运行结果
图像显示
关于GAN网络
GAN包含有两个模型,一个是生成模型(generative model),一个是判别模型(discriminative model)。生成模型的任务是生成看起来自然真实的、和原始数据相似的实例。判别模型的任务是判断给定的实例看起来是自然真实的还是人为伪造的
关于生成模型和判别模型
通俗理解:生成模型像“一个造假币的机器”,而判别模型像“检测假币的机器”。
在判别模型这里只有“是假币”,“不是假币”。
而生成器的目的就是让自己的假币越来越逼真以至于能够“骗过”判断模型,判别器则努力不被生成器欺骗。模型经过交替优化训练,两种模型都能得到提升,但最终我们要得到的是效果提升到很高很好的生成模型(造假币的机器),这个生成模型(造假币的机器)所生成的产品能达到真假难分的地步。
(虽然最后判别器也能得到提升,可是随着我们不断地训练,判别器的准确率会停留在0.5左右,这是由于两个模型的特性决定的,也是之所以叫生成对抗神经网络的原因。)
GAN网路的特性和搭建步骤(以手写字体识别数据集为例)
搭建步骤
#GAN生成对抗网络,步骤:
#首先编写生成器和判别器
#然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
#接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
#生成器的代码(针对手写字体识别)
特性
我们第一开始后向传播的时候,我们先更新判断器参数,然后更新生成器的,然后接着这样重复,后面随着循环的进行,我们生成器生成的的图片大体上肯定是越来越接近真实的图片的。
假设我们已经训练了50轮,我们假设这个时候已经很接近正确的真实图片。
当他开始下一轮的时候,此时我们更新判别器在判0方面的参数更新的依据是:把这个很接近真实图片的伪造图片放进判断模型判断,通过损失函数计算和判断结果和0的误差来进行参数的更新
此时,就会出现一个矛盾点,这个矛盾点有两个主要方面:
第一点,伪造图片数据用来判别器判为0的优化(过损失函数计算和判断结果和0的误差),而我们的真实图片用来判别器判为1的优化(过损失函数计算和判断结果和1的误差)
第二点,因为第一点的原因也会影响到我们生成器,因为我们生成一是将生成的伪造图片进行判别器判断后,通过损失函数计算和判断结果和1的误差来进行参数的更新
这样就有了个矛盾点:生成器是让伪造数据判断结果趋向于1,判别器是让伪造数据的判断结果趋向于0,这样就形成了‘对抗’
我们生成器一开始肯定是损失下降的。
可慢慢的不论我们怎么优化生成器的参数,我们生成的图片经过判断器判断后和1的区别似乎小到某个临界点之后越来越大了,以至于后面这个区别似乎在某个区间反复的“振荡”,并不是愈来愈小。
这表示如果把每一轮的损失都分为生成器损失和判别器损失,这两个损失都应该是波动型的。
按我的理解,这个模型的特点是体现在一开始我们生成器生成的图片是愈来愈“真”,但是如果次数很多的话,就会导致判别器很难判断0或者1,以至于影响到后面生成器生成的图片
GAN的目标函数(损失函数)
目标函数原理
这个是判别器的目标函数,D(x)表示生成器判1的概率,D(G(x))表示生成器判0的概率,光从这里看这个应当是愈来愈大的。
这是我们生成器 目标函数,可以看出来是愈来愈小的。
可我们实际应用的话也有一些区别。
torch.nn.BCELoss(实际应用的损失函数)
(以下内容来自网络)
可以看出来在计算我们生成模型的损失值得时候,我们对我们伪造的图片进行判0得时候,这一块得损失值应当是增大的,因为实际上预测他的概率或者说最后输出层那一个神经元里面的数是应该不断接近一的。
代码
生成器
class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__()self.main=torch.nn.Sequential(torch.nn.Linear(100,256),torch.nn.ReLU(),torch.nn.Linear(256,512),torch.nn.ReLU(),torch.nn.Linear(512,28*28),torch.nn.Tanh())def forward(self,x):img=self.main(x)img=img.reshape(-1,28,28)return img
判别器
class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator,self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28*28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256,1),torch.nn.Sigmoid())def forward(self,x):x=x.view(-1,28*28)x=self.mainf(x)return x
后向传播
#定义损失函数和优化函数
device='cuda' if torch.cuda.is_available() else 'cpu'
gen=Generator().to(device)
dis=Discraiminator().to(device)
#定义优化器
gen_opt=torch.optim.Adam(gen.parameters(),lr=0.0001)
dis_opt=torch.optim.Adam(dis.parameters(),lr=0.0001)
loss_fn=torch.nn.BCELoss()#损失函数
#图像显示
def gen_img_plot(model,testdata):pre=np.squeeze(model(testdata).detach().cpu().numpy())
#tensor.detach()
#返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4,4,i+1)plt.imshow(pre[i])plt.show()#后向传播
dis_loss=[]#判别器损失值记录
gen_loss=[]#生成器损失值记录
lun=[]#轮数
for epoch in range(60):d_epoch_loss=0g_epoch_loss=0cout=len(trainload)#938批次for step, (img, _) in enumerate(trainload):img=img.to(device) #图像数据#print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size=img.size(0)#一批次的图片数量64#随机生成一批次的100维向量样本,或者说100个像素点random_noise=torch.randn(size,100,device=device)#先进性判断器的后向传播dis_opt.zero_grad()real_output=dis(img)d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#真实数据的损失函数值d_real_loss.backward()gen_img=gen(random_noise)fake_output=dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))#人造的数据的损失函数值d_fake_loss.backward()d_loss=d_real_loss+d_fake_lossdis_opt.step()#生成器的后向传播gen_opt.zero_grad()fake_output=dis(gen_img)g_loss=loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_loss
完整代码
import matplotlib.pyplot as plt
from matplotlib import font_manager
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
#GAN生成对抗网络,步骤:
#首先编写生成器和判别器
#然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
#接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
#生成器的代码(针对手写字体识别)
class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__()self.main=torch.nn.Sequential(torch.nn.Linear(100,256),torch.nn.ReLU(),torch.nn.Linear(256,512),torch.nn.ReLU(),torch.nn.Linear(512,28*28),torch.nn.Tanh())def forward(self,x):img=self.main(x)img=img.reshape(-1,28,28)return img
#判别器,最后判断0,1,这意味着最后可以是一个神经元或者两个神经元
class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator,self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28*28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256,1),torch.nn.Sigmoid())def forward(self,x):x=x.view(-1,28*28)x=self.mainf(x)return x
#定义损失函数和优化函数
device='cuda' if torch.cuda.is_available() else 'cpu'
gen=Generator().to(device)
dis=Discraiminator().to(device)
#定义优化器
gen_opt=torch.optim.Adam(gen.parameters(),lr=0.0001)
dis_opt=torch.optim.Adam(dis.parameters(),lr=0.0001)
loss_fn=torch.nn.BCELoss()#损失函数
#图像显示
def gen_img_plot(model,testdata):pre=np.squeeze(model(testdata).detach().cpu().numpy())
#tensor.detach()
#返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4,4,i+1)plt.imshow(pre[i])plt.show()#后向传播
dis_loss=[]#判别器损失值记录
gen_loss=[]#生成器损失值记录
lun=[]#轮数
for epoch in range(60):d_epoch_loss=0g_epoch_loss=0cout=len(trainload)#938批次for step, (img, _) in enumerate(trainload):img=img.to(device) #图像数据#print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size=img.size(0)#一批次的图片数量64#随机生成一批次的100维向量样本,或者说100个像素点random_noise=torch.randn(size,100,device=device)#先进性判断器的后向传播dis_opt.zero_grad()real_output=dis(img)d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#真实数据的损失函数值d_real_loss.backward()gen_img=gen(random_noise)fake_output=dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))#人造的数据的损失函数值d_fake_loss.backward()d_loss=d_real_loss+d_fake_lossdis_opt.step()#生成器的后向传播gen_opt.zero_grad()fake_output=dis(gen_img)g_loss=loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_lossdis_loss.append(float(d_epoch_loss))gen_loss.append(float(g_epoch_loss))print(f'第{epoch+1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')lun.append(epoch+1)
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
plt.figure()
plt.plot(lun,dis_loss,'r',label='判别器损失值')
plt.plot(lun,gen_loss,'b',label='生成器损失值')
plt.xlabel('训练轮数', fontproperties=font, fontsize=12)
plt.ylabel('损失值', fontproperties=font, fontsize=12)
plt.title('损失值随着训练轮数得变化情况:',fontproperties=font, fontsize=18)
plt.legend(prop=font)
plt.show()
random_noise=torch.randn(16,100,device=device)
gen_img_plot(gen,random_noise)
运行结果
D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第七周的代码\GAN.py
第1轮的生成器损失值:2528.59814453125,判别器损失值353.325927734375
第2轮的生成器损失值:2921.286376953125,判别器损失值248.41244506835938
第3轮的生成器损失值:3441.375732421875,判别器损失值159.46466064453125
第4轮的生成器损失值:3337.6513671875,判别器损失值215.8180694580078
第5轮的生成器损失值:3971.089111328125,判别器损失值125.399658203125
第6轮的生成器损失值:4233.75341796875,判别器损失值117.00439453125
第7轮的生成器损失值:4556.33935546875,判别器损失值124.78838348388672
第8轮的生成器损失值:4610.49658203125,判别器损失值130.80592346191406
第9轮的生成器损失值:4639.509765625,判别器损失值141.80921936035156
第10轮的生成器损失值:4648.3037109375,判别器损失值148.4178466796875
第11轮的生成器损失值:4764.65380859375,判别器损失值158.76829528808594
第12轮的生成器损失值:5018.9208984375,判别器损失值138.4495391845703
第13轮的生成器损失值:5489.87939453125,判别器损失值142.05093383789062
第14轮的生成器损失值:5771.021484375,判别器损失值122.76760864257812
第15轮的生成器损失值:5327.2373046875,判别器损失值142.30992126464844
第16轮的生成器损失值:5096.5966796875,判别器损失值131.3242950439453
第17轮的生成器损失值:5312.2607421875,判别器损失值143.69143676757812
第18轮的生成器损失值:5721.6220703125,判别器损失值135.46119689941406
第19轮的生成器损失值:4894.25341796875,判别器损失值178.1021728515625
第20轮的生成器损失值:4638.419921875,判别器损失值182.18136596679688
第21轮的生成器损失值:4944.3798828125,判别器损失值167.8157958984375
第22轮的生成器损失值:4811.68798828125,判别器损失值185.08151245117188
第23轮的生成器损失值:4419.439453125,判别器损失值213.2805633544922
第24轮的生成器损失值:4265.69873046875,判别器损失值224.0750732421875
第25轮的生成器损失值:3908.226318359375,判别器损失值263.6612243652344
第26轮的生成器损失值:3829.155517578125,判别器损失值275.2219543457031
第27轮的生成器损失值:3941.054443359375,判别器损失值260.74066162109375
第28轮的生成器损失值:4242.72216796875,判别器损失值236.7626953125
第29轮的生成器损失值:3621.1337890625,判别器损失值297.2765808105469
第30轮的生成器损失值:3010.93310546875,判别器损失值355.7358703613281
第31轮的生成器损失值:2710.3935546875,判别器损失值389.6776428222656
第32轮的生成器损失值:2636.367919921875,判别器损失值399.9930725097656
第33轮的生成器损失值:2637.282470703125,判别器损失值402.30859375
第34轮的生成器损失值:2988.23974609375,判别器损失值357.68804931640625
第35轮的生成器损失值:5289.29541015625,判别器损失值199.2281951904297
第36轮的生成器损失值:5623.095703125,判别器损失值194.9387969970703
第37轮的生成器损失值:3149.01416015625,判别器损失值348.37677001953125
第38轮的生成器损失值:2634.451904296875,判别器损失值405.61566162109375
第39轮的生成器损失值:2547.442138671875,判别器损失值418.43597412109375
第40轮的生成器损失值:2448.359130859375,判别器损失值432.43096923828125
第41轮的生成器损失值:2429.5537109375,判别器损失值443.5609130859375
第42轮的生成器损失值:2452.239013671875,判别器损失值449.2001953125
第43轮的生成器损失值:2652.614990234375,判别器损失值427.6022033691406
第44轮的生成器损失值:3612.08935546875,判别器损失值330.52484130859375
第45轮的生成器损失值:5931.68701171875,判别器损失值155.98318481445312
第46轮的生成器损失值:2881.45751953125,判别器损失值420.7775573730469
第47轮的生成器损失值:2272.14111328125,判别器损失值491.7653503417969
第48轮的生成器损失值:2222.366943359375,判别器损失值504.3684387207031
第49轮的生成器损失值:2263.885498046875,判别器损失值493.1470031738281
第50轮的生成器损失值:2237.902587890625,判别器损失值507.2246398925781
第51轮的生成器损失值:2207.36962890625,判别器损失值512.4918823242188
第52轮的生成器损失值:2241.222900390625,判别器损失值509.4316711425781
第53轮的生成器损失值:2264.532958984375,判别器损失值501.825927734375
第54轮的生成器损失值:2370.49365234375,判别器损失值487.0404052734375
第55轮的生成器损失值:2869.4560546875,判别器损失值425.4640197753906
第56轮的生成器损失值:3764.148681640625,判别器损失值317.8245849609375
第57轮的生成器损失值:3606.946533203125,判别器损失值340.7588195800781
第58轮的生成器损失值:2236.921875,判别器损失值521.7313842773438
第59轮的生成器损失值:2077.1865234375,判别器损失值551.8837890625
第60轮的生成器损失值:2065.66943359375,判别器损失值555.025634765625进程已结束,退出代码0
图像显示
果然,从本次实验来看,我们得生成器损失值是波动的,判别器损失值也是,很难说他们的趋势走向(当然估计和我的训练轮数有关)
这是我们生成器生成的“伪造的图片”,从这里可以看出来已经很不错了。