变分自编码器(VAE)PyTorch Lightning 实现

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • VAE 简介
      • 基本原理
      • 应用与优点
      • 缺点与挑战
    • 使用 VAE 生成 MNIST 手写数字
      • 忽略警告
      • 导入必要的库
      • 设置随机种子
      • cuDNN 设置
      • 超参数设置
      • 数据加载
      • 定义 VAE 模型
      • 定义损失函数
      • 定义 Lightning 模型
      • 训练模型
      • 绘制训练过程
      • 随机生成新样本
      • 根据潜变量插值生成新样本


VAE 简介

变分自编码器(Variational Autoencoder,VAE)是一种深度学习中的生成模型,它结合了自编码器(Autoencoder, AE)和概率建模的思想,在无监督学习环境中表现出了强大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成为生成模型领域的重要组成部分。

基本原理

自编码器(AE)基础:
自编码器是一种神经网络结构,通常由两部分组成:编码器(Encoder)和解码器(Decoder)。原始数据通过编码器映射到一个低维的潜在空间(或称为隐空间),这个低维向量被称为潜变量(latent variable)。然后,潜变量再通过解码器重构回原始数据的近似版本。在训练过程中,自编码器的目标是使得输入数据经过编码-解码过程后能够尽可能地恢复原貌,从而学习到数据的有效表示。

VAE的引入与扩展:
VAE 将自编码器的概念推广到了概率框架下。在 VAE 中,潜变量不再是确定性的,而是被赋予了概率分布。具体来说,对于给定的输入数据,编码器不直接输出一个点估计值,而是输出潜变量的均值和方差(假设潜变量服从高斯分布)。这样,每个输入数据可以被视为是从某个潜在的概率分布中采样得到的。

变分推断(Variational Inference):
训练 VA E时,由于真实的后验概率分布难以直接计算,因此采用变分推断来近似后验分布。编码器实际上输出的是一个参数化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即给定输入 x x x 时潜变量 z z z 的概率分布。然后通过最小化 KL 散度(Kullback-Leibler divergence)来优化这个近似分布,使其尽可能接近真实的后验分布 p ( z ∣ x ) p(z|x) p(zx)

目标函数 - Evidence Lower Bound (ELBO):
VAE 的目标函数是证据下界(ELBO),它是原始数据 log-likelihood 的下界。优化该目标函数既鼓励编码器找到数据的高效潜在表示,又促使解码器基于这些表示重建出类似原始数据的新样本。

数学表达上,ELBO 通常分解为两个部分:

  1. 重构损失(Reconstruction Loss):衡量从潜变量重构出来的数据与原始数据之间的差异。
  2. KL散度损失(KL Divergence Loss):衡量编码器产生的潜变量分布与预设的标准正态分布(或其他先验分布)之间的距离。

应用与优点

  • VAE 可以用于生成新数据,例如图像、文本、音频等。
  • 由于其对潜变量进行概率建模,所以它可以提供连续的数据生成,并且能够探索数据的不同模式。
  • 在处理连续和离散数据时具有一定的灵活性。
  • 可以用于特征学习,提取数据的有效低维表示。

缺点与挑战

  • 训练 VAE 可能需要大量的计算资源和时间。
  • 生成的样本有时可能不够清晰或细节模糊,尤其是在复杂数据集上。
  • 对于某些复杂的分布形式,VAE 可能无法完美捕获所有细节。

使用 VAE 生成 MNIST 手写数字

下面我们将使用 PyTorch Lightning 来实现一个简单的 VAE 模型,并使用 MNIST 数据集来进行训练和生成。

在线 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

导入必要的库

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasetsimport lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

设置随机种子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超参数设置

batch_size = 64epochs = 10
KLD_weight = 1
lr = 0.001input_dim = 784  # 28 * 28
h_dim = 256  # 隐藏层维度  
z_dim = 2  # 潜变量维度

数据加载

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定义 VAE 模型

class VAE(nn.Module):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.input_dim = input_dimself.h_dim = h_dimself.z_dim = z_dim# Encoderself.fc1 = nn.Linear(input_dim, h_dim)self.fc21 = nn.Linear(h_dim, z_dim)  # muself.fc22 = nn.Linear(h_dim, z_dim)  # log_var# Decoderself.fc3 = nn.Linear(z_dim, h_dim)self.fc4 = nn.Linear(h_dim, input_dim)def encode(self, x):h = torch.relu(self.fc1(x))mean = self.fc21(h)log_var = self.fc22(h)return mean, log_vardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = torch.relu(self.fc3(z))out = torch.sigmoid(self.fc4(h))return outdef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterize(mean, log_var)reconstructed_x = self.decode(z)return reconstructed_x, mean, log_varvae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定义损失函数

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重构损失KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度损失loss = BCE_loss + KLD_loss * KLD_weightreturn loss, BCE_loss, KLD_loss

定义 Lightning 模型

class LitModel(pl.LightningModule):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super().__init__()self.model = VAE(input_dim, h_dim, z_dim)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)reconstructed_x, mean, log_var = self(x)loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log_dict({"BCE_loss": BCE_loss,"KLD_loss": KLD_loss,},on_step=False,on_epoch=True,logger=True,)return lossdef decode(self, z):out = self.model.decode(z)return out

训练模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(max_epochs=epochs,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

绘制训练过程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

训练过程

随机生成新样本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()plt.figure(figsize=(col, row))
for i in range(row * col):plt.subplot(row, col, i + 1)plt.imshow(random_res[i].squeeze(), cmap="gray")plt.xticks([])plt.yticks([])plt.axis("off")
plt.show()

随机生成新样本

根据潜变量插值生成新样本

from scipy.stats import normn = 15
digit_size = 28grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):t = [xi, yi]z_sampled = torch.FloatTensor(t)with torch.no_grad():decode = model.decode(z_sampled)digit = decode.view((digit_size, digit_size))figure[i * digit_size : (i + 1) * digit_size,j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

根据潜变量插值生成新样本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/685255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP PP学习笔记- 豆知识02 - 品目要谁来维护?怎么决定更不更新品目的数量金额?

其实都是在品目类型的Customize中设定的。 咱们这里简单试着说一下什么场景使用。 1,SAP中品目有很多View,都要由哪些部门来维护呢? 其实就是谁用谁维护呗。 在新建一个品目的时候,品目Type本身就决定了该品目要由哪些部门来维…

数据结构(4) 链表(链式存储)

链表(链式存储) 单链表定义基本操作的实现单链表的插入按位序插入指定节点的前插指定节点的后插 单链表的删除 小结 单链表 定义 顺序表优点:可随机存取,存储密度高,缺点:要求大片连续空间,改变容量不方便。 单链表优…

高数总结(4

目录 1.总结:小结: 1.总结: 小结: 关注我给大家分享更多有趣的知识,以下是个人公众号,提供 ||代码兼职|| ||代码问题求解|| 由于本号流量还不足以发表推广,搜我的公众号即可:

【ASP.NET Core 基础知识】--最佳实践和进阶主题--性能调优和缓存

一、性能调优 在 ASP.NET Core 中进行性能调优,代码优化是至关重要的一部分。以下是一些常见的 ASP.NET Core 代码优化技巧: 减少数据库查询: 尽可能地减少数据库查询次数,可以通过使用合适的 ORM(对象关系映射&…

Codeforces Round 926 F. Sasha and the Wedding Binary Search Tree

F. Sasha and the Wedding Binary Search Tree 题意 给定一颗二叉搜索树,规定树上的所有点的点权都在范围 [ 1 , C ] [1, C] [1,C] 内,树上的某些节点点权已知,某些节点点权未知,求出合法的二叉搜索树的数量 思路 由于是二叉搜…

这应该是全网第一篇全面解读OpenAI Sora报告的文章,精读报告:Video generation models as world simulators

今天是2024年2月16号,大年初七,年还没过完,早晨起来朋友圈就被Sora刷屏了。本来以为没啥,都是公众号或者视屏啥的,都没点开看,直到看到我导也发了Sora的文章,我就知道这个事情不简单了。 先来看…

基于Qt数据库项目实现(Sqlite3为例)|考查数据库、表格(QTableView 显示)(进阶)

01 数据库表格(QTableView 显示) 本小节设计一个生活中的例子,使用数据库修改/查询员工的编号、姓名、年龄、性别与照片信息。 本例将数据库的内容显示到 QTableView 上。如果只是简单的显示数据库的内容到QTableView 上,可以使用下面的方法,此方法 QTableView 上可以看…

支付宝沙箱版模拟网站在线完整支付流程(无营业无费用)内网穿透+局域网测试

文章目录 一、介绍1. 支付2. 支付结果 二、前提准备1. 支付宝开放平台2. 内网穿透3. 局域网 三、order微服务1. 依赖、配置2. 工具类1. 二维码生成2. AlipayConfig 3. 端点代码1. /generatepaycode2. /requestpay3. /payresult4. /receivenotify 环境如下 Version手机安卓支付…

力扣代码学习日记二

Problem: 28. 找出字符串中第一个匹配项的下标 文章目录 思路解题方法复杂度代码 思路 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的…

红队笔记Day4 -->多层代理(模拟企业拓扑)

声明:本机文章只用于教育用途,无不良引导,禁止用于从事任何违法活动 前几天的红队笔记的网络拓扑都比较简单,今天就来模拟一下企业的真实网络拓扑,以及攻击方法 一般的大企业的网络拓扑如下::…

NodeJs事件循环

NodeJs事件循环 能力有限,暂时理解不了。 网上众说纷纭,而且不同的电脑、操作环境、不同的 NodeJS 版本,结果有可能不一样。 建议自己学习的时候别用 Windows 系统,实在没条件使用虚拟机或者 WSL。

基于Springboot+Vue实现的宿舍管理系统

基于SpringbootVue的宿舍管理系统 1.系统相关性介绍1.1 系统架构1.2 设计思路 2.功能模块介绍2.1 用户信息模块2.2 宿舍管理模块2.3 信息管理模块 3. 源码获取以及远程部署 前言: 在现代教育环境中,学生宿舍的管理显得尤为重要,需要一套能…

如何在Windows中配置多个显示器?这里提供详细步骤

Windows可以通过多种方式使用多个显示器,扩展或复制主显示器。你甚至可以关闭主显示器。以下是如何使用简单的键盘快捷键更改辅助显示设置。 使用WindowsP投影菜单 要快速更改Windows 10处理多个显示器的方式,请按WindowsP。屏幕右侧会弹出一个名为“投…

单点登录(SSO,Single Sign-On)

单点登录(SSO,Single Sign-On)是一种用户身份验证服务,允许用户使用一组登录凭据(如用户名和密码)访问多个应用程序或系统。SSO旨在提高用户体验和安全性,通过减少用户需要记住的密码数量&#…

CSP-201903-1-小中大

CSP-201903-1-小中大 【坑!】 注意中位数的输出,当中位数为整数时不保留小数点。当两个整数相加后再除以 2 时,结果会自动向下取整,而不是得到一个准确的浮点数结果。需要确保至少有一个操作数是浮点数类型,以便执行浮点数除法&…

图表示学习 Graph Representation Learning chapter1 引言

图表示学习 Graph Representation Learning chapter1 引言 前言1.1图的定义1.1.1多关系图1.1.2特征信息 1.2机器学习在图中的应用1.2.1 节点分类1.2.2 关系预测1.2.3 聚类和组织检测1.2.4 图分类、回归、聚类 前言 虽然我并不研究图神经网络,但是我认为图高效的表示…

阅读笔记12:(全文翻译)Oocyte quality and aging

Oocyte quality and aging 作者:Ali Reza Eftekhari Moghadam, Mahin Taheri Moghadam, Masoud Hemadi, and Ghasem Saki 发表期刊:JBRA Assist Reprod. 发表时间:April 15, 2021 摘要 众所周知,由于卵子的质量和数量随年龄相关的变化,女性的生育能力在生命的第四个十年期…

电脑重装系统之Windows 10 企业版 LTSC 2021

简介 Windows 10 22H2对于我来说太不简洁,最受不了的一点是微软强行硬塞给我一些并没有什么luan用的应用和功能,比如:天气,Onedrive......以及臃肿的ui设计。而且强行进行自动更新,我是真的受不了这个,看着…

leetcode:343.整数拆分

解题思路: 拆分的越多越好(暂且认为),尽可能拆成m个近似相等的数,会使得乘积最大 dp含义:将i进行拆分得到最大的积为dp[i] 递推公式:j x dp[i-j](固定j,只通过凑dp[i-j]进而实现所…

optee UTA加载

流程 动态TA按照存储位置的不同分为REE filesystem TA:存放在REE侧文件系统里的TA; Early TA:被嵌入到optee os里的在supplicant启动之前就可用了。 这里我们讲的是常规的存放在REE侧文件系统里的TA。 通过GP标准调用的与TA通信的命令(opens…