PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(# Input: N x channels_img x 64 x 64nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32nn.LeakyReLU(0.2),self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1nn.Sigmoid(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.gen = nn.Sequential(# Input: N x z_dim x 1 x 1self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1,),nn.Tanh(),)def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False,),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.gen(x)def initialize_weights(model):for m in model.modules():if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):nn.init.normal_(m.weight.data, 0.0, 0.02)def test():N, in_channels, H, W = 8, 3, 64, 64z_dim = 100x = torch.randn((N, in_channels, H, W))disc = Discriminator(in_channels, 8)initialize_weights(disc)assert disc(x).shape == (N, 1, 1, 1)gen = Generator(z_dim, in_channels, 8)initialize_weights(gen)z = torch.randn((N, z_dim, 1, 1))assert gen(z).shape == (N, in_channels, H, W)print("success")if __name__ == "__main__":test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64transforms = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),]
)# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )# comment mnist above and uncomment below if train on CelebA# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0gen.train()
disc.train()for epoch in range(NUM_EPOCHS):# Target labels not needed! <3 unsupervisedfor batch_idx, (real, _) in enumerate(dataloader):real = real.to(device)noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()# Print losses occasionally and print to tensorboardif batch_idx % 100 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)# take out (up to) 32 examplesimg_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)writer_real.add_image("Real", img_grid_real, global_step=step)writer_fake.add_image("Fake", img_grid_fake, global_step=step)step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VSCode 使用总结

快捷键 在 Visual Studio Code (VSCode) 中&#xff0c;有许多常用的快捷键可以提高编程效率。以下是一些常见的 VSCode 编程项目快捷键&#xff1a; 编辑器操作&#xff1a; 撤销&#xff1a;Ctrl Z重做&#xff1a;Ctrl Shift Z复制&#xff1a;Ctrl C剪切&#xff1a;C…

Electron入门,项目启动。

electron 简单介绍&#xff1a; 实现&#xff1a;HTML/CSS/JS桌面程序&#xff0c;搭建跨平台桌面应用。 electron 官方文档&#xff1a; [https://electronjs.org/docs] 本文是基于以下2篇文章且自行实践过的&#xff0c;可行性真实有效。 文章1&#xff1a; https://www.cnbl…

题解 | #1005.List Reshape# 2023杭电暑期多校9

1005.List Reshape 签到题 题目大意 按一定格式给定一个纯数字一维数组&#xff0c;按给定格式输出成二维数组。 解题思路 读入初始数组字符串&#xff0c;将每个数字分离&#xff0c;按要求输出即可 参考代码 参考代码为已AC代码主干&#xff0c;其中部分功能需读者自行…

学习Vue:组件通信

组件化开发在现代前端开发中是一种关键的方法&#xff0c;它能够将复杂的应用程序拆分为更小、更可管理的独立组件。在Vue.js中&#xff0c;父子组件通信是组件化开发中的重要概念&#xff0c;同时我们还会讨论其他组件间通信的方式。 父子组件通信&#xff1a;Props 和 Events…

TDSQL赤兔管理台无管理员用户密码解决方案

解决方案 问题描述&#xff1a; tdsql使用过程中&#xff0c;可能会遇到控制台用户密码忘记的情况&#xff0c;用户登录次数过多被锁的情况&#xff0c;没有管理员的用户密码又急需某些权限的情况。 解决过程&#xff1a; 获取配置库信息&#xff1a; 在浏览器上打开如下命…

基于Javaweb的摄影作品网站/摄影网站

摘 要 随着信息化时代的到来&#xff0c;系统管理都趋向于智能化、系统化&#xff0c;摄影作品网站也不例外&#xff0c;但目前国内的有些网站仍然都使用人工管理&#xff0c;浏览网站人数越来越多&#xff0c;同时信息量也越来越庞大&#xff0c;人工管理显然已无法应对时代的…

AMD fTPM RNG的BUG使得Linus Torvalds不满

导读因为在 Ryzen 系统上对内核造成了困扰&#xff0c;Linus Torvalds 最近在邮件列表中表达了对 AMD fTPM 硬件随机数生成器的不满&#xff0c;并提出了禁用该功能的建议。 因为在 Ryzen 系统上对内核造成了困扰&#xff0c;Linus Torvalds 最近在邮件列表中表达了对 AMD fTPM…

【【Verilog典型电路设计之FIFO设计】】

典型电路设计之FIFO设计 FIFO (First In First Out&#xff09;是一种先进先出的数据缓存器&#xff0c;通常用于接口电路的数据缓存。与普通存储器的区别是没有外部读写地址线&#xff0c;可以使用两个时钟分别进行写和读操作。FIFO只能顺序写入数据和顺序读出数据&#xff0…

ThinkPHP中实现IP地址定位

在网站开发中&#xff0c;我们经常需要获取用户的地理位置信息以提供个性化的服务。一种常见的方法是通过IP地址定位。在本文中&#xff0c;我们将介绍如何在ThinkPHP框架中实现IP地址定位。 一、IP地址定位的基本原理 IP地址是Internet上的设备在网络中的标识符。每个设备都有…

【从0开始学架构笔记】01 基础架构

文章目录 一、架构的定义1. 系统与子系统2. 模块与组件3. 框架与架构4. 重新定义架构 二、架构设计的目的三、复杂度来源&#xff1a;高性能1. 单机复杂度2. 集群复杂度2.1 任务分配2.2 任务分解&#xff08;微服务&#xff09; 四、复杂度来源&#xff1a;高可用1. 计算高可用…

GitKraken保姆级图文使用指南

前言 写这篇文章的原因是组内的产品和美术同学&#xff0c;开始参与到git工作流中&#xff0c;但是网上又没有找到一个比较详细的使用教程&#xff0c;所以干脆就自己写了一个[doge]。文章的内容比较基础&#xff0c;介绍了Git内的一些基础概念和基本操作&#xff0c;适合零基…

合并多个文本文件

使用 wxPython 模块合并多个文本文件的博客。以下是一篇示例博客&#xff1a; C:\pythoncode\blog\txtmerge.py 在 Python 编程中&#xff0c;我们经常需要处理文本文件。有时候&#xff0c;我们可能需要将多个文本文件合并成一个文件&#xff0c;以便进行进一步的处理或分析。…

QT报表Limereport v1.5.35编译及使用

1、编译说明 下载后QT CREATER中打开limereport.pro然后直接编译就可以了。编译后结果如下图&#xff1a; 一次编译可以得到库文件和DEMO执行程序。 2、使用说明 拷贝如下图编译后的lib目录到自己的工程目录中。 release版本的重新命名为librelease. PRO文件中配置 QT …

openpose姿态估计【学习笔记】

文章目录 1、人体需要检测的关键点2、Top-down方法3、Openpose3.1 姿态估计的步骤3.2 PAF&#xff08;Part Affinity Fields&#xff09;部分亲和场3.3 制作PAF标签3.4 PAF权值计算3.5 匹配方法 4、CPM&#xff08;Convolutional Pose Machines&#xff09;模型5、Openpose5.1 …

怎么修改图片的分辨率?

怎么修改图片的分辨率&#xff1f;很多人还不知道分辨率是什么意思&#xff0c;以为代表了图片的清晰度&#xff0c;然而并不是这样的&#xff0c;其实图片的分辨率就是图片尺寸大小的意思。修改图片的分辨率即改变图片的尺寸&#xff0c;通常以像素为单位表示。分辨率决定了图…

【linux基础(四)】对Linux权限的理解

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到开通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; Linux权限 1. 前言2. shell命…

八、Linux下,grep/wc/管道符/echo/重定向符/tail如何使用?

1、grep命令 &#xff08;1&#xff09;主要用于文件 &#xff08;2&#xff09;主要作用是“通过关键字&#xff0c;过滤文件行” &#xff08;3&#xff09;示例&#xff1a; 2、wc命令 &#xff08;1&#xff09;统计文件的行数、单词数等 &#xff08;2&#xff09;示例…

react之路由的安装与使用

一、路由安装 路由官网2021.11月初&#xff0c;react-router 更新到 v6 版本。使用最广泛的 v5 版本的使用 npm i react-router-dom5.3.0二、路由使用 2.1 路由的简单使用 第一步 在根目录下 创建 views 文件夹 ,用于放置路由页面 films.js示例代码 export default functio…

一文预览 | 8 月 16 日 NVIDIA 在 WAVE SUMMIT深度学习开发者大会 2023精彩亮点抢先看!

由深度学习技术及应用国家工程研究中心主办&#xff0c;百度飞桨和文心大模型承办的 WAVE SUMMIT深度学习开发者大会2023&#xff0c;将于 8 月 16 日在北京与大家见面。NVIDIA 作为技术合作伙伴&#xff0c;将携手百度飞桨参与这场技术盛会。 在这次大会中&#xff0c;NVIDIA…

Java 项目日志实例基础:Log4j

点击下方关注我&#xff0c;然后右上角点击...“设为星标”&#xff0c;就能第一时间收到更新推送啦~~~ 介绍几个日志使用方面的基础知识。 1 Log4j 1、Log4j 介绍 Log4j&#xff08;log for java&#xff09;是 Apache 的一个开源项目&#xff0c;通过使用 Log4j&#xff0c;我…