Pytorch 实现 GAN 对抗网络

GAN 对抗网络

GAN(Generative Adversarial Network)对抗网络指的是神经网络中包括两个子网络,一个用于生成信息,一个用于验证信息。下面的例子是生成图片的对抗网络,一个网络用于生成图片,另一个网络用于验证。G 网络用于生成图片,不断的学习并生成更接近于训练数据的图像,D 网络用于鉴别图片,通过学习更准确的识别出图片的真假,最终通过学习让 G 网络能够生成高质量的目标图片。下面通过代码实现两种不同的 GAN,图片为自动生成手写图片,采用 MNIST数据集。

  • DCGAN (Deep Convolution)深度卷积生成对抗网络
  • SAGAN(Self Attention)自注意力生成对抗网络

安装依赖

本文将使用 sklearn,首先安装 sklearn。

pip install -U scikit-learn

DCGAN 深度卷积对抗网络

数据准备

采用 MNIST 数据,并只选用 7、8 两个数字,MINST 中 7、8 数字各有 200 张。

def make_datapath_list():"""创建用于学习和验证的图像数据及标注数据的文件路径列表。 """train_img_list = list() #保存图像文件的路径for img_idx in range(200):img_path = "./data/img_78/img_7_" + str(img_idx)+'.jpg'train_img_list.append(img_path)img_path = "./data/img_78/img_8_" + str(img_idx)+'.jpg'train_img_list.append(img_path)return train_img_list
class ImageTransform():"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)
class GAN_Img_Dataset(data.Dataset):"""图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取经过预处理后的图像的张量格式的数据'''img_path = self.file_list[index]img = Image.open(img_path)  #[ 高度 ][ 宽度 ] 黑白#图像的预处理img_transformed = self.transform(img)return img_transformed
#创建DataLoader并确认执行结果#创建文件列表
train_img_list=make_datapath_list()#创建Dataset
mean = (0.5,)
std = (0.5,)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std))#创建DataLoader
batch_size = 64train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#确认执行结果
batch_iterator = iter(train_dataloader)  #转换为迭代器
imges = next(batch_iterator)  #取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

生成网络实现

需要根据输入的随机数生成图像,对数据的维度进行放大,并增加维度中的元素数量,通过 nn.ConvTranspose2d 转置卷积进行实现。转置卷积是卷积的反向操作,卷积输出特征通常比输入数据小,反向卷积输出比输入大,可以看做数据放大操作。

# 导入软件包
import random
import math
import time
import pandas as pd
import numpy as np
from PIL import Imageimport torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimfrom torchvision import transforms# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)
class Generator(nn.Module):def __init__(self, z_dim=20, image_size=64):super(Generator, self).__init__()self.layer1 = nn.Sequential(nn.ConvTranspose2d(z_dim, image_size * 8,kernel_size=4, stride=1),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.ConvTranspose2d(image_size * 8, image_size * 4,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 1, kernel_size=4,stride=2, padding=1),nn.Tanh())# 注意:因为是黑白图像,所以只有一个输出通道def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.last(out)return out

根据输入,生成图片

#动作确认
import matplotlib.pyplot as plt
%matplotlib inlineG = Generator(z_dim=20, image_size=64)# 输入的随机数
input_z = torch.randn(1, 20)# 将张量尺寸变形为(1,20,1,1)
input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)#输出假图像
fake_images = G(input_z)img_transformed = fake_images[0][0].detach().numpy()
plt.imshow(img_transformed, 'gray')
plt.show()

没有经过学习生成的图片,目标是通过学习生成手写数字的效果。
在这里插入图片描述

鉴别网络实现

鉴别网络是一个进行图片分类的神经网络模型,由 5 层网络组成。

class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=64):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, image_size, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))#注意:因为是黑白图像,所以输入通道只有一个self.layer2 = nn.Sequential(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1),nn.LeakyReLU(0.1, inplace=True))self.last = nn.Conv2d(image_size*8, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.last(out)return out

DCGAN 训练

生成网络和识别网络使用 BCEWithLogitsLoss 作为损失函数。

#网络的初始化处理
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:#Conv2d和ConvTranspose2d的初始化nn.init.normal_(m.weight.data, 0.0, 0.02)nn.init.constant_(m.bias.data, 0)elif classname.find('BatchNorm') != -1:#BatchNorm2d的初始化nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)#开始初始化
G.apply(weights_init)
D.apply(weights_init)print("网络已经成功地完成了初始化")#网络已经成功地完成了初始化def train_model(G, D, dataloader, num_epochs):#确认是否能够使用GPU加速device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)#设置最优化算法g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])#定义误差函数criterion = nn.BCEWithLogitsLoss(reduction='mean')#使用硬编码的参数z_dim = 20mini_batch_size = 64#将网络载入GPU中G.to(device)D.to(device)G.train()  #将模式设置为训练模式D.train()  #将模式设置为训练模式#如果网络相对固定,则开启加速torch.backends.cudnn.benchmark = True#图像张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size#设置迭代计数器iteration = 1logs = []#epoch循环for epoch in range(num_epochs):#保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0  #epoch的损失总和epoch_d_loss = 0.0  #epoch的损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')#以minibatch为单位从数据加载器中读取数据的循环for imges in dataloader:# --------------------#1.判别器D的学习# --------------------#如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免if imges.size()[0] == 1:continue#如果能使用GPU,则将数据送入GPU中imges = imges.to(device)#创建正确答案标签和伪造数据标签#在epoch最后的迭代中,小批次的数量会减少mini_batch_size = imges.size()[0]label_real = torch.full((mini_batch_size,), 1).to(device)label_fake = torch.full((mini_batch_size,), 0).to(device)#对真正的图像进行判定d_out_real = D(imges)#生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)#计算误差d_loss_real = criterion(d_out_real.view(-1), label_real.float())d_loss_fake = criterion(d_out_fake.view(-1), label_fake.float())d_loss = d_loss_real + d_loss_fake#反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------#2.生成器G的学习# --------------------#生成伪造图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images = G(input_z)d_out_fake = D(fake_images)#计算误差g_loss = criterion(d_out_fake.view(-1), label_real.float())#反向传播处理g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------#3.记录结果# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1#epoch的每个phase的loss和准确率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss/batch_size, epoch_g_loss/batch_size))print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()return G, D
#执行训练和验证操作
num_epochs = 200
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs)
#将生成的图像和训练数据可视化
#反复执行本单元中的代码,直到生成感觉良好的图像为止device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)#生成图像
fake_images = G_update(fixed_z.to(device))#训练数据
batch_iterator = iter(train_dataloader) #转换成迭代器
imges = next(batch_iterator)  #取出位于第一位的元素#输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):#将训练数据放入上层plt.subplot(2, 5, i+1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')#将生成数据放入下层plt.subplot(2, 5, 5+i+1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

SAGAN自注意力生成对抗网络

自注意力对抗网络是为了解决卷积操作只关注局部的问题,在深度对抗网络中,卷积操作只是对周围的信息进行了放大而忽略了全局信息。自注意力机制中用于计算特征向量是关注与其相似的那些像素点,而不只是周围的像素点。首先,通过 1x1 卷积进行逐点卷积,当数据进行压缩,之后再通过频谱归一化,对网络层的权重进行归一化操作吗,更有助于模型的收敛。

准备数据

准备 Dataset、Dataloader和图像预处理

def make_datapath_list():"""制作用于学习、验证的图像数据和标注数据的文件路径表。 """train_img_list = list()  # 保存图像文件路径for img_idx in range(200):img_path = "./data/img_78/img_7_" + str(img_idx)+'.jpg'train_img_list.append(img_path)img_path = "./data/img_78/img_8_" + str(img_idx)+'.jpg'train_img_list.append(img_path)return train_img_list
class ImageTransform():"""图像的预处理类"""def __init__(self, mean, std):self.data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])def __call__(self, img):return self.data_transform(img)class GAN_Img_Dataset(data.Dataset):"""图像的Dataset类。继承PyTorch的Dataset类"""def __init__(self, file_list, transform):self.file_list = file_listself.transform = transformdef __len__(self):'''返回图像的张数'''return len(self.file_list)def __getitem__(self, index):'''获取预处理图像的Tensor格式的数据'''img_path = self.file_list[index]img = Image.open(img_path)  #[高][宽]黑白# 图像的预处理img_transformed = self.transform(img)return img_transformed# 创建DataLoader并确认操作# 制作文件列表
train_img_list=make_datapath_list()#制作Dataset
mean = (0.5,)
std = (0.5,)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std))# 制作DataLoader
batch_size = 64train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 动作的确认
batch_iterator = iter(train_dataloader)  # 转换成迭代器
imges = next(batch_iterator)  # 找出第一个要素
print(imges.size())  # torch.Size([64, 1, 64, 64])

创建网络

以下代码用于实现 SelfAttention、生成网络和验证网络。

#导入软件包
import random
import math
import time
import pandas as pd
import numpy as np
from PIL import Imageimport torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimfrom torchvision import transforms# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)class Self_Attention(nn.Module):""" Self-AttentionのLayer"""def __init__(self, in_dim):super(Self_Attention, self).__init__()#准备1×1的卷积层的逐点卷积self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)#创建Attention Map时归一化用的SoftMax函数self.softmax = nn.Softmax(dim=-2)#原有输入数据x与作为Self−Attention Map的o进行加法运算时使用的系数# output = x +gamma*o#刚开始gamma=0,之后让其进行学习self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):#输入变量X = x#先计算卷积,再对尺寸进行变形,将形状由B、C'、W、H变为B、C'、Nproj_query = self.query_conv(X).view(X.shape[0], -1, X.shape[2]*X.shape[3]) #尺寸 :B、C'、Nproj_query = proj_query.permute(0, 2, 1)  #转置操作proj_key = self.key_conv(X).view(X.shape[0], -1, X.shape[2]*X.shape[3])  #尺寸 :B、C'、N#乘法运算S = torch.bmm(proj_query, proj_key)  #bmm是以批次为单位进行的矩阵乘法运算#归一化attention_map_T = self.softmax(S)  #将行i方向上的和转换为1的SoftMax函数attention_map = attention_map_T.permute(0, 2, 1)  #进行转置#计算Self-Attention Mapproj_value = self.value_conv(X).view(X.shape[0], -1, X.shape[2]*X.shape[3])  #尺寸 :B、C、No = torch.bmm(proj_value, attention_map.permute(0, 2, 1))  #对Attention Map进行转置并计算乘积#将作为Self−Attention Map的o的张量尺寸与x对齐,并输出结果o = o.view(X.shape[0], X.shape[1], X.shape[2], X.shape[3])out = x+self.gamma*oreturn out, attention_mapclass Generator(nn.Module):def __init__(self, z_dim=20, image_size=64):super(Generator, self).__init__()self.layer1 = nn.Sequential(#添加频谱归一化处理nn.utils.spectral_norm(nn.ConvTranspose2d(z_dim, image_size * 8,kernel_size=4, stride=1)),nn.BatchNorm2d(image_size * 8),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(#添加频谱归一化处理nn.utils.spectral_norm(nn.ConvTranspose2d(image_size * 8, image_size * 4,kernel_size=4, stride=2, padding=1)),nn.BatchNorm2d(image_size * 4),nn.ReLU(inplace=True))self.layer3 = nn.Sequential(#添加频谱归一化处理nn.utils.spectral_norm(nn.ConvTranspose2d(image_size * 4, image_size * 2,kernel_size=4, stride=2, padding=1)),nn.BatchNorm2d(image_size * 2),nn.ReLU(inplace=True))#添加Self−Attentin网络层self.self_attntion1 = Self_Attention(in_dim=image_size * 2)self.layer4 = nn.Sequential(#添加频谱归一化处理nn.utils.spectral_norm(nn.ConvTranspose2d(image_size * 2, image_size,kernel_size=4, stride=2, padding=1)),nn.BatchNorm2d(image_size),nn.ReLU(inplace=True))#添加Self−Attentin网络层self.self_attntion2 = Self_Attention(in_dim=image_size)self.last = nn.Sequential(nn.ConvTranspose2d(image_size, 1, kernel_size=4,stride=2, padding=1),nn.Tanh())#注意 :由于是黑白图像,因此输出的通道数为1self.self_attntion2 = Self_Attention(in_dim=64)def forward(self, z):out = self.layer1(z)out = self.layer2(out)out = self.layer3(out)out, attention_map1 = self.self_attntion1(out)out = self.layer4(out)out, attention_map2 = self.self_attntion2(out)out = self.last(out)return out, attention_map1, attention_map2class Discriminator(nn.Module):def __init__(self, z_dim=20, image_size=64):super(Discriminator, self).__init__()self.layer1 = nn.Sequential(#追加Spectral Normalizationnn.utils.spectral_norm(nn.Conv2d(1, image_size, kernel_size=4,stride=2, padding=1)),nn.LeakyReLU(0.1, inplace=True))#注意 :由于是黑白图像,因此输入的通道数为1self.layer2 = nn.Sequential(#追加Spectral Normalizationnn.utils.spectral_norm(nn.Conv2d(image_size, image_size*2, kernel_size=4,stride=2, padding=1)),nn.LeakyReLU(0.1, inplace=True))self.layer3 = nn.Sequential(#追加频谱归一化nn.utils.spectral_norm(nn.Conv2d(image_size*2, image_size*4, kernel_size=4,stride=2, padding=1)),nn.LeakyReLU(0.1, inplace=True))#追加Self-Attentin层self.self_attntion1 = Self_Attention(in_dim=image_size*4)self.layer4 = nn.Sequential(#追加频谱归一化nn.utils.spectral_norm(nn.Conv2d(image_size*4, image_size*8, kernel_size=4,stride=2, padding=1)),nn.LeakyReLU(0.1, inplace=True))#追加Self-Attentin层self.self_attntion2 = Self_Attention(in_dim=image_size*8)self.last = nn.Conv2d(image_size*8, 1, kernel_size=4, stride=1)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out, attention_map1 = self.self_attntion1(out)out = self.layer4(out)out, attention_map2 = self.self_attntion2(out)out = self.last(out)return out, attention_map1, attention_map2

模型训练和验证

# 创建一个函数来学习模型def train_model(G, D, dataloader, num_epochs):# 确认是否可以使用GPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("使用设备:", device)# 优化方法的设定g_lr, d_lr = 0.0001, 0.0004beta1, beta2 = 0.0, 0.9g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])#将误差函数从定义变更为hinge version of the adversarial loss# criterion = nn.BCEWithLogitsLoss(reduction='mean')#参数硬编码z_dim = 20mini_batch_size = 64# 将网络变成GPUG.to(device)D.to(device)G.train()  # 将模型转换为训练模式D.train()  #将模型转换为训练模式# 如果网络固定到一定程度,就可以提高速度torch.backends.cudnn.benchmark = True#图像的张数num_train_imgs = len(dataloader.dataset)batch_size = dataloader.batch_size# 设置了迭代计数器iteration = 1logs = []# epoch循环for epoch in range(num_epochs):# 保存开始时间t_epoch_start = time.time()epoch_g_loss = 0.0  #epoch损失总和epoch_d_loss = 0.0  # epoch损失总和print('-------------')print('Epoch {}/{}'.format(epoch, num_epochs))print('-------------')print('(train)')# 从数据加载器中每次提取minibatch的循环for imges in dataloader:# --------------------# 1. Discriminator的学习# --------------------# 如果迷你batch的大小为1,则在batch normatization中会出错,所以避免if imges.size()[0] == 1:continue# 如果你能用GPU,就把数据传送给GPU。imges = imges.to(device)# 制作正确标签和假标签# epoch的最后迭代会导致小批量的数量变少mini_batch_size = imges.size()[0]#label_real = torch.full((mini_batch_size,), 1).to(device)#label_fake = torch.full((mini_batch_size,), 0).to(device)# 判断真正的图像d_out_real, _, _ = D(imges)# 生成假图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images, _, _ = G(input_z)d_out_fake, _, _ = D(fake_images)# 计算误差→hinge version of the adversarial loss变更# d_loss_real = criterion(d_out_real.view(-1), label_real)# d_loss_fake = criterion(d_out_fake.view(-1), label_fake)d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()#误差d_out_real大于1,误差为0。d out译文:real & gt;在1中,#1.0 - d_out_real为负时用ReLU设为0d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()# 如果误差d_out_fake小于-1,则误差为0。# 1.0 + d_out_real为负时用ReLU设为0d_loss = d_loss_real + d_loss_fake# 反向传播g_optimizer.zero_grad()d_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# --------------------# 2. Generator的学习# --------------------# 生成假图像并进行判定input_z = torch.randn(mini_batch_size, z_dim).to(device)input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)fake_images, _, _ = G(input_z)d_out_fake, _, _ = D(fake_images)# 计算误差→hinge version of the adversarial loss变更#g_loss = criterion(d_out_fake.view(-1), label_real)g_loss = - d_out_fake.mean()# 反向传播g_optimizer.zero_grad()d_optimizer.zero_grad()g_loss.backward()g_optimizer.step()# --------------------# 3. 记录# --------------------epoch_d_loss += d_loss.item()epoch_g_loss += g_loss.item()iteration += 1# epoch的每个phase的loss和正确答案率t_epoch_finish = time.time()print('-------------')print('epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(epoch, epoch_d_loss/batch_size, epoch_g_loss/batch_size))print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))t_epoch_start = time.time()# print("总迭代次数:", iteration)return G, D
# 网络的初始化
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:# Conv2d和ConvTranspose2d的初始化nn.init.normal_(m.weight.data, 0.0, 0.02)nn.init.constant_(m.bias.data, 0)elif classname.find('BatchNorm') != -1:#  BatchNorm2dの初期化nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)#初始化的实施
G.apply(weights_init)
D.apply(weights_init)print("网络的初始化完成")# 进行训练和验证
num_epochs = 300
G_update, D_update = train_model(G, D, dataloader=train_dataloader, num_epochs=num_epochs)
# 将生成图像和训练数据可视化
# 本直到生成感觉良好的图像为止,这个单元格会重新运行几次。device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 输入的随机数。
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)# 画像生成
fake_images, am1, am2 = G_update(fixed_z.to(device))#训练数据
batch_iterator = iter(train_dataloader)  #转换成迭代器
imges = next(batch_iterator)  # 找出第一个要素fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):# 把训练数据放在上层plt.subplot(2, 5, i+1)plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')# 在下层显示生成数据plt.subplot(2, 5, 5+i+1)plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

GAN 对抗网络可以通过两个网络互相对抗,最终让生成网络更加有效的生成目标图片,由于两个网络都是在不断学习中互相促进,可以完善学习的效果。如果通过人工识别的方式,生成网络很难进行学习,在网络开始阶段生成效果很差,人眼直接可以识别图片是假的。而鉴别网络模型刚开始可以给出一些错误信息,反而促进模型的最终效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/6754.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[C++基础学习-06]----C++指针详解

前言 指针是一个存储变量地址的变量,可以用来访问内存中的数据。在C中,指针是一种非常有用的数据类型,可以帮助我们在程序中对内存进行操作和管理。 正文 01-指针简介 指针的基本概念如下: 声明指针:使用“*”符…

[单片机课设]十字路口交通灯的设计

题目要求: 模拟交通灯运行情况。南北绿灯亮30秒,南北黄灯亮3秒,东西红灯亮33秒;南北红灯亮33秒,东西绿灯亮30秒,东西黄灯亮3秒;要求数码管同步显示时间的倒计时,用定时器实现延时。…

(HAL)STM32F103C8T6——内部flash模拟EEPROM

内部Flash大部分空间是用来存储烧录进单片机的程序代码,因此可以将非代码等无关区域用来存储数据。项目工程的代码量可以通过Keil uVision5软件底下框查看,如下图所示。一般只需参考代码量(Code)以及只读数据(RO-data&…

某盾BLACKBOX逆向关键点

需要准备的东西: 1、原JS码 2、AST解混淆码 3、token(来源于JSON) 一、原JS码很好获取,每次页面刷新,混淆的代码都会变,这是正常,以下为部分代码 while (Qooo0) {switch (Qooo0) {case 110 14 - 55: {function O0…

C++入门第二节--关键字、命名空间、输入输出

点赞关注不迷路!本节涉及c入门关键字、命名空间、输入输出... 1. C关键字 C总计63个关键字,C语言32个关键字 asmdoifreturntrycontinueautodoubleinlineshorttypedefforbooldynamic_castintsignedtypeidpublicbreakelselongsizeoftypenamethrowcaseen…

A Dexterous Hand-Arm Teleoperation System

A Dexterous Hand-Arm Teleoperation System Based on Hand Pose Estimation and Active Vision解读 摘要1. 简介2.相关工作2.1 机器人遥操作2.2 主动视觉(Active Vision) 3. 硬件设置4. 基于视觉的机器人手部姿态估计4.1 Transteleop4.2 Dataset 5. 主动…

升级OpenSSH版本(安装telnet远程管理主机)

一 OpenSSH是什么 OpenSSH 是 SSH (Secure SHell) 协议的免费开源实现。SSH协议族可以用来进行远程控制, 或在计算机之间传送文件。而实现此功能的传统方式,如telnet(终端仿真协议)、 rcp ftp、 rlogin、 rsh都是极为不安全的&…

C++奇迹之旅:string类接口详解(上)

文章目录 📝为什么学习string类?🌉 C语言中的字符串🌉string考察 🌠标准库中的string类🌉string类的常用接口说明🌠string类对象的常见构造 🚩总结 📝为什么学习string类…

二维泊松方程(Neumann+Direchliet边界条件)有限元Matlab编程求解|程序源码+说明文本

专栏导读 作者简介:工学博士,高级工程师,专注于工业软件算法研究本文已收录于专栏:《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现,并提供所有案例完整源码;2.单元…

stm32开发之netxduo网口通讯,网线热插拔处理

前言 在使用netxduo组件时,如果在上电过程中,未插入网线,eth驱动使能过程中未正常初始化本次使用以下几种方式进行设置 问题原因 使用定时器事件回调方式 网络组件中进行调整 /** Copyright (c) 2024-2024,shchl** SPDX-Licen…

Initialize failed: invalid dom.

项目场景: 在vue中使用Echarts出现的错误 问题描述 提示:这里描述项目中遇到的问题: 例如:在vue中使用Echarts出现的错误 ERROR Initialize failed: invalid dom.at Module.init (webpack-internal:///./node_modules/echarts…

Delta lake with Java--入门

最近在研究数据湖,虽然不知道研究成果是否可以用于工作,但我相信机会总是留给有准备的人。 数据湖尤其是最近提出的湖仓一体化概念,很少有相关的资料,目前开源的项目就三个,分别是hudi, delta lake, iceberg。最终选择…

算法打卡day41

今日任务: 1)198.打家劫舍 2)213.打家劫舍II 3)337.打家劫舍III 4)复习day16 198.打家劫舍 题目链接:198. 打家劫舍 - 力扣(LeetCode) 你是一个专业的小偷,计划偷窃沿街…

【hive】transform脚本

文档地址:https://cwiki.apache.org/confluence/display/Hive/LanguageManualTransform 一、介绍二、实现1.脚本上传到本地2.脚本上传到hdfs 三、几个需要注意的点1.脚本名不要写全路径2.using后面语句中,带不带"python"的问题3.py脚本Shebang…

LNMP部署wordpress

1.环境准备 总体架构介绍 序号类型名称外网地址内网地址软件02负载均衡服务器lb0110.0.0.5192.168.88.5nginx keepalived03负载均衡服务器lb0210.0.0.6192.168.88.6nginx keepalived04web服务器web0110.0.0.7192.168.88.7nginx05web服务器web0210.0.0.8192.168.88.8nginx06we…

基于Springboot的校园生活服务平台(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的校园生活服务平台(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构…

shell脚本-监控系统内存和磁盘容量

监控内存和磁盘容量除了可以使用zabbix监控工具来监控,还可以通过编写Shell脚本来监控。 #! /bin/bash #此脚本用于监控内存和磁盘容量,内存小于500MB且磁盘容量小于1000MB时报警#提取根分区剩余空间 disk_size$(df / | awk /\//{print $4})#提取内存剩…

《机器学习算法面试宝典》重磅发布!

我们经常会组织场算法岗技术&面试讨论会,会邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。 基于讨论和经验总结,历时半年的梳理和修改&#xff…

eNSP-浮动静态路由配置

ip route-static 192.168.1.0 24 192.168.3.2 preference 60 #设置路由 目标网络地址 和 下一跳地址 preference值越大 优先级越低 一、搭建拓扑结构 二、主机配置 pc1 pc2 三、配置路由器 1.AR1路由器配置 <Huawei>sys #进入系统视图 [Huawei]int g0/0/0 #进入接…

详解面向对象-类和对象

1.面向对象与面向过程的区别 ①面向过程 &#xff1a;关注点是在实现功能的步骤上面&#xff0c;就是分析出解决问题所需要的步骤&#xff0c;让后函数把这些步骤一步一步实现&#xff0c;使用的时候一个一个依次调用就可以。对于简单的流程是适合面向过程的方式进行的&#x…