PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Electron入门,项目启动。

electron 简单介绍&#xff1a; 实现&#xff1a;HTML/CSS/JS桌面程序&#xff0c;搭建跨平台桌面应用。 electron 官方文档&#xff1a; [https://electronjs.org/docs] 本文是基于以下2篇文章且自行实践过的&#xff0c;可行性真实有效。 文章1&#xff1a; https://www.cnbl…

基于Javaweb的摄影作品网站/摄影网站

摘 要 随着信息化时代的到来&#xff0c;系统管理都趋向于智能化、系统化&#xff0c;摄影作品网站也不例外&#xff0c;但目前国内的有些网站仍然都使用人工管理&#xff0c;浏览网站人数越来越多&#xff0c;同时信息量也越来越庞大&#xff0c;人工管理显然已无法应对时代的…

AMD fTPM RNG的BUG使得Linus Torvalds不满

导读因为在 Ryzen 系统上对内核造成了困扰&#xff0c;Linus Torvalds 最近在邮件列表中表达了对 AMD fTPM 硬件随机数生成器的不满&#xff0c;并提出了禁用该功能的建议。 因为在 Ryzen 系统上对内核造成了困扰&#xff0c;Linus Torvalds 最近在邮件列表中表达了对 AMD fTPM…

【【Verilog典型电路设计之FIFO设计】】

典型电路设计之FIFO设计 FIFO (First In First Out&#xff09;是一种先进先出的数据缓存器&#xff0c;通常用于接口电路的数据缓存。与普通存储器的区别是没有外部读写地址线&#xff0c;可以使用两个时钟分别进行写和读操作。FIFO只能顺序写入数据和顺序读出数据&#xff0…

【从0开始学架构笔记】01 基础架构

文章目录 一、架构的定义1. 系统与子系统2. 模块与组件3. 框架与架构4. 重新定义架构 二、架构设计的目的三、复杂度来源&#xff1a;高性能1. 单机复杂度2. 集群复杂度2.1 任务分配2.2 任务分解&#xff08;微服务&#xff09; 四、复杂度来源&#xff1a;高可用1. 计算高可用…

GitKraken保姆级图文使用指南

前言 写这篇文章的原因是组内的产品和美术同学&#xff0c;开始参与到git工作流中&#xff0c;但是网上又没有找到一个比较详细的使用教程&#xff0c;所以干脆就自己写了一个[doge]。文章的内容比较基础&#xff0c;介绍了Git内的一些基础概念和基本操作&#xff0c;适合零基…

合并多个文本文件

使用 wxPython 模块合并多个文本文件的博客。以下是一篇示例博客&#xff1a; C:\pythoncode\blog\txtmerge.py 在 Python 编程中&#xff0c;我们经常需要处理文本文件。有时候&#xff0c;我们可能需要将多个文本文件合并成一个文件&#xff0c;以便进行进一步的处理或分析。…

QT报表Limereport v1.5.35编译及使用

1、编译说明 下载后QT CREATER中打开limereport.pro然后直接编译就可以了。编译后结果如下图&#xff1a; 一次编译可以得到库文件和DEMO执行程序。 2、使用说明 拷贝如下图编译后的lib目录到自己的工程目录中。 release版本的重新命名为librelease. PRO文件中配置 QT …

openpose姿态估计【学习笔记】

文章目录 1、人体需要检测的关键点2、Top-down方法3、Openpose3.1 姿态估计的步骤3.2 PAF&#xff08;Part Affinity Fields&#xff09;部分亲和场3.3 制作PAF标签3.4 PAF权值计算3.5 匹配方法 4、CPM&#xff08;Convolutional Pose Machines&#xff09;模型5、Openpose5.1 …

怎么修改图片的分辨率?

怎么修改图片的分辨率&#xff1f;很多人还不知道分辨率是什么意思&#xff0c;以为代表了图片的清晰度&#xff0c;然而并不是这样的&#xff0c;其实图片的分辨率就是图片尺寸大小的意思。修改图片的分辨率即改变图片的尺寸&#xff0c;通常以像素为单位表示。分辨率决定了图…

【linux基础(四)】对Linux权限的理解

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到开通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; Linux权限 1. 前言2. shell命…

八、Linux下,grep/wc/管道符/echo/重定向符/tail如何使用?

1、grep命令 &#xff08;1&#xff09;主要用于文件 &#xff08;2&#xff09;主要作用是“通过关键字&#xff0c;过滤文件行” &#xff08;3&#xff09;示例&#xff1a; 2、wc命令 &#xff08;1&#xff09;统计文件的行数、单词数等 &#xff08;2&#xff09;示例…

react之路由的安装与使用

一、路由安装 路由官网2021.11月初&#xff0c;react-router 更新到 v6 版本。使用最广泛的 v5 版本的使用 npm i react-router-dom5.3.0二、路由使用 2.1 路由的简单使用 第一步 在根目录下 创建 views 文件夹 ,用于放置路由页面 films.js示例代码 export default functio…

一文预览 | 8 月 16 日 NVIDIA 在 WAVE SUMMIT深度学习开发者大会 2023精彩亮点抢先看!

由深度学习技术及应用国家工程研究中心主办&#xff0c;百度飞桨和文心大模型承办的 WAVE SUMMIT深度学习开发者大会2023&#xff0c;将于 8 月 16 日在北京与大家见面。NVIDIA 作为技术合作伙伴&#xff0c;将携手百度飞桨参与这场技术盛会。 在这次大会中&#xff0c;NVIDIA…

Java 项目日志实例基础:Log4j

点击下方关注我&#xff0c;然后右上角点击...“设为星标”&#xff0c;就能第一时间收到更新推送啦~~~ 介绍几个日志使用方面的基础知识。 1 Log4j 1、Log4j 介绍 Log4j&#xff08;log for java&#xff09;是 Apache 的一个开源项目&#xff0c;通过使用 Log4j&#xff0c;我…

RabbitMq交换机类型介绍

RabbitMq交换机类型介绍 在RabbitMq中&#xff0c;生产者的消息都是通过交换器来接收&#xff0c;然后再从交换器分发到不同的队列&#xff0c;再由消费者从队列获取消息。这种模式也被成为“发布/订阅”。 分发的过程中交换器类型会影响分发的逻辑。 直连交换机&#xff1a…

【计算机视觉|生成对抗】逐步增长的生成对抗网络(GAN)以提升质量、稳定性和变化

本系列博文为深度学习/计算机视觉论文笔记&#xff0c;转载请注明出处 标题&#xff1a;Progressive Growing of GANs for Improved Quality, Stability, and Variation 链接&#xff1a;[1710.10196] Progressive Growing of GANs for Improved Quality, Stability, and Vari…

安防监控视频云存储平台EasyNVR出现内核报错的情况该如何解决?

安防视频监控汇聚EasyNVR视频集中存储平台&#xff0c;是基于RTSP/Onvif协议的安防视频平台&#xff0c;可支持将接入的视频流进行全平台、全终端分发&#xff0c;分发的视频流包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等格式。 近期有用户联系到我们&#xff0c;EasyNVR…

kafka集成篇

kafka的Java客户端 生产者 1.引入依赖 <dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>2.6.3</version></dependency>2.生产者发送消息的基本实现 /*** 消息的发送⽅*/ …

分类预测 | MATLAB实现DRN深度残差网络多输入分类预测

分类预测 | MATLAB实现DRN深度残差网络多输入分类预测 目录 分类预测 | MATLAB实现DRN深度残差网络多输入分类预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.分类预测 | MATLAB实现DRN深度残差网络多输入分类预测 2.代码说明&#xff1a;MATLAB实现DRN深度残差网络…