第G2周:人脸图像生成(DCGAN)

第G2周:人脸图像生成(DCGAN)

  • 第G2周:人脸图像生成(DCGAN)
    • 一、前言
    • 二、我的环境
    • 三、代码实现
      • 1、导入第三方库
      • 2、设置超参数
      • 3、导入数据
    • 四、定义模型
      • 4.1 初始化权重
      • 4.2 定义生成器
      • 4.3 定义鉴别器
    • 五、训练模型
    • 六、可视化代码
    • 七、总结

第G2周:人脸图像生成(DCGAN))

第G2周:人脸图像生成(DCGAN)

一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.5
  • 编译器:Spyder

三、代码实现

1、导入第三方库

import torch, random, random, os
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTMLmanualSeed = 999  # 随机种子
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

2、设置超参数

dataroot = 'D:/DL_Camp/GAN/G2'  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 64   # 图像的尺寸(宽度和高度)
nz  = 100         # z潜在向量的大小(生成器输入的尺寸)
ngf = 64          # 生成器中的特征图大小
ndf = 64          # 判别器中的特征图大小
num_epochs = 5    # 训练的总轮数
lr    = 0.0002    # 学习率
beta1 = 0.5       # Adam优化器的Beta1超参数

3、导入数据

dataset = dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),        # 调整图像大小transforms.CenterCrop(image_size),    # 中心裁剪图像transforms.ToTensor(),                # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量(0.5, 0.5, 0.5)),]))# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,  # 批量大小shuffle=True,           # 是否打乱数据集num_workers=5 # 使用多个线程加载数据的工作进程数)# 选择要在哪个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:",device)# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24], padding=2, normalize=True).cpu(),(1,2,0)))

在这里插入图片描述

四、定义模型

4.1 初始化权重

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果类名中包含'Conv',即当前层是卷积层if classname.find('Conv') != -1:# 使用正态分布初始化权重数据,均值为0,标准差为0.02nn.init.normal_(m.weight.data, 0.0, 0.02)# 如果类名中包含'BatchNorm',即当前层是批归一化层elif classname.find('BatchNorm') != -1:# 使用正态分布初始化权重数据,均值为1,标准差为0.02nn.init.normal_(m.weight.data, 1.0, 0.02)# 使用常数初始化偏置项数据,值为0nn.init.constant_(m.bias.data, 0)

4.2 定义生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(# 输入为Z,经过一个转置卷积层nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),  # 批归一化层,用于加速收敛和稳定训练过程nn.ReLU(True),  # ReLU激活函数# 输出尺寸:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 输出尺寸:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 输出尺寸:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 输出尺寸:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh()  # Tanh激活函数# 输出尺寸:3 x 64 x 64)def forward(self, input):return self.main(input)# 创建生成器
netG = Generator().to(device)
# 使用 "weights_init" 函数对所有权重进行随机初始化,
# 平均值(mean)设置为0,标准差(stdev)设置为0.02。
netG.apply(weights_init)
# 打印生成器模型
print(netG)

在这里插入图片描述

4.3 定义鉴别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器的主要结构,使用Sequential容器将多个层按顺序组合在一起self.main = nn.Sequential(# 输入大小为3 x 64 x 64nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):# 将输入通过判别器的主要结构进行前向传播return self.main(input)# 创建判别器模型
netD = Discriminator().to(device)# 应用 "weights_init" 函数来随机初始化所有权重
# 使用 mean=0, stdev=0.2 的方式进行初始化
netD.apply(weights_init)# 打印模型
print(netD)

在这里插入图片描述

五、训练模型

# 初始化“BCELoss”损失函数
criterion = nn.BCELoss()# 创建用于可视化生成器进程的潜在向量批次
fixed_noise = torch.randn(64, nz, 1, 1, device=device)real_label = 1.
fake_label = 0.# 为生成器(G)和判别器(D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数print("Starting Training Loop...")  # 输出训练开始的提示信息
# 对于每个epoch(训练周期)
for epoch in range(num_epochs):# 对于dataloader中的每个batchfor i, data in enumerate(dataloader, 0):############################# (1) 更新判别器网络:最大化 log(D(x)) + log(1 - D(G(z)))############################# 使用真实图像样本训练netD.zero_grad()  # 清除判别器网络的梯度# 准备真实图像的数据real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  # 创建一个全是真实标签的张量# 将真实图像样本输入判别器,进行前向传播output = netD(real_cpu).view(-1)# 计算真实图像样本的损失errD_real = criterion(output, label)# 通过反向传播计算判别器的梯度errD_real.backward()D_x = output.mean().item()  # 计算判别器对真实图像样本的输出的平均值## 使用生成图像样本训练# 生成一批潜在向量noise = torch.randn(b_size, nz, 1, 1, device=device)# 使用生成器生成一批假图像样本fake = netG(noise)label.fill_(fake_label)  # 创建一个全是假标签的张量# 将所有生成的图像样本输入判别器,进行前向传播output = netD(fake.detach()).view(-1)# 计算判别器对生成图像样本的损失errD_fake = criterion(output, label)# 通过反向传播计算判别器的梯度errD_fake.backward()D_G_z1 = output.mean().item()  # 计算判别器对生成图像样本的输出的平均值# 计算判别器的总损失,包括真实图像样本和生成图像样本的损失之和errD = errD_real + errD_fake# 更新判别器的参数optimizerD.step()############################# (2) 更新生成器网络:最大化 log(D(G(z)))###########################netG.zero_grad()  # 清除生成器网络的梯度label.fill_(real_label)  # 对于生成器成本而言,将假标签视为真实标签# 由于刚刚更新了判别器,再次将所有生成的图像样本输入判别器,进行前向传播output = netD(fake).view(-1)# 根据判别器的输出计算生成器的损失errG = criterion(output, label)# 通过反向传播计算生成器的梯度errG.backward()D_G_z2 = output.mean().item()  # 计算判别器对生成器输出的平均值# 更新生成器的参数optimizerG.step()# 输出训练统计信息if i % 400 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))# 保存损失值以便后续绘图G_losses.append(errG.item())D_losses.append(errD.item())# 通过保存生成器在固定噪声上的输出来检查生成器的性能if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1

六、可视化代码

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()# 创建一个大小为8x8的图形对象
fig = plt.figure(figsize=(8, 8))# 不显示坐标轴
plt.axis("off")# 将图像列表img_list中的图像转置并创建一个包含每个图像的单个列表ims
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]# 使用图形对象、图像列表ims以及其他参数创建一个动画对象ani
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)# 将动画以HTML形式呈现
HTML(ani.to_jshtml())# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))# 绘制上一个时期生成的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在这里插入图片描述
在这里插入图片描述在这里插入图片描述

七、总结

对GAN有更深的理解,同时学习到了GAN的变体DCGAN。DCGAN使用卷积网络的对抗网络,把CNN卷积技术用于GAN模式的网络里,G网在生成数据时,使用反卷积的重构技术来重构原始图片。D网用卷积技术来识别图片特征,进而做出判别。能提高样本的质量和收敛速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/811787.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年蓝桥杯40天打卡总结

2024蓝桥杯40天打卡总结 真题题解其它预估考点重点复习考点时间复杂度前缀和二分的两个模板字符串相关 String和StringBuilderArrayList HashSet HashMap相关蓝桥杯Java常用算法大数类BigInteger的存储与运算日期相关考点及函数质数最小公倍数和最大公约数排序库的使用栈Math类…

Go语言图像处理实战:深入image/color库的应用技巧

Go语言图像处理实战:深入image/color库的应用技巧 引言image/color库基础颜色模型简介颜色类型和接口 image/color库实际应用基本颜色操作创建颜色颜色值转换颜色比较 颜色转换与处理与image库结合使用 性能优化和高级技巧性能考量避免频繁的颜色类型转换使用并发处…

web服务器是如何运行的?tomcat基本原理

tomcat基本流程 tomcat在启动时将webapps下的每个项目中的web.xml读取,获取相关信息。tomcat只关心Servlet 程序、Filter 过滤器、Listener 监听器,不关心其他类。 tomcat接收到请求后会将请求报文转换成一个httpServletRequest对象,包含请求…

【MATLAB源码-第45期】基于matlab的16APSK调制解调仿真,使用卷积编码软判决。

操作环境: MATLAB 2022a 1、算法描述 1. 16APSK调制解调 16APSK (16-ary Amplitude Phase Shift Keying) 是一种相位调制技术,其基本思想是在恒定幅度的条件下,改变信号的相位,从而传送信息。 - 调制:在16APSK中&am…

在vue和 js 、ts 数据中使用 vue-i18n,切换语言环境时,标签文本实时变化

我的项目需要显示两种语言(中文和英文),并且我想要切换语言时,页面语言环境会随之改变,目前发现,只能在vue中使用$t(‘’)的方式使用,但是这种方式只能在vue中使用,而我的菜单文件是定义在js中,…

gpt系列概述——从gpt1到chatgpt

GPT建模实战:GPT建模与预测实战-CSDN博客 OpenAI的GPT(Generative Pre-trained Transformer)系列模型是自然语言处理领域的重要里程碑。从2018年至2020年,该公司相继推出了GPT-1、GPT-2和GPT-3,这些模型在文本生…

CTFshow电子取证——内存取证1

关于内存与注册表 内存中的注册表项 当Windows操作系统启动时,它会将注册表的部分数据加载到内存中,以便系统和应用程序可以快速地访问这些信息。这些数据在内存中可以更快地被读取和修改,以便系统能够动态地调整其行为和配置。 系统性能和…

工业相机飞拍功能的介绍(转载学习)

转载至: https://baijiahao.baidu.com/s?id1781438589726948322&wfrspider&forpc

办公软件巨头CCED、WPS迎来新挑战,新款办公软件已形成普及之势

办公软件巨头CCED、WPS的成长经历 CCED与WPS,这两者均是中国办公软件行业的佼佼者,为人们所熟知。 然而,它们的成功并非一蹴而就,而是经过了长时间的积累与沉淀。 CCED,这款中国大陆早期的文本编辑器,在上…

【vue】购物车案例优化

对 购物车案例 进行优化 用watch实现全选/取消全选用watch实现全选状态的检查用computed计算总价格 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-w…

014:vue3 van-list van-pull-refresh实现上拉加载,下拉刷新

文章目录 1. 实现上拉加载&#xff0c;下拉刷新效果2. van-list&#xff0c;van-pull-refresh组件详解2.1 van-list组件2.2 van-pull-refresh组件 3. 完整案例4. 坑点&#xff1a;加载页面会一直调用加载接口 1. 实现上拉加载&#xff0c;下拉刷新效果 通过下拉刷新加载下一页…

淘宝扭蛋机小程序:扭出惊喜,乐享购物新体验

在快节奏的现代生活中&#xff0c;人们总是在寻找新鲜、有趣的娱乐方式。淘宝扭蛋机小程序应运而生&#xff0c;为您带来前所未有的购物与娱乐结合新体验。在这里&#xff0c;每一次的扭动都充满惊喜&#xff0c;每一次的点击都带来乐趣&#xff0c;让您在购物的同时&#xff0…

4.Labview簇、变体与类(上)

在Labview中&#xff0c;何为簇与变体&#xff0c;何为类&#xff1f;应该如何理解&#xff1f;具体有什么应用场景&#xff1f; 本文基于Labview软件&#xff0c;独到的讲解了簇与变体与类函数的使用方法和场景&#xff0c;从理论上讲解其数据流的底层概念&#xff0c;从实践上…

汽车维修类中译英的英语翻译

近年来&#xff0c;随着全球化的加速和汽车市场的不断扩大&#xff0c;汽车维修领域的交流与合作也日益频繁。汽车维修类中译英的英语翻译在汽车行业中扮演着至关重要的角色。那么&#xff0c;针对汽车维修类翻译&#xff0c;中译英的英语翻译有何技巧&#xff1f; 业内人士指出…

Linux 系统解压缩文件

Linux系统&#xff0c;可以使用unzip命令来解压zip文件 方法如下 1. 打开终端&#xff0c;在命令行中输入以下命令来安装unzip&#xff1a; sudo apt-get install unzip 1 2. 假设你想要将zip文件解压缩到名为"target_dir"的目录中&#xff0c;在终端中切换到目标路…

C#/.NET/.NET Core拾遗补漏合集(24年4月更新)

前言 在这个快速发展的技术世界中&#xff0c;时常会有一些重要的知识点、信息或细节被忽略或遗漏。《C#/.NET/.NET Core拾遗补漏》专栏我们将探讨一些可能被忽略或遗漏的重要知识点、信息或细节&#xff0c;以帮助大家更全面地了解这些技术栈的特性和发展方向。 GitHub开源地…

Java中的单元测试Junit框架

单元测试 概述与以往的测试比较 Junit框架快速入门优点测试步骤断言机制 自动化测试 Junit框架常见注解 概述 针对最小的功能单元&#xff08;方法&#xff09;&#xff0c;编写测试代码对其进行正确性测试 与以往的测试比较 只能在main方法编写测试代码&#xff0c;去调用其…

Python实现PDF页面的删除与添加

在处理PDF文档的过程中&#xff0c;我们时常会需要对PDF文档中的页面进行编辑操作的情况&#xff0c;如插入和删除页面。通过添加和删除PDF页面&#xff0c;我们可以增加内容或对不需要的内容进行删除&#xff0c;使文档内容更符合需求。而通过Python实现PDF文档中的插入和删除…

微服务学习 Eureka注册中心

服务调用时候出现问题&#xff0c;当服务者很多时候&#xff0c;比如不同的端口。消费者如何找到服务者的地址&#xff1f;又如何判断服务者是否健康。 Eureka基本原理&#xff1a; 总结:如果有多个服务提供者&#xff0c;消费者该如何选择&#xff1f; 搭建Eureka注册中心: 1.…

LeetCode-72. 编辑距离【字符串 动态规划】

LeetCode-72. 编辑距离【字符串 动态规划】 题目描述&#xff1a;解题思路一&#xff1a;动规五部曲解题思路二&#xff1a;动态规划【版本二】解题思路三&#xff1a;0 题目描述&#xff1a; 给你两个单词 word1 和 word2&#xff0c; 请返回将 word1 转换成 word2 所使用的最…