CGAN|生成手势图像|可控制生成

  •     🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

一、前期工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')batch_size = 128
train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,num_workers=6)
def show_images(dl):for images, _ in dl:fig, ax = plt.subplots(figsize=(10, 10))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))breakshow_images(train_loader)

 

二、构建模型 

latent_dim = 100
n_classes = 3
embedding_dim = 100def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim), nn.Linear(embedding_dim, 16)            )self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  nn.LeakyReLU(0.2, inplace=True)  )self.model = nn.Sequential( nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  nn.ReLU(True),            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),     nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),       nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  )def forward(self, inputs):noise_vector, label = inputs  label_output = self.label_conditioned_generator(label)     label_output = label_output.view(-1, 1, 4, 4)        latent_output = self.latent(noise_vector)     latent_output = latent_output.view(-1, 512, 4, 4) concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return imagegenerator = Generator().to(device)
generator.apply(weights_init)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)import torch
import torch.nn as nnclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),     nn.Linear(embedding_dim, 3*128*128)         )self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),      nn.LeakyReLU(0.2, inplace=True),             nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),                               nn.Dropout(0.4),                            nn.Linear(4608, 1),                         nn.Sigmoid()                                )def forward(self, inputs):img, label = inputslabel_output = self.label_condition_disc(label)label_output = label_output.view(-1, 3, 128, 128)concat = torch.cat((img, label_output), dim=1)output = self.model(concat)return outputa = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)c = discriminator((a,b))

 三、训练模型及可视化

 这一部分主要定义初始化权重,构建鉴别器和生成器。

# 定义损失函数
adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss
learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []# 循环进行训练
for epoch in range(1, num_epochs + 1):# 初始化每轮训练中判别器和生成器损失的临时列表D_loss_list, G_loss_list = [], []# 遍历训练数据加载器中的数据for index, (real_images, labels) in enumerate(train_loader):# 清空判别器的梯度缓存D_optimizer.zero_grad()# 将真实图像数据和标签转移到GPU(如果可用)real_images = real_images.to(device)labels      = labels.to(device)# 将标签的形状从一维向量转换为二维张量(用于后续计算)labels = labels.unsqueeze(1).long()# 创建真实目标和虚假目标的张量(用于判别器损失函数)real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))# 计算判别器对真实图像的损失D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)# 从噪声向量中生成假图像(生成器的输入)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)# 计算判别器总体损失(真实图像损失和假图像损失的平均值)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)# 反向传播更新判别器的参数D_total_loss.backward()D_optimizer.step()# 清空生成器的梯度缓存G_optimizer.zero_grad()# 计算生成器的损失G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)# 反向传播更新生成器的参数G_loss.backward()G_optimizer.step()# 打印当前轮次的判别器和生成器的平均损失print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))# 将当前轮次的判别器和生成器的平均损失保存到列表中D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:# 将生成的假图像保存为图片文件save_image(generated_image.data[:50], './sample_%d' % epoch + '.png', nrow=5, normalize=True)# 将当前轮次的生成器和判别器的权重保存到文件torch.save(generator.state_dict(), './generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './discriminator_epoch_%d.pth' % (epoch))

 from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot as plt, gridspec
import numpy as np# Assuming 'generator' and 'device' are defined earlier in your codegenerator.load_state_dict(torch.load('./generator_epoch_100.pth'), strict=False)
generator.eval()interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100plt.figure(figsize=(8, 3))pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

 代码中的操作将预测结果的值加1(这样所有的值都变为非负数),然后乘以127.5,最终得到的值就在0到255之间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/839914.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytorch-13_2 模型结构选择策略:层数、激活函数、神经元个数

一、拟合度概念 在所有的模型优化问题中,最基础的也是最核心的问题,就是关于模型拟合程度的探讨与优化。根据此前的讨论,模型如果能很好的捕捉总体规律,就能够有较好的未知数据的预测效果。但限制模型捕捉总体规律的原因主要有两点…

C++:vector基础讲解

hello,各位小伙伴,本篇文章跟大家一起学习《C:vector基础讲解》,感谢大家对我上一篇的支持,如有什么问题,还请多多指教 ! 如果本篇文章对你有帮助,还请各位点点赞!&#…

day15|各种遍历的应用

相关题目&#xff1a; 层次遍历会一打十 反转二叉树 对称二叉树 层次遍历会一打十 自底向上的层序遍历 实现思路&#xff1a;层次遍历二叉树&#xff0c;将遍历后的结果revers即可 public List<List<Integer>> levelOrderBottom(TreeNode root) {List<List&l…

框架学习之SpringMVC学习笔记(一)

一、SpringMVC简介 1-介绍 Spring Web MVC是基于Servlet API构建的原始Web框架&#xff0c;从一开始就包含在Spring Framework中。正式名称“Spring Web MVC”来自其源模块的名称&#xff08; spring-webmvc &#xff09;&#xff0c;但它通常被称为“Spring MVC”。 在控制层…

一文深度剖析 ColBERT

近年来&#xff0c;向量搜索领域经历了爆炸性增长&#xff0c;尤其是在大型语言模型&#xff08;LLMs&#xff09;问世后。学术界开始重点关注如何通过扩展训练数据、采用先进的训练方法和新的架构等方法来增强 embedding 向量模型。 在之前的文章中&#xff0c;我们已经深入探…

记录踩坑事件 分页查询order by出现重复数据bug

MySQL排序小坑_mysql order by name相同导致排序混乱-CSDN博客 1、问题描述 列表页分页查询出现重复数据。 2、问题排查 排查最终执行sql日志。 select * from tableA where (start_time>2024-04-17 00:00:00) AND (start_time<2024-05-18 00:00:00) ORDER BY sta…

AIGC基础教学:AI+建筑设计,一场划时代变革的序幕已经拉开

2015年9月&#xff0c;美的集团本着把艺术融入民间的理念&#xff0c;邀请了安藤忠雄设计正在筹建中的美术馆。 在历经长达近120天的设计工作之后&#xff0c;美术馆于同年12月动工。这座具有岭南建筑文化意境的美术馆&#xff0c;后来荣获2020年美国建筑大师奖(Architecture …

【ArcGIS微课1000例】0112:沿线(面)按距离或百分比生成点

文章目录 一、沿线生成点工具介绍二、线状案例三、面状案例一、沿线生成点工具介绍 位置:工具箱→数据管理工具→采样→沿线生成点 摘要:沿线或面以固定间隔或百分比创建点要素。 用法:输入要素的属性将保留在输出要素类中。向输出要素类添加新字段 ORIG_FID,并设置为输…

Java.lang.InterruptedException被中止异常解决方案

大家好&#xff01;我是咕噜铁蛋&#xff01;在Java编程的世界里&#xff0c;java.lang.InterruptedException是一个常见的异常&#xff0c;尤其是在处理多线程和并发任务时。这个异常通常表示一个线程在等待、休眠或其他占用时间不长的操作时被中断。作为一个资深的Java开发者…

Navicat 连接 OceanBase 快速入门 | 社区版

Navicat Premium&#xff08;16.1.9或更高版本&#xff09;正式支持 OceanBase全线数据库产品。OceanBase为现代数据架构打造的开源分布式数据库。兼容 MySQL 的单机分布式一体化国产开源数据库&#xff0c;具有原生分布式架构&#xff0c;支持金融级高可用、透明水平扩展、分布…

CCF CAT- 全国算法精英大赛(2024第二场)往届真题练习 2 | 珂学家

前言 这是第二场CCF的练习赛&#xff0c;找找手感&#xff0c;顺便熟悉下赛氪OJ平台。 当前就做了5题&#xff0c;感觉还可以&#xff0c;部分题目质量蛮高的&#xff0c;但是易错。 第1题dp入门题&#xff0c; 第5属于诈骗题&#xff0c;第2和第3挺有难度的&#xff0c;第四…

【杂七杂八】Huawei Gt runner手表系统降级

文章目录 Step1&#xff1a;下载安装修改版华为运动与健康Step2&#xff1a;在APP里进行配置Step3&#xff1a;更新固件(时间会很长) 目前在使用用鸿蒙4 111版本的手表系统&#xff0c;但是感觉睡眠检测和运动心率检测一言难尽&#xff0c;于是想到是否能回退到以前的版本&…

设计模式14——组合模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用&#xff0c;主要是下面的UML图可以起到大作用&#xff0c;在你学习过一遍以后可能会遗忘&#xff0c;忘记了不要紧&#xff0c;只要看一眼UML图就能想起来了。同时也请大家多多指教。 组合模式&#xff08;Composit…

LeetCode199二叉树的右视图

题目描述 给定一个二叉树的 根节点 root&#xff0c;想象自己站在它的右侧&#xff0c;按照从顶部到底部的顺序&#xff0c;返回从右侧所能看到的节点值。 解析 这一题的关键其实就是找到怎么去得到当前是哪一层级&#xff0c;可以利用队列对二叉树进行层次遍历&#xff0c;但…

ICRA 2024: NVIDIA 联合多伦多大学、加州大学伯克利分校、苏黎世联邦理工学院等研究人员开发了精细操作的手术机器人

英伟达&#xff08;NVIDIA&#xff09;正与学术研究人员合作&#xff0c;研究手术机器人。 NVIDIA 联合多伦多大学、加州大学伯克利分校、苏黎世联邦理工学院和佐治亚理工学院的研究人员开发了 ORBIT-Surgical&#xff0c;一个训练机器人的模拟框架&#xff0c;可以提高手术团…

vue3的api风格

Vue的组件有两种不同的风格&#xff1a;组合式API 和 选项式API 选项式api 选项式API&#xff0c;可以用包含多个选项的对象来描述组件的逻辑&#xff0c;如&#xff1a;data&#xff0c;methods&#xff0c;mounted等。 组合式api setup&#xff1a;是一个标识&#xff0c;告…

图像上下文学习|多模态基础模型中的多镜头情境学习

【原文】众所周知&#xff0c;大型语言模型在小样本上下文学习&#xff08;ICL&#xff09;方面非常有效。多模态基础模型的最新进展实现了前所未有的长上下文窗口&#xff0c;为探索其执行 ICL 的能力提供了机会&#xff0c;并提供了更多演示示例。在这项工作中&#xff0c;我…

Docker简单使用

1.简单认识 软件的打包技术&#xff0c;就是将打乱的多个文件打包为一个整体&#xff0c;比如想使用nginx&#xff0c;需要先有一台linux的虚拟机&#xff0c;然后在虚拟机上安装nginx.比如虚拟机大小1G&#xff0c;nginx100M。当有了docker后我们可以下载nginx 的镜像文件&am…

【openlayers系统学习】1.6下载要素,将要素数据序列化为 GeoJSON并下载

六、下载要素 下载要素 上传数据并编辑后&#xff0c;我们想让用户下载结果。为此&#xff0c;我们将要素数据序列化为 GeoJSON&#xff0c;并创建一个带有 download​ 属性的 <a>​ 元素&#xff0c;该属性会触发浏览器的文件保存对话框。同时&#xff0c;我们将在地图…

Linux--07---查看CPU、内存、磁盘

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 常用命令1.查看CPU使用率1.1 top 命令第一行是任务队列信息&#xff1a; top第二行为进程的信息 Tasks第三行为CPU的信息Mem:Swap 1.2 vmstat命令参数详解每个参数的…