PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

PyTorch深度学习实战(31)——生成对抗网络

    • 0. 前言
    • 1. GAN
    • 2. GAN 模型分析
    • 3. 利用 GAN 模型生成手写数字
    • 小结
    • 系列链接

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. GAN

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

GAN原理

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. GAN 模型分析

为了生成手写数字的图像,我们采取以下策略:

  • 导入 MNIST 数据
  • 初始化随机噪声
  • 定义生成网络模型
  • 定义判别网络模型
  • 使用生成网络生成伪造图像,生成网络在最初只能生成噪声图像,噪声图像是通过将一组噪声值通过权重随机的神经网络得到的图像
  • 交替训练两个模型
    • 将生成的图像与原始图像串联起来,判别网络预测每个图像是伪造图像还是真实图像,对判别网络进行训练,判别网络的损失是图像的预测值和实际值(标签)的二进制交叉熵,生成的伪造图像的实际值(标签)为 0,原始数据集中真实图像的实际值(标签)为 1
    • 训练生成网络利用输入噪声生成伪造图像,使其看起来更接近真实图像,从而使生成图像有可能欺骗判别网络
    • 输入噪声通过生成网络传递输出伪造图像,将生成网络生成的图像输入到判别网络中,此时,判别网络权重被冻结,因为生成网络的目标是欺骗判别网络,因此,假设生成的伪造图像实际值(标签)为 1,生成网络的损失是判别网络对输入图像的预测值和实际值 (1) 的二进制交叉熵

了解了 GAN 的基本原理后,在下一小节,我们实现 GAN 生成 MNIST 手写数字图像。

3. 利用 GAN 模型生成手写数字

(1) 导入相关库并定义设备:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
device = "cuda" if torch.cuda.is_available() else "cpu"from torchvision.datasets import MNIST
from torchvision import transforms

(2) 导入 MNIST 数据,定义具有内置数据转换功能的数据加载器,以便缩放输入数据:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,), std=(0.5,))
])data_loader = torch.utils.data.DataLoader(MNIST('MNIST/', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)

(3) 定义判别网络模型类:

class Discriminator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential( nn.Linear(784, 1024),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)

在以上代码中,使用 LeakyReLU 激活函数替换 ReLU。打印判别网络的简要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (1,784)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1              [-1, 1, 1024]         803,840LeakyReLU-2              [-1, 1, 1024]               0Dropout-3              [-1, 1, 1024]               0Linear-4               [-1, 1, 512]         524,800LeakyReLU-5               [-1, 1, 512]               0Dropout-6               [-1, 1, 512]               0Linear-7               [-1, 1, 256]         131,328LeakyReLU-8               [-1, 1, 256]               0Dropout-9               [-1, 1, 256]               0Linear-10                 [-1, 1, 1]             257Sigmoid-11                 [-1, 1, 1]               0
================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.57
Estimated Total Size (MB): 5.61
----------------------------------------------------------------

(4) 定义生成网络模型类 Generator

class Generator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(100, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, x):return self.model(x)

生成网络根据 100 维随机噪声输入生成图像。打印生成网络模型的简要信息:

generator = Generator().to(device)
print(summary(generator, (1,100)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1               [-1, 1, 256]          25,856LeakyReLU-2               [-1, 1, 256]               0Linear-3               [-1, 1, 512]         131,584LeakyReLU-4               [-1, 1, 512]               0Linear-5              [-1, 1, 1024]         525,312LeakyReLU-6              [-1, 1, 1024]               0Linear-7               [-1, 1, 784]         803,600Tanh-8               [-1, 1, 784]               0
================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.67
Estimated Total Size (MB): 5.71
----------------------------------------------------------------

(5) 定义函数生成随机噪声并将其注册到设备中:

def noise(size):n = torch.randn(size, 100)return n.to(device)

(6) 定义函数来训练判别网络。

判别网络训练函数 (discriminator_train_step) 将真实数据 (real_data) 和伪造数据 (fake_data) 作为输入:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):

重置优化器梯度:

    d_optimizer.zero_grad()

在对损失值执行反向传播之前,预测真实数据 (real_data) 并计算损失 (error_real):

    prediction_real = discriminator(real_data)error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))error_real.backward()

在真实数据上计算判别网络损失时,我们期望判别网络预测输出为 1。因此,在判别网络的训练过程中,使用 torch.ones 作为标签,期望判别网络在真实数据上的输出为 1,从而计算判别网络在真实数据上的损失。

在对损失值执行反向传播之前,预测伪造数据 (fake_data) 并计算损失 (error_fake):

    prediction_fake = discriminator(fake_data)error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))error_fake.backward()

在伪造数据上计算判别网络损失时,我们期望判别网络预测输出为 0。因此,在判别网络的训练过程中,使用 torch.zeros 作为标签,期望判别网络在伪造数据上的输出为 0,从而计算判别网络在伪造数据上的损失。

更新权重并返回整体损失(将模型在 real_dataerror_realfake_dataerror_fake 的损失值相加):

    d_optimizer.step()return error_real + error_fake

(7) 训练生成网络模型。

定义生成网络训练函数 generator_train_step 并传入伪造数据 fake_data 作为参数:

def generator_train_step(real_data, fake_data, loss, g_optimizer):

重置优化器梯度:

    g_optimizer.zero_grad()

预测判别网络对伪造数据 (fake_data) 的输出:

    prediction = discriminator(fake_data)

在计算生成网络的损失时,使用 torch.ones 作为标签,期望判别网络在伪造数据上的输出为 1,以在训练生成网络时欺骗判别网络输出值 1,以此来鼓励生成网络生成更加逼真的数据,并让判别网络无法区分其真伪:

    error = loss(prediction, torch.ones(len(real_data), 1).to(device))

执行反向传播,更新权重,并返回损失:

    error.backward()g_optimizer.step()return error

(8) 定义模型对象、生成网络和判别网络的优化器,以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

(9) 训练模型。

循环训练模型 200epochs (num_epochs):

num_epochs = 200d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):N = len(data_loader)d_loss_items = []g_loss_items = []for i, (images, _) in enumerate(data_loader):

加载真实数据 (real_data) 和伪造数据,其中,伪造数据是通过将大小与真实数据样本数相同的噪声数据 (batch_size = len(real_data)) 传入生成网络网络获得的。需要注意的是,必须调用 fake_data.detach(),否则训练无法正常进行。通过 detach() 函数分离出来一个新的张量,这样在 discriminator_train_step() 中调用 error.backward() 时,与生成网络相关的张量(生成 fake_data )不会受到影响。使用 discriminator_train_step 函数训练判别网络:

        real_data = images.view(len(images), -1).to(device)fake_data = generator(noise(len(real_data))).to(device)fake_data = fake_data.detach()

训练判别网络后,继续训练生成网络。从噪声数据生成一组新的伪造图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(noise(len(real_data))).to(device)g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())g_loss_items.append(g_loss.item())d_loss_epoch.append(np.average(d_loss_items))g_loss_epoch.append(np.average(g_loss_items))

绘制判别网络和生成网络的损失随训练的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型性能检测

(10) 可视化模型训练后生成的伪造数据:

z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0), cmap='gray')
plt.show()

生成结果

在上图中,可以看到利用 GAN 生成逼真的图像,但仍有一定的改进空间,在之后的学习中,我们将介绍更多 GAN 的改进模型生成更逼真的图像。

小结

生成对抗网络是一种强大的深度学习模型,由生成器网络和判别器网络组成,通过彼此之间的竞争来提高性能,已经在图像生成、图像修复、图像转换和自然语言处理等领域取得了巨大的成功。其核心思想是通过生成器和判别器之间的博弈过程来实现真实样本的生成。生成器负责生成逼真的样本,而判别器则负责判断样本是真实还是伪造。通过不断的训练和迭代,生成器和判别器会相互竞争并逐渐提高性能。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/639343.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flask框架小程序后端分离开发学习笔记《3》客户端向服务器端发送请求

Flask框架小程序后端分离开发学习笔记《3》客户端向服务器端发送请求 Flask是使用python的后端,由于小程序需要后端开发,遂学习一下后端开发。 一、为什么请求数据需要先编码 #构造一个HTTP请求 http_request GET / HTTP/1.1\r\nhost:{}\r\n\r\n.for…

大语言模型系列-ELMo

文章目录 前言一、ELMo的网络结构和流程二、ELMo的创新点总结 前言 在前文大语言模型系列-word2vec已经提到word2vec的缺点: 为每个词汇表中每个分词静态生成一个对应的词向量表示,没有考虑到语境,因此无法无法处理多义词 ps:先…

Python Web 开发之 Flask 入门实践

导语:Flask 是一个轻量级的 Python Web 框架,广受开发者喜爱。本文将带领大家了解 Flask 的基本概念、搭建一个简单的 Web 项目以及如何进一步扩展功能。 一、Flask 简介 Flask 是一个基于 Werkzeug 和 Jinja2 的微型 Web 框架,它的特点是轻…

JAVA RPC Thrift基操实现与微服务间调用

一、Thrift 基操实现 1.1 thrift文件 namespace java com.zn.opit.thrift.helloworldservice HelloWorldService {string sayHello(1:string username) }1.2 执行命令生成Java文件 thrift -r --gen java helloworld.thrift生成代码HelloWorldService接口如下 /*** Autogene…

MBR扇区修复和GRUB引导修复实验

修复MBR扇区 步骤一:在进行实验之前我们需要新加一块磁盘,并对新加磁盘进行分区处理,用来备份sda磁盘的MBR及分区表信息。(注:在实验中可以不像我如此这么繁琐,一个主分区,并格式化挂载即可&am…

Android 通过adb命令查看应用流量

一. 获取应用pid号 通过adb shell ps | grep 包名 来获取app的 pid号 二. 查看应用流量情况 使用adb shell cat /proc/#pid#/net/dev 命令 来获取流量数据 备注: Recevice: 表示收包 Transmit: 表示发包 bytes: 表示收发的字节数 packets: 表示收发正确的包量…

【CompletableFuture任务编排】游戏服务器线程模型及其线程之间的交互(以排行榜线程和玩家线程的交互为例子)

需求: 1.我们希望玩家的业务在玩家线程执行,无需回调,因此是多线程处理。 2.匹配线程负责匹配逻辑,是单独一个线程。 3.排行榜线程负责玩家的上榜等。 4.从排行榜线程获取到排行榜列表后,需要给玩家发奖修改玩家数…

【GitHub项目推荐--不错的 C 开源项目】【转载】

大学时接触的第一门语言就是 C语言,虽然距 C语言创立已过了40多年,但其经典性和可移植性任然是当今众多高级语言中不可忽视的,想要学好其他的高级语言,最好是先从掌握 C语言入手。 今天老逛盘点 GitHub 上不错的 C语言 开源项目&…

【代码随想录11】239. 滑动窗口最大值 347. 前 K 个高频元素

目录 239. 滑动窗口最大值题目描述做题思路参考代码 347. 前 K 个高频元素题目描述参考代码 239. 滑动窗口最大值 题目描述 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每…

详解C语言中`||`的短路机制

在C语言中,逻辑或运算符(||)是一种常用的逻辑运算符,用于组合多个条件表达式。与其他编程语言一样,C语言中的逻辑或运算符具有短路机制,这是一种非常重要的概念,本文将深入解释C语言中的||短路机…

sportplay项目

1.编写userMapping.xml时报错, Error querying database. Cause: java.sql.SQLSyntaxErrorException: You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near ‘‘easyuser’ W…

MSVS C# Matlab的混合编程系列2 - 构建一个复杂(含多个M文件)的动态库:

前言: 本节我们尝试将一个有很多函数和文件的Matlab算法文件集成到C#的项目里面。 本文缩语: MT = Matlab 问题提出: 1 我们有一个比较复杂的Matlab文件: 这个MATLAB的算法,写了很多的算法函数在其他的M文件里面,这样,前面博客的方法就不够用了。会报错: 解决办法如下…

华为机考入门python3--(0)模拟题2-vowel元音字母翻译

分类:字符串 知识点: 字符串转list,每个字符成为list中的一个元素 list(string) 字符串变大小写 str.upper(), str.lower() 题目来自【华为招聘模拟考试】 # If you need to import additional packages or classes, please import …

分享5款简单实用的软件,值得收藏

​ 电脑上的各类软件有很多,除了那些常见的大众化软件,还有很多不为人知的小众软件,专注于实用功能,简洁干净、功能强悍。 1.自定义图标——TileIconifier ​ TileIconifier 是一款可以自定义 Windows 开始菜单图标的软件&#…

蓝牙运动耳机什么牌子的好?2024年运动无线耳机推荐

​在选择运动耳机时,我们需要综合考虑音质、舒适度以及适应不同运动场景的能力。好的运动耳机能够提高运动效率,增添锻炼的乐趣。今天,我为大家介绍几款在音质、佩戴舒适度、防水防汗等方面表现卓越的运动耳机,助你选购最适合的一…

《GreenPlum系列》GreenPlum初级教程-05GreenPlum语言DDLDMLDQL

文章目录 第五章 DDL&DML&DQL1.DDL(Data Definition Language)数据定义语言1.1 创建数据库1.2 查询数据库1.3 删除数据库1.4 创建表1.5 修改表1.6 清除表1.7 删除表 2.DML(Data Manipulation Language)数据操作语言2.1 数据导入2.2 数据更新和删除2.3 数据导出 3.DQL(D…

04 单链表

目录 链表的概念和结构单链表OJ练习 1. 链表的概念和结构 1.1 链表的概念 链表是一种物理存储结构上非连续、非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的 1.从上图可以看出链式结构在逻辑上是连续的,物理上不一定连续 2.现…

139 删除链表中的重复元素II

问题描述:存在一个按照升序排列的链表,给你这个链表的头结点head,请你删除链表中所有存在数字重复情况的节点,只保留链表中没有出现的数字,返回的结果同样按升序的结果链表。 求解思路:双指针求解&#xf…

常用的gpt-4 prompt words收集4

1. it poses a certain risk to my work 这对我来说是一个风险点 2. one point to note is that 需要说的一个问题是 3. What is the English phonetic transcription of ‘emoji’? emoji的音标是什么? 4. it would be best if you can insert some proper …

Docker(十)Docker Compose

作者主页: 正函数的个人主页 文章收录专栏: Docker 欢迎大家点赞 👍 收藏 ⭐ 加关注哦! Docker Compose 项目 Docker Compose 是 Docker 官方编排(Orchestration)项目之一,负责快速的部署分布式…