变分自编码器(VAE)PyTorch Lightning 实现

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • VAE 简介
      • 基本原理
      • 应用与优点
      • 缺点与挑战
    • 使用 VAE 生成 MNIST 手写数字
      • 忽略警告
      • 导入必要的库
      • 设置随机种子
      • cuDNN 设置
      • 超参数设置
      • 数据加载
      • 定义 VAE 模型
      • 定义损失函数
      • 定义 Lightning 模型
      • 训练模型
      • 绘制训练过程
      • 随机生成新样本
      • 根据潜变量插值生成新样本


VAE 简介

变分自编码器(Variational Autoencoder,VAE)是一种深度学习中的生成模型,它结合了自编码器(Autoencoder, AE)和概率建模的思想,在无监督学习环境中表现出了强大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成为生成模型领域的重要组成部分。

基本原理

自编码器(AE)基础:
自编码器是一种神经网络结构,通常由两部分组成:编码器(Encoder)和解码器(Decoder)。原始数据通过编码器映射到一个低维的潜在空间(或称为隐空间),这个低维向量被称为潜变量(latent variable)。然后,潜变量再通过解码器重构回原始数据的近似版本。在训练过程中,自编码器的目标是使得输入数据经过编码-解码过程后能够尽可能地恢复原貌,从而学习到数据的有效表示。

VAE的引入与扩展:
VAE 将自编码器的概念推广到了概率框架下。在 VAE 中,潜变量不再是确定性的,而是被赋予了概率分布。具体来说,对于给定的输入数据,编码器不直接输出一个点估计值,而是输出潜变量的均值和方差(假设潜变量服从高斯分布)。这样,每个输入数据可以被视为是从某个潜在的概率分布中采样得到的。

变分推断(Variational Inference):
训练 VA E时,由于真实的后验概率分布难以直接计算,因此采用变分推断来近似后验分布。编码器实际上输出的是一个参数化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即给定输入 x x x 时潜变量 z z z 的概率分布。然后通过最小化 KL 散度(Kullback-Leibler divergence)来优化这个近似分布,使其尽可能接近真实的后验分布 p ( z ∣ x ) p(z|x) p(zx)

目标函数 - Evidence Lower Bound (ELBO):
VAE 的目标函数是证据下界(ELBO),它是原始数据 log-likelihood 的下界。优化该目标函数既鼓励编码器找到数据的高效潜在表示,又促使解码器基于这些表示重建出类似原始数据的新样本。

数学表达上,ELBO 通常分解为两个部分:

  1. 重构损失(Reconstruction Loss):衡量从潜变量重构出来的数据与原始数据之间的差异。
  2. KL散度损失(KL Divergence Loss):衡量编码器产生的潜变量分布与预设的标准正态分布(或其他先验分布)之间的距离。

应用与优点

  • VAE 可以用于生成新数据,例如图像、文本、音频等。
  • 由于其对潜变量进行概率建模,所以它可以提供连续的数据生成,并且能够探索数据的不同模式。
  • 在处理连续和离散数据时具有一定的灵活性。
  • 可以用于特征学习,提取数据的有效低维表示。

缺点与挑战

  • 训练 VAE 可能需要大量的计算资源和时间。
  • 生成的样本有时可能不够清晰或细节模糊,尤其是在复杂数据集上。
  • 对于某些复杂的分布形式,VAE 可能无法完美捕获所有细节。

使用 VAE 生成 MNIST 手写数字

下面我们将使用 PyTorch Lightning 来实现一个简单的 VAE 模型,并使用 MNIST 数据集来进行训练和生成。

在线 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

导入必要的库

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasetsimport lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

设置随机种子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超参数设置

batch_size = 64epochs = 10
KLD_weight = 1
lr = 0.001input_dim = 784  # 28 * 28
h_dim = 256  # 隐藏层维度  
z_dim = 2  # 潜变量维度

数据加载

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定义 VAE 模型

class VAE(nn.Module):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.input_dim = input_dimself.h_dim = h_dimself.z_dim = z_dim# Encoderself.fc1 = nn.Linear(input_dim, h_dim)self.fc21 = nn.Linear(h_dim, z_dim)  # muself.fc22 = nn.Linear(h_dim, z_dim)  # log_var# Decoderself.fc3 = nn.Linear(z_dim, h_dim)self.fc4 = nn.Linear(h_dim, input_dim)def encode(self, x):h = torch.relu(self.fc1(x))mean = self.fc21(h)log_var = self.fc22(h)return mean, log_vardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = torch.relu(self.fc3(z))out = torch.sigmoid(self.fc4(h))return outdef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterize(mean, log_var)reconstructed_x = self.decode(z)return reconstructed_x, mean, log_varvae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定义损失函数

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重构损失KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度损失loss = BCE_loss + KLD_loss * KLD_weightreturn loss, BCE_loss, KLD_loss

定义 Lightning 模型

class LitModel(pl.LightningModule):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super().__init__()self.model = VAE(input_dim, h_dim, z_dim)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)reconstructed_x, mean, log_var = self(x)loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log_dict({"BCE_loss": BCE_loss,"KLD_loss": KLD_loss,},on_step=False,on_epoch=True,logger=True,)return lossdef decode(self, z):out = self.model.decode(z)return out

训练模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(max_epochs=epochs,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

绘制训练过程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

训练过程

随机生成新样本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()plt.figure(figsize=(col, row))
for i in range(row * col):plt.subplot(row, col, i + 1)plt.imshow(random_res[i].squeeze(), cmap="gray")plt.xticks([])plt.yticks([])plt.axis("off")
plt.show()

随机生成新样本

根据潜变量插值生成新样本

from scipy.stats import normn = 15
digit_size = 28grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):t = [xi, yi]z_sampled = torch.FloatTensor(t)with torch.no_grad():decode = model.decode(z_sampled)digit = decode.view((digit_size, digit_size))figure[i * digit_size : (i + 1) * digit_size,j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

根据潜变量插值生成新样本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/685255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP PP学习笔记- 豆知识02 - 品目要谁来维护?怎么决定更不更新品目的数量金额?

其实都是在品目类型的Customize中设定的。 咱们这里简单试着说一下什么场景使用。 1,SAP中品目有很多View,都要由哪些部门来维护呢? 其实就是谁用谁维护呗。 在新建一个品目的时候,品目Type本身就决定了该品目要由哪些部门来维…

数据结构(4) 链表(链式存储)

链表(链式存储) 单链表定义基本操作的实现单链表的插入按位序插入指定节点的前插指定节点的后插 单链表的删除 小结 单链表 定义 顺序表优点:可随机存取,存储密度高,缺点:要求大片连续空间,改变容量不方便。 单链表优…

高数总结(4

目录 1.总结:小结: 1.总结: 小结: 关注我给大家分享更多有趣的知识,以下是个人公众号,提供 ||代码兼职|| ||代码问题求解|| 由于本号流量还不足以发表推广,搜我的公众号即可:

Codeforces Round 926 F. Sasha and the Wedding Binary Search Tree

F. Sasha and the Wedding Binary Search Tree 题意 给定一颗二叉搜索树,规定树上的所有点的点权都在范围 [ 1 , C ] [1, C] [1,C] 内,树上的某些节点点权已知,某些节点点权未知,求出合法的二叉搜索树的数量 思路 由于是二叉搜…

这应该是全网第一篇全面解读OpenAI Sora报告的文章,精读报告:Video generation models as world simulators

今天是2024年2月16号,大年初七,年还没过完,早晨起来朋友圈就被Sora刷屏了。本来以为没啥,都是公众号或者视屏啥的,都没点开看,直到看到我导也发了Sora的文章,我就知道这个事情不简单了。 先来看…

支付宝沙箱版模拟网站在线完整支付流程(无营业无费用)内网穿透+局域网测试

文章目录 一、介绍1. 支付2. 支付结果 二、前提准备1. 支付宝开放平台2. 内网穿透3. 局域网 三、order微服务1. 依赖、配置2. 工具类1. 二维码生成2. AlipayConfig 3. 端点代码1. /generatepaycode2. /requestpay3. /payresult4. /receivenotify 环境如下 Version手机安卓支付…

红队笔记Day4 -->多层代理(模拟企业拓扑)

声明:本机文章只用于教育用途,无不良引导,禁止用于从事任何违法活动 前几天的红队笔记的网络拓扑都比较简单,今天就来模拟一下企业的真实网络拓扑,以及攻击方法 一般的大企业的网络拓扑如下::…

基于Springboot+Vue实现的宿舍管理系统

基于SpringbootVue的宿舍管理系统 1.系统相关性介绍1.1 系统架构1.2 设计思路 2.功能模块介绍2.1 用户信息模块2.2 宿舍管理模块2.3 信息管理模块 3. 源码获取以及远程部署 前言: 在现代教育环境中,学生宿舍的管理显得尤为重要,需要一套能…

如何在Windows中配置多个显示器?这里提供详细步骤

Windows可以通过多种方式使用多个显示器,扩展或复制主显示器。你甚至可以关闭主显示器。以下是如何使用简单的键盘快捷键更改辅助显示设置。 使用WindowsP投影菜单 要快速更改Windows 10处理多个显示器的方式,请按WindowsP。屏幕右侧会弹出一个名为“投…

图表示学习 Graph Representation Learning chapter1 引言

图表示学习 Graph Representation Learning chapter1 引言 前言1.1图的定义1.1.1多关系图1.1.2特征信息 1.2机器学习在图中的应用1.2.1 节点分类1.2.2 关系预测1.2.3 聚类和组织检测1.2.4 图分类、回归、聚类 前言 虽然我并不研究图神经网络,但是我认为图高效的表示…

电脑重装系统之Windows 10 企业版 LTSC 2021

简介 Windows 10 22H2对于我来说太不简洁,最受不了的一点是微软强行硬塞给我一些并没有什么luan用的应用和功能,比如:天气,Onedrive......以及臃肿的ui设计。而且强行进行自动更新,我是真的受不了这个,看着…

leetcode:343.整数拆分

解题思路: 拆分的越多越好(暂且认为),尽可能拆成m个近似相等的数,会使得乘积最大 dp含义:将i进行拆分得到最大的积为dp[i] 递推公式:j x dp[i-j](固定j,只通过凑dp[i-j]进而实现所…

optee UTA加载

流程 动态TA按照存储位置的不同分为REE filesystem TA:存放在REE侧文件系统里的TA; Early TA:被嵌入到optee os里的在supplicant启动之前就可用了。 这里我们讲的是常规的存放在REE侧文件系统里的TA。 通过GP标准调用的与TA通信的命令(opens…

【AIGC】Stable Diffusion的ControlNet参数入门

Stable Diffusion 中的 ControlNet 是一种用于控制图像生成过程的技术,它可以指导模型生成特定风格、内容或属性的图像。下面是关于 ControlNet 的界面参数的详细解释: 低显存模式 是一种在深度学习任务中用于处理显存受限设备的技术。在这种模式下&am…

Vue的一些基础设置

1.浏览器控制台显示Vue 设置找到扩展,搜索Vue 下载这个 然后 点击扩展按钮 点击详细信息 选择这个,然后重启一下就好了 ——————————————————————————————————————————— 2.优化工程结构 src的components里要…

MySQL数据库基础第四篇(多表查询与事务)

文章目录 一、多表关系二、多表查询三、内连接查询四、外连接查询五、自连接查询六、联合查询 union, union all七、子查询1.标量子查询2.列子查询3.行子查询4.表子查询 八、事务八、事务的四大特性九、并发事务问题十、事务隔离级级别 在这篇文章中,我们将深入探讨…

数据结构之队的实现

𝙉𝙞𝙘𝙚!!👏🏻‧✧̣̥̇‧✦👏🏻‧✧̣̥̇‧✦ 👏🏻‧✧̣̥̇:Solitary-walk ⸝⋆ ━━━┓ - 个性标签 - :来于“云”的“羽球人”。…

用HTML5 Canvas创造视觉盛宴——动态彩色线条效果

目录 一、程序代码 二、代码原理 三、运行效果 一、程序代码 <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <!-- 声明文档类型为XHTML 1.0 Transitional -…

Python四级考试笔记

Python四级考试笔记【源源老师】 四级标准 一、 理解函数及过程、函数的参数、函数的返回值、变量作用域等概念。 二、 能够创建简单的自定义函数。 三、 理解算法以及算法性能、效率的概念&#xff0c;初步认识算法优化 效率的方法。 四、 理解基本算法中递归的概念。 五、 掌…

【开源】SpringBoot框架开发独居老人物资配送系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询社区4.2 新增物资4.3 查询物资4.4 查询物资配送4.5 新增物资配送 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBootMySQL的独居老人物资配送系统&#xff0c;包含了社区档案、…