PyTorch深度学习实战(32)——DCGAN详解与实现

PyTorch深度学习实战(32)——DCGAN详解与实现

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 模型分析
      • 1.2 数据集介绍
    • 2. 构建 DCGAN 生成人脸图像
    • 小结
    • 系列链接

0. 前言

DCGAN (Deep Convolutional Generative Adversarial Networks) 是基于生成对抗网络 (Convolutional Generative Adversarial Networks, GAN) 的深度学习模型,相比传统的 GAN 模型,DCGAN 通过引入卷积神经网络 (Convolutional Neural Networks, CNN) 架构来提升生成网络和判别网络的性能。DCGAN 中的生成网络和判别网络都是使用卷积层和反卷积层构建的深度神经网络。生成网络接收一个随机噪声向量作为输入,并通过反卷积层将其逐渐转化为与训练数据相似的输出图像,判别网络则是一个用于分类真实和生成图像的卷积神经网络。

1. 模型与数据集分析

1.1 模型分析

我们已经学习了 GAN 的基本原理并并使用 PyTorch 实现了 GAN 模型用于生成 MNIST 手写数字图像。同时,我们已经知道,与普通神经网络相比,卷积神经网络 (Convolutional Neural Networks, CNN) 架构能够更好地学习图像中的特征。在本节中,我们将学习使用深度卷积生成对抗网络生成图像,在模型中使用卷积和池化操作替换全连接层。
首先,介绍如何使用随机噪声( 100 维向量)生成图像,将噪声形状转换为 batch size x 100 x 1 x 1,其中 batch size 表示批大小,由于在 DCGAN 使用 CNN,因此需要添加额外的通道信息,即 batch size x channel x height x width 的形式,channel 表示通道数,heightwidth 分别表示高度和宽度。
接下来,利用 ConvTranspose2d 将生成的噪声向量转换为图像,ConvTranspose2d 与卷积操作相反,将输入的小特征图通过预定义的核大小、步幅和填充上上采样到较大的尺寸。利用上采样逐渐将向量形状从 batch size x 100 x 1 x 1 转换为 batch size x 3 x 64 x 64,即将 100 维的随机噪声向量转换成一张 64 x 64 的图像。

1.2 数据集介绍

为了训练对抗生成网络,我们需要了解本节所用的数据集,数据集取自 Celeb A,可以自行构建数据集,也可以下载本文所用数据集,下载地址:https://pan.baidu.com/s/1dvDCBLSGwblg57p9RDBEJQ,提取码:y9fiCelebA 是一个大规模的人脸属性数据集,其中包含超过 20 万张名人图像,每张图像有 40 个属性注释。CelebA 数据集的图像来源于互联网上的名人照片,包括电影、音乐和体育界等各个领域。这些图像具有多样的姿势、表情、背景和装扮,涵盖了各种真实世界的场景。

2. 构建 DCGAN 生成人脸图像

接下来,我们使用 PyTorch 构建 DCGAN 模型生成人脸图像。

(1) 下载并获取人脸图像,示例图像如下所示:

示例图像
(2) 导入相关库:

from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"

(3) 定义数据集和数据加载器。

裁剪图像,只保留面部区域并丢弃图像中的其他部分。首先,使用级联滤波器识别图像中的人脸:

face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')

OpenCV 提供了 4 个级联分类器用于人脸检测,可以从 OpenCV 官方下载这些级联分类器文件:

  • haarcascade_frontalface_alt.xml (FA1)
  • haarcascade_frontalface_alt2.xml (FA2)
  • haarcascade_frontalface_alt_tree.xml (FAT)
  • haarcascade_frontalface_default.xml (FD)

可以使用不同的数据集评估这些级联分类器的性能,总的来说这些分类器具有相似的准确率。

创建一个新文件夹,并将所有裁剪后的人脸图像转储到新文件夹中:

if not os.path.exists('cropped_faces'):os.mkdir('cropped_faces')images = glob('male_female_face_images/females/*.jpg')+glob('male_female_face_images/males/*.jpg')
for i in range(len(images)):img = cv2.imread(images[i],1)gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)faces = face_cascade.detectMultiScale(gray, 1.3, 5)for (x,y,w,h) in faces:img2 = img[y:(y+h),x:(x+w),:]cv2.imwrite('cropped_faces/'+str(i)+'.jpg', img2)

裁剪后的面部示例图像如下:

面部裁剪图像
定义要对每个图像执行的转换:

transform=transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

定义 Faces 数据集类:

class Faces(Dataset):def __init__(self, folder):super().__init__()self.folder = folderself.images = sorted(glob(folder))def __len__(self):return len(self.images)def __getitem__(self, ix):image_path = self.images[ix]image = Image.open(image_path)image = transform(image)return image

创建数据集对象 ds

ds = Faces(folder='cropped_faces/*.jpg')

定义数据加载器类:

dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)

(4) 定义权重初始化函数,使权重的分布较小:

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)

(5) 定义判别网络模型类 Discriminator,接收形状为 batch size x 3 x 64 x 64 的图像,并预测输入图像是真实图像还是生成图像:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Conv2d(3,64,4,2,1,bias=False),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(64,64*2,4,2,1,bias=False),nn.BatchNorm2d(64*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(64*2,64*4,4,2,1,bias=False),nn.BatchNorm2d(64*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(64*4,64*8,4,2,1,bias=False),nn.BatchNorm2d(64*8),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(64*8,1,4,1,0,bias=False),nn.Sigmoid())self.apply(weights_init)def forward(self, input):return self.model(input)

打印模型的摘要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (3,64,64)))

模型摘要输出结果如下所示:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1           [-1, 64, 32, 32]           3,072LeakyReLU-2           [-1, 64, 32, 32]               0Conv2d-3          [-1, 128, 16, 16]         131,072BatchNorm2d-4          [-1, 128, 16, 16]             256LeakyReLU-5          [-1, 128, 16, 16]               0Conv2d-6            [-1, 256, 8, 8]         524,288BatchNorm2d-7            [-1, 256, 8, 8]             512LeakyReLU-8            [-1, 256, 8, 8]               0Conv2d-9            [-1, 512, 4, 4]       2,097,152BatchNorm2d-10            [-1, 512, 4, 4]           1,024LeakyReLU-11            [-1, 512, 4, 4]               0Conv2d-12              [-1, 1, 1, 1]           8,192Sigmoid-13              [-1, 1, 1, 1]               0
================================================================
Total params: 2,765,568
Trainable params: 2,765,568
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 2.31
Params size (MB): 10.55
Estimated Total Size (MB): 12.91
----------------------------------------------------------------

(6) 定义生成网络模型类,使用形状为 batch size x 100 x 1 x 1 的输入生成图像:

class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.model = nn.Sequential(nn.ConvTranspose2d(100,64*8,4,1,0,bias=False,),nn.BatchNorm2d(64*8),nn.ReLU(True),nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),nn.BatchNorm2d(64*4),nn.ReLU(True),nn.ConvTranspose2d( 64*4,64*2,4,2,1,bias=False),nn.BatchNorm2d(64*2),nn.ReLU(True),nn.ConvTranspose2d( 64*2,64,4,2,1,bias=False),nn.BatchNorm2d(64),nn.ReLU(True),nn.ConvTranspose2d( 64,3,4,2,1,bias=False),nn.Tanh())self.apply(weights_init)def forward(self,input):return self.model(input)

打印模型的摘要信息:

generator = Generator().to(device)
print(summary(generator, (100,1,1)))

代码输出结果如下所示:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================ConvTranspose2d-1            [-1, 512, 4, 4]         819,200BatchNorm2d-2            [-1, 512, 4, 4]           1,024ReLU-3            [-1, 512, 4, 4]               0ConvTranspose2d-4            [-1, 256, 8, 8]       2,097,152BatchNorm2d-5            [-1, 256, 8, 8]             512ReLU-6            [-1, 256, 8, 8]               0ConvTranspose2d-7          [-1, 128, 16, 16]         524,288BatchNorm2d-8          [-1, 128, 16, 16]             256ReLU-9          [-1, 128, 16, 16]               0ConvTranspose2d-10           [-1, 64, 32, 32]         131,072BatchNorm2d-11           [-1, 64, 32, 32]             128ReLU-12           [-1, 64, 32, 32]               0ConvTranspose2d-13            [-1, 3, 64, 64]           3,072Tanh-14            [-1, 3, 64, 64]               0
================================================================
Total params: 3,576,704
Trainable params: 3,576,704
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 3.00
Params size (MB): 13.64
Estimated Total Size (MB): 16.64
----------------------------------------------------------------

(7) 定义训练生成网络 (generator_train_step) 和判别网络 (discriminator_train_step) 的函数:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):d_optimizer.zero_grad()prediction_real = discriminator(real_data)error_real = loss(prediction_real.squeeze(), torch.ones(len(real_data)).to(device))error_real.backward()prediction_fake = discriminator(fake_data)error_fake = loss(prediction_fake.squeeze(), torch.zeros(len(fake_data)).to(device))error_fake.backward()d_optimizer.step()return error_real + error_fakedef generator_train_step(real_data, fake_data, loss, g_optimizer):g_optimizer.zero_grad()prediction = discriminator(fake_data)error = loss(prediction.squeeze(), torch.ones(len(real_data)).to(device))error.backward()g_optimizer.step()return error

在以上代码中,在判别网络预测结果上执行 .squeeze 操作,因为模型的输出形状为 batch size x 1 x 1 x 1,而预测结果需要与形状为 batch size x 1 的张量进行比较。

(8) 创建生成网络和判别网络模型对象、优化器以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

(9) 训练模型。

加载真实数据 (real_data) 并通过生成网络生成图像 (fake_data):

num_epochs = 100
d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):N = len(dataloader)d_loss_items = []g_loss_items = []for i, images in enumerate(dataloader):real_data = images.to(device)fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)fake_data = fake_data.detach()

原始 GANDCGAN 的主要区别在于,在 DCGAN 模型中,由于使用了 CNN,因此不必展平 real_data

使用 discriminator_train_step 函数训练判别网络:

        d_loss = discriminator_train_step(real_data, fake_data, loss, d_optimizer)

利用噪声数据 (torch.randn(len(real_data))) 生成新图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())g_loss_items.append(g_loss.item())d_loss_epoch.append(np.average(d_loss_items))
g_loss_epoch.append(np.average(g_loss_items))

(10) 绘制模型训练期间,判别网络和生成网络损失变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')

损失变化
从上图中可以看出,生成网络和判别网络损失的变化与手写数字生成模型的损失变化模式并不相同,原因如下:

  • 人脸图像的尺寸相比手写数字更大,手写数字图像形状为 28 x 28 x 1,人脸图像形状为 64 x 64 x 3
  • 与人脸图像中的特征相比,手写数字图像中的特征较少
  • 与人脸图像中的信息相比,手写数字图像中仅少数像素中存在可用信息

(11) 训练过程完成后,生成图像样本:

generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0))
plt.show()

生成图像样本

小结

DCGAN 是优秀的图像生成模型,其生成网路和判别网络都是使用卷积层和反卷积层构建的深度神经网络。生成网络接收一个随机噪声向量作为输入,并通过逐渐减小的反卷积层将其逐渐转化为与训练数据相似的输出图像;判别网络则是一个用于分类真实和生成图像的卷积神经网络。在本节中,我们学习了如何构建并训练 DCGAN 生成人脸图像。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes
PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/645733.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

el-form动态检验无法生效问题

<el-form-item label"状态&#xff1a;"prop"zt"class"bitianxian"><el-select v-model"ruleForm.zt"placeholder"请选择"change"emptyztXM()"><el-option v-for"(item,index) in ZTdata&quo…

PWN入门Protostar靶场Stack系列

Protostar靶场地址 https://exploit.education/protostar/溢出 源码分析 #include <stdlib.h> #include <unistd.h> #include <stdio.h>int main(int argc, char **argv) {volatile int modified; //定义一个变量char buffer[64]; //给…

Git 入门精讲

我们为什么要学习git&#xff1f; 就当下的发展而言&#xff0c;只要你从事开发就一定会接触git。作为最强大的分布式版本控制器&#xff0c;git 与 svn 有着本质上的区别。 Git是一种分布式版本控制系统&#xff0c;每个开发者都可以在本地维护完整的代码库&#xff0c;可以离…

c++ 加密与解密代码(普通加密 + 凯撒加密 + 图灵来了都解不开的加密)

当你和你的好朋友聊天的时候&#xff0c;你们的聊天内容很容易就被看出来&#xff0c;那么小天狼星这边可以给到一些建议~~ 一、用另一种语言 通常来说&#xff0c;使用除中文和其他常用语言外的语言是一个优秀的选择&#xff01; 例如&#xff1a;乌伯克语、阿亚帕涅科语。 …

智能泊车,再上热搜

编者按&#xff1a;相比于行车&#xff0c;低速可控场景&#xff0c;更有利于泊车功能快速迭代。同时&#xff0c;对于部分消费者来说&#xff0c;泊车智能化也是加分项。 智能泊车赛道&#xff0c;正在重新成为各路势力争夺的焦点。而上一次“高潮”&#xff0c;要追溯到2018年…

开源客户沟通平台Chatwoot账号激活问题

安装docker docker-compose 安装git clone https://github.com/chatwoot/chatwoot 下载之后根目录有一个docker-compose.production.yaml将其复制到一个目录 重命名 docker-compose.yaml 执行docker-compose up -d 构建 构建之后所有容器都安装好了 直接访问http://ip:3…

护眼台灯怎么选——明基、书客、孩视宝实测横评

最近护眼台灯的热度真是不小&#xff0c;许多博主纷纷推荐。考虑到孩子即将放寒假&#xff0c;市场上的产品也是五花八门&#xff0c;于是我决定认真研究一下&#xff0c;找出其中的水货和宝藏产品。我挑选了市场上口碑较好的3款产品进行深入评估&#xff0c;主要从照度、显色指…

Revit二次开发 设置材质

设置此处材质&#xff0c;需要在材质浏览器中创建材质&#xff0c;根据材质名字设置此材质。 代码如下&#xff1a; Material material new FilteredElementCollector(doc).OfClass(typeof(Material)).FirstOrDefault(x > x.Name "窗框") as Material; Element…

如何利用streamlit 將 gemini pro vision 進行圖片內容介紹

如何利用streamlit 將 gemini pro vision 進行圖片內容介紹 1.安裝pip install google-generativeai 2.至 gemini pro 取 api key 3.撰寫如下文章:(方法一) import json import requests import base64 import streamlit as st 讀取圖片檔案&#xff0c;並轉換成 Base64 編…

ES6对象新增了哪些扩展?

ES6&#xff08;ECMAScript 2015&#xff09;为JavaScript中的对象引入了一些新的扩展功能。以下是一些主要的ES6对象扩展&#xff1a; 对象字面量的增强&#xff08;Object Literal Enhancements&#xff09;&#xff1a; ES6允许在对象字面量中更简洁地定义属性和方法。您可以…

Android SeekBar 进度条圆角

先看下效果图&#xff1a; 之前&#xff1a; 优化后&#xff1a; 之前的不是圆角是clip切割导致的 全代码&#xff1a; <SeekBarandroid:layout_width"188dp"android:layout_height"wrap_content"android:background"null"android:focusa…

程序员裁员潮

裁员对程序员的影响可以是相当大的&#xff0c;特别是在技术变革的时期。以下是一些可能的影响&#xff1a; 失业&#xff1a;当公司裁员时&#xff0c;程序员可能会失去他们的工作。这将导致失业风险的增加&#xff0c;特别是如果他们在特定行业或领域内专门从事工作。 就业机…

风速预测 | Python基于CEEMDAN-CNN-Transformer+ARIMA的风速时间序列预测

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 CEEMDAN-CNN-TransformerARIMA是一种用于风速时间序列预测的模型&#xff0c;结合了不同的技术和算法。收集风速时间序列数据&#xff0c;并确保数据的质量和完整性。这些数据通常包括风速的观测值和时间戳。CEEMDAN分…

php 面向对象与反序列

目录 1.类和对象 2.序列化 3.反序列化 1.类和对象 <?php//类 class cl {var $name "fly"; // 类属性//函数function _destruct(){echo $this->name;}//函数function eat() {echo apple;} }//对象 $a new cl(); echo $a->name.<br>; //直接调用…

使用Spring Boot实现基于HTTP的API

Spring Boot是一个用于简化Spring应用程序开发的框架&#xff0c;它提供了一系列的开箱即用的功能&#xff0c;使得快速构建RESTful Web服务和基于HTTP的API变得简单。以下是使用Spring Boot实现基于HTTP的API的步骤&#xff1a; 添加依赖&#xff1a;在Maven项目中&#xff0c…

(每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理第8章 项目整合管理(五)

博主2023年11月通过了信息系统项目管理的考试&#xff0c;考试过程中发现考试的内容全部是教材中的内容&#xff0c;非常符合我学习的思路&#xff0c;因此博主想通过该平台把自己学习过程中的经验和教材博主认为重要的知识点分享给大家&#xff0c;希望更多的人能够通过考试&a…

企业能源消耗监测管理系统是否可以做好能源计量与能耗分析?

能源消耗与分析是能源科学管理的基础&#xff0c;也可促进能源管理工作的改善&#xff0c;在企业中能源管理系统的作用也愈加重要。 首先&#xff0c;能源计量是能源管理的基础&#xff0c;通过能源精准计老化&#xff0c;容易出现测量设备不准确以及其他一些人为因素原因导致…

生信软件12 - 基于Symbol和ENTREZID查询基因注释的R包(easyConvert )

使用easyConvert R包可获取基因的注释,包括基因的别名称别名、基因类型、Ensembl ID、Entrez ID及多个数据库的基因简介(summary)。 R包安装 # remotes包安装 install.packages("remotes")# easyConvert包安装 remotes::install_github("dming1024/easyCon…

P1184 高手之在一起 题解

还是先复习 or 预习一下set。 先给set一个名字&#xff1a; set<元素类型>qwq;插入元素&#xff1a; qwq.insert(元素);查找元素&#xff1a; qwq.find(元素);如果元素没有找到&#xff0c;返回qwq.end()&#xff0c;是一个空的位置迭代器。 注&#xff1a; 1.迭代器…

[AIGC 大数据基础]hive浅谈

在当今大数据时代&#xff0c;随着数据量的不断增大&#xff0c;如何高效地处理和分析海量数据已经成为一个重要的挑战。为了满足这一需求&#xff0c;Hive应运而生。 Hive作为一个基于Hadoop的数据仓库基础设施&#xff0c;为用户提供了类SQL的查询语言和丰富的功能&#xff0…