# 生成对抗网络
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
latent_size = 64 # 潜在空间(latent space)的维度数量
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# Image processing
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
# std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], # 1 for greyscale channels
std=[0.5])])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./datasets',
train=True,
transform=transform,
download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
for i,_ in data_loader:
print(i.shape,i.max(),i.min(),torch.unique(_))
break
# 鉴别器
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
# 生成器
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1) # 裁剪到0-1
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
torch.zeros(4, 1).shape
torch.randn(4,50).shape
# 生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,
# 判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由
# 真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数
# 鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断
# 它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差
# 特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是
#生成器模型中的参数
# 训练
total_step = len(data_loader) # 总批次数
for epoch in range(num_epochs):
# 遍历每个批次数据
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)
# 创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 训练鉴别器
outputs = D(images) # 获取鉴别器对真实图片的鉴别分数
d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失
real_score = outputs # 真实鉴别分数
# z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z) # 获取生成的图片
outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数
d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失
fake_score = outputs
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失
reset_grad()
# 鉴别器损失反向传播
d_loss.backward()
# 根据梯度更新参数
d_optimizer.step()
# 训练生成器
# 随机初始化一个噪音图片(用一定大小的向量表示)
z = torch.randn(batch_size, latent_size).to(device)
# 通过生成器生成图片
fake_images = G(z)
outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分
# 生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差
g_loss = criterion(outputs, real_labels)
# 清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数
reset_grad()
g_loss.backward()
g_optimizer.step()
# 每隔200批次打印日志
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 真实图片只需要保存一次
if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
# 每个轮次都会保存一次生成器生成的图片
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
上面是真实图片,下面是经过200个轮次后的生成图片