Pytorch从零开始实战18

Pytorch从零开始实战——人脸图像生成

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——人脸图像生成
    • 环境准备
    • 模型定义
    • 开始训练
    • 可视化
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是了解并使用DCGAN模型,完成人脸图生成。
第一步,导入常用包

import torch, random, random, os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

检查设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.cuda.device_count()

设置超参数,其中数据集源于K同学

dataroot = "./data/face"  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 64   # 图像的尺寸(宽度和高度)
nz  = 100         # z潜在向量的大小(生成器输入的尺寸)
ngf = 64          # 生成器中的特征图大小
ndf = 64          # 判别器中的特征图大小
num_epochs = 50   # 训练的总轮数,如果你显卡不太行,可调小,但是生成效果会随之降低
lr    = 0.0002    # 学习率
beta1 = 0.5       # Adam优化器的Beta1超参数

使用datasets读取数据

dataset = datasets.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),        # 调整图像大小transforms.CenterCrop(image_size),    # 中心裁剪图像transforms.ToTensor(),                # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量(0.5, 0.5, 0.5)),]))

随机查看五张数据

def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(dataset)

在这里插入图片描述
使用dataloader进行批量划分和打乱

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,  # 批量大小shuffle=True,           # 是否打乱数据集num_workers=5 # 使用多个线程加载数据的工作进程数)

模型定义

深度卷积对抗网络(Deep Convolutional Generative Adversarial Networks, DCGAN)是生成对抗网络的一种模型改进,其将卷积运算的思想引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。

判别器网络和生成器网络: DCGAN包括两个主要部分,即判别器(Discriminator)和生成器(Generator)。判别器负责评估输入图像是真实图像还是生成图像,而生成器则试图生成逼真的图像以欺骗判别器。

卷积层和批量归一化: DCGAN使用卷积神经网络(CNN)作为判别器和生成器的主要组件,以有效地捕捉图像的空间结构。此外,引入批量归一化来稳定训练过程,加速收敛。

生成器输入和输出: 生成器接收一个随机噪声向量作为输入,通过反卷积(或称为转置卷积)操作生成图像。这使得生成器能够从随机噪声中学到数据分布的特征。

判别器的激活函数和损失函数: 判别器使用Leaky ReLU(带有泄漏的修正线性单元)作为激活函数,以防止梯度消失问题。损失函数采用二元交叉熵,用于判断生成图像和真实图像的相似度。

避免全连接层: DCGAN的设计避免使用全连接层,而是主要使用卷积和反卷积层,以减少模型参数,降低过拟合风险。

其中,反卷积核心思想是通过在输入特征图之间插入一些新的值(通常用零填充),使得输出的尺寸比输入更大。

DCGAN模型主要包括了一个生成网络 G 和一个判别网络 D,生成网络 G 负责生成图像,它接受一个随机的噪声z,通过该噪声生成图像,将生成的图像记为G(z),判别网络D负责判别一张图像是否为真实的,它的输入是x,代表一张图像,输出D(x)表示x为真实图像的概率。

实际上判别网络D是对数据的来源进行一个判别:究竟这个数据是来自真实的数据分布Pd(x)判别为“1”),还是来自于一个生成网络G所产生的一个数据分布Pg(z)(判别为“0”)。所以在整个训练过程中,生成网络G的目标是生成可以以假乱真的图像G(z),当判别网络D无法区分,即D(G(z))=0.5时,便得到了一个生成网络G用来生成图像扩充数据集。
在这里插入图片描述
初始化权重

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果类名中包含'Conv',即当前层是卷积层if classname.find('Conv') != -1:# 使用正态分布初始化权重数据,均值为0,标准差为0.02nn.init.normal_(m.weight.data, 0.0, 0.02)# 如果类名中包含'BatchNorm',即当前层是批归一化层elif classname.find('BatchNorm') != -1:# 使用正态分布初始化权重数据,均值为1,标准差为0.02nn.init.normal_(m.weight.data, 1.0, 0.02)# 使用常数初始化偏置项数据,值为0nn.init.constant_(m.bias.data, 0)

定义生成器网络

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(# 输入为Z,经过一个转置卷积层nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),  # 批归一化层,用于加速收敛和稳定训练过程nn.ReLU(True),  # ReLU激活函数# 输出尺寸:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 输出尺寸:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 输出尺寸:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 输出尺寸:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh()  # Tanh激活函数# 输出尺寸:3 x 64 x 64)def forward(self, input):return self.main(input)netG = Generator().to(device)
netG.apply(weights_init) # 使用 "weights_init" 函数来随机初始化所有权重
print(netG)

在这里插入图片描述
定义判别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器的主要结构,使用Sequential容器将多个层按顺序组合在一起self.main = nn.Sequential(# 输入大小为3 x 64 x 64nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 将输入通过判别器的主要结构进行前向传播return self.main(input)# 创建判别器模型
netD = Discriminator().to(device)
netD.apply(weights_init) # 使用 "weights_init" 函数来随机初始化所有权重
print(netD)

在这里插入图片描述

开始训练

定义损失函数和优化算法

# 损失函数
criterion = nn.BCELoss()# 创建用于可视化生成器进程的潜在向量批次
fixed_noise = torch.randn(64, nz, 1, 1, device=device)real_label = 1.
fake_label = 0.# 为生成器(G)和判别器(D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

对于每个dataloader中的atch,会进行以下步骤:
1.更新判别器网络: 通过最大化判别器对真实图像和生成图像的损失来训练判别器。这包括计算对真实图像的损失和对生成图像的损失,然后通过梯度反向传播来更新判别器的参数。
2.更新生成器网络: 通过最大化生成器在生成图像上的损失来训练生成器。生成器的目标是欺骗判别器,使其无法区分生成的图像和真实图像。同样,通过梯度反向传播来更新生成器的参数。
3.记录损失值,输出训练统计信息,并定期保存生成器在固定噪声上的输出图像。

img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数print("Starting Training Loop...")  # 输出训练开始的提示信息
# 对于每个epoch(训练周期)
for epoch in range(num_epochs):# 对于dataloader中的每个batchfor i, data in enumerate(dataloader, 0):############################# (1) 更新判别器网络:最大化 log(D(x)) + log(1 - D(G(z)))############################# 使用真实图像样本训练netD.zero_grad()  # 清除判别器网络的梯度# 准备真实图像的数据real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  # 创建一个全是真实标签的张量# 将真实图像样本输入判别器,进行前向传播output = netD(real_cpu).view(-1)# 计算真实图像样本的损失errD_real = criterion(output, label)# 通过反向传播计算判别器的梯度errD_real.backward()D_x = output.mean().item()  # 计算判别器对真实图像样本的输出的平均值## 使用生成图像样本训练# 生成一批潜在向量noise = torch.randn(b_size, nz, 1, 1, device=device)# 使用生成器生成一批假图像样本fake = netG(noise)label.fill_(fake_label)  # 创建一个全是假标签的张量# 将所有生成的图像样本输入判别器,进行前向传播output = netD(fake.detach()).view(-1)# 计算判别器对生成图像样本的损失errD_fake = criterion(output, label)# 通过反向传播计算判别器的梯度errD_fake.backward()D_G_z1 = output.mean().item()  # 计算判别器对生成图像样本的输出的平均值# 计算判别器的总损失,包括真实图像样本和生成图像样本的损失之和errD = errD_real + errD_fake# 更新判别器的参数optimizerD.step()############################# (2) 更新生成器网络:最大化 log(D(G(z)))###########################netG.zero_grad()  # 清除生成器网络的梯度label.fill_(real_label)  # 对于生成器成本而言,将假标签视为真实标签# 由于刚刚更新了判别器,再次将所有生成的图像样本输入判别器,进行前向传播output = netD(fake).view(-1)# 根据判别器的输出计算生成器的损失errG = criterion(output, label)# 通过反向传播计算生成器的梯度errG.backward()D_G_z2 = output.mean().item()  # 计算判别器对生成器输出的平均值# 更新生成器的参数optimizerG.step()# 输出训练统计信息if i % 400 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))# 保存损失值以便后续绘图G_losses.append(errG.item())D_losses.append(errD.item())# 通过保存生成器在固定噪声上的输出来检查生成器的性能if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1

在这里插入图片描述

可视化

查看训练过程

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

在这里插入图片描述

查看生成的图像

# 创建一个大小为8x8的图形对象
fig = plt.figure(figsize=(8, 8))
# 不显示坐标轴
plt.axis("off")
# 将图像列表img_list中的图像转置并创建一个包含每个图像的单个列表ims
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
# 使用图形对象、图像列表ims以及其他参数创建一个动画对象ani
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# 将动画以HTML形式呈现
HTML(ani.to_jshtml())

在这里插入图片描述
对比一下真实图像和生成的图像

# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))# 绘制上一个时期生成的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在这里插入图片描述

总结

DCGAN是生成对抗网络的一种应用,包含生成器和判别器,通过对抗训练的方式,使得生成器能够生成逼真的数据,而判别器则学会区分真实数据和生成数据。总之,DCGAN的设计使得生成对抗网络在图像生成领域取得了显著的进展,促进了后续对GAN的研究和发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/660821.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2024年美赛即将开赛】最后一天如何提高获奖率

美赛思路预定 01 美赛赛中时间分配美赛时间安排比赛前2~3天第一天(2号)第二天(3号)第三天(4号)第四天(5号)第五天(6号)8:00~10:00 02 …

Hadoop-生产调优(更新中)

第1章 HDFS-核心参数 1.1 NameNode内存生产配置 1)NameNode 内存计算 每个文件块大概占用 150 byte,一台服务器 128G 内存为例,能存储多少文件块呢? 128 * 1024 * 1024 * 1024 / 150byte ≈ 9.1 亿G MB KB Byte 2&#xff09…

前端构建变更:从 webpack 换 vite

现状 这里以一个 op (内部运营管理用)项目为例,从 webpack 构建改为 vite 构建,提高本地开发效率,顺便也加深对 webpack 、 vite 的了解。 vite 是前端构建工具,使用 一系列预配置进行rollup 打包&#x…

gdb 调试 - 在vscode图形化展示在远程的gdb debug过程

前言 本地机器的操作系统是windows,远程机器的操作系统是linux,开发在远程机器完成,本地只能通过ssh登录到远程。现在目的是要在本地进行图形化展示在远程的gdb debug过程。(注意这并不是gdb remote !!&am…

实现vue3响应式系统核心-shallowReactive

简介 今天来实现一下 shallowReactive 这个 API。 reactive函数是一个深响应,当你取出的值为对象类型,需要再次调用 reactive进行响应式处理。很明显我们目前的代码是一个浅响应,即 只代理了对象的第一层,也就是 shallowReactiv…

36万的售价,蔚来理想卖得,小米卖不得?

文 | AUTO芯球 作者 | 雷歌 Are you OK?雷军被网友们叫“小雷”! 被网友一猜再猜的小米SU7的价格,因为一份保险上牌价格单的曝光被网友吵得热热闹闹,曝出的小米汽车顶配上牌保险价格为36.14万。 20万以下,人们愿称…

【python】OpenCV—Tracking(10.1)

学习来自《Learning OpenCV 3 Computer Vision with Python》Second Edition by Joe Minichino and Joseph Howse 文章目录 检测移动的目标涉及到的 opencv 库cv2.GaussianBlurcv2.absdiffcv2.thresholdcv2.dilatecv2.getStructuringElementcv2.findContourscv2.contourAreacv2…

css多行文本擦拭效果

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>多行文本擦拭效果</title><style>* …

亚信安全助力宁夏首个人工智能数据中心建成 铺设绿色算力安全底座

近日&#xff0c;由宁夏西云算力科技有限公司倾力打造&#xff0c;亚信安全科技股份有限公司&#xff08;股票代码&#xff1a;688225&#xff09;全力支撑&#xff0c;总投资达数十亿元人民币的宁夏智算中心项目&#xff0c;其一期工程——宁夏首个采用全自然风冷技术的30KW机…

复刻桌面小电视【包含代码分析】

宗旨&#xff1a;开源、分享、学习、进步&#xff0c;生命不息&#xff0c;折腾不止。 复刻小电视 感谢各位大佬的开源项目&#xff0c;让我有了学习的机会&#xff0c;如果侵权&#xff0c;请联系我删除。本人能力有限&#xff0c;如果有什么不对的地方&#xff0c;欢迎指正…

c/c++串的链式操作

文章目录 1.链式串的定义2.初始化3.赋值为04.赋值操作5.打印操作6.源码 本篇博客中都是带头结点的串。 1.链式串的定义 这里的数据域是4个字节&#xff0c;是为了节省空间。 typedef struct StringNode{char ch[4]; //按串长分配存储区&#xff0c;ch指向串的基地址struct S…

C++引用、内联函数、auto关键字介绍以及C++中无法使用NULL的原因

文章目录 一、引用1.1 引用概念1.2 引用特性1.3 常引用1.4 使用场景1.4.1 做参数1.4.2做返回值 1.5 引用和指针的区别1.6 小结一下 二、内联函数2.1 内联的概念2.2 内联的特性2.3 【面试题】 三、auto关键字(C11)3.1 类型别名思考3.2 auto简介 四、auto的使用细则4.1 基于范围的…

【2024年美国大学生数学建模竞赛】完整解析+模型代码+技术文档

美赛思路预定 01 美赛赛中时间分配美赛时间安排比赛前2~3天第一天&#xff08;2号&#xff09;第二天&#xff08;3号&#xff09;第三天&#xff08;4号&#xff09;第四天&#xff08;5号&#xff09;第五天&#xff08;6号&#xff09;8&#xff1a;00~10&#xff1a;00 02 …

固态硬盘颗粒,让我们了解下SLC、MLC、TLC

前文提要 近些年SSD的市场越来越好&#xff0c;大家的家用PC也逐渐都转向速度更快&#xff0c;玩游戏更流程的SSD,反而更加推动了SSD厂商的生产种类&#xff0c;但是其实大家还是挺关注SSD盘的使用寿命&#xff0c;处理数据速度&#xff0c;以及更重要的价格&#xff0c;面对市…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之TextPicker组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之TextPicker组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、TextPicker组件 TextClock组件通过文本将当前系统时间显示在设备上。支持不…

什么是回归测试?回归测试的类型和方法?

随着软件开发进程的进行&#xff0c;每一次的修改和更新都有可能引入新的问题和错误。为了确保产品质量和稳定性&#xff0c;需要进行回归测试。那么&#xff0c;什么是回归测试&#xff1f;本文将为您解答。 回归测试是指在软件代码、使用环境或产品需求发生改变时&#xff0…

【Oracle云】使用 boto3 访问 OCI 对象存储 (AWS S3协议兼容)

在现代云计算环境中&#xff0c;S3&#xff08;Simple Storage Service&#xff09;协议已经成为云对象存储的事实标准。它提供了简单、可扩展、高度耐用的存储解决方案&#xff0c;得到了广泛应用。Oracle Cloud Infrastructure&#xff08;OCI&#xff09;秉承着开放性和灵活…

C++初阶 类和对象(补充)

目录 一、友元 1.1什么是友元&#xff1f; 1.2如何使用友元&#xff1f; 1.3使用友元 1.4使用友元注意事项 二、初始化列表 2.1什么是初始化列表? 2.2为什么要有初始化列表&#xff1f; 2.3使用初始化列表 2.4注意事项 一、友元 1.1什么是友元&#xff1f; 友元是一…

大数据知识图谱之深度学习——基于BERT+LSTM+CRF深度学习识别模型医疗知识图谱问答可视化系统

文章目录 大数据知识图谱之深度学习——基于BERTLSTMCRF深度学习识别模型医疗知识图谱问答可视化系统一、项目概述二、系统实现基本流程三、项目工具所用的版本号四、所需要软件的安装和使用五、开发技术简介Django技术介绍Neo4j数据库Bootstrap4框架Echarts简介Navicat Premiu…

Windows Server 2003 DNS服务器搭建

系列文章目录 目录 系列文章目录 文章目录 前言 一、DNS服务器是什么&#xff1f; 二、配置服务器 1.实验环境搭建 2.服务器搭建 3)安装Web服务器和DNS服务器 4)查看安装是否成功 5)这里直接配置DNS服务器了,Web服务器如何配置我已经发布过了 文章目录 Windows Serve…