使用pytorch构建一个无监督的深度卷积GAN网络模型

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。

本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
在这里插入图片描述
DCGAN模型有以下特点:

  1. 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
  2. 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  3. 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  4. 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。

代码

model.py:

from torch import nnclass Generator(nn.Module):def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):super(Generator, self).__init__()self.z_dim = z_dim# Build the neural networkself.gen = nn.Sequential(self.make_gen_block(z_dim, hidden_dim * 4),self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),self.make_gen_block(hidden_dim * 2, hidden_dim),self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),)def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))else: # Final Layerreturn nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh())def unsqueeze_noise(self, noise):return noise.view(len(noise), self.z_dim, 1, 1)    # [b,c,h,w]def forward(self, noise):x = self.unsqueeze_noise(noise)return self.gen(x)class Discriminator(nn.Module):def __init__(self, im_chan=1, hidden_dim=16):super(Discriminator, self).__init__()self.disc = nn.Sequential(self.make_disc_block(im_chan, hidden_dim),self.make_disc_block(hidden_dim, hidden_dim * 2),self.make_disc_block(hidden_dim * 2, 1, final_layer=True),)def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):if not final_layer:return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride),nn.BatchNorm2d(output_channels),nn.LeakyReLU(0.2, inplace=True))else:  # Final Layerreturn nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride))def forward(self, image):disc_pred = self.disc(image)return disc_pred.view(len(disc_pred), -1)

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples, z_dim, device=device)criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])dataloader = DataLoader(MNIST('.', download=False, transform=transform),batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))def weights_init(m):if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)if isinstance(m, nn.BatchNorm2d):torch.nn.init.normal_(m.weight, 0.0, 0.02)torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)real = real.to(device)## Update discriminator ##disc_opt.zero_grad()fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()## Update generator ##gen_opt.zero_grad()fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)fake_2 = gen(fake_noise_2)disc_fake_pred = disc(fake_2)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))gen_loss.backward()gen_opt.step()# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step## Visualization code ##if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

每500个batch展示一次
每500个batch展示一次。
在这里插入图片描述
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
在这里插入图片描述
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
在这里插入图片描述
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。

下一篇构建WGAN_GP。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/776507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python算法100例-4.6 歌星大奖赛

完整源代码项目地址,关注博主私信源代码后可获取 1.问题描述2.问题分析3.算法设计4.确定程序框架5.完整的程序6.问题拓展7.知识点补充 1.问题描述 在歌星大奖赛中,有10个评委为参赛的选手打分,分数为1~100分。选手最…

Spring Cloud+Spring Alibaba笔记

Spring CloudSpring Alibaba 文章目录 Spring CloudSpring AlibabaNacos服务发现配置中心 OpenFeign超时机制开启httpclient5重试机制开启日志 SeataSentinel流量控制熔断降级热点控制规则持久化集成 OpenFeign集成 Gateway MicrometerZipKinGateway路由断言过滤器 Nacos 服务…

TikTok养号保姆级教程:学好这9招,你就是流量宠儿!

01 什么是养号? 在回答这个问题之前,我们首先要明白Tiktok智能推荐机制,也就是著名TikTok算法。 每个创作者发布视频时,TikTok都会根据视频内容、视频发布地区、标题内容进行特征识别,来判断视频可能会有哪些人群喜欢…

鱼哥赠书活动第14期:看完这本《数字化运维》掌握数字化运维方法,构建数字化运维体系

鱼哥赠书活动第14期:看完这本《数字化运维》掌握数字化运维方法,构建数字化运维体系 主要内容:读者对象:赠书抽奖规则:往期赠书福利: 数字化转型已经成为大势所趋,各行各业正朝着数字化方向转型&#xff0c…

SpringBoot3集成PostgreSQL

标签:PostgreSQL.Druid.Mybatis.Plus; 一、简介 PostgreSQL是一个功能强大的开源数据库系统,具有可靠性、稳定性、数据一致性等特点,且可以运行在所有主流操作系统上,包括Linux、Unix、Windows等。 通过官方文档可以…

MySQL数据库高级语句

文章目录 MySQL高级语句older by 排序区间判断查询或与且(or 与and)嵌套查询(多条件)查询不重复记录distinctcount 计数限制结果条目limit别名as常用通配符嵌套查询(子查询)同表不同表嵌套查询还能用于删除…

C语言例4-36:求Fibonacci数列的前40个数

教材优化代码如下&#xff1a; //求Fibonacci数列的前40个数 #include<stdio.h> int main(void) {long int f11,f21;int i1;for(;i<20;i){printf("%15ld%15ld",f1,f2);if(i%20)printf("\n");f1f2;f2f1;}return 0; } 结果如下&#xff1a; 我的基…

IC-随便记

1、移远通信---通信模组 物联网解决方案供应商&#xff0c;可提供完备的IoT产品和服务&#xff0c;涵盖蜂窝模组(5G/4G/3G/2G/LPWA)、车载前装模组、智能模组&#xff08;5G/4G/边缘计算&#xff09;、短距离通信模组(Wi-Fi&BT)、GNSS定位模组、卫星通信模组、天线等硬件产…

Radio Silence for mac 好用的防火墙软件

Radio Silence for Mac是一款功能强大的网络防火墙软件&#xff0c;专为Mac用户设计&#xff0c;旨在保护用户的隐私和网络安全。它具备实时网络监视和控制功能&#xff0c;可以精确显示每个网络连接的状态&#xff0c;让用户轻松掌握网络活动情况。 软件下载&#xff1a;Radio…

DNS 服务 Unbound 部署最佳实践

文章目录 安装unbound-control配置启动服务测试 参考&#xff1a; 官网地址&#xff1a;https://nlnetlabs.nl/projects/unbound/about/ 详细文档&#xff1a;https://unbound.docs.nlnetlabs.nl/en/latest/index.html DNS服务Unbound部署于使用 https://cloud.tencent.com/…

Redis项目实战

本文用用代码演示Redis实现分布式缓存、分布式锁、接口幂等性、接口防刷的功能。 课程地址&#xff1a;Redis实战系列-课程大纲_哔哩哔哩_bilibili 目录 一. 新建springBoot项目整合Redis 二. Redis实现分布式缓存 2.1 原理及好处 2.2 数据准备 2.3 Redis实现分布式缓存…

知行之桥EDI系统功能介绍——FlatFile 端口介绍

FlatFile 端口能够实现平面文件与XML文件的互相转换。 每个 Flat File 端口配置一个特定的平面文件格式&#xff0c;从而实现与 XML 格式的互相转换。Flat File 端口有两个主要的模式&#xff1a; Position DelimitedCharacter Delimited 对于 Position Delimited 平面文件&a…

【Git篇】复习git

文章目录 &#x1f354;什么是git⭐git和svn的区别 &#x1f354;搭建本地仓库&#x1f354;克隆远程仓库&#x1f6f8;git常用命令 &#x1f354;什么是git Git是一种分布式版本控制系统&#xff0c;它可以追踪文件的变化、协调多人在同一个项目上的工作、恢复文件的旧版本等…

在宝塔面板中,为自己的云服务器安装SSL证书,为所搭建的网站启用https(主要部分攻略)

前提条件 My HTTP website is running Nginx on Debian 10&#xff08;或者11&#xff09; 时间&#xff1a;2024-3-28 16:25:52 你的网站部署在Debain 10&#xff08;或者11&#xff09;的 Nginx上 安装单域名证书&#xff08;默认&#xff09;&#xff08;非泛域名&#xf…

现在做抖音小店都需要准备什么?需要什么条件?门槛很高吗?

大家好&#xff0c;我是电商花花。 自从抖音小店这个项目做的人越来越多&#xff0c;很多人都想赶上抖音小店这个红利项目&#xff0c;但是很多新手在刚开始接触这个项目时候因为不懂&#xff0c;开始频频踩雷&#xff0c;不得不关店重新再来。 我们今天汇总了一下抖音小店的…

OSCP靶场--image

OSCP靶场–image 考点(CVE-2023-34152 suid strace提权) 1.nmap扫描 ## ┌──(root㉿kali)-[~/Desktop] └─# nmap -Pn -sC -sV 192.168.178.178 --min-rate 2500 Starting Nmap 7.92 ( https://nmap.org ) at 2024-03-27 23:43 EDT Nmap scan report for 192.168.178.17…

如何在群晖NAS搭建bitwarden密码管理软件并实现无公网IP远程访问

前言 作者简介&#xff1a; 懒大王敲代码&#xff0c;计算机专业应届生 今天给大家聊聊如何在群晖NAS搭建bitwarden密码管理软件并实现无公网IP远程访问&#xff0c;希望大家能觉得实用&#xff01; 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496;&am…

降分违规?90%新手会遇到的抖音小店运营问题!解决方法快围观!

哈喽~我是电商月月 今天我们聊聊新手开抖音小店会遇到的问题以及解决方法 为了完整性我们从头到尾分析&#xff0c;根据情况不同可自行翻阅 一&#xff0c;入驻和运营时的操作问题 1.营业执照的办理&#xff0c;选择&#xff0c;填写 营业执照的办理可以去当地工商局办理&…

迭代器模式(统一对集合的访问方式)

目录 前言 UML plantuml 类图 实战代码 Iterator ArrayList Client 自定义迭代器 TreeNode TreeUtils Client 前言 在实际开发过程中&#xff0c;常用各种集合来存储业务数据并处理&#xff0c;比如使用 List&#xff0c;Map&#xff0c;Set 等等集合来存储业务数…