第G7周:Semi-Supervised GAN 理论与实战

 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  🍦 参考文章:365天深度学习训练营-第G7周:Semi-Supervised GAN 理论与实战(训练营内部成员可读)

  🍖 原作者:K同学啊|接辅导、项目定制

 🏡 运行环境:
电脑系统:Windows 10
语言环境:python 3.10
编译器:Pycharm 2022.1.1
深度学习环境:Pytorch  


目录

一、理论知识讲解

二、代码实现

1、配置代码 

 2、初始化权重

3、定义算法模型

4、配置模型

 5、训练模型


一、理论知识讲解

该算法将产生式对抗网络(GAN) 拓展到半监督学习,通过强制判别器D来输出类别标签。我们
在一个数据集上训练一个生成器G以及一个判别器D,输入是N类当中的一个。在训练的时候,判别器D被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器
D同时也充当起了分类器C的效果。这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产性更加高质量的样本。Semi-Supervised GAN有如下优点:
(1)作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器。我们把这个 扩展叫做半监督GAN或SGAN
(2)论文实验结果表明,SGAN在有限数据集比没有生成部分的基准分类器提升了分类性能
(3)论文实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间。 

二、代码实现

1、配置代码 
import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)cuda = True if torch.cuda.is_available() else False
Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)
 2、初始化权重
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
3、定义算法模型
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)self.init_size = opt.img_size // 4  # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise):out = self.l1(noise)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4# Output layersself.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label
4、配置模型
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
 5、训练模型
# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorvalidity, _ = discriminator(gen_imgs)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesreal_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# Loss for fake imagesfake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2# Calculate discriminator accuracypred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))
[Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799]
[Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131309.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centralized Feature Pyramid for Object Detection解读

Centralized Feature Pyramid for Object Detection 问题 主流的特征金字塔集中于层间特征交互,而忽略了层内特征规则。尽管一些方法试图在注意力机制或视觉变换器的帮助下学习紧凑的层内特征表示,但它们忽略了对密集预测任务非常重要的被忽略的角点区…

【论文精读】PlanT: Explainable Planning Transformers via Object-Level Representations

1 基本信息 院校:德国的图宾根大学 网站:https://www.katrinrenz.de/plant 2 论文背景 2.1 现有问题 现在的基于学习的方法使用高精地图和BEV,认为准确的(达到像素级的pixel-level)场景理解是鲁棒的输出的关键。re…

Java自学第1课:安装JDK+Eclipse

1 引言 在学习前,我想说一句,那就是为什么要学习Java。 每个人的出发点都不同,对于做信息化的工程技术人员来说,java不懂,就没法干项目。 尽管有c和matlab等基础,但java看起来与这些语言都不太一样。 做…

【从0到1设计一个网关】基于Hystrix实现熔断降级

文章目录 依赖引入服务降级效果演示上文我们已经成功实现了请求重试与请求限流,接下来我们开始实现熔断与服务降级。 熔断与服务降级,在SpringCloud中设计到的就是我们的hystrix,这里我们也将会考虑配合hystrix来实现熔断与服务降级。 如果不了解hystix的可以先进行一下了解…

LeetCode 面试题 16.16. 部分排序

文章目录 一、题目二、C# 题解 一、题目 给定一个整数数组,编写一个函数,找出索引m和n,只要将索引区间[m,n]的元素排好序,整个数组就是有序的。注意:n-m尽量最小,也就是说,找出符合条件的最短序…

基于单片机的智能饮水机系统

收藏和点赞,您的关注是我创作的动力 文章目录 概要 一、系统设计方案分析2.1 设计功能及性能分析2.2设计方案分析 二、系统的硬件设计3.1 系统设计框图系统软件设计4.1 总体介绍原理图 四、 结论 概要 现在很多学校以及家庭使用的饮水机的功能都是比较单一的&#…

【JavaScript】事件监听、事件委托和回调函数

1. 事件监听 on 方法:box.onclick function(){},但是这种方式会被覆盖,直接使用null覆盖偶就可以实现事件的解绑。 addEventListener 是 DOM 对象专门用来添加事件监听的方法,它的前两个参数分别为【事件类型】和【事件回调】&…

MySQL InnoDB数据存储结构

1. 数据库的存储结构:页 索引结构给我们提供了高效的索引方式,不过索引信息以及数据记录都是保存在文件上的,确切说是存储在页结构中。另一方面,索引是在存储引擎中实现的,MySQL服务器上的存储引擎负责对表中数据的读…

第四次pta认证P测试

第一题 试题编号: 试题名称:整数排序 时间限制: 1.0s 内存限制: 128.0MB 【问题描述】 老师给定 10 个整数的序列,要求对其重新排序。排序要求: 1.奇数在前,偶数在后; 2.奇数按从大到小排序&am…

分享68个工作总结PPT,总有一款适合您

分享68个工作总结PPT,总有一款适合您 PPT下载链接:https://pan.baidu.com/s/1juus0gmesBFxJ-5KZgSMdQ?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不易。知识付…

C语言----每日五道选择题Day1

1.第一题 1、指出下列代码的缺陷&#xff08; &#xff09;【多选】 float f[10]; // 假设这里有对f进行初始化的代码 for(int i 0; i < 10;) {if(f[i] 0)break; } A: for(int i 0; i < 10;)这一行写错了 B: f是float型数据直接做相等判断有风险 C: f[i]应该是…

[MICROSAR Adaptive] --- autosar官方文档阅读建议

目前互联网上没有太多的 Adaptive AUTOSAR 的学习资料,官方文档是一个很不错的途径。看过官方文档才发现,目前很多关于 Adaptive AUTOSAR 的文章都是官方文档的简化翻译,不如直接看官方文档更全面深入。 1 Adaptive AUTOSAR 文档官方下载地址 https://www.autosar.org/sta…

微信小程序:实现多个按钮提交表单

效果 核心步骤 通过data-type给不同按钮进行设置&#xff0c;便于很好的区分不同按钮执行不同功能 data-type"" 完整代码 wxml <form action"" bindsubmit"formSubmit"><button style"margin-bottom:5%" data-type"pa…

SpringBoot 多组 Kafka 配置

SpringBoot 多组 Kafka 配置 单组 Kafka 配置 时隔多日&#xff0c;冒个泡吧。 场景 是 我在日常的开发过程中需要监听 kafka 的消息进行回调处理&#xff0c;但是呢&#xff0c;不同的三方服务他们用了不同的 kafka 集群&#xff0c;那么默认的 Spring 自动读取的 kafka 配…

[黑马程序员SpringBoot2]——运维实用篇

目录&#xff1a; 工程打包与运行打包插件Boot工程快速启动&#xff08;Linux版本&#xff09;临时属性配置文件4级分类自定义配置文件多环境开发(yaml版)多环境开发多文件版&#xff08;yaml版&#xff09;多环境开发多文件版&#xff08;properties版&#xff09;多环境分组…

vue如何实现视频全屏切换

最近项目开发中遇到一个视频窗口全屏切换功能&#xff0c;为此在这里做个记录。 具体的实现思路&#xff1a; <template><div class"content-box"><div class"container"><div id"screen" class"screen"><…

难题来了:分库分表后,查询太慢了,如何优化?

说在前面&#xff1a; 尼恩社群中&#xff0c;很多小伙伴反馈&#xff0c; Sharding-JDBC 分页查询的速度超级慢&#xff0c; 怎么处理&#xff1f; 反馈这个问题的小伙伴&#xff0c;很多很多。 而且这个问题&#xff0c;也是面试的核心难题。前段时间&#xff0c;有小伙伴…

MySQL数据库干货_13—— MySQL查询数据

MySQL查询数据 SELECT基本查询 SELECT语句的功能 SELECT 语句从数据库中返回信息。使用一个 SELECT 语句&#xff0c;可以做下面的事&#xff1a; 列选择&#xff1a;能够使用 SELECT 语句的列选择功能选择表中的列&#xff0c;这些列是想 要用查询返回的。当查询时&#xf…

vue-render函数的三个参数

第一个参数(必须) - {String | Object | Function} Vue.component(elem, {render: function(createElement) {return createElement(div);//一个HTML标签字符/*return createElement({template: <div></div>//组件选项对象});*//*var func function() {return {t…

使用electron ipcRenderer接收通信消息多次触发

使用electron ipcRenderer接收通信消息多次触发 在使用electron ipcRenderer.on接收ipcRenderer.send的返回值时&#xff0c;ipcRenderer.send发送一次信息&#xff0c; ipcRenderer.on会打印多个日志&#xff0c; renderer.once(get-file-path, (event: any, paths: any) &g…