生成对抗网络——CGAN(代码+理解)

目录

一、CGAN模型介绍

二、CGAN训练流程

1. 初始化

2. 数据准备

3. 输出模型计算结果

4. 计算损失

5. 反向传播和优化

6. 迭代训练

三、CGAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

3. 训练结果

四、学习中产生的疑问,及文心一言回答

1. torch.cat((self.label_emb(labels.long()), noise), -1) 函数理解

2. Discriminator 模型疑问


一、CGAN模型介绍

        CGAN(Conditional Generative Adversarial Network)模型是一种 深度学习模型,属于生成对抗网络(GAN)的一种 变体。它的 基本思想是通过 训练生成器和判别器 两个网络,使生成器能够生成与给定条件 相匹配的 合成数据,而判别器则 负责区分真实数据和 生成数据。相比于GAN引入了条件信息(y),使得生成器可以生成与给定条件相匹配的合成数据,从而提高了生成数据的可控性和针对性。

二、CGAN训练流程

1. 初始化

        首先,初始化生成器和判别器的网络参数本例未初始化

2. 数据准备

        准备真实数据集和对应的条件信息。条件信息可以是类别标签、文本描述等。

# labels 即真事条件信息
for i, (imgs, labels) in enumerate(dataloader):# gen_labels 即假条件信息
gen_labels = torch.randint(0, opt.n_classes, (batch_size,))

3. 输出模型计算结果

1对于生成器:从随机噪声分布中采样噪声向量,并与条件信息一起输入到生成器中,生成合成数据。

gen_imgs = generator(z, gen_labels)

(2)对于判别器:将真实数据 及其 条件信息 和 生成数据及 其条件信息分别输入到判别器中,得到真实数据 和 生成数据的 判别结果。

validity_fake = discriminator(gen_imgs.detach(), gen_labels)validity_real = discriminator(imgs, labels)

4. 计算损失

1生成器损失:鼓励判别器对生成样本及相应条件c的判断为“真实”,即最大化log(D(G(z|c), c))。

g_loss = adversarial_loss(validity, valid)

2判别器损失:激励判别器正确区分真实样本(X, c)与生成样本(G(z|c), c)

d_loss = (d_real_loss + d_fake_loss) / 2

5. 反向传播和优化

        使用梯度下降算法或其他优化算法更新生成器和判别器的网络参数,以最小化各自的损失函数。

6. 迭代训练

        重复步骤 3至 5,直到达到预设的训练轮数或满足其他停止条件。

三、CGAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt
import argparse
import numpy as npparser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)dataloader = torch.utils.data.DataLoader(datasets.MNIST("./others/",train=False,download=False,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)img_shape = (opt.channels, opt.img_size, opt.img_size)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim + opt.n_classes, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),   # np.prod 计算所有元素的乘积nn.Tanh())def forward(self, noise, labels):# 噪声样本与标签的拼接,-1 表示最后一个维度gen_input = torch.cat((self.label_emb(labels.long()), noise), -1)img = self.model(gen_input)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)self.model = nn.Sequential(nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),    # 将输入单元的一部分(本例中为40%)设置为0,有助于 防止过拟合nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 512),nn.Dropout(0.4),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 1),)def forward(self, img, labels):d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels.long())), -1)validity = self.model(d_in)return validity# 实例化模型
generator = Generator()
discriminator = Discriminator()# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# 均方误差
adversarial_loss = torch.nn.MSELoss()def sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = torch.randn(n_row ** 2, opt.latent_dim)# Get labels ranging from 0 to n_classes for n rowslabels = torch.Tensor(np.array([num for _ in range(n_row) for num in range(n_row)]))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "./others/images/CGAN/%d.png" % batches_done, nrow=n_row, normalize=True)def gen_img_plot(model, text_input, labels):prediction = np.squeeze(model(text_input, labels).detach().cpu().numpy()[:16])plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((prediction[i] + 1) / 2)plt.axis('off')plt.show()# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):# 初始化损失值D_epoch_loss = 0G_epoch_loss = 0count = len(dataloader)  # 返回批次数for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]valid = torch.ones(batch_size, 1)fake = torch.zeros(batch_size, 1)# 生成随机噪声 和 标签z = torch.randn(batch_size, opt.latent_dim)gen_labels = torch.randint(0, opt.n_classes, (batch_size,))# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()gen_imgs = generator(z, gen_labels)validity_fake = discriminator(gen_imgs.detach(), gen_labels)d_fake_loss = adversarial_loss(validity_fake, fake)validity_real = discriminator(imgs, labels)d_real_loss = adversarial_loss(validity_real, valid)d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()optimizer_D.step()# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()validity = discriminator(gen_imgs, gen_labels)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# batches_done = epoch * len(dataloader) + i# if batches_done % opt.sample_interval == 0:#     sample_image(n_row=10, batches_done=batches_done)with torch.no_grad():D_epoch_loss += d_lossG_epoch_loss += g_loss# 求平均损失with torch.no_grad():D_epoch_loss /= countG_epoch_loss /= countD_loss_.append(D_epoch_loss.item())G_loss_.append(G_epoch_loss.item())text_input = torch.randn(opt.batch_size, opt.latent_dim)text_labels = torch.randint(0, opt.n_classes, (opt.batch_size,))gen_img_plot(generator, text_input, text_labels)x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss', 'D_loss'])
plt.show()

3. 训练结果

四、学习中产生的疑问,及文心一言回答

1. torch.cat((self.label_emb(labels.long()), noise), -1) 函数理解

2. Discriminator 模型疑问


                          后续更新 GAN 的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/31249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ShuffleNet系列论文阅读笔记(ShuffleNetV1和ShuffleNetV2)

目录 ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices摘要Approach—方法Channel Shuffle for Group Convolutions—用于分组卷积的通道重排ShuffleNet Unit—ShuffleNet单元Network Architecture—网络体系结构 总结 ShuffleNet V2: Pra…

Vmware与Windows之间复制、粘贴内容、拖拽文件

Vmware17.0Ubuntu20 Vmware正确安装完linux虚拟机之后,这里以Ubuntu为例(其他linux或windows系统也是类似的),如果你使用的默认配置,正常情况下就可以复制、粘贴和拖拽内容的,双方向都是支持的。如果不能复…

nvdiadocker相关配置S3Gaussian

https://download.csdn.net/download/sinat_21699465/89458214 dockerfile文件参考: https://download.csdn.net/download/sinat_21699465/89458214 prework: 显卡驱动决定了cuda版本支持的上限。例如nvdia535驱动最高支持cuda12.2所以显卡驱动版本选…

15.树形虚拟列表实现(支持10000+以上的数据)el-tree(1万+数据页面卡死)

1.问题使用el-tree渲染的树形结构&#xff0c;当数据超过一万条以上的时候页面卡死 2.解决方法&#xff1a; 使用vue-easy-tree来实现树形虚拟列表&#xff0c;注意&#xff1a;vue-easy-tree需要设置高度 3.代码如下 <template><div class"ve-tree" st…

2024广东省职业技能大赛云计算赛项实战——OpenStack搭建

OpenStack搭建 前言 搭建采用双节点安装&#xff0c;即controller控制节点和compute计算节点。 CentOS7 系统选择 2009 版本&#xff1a;CentOS-7-x86_64-DVD-2009.iso 可从阿里镜像站下载&#xff1a;https://mirrors.aliyun.com/centos/7/isos/x86_64/ OpenStack使用竞赛培…

JaveEE进阶----Spring Web MVC入门

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、什么是 Spring Web MVC&#xff1f;&#xff1f;1.1MVC 定义1.2 什么是Spring MVC ?1.3过浏览器和用户程序交互 二、 RequestMapping 注解三、Postman 前言…

容器基本概念_从虚拟化技术_到容器化技术_开通青云服务器_并远程连接_容器安装---分布式云原生部署架构搭建007

这一部分,属于以前都会用到的,会快速过一遍,对于关键技术问题会加以说明 https://www.yuque.com/leifengyang/oncloud文档地址在这里,可以看,有些命令可以复制使用 可以看到容器的出现就是 目的就是,让你做的所有的软件,都可以一键部署启动 打包就是docker build 然后: 对于…

陈晓婚前婚后大变样

陈晓婚前婚后大变样&#xff1f;陈妍希揭秘甜蜜与现实的碰撞在娱乐圈的星光璀璨中&#xff0c;有一对夫妻总是津津乐道&#xff0c;那就是陈晓和陈妍希。他们的爱情故事&#xff0c;从荧幕到现实&#xff0c;一直备受关注。然而&#xff0c;近日陈妍希在节目中透露&#xff0c;…

差分总结(一维+二维)

差分&#xff0c;可以视作前缀和的逆运算。 前缀和用于去求一个区间段的和 差分用于改变一个区间的值&#xff08;比如说某个区间都加上或者减去一个数&#xff09; P2367 语文成绩 #include<bits/stdc.h> using namespace std; #define int long long int n,p; int a…

RabbitMQ 学习笔记

RabbitMQ学习笔记 一些概念 Broker &#xff1a;RabbitMQ服务。 virtual host&#xff1a; 其实就是分组。 Connection&#xff1a;连接&#xff0c;生产者消费者与Broker之间的TCP连接。 Channel&#xff1a;网络信道&#xff0c;轻量级的Connection&#xff0c;使用Chann…

2024广东省职业技能大赛云计算赛项实战——Minio服务搭建

Minio服务搭建 前言 这道题是比赛时考到的&#xff0c;没找到具体题目&#xff0c;但在公布的样题中找到了&#xff0c;虽然很短~ 使用提供的 OpenStack 云平台&#xff0c;申请一台云主机&#xff0c;使用提供的软件包安装部署 MINIO 服务并使用 systemctl 管理 Minio是一个…

HTML静态网页成品作业(HTML+CSS)——手机电子商城网页(4个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有4个页面。 二、作品演示 三、代…

python API自动化(Pytest+Excel+Allure完整框架集成+yaml入门+大量响应报文处理及加解密、签名处理)

1.pytest数据参数化 假设你需要测试一个登录功能&#xff0c;输入用户名和密码后验证登录结果。可以使用参数化实现多组输入数据的测试: 测试正确的用户名和密码登录成功 测试正确的用户名和错误的密码登录失败 测试错误的用户名和正确的密码登录失败 测试错误的用户名和密码登…

定时器-前端使用定时器3s轮询状态接口,2min为接口超时

背景 众所周知&#xff0c;后端是处理不了复杂的任务的&#xff0c;所以经过人家的技术讨论之后&#xff0c;把业务放在前端来实现。记录一下这次的离大谱需求吧。 如图所示&#xff0c;这个页面有5个列表&#xff0c;默认加载计划列表。但是由于后端的种种原因&#xff0c;这…

C++ | Leetcode C++题解之第171题Excel表列序号

题目&#xff1a; 题解&#xff1a; class Solution { public:int titleToNumber(string columnTitle) {int number 0;long multiple 1;for (int i columnTitle.size() - 1; i > 0; i--) {int k columnTitle[i] - A 1;number k * multiple;multiple * 26;}return num…

QT中利用QMovie实现动态加载效果

1、效果 2、代码 #include "widget.h" #include "ui_widget.h" #include <QLabel> #include <QMovie>

YOLOv10训练自己的数据集(图像目标检测)

目录 1、下载代码 2、环境配置 3、准备数据集 4、yolov10训练 可能会出现报错&#xff1a; 1、下载代码 源码地址&#xff1a;https://github.com/THU-MIG/yolov10 2、环境配置 打开源代码&#xff0c;在Terminal中&#xff0c;使用conda 创建虚拟环境配置 命令如下&a…

Python基础教程(二十五):内置函数整理

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…

将AI带入企业,红帽选择了开源

伴随着生成式AI与大模型技术的飞速发展&#xff0c;业界人士对于生成式AI应用在企业的落地也愈发关注。 近日在2024红帽媒体Open讲上&#xff0c;红帽全球副总裁兼大中华区总裁曹衡康深入剖析了AI在混合云中的应用及其带来的资源利用最大化优势&#xff0c;并同与会媒体共同探讨…

Redis-数据类型-List

文章目录 1、通过客户端连接redis2、切换到第二个数据库 db13、查看当前库所有key4、从左边插入一个或多个值5、按照索引下标获得元素(从左到右)6、针对key指定的list&#xff0c;从右边放入元素7、返回list集合的长度8、从左边弹出一个元素。弹出返回删除9、从右边弹出一个元素…