Pytorch从零开始实战19

Pytorch从零开始实战——生成手势图像

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——生成手势图像
    • 环境准备
    • 模型选择
    • 模型训练
    • 模型分析
    • 总结

环境准备

本文基于Jupyter notebook,使用Python3.8,Pytorch2.0.1+cu118,torchvision0.15.2,需读者自行配置好环境且有一些深度学习理论基础。本次实验的目的是了解并使用CGAN模型,完成手势图像生成。
第一步,导入常用包

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt

查看设备对象

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128
device # device(type='cuda')

使用transform调整数据集图像,本次数据集使用手势图像,用于生成对抗网络的训练。其中数据集源于K同学

train_transform = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])train_dataset = datasets.ImageFolder(root='./data/Gan3/rps/', transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,num_workers=6)

查看数据集图像

def show_images(images):fig, ax = plt.subplots(figsize=(20, 20))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))def show_batch(dl):for images, _ in dl:show_images(images)breakshow_batch(train_loader)

在这里插入图片描述
设置超参数

image_shape = (3, 128, 128) # 生成器和判别器操作的图像的形状
image_dim = int(np.prod(image_shape)) # 图像的总维度
latent_dim = 100 # 生成器的输入随机噪声向量的维度n_classes = 3 # 条件生成对抗网络中的类别数量
embedding_dim = 100 # 类别嵌入向量的维度

模型选择

条件生成对抗网络(CGAN)是一种生成对抗网络(GAN)的变体,它引入了条件信息以指导生成器生成特定类型的输出。在标准的生成对抗网络中,生成器的输入是一个随机噪声向量,而在条件生成对抗网络中,生成器的输入不仅包括随机噪声向量,还包括一个额外的条件向量,用于指定所需输出的特征。
例如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。条件生成对抗网络的本质是将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能够在我们的掌控下更好地进行训练。

使用weights_init(m)对神经网络的权重进行初始化

# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果当前层是卷积层(类名中包含 'Conv' )if classname.find('Conv') != -1:# 使用正态分布随机初始化权重,均值为0,标准差为0.02torch.nn.init.normal_(m.weight, 0.0, 0.02)# 如果当前层是批归一化层(类名中包含 'BatchNorm' )elif classname.find('BatchNorm') != -1:# 使用正态分布随机初始化权重,均值为1,标准差为0.02torch.nn.init.normal_(m.weight, 1.0, 0.02)# 将偏置项初始化为全零torch.nn.init.zeros_(m.bias)

构建生成器,下面代码实现了一个生成器模型,它接受一个随机噪声向量和一个条件标签作为输入,并生成一个与条件标签相关的合成图像。首先,条件标签被嵌入到一个稠密向量中,将离散的标签映射到连续的嵌入空间,以便与噪声向量进行合并。接下来,从随机噪声向量中生成潜在向量。将噪声向量映射到一个更高维度的表示空间中,最后,将生成的条件标签嵌入向量和潜在向量在通道维度上合并。然后,将合并后的特征图通过一系列反卷积层进行处理,逐渐将其转换为与所需输出图像相同尺寸的特征图。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 定义条件标签的生成器部分,用于将标签映射到嵌入空间中# n_classes:条件标签的总数# embedding_dim:嵌入空间的维度self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),  # 使用Embedding层将条件标签映射为稠密向量nn.Linear(embedding_dim, 16)             # 使用线性层将稠密向量转换为更高维度)# 定义潜在向量的生成器部分,用于将噪声向量映射到图像空间中# latent_dim:潜在向量的维度self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),  # 使用线性层将潜在向量转换为更高维度nn.LeakyReLU(0.2, inplace=True)  # 使用LeakyReLU激活函数进行非线性映射)# 定义生成器的主要结构,将条件标签和潜在向量合并成生成的图像self.model = nn.Sequential(# 反卷积层1:将合并后的向量映射为64x8x8的特征图nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  # 批标准化nn.ReLU(True),  # ReLU激活函数# 反卷积层2:将64x8x8的特征图映射为64x4x4的特征图nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层3:将64x4x4的特征图映射为64x2x2的特征图nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层4:将64x2x2的特征图映射为64x1x1的特征图nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),# 反卷积层5:将64x1x1的特征图映射为3x64x64的RGB图像nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh()  # 使用Tanh激活函数将生成的图像像素值映射到[-1, 1]范围内)def forward(self, inputs):noise_vector, label = inputs# 通过条件标签生成器将标签映射为嵌入向量label_output = self.label_conditioned_generator(label)# 将嵌入向量的形状变为(batch_size, 1, 4, 4),以便与潜在向量进行合并label_output = label_output.view(-1, 1, 4, 4)# 通过潜在向量生成器将噪声向量映射为潜在向量latent_output = self.latent(noise_vector)# 将潜在向量的形状变为(batch_size, 512, 4, 4),以便与条件标签进行合并latent_output = latent_output.view(-1, 512, 4, 4)# 将条件标签和潜在向量在通道维度上进行合并,得到合并后的特征图concat = torch.cat((latent_output, label_output), dim=1)# 通过生成器的主要结构将合并后的特征图生成为RGB图像image = self.model(concat)return image

将模型导入GPU并初始化权重,查看模型

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

在这里插入图片描述
下面实现判别器模型,首先,条件标签(被嵌入到一个特征向量中。这一步类似于生成器中的过程,目的是将离散的标签映射到一个连续的嵌入空间,以便与图像特征进行合并。接下来,输入图像和嵌入的标签特征被拼接在一起作为鉴别器的输入。然后,这个输入通过一系列卷积层和批量归一化层进行处理,逐渐将其转换为一个用于区分真实和合成图像的特征表示。在每个卷积层之后,都跟随着 激活函数以及批量归一化层,最后,通过全连接层将特征向量映射到一个单一的输出。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义一个条件标签的嵌入层,用于将类别标签转换为特征向量self.label_condition_disc = nn.Sequential(nn.Embedding(n_classes, embedding_dim), # 嵌入层将类别标签编码为固定长度的向量nn.Linear(embedding_dim, 3*128*128)     # 线性层将嵌入的向量转换为与图像尺寸相匹配的特征张量)# 定义主要的鉴别器模型self.model = nn.Sequential(nn.Conv2d(6, 64, 4, 2, 1, bias=False),    # 输入通道为6(包含图像和标签的通道数),输出通道为64,4x4的卷积核,步长为2,padding为1nn.LeakyReLU(0.2, inplace=True),          # LeakyReLU激活函数,带有负斜率,增加模型对输入中的负值的感知能力nn.Conv2d(64, 64*2, 4, 3, 2, bias=False), # 输入通道为64,输出通道为64*2,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8), # 批量归一化层,有利于训练稳定性和收敛速度nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  # 输入通道为64*2,输出通道为64*4,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  # 输入通道为64*4,输出通道为64*8,4x4的卷积核,步长为3,padding为2nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.LeakyReLU(0.2, inplace=True),nn.Flatten(),       # 将特征图展平为一维向量,用于后续全连接层处理nn.Dropout(0.4),    # 随机失活层,用于减少过拟合风险nn.Linear(4608, 1), # 全连接层,将特征向量映射到输出维度为1的向量nn.Sigmoid()        # Sigmoid激活函数,用于输出范围限制在0到1之间的概率值)def forward(self, inputs):img, label = inputs# 将类别标签转换为特征向量label_output = self.label_condition_disc(label)# 重塑特征向量为与图像尺寸相匹配的特征张量label_output = label_output.view(-1, 3, 128, 128)# 将图像特征和标签特征拼接在一起作为鉴别器的输入concat = torch.cat((img, label_output), dim=1)# 将拼接后的输入通过鉴别器模型进行前向传播,得到输出结果output = self.model(concat)return output

将模型导入GPU并初始化权重,查看模型

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

在这里插入图片描述

模型训练

设置损失函数、优化算法和学习率

adversarial_loss = nn.BCELoss() def generator_loss(fake_output, label):gen_loss = adversarial_loss(fake_output, label)return gen_lossdef discriminator_loss(output, label):disc_loss = adversarial_loss(output, label)return disc_losslearning_rate = 0.0002G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

开始训练

# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []# 循环进行训练
for epoch in range(1, num_epochs + 1):# 初始化每轮训练中判别器和生成器损失的临时列表D_loss_list, G_loss_list = [], []# 遍历训练数据加载器中的数据for index, (real_images, labels) in enumerate(train_loader):# 清空判别器的梯度缓存D_optimizer.zero_grad()# 将真实图像数据和标签转移到GPU(如果可用)real_images = real_images.to(device)labels      = labels.to(device)# 将标签的形状从一维向量转换为二维张量(用于后续计算)labels = labels.unsqueeze(1).long()# 创建真实目标和虚假目标的张量(用于判别器损失函数)real_target = Variable(torch.ones(real_images.size(0), 1).to(device))fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))# 计算判别器对真实图像的损失D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)# 从噪声向量中生成假图像(生成器的输入)noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)noise_vector = noise_vector.to(device)generated_image = generator((noise_vector, labels))# 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)output = discriminator((generated_image.detach(), labels))D_fake_loss = discriminator_loss(output, fake_target)# 计算判别器总体损失(真实图像损失和假图像损失的平均值)D_total_loss = (D_real_loss + D_fake_loss) / 2D_loss_list.append(D_total_loss)# 反向传播更新判别器的参数D_total_loss.backward()D_optimizer.step()# 清空生成器的梯度缓存G_optimizer.zero_grad()# 计算生成器的损失G_loss = generator_loss(discriminator((generated_image, labels)), real_target)G_loss_list.append(G_loss)# 反向传播更新生成器的参数G_loss.backward()G_optimizer.step()# 打印当前轮次的判别器和生成器的平均损失print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ((epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list))))# 将当前轮次的判别器和生成器的平均损失保存到列表中D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))if epoch%10 == 0:# 将生成的假图像保存为图片文件save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)# 将当前轮次的生成器和判别器的权重保存到文件torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

在这里插入图片描述

模型分析

加载模型

generator.load_state_dict(torch.load('./training_weights/generator_epoch_100.pth'), strict=False)
generator.eval()

在这里插入图片描述
下段代码实现了在两个潜在空间点之间进行插值,并使用生成器生成相应的图像。
使用generate_latent_points函数生成两个潜在空间的点。这些点是从标准正态分布中随机生成的,作为生成器的输入。
使用interpolate_points函数对两个潜在空间点进行线性插值,生成插值比率在0到1之间。这些比率用于在两个点之间生成一系列插值点,数量为n_steps。
对于三个类别的循环,分别进行插值和生成图片。对于每个类别,创建包含相同类别标签的张量,然后将其与插值点一起传递给生成器。生成器将生成对应类别的插值图像。生成的图像通过permute函数调整维度顺序以匹配预期的通道顺序。最后,将生成的插值图像连接起来形成一个输出。

# 导入所需的库
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from numpy import linspace
from matplotlib import pyplot
from matplotlib import gridspec# 生成潜在空间的点,作为生成器的输入
def generate_latent_points(latent_dim, n_samples, n_classes=3):# 从标准正态分布中生成潜在空间的点x_input = randn(latent_dim * n_samples)# 将生成的点整形成用于神经网络的输入的批量z_input = x_input.reshape(n_samples, latent_dim)return z_input# 在两个潜在空间点之间进行均匀插值
def interpolate_points(p1, p2, n_steps=10):# 在两个点之间进行插值,生成插值比率ratios = linspace(0, 1, num=n_steps)# 线性插值向量vectors = list()for ratio in ratios:v = (1.0 - ratio) * p1 + ratio * p2vectors.append(v)return asarray(vectors)# 生成两个潜在空间的点
pts = generate_latent_points(100, 2)
# 在两个潜在空间点之间进行插值
interpolated = interpolate_points(pts[0], pts[1])# 将数据转换为torch张量并将其移至GPU(假设device已正确声明为GPU)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)output = None
# 对于三个类别的循环,分别进行插值和生成图片
for label in range(3):# 创建包含相同类别标签的张量labels = torch.ones(10) * labellabels = labels.to(device)labels = labels.unsqueeze(1).long()print(labels.size())# 使用生成器生成插值结果predictions = generator((interpolated, labels))predictions = predictions.permute(0,2,3,1)pred = predictions.detach().cpu()if output is None:output = predelse:output = np.concatenate((output,pred))

查看图像

nrow = 3
ncol = 10fig = plt.figure(figsize=(15,4))
gs = gridspec.GridSpec(nrow, ncol) k = 0
for i in range(nrow):for j in range(ncol):pred = (output[k, :, :, :] + 1 ) * 127.5pred = np.array(pred)  ax= plt.subplot(gs[i,j])ax.imshow(pred.astype(np.uint8))ax.set_xticklabels([])ax.set_yticklabels([])ax.axis('off')k += 1   plt.show()

在这里插入图片描述

总结

与标准GAN相比,CGAN额外接收一个条件向量,用于指导生成器生成特定类型的输出。从中了解如何生成样本时使用额外的条件信息,未来掌握利用条件信息生成特定类别的图像或生成带有特定属性的图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709703.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法|344.反转字符串 541. 反转字符串II 卡码网:54.替换数字 151.翻转字符串里的单词 卡码网:55.右旋转字符串

344.反转字符串 考察reverse, 也可以用其他方法 /** * param {character[]} s* return {void} Do not return anything, modify s in-place instead.*/ var reverseString function (s) {return s.reverse(); };541. 反转字符串 思路: 一般是i&#…

vue-router4 (七) 滚动行为(scrollBehavior )

应用场景: 从A组件进入B组件,再返回A组件后,想让A组件的页面回到进入B组件前的位置,或者自动刷新回到A组件顶部,就需配置路由的滚动行为(scrollBehavior )。 ①返回A组件时,让A组件…

webrtc

stun服务 阿里云服务器安全组添加端口开放 webrtc-streamer视屏流服务器搭建 - 简书

Prometheus-监控远程linux的主机

一、本地访问 1、访问 http://8.137.122.212:9090/2、查看监控的主机 默认只监控了本机一台主机 这里的IP地址原本是‘localhost’,为了方便我将‘localhost’换成了主机的IP地址 现在看只监控了本机一台主机 3、查看监控数据 通过http://8.137.122.212:9090/m…

RT-Thread studio上创建一个STM32F103的CAN通讯功能

前言 (1)如果有嵌入式企业需要招聘湖南区域日常实习生,任何区域的暑假Linux驱动实习岗位,可C站直接私聊,或者邮件:zhangyixu02gmail.com,此消息至2025年1月1日前均有效 (2&#xff0…

学Python如此简单--用Python实现一个超简单的学生信息管理系统

简介 系统名称:番茄系统 实现功能:增删查改 运用技术:python基础 代码 import time student_all [] print(欢迎进入番茄系统.center(30)) print(**36,end) while True:pee 请选择功能:1、添加学生2、删除学生3、修改学生4、…

1.1 编程环境的安装

汇编语言 汇编语言环境部署 第二个运行程序直接双击安装一直下一步即可MASM文件复制到D盘路径下找到dosbox安装路径:C:\Program Files (x86)\DOSBox-0.74找到该文件双击打开它,修改一下窗口大小 把这两行改成如下所示 运行dos,黑框中输入mou…

C#,数值计算,求解微分方程的吉尔(Gear)四阶方法与源代码

1 微分方程 微分方程,是指含有未知函数及其导数的关系式。解微分方程就是找出未知函数。 微分方程是伴随着微积分学一起发展起来的。微积分学的奠基人Newton和Leibniz的著作中都处理过与微分方程有关的问题。微分方程的应用十分广泛,可以解决许多与导数…

RTCA DO-178C 机载系统和设备认证中的软件注意事项-软件配置管理流程(七)

7.0 软件配置管理流程 SOFTWARE CONFIGURATION MANAGEMENT PROCESS 本节讨论软件配置管理 (SCM) 过程的目标和活动。 SCM 流程按照软件计划流程(参见 4)和软件配置管理计划(参见 11.4)的定义进行应用。 SCM 过程的输出记录在软件…

【探索AI】Sora - 探索AI视频模型的无限可能

Sora - 探索AI视频模型的无限可能 随着人工智能技术的飞速发展,AI视频模型已成为科技领域的新热点。而在这个浪潮中,OpenAI推出的首个AI视频模型Sora,以其卓越的性能和前瞻性的技术,引领着AI视频领域的创新发展。让我们将一起探讨…

【Web安全靶场】sqli-labs-master 21-37 Advanced-Injection

sqli-labs-master 21-37 Advanced-Injection 第一关到第二十关请见专栏 文章目录 sqli-labs-master 21-37 Advanced-Injection第二十一关-Cookie注入第二十二关-Cookie注入第二十三关-注释符过滤的报错注入第二十四关-二次注入第二十五关-过滤OR、AND双写绕过第二十五a关-过滤…

老卫带你学---leetcode刷题(190. 颠倒二进制位)

190. 颠倒二进制位 问题 颠倒给定的 32 位无符号整数的二进制位。 提示: 请注意,在某些语言(如 Java)中,没有无符号整数类型。在这种情况下,输入和输出都将被指定为有符号整数类型,并且不应…

《Flask入门教程》学习笔记

《Flask入门教程》官网:https://tutorial.helloflask.com/ 目录 第一章:准备工作第二章:Hello, Flask!第三章:模板第四章:静态文件第五章:数据库第六章:模板优化第七章:表单第八章&a…

【嵌入式——QT】日期与定时器

日期 QTime:时间数据类型,仅表示时间,如 16:16:16;QDate:日期数据类型,仅表示日期,如2024-1-22;QDateTime:日期时间数据类型,表示日期和时间,如2…

多个版本的Python如何不冲突?

转载文章,防止忘记或删除 转载于:电脑中存在多个版本的Python如何不冲突? - 知乎 (zhihu.com) 如何安装多版本的Python并与之共存? 如果你的工作涉及到Python多版本之间开发或测试,那么请收藏本文, 如果你…

【python】Python Turtle绘制流星雨动画效果【附源码】

在这篇技术博客中,我们将学习如何使用 Python 的 Turtle 模块绘制一个流星雨的动画效果。通过简单的代码实现,我们可以在画布上展现出流星闪耀的场景,为视觉带来一丝神秘与美感。 一、效果图: 二、准备工作 (1)、导入…

每日一题——LeetCode1544.整理字符串

方法一 字符串转数组删除元素 将字符串转为数组&#xff0c;遍历数组&#xff0c;如果碰到同一字母大写小写连续出现就原地删除这两个元素&#xff0c;最后把数组转回字符串并返回 var makeGood function(s) {let arrs.split()for(let i0;i<s.length-1;i){if(arr[i]!arr[…

【程序员的金三银四求职宝典】《春风拂面,代码在手:程序员的金三银四求职指南》

《春风拂面&#xff0c;代码在手&#xff1a;程序员的金三银四求职指南》 随着春风的轻拂&#xff0c;大地复苏&#xff0c;万物更新。在这个生机勃勃的季节&#xff0c;不仅自然界在迎接新生&#xff0c;对于广大的程序员朋友们而言&#xff0c;这也是一个全新的开始——金三…

关于HTML标签应用教程

简介 HTML&#xff08;HyperText Markup Language&#xff09;是用于创建网页结构的标记语言。在本教程中&#xff0c;我们将介绍一些常用的HTML标签&#xff0c;以及它们的用法和示例。 1. HTML基础结构 <!DOCTYPE html> <html> <head><title>页面…

windows U盘不能识别

windows U盘不能识别 1、问题描述2、问题分析解决3、把U盘插到windows电脑上试试能不能识别 1、问题描述 windwos u盘不能识别 u盘被拿到mac电脑上做了启动盘之后&#xff0c;就不能被windows识别了。题主很奇怪里面被mac电脑的同学放了什么&#xff0c;因此想到把优盘挂载到L…