生成对抗网络——CGAN(代码+理解)

目录

一、CGAN模型介绍

二、CGAN训练流程

1. 初始化

2. 数据准备

3. 输出模型计算结果

4. 计算损失

5. 反向传播和优化

6. 迭代训练

三、CGAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

3. 训练结果

四、学习中产生的疑问,及文心一言回答

1. torch.cat((self.label_emb(labels.long()), noise), -1) 函数理解

2. Discriminator 模型疑问


一、CGAN模型介绍

        CGAN(Conditional Generative Adversarial Network)模型是一种 深度学习模型,属于生成对抗网络(GAN)的一种 变体。它的 基本思想是通过 训练生成器和判别器 两个网络,使生成器能够生成与给定条件 相匹配的 合成数据,而判别器则 负责区分真实数据和 生成数据。相比于GAN引入了条件信息(y),使得生成器可以生成与给定条件相匹配的合成数据,从而提高了生成数据的可控性和针对性。

二、CGAN训练流程

1. 初始化

        首先,初始化生成器和判别器的网络参数本例未初始化

2. 数据准备

        准备真实数据集和对应的条件信息。条件信息可以是类别标签、文本描述等。

# labels 即真事条件信息
for i, (imgs, labels) in enumerate(dataloader):# gen_labels 即假条件信息
gen_labels = torch.randint(0, opt.n_classes, (batch_size,))

3. 输出模型计算结果

1对于生成器:从随机噪声分布中采样噪声向量,并与条件信息一起输入到生成器中,生成合成数据。

gen_imgs = generator(z, gen_labels)

(2)对于判别器:将真实数据 及其 条件信息 和 生成数据及 其条件信息分别输入到判别器中,得到真实数据 和 生成数据的 判别结果。

validity_fake = discriminator(gen_imgs.detach(), gen_labels)validity_real = discriminator(imgs, labels)

4. 计算损失

1生成器损失:鼓励判别器对生成样本及相应条件c的判断为“真实”,即最大化log(D(G(z|c), c))。

g_loss = adversarial_loss(validity, valid)

2判别器损失:激励判别器正确区分真实样本(X, c)与生成样本(G(z|c), c)

d_loss = (d_real_loss + d_fake_loss) / 2

5. 反向传播和优化

        使用梯度下降算法或其他优化算法更新生成器和判别器的网络参数,以最小化各自的损失函数。

6. 迭代训练

        重复步骤 3至 5,直到达到预设的训练轮数或满足其他停止条件。

三、CGAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt
import argparse
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)dataloader = torch.utils.data.DataLoader(datasets.MNIST("./others/",train=False,download=False,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)img_shape = (opt.channels, opt.img_size, opt.img_size)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim + opt.n_classes, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),   # np.prod 计算所有元素的乘积nn.Tanh())def forward(self, noise, labels):# 噪声样本与标签的拼接,-1 表示最后一个维度gen_input = torch.cat((self.label_emb(labels.long()), noise), -1)img = self.model(gen_input)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)self.model = nn.Sequential(nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),    # 将输入单元的一部分(本例中为40%)设置为0,有助于 防止过拟合nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1),)def forward(self, img, labels):d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels.long())), -1)validity = self.model(d_in)return validity# 实例化模型
generator = Generator()
discriminator = Discriminator()# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# 均方误差
adversarial_loss = torch.nn.MSELoss()def sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = torch.randn(n_row ** 2, opt.latent_dim)# Get labels ranging from 0 to n_classes for n rowslabels = torch.Tensor(np.array([num for _ in range(n_row) for num in range(n_row)]))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "./others/images/CGAN/%d.png" % batches_done, nrow=n_row, normalize=True)def gen_img_plot(model, text_input, labels):prediction = np.squeeze(model(text_input, labels).detach().cpu().numpy()[:16])plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((prediction[i] + 1) / 2)plt.axis('off')plt.show()# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):# 初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)  # 返回批次数for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]valid = torch.ones(batch_size, 1)fake = torch.zeros(batch_size, 1)# 生成随机噪声 和 标签z = torch.randn(batch_size, opt.latent_dim)gen_labels = torch.randint(0, opt.n_classes, (batch_size,))# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()gen_imgs = generator(z, gen_labels)validity_fake = discriminator(gen_imgs.detach(), gen_labels)d_fake_loss = adversarial_loss(validity_fake, fake)validity_real = discriminator(imgs, labels)d_real_loss = adversarial_loss(validity_real, valid)d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()optimizer_D.step()# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()validity = discriminator(gen_imgs, gen_labels)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# batches_done = epoch * len(dataloader) + i# if batches_done % opt.sample_interval == 0:#     sample_image(n_row=10, batches_done=batches_done)with torch.no_grad():D_epoch_loss += d_lossG_epoch_loss += g_loss# 求平均损失with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss_.append(D_epoch_loss.item())G_loss_.append(G_epoch_loss.item())text_input = torch.randn(opt.batch_size, opt.latent_dim)text_labels = torch.randint(0, opt.n_classes, (opt.batch_size,))gen_img_plot(generator, text_input, text_labels)x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss', 'D_loss'])
plt.show()

3. 训练结果

四、学习中产生的疑问,及文心一言回答

1. torch.cat((self.label_emb(labels.long()), noise), -1) 函数理解

2. Discriminator 模型疑问


                          后续更新 GAN 的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/31249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ShuffleNet系列论文阅读笔记(ShuffleNetV1和ShuffleNetV2)

目录 ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices摘要Approach—方法Channel Shuffle for Group Convolutions—用于分组卷积的通道重排ShuffleNet Unit—ShuffleNet单元Network Architecture—网络体系结构 总结 ShuffleNet V2: Pra…

Vmware与Windows之间复制、粘贴内容、拖拽文件

Vmware17.0Ubuntu20 Vmware正确安装完linux虚拟机之后,这里以Ubuntu为例(其他linux或windows系统也是类似的),如果你使用的默认配置,正常情况下就可以复制、粘贴和拖拽内容的,双方向都是支持的。如果不能复…

nvdiadocker相关配置S3Gaussian

https://download.csdn.net/download/sinat_21699465/89458214 dockerfile文件参考: https://download.csdn.net/download/sinat_21699465/89458214 prework: 显卡驱动决定了cuda版本支持的上限。例如nvdia535驱动最高支持cuda12.2所以显卡驱动版本选…

15.树形虚拟列表实现(支持10000+以上的数据)el-tree(1万+数据页面卡死)

1.问题使用el-tree渲染的树形结构&#xff0c;当数据超过一万条以上的时候页面卡死 2.解决方法&#xff1a; 使用vue-easy-tree来实现树形虚拟列表&#xff0c;注意&#xff1a;vue-easy-tree需要设置高度 3.代码如下 <template><div class"ve-tree" st…

《web应用技术》第12次课后作业

1、了解servlet技术 Servlet(server applet)&#xff1a;运行在服务器的小程序&#xff0c;Servlet就是一个接口&#xff0c;定义了Java类被浏览器访问到的规则。将来我们自定义一个类&#xff0c;实现Servlet接口&#xff0c;复写方法。 Servlet本身不能独立运行&#xff0c…

2024广东省职业技能大赛云计算赛项实战——OpenStack搭建

OpenStack搭建 前言 搭建采用双节点安装&#xff0c;即controller控制节点和compute计算节点。 CentOS7 系统选择 2009 版本&#xff1a;CentOS-7-x86_64-DVD-2009.iso 可从阿里镜像站下载&#xff1a;https://mirrors.aliyun.com/centos/7/isos/x86_64/ OpenStack使用竞赛培…

JaveEE进阶----Spring Web MVC入门

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、什么是 Spring Web MVC&#xff1f;&#xff1f;1.1MVC 定义1.2 什么是Spring MVC ?1.3过浏览器和用户程序交互 二、 RequestMapping 注解三、Postman 前言…

Python中栈的实现与应用

Python中栈的实现与应用 一、引言 栈&#xff08;Stack&#xff09;是一种重要的数据结构&#xff0c;它遵循后进先出&#xff08;LIFO&#xff0c;Last In First Out&#xff09;的原则。栈的基本操作包括入栈&#xff08;push&#xff09;和出栈&#xff08;pop&#xff09…

容器基本概念_从虚拟化技术_到容器化技术_开通青云服务器_并远程连接_容器安装---分布式云原生部署架构搭建007

这一部分,属于以前都会用到的,会快速过一遍,对于关键技术问题会加以说明 https://www.yuque.com/leifengyang/oncloud文档地址在这里,可以看,有些命令可以复制使用 可以看到容器的出现就是 目的就是,让你做的所有的软件,都可以一键部署启动 打包就是docker build 然后: 对于…

关于后端幂等性问题分析与总结

后端幂等性&#xff08;Idempotency&#xff09;是指对系统执行一次操作或多次执行相同的操作&#xff0c;其结果始终如一。在分布式系统和API设计中&#xff0c;这是一个关键概念&#xff0c;因为它能保证用户无论请求被路由到哪个节点&#xff0c;多次执行相同的请求都不会导…

陈晓婚前婚后大变样

陈晓婚前婚后大变样&#xff1f;陈妍希揭秘甜蜜与现实的碰撞在娱乐圈的星光璀璨中&#xff0c;有一对夫妻总是津津乐道&#xff0c;那就是陈晓和陈妍希。他们的爱情故事&#xff0c;从荧幕到现实&#xff0c;一直备受关注。然而&#xff0c;近日陈妍希在节目中透露&#xff0c;…

22、架构-资源与调度

1、资源与调度 调度是容器编排系统最核心的功能之一&#xff0c;“编排”一词本身便包 含“调度”的含义。调度是指为新创建的Pod找到一个最恰当的宿主机 节点来运行它&#xff0c;这个过程成功与否、结果恰当与否&#xff0c;关键取决于容器 编排系统是如何管理与分配集群节点…

Hadoop 面试题(一)

1. 简述Hadoop核心组件 &#xff1f; Hadoop是一个开源的分布式计算平台&#xff0c;其核心组件主要包括以下几个方面&#xff1a; HDFS (Hadoop Distributed File System)&#xff1a; 一个分布式文件系统&#xff0c;用于在廉价的硬件上存储和管理大量数据。 MapReduce&…

Elasticsearch**Elasticsearch自定义插件开发入门

Elasticsearch作为一个强大的搜索引擎和数据分析工具&#xff0c;其强大的扩展性是其受欢迎的重要原因之一。自定义插件开发入门** Elasticsearch作为一个强大的搜索引擎和数据分析工具&#xff0c;其强大的扩展性是其受欢迎的重要原因之一。通过自定义插件&#xff0c;用户可…

QT设计模式:备忘录模式

备忘录模式&#xff08;Memento Pattern&#xff09;是一种行为型设计模式&#xff0c;主要用于保存一个对象当前的状态&#xff0c;并在需要时恢复该状态。它常应用于以下场景&#xff1a; 撤销操作&#xff1a;如文本编辑器撤销、软件开发中的版本控制等&#xff0c;用户可以…

差分总结(一维+二维)

差分&#xff0c;可以视作前缀和的逆运算。 前缀和用于去求一个区间段的和 差分用于改变一个区间的值&#xff08;比如说某个区间都加上或者减去一个数&#xff09; P2367 语文成绩 #include<bits/stdc.h> using namespace std; #define int long long int n,p; int a…

RabbitMQ 学习笔记

RabbitMQ学习笔记 一些概念 Broker &#xff1a;RabbitMQ服务。 virtual host&#xff1a; 其实就是分组。 Connection&#xff1a;连接&#xff0c;生产者消费者与Broker之间的TCP连接。 Channel&#xff1a;网络信道&#xff0c;轻量级的Connection&#xff0c;使用Chann…

2024广东省职业技能大赛云计算赛项实战——Minio服务搭建

Minio服务搭建 前言 这道题是比赛时考到的&#xff0c;没找到具体题目&#xff0c;但在公布的样题中找到了&#xff0c;虽然很短~ 使用提供的 OpenStack 云平台&#xff0c;申请一台云主机&#xff0c;使用提供的软件包安装部署 MINIO 服务并使用 systemctl 管理 Minio是一个…

HTML静态网页成品作业(HTML+CSS)——手机电子商城网页(4个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有4个页面。 二、作品演示 三、代…

Vue 封装组件之Input框

封装Input组件:MyInput.vue <template><div class"base-input-wraper"><el-inputv-bind"$attrs"v-on"$listeners"class"e-input":style"inputStyle":value"value":size"size"input&quo…