变分自动编码器 (VAE)02/2 PyTorch 教程

一、说明

        在自动编码器中,来自输入数据的信息被映射到固定的潜在表示中。当我们旨在训练模型以生成确定性预测时,这特别有用。相比之下,变分自动编码器(VAE)将输入数据转换为变分表示向量(顾名思义),其中该向量的元素表示有关输入数据分布的不同属性。VAE的这种概率特性使其成为一个生成模型。VAE中的潜在表示由最能定义输入数据的概率分布(μ,σ)组成。

二、前提知识

        要了解有关VAE直觉的更多信息,我建议您阅读了解变分自动编码器(VAE)和什么是变分自动编码器?

三、VAE的PyTorch实现

        在本文中,我们只关注 PyTorch 中的简单 VAE,并在 MNIST 数据集上训练后可视化其潜在表示。让我们从导入一些库开始:

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

并下载 MNIST 数据集并制作数据加载器:

# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

让我们可视化一些示例训练数据:

# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)for ax, im in zip(grid, sample_images):ax.imshow(im, cmap='gray')ax.axis('off')plt.show()
25 张样本训练图像

        现在,我们创建一个简单的VAE,它具有完全连接的编码器和解码器。输入维度为 784,这是 MNIST 图像 (28×28) 的扁平维度。在编码器中,均值 (μ) 和方差 (σ²) 向量是我们的变分表示向量 (size=200)。请注意,我们将潜在方差与 epsilon (ε) 参数相乘,以便在解码之前重新参数化。这使我们能够执行反向传播并解决节点随机性。在此处阅读有关重新参数化的更多信息。

        此外,我们的最终编码器维度为 2,即μ和σ向量。这些连续向量定义了我们的潜在空间分布,使我们能够在VAE中对图像进行采样。

class VAE(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):super(VAE, self).__init__()# encoderself.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, latent_dim),nn.LeakyReLU(0.2))# latent mean and variance self.mean_layer = nn.Linear(latent_dim, 2)self.logvar_layer = nn.Linear(latent_dim, 2)# decoderself.decoder = nn.Sequential(nn.Linear(2, latent_dim),nn.LeakyReLU(0.2),nn.Linear(latent_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, input_dim),nn.Sigmoid())def encode(self, x):x = self.encoder(x)mean, logvar = self.mean_layer(x), self.logvar_layer(x)return mean, logvardef reparameterization(self, mean, var):epsilon = torch.randn_like(var).to(device)      z = mean + var*epsilonreturn zdef decode(self, x):return self.decoder(x)def forward(self, x):mean, logvar = self.encode(x)z = self.reparameterization(mean, logvar)x_hat = self.decode(z)return x_hat, mean, log_vardef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterization(mean, torch.exp(0.5 * log_var)) x_hat = self.decode(z)  return x_hat, mean, log_var

        现在,我们可以定义我们的模型和优化器:

model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

        VAE中的损失函数由再现损失和库尔巴克-莱布勒(KL)散度组成。KL 散度是用于衡量两个概率分布之间距离的指标。KL 散度是生成建模中的一个重要概念,但在本教程中,我们不会更详细地介绍。:了解KL背离的直观指南。

def loss_function(x, x_hat, mean, log_var):reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())return reproduction_loss + KLD

        最后,我们可以训练我们的模型:

def train(model, optimizer, epochs, device):model.train()for epoch in range(epochs):overall_loss = 0for batch_idx, (x, _) in enumerate(train_loader):x = x.view(batch_size, x_dim).to(device)optimizer.zero_grad()x_hat, mean, log_var = model(x)loss = loss_function(x, x_hat, mean, log_var)overall_loss += loss.item()loss.backward()optimizer.step()print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))return overall_losstrain(model, optimizer, epochs=50, device=device)

        我们现在知道,从潜在空间生成图像所需要的只是两个浮点值(平均值和方差)。让我们从潜在空间生成一些图像:

def generate_digit(mean, var):z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)x_decoded = model.decode(z_sample)digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d arrayplt.imshow(digit, cmap='gray')plt.axis('off')plt.show()generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

                有趣!更令人印象深刻的潜在空间视图:

def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):# display a n*n 2D manifold of digitsfigure = np.zeros((digit_size * n, digit_size * n))# construct a grid grid_x = np.linspace(-scale, scale, n)grid_y = np.linspace(-scale, scale, n)[::-1]for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)x_decoded = model.decode(z_sample)digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(figsize, figsize))plt.title('VAE Latent Space Visualization')start_range = digit_size // 2end_range = n * digit_size + start_rangepixel_range = np.arange(start_range, end_range, digit_size)sample_range_x = np.round(grid_x, 1)sample_range_y = np.round(grid_y, 1)plt.xticks(pixel_range, sample_range_x)plt.yticks(pixel_range, sample_range_y)plt.xlabel("mean, z [0]")plt.ylabel("var, z [1]")plt.imshow(figure, cmap="Greys_r")plt.show()plot_latent_space(model)

        潜在空间可视化,范围:[-1.0、1.0]

        这就是介于 -1.0 和 1.0 之间的均值和方差值的潜在空间的外观。如果我们将此比例更改为 -5.0 和 5.0 会发生什么?

潜在空间可视化,范围:[-5.0、5.0]

        再次,有趣!现在,我们可以看到大多数数字表示形式所在的均值和方差值的范围。现在,我们知道如何从头开始构建一个简单的VAE,对图像进行采样并可视化潜在空间。但VAE并不止于此,还有更先进的技术使表征学习更加迷人。我将在以后的文章中探讨它们。在Github,LinkedIn和 Google Scholar上找到我。

四、在Google Colab中尝试代码:

import torch
import numpy as np
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)for ax, im in zip(grid, sample_images):ax.imshow(im, cmap='gray')ax.axis('off')plt.show()
class Encoder(nn.Module):def __init__(self, input_dim=784, hidden_dim=512, latent_dim=256):super(Encoder, self).__init__()self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, hidden_dim)self.mean = nn.Linear(hidden_dim, latent_dim)self.var = nn.Linear (hidden_dim, latent_dim)self.LeakyReLU = nn.LeakyReLU(0.2)self.training = Truedef forward(self, x):x = self.LeakyReLU(self.linear1(x))x = self.LeakyReLU(self.linear2(x))mean = self.mean(x)log_var = self.var(x)                     return mean, log_varclass Decoder(nn.Module):def __init__(self, output_dim=784, hidden_dim=512, latent_dim=256):super(Decoder, self).__init__()self.linear2 = nn.Linear(latent_dim, hidden_dim)self.linear1 = nn.Linear(hidden_dim, hidden_dim)self.output = nn.Linear(hidden_dim, output_dim)self.LeakyReLU = nn.LeakyReLU(0.2)def forward(self, x):x = self.LeakyReLU(self.linear2(x))x = self.LeakyReLU(self.linear1(x))x_hat = torch.sigmoid(self.output(x))return x_hat
class VAE(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):super(VAE, self).__init__()# encoderself.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, latent_dim),nn.LeakyReLU(0.2))# latent mean and variance self.mean_layer = nn.Linear(latent_dim, 2)self.logvar_layer = nn.Linear(latent_dim, 2)# decoderself.decoder = nn.Sequential(nn.Linear(2, latent_dim),nn.LeakyReLU(0.2),nn.Linear(latent_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, input_dim),nn.Sigmoid())def encode(self, x):x = self.encoder(x)mean, logvar = self.mean_layer(x), self.logvar_layer(x)return mean, logvardef reparameterization(self, mean, var):epsilon = torch.randn_like(var).to(device)      z = mean + var*epsilonreturn zdef decode(self, x):return self.decoder(x)def forward(self, x):mean, logvar = self.encode(x)z = self.reparameterization(mean, logvar)x_hat = self.decode(z)return x_hat, mean, log_vardef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterization(mean, torch.exp(0.5 * log_var)) x_hat = self.decode(z)  return x_hat, mean, log_var
model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
def loss_function(x, x_hat, mean, log_var):reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())return reproduction_loss + KLDdef train(model, optimizer, epochs, device, x_dim=784):model.train()for epoch in range(epochs):overall_loss = 0for batch_idx, (x, _) in enumerate(train_loader):x = x.view(batch_size, x_dim).to(device)optimizer.zero_grad()x_hat, mean, log_var = model(x)loss = loss_function(x, x_hat, mean, log_var)overall_loss += loss.item()loss.backward()optimizer.step()print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))return overall_loss
train(model, optimizer, epochs=50, device=device)
Epoch 1 	Average Loss:  180.83752029750104Epoch 2 	Average Loss:  163.2696611703099Epoch 3 	Average Loss:  158.45957443721306Epoch 4 	Average Loss:  155.22525566699707Epoch 5 	Average Loss:  153.2932642294449Epoch 6 	Average Loss:  151.85221148202734Epoch 7 	Average Loss:  150.7171567847454Epoch 8 	Average Loss:  149.69312638577316Epoch 9 	Average Loss:  148.78454667284015Epoch 10 	Average Loss:  148.15143693264815
def generate_digit(mean, var):z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)x_decoded = model.decode(z_sample)digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d arrayplt.title(f'[{mean},{var}]')plt.imshow(digit, cmap='gray')plt.axis('off')plt.show()#img1: mean0, var1 / img2: mean1, var0
generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

def plot_latent_space(model, scale=5.0, n=25, digit_size=28, figsize=15):# display a n*n 2D manifold of digitsfigure = np.zeros((digit_size * n, digit_size * n))# construct a grid grid_x = np.linspace(-scale, scale, n)grid_y = np.linspace(-scale, scale, n)[::-1]for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)x_decoded = model.decode(z_sample)digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(figsize, figsize))plt.title('VAE Latent Space Visualization')start_range = digit_size // 2end_range = n * digit_size + start_rangepixel_range = np.arange(start_range, end_range, digit_size)sample_range_x = np.round(grid_x, 1)sample_range_y = np.round(grid_y, 1)plt.xticks(pixel_range, sample_range_x)plt.yticks(pixel_range, sample_range_y)plt.xlabel("mean, z [0]")plt.ylabel("var, z [1]")plt.imshow(figure, cmap="Greys_r")plt.show()plot_latent_space(model, scale=1.0)
plot_latent_space(model, scale=5.0)

礼萨·卡兰塔尔

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/106196.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python 练习--更新

1.判断一个列表中的数值是否全部小于某个数 方法一:利用if函数 (只要列表中有一个数字比大 就可以终止比较) n int(input("请输入需要比较的数字:")) arr1 [1,3,4,5,8] index 0 for i in arr1:if i > n:index 1continue…

服务器崩溃前的数据拯救实践

前言 在服务器的VMWARE ESXi系统环境中,我们经常需要创建虚拟机来运行各种应用程序。然而,服务器如果偶尔出现自动重启以及紫屏报错的问题,说明服务器内部出现了故障,一般情况下重启机器能够解决问题,但时间一长&…

[23] IPDreamer: Appearance-Controllable 3D Object Generation with Image Prompts

pdf Text-to-3D任务中,对3D模型外观的控制不强,本文提出IPDreamer来解决该问题。在NeRF Training阶段,IPDreamer根据文本用ControlNet生成参考图,并将参考图作为Zero 1-to-3的控制条件,用基于Zero 1-to-3的SDS损失生成…

一、K8S第一步搭建

一、初始化操作 1.1、关闭防火墙 systemctl stop firewalld systemctl disable firewalld关闭交换空间 swapoff -a # 临时 sed -ri s/.*swap.*/#&/ /etc/fstab # 永久重启才能生效 根据规划设置主机名 hostnamectl set-hostname <hostname>映射主机 cat >>…

网络类型与数据链路层协议

目录 整体大纲图 一、网络类型 二、数据链路层协议 1、MA网络 2、P2P网络 1&#xff09;HDLC协议 2&#xff09;PPP协议 a、特点及其数据帧封装结构 b、组成及其工作过程 c、ppp会话流程及ppp验证 d、ppp配置命令 f、ppp mp 整体大纲图 一、网络类型 二、数据链路层…

系统文件IO、文件描述符fd、重定向、文件系统、动态库和静态库

目录 C文件接口系统文件I/O系统调用和库函数文件描述符0 & 1 & 2FILE和fd的关系文件描述符的分配规则 重定向重定向的本质输出重定向输入重定向追加重定向 dup2函数 FILE理解文件系统了解磁盘的物理结构逻辑抽象文件系统文件系统的图解和解析通过文件系统来理解ls -al通…

MySQL 3 环境搭建 MySQL 5.7版本的安装、配置

MySQL5.7.43官网下载地址 MySQL :: Download MySQL Community Server 这里选5.7.43&#xff0c;Windows版本&#xff0c;然后点击Go to Download Page&#xff0c;下载msi安装包的版本 MSI安装包版本比ZIP压缩包版本的安装过程要简单的多&#xff0c;过程更加清楚直观&#x…

MATLAB——径向基神经网络预测程序

欢迎关注公众号“电击小子程高兴的MATLAB小屋” %% 学习目标&#xff1a;径向基神经网络 %% 可以以任意精度逼近任意连续函数 clear all; close all; P1:10; T[2.523 2.434 3.356 4.115 5.834 6.967 7.098 8.315 9.387 9.928]; netnewrbe(P,T,2); %建立精确的径向基…

KMP 算法 + 详细笔记

给两个字符串&#xff0c;T"AAAAAAAAB"&#xff0c;P"AAAAB"; 可以暴力匹配&#xff0c;但是太费时和效率不太好。于是KMP问世&#xff0c;我们一起来探究一下吧&#xff01;&#xff01;&#xff01; &#xff08;一&#xff09;最长公共前后缀 D[i] p[…

【C/C++数据结构 - 2】:稳定性与优化揭秘,揭开插入排序、希尔排序和快速排序的神秘面纱!

文章目录 排序的稳定性插入排序插入排序的优化 希尔排序快速排序 排序的稳定性 稳定排序&#xff1a;排序前2个相等的数在序列中的前后位置顺序和排序后它们2个的前后位置顺序相同。&#xff08;比如&#xff1a;冒泡、插入、基数、归并&#xff09; 非稳定排序&#xff1a;排…

UVa658 It’s not a Bug, it’s a Feature!(Dijkstra)

题意 给出一个包含n个bug的应用程序&#xff0c;以及m个补丁&#xff0c;每个补丁使用两个字符串表示&#xff0c;第一个串表示补丁针对bug的情况&#xff0c;即哪些bug存在&#xff0c;以及哪些bug不存在&#xff0c;第二个串表示补丁对bug的修复情况&#xff0c;即修复了哪些…

进化算法------微生物进化算法(MGA)

前言 该文章写在GA算法之后&#xff1a;GA算法 遗传算法 (GA)的问题在于没有有效保留好的父母 (Elitism), 让好的父母不会消失掉. Microbial GA (后面统称 MGA) 就是一个很好的保留 Elitism 的算法. 一句话来概括: 在袋子里抽两个球, 对比两个球, 把球大的放回袋子里, 把球小…

ARMv5架构对齐访问异常问题

strh非对齐访问 在ARMv5架构中&#xff0c;对于strh指令&#xff08;Store Halfword&#xff09;&#xff0c;通常是要求对地址进行对齐访问的。ARMv5架构对于半字&#xff08;Halfword&#xff09;的存储操作有对齐要求&#xff0c;即地址必须是2的倍数。 如果尝试使用strh指…

2024北京国际光刻设备及光掩膜技术展览会

2024北京国际光刻设备及光掩膜技术展览会 Beijing Photolithography Equipment and Mask Application Technology Exhibition2024 基本信息 时间&#xff1a;2024年7月24-26日 地点&#xff1a;北京国家会议中心 展会简介 微电子技术的发展一直是光刻设备和技术变革的动力&…

二、使用DockerCompose部署RocketMQ

使用DockerCompose进行部署 RocketMQ的部署方式以及各自的特点 单master模式 只有一个 master 节点&#xff0c;如果master节点挂掉了&#xff0c;会导致整个服务不可用&#xff0c;线上不宜使用&#xff0c;适合个人学习使用。 多master模式 和kafka不一样&#xff0c;Rocke…

vue3 状态管理pinia

1. 什么是Pinia Pinia 是 Vue 的专属的最新状态管理库 &#xff0c;是 Vuex 状态管理工具的替代品 特点优势: 提供更加简单的API(去掉了mutation)提供符合组合式风格的API(和Vue3新语法统一)去掉modules的概念,每一个store都是一个独立的模块配合TypeScript更加友好,提供可靠的…

网站的常见攻击与防护方法

在互联网时代&#xff0c;几乎每个网站都存在着潜在的安全威胁。这些威胁可能来自人为失误&#xff0c;也可能源自网络犯罪团伙所发起的复杂攻击。无论攻击的本质如何&#xff0c;网络攻击者的主要动机通常是谋求经济利益。这意味着无论您经营的是电子商务项目还是小型商业网站…

【Redis】Set集合相关的命令

目录 命令SADDSMEMBERSSISMEMBERSCARDSPOPSMOVESREMSINTERSINTERSTORESUNIONSUNIONSTORESDIFFSDIFFSTORE 命令 SADD 将⼀个或者多个元素添加到set中。注意&#xff0c;重复的元素⽆法添加到set中。 SADD key member [member ...]SMEMBERS 获取⼀个set中的所有元素&#xff0…

vector Autosar someip和vsomeip协议调试总结

someip是现代车辆通信的主流通信协议知一&#xff1b; someip的主要涉及模型以及协议结构&#xff0c;我就不做多的做介绍了&#xff0c;如有需要请读者自行进行百度学些&#xff1b; 虽然someip协议已经基本成熟&#xff0c;但有多个实现版本&#xff0c;现在使用较多的主要…

④. GPT错误:导入import pandas as pd库,存储输入路径图片信息存储错误

꧂ 问题最初꧁ 用 import pandas as pd 可是你没有打印各种信息input输入图片路径 print图片尺寸 大小 长宽高 有颜色占比>0.001的按照大小排序将打印信息存储excel表格文件名 表格路径 图片大小 尺寸 颜色类型 占比信息input输入的是文件就处理文件 是文件夹&#x1f4c…