CGAN原理讲解与源码

1.CGAN原理

生成器,输入的是c和z,z是随机噪声,c是条件,对应MNIST数据集,要求规定生成数字是几。
输出是生成的虚假图片。
在这里插入图片描述

判别器的输入是
1.生成器输出的虚假图片x;
2.对应图片的标签c

在这里插入图片描述
来自真实数据集,且标签是对的,就是1
如果是生成器生成的虚假照片就直接是1,都不需要看是否与标签对应

上面第二张图的意思就是,当图片是来自真实数据集,再来看是否与标签对应

2.CGAN损失函数

在这里插入图片描述
上面这个值,生成器越小越好,即判别器认为真实图片是真实图片的概率越低越好,认为虚假图片是真实图片的概率越高越好
判别器越大越好,即判别器认为真实图片是真实图片的概率越大越好,认为虚假图片是真实图片的概率越小越好

criterion(output, label)

在判别器中,
1)output是预测来自真实数据集的图片和标签是否是真实且符合标签的概率,label是1
2)output是预测虚假图片是否是虚假图片的概率,label是0
在生成器中,
output是判别器预测虚假图片是否是真实图片的概率,label是1
以上三种,都是交叉熵越小越好

3.生成器和判别器的源码

class Generator(nn.Module):def __init__(self, num_channel=1, nz=100, nc=10, ngf=64):super(Generator, self).__init__()self.main = nn.Sequential(# 输入维度 110 x 1 x 1nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# 特征维度 (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 特征维度 (ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 特征维度 (ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 特征维度 (ngf) x 32 x 32nn.ConvTranspose2d(ngf, num_channel, 4, 2, 1, bias=False),nn.Tanh()# 特征维度. (num_channel) x 64 x 64)self.apply(weights_init)def forward(self, input_z, onehot_label):input_ = torch.cat((input_z, onehot_label), dim=1)n, c = input_.size()input_ = input_.view(n, c, 1, 1)return self.main(input_)class Discriminator(nn.Module):def __init__(self, num_channel=1, nc=10, ndf=64):super(Discriminator, self).__init__()self.main = nn.Sequential(# 输入维度 (num_c3# channel+nc) x 64 x 64  1*64*64的图像和10维的类别   10维类别先转换成10*64*64    然后合并就是11*64*64# 输入通道  输出通道   卷积核的大小   步长  填充#原始输入张量:b 11 64  64nn.Conv2d(num_channel + nc, ndf, 4, 2, 1, bias=False),   #b  64  32  32nn.LeakyReLU(0.2, inplace=True),# 特征维度 (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  #b   64*2   16  16nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# 特征维度 (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),    #b   64*4   8    8nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# 特征维度 (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),    #b   64*8    4    4nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# 特征维度 (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),        #b   1    1    1      其实就是一个数值,区间在正无穷到负无穷之间nn.Sigmoid())self.apply(weights_init)def forward(self, images, onehot_label):device = 'cuda' if torch.cuda.is_available() else 'cpu'h, w = images.shape[2:]n, nc = onehot_label.shape[:2]label = onehot_label.view(n, nc, 1, 1) * torch.ones([n, nc, h, w]).to(device)input_ = torch.cat([images, label], 1)return self.main(input_)

4.训练过程

MODEL_G_PATH = "./"
LOG_G_PATH = "Log_G.txt"
LOG_D_PATH = "Log_D.txt"
IMAGE_SIZE = 64
BATCH_SIZE = 128
WORKER = 1
LR = 0.0002
NZ = 100
NUM_CLASS = 10
EPOCH = 10data_loader = loadMNIST(img_size=IMAGE_SIZE, batch_size=BATCH_SIZE)  #原始图片宽高是28*28的,给改变成64*64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))g_writer = LossWriter(save_path=LOG_G_PATH)
d_writer = LossWriter(save_path=LOG_D_PATH)fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)
fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)
fix_input_c = onehot(fix_input_c, NUM_CLASS)img_list = []
G_losses = []
D_losses = []
iters = 0print("开始训练>>>")
for epoch in range(EPOCH):print("正在保存网络并评估...")save_network(MODEL_G_PATH, netG, epoch)with torch.no_grad():fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()images = recover_image(fake_imgs)full_image = np.full((5 * 64, 5 * 64, 3), 0, dtype="uint8")for i in range(25):row = i // 5col = i % 5full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]plt.imshow(full_image)#plt.show()plt.imsave("{}.png".format(epoch), full_image)for data in data_loader:##################################################判别器交叉熵越小越好# 1. 更新判别器D: 最大化 log(D(x)) + log(1 - D(G(z)))# 等同于最小化 - log(D(x)) - log(1 - D(G(z)))#################################################netD.zero_grad()real_imgs, input_c = data   #这里的input_c其实就是数据集每一批中的每个图片对应的标签input_c = input_c.to(device)input_c = onehot(input_c, NUM_CLASS).to(device)# 1.1 来自数据集的样本#这里这一步就是想训练判别器,能够识别出是否真实图片,以及图片与对应的标签是否对应real_imgs = real_imgs.to(device)b_size = real_imgs.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)#上面的torch.full是生成一维的 b_size这么多的,填充值为1.的张量# real_label = 1.# fake_label = 0.# 使用鉴别器对数据集样本做判断output = netD(real_imgs, input_c).view(-1)   #view() 方法被用来将模型输出的张量进行扁平化操作,即将张量中的所有元素都展开成一个一维向量# 计算交叉熵损失 -log(D(x))errD_real = criterion(output, label)# 对判别器进行梯度回传errD_real.backward()D_x = output.mean().item()    #对同一批预测结果的交叉熵取平均值## 1.2 生成随机向量   这一步想要训练判别器是否能够识别出是虚假图片noise = torch.randn(b_size, NZ, device=device)# 生成随机标签input_c = (torch.rand(b_size, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)input_c = onehot(input_c, NUM_CLASS)#fix_noise = torch.randn(BATCH_SIZE, NZ, device=device)#fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(device)#fix_input_c = onehot(fix_input_c, NUM_CLASS)# 来自生成器生成的样本fake = netG(noise, input_c)label.fill_(fake_label)# real_label = 1.# fake_label = 0.# 使用鉴别器对生成器生成样本做判断output = netD(fake.detach(), input_c).view(-1)   #view() 方法被用来将模型输出的张量进行扁平化操作,即将张量中的所有元素都展开成一个一维向量# 计算交叉熵损失 -log(1 - D(G(z)))errD_fake = criterion(output, label)# 对判别器进行梯度回传errD_fake.backward()D_G_z1 = output.mean().item()# 对判别器计算总梯度,-log(D(x))-log(1 - D(G(z)))errD = errD_real + errD_fake# 更新判别器optimizerD.step()################################################## 2. 更新生成器G: 最小化 log(D(x)) + log(1 - D(G(z))),# 等同于最小化log(1 - D(G(z))),即最小化-log(D(G(z)))# 也就等同于最小化-(log(D(G(z)))*1+log(1-D(G(z)))*0)# 令生成器样本标签值为1,上式就满足了交叉熵的定义#################################################netG.zero_grad()# 对于生成器训练,令生成器生成的样本为真,label.fill_(real_label)# real_label = 1.# fake_label = 0.output = netD(fake, input_c).view(-1)# 对生成器计算损失errG = criterion(output, label)# 因为这里判别器的角度label真实应该是0,但是站在生成器的角度,label真实应该是1,即生成器希望生成的虚假图片让判别器识别的时候,会误以为1才比较好,即误以为是真实的图片# 所以生成器交叉熵也是越小越好# 对生成器进行梯度回传errG.backward()D_G_z2 = output.mean().item()# 更新生成器optimizerG.step()# 输出损失状态if iters % 5 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, EPOCH, iters % len(data_loader), len(data_loader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))d_writer.add(loss=errD.item(), i=iters)g_writer.add(loss=errG.item(), i=iters)# 保存损失记录G_losses.append(errG.item())D_losses.append(errD.item())iters += 1

5.关于交叉熵

熵代表确定性,熵越小越好,说明确定性越好
在这里,因为参照的是真实标签,它的熵是0
而交叉熵-熵=相对熵
故相对熵在预测情况相对真实情况的时候,相对熵=交叉熵,相对熵越小,说明预测情况越接近真实情况;
同理,交叉熵越小,说明预测情况越接近真实情况。

在二分类0,1任务中,经过卷积、正则化、激活函数ReLU等操作之后,假如生成了一个(B,1,1,1)的张量,每个值在(无穷小,无穷大)之间,经过sigmoid函数,会变成一个(B,1,1,1)的张量,数值h在(0,1)之间,如果这个h>0.5说明模型预测的是1,如果h<0.5说明模型预测的是0,但是这是模型预测的标签值y*,而还有个真实标签值y。假如现在h=0.6,那么说明模型预测的标签y*是1,真实标签却是0,

交叉熵= -y(lgh) -(1-y)(lg(1-h))
即当y=1时,交叉熵是-lgh 这个情况下,h越大越好
当y=0时,交叉熵是-(lg(1-h)) 这个情况下,h越小越好

6.训练过程运行结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

7.测试结果

在这里插入图片描述

测试代码


NZ = 100
NUM_CLASS = 10
BATCH_SIZE = 10
DEVICE = "cpu"# fix_input_c = (torch.rand(BATCH_SIZE, 1) * NUM_CLASS).type(torch.LongTensor).squeeze().to(DEVICE)netG = Generator()
netG = restore_network("./", "49", netG)
fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
fix_input_c = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
device = "cuda" if torch.cuda.is_available() else "cpu"
fix_input_c = onehot(fix_input_c, NUM_CLASS)
fix_input_c = fix_input_c.to(device)
fix_noise = fix_noise.to(device)
netG = netG.to(device)
#fake_imgs = netG(fix_noise, fix_input_c).detach().cpu()# images = recover_image(fake_imgs)
# full_image = np.full((1 * 64, 10 * 64, 3), 0, dtype="uint8")
# for i in range(10):
#     row = i // 10
#     col = i % 10
#     full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]#fix_noise = torch.randn(BATCH_SIZE, NZ, device=DEVICE)
full_image = np.full((10 * 64, 10 * 64, 3), 0, dtype="uint8")
for num in range(10):input_c = torch.tensor(np.ones(10, dtype="int64") * num)input_c = onehot(input_c, NUM_CLASS)fix_noise = fix_noise.to(device)input_c = input_c.to(device)fake_imgs = netG(fix_noise, input_c).detach().cpu()images = recover_image(fake_imgs)for i in range(10):row = numcol = i % 10full_image[row * 64:(row + 1) * 64, col * 64:(col + 1) * 64, :] = images[i]plt.imshow(full_image)
plt.show()
plt.imsave("hah.png", full_image)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/174838.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】概率图模型(一)概率图模型理论简介

文章目录 一、概率图模型1. 联合概率表2. 条件独立性假设3. 三个基本问题 二、模型表示1. 有向图模型&#xff08;贝叶斯网络&#xff09;2. 无向图模型&#xff08;马尔可夫网络&#xff09; 三、学习四、推断 概率图模型&#xff08;Probabilistic Graphical Model&#xff0…

ROS知识:卡尔曼滤波

https://en.wikipedia.org/wiki/Kalman_filter 一、提要 在卡尔曼滤波的相关技术文献中,其数学表达看起来都非常晦涩和不透明。这很糟糕,如果您以正确的方式看待卡尔曼滤波器,它实际上非常简单易懂。这里的叙述简单,先决条件也很简单;您所需要的只是对概率和矩阵的基本了解…

【C++】友元

1. 友元的概念 友元的目的就是让一个函数或者类 访问另一个类中私有成员。 友元的三种实现&#xff1a; 全局函数做友元类做友元成员函数做友元 2. 友元的实现方式 2.1 全局函数做友元 #include <iostream> using namespace std; class Building {// 告诉编译器 go…

【Android Gradle】之一小时 Gradle及 wrapper 入门

&#x1f604;作者简介&#xff1a; 小曾同学.com,一个致力于测试开发的博主⛽️&#xff0c;主要职责&#xff1a;测试开发、CI/CD 如果文章知识点有错误的地方&#xff0c;还请大家指正&#xff0c;让我们一起学习&#xff0c;一起进步。 &#x1f60a; 座右铭&#xff1a;不…

PC删除数据,并提示删除成功

<template<el-button size"mini" type"text">分配权限</el-button><el-button size"mini" type"text" click"btnEditRow(row)">编辑</el-button ><el-popconfirmtitle"这是一段内容确定…

计算机毕业设计springboot+vue高校田径运动会报名管理系统61s38

高校田径运动会管理采用java技术&#xff0c;基于springboot框架&#xff0c;mysql数据库进行开发&#xff0c;实现了首页、个人中心、运动员管理、裁判员管理、场地信息管理、项目类型管理、比赛项目管理、比赛报名管理、比赛成绩管理、通知公告管理、留言板管理、交流论坛、系…

微软发布了Orca 2,一对小型语言模型,它们的性能超越了体积更大的同类产品

尽管全球目睹了OpenAI的权力斗争和大规模辞职&#xff0c;但作为AI领域的长期支持者&#xff0c;微软并没有放慢自己的人工智能努力。今天&#xff0c;由萨提亚纳德拉领导的公司研究部门发布了Orca 2&#xff0c;这是一对小型语言模型&#xff0c;它们在零样本设置下对复杂推理…

数据结构---顺序表

文章目录 线性表线性表的定义线性表分类 顺序表顺次表的存储结构实现顺序表的主要接口函数初始化顺序表顺序表尾插顺序表尾删顺序表头插顺序表头删在指定位置插入数据在指定的位置删除数据头插&#xff0c;头删&#xff0c;尾插&#xff0c;尾删新写法打印顺序表销毁顺序表 线性…

基于halo框架采用docker-compose快速部署个人博客

halo快速部署个人博客 技术方案 dockerdocker-composenginxmysql halo简介 Halo是一款现代化的开源博客/CMS系统&#xff0c;所有代码开源在GitHub上且处于积极维护状态。它是基于 Java Spring Boot 构建的&#xff0c;易于部署&#xff0c;支持REST API、模板系统、附件系…

关于微服务的思考

目录 什么是微服务 定义 特点 利弊 引入时机 需要哪些治理环节 从单体架构到微服务架构的演进 单体架构 集群和垂直化 SOA 微服务架构 如何实现微服务架构 服务拆分 主流微服务解决方案 基础设施 下一代微服务架构Service Mesh 什么是Service Mesh&#xff1f…

python实现自动刷平台学时

背景 前一阵子有个朋友让我帮给小忙&#xff0c;因为他每学期都要看视频刷学时&#xff0c;一门平均需要刷500分钟&#xff0c;一学期有3-4门需要刷的。 如果是手动刷的话&#xff0c;比较麻烦&#xff0c;能否帮他做成自动化的。搞成功的话请我吃饭。为了这顿饭&#xff0c;咱…

京东秒杀之商品展示

1 在gitee上添加.yml文件 1.1 添加good-server.yml文件 server:port: 8084 spring:datasource:url: jdbc:mysql://localhost:3306/shop_goods?serverTimezoneGMT%2B8driverClassName: com.mysql.cj.jdbc.Drivertype: com.alibaba.druid.pool.DruidDataSourceusername: rootpa…

多功能音乐沙漏的设计与实现

【摘要】随着当今社会快节奏生活的发展&#xff0c;当代大学生越来忽视时间管理的重要性&#xff0c;在原本计划只看几个视频只玩几个游戏的碎片化娱乐中耗费了大量的时光&#xff0c;对于自己原本的学习生活产生了巨大的影响。为更加有效的反映时间的流逝&#xff0c;特设计该…

第十七章 解读PyTorch断点训练(工具)

主要有以下几方面的内容&#xff1a; 对于多步长训练需要保存lr_schedule初始化随机数种子保存每一代最好的结果 简单详细介绍 最近在尝试用CIFAR10训练分类问题的时候&#xff0c;由于数据集体量比较大&#xff0c;训练的过程中时间比较长&#xff0c;有时候想给停下来&…

Gitee上传代码教程

1. 本地安装git 官网下载太慢&#xff0c;我们也可以使用淘宝镜像下载&#xff1a;CNPM Binaries Mirror 安装成功以后电脑会有Git Bush标识&#xff0c;空白处右键也可查看。 2. 注册gitee账号&#xff08;略&#xff09; 3. 创建远程仓库 4. 上传代码 4.1 在项目文件目录…

go当中的channel 无缓冲channel和缓冲channel的适用场景、结合select的使用

Channel Go channel就像Go并发模型中的“胶水”&#xff0c;它将诸多并发执行单元连接起来&#xff0c;或者正是因为有channel的存在&#xff0c;Go并发模型才能迸发出强大的表达能力。 无缓冲channel 无缓冲channel兼具通信和同步特性&#xff0c;在并发程序中应用颇为广泛。…

坚鹏:贵州银行西南财经大学零售业务数字化转型与场景营销策略

中国银保监会2022年1月正式发布了中国银保监会发布《关于银行业保险业数字化转型的指导意见》&#xff0c;这标准着中国银行业从局部的数字化转型向全面的数字化转型转变&#xff0c;进一步加速了银行数字化转型高潮的到来。 《关于银行业保险业数字化转型的指导意见》提出明确…

【教学类-06-12】20231126 (二)三位数 如何让加减乘除题目从小到大排序(以0-110之间加法为例,做正序排列用)

结果展示 背景需求&#xff1a; 二位数&#xff1a;去0 三位数&#xff08;需要排除很多0&#xff09; 解决思路 一、把数字改成三位数 二、对数组内的题目&#xff0c;8种可能性进行去“0”处理 1、十位数&#xff08;去百位数0&#xff09;十位数&#xff08;去百位数0&am…

数据增强让模型更健壮

在做一些图像分类训练任务时,我们经常会遇到一个很尴尬的情况,那就是: 明明训练数据集中有很多可爱猫咪的照片,但是当我们给训练好的模型输入一张戴着头盔的猫咪进行测试时,模型就不认识了,或者说识别精度很低。 很明显,模型的泛化能力太差,难道戴着头盔的猫咪就不是猫…

线性分类器--数据处理

数据集划分 通常按照 70%&#xff0c;20% &#xff0c;10% 来分数据集 数据处理 斯坦福的线性分类器体验 http://vision.stanford.edu/teaching/cs231n-demos/linear-classify/