第G7周:Semi-Supervised GAN 理论与实战

 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  🍦 参考文章:365天深度学习训练营-第G7周:Semi-Supervised GAN 理论与实战(训练营内部成员可读)

  🍖 原作者:K同学啊|接辅导、项目定制

 🏡 运行环境:
电脑系统:Windows 10
语言环境:python 3.10
编译器:Pycharm 2022.1.1
深度学习环境:Pytorch  


目录

一、理论知识讲解

二、代码实现

1、配置代码 

 2、初始化权重

3、定义算法模型

4、配置模型

 5、训练模型


一、理论知识讲解

该算法将产生式对抗网络(GAN) 拓展到半监督学习,通过强制判别器D来输出类别标签。我们
在一个数据集上训练一个生成器G以及一个判别器D,输入是N类当中的一个。在训练的时候,判别器D被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器
D同时也充当起了分类器C的效果。这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产性更加高质量的样本。Semi-Supervised GAN有如下优点:
(1)作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器。我们把这个 扩展叫做半监督GAN或SGAN
(2)论文实验结果表明,SGAN在有限数据集比没有生成部分的基准分类器提升了分类性能
(3)论文实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间。 

二、代码实现

1、配置代码 
import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)cuda = True if torch.cuda.is_available() else False
Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)
 2、初始化权重
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
3、定义算法模型
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)self.init_size = opt.img_size // 4  # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise):out = self.l1(noise)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4# Output layersself.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label
4、配置模型
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
 5、训练模型
# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorvalidity, _ = discriminator(gen_imgs)g_loss = adversarial_loss(validity, valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesreal_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# Loss for fake imagesfake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2# Calculate discriminator accuracypred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))
[Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799]
[Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131309.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centralized Feature Pyramid for Object Detection解读

Centralized Feature Pyramid for Object Detection 问题 主流的特征金字塔集中于层间特征交互,而忽略了层内特征规则。尽管一些方法试图在注意力机制或视觉变换器的帮助下学习紧凑的层内特征表示,但它们忽略了对密集预测任务非常重要的被忽略的角点区…

【论文精读】PlanT: Explainable Planning Transformers via Object-Level Representations

1 基本信息 院校:德国的图宾根大学 网站:https://www.katrinrenz.de/plant 2 论文背景 2.1 现有问题 现在的基于学习的方法使用高精地图和BEV,认为准确的(达到像素级的pixel-level)场景理解是鲁棒的输出的关键。re…

Java自学第1课:安装JDK+Eclipse

1 引言 在学习前,我想说一句,那就是为什么要学习Java。 每个人的出发点都不同,对于做信息化的工程技术人员来说,java不懂,就没法干项目。 尽管有c和matlab等基础,但java看起来与这些语言都不太一样。 做…

基于单片机的智能饮水机系统

收藏和点赞,您的关注是我创作的动力 文章目录 概要 一、系统设计方案分析2.1 设计功能及性能分析2.2设计方案分析 二、系统的硬件设计3.1 系统设计框图系统软件设计4.1 总体介绍原理图 四、 结论 概要 现在很多学校以及家庭使用的饮水机的功能都是比较单一的&#…

MySQL InnoDB数据存储结构

1. 数据库的存储结构:页 索引结构给我们提供了高效的索引方式,不过索引信息以及数据记录都是保存在文件上的,确切说是存储在页结构中。另一方面,索引是在存储引擎中实现的,MySQL服务器上的存储引擎负责对表中数据的读…

分享68个工作总结PPT,总有一款适合您

分享68个工作总结PPT,总有一款适合您 PPT下载链接:https://pan.baidu.com/s/1juus0gmesBFxJ-5KZgSMdQ?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不易。知识付…

C语言----每日五道选择题Day1

1.第一题 1、指出下列代码的缺陷&#xff08; &#xff09;【多选】 float f[10]; // 假设这里有对f进行初始化的代码 for(int i 0; i < 10;) {if(f[i] 0)break; } A: for(int i 0; i < 10;)这一行写错了 B: f是float型数据直接做相等判断有风险 C: f[i]应该是…

[MICROSAR Adaptive] --- autosar官方文档阅读建议

目前互联网上没有太多的 Adaptive AUTOSAR 的学习资料,官方文档是一个很不错的途径。看过官方文档才发现,目前很多关于 Adaptive AUTOSAR 的文章都是官方文档的简化翻译,不如直接看官方文档更全面深入。 1 Adaptive AUTOSAR 文档官方下载地址 https://www.autosar.org/sta…

微信小程序:实现多个按钮提交表单

效果 核心步骤 通过data-type给不同按钮进行设置&#xff0c;便于很好的区分不同按钮执行不同功能 data-type"" 完整代码 wxml <form action"" bindsubmit"formSubmit"><button style"margin-bottom:5%" data-type"pa…

[黑马程序员SpringBoot2]——运维实用篇

目录&#xff1a; 工程打包与运行打包插件Boot工程快速启动&#xff08;Linux版本&#xff09;临时属性配置文件4级分类自定义配置文件多环境开发(yaml版)多环境开发多文件版&#xff08;yaml版&#xff09;多环境开发多文件版&#xff08;properties版&#xff09;多环境分组…

难题来了:分库分表后,查询太慢了,如何优化?

说在前面&#xff1a; 尼恩社群中&#xff0c;很多小伙伴反馈&#xff0c; Sharding-JDBC 分页查询的速度超级慢&#xff0c; 怎么处理&#xff1f; 反馈这个问题的小伙伴&#xff0c;很多很多。 而且这个问题&#xff0c;也是面试的核心难题。前段时间&#xff0c;有小伙伴…

windows 用vs创建cmake工程并编译opencv应用项目生成exe流程简述

目录 前言一、安装opencv&#xff08;1&#xff09;下载&#xff08;2&#xff09;双击安装&#xff08;3&#xff09;环境变量和system文件夹设置 二、打开vs创建项目三、编辑cpp&#xff0c;.h&#xff0c;cmakelist.txt文件&#xff08;1&#xff09;h文件&#xff08;2&…

【Python从入门到进阶】41、有关requests代理的使用

接上篇《40、requests的基本使用》 上一篇我们介绍了requests库的基本使用&#xff0c;本篇我们来学习requests的代理。 一、引言 在网络爬虫和数据抓取的过程中&#xff0c;我们经常需要发送HTTP请求来获取网页内容或与远程服务器进行通信。然而&#xff0c;在某些情况下&…

通过在Z平面放置零极点的来设计数字滤波器

文章来源地址&#xff1a;https://www.yii666.com/blog/393376.html 通过在Z平面放置零极点的来设计数字滤波器 要求&#xff1a;设计一款高通滤波器&#xff0c;用在音频信号处理过程中&#xff0c;滤掉100Hz以下的信号。 实现方法&#xff1a;通过在Z平面放置零极点的来设…

数据结构与算法【02】—线性表

CSDN系列专栏&#xff1a;数据结构与算法专栏 针对以前写的数据结构与算法系列重写(针对文字描述、图片、错误修复)&#xff0c;改动会比较大&#xff0c;一直到更新完为止 前言 通过前面数据结构与算法基础知识我们知道了数据结构的一些概念和重要性&#xff0c;那么本章总结…

【UE 材质】简单的闪闪发光材质

效果 节点 参考视频&#xff1a; https://www.bilibili.com/video/BV1uK411y737/?vd_source36a3e35639c44bb339f59760641390a8

MySQL(8):聚合函数

聚合函数介绍 聚合函数&#xff1a; 对一组数据进行汇总的函数&#xff0c;输入的是一组数据的集合&#xff0c;输出的是单个值。 聚合函数类型&#xff1a;AVG(),SUM(),MAX(),MIN(),COUNT() AVG / SUM 只适用于数值类型的字段&#xff08;或变量&#xff09; SELECT AVG(…

【LeetCode】每日一题 2023_11_4 数组中两个数的最大异或值

文章目录 刷题前唠嗑题目&#xff1a;数组中两个数的最大异或值题目描述代码与解题思路 结语 刷题前唠嗑 LeetCode? 启动&#xff01;&#xff01;&#xff01; 题目&#xff1a;数组中两个数的最大异或值 题目链接&#xff1a;421. 数组中两个数的最大异或值 题目描述 代…

前端埋点方式

前言&#xff1a; 想要了解用户在系统中所做的操作&#xff0c;从而得出用户在本系统中最常用的模块、在系统中停留的时间。对于了解用户的行为、分析用户的需求有很大的帮助&#xff0c;想实现这种需求可以通过前端埋点的方式。 埋点方式&#xff1a; 1.什么是埋点&#xff1f…

基于Jenkins实现接口自动化持续集成,学完涨薪5k

一、JOB项目配置 1、添加描述 可选选项可填可不填 2、限制项目的运行节点 节点中要有运行环境所需的配置 节点配置教程&#xff1a;https://blog.csdn.net/YZL40514131/article/details/131504280 3、源码管理 需要将脚本推送到远程仓库中 4、构建触发器 可以选择定时构建…