生成对抗网络——GAN深度卷积实现(代码+理解)

        本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。

  生成对抗网络—GAN(代码+理解)

http://t.csdnimg.cn/HDfLOicon-default.png?t=N7T8http://t.csdnimg.cn/HDfLO


目录

一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

2. 模型训练时

3. 优化器定义

4. 训练数据

5. 模型结构

(1)生成器        

(2)判别器


一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

import torch
import torch.nn as nn
import argparse
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)# 加载数据
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./others/",train=False,download=False,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # 给定均值和标准差的正态分布N(mean,std)中生成值torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.init_size = opt.img_size // 4  # 原为28*28,现为32*32,两边各多了2self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),    # 调整数据的分布,使其 更适合于 下一层的 激活函数或学习nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, z):out = self.l1(z)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.model = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样(图片进行 4次卷积操作,变为ds_size * ds_size尺寸大小)ds_size = opt.img_size // 2 ** 4self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1),nn.Sigmoid())def forward(self, img):out = self.model(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)return validity# 实例化
generator = Generator()
discriminator = Discriminator()# 初始化参数
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# 交叉熵损失函数
adversarial_loss = torch.nn.BCELoss()def gen_img_plot(model, epoch, text_input):prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((prediction[i] + 1) / 2)plt.axis('off')plt.show()# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):# 初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)  # 返回批次数for i, (imgs, _) in enumerate(dataloader):valid = torch.ones(imgs.shape[0], 1)fake = torch.zeros(imgs.shape[0], 1)# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()z = torch.randn(imgs.shape[0], opt.latent_dim)gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# batches_done = epoch * len(dataloader) + i# if batches_done % opt.sample_interval == 0:#     save_image(gen_imgs.data[:25], "others/images/%d.png" % batches_done, nrow=5, normalize=True)# 累计每一个批次的losswith torch.no_grad():D_epoch_loss += d_lossG_epoch_loss += g_loss# 求平均损失with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss_.append(D_epoch_loss.item())G_loss_.append(G_epoch_loss.item())text_input = torch.randn(opt.batch_size, opt.latent_dim)gen_img_plot(generator, epoch, text_input)x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

        函数 weights_init_normal 用于初始化 模型参数,为什么要 以 均值和标准差 的正态分布中采样的数 为标准?

2. 模型训练时

        这里“d_loss = (real_loss + fake_loss) / 2” 中的 “/ 2” 操作,在 实际训练中 有什么作用?

        由(real_loss + fake_loss) / 2的 得到 的 d_loss 与(real_loss+fake_loss)得到的 d_loss 进行 回溯,两者结果会 有什么不同吗?

3. 优化器定义

        设置 betas=(opt.b1, opt.b2) 有什么 实际的作用?通俗易懂的讲一下

        betas=(opt.b1, opt.b2) 是怎样 更新学习率的?

4. 训练数据

        这里我们用的data为 MNIST,为什么img_size设置为 32,不是 28?

5. 模型结构

(1)生成器        

        解释一下为什么是“Upsample, Conv2d, BatchNorm2d, LeakyReLU ”这种顺序?

(2)判别器

        模型的 基本 运算步骤是什么?其中为什么需要 “Dropout2d( p=0.25, inplace=False)”这一步?

        关于“ds_size” 和 “128 * ds_size ** 2”的实际意义?


                                后续更新 GAN的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/855049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

The First项目报告:深度解读Layer 2生态zkSync

zkSync发币了,这个无数撸毛党心心念念数年之久的项目终于要来了,zkSync 是由Matter Labs 于2019 年推出的以太坊Layer 2 扩容解决方案,作为L2龙头项目之一,与其同属一个层次的L2四大天王之三Optimism、Arbitrum、zkSync、StarkNet…

Profibus协议转Modbus协议网关模块帮助PLC实现智能激光设备通讯

一、前言 Profibus转Modbus网关(XD-MDPB100)是一种工业通信协议转换设备,用于实现Profibus协议与Modbus协议之间的转换。Profibus转Modbus网关在工业自动化系统中具有广泛的应用,它解决了不同协议设备之间的通信问题。本文将深入…

怎么样判断真假单北斗

国产化替代正在中国各行各业逐步提升中,特别涉及重点产业——国家安全! 只有仅支持B1I和B3信号的芯片才是真正的单北斗芯片。但凡你支持了B1C、B2a、B2b中的一个就是假的单北斗。 B1C/L1/E1、B2a/ L5/E5a、B2b/G3/E5b这些频点与其他GNSS系统是完全重合的…

湖北科技学院2024年成人高等继续教育招生简章

湖北科技学院,这所坐落在荆楚大地的高等学府,一直以来都是培养各类专业人才的重要基地。随着社会的快速发展,终身学习的理念深入人心,成人高等继续教育作为满足广大成年人提升学历、增强职业技能的重要途径,受到了越来…

Java输入输出语句 和 保留字

目录 键盘输入语句 保留字 键盘输入语句 Input.java , 需要一个 扫描器(对象), 就是Scanner 步骤 : 导入该类的所在包, java.util.*创建该类对象(声明变量)调用里面的功能 案例要求:可以从控制台接收用户信息,【姓…

润滑不良:滚珠花键磨损的隐形杀手!

滚珠花键作为一种精密机械传动元件,被广泛应用于各种机器和设备中,起着传递动力和运动的重要作用。滚珠花键经过长时间的运行,难免会多少些磨损,严重的话还会导致设备不能正常运转。那么,如何保证它的正常运行呢&#…

88. 合并两个有序数组(简单)

88. 合并两个有序数组 1. 题目描述2.详细题解3.代码实现3.1 Python3.2 Java 1. 题目描述 题目中转:88. 合并两个有序数组 2.详细题解 两个数组均有序(非递减),要求合并两个数组,直观的思路,借助第三个数…

【Linux环境下Hadoop部署】—报错“Unit ntpd.service could not be found.“

项目场景: 执行 “systemctl status ntpd” 命令。 问题描述 报错:Unit ntpd.service could not be found. 原因分析: 没有安装ntp 解决方案: 执行 “yum install ntp” 命令,再次执行 “systemctl status ntpd” 命令…

Docker部署私有仓库Harbor

Harbor构建Docker私有仓库 文章目录 Harbor构建Docker私有仓库资源列表一、部署Docker-Compose服务1.1、下载最新Docker-Compose1.2、查看Docker-Compose版本 二、部署Harbor服务2.1、下载Harbor安装程序2.2、配置Harbor参数文件2.3、所需参数和可选参数2.3.1、所需参数2.3.2、…

CP AUTOSAR标准之MemoryDriver(AUTOSAR_CP_SWS_MemoryDriver)

1 简介和功能概述 该规范描述了AUTOSAR基础软件模块内存驱动程序(Mem)的功能、API和配置。   内存驱动程序提供访问不同类型内存设备的基本服务,如读取、写入、擦除和空白检查。   尽管闪存仍然是最常见的非易失性存储器技术,但内存驱动程序规范考虑了所有相关的内存设备…

虚拟警示教育馆如何革新安全教育?揭秘其深远意义与实际优势

一、推动警示教育的创新与普及 虚拟警示教育馆是将传统警示教育与现代科技相结合的新型教育模式。其意义主要体现在以下几个方面: 1、增强教育的互动性和沉浸感:虚拟警示教育馆通过3D建模、VR等技术,创建逼真的警示场景。这种身临其境的体验能…

(资料收藏)王阳明传《知行合一》共74讲,王阳明知行合一音频讲解资料

今天给大家带来的不是软件,而是一份精神食粮——《知行合一》的教程福利。这可不是一般的教程,它关乎心灵,关乎智慧,关乎我们如何在纷繁复杂的世界中找到自己的位置。 咱们得聊聊王阳明,这位明代的大儒,他…

餐饮业应该购置精酿啤酒设备吗?

近几年,啤酒行业刮起了一股“精酿风”,它不只是一种饮品口味上的变化,更像是一个生活方式的升级。精酿啤酒的兴起,不仅体现在味道的多样性和层次感上,更重要的是它代表了一种生活态度,是对品质生活的追求。…

可复用验证的测试用例 5大编写技巧

编写可复用验证的测试用例,节省了编写新测试用例的时间和资源,提高了测试效率和项目质量,减少错误修复成本,有利于实现较高的投入产出比。缺乏可复用的测试用例会导致测试团队不断重复创建相似的测试场景,消耗大量时间…

每日一练:攻防世界:Ditf

这是难度1的题吗??? 拿到一个png图片,第一反应就是CRC爆破,结果还真的是高度被修改了 这里拿到一个字符串,提交flag结果发现不是,那么只可能是密钥之类的了 看看有没有压缩包,搜索…

IMU应用于体操训练

考虑到在艺术体操训练与竞赛中艺术体操的训练与比赛中,地板项目导致的伤率最高,最近,一个来自澳大利亚的科研团队利用IMU评估运动员执行基础翻腾技巧训练时,他们上肢与下肢所承受的冲击负荷。 本次实验共有十四名艺术体操运动员参…

(五)React受控表单、获取DOM

1. React受控表单 概念&#xff1a;使用React组件的状态&#xff08;useState&#xff09;控制表单的状态 准备一个React状态值 const [value, setValue] useState()通过value属性绑定状态&#xff0c;通过onChange属性绑定状态同步的函数 <input type"text"…

计算机网络期末复习1(最后一天才开始学版)

1.一个PPP帧的数据部分&#xff08;用十六进制写出&#xff09;是7D 5E FE 27 7D 5D 7D 5D 65 7D 5E。试问真正的数据是&#xff08;用十六进制写出&#xff09; 由于PPP帧的标志字段为7E,因此,为了区别标志字段和信息字段,将信息字段中出现的每一个0x7E转变成(0x7D,0x5E),0x7…

PIL保存后的图像莫名的失真,部分不失真部分很失真

原图片是这样的&#xff1a; PIL会自行**“自救”被正则化的图片&#xff0c;导致自救过曝&#xff0c;部分颜色非常失真&#xff0c;但是部分又保存的还行。现象如下&#xff1a; 这里你检查一下你保存的是不是被正则化的图片**&#xff0c;如果是&#xff0c;改改。 查看一…

【vue大作业-端午节主题网站】【预览展示视频和详细文档】

vue大作业-端午节主题网站介绍 端午节&#xff0c;又称为龙舟节&#xff0c;是中国的传统节日之一&#xff0c;每年农历五月初五庆祝。这个节日不仅是纪念古代爱国诗人屈原的日子&#xff0c;也是家人团聚、共享美食的时刻。今天&#xff0c;我们非常高兴地分享一个以端午节为…