深入解析与实现:变分自编码器(VAE)完整代码详解

VAE理论上一篇已经详细讲完了,虽然VAE已经是过去的东西了,但是它对后面强大的生成模型是很有指导意义的。接下来,我们简单实现一下其代码吧。

1 VAE在minist数据集上的实现

完整的代码如下,没有什么特别好讲的。

import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary""" 就用线性层构造最简单的vae吧"""class VAE(nn.Module):def __init__(self, image_size=28*28, hidden1=400, hidden2=100, latent_dims=40):super().__init__()# encoderself.encoder = nn.Sequential(nn.Linear(image_size, hidden1),nn.ReLU(),nn.Linear(hidden1, hidden2),nn.ReLU(),)self.mu = nn.Sequential(nn.Linear(hidden2, latent_dims),)self.logvar = nn.Sequential(nn.Linear(hidden2, latent_dims),)   # 由于方差是非负的,因此预测方差对数# decoderself.decoder = nn.Sequential(nn.Linear(latent_dims, hidden2),nn.ReLU(),nn.Linear(hidden2, hidden1),nn.ReLU(),nn.Linear(hidden1, image_size),nn.Tanh())# 重参数,为了可以反向传播def reparametrization(self, mu, logvar):# sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))std = 0.5 * torch.exp(logvar)# N(mu, std^2) = N(0, 1) * std + muz = torch.randn(std.size(), device=mu.device) * std + mureturn zdef forward(self, x):en = self.encoder(x)mu = self.mu(en)logvar = self.logvar(en)z = self.reparametrization(mu, logvar)return self.decoder(z), mu, logvardef loss_function(fake_imgs, real_imgs, mu, logvar):kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)reconstruction = ((real_imgs - fake_imgs)**2).sum()return kl, reconstructiondef train(num_epoch):write_fake = SummaryWriter(f'logs/fake')device = torch.device("cuda:0")trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)vae = VAE().to(device)optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)vae.train()step = 0for epoch in range(num_epoch):for batch_indx, (inputs, _) in enumerate(trainloader):inputs = inputs.to(device)real_imgs = torch.flatten(inputs, start_dim=1)fake_imgs, mu, logvar = vae(real_imgs)loss_kl, loss_re = loss_function(fake_imgs, real_imgs, mu, logvar)loss_all = loss_kl + loss_reoptimizer.zero_grad()loss_all.backward()optimizer.step()print(f"epoch:{epoch}, loss kl:{loss_kl.item()}, loss re:{loss_re.item()}, loss all:{loss_all.item()}")if batch_indx == 0:with torch.no_grad():x = torch.randn((32, 40)).to(device)fake = vae.decoder(x).reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)write_fake.add_image("Mnist Fake Image", img_grid_fake, global_step=step)step += 1if __name__ == "__main__":summary(VAE(), input_size=(1, 784))train(1000)

模型结构打印如下:
VAE [1, 784] –
├─Sequential: 1-1 [1, 100] –
│ └─Linear: 2-1 [1, 400] 314,000
│ └─ReLU: 2-2 [1, 400] –
│ └─Linear: 2-3 [1, 100] 40,100
│ └─ReLU: 2-4 [1, 100] –
├─Sequential: 1-2 [1, 40] –
│ └─Linear: 2-5 [1, 40] 4,040
├─Sequential: 1-3 [1, 40] –
│ └─Linear: 2-6 [1, 40] 4,040
├─Sequential: 1-4 [1, 784] –
│ └─Linear: 2-7 [1, 100] 4,100
│ └─ReLU: 2-8 [1, 100] –
│ └─Linear: 2-9 [1, 400] 40,400
│ └─ReLU: 2-10 [1, 400] –
│ └─Linear: 2-11 [1, 784] 314,384
│ └─Tanh: 2-12 [1, 784] –

训练结果,从结果上来看,是不如GAN的,主要原因在于其在KL散度和重建损失之间很难做到平衡,所以很难训练得好,当然原因是多方面的。
在这里插入图片描述

2 VAE的缺陷

变分自编码器(VAE, Variational Autoencoder)作为一种强大的深度学习模型,在生成建模领域有着广泛的应用,但它也存在一些缺陷,主要包括:

  • 生成样本质量:与生成对抗网络(GANs)相比,VAE生成的样本可能显得较为模糊或缺乏清晰度。尽管VAE能够生成连续且有结构的潜在空间,其生成的样本在某些情况下可能不够真实或细节不够丰富。

  • 潜在空间的连续性问题:虽然VAE设计用于学习连续的潜在空间,以允许插值和生成流畅的变化序列,但在实践中,这种连续性可能不如理论中那样完美。潜在空间中可能会出现空洞或不连贯区域,影响样本生成的质量和连续性变换的效果。

  • KL散度的平衡问题:VAE通过在其损失函数中加入KL散度项来约束潜在变量的分布,以确保它接近先验分布(通常是标准正态分布)。然而,KL散度的权重难以选择,如果设置不当,可能导致模型过分关注重构损失而忽视了潜在空间的平滑性和多样性,或者相反。

  • 训练难度与稳定性:VAE的训练过程比一些其他模型更为复杂,涉及到优化 Evidence Lower Bound (ELBO),这可能导致训练过程较为不稳定,需要更多的计算资源和更长的训练时间。特别是优化过程中对似然的近似以及对数似然的下界处理增加了训练的复杂度。

  • 表达能力与模型容量:由于VAE的编码器和解码器结构相对简单(通常为全连接层或简单的卷积层),在处理高度复杂的高维数据时,其表达能力可能受限,影响生成样本的质量和多样性。

这些缺陷提示研究者和实践者在使用VAE时需要仔细调整模型架构、损失函数的平衡以及训练策略,以最大化其生成能力和实用性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/842071.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营day31 | 491.递增子序列、46.全排列、47.全排列 II

491.递增子序列 未去重的代码 class Solution:def findSubsequences(self, nums: List[int]) -> List[List[int]]:result []self.backtracking(nums, result, [], 0)return resultdef backtracking(self, nums, result, path, startIndex):if len(path) > 2:result.ap…

【代码随想录】【算法训练营】【第20天】 [654]最大二叉树 [617]合并二叉树 [700]二叉搜索树中的搜索 [98]验证二叉搜索树

前言 思路及算法思维,指路 代码随想录。 题目来自 LeetCode。 day 19,一个愉快的周日~ day 20,一个悲伤的周一~ 题目详情 [654] 最大二叉树 题目描述 654 最大二叉树 解题思路 前提:构造二叉树 思路:寻找根节…

python两个列表如何取交集

在Python编程中,我们经常需要处理各种数据集合,包括列表(list)。有时候,我们可能想要找出两个列表中的共同元素,这通常被称为取交集。下面,我将介绍几种在Python中实现两个列表取交集的方法。 …

如何设置XHSC(华大)单片机的IO口中断

XHSC(华大)单片机IO口中断使用 一、代码说明 华大单片机的历程在华大或者小华的官网上都可以下载到,但是我们下载的历程基本注释都是非常简单,有的还没有注释;再加上小华跟华大的历程在代码架构上有所区别,所以新手在直接调用华大或者小华历程后,历程代码的可读性并不…

内网安全--域渗透准备知识

目录 知识点: 0x01 0x02 0x03 系列点: Linux主机信息收集 windows主机信息收集 知识点: 0、域产生原因 1、内网域的区别 2、如何判断在域内 3、域内常见信息收集 4、域内自动化工具收集 -局域网&工作组&域环境区别 -域…

Hinton揭秘GPT之父【Ilya】成长历程:Scaling Law是他学生时代就有的直觉

2003年夏天的一个周日,AI教父Hinton在多伦多大学的办公室里敲代码,突然响起略显莽撞的敲门声。门外站着一位年轻的学生,说自己整个夏天都在打工炸薯条,但更希望能加入Hinton的实验室工作。Hinton问,你咋不预约呢&#…

SQLite 如何导出某些SQLite3表的数据

https://deepinout.com/sqlite/sqlite-questions/44_sqlite_how_do_i_dump_the_data_of_some_sqlite3_tables.html 要导出整个SQLite3数据库的数据,可以使用SQLite3的.dump命令。首先,打开终端或命令提示符,并进入SQLite3终端会话。然后&…

# LLM高效微调详解-从Adpter、PrefixTuning到LoRA

一、背景 目前NLP主流范式是在大量通用数据上进行预训练语言模型训练,然后再针对特定下游任务进行微调,达到领域适应(迁移学习)的目的。 Context Learning v.s. SFT 指令微调是预训练语言模型微调的主流范式,其目的是…

嵌入式C语言--基础知识

嵌入式C语言–基础知识 嵌入式C语言--基础知识 嵌入式C语言--基础知识一. 含参数的宏与函数的不同点1)函数2)宏 二. scanf格式化输入的注意事项三. 指针1)指针变量(地址变量)2)指针常见含义 四. 数组五. 数组与指针的区…

解读 Nginx:构建高效反向代理和负载均衡的秘密

解读 Nginx:构建高效反向代理和负载均衡的秘密 一、简介 Nginx (Engine-X) 是一个高性能的 HTTP 和反向代理服务器,也是一个 IMAP/POP3/SMTP 代理服务器。Nginx 以其高并发、高可靠性、低内存消耗等特点,成为了众多互联网公司首选的服务器软…

通用代码生成器应用场景三,遗留项目反向工程

通用代码生成器应用场景三,遗留项目反向工程 如果您有一个遗留项目,要重新开发,或者源代码遗失,或者需要重新开发,但是希望复用原来的数据,并加快开发。 如果您的项目是通用代码生成器生成的,…

在智慧城市建设中,大数据发挥着怎样的关键作用?

在智慧城市建设中,大数据发挥着以下关键作用: 数据采集与监测:大数据技术能够帮助城市采集和监测各种数据,包括气象、环境、交通、能源等方面的数据。这些数据可以用来分析和预测城市的运行情况,并为城市的各个部门提供…

阿里云产品DTU评测报告(二)

阿里云产品DTU评测报告(二) 问题回顾问题处理继续执行 问题回顾 基于上一次DTU评测,在评测过程中遇到了windows系统情况下执行amp命令失败的情况,失败情况如图 导致后续命令无法执行,一时之间不知如何处理&#xff0…

20 道大模型面试问题(含答案)

大型语言模型在生成式人工智能(GenAI)和人工智能(AI)中正变得越来越有价值。这些复杂的算法增强了人类的能力,并在各个领域促进了效率和创造力。 节前,我们组织了一场算法岗技术&面试讨论会&#xff0…

python 两个表格字段列名称值,对比字段差异

支持xlsx,xls文件,相互对比字段列 输出两个表格文件相同字段,置底色为绿色 存在差异的不同字段,输出两个新的表格文件,差异字段,置底色为红色 注意点:读取的文件仅支持xlsx格式,头列需要删除…

【AD21】Gerber文件的输出

Gerber文件是对接生产的文件,该文件包含了PCB的所有层的信息,如铜层、焊盘、丝印层、阻焊层等。板厂使用这些文件来准备生产工艺。虽然可以将PCB发给板厂去打板,但是对于公司而言,直接发PCB会有泄密风险,Gerber文件会相…

《宝贵的人生建议》

致读者 2024/05/25 发表想法 简练表达,发散(灵活)运用。 原文:在写作过程中,我的主要精力是用在这个方面:把这些重要的经验教训浓缩为尽可能紧凑简炼、易于传播的语言。我鼓励读者在阅读时扩展这些“种子”…

不能错过的AI知识学习神器「Mo卡片」

1. 「Mo卡片」——知识点的另一种承载方式 1.1 产品特点 📱一款专为渴望理解和掌握人工智能知识的小伙伴量身打造的轻量级 App。 🏷AI 知识卡片集 Mo卡片内置了 26 套卡片集,总计 1387 张卡片,每张卡片都能获得 1 个核心知识。…

GpuMall智算云:AUTOMATIC1111/stable-diffusion-webui/stable-diffusion-webui-v1.8.0

配置环境介绍 目前平台集成了 Stable Diffusion WebUI 的官方镜像,该镜像中整合如下资源: GpuMall智算云 | 省钱、好用、弹性。租GPU就上GpuMall,面向AI开发者的GPU云平台 Stable Diffusion WebUI版本:v1.8.0 Python版本:3.10.…