深度学习训练营之CGAN生成手势图像

深度学习训练营之CGAN生成手势

  • 原文链接
  • CGAN简单介绍
  • 环境介绍
  • 前置工作
    • 数据
    • 导入所需的包
    • 加载数据
    • 创建数据集
    • 查看数据集
  • 模型设置
    • 初始化模型的权重
    • 定义生成器
    • 构造判别器
  • 模型训练
    • 定义损失函数
    • 设置超参数
    • 正式开始训练
  • 结果可视化

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第G3周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN简单介绍

和前面所提及的文章一样,在 CGAN 中,生成器(Generator)判别器(Discriminato都接收条件信息。生成器的目标是生成与条件信息相关的合成样本,而判别器的目标是将生成的样本与真实样本区分开来。二者之间通过不断的学习,不断提高自身的判别能力以及生成能力
当生成器和判别器通过反馈循环不断地进行训练时,生成器会逐渐学会如何生成符合条件信息的样本,而判别器则会逐渐变得更加准确。在二者达到一个平衡点时就能停止训练

在这里插入图片描述

根据网络结构可知,条件信息 y y y作为额外的输入被引入对抗网络中,与生成器 G G G中的噪声 z z z合并作为隐含层表达;
而在判别器 D D D中,条件信息y则与原始数据x合并作为判别函数的输入。

环境介绍

  • 语言环境:Python3.9.12
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

数据

百度网盘:https://pan.baidu.com/s/1thsYtu_1bzd2yel0K97hYA?pwd=7e3q
将下载的数据集放到运行文件所在目录的data文件夹当中,data文件夹当中设置rps文件夹来放置图片
,同时设置文件夹images_GAN3为后续的训练做准备
在这里插入图片描述
运行文件生成手势图像.ipynb的目录下创建一个training_weights的文件夹
在这里插入图片描述

导入所需的包

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt

加载数据

dataroot = "./data/rps"  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 128   # 图像的尺寸(宽度和高度)
image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100
n_classes = 3     # 条件标签的总数
embedding_dim = 100device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

创建数据集

生成训练所用的数据集,并且创建数据加载器

train_dataset = datasets.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),        # 调整图像大小transforms.ToTensor(),                # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量(0.5, 0.5, 0.5)),]))# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,  # 批量大小shuffle=True,           # 是否打乱数据集num_workers=6 # 使用多个线程加载数据的工作进程数)

查看数据集

def show_images(images):fig, ax = plt.subplots(figsize=(20, 20))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))def show_batch(dl):for images, _ in dl:show_images(images)breakshow_batch(train_loader)

请添加图片描述

模型设置

初始化模型的权重

# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果当前层是卷积层(类名中包含 'Conv' )if classname.find('Conv') != -1:# 使用正态分布随机初始化权重,均值为0,标准差为0.02torch.nn.init.normal_(m.weight, 0.0, 0.02)# 如果当前层是批归一化层(类名中包含 'BatchNorm' )elif classname.find('BatchNorm') != -1:# 使用正态分布随机初始化权重,均值为1,标准差为0.02torch.nn.init.normal_(m.weight, 1.0, 0.02)# 将偏置项初始化为全零torch.nn.init.zeros_(m.bias)

定义生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 定义条件标签的生成器部分,用于将标签映射到嵌入空间中# n_classes:条件标签的总数# embedding_dim:嵌入空间的维度self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),  # 使用Embedding层将条件标签映射为稠密向量nn.Linear(embedding_dim, 16)             # 使用线性层将稠密向量转换为更高维度)# 定义潜在向量的生成器部分,用于将噪声向量映射到图像空间中# latent_dim:潜在向量的维度self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  # 使用线性层将潜在向量转换为更高维度nn.LeakyReLU(0.2, inplace=True)  # 使用LeakyReLU激活函数进行非线性映射)# 定义生成器的主要结构,将条件标签和潜在向量合并成生成的图像self.model = nn.Sequential(# 反卷积层1:将合并后的向量映射为64x8x8的特征图nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  # 批标准化nn.ReLU(True),  # ReLU激活函数# 反卷积层2:将64x8x8的特征图映射为64x4x4的特征图nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层3:将64x4x4的特征图映射为64x2x2的特征图nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层4:将64x2x2的特征图映射为64x1x1的特征图nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层5:将64x1x1的特征图映射为3x64x64的RGB图像nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  # 使用Tanh激活函数将生成的图像像素值映射到[-1, 1]范围内)def forward(self, inputs):noise_vector, label = inputs# 通过条件标签生成器将标签映射为嵌入向量label_output = self.label_conditioned_generator(label)# 将嵌入向量的形状变为(batch_size, 1, 4, 4),以便与潜在向量进行合并label_output = label_output.view(-1, 1, 4, 4)# 通过潜在向量生成器将噪声向量映射为潜在向量latent_output = self.latent(noise_vector)# 将潜在向量的形状变为(batch_size, 512, 4, 4),以便与条件标签进行合并latent_output = latent_output.view(-1, 512, 4, 4)# 将条件标签和潜在向量在通道维度上进行合并,得到合并后的特征图concat = torch.cat((latent_output, label_output), dim=1)# 通过生成器的主要结构将合并后的特征图生成为RGB图像image = self.model(concat)return imagegenerator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Generator((label_conditioned_generator): Sequential((0): Embedding(3, 100)(1): Linear(in_features=100, out_features=16, bias=True))(latent): Sequential((0): Linear(in_features=100, out_features=8192, bias=True)(1): LeakyReLU(negative_slope=0.2, inplace=True))(model): Sequential((0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU(inplace=True)(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(11): ReLU(inplace=True)(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(13): Tanh())
)

构造判别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义一个条件标签的嵌入层,用于将类别标签转换为特征向量self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim),     # 嵌入层将类别标签编码为固定长度的向量nn.Linear(embedding_dim, 3*128*128)         # 线性层将嵌入的向量转换为与图像尺寸相匹配的特征张量)# 定义主要的鉴别器模型self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),       # 输入通道为6(包含图像和标签的通道数),输出通道为64,4x4的卷积核,步长为2,padding为1nn.LeakyReLU(0.2, inplace=True),             # LeakyReLU激活函数,带有负斜率,增加模型对输入中的负值的感知能力nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    # 输入通道为64,输出通道为64*2,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  # 批量归一化层,有利于训练稳定性和收敛速度nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  # 输入通道为64*2,输出通道为64*4,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  # 输入通道为64*4,输出通道为64*8,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),                               # 将特征图展平为一维向量,用于后续全连接层处理nn.Dropout(0.4),                            # 随机失活层,用于减少过拟合风险nn.Linear(4608, 1),                         # 全连接层,将特征向量映射到输出维度为1的向量nn.Sigmoid()                                # Sigmoid激活函数,用于输出范围限制在0到1之间的概率值)def forward(self, inputs):img, label = inputs# 将类别标签转换为特征向量label_output = self.label_condition_disc(label)# 重塑特征向量为与图像尺寸相匹配的特征张量label_output = label_output.view(-1, 3, 128, 128)# 将图像特征和标签特征拼接在一起作为鉴别器的输入concat = torch.cat((img, label_output), dim=1)# 将拼接后的输入通过鉴别器模型进行前向传播,得到输出结果output = self.model(concat)return outputdiscriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

在这里插入图片描述

模型训练

定义损失函数

使用的是BCELoss()
L o s s = − w ∗ [ p ∗ l o g ( q ) + ( 1 − p ) ∗ l o g ( 1 − q ) ] Loss = -w * [p * log(q) + (1-p) * log(1-q)] Loss=w[plog(q)+(1p)log(1q)],其中 p p p q q q分别为理论标签、实际预测值, w w w为权重。这里的 l o g log log对应数学上的 l n ln ln

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction=‘mean’)
计算目标值和预测值之间的二进制交叉熵损失函数。有四个可选参数:weight、size_average、reduce、reduction
adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_loss

设置超参数

设置学习率以及优化器

learning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

正式开始训练

每10轮就进行对生成图像的保存

# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []# 循环进行训练
for epoch in range(1, num_epochs + 1):# 初始化每轮训练中判别器和生成器损失的临时列表D_loss_list, G_loss_list = [], []# 遍历训练数据加载器中的数据for index, (real_images, labels) in enumerate(train_loader):# 清空判别器的梯度缓存D_optimizer.zero_grad()# 将真实图像数据和标签转移到GPU(如果可用)real_images = real_images.to(device)labels      = labels.to(device)# 将标签的形状从一维向量转换为二维张量(用于后续计算)labels = labels.unsqueeze(1).long()# 创建真实目标和虚假目标的张量(用于判别器损失函数)real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))# 计算判别器对真实图像的损失D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)# 从噪声向量中生成假图像(生成器的输入)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)# 计算判别器总体损失(真实图像损失和假图像损失的平均值)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)# 反向传播更新判别器的参数D_total_loss.backward()D_optimizer.step()# 清空生成器的梯度缓存G_optimizer.zero_grad()# 计算生成器的损失G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)# 反向传播更新生成器的参数G_loss.backward()G_optimizer.step()# 打印当前轮次的判别器和生成器的平均损失print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))# 将当前轮次的判别器和生成器的平均损失保存到列表中D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:# 将生成的假图像保存为图片文件save_image(generated_image.data[:50], './data/images_GAN3/sample_%d' % epoch + '.png', nrow=5, normalize=True)# 将当前轮次的生成器和判别器的权重保存到文件torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

在这里插入图片描述

结果可视化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss_plot,label="G")
plt.plot(D_loss_plot,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

这是我训练出来的结果
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/23116.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎样选择适合的爬虫ip服务商?

在当今数字化的时代,越来越多的企业和个人需要采集和分析大量的数据来进行市场调研、竞品分析、舆情监测等工作。而为了保护其数据和资源,很多网站采取了反爬虫措施,限制了普通用户和爬虫程序的访问。为了应对这种限制,许多人开始…

websocket服务端大报文发送连接自动断开分析

概述 当前springboot版本&#xff1a;2.7.4 使用依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency>现象概述&#xff1a; 客户端和服务端已经有心跳…

Collections工具类(java)

文章目录 7.1 常用方法 参考操作数组的工具类&#xff1a;Arrays&#xff0c;Collections 是一个操作 Set、List 和 Map 等集合的工具类。 7.1 常用方法 Collections 中提供了一系列静态的方法对集合元素进行排序、查询和修改等操作&#xff0c;还提供了对集合对象设置不可变、…

C++基础

目录 在Ubuntu 下编写CC简介C环境设置编写一个简单的C程序 C基础C的新特性C的输入输出方式C之命名空间namespaceC面向对象类和对象构造函数与析构函数this 指针 继承重载函数重载运算符重载 多态数据封装数据抽象接口&#xff08;抽象类&#xff09; 在Ubuntu 下编写C 在Ubunt…

k8s nginx+ingress 配置

1 nginx> ingress 配置&#xff1a; 2 nginx >service 配置 3 nginx pod配置&#xff1a; 4 nginx.conf 配置文件&#xff1a; # web端v1server{listen 30006;add_header Strict-Transport-Security "max-age31536000; includeSubDomains";#add_header Content…

Intellij IDEA运行报Command line is too long的解决办法

想哭&#xff0c;vue前端运行起来&#xff0c;对应的后端也得起服务。 后端出的这个bug&#xff0c;下面的博客写的第二种方法&#xff0c;完整截图是下面这个。 ​​​​​​​​​​​​​​​​​​​​Intellij IDEA运行报Command line is too long的解决办法 - 知乎 (zh…

【CSS】旋转中的视差效果

效果 index.html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"/><meta http-equiv"X-UA-Compatible" content"IEedge"/><meta name"viewport" content"widthdevice-…

docker【安装、存储、镜像、仓库、网络、监控】

docker-0110.0.0.51docker-0210.0.0.52docker-0310.0.0.53 【1】docker安装 docker-01 [rootdocker-01 ~]# vim /etc/yum.conf [main] cachedir/var/cache/yum/$basearch/$releasever keepcache1 debuglevel2 logfile/var/log/yum.log exactarch1 obsoletes1 gpgcheck1 plugin…

HDFS的QJM方案

Quorum Journal Manager仲裁日志管理器 介绍主备切换&#xff0c;脑裂问题解决---ZKFailoverController&#xff08;zkfc&#xff09;主备切换&#xff0c;脑裂问题解决-- Fencing&#xff08;隔离&#xff09;机制主备数据状态同步问题解决 HA集群搭建集群基础环境准备HA集群规…

Unity 3D ScrollRect和ScrollView回弹问题的解决

你是否是这样&#xff1f; Content高度 < 全部Cell加在一起的总高 他就认为你的全部Cell加起来就跟Content一样大&#xff0c;所以才出现了这种完全回弹 我该怎么办&#xff1f; 很简单&#xff0c;改变Content的长度跟所有Cell的和一样大 void RefreshSize(){float allD…

【BASH】回顾与知识点梳理(四)

【BASH】回顾与知识点梳理 四 四. Bash Shell 的操作环境4.1 路径与指令搜寻顺序4.2 bash 的进站与欢迎讯息&#xff1a; /etc/issue, /etc/motd4.3 bash 的环境配置文件login与non-login shell/etc/profile (login shell 才会读)~/.bash_profile (login shell 才会读)source &…

好奇心驱使下试验了 chatGPT 的 js 代码的能力

手边的项目中有个函数&#xff0c;主要实现图片分片裁剪功能。可以优化一下。 也想看看 chatGPT 的代码理解能力&#xff0c;优化能力&#xff0c;实现能力&#xff0c;用例能力。 于是有了这篇文章。 实验结果总结&#xff1a; chatGPT 确实强大&#xff0c;提供的答案可以借…

opencv基础40-礼帽运算(原始图像减去其开运算)cv2.MORPH_TOPHAT

礼帽运算是用原始图像减去其开运算图像的操作。礼帽运算能够获取图像的噪声信息&#xff0c;或者得到比原始图像的边缘更亮的边缘信息。 例如&#xff0c;图 8-22 是一个礼帽运算示例&#xff0c;其中&#xff1a; 左图是原始图像。中间的图是开运算图像。右图是原始图像减开运…

python数据处理程序代码,如何用python处理数据

大家好&#xff0c;给大家分享一下python数据处理程序代码&#xff0c;很多人还不知道这一点。下面详细解释一下。现在让我们来看看&#xff01; 要求&#xff1a;分别以james&#xff0c;julie&#xff0c;mikey&#xff0c;sarah四个学生的名字建立文本文件&#xff0c;分别存…

shell指令的应用

整理思维导图判断家目录下&#xff0c;普通文件的个数和目录文件的个数输入一个文件名&#xff0c;判断是否为shell脚本文件&#xff0c;如果是脚本文件&#xff0c;判断是否有可执行权限&#xff0c;如果有可执行权限&#xff0c;运行文件&#xff0c;如果没有可执行权限&…

测试平台——项目模块模型类设计

这里写目录标题 一、项目应用1、项目包含接口&#xff1a;2、创建子应用3、项目模块设计a、模型类设计b、序列化器类设计c、视图类设计 4、接口模块设计a、模型类设计b、序列化器类设计c、视图类设计 5、环境模块设计6、DRF中的通用过滤6.1、设置过滤器后端 一、项目应用 1、项…

SpringBoot统一功能处理(拦截器)

1.用户登录权限校验 1.1自定义拦截器 写一个类去实现HandlerInterceptor接口表示当前类是一个拦截器,再重写HandlerInterceptor接口中的方法,preHandle为在方法执行前拦截,postHandle为方法执行中拦截,afterCompletion为方法执行中拦截.需要在什么时候拦截就重写什么方法 Co…

百度智能云“千帆大模型平台”最新升级:接入Llama 2等33个模型!

今年3月&#xff0c;百度智能云推出“千帆大模型平台”。作为全球首个一站式的企业级大模型平台&#xff0c;千帆不但提供包括文心一言在内的大模型服务及第三方大模型服务&#xff0c;还提供大模型开发和应用的整套工具链&#xff0c;能够帮助企业解决大模型开发和应用过程中的…

人工智能可解释性分析导论(初稿)

目录 思维导图 1.黑箱所带来的问题 2.从应用面论述为什么要进行可解释性分析 2.1可解释性分析指什么 2.2可解释性分析的必要性 2.3可解释性分析应用实例 2.4 可解释性分析的脑回路&#xff08;以可视化为例如何&#xff09; 3.如何研究可解释性分析 3.1使用好解释的模型 3…

ClickHouse SQL与引擎--基本使用(一)

1.查看所有的数据库 show databases; 2.创建库 CREATE DATABASE zabbix ENGINE Ordinary; ATTACH DATABASE ck_test ENGINE Ordinary;3.创建本地表 CREATE TABLE IF NOT EXISTS test01(id UInt64,name String,time UInt64,age UInt8,flag UInt8 ) ENGINE MergeTree PARTI…