ACGAN

CGAN通过在生成器和判别器中均使用标签信息进行训练,不仅能产生特定标签的数据,还能够提高生成数据的质量;SGAN(Semi-Supervised GAN)通过使判别器/分类器重建标签信息来提高生成数据的质量。既然这两种思路都可以提高生成数据的质量,于是ACGAN综合了以上两种思路,既使用标签信息进行训练,同时也重建标签信息,结合CGAN和SGAN的优点,从而进一步提升生成样本的质量,并且还能根据指定的标签相应的样本。

1. ACGAN的网络结构为:

ACGAN的网络结构框图

        生成器输入包含C_vector和Noise_data两个部分,其中C_vector为训练数据标签信息的One-hot编码张量,其形状为:(batch_size, num_class) ;Noise_data的形状为:(batch_size, latent_dim)。然后将两者进行拼接,拼接完成后,得到的输入张量为:(batch_size, num_class + latent_dim)。生成器的的输出张量为:(batch_size, channel, Height, Width)。

        判别器输入为:(batch_size, channel, Height, Width); 判别的器的输出为两部分,一部分是源数据真假的判断,形状为:(batch_size, 1),一部分是输入数据的分类结果,形状为:(batch_size, class_num)。因此判别器的最后一层有两个并列的全连接层,分别得到这两部分的输出结果,即判别器的输出有两个张量(真假判断张量和分类结果张量)。

2. ACGAN的损失函数:

        对于判别器而言,既希望分类正确,又希望能正确分辨数据的真假;对于生成器而言,也希望能够分类正确,当时希望判别器不能正确分辨假数据。

D_real, C_real = Discriminator( real_imgs)         # real_img 为输入的真实训练图片

D_real_loss = torch.nn.BCELoss(D_real, Y_real)          #  Y_real为真实数据的标签,真数据都为-1,假数据都为+1

C_real_loss = torch.nn.CrossEntropyLoss(C_real, Y_vec)        # Y_vec为训练数据One-hot编码的标签张量

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_fake)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss

生成器的损失函数:  

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_real)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

G_loss = D_fake_loss + C_fake_loss

class Discriminator(nn.Module):  # 定义判别器def __init__(self, img_size=(64, 64), num_classes=2):  # 初始化方法super(Discriminator, self).__init__()  # 继承初始化方法self.img_size = img_size  # 图片尺寸,默认为(64.64)三通道图片self.num_classes = num_classes  # 类别数self.conv1 = nn.Conv2d(3, 128, 4, 2, 1)  # conv操作self.conv2 = nn.Conv2d(128, 256, 4, 2, 1)  # conv操作self.bn2 = nn.BatchNorm2d(256)  # bn操作self.conv3 = nn.Conv2d(256, 512, 4, 2, 1)  # conv操作self.bn3 = nn.BatchNorm2d(512)  # bn操作self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1)  # conv操作self.bn4 = nn.BatchNorm2d(1024)  # bn操作self.leakyrelu = nn.LeakyReLU(0.2)  # leakyrelu激活函数self.linear1 = nn.Linear(int(1024 * (self.img_size[0] / 2 ** 4) * (self.img_size[1] / 2 ** 4)), 1)  # linear映射self.linear2 = nn.Linear(int(1024 * (self.img_size[0] / 2 ** 4) * (self.img_size[1] / 2 ** 4)),self.num_classes)  # linear映射self.sigmoid = nn.Sigmoid()  # sigmoid激活函数self.softmax = nn.Softmax(dim=1)  # softmax激活函数self._init_weitghts()  # 模型权重初始化def _init_weitghts(self):  # 定义模型权重初始化方法for m in self.modules():  # 遍历模型结构if isinstance(m, nn.Conv2d):  # 如果当前结构是convnn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化nn.init.constant_(m.bias, 0)  # b设为0elif isinstance(m, nn.BatchNorm2d):  # 如果当前结构是bnnn.init.constant_(m.weight, 1)  # w设为1nn.init.constant_(m.bias, 0)  # b设为0elif isinstance(m, nn.Linear):  # 如果当前结构是linearnn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化nn.init.constant_(m.bias, 0)  # b设为0def forward(self, x):  # 前传函数x = self.conv1(x)  # conv,(n,3,64,64)-->(n,128,32,32)x = self.leakyrelu(x)  # leakyrelu激活函数x = self.conv2(x)  # conv,(n,128,32,32)-->(n,256,16,16)x = self.bn2(x)  # bn操作x = self.leakyrelu(x)  # leakyrelu激活函数x = self.conv3(x)  # conv,(n,256,16,16)-->(n,512,8,8)x = self.bn3(x)  # bn操作x = self.leakyrelu(x)  # leakyrelu激活函数x = self.conv4(x)  # conv,(n,512,8,8)-->(n,1024,4,4)x = self.bn4(x)  # bn操作x = self.leakyrelu(x)  # leakyrelu激活函数x = torch.flatten(x, 1)  # 三维特征压缩至一位特征向量,(n,1024,4,4)-->(n,1024*4*4)# 根据特征向量x,计算图片真假的得分validity = self.linear1(x)  # linear映射,(n,1024*4*4)-->(n,1)validity = self.sigmoid(validity)  # sigmoid激活函数,将输出压缩至(0,1)# 根据特征向量x,计算图片分类的标签label = self.linear2(x)  # linear映射,(n,1024*4*4)-->(n,2)label = self.softmax(label)  # softmax激活函数,将输出压缩至(0,1)return (validity, label)  # 返回(图像真假的得分,图片分类的标签)class Generator(nn.Module):  # 定义生成器def __init__(self, img_size=(64, 64), num_classes=2, latent_dim=100):  # 初始化方法super(Generator, self).__init__()  # 继承初始化方法self.img_size = img_size  # 图片尺寸,默认为(64.64)三通道图片self.num_classes = num_classes  # 类别数self.latent_dim = latent_dim  # 输入噪声长度,默认为100self.linear = nn.Linear(self.latent_dim, 4 * 4 * 1024)  # linear映射self.bn0 = nn.BatchNorm2d(1024)  # bn操作self.deconv1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)  # transconv操作self.bn1 = nn.BatchNorm2d(512)  # bn操作self.deconv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)  # transconv操作self.bn2 = nn.BatchNorm2d(256)  # bn操作self.deconv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)  # transconv操作self.bn3 = nn.BatchNorm2d(128)  # bn操作self.deconv4 = nn.ConvTranspose2d(128, 3, 4, 2, 1)  # transconv操作self.relu = nn.ReLU(inplace=True)  # relu激活函数self.tanh = nn.Tanh()  # tanh激活函数self.embedding = nn.Embedding(self.num_classes, self.latent_dim)  # embedding操作self._init_weitghts()  # 模型权重初始化def _init_weitghts(self):  # 定义模型权重初始化方法for m in self.modules():  # 遍历模型结构if isinstance(m, nn.ConvTranspose2d):  # 如果当前结构是transconvnn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化nn.init.constant_(m.bias, 0)  # b设为0elif isinstance(m, nn.BatchNorm2d):  # 如果当前结构是bnnn.init.constant_(m.weight, 1)  # w设为1nn.init.constant_(m.bias, 0)  # b设为0elif isinstance(m, nn.Linear):  # 如果当前结构是linearnn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化nn.init.constant_(m.bias, 0)  # b设为0def forward(self, input: tuple):  # 前传函数noise, label = input  # 从输入的元组中获取噪声向量和标签信息label = self.embedding(label)  # 标签信息经过embedding操作,变成与噪声向量尺寸相同的稠密向量z = torch.multiply(noise, label)  # 噪声向量与标签稠密向量相乘,得到带有标签信息的噪声向量z = self.linear(z)  # linear映射,(n,100)-->(n,1024*4*4)z = z.view((-1, 1024, int(self.img_size[0] / 2 ** 4),int(self.img_size[1] / 2 ** 4)))  # 一维特征向量扩展至三维特征,(n,1024*4*4)-->(n,1024,4,4)z = self.bn0(z)  # bn操作z = self.relu(z)  # relu激活函数z = self.deconv1(z)  # trainsconv操作,(n,1024,4,4)-->(n,512,8,8)z = self.bn1(z)  # bn操作z = self.relu(z)  # relu激活函数z = self.deconv2(z)  # trainsconv操作,(n,512,8,8)-->(n,256,16,16)z = self.bn2(z)  # bn操作z = self.relu(z)  # relu激活函数z = self.deconv3(z)  # trainsconv操作,(n,256,16,16)-->(n,128,32,32)z = self.bn3(z)  # bn操作z = self.relu(z)  # relu激活函数z = self.deconv4(z)  # trainsconv操作,(n,128,32,32)-->(n,3,64,64)z = self.tanh(z)  # tanh激活函数,将输出压缩至(-1,1)return z  # 返回生成图像

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/90714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 使用kotlin+注解+反射+泛型实现MVP架构

一,MVP模式的定义 ①Model:用于存储数据。它负责处理领域逻辑以及与数据库或网络层的通信。 ②View:UI层,提供数据可视化界面,并跟踪用户的操作,以便通知presenter。 ③Presenter:从Model层获…

FPGA的数字钟带校时闹钟报时功能VHDL

名称:基于FPGA的数字钟具有校时闹钟报时功能 软件:Quartus 语言:VHDL 要求: 1、计时功能:这是数字钟设计的基本功能,每秒钟更新一次,并且能在显示屏上显示当前的时间。 2、闹钟功能:如果当前的时间与闹钟设置的时…

【AI视野·今日NLP 自然语言处理论文速览 第四十四期】Fri, 29 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Fri, 29 Sep 2023 Totally 45 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers MindShift: Leveraging Large Language Models for Mental-States-Based Problematic Smartphone Use Interve…

软件测试(测试用例攻略)—写用例无压力

一、概念 测试用例的基本概念: 测试用例(Test Case)是为了实施测试而向被测试的系统提供的一组集合,这组集合包含:测试环境、操作步骤、测试数据、预期结果等要素 。 主要步骤: 测试环境——测试步骤—…

C理解(二):指针,数组,字符串,函数

本文主要探讨指针,数组,字符串,函数 指针 int *p; 未绑定:*表示p为指针变量,占4字节 int a 1;p &a; 绑定:p与a地址绑定即p中存放a的地址 *p *p 1; 解引用:p间接访问a的存储空间…

【接口测试】HTTP协议

一、HTTP 协议基础 HTTP 简介 HTTP 是一个客户端终端(用户)和服务器端(网站)请求和应答的标准(TCP)。通常是由客户端发起一个请求,创建一个到服务器的 TCP 连接,当服务器监听到客户…

第十四届蓝桥杯大赛软件赛决赛 C/C++ 大学 B 组 试题 C: 班级活动

[蓝桥杯 2023 国 B] 班级活动 【问题描述】 小明的老师准备组织一次班级活动。班上一共有 n n n 名( n n n 为偶数)同学,老师想把所有的同学进行分组,每两名同学一组。为了公平,老师给每名同学随机分配了一个 n n …

以太坊智能合约的历史里程碑: 从DAO到数据隐私的技术演进

文章目录 系列文章目录前言一、时间线 项目介绍总结 前言 在短短的几年内,以太坊不仅成为了去中心化应用和智能合约的主导平台,而且也见证了区块链技术和应用的多次重大革命。本文详细回顾了自2016年至今,以太坊生态所经历的几个关键时刻与技…

leetcodetop100(29) K 个一组翻转链表

K 个一组翻转链表 给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那么请将最后剩余的节点保持原有顺序。 你不能只是单纯的改…

React Native搭建Android开发环境

React Native搭建Android开发环境 搭建Android开发环境一、下载JDK二、安装Android Studio2.1 配置 ANDROID_HOME 环境变量 三、初始化项目 搭建Android开发环境 我的电脑是windows系统,所以只能搭建Android,如果电脑是mac,既可以搭建Androi…

修改sqlmap-Tamper脚本

修改sqlmap-Tamper脚本 文章目录 修改sqlmap-Tamper脚本1 sqlmap官网2 sql注入漏洞注入尝试3 环境:sqli-labs/Less-26a/3.1 尝试宽字节注入: 3.2 sqlmap使用3.3准备修改sqlmap使用 4 sqlmap中-tamper工厂(输入输出)4.1 [参考文章:…

蓝桥杯 题库 简单 每日十题 day11

01 质数 质数 题目描述 给定一个正整数N,请你输出N以内(不包含N)的质数以及质数的个数。 输入描述 输入一行,包含一个正整数N。1≤N≤10^3 输出描述 共两行。 第1行包含若干个素数,每两个素数之间用一个空格隔开&…

rust生命期

一、生命期是什么 生命期,又叫生存期,就是变量的有效期。 实例1 {let r;{let x 5;r &x;}println!("r: {}", r); }编译错误,原因是r所引用的值已经被释放。 上图中的绿色范围’a表示r的生命期,蓝色范围’b表示…

pygame实现跳跃发射子弹打怪效果

import pygame import sys,time,random from pygame.locals import * pygame.init() # 设置按下鼠标的时候一直触发 pygame.key.set_repeat(10, 10) # 加载背景图片 bg pygame.image.load(./img/bg.png) # 加载左方向行走和站立图片 heroLStand pygame.image.load(img/heroLs…

传统遗产与技术相遇,古彝文的数字化与保护

古彝文是中国彝族的传统文字,具有悠久的历史和文化价值。然而,由于古彝文的形状复杂且没有标准化的字符集,对其进行文字识别一直是一项具有挑战性的任务。本文介绍了古彝文合合信息的文字识别技术,旨在提高古彝文的自动识别准确性…

十七,IBL-打印各个Mipmap级别的hdr环境贴图

预滤波环境贴图类似于辐照度图,是预先计算的环境卷积贴图,但这次考虑了粗糙度。因为随着粗糙度的增加,参与环境贴图卷积的采样向量会更分散,导致反射更模糊,所以对于卷积的每个粗糙度级别,我们将按顺序把模…

【单片机】11-步进电机和直流电机

1.直流电机 1.什么是电机 电能转换为动能 2.常见电机 (1)交流电机【大功率】:两相【200W左右】,三相【1000W左右】 (2)直流电机【小功率】:永磁【真正的磁铁】,励磁【电磁铁】 &…

3种Renko图表形态FPmarkets3秒轻松判断价格走势

Renko图表形态在交易中的应用并不逊色于其他技术分析方法。相较于普通的烛台图表,使用Renko图表时,有些经典模式更容易被发现和识别,FPmarkets总结这些模式包括: 首先是头和肩膀形态。这是一种价格反转形态,由两个较小…

华为智能企业远程办公安全解决方案(1)

华为智能企业远程办公安全解决方案(1) 课程地址方案背景需求分析企业远程办公业务概述企业远程办公安全风险分析企业远程办公环境搭建需求分析 方案设计组网架构设备选型方案亮点 课程地址 本方案相关课程资源已在华为O3社区发布,可按照以下…

Java编码技巧:验证码

目录 1.1、EasyCaptcha(优选,支持种类多,样式多,使用简单)1.1.1、作用1.1.2、官方信息1.1.3、使用案例1.1.4、依赖1.1.5、代码1.1.6、效果1.1.7、拓展 1.2、kaptcha1.2.1、作用1.2.2、官方信息1.2.3、使用案例1.2.4、依…