深入解析与实现:变分自编码器(VAE)完整代码详解

VAE理论上一篇已经详细讲完了,虽然VAE已经是过去的东西了,但是它对后面强大的生成模型是很有指导意义的。接下来,我们简单实现一下其代码吧。

1 VAE在minist数据集上的实现

完整的代码如下,没有什么特别好讲的。

import cv2
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary""" 就用线性层构造最简单的vae吧"""class VAE(nn.Module):def __init__(self, image_size=28*28, hidden1=400, hidden2=100, latent_dims=40):super().__init__()# encoderself.encoder = nn.Sequential(nn.Linear(image_size, hidden1),nn.ReLU(),nn.Linear(hidden1, hidden2),nn.ReLU(),)self.mu = nn.Sequential(nn.Linear(hidden2, latent_dims),)self.logvar = nn.Sequential(nn.Linear(hidden2, latent_dims),)   # 由于方差是非负的,因此预测方差对数# decoderself.decoder = nn.Sequential(nn.Linear(latent_dims, hidden2),nn.ReLU(),nn.Linear(hidden2, hidden1),nn.ReLU(),nn.Linear(hidden1, image_size),nn.Tanh())# 重参数,为了可以反向传播def reparametrization(self, mu, logvar):# sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))std = 0.5 * torch.exp(logvar)# N(mu, std^2) = N(0, 1) * std + muz = torch.randn(std.size(), device=mu.device) * std + mureturn zdef forward(self, x):en = self.encoder(x)mu = self.mu(en)logvar = self.logvar(en)z = self.reparametrization(mu, logvar)return self.decoder(z), mu, logvardef loss_function(fake_imgs, real_imgs, mu, logvar):kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)reconstruction = ((real_imgs - fake_imgs)**2).sum()return kl, reconstructiondef train(num_epoch):write_fake = SummaryWriter(f'logs/fake')device = torch.device("cuda:0")trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)vae = VAE().to(device)optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)vae.train()step = 0for epoch in range(num_epoch):for batch_indx, (inputs, _) in enumerate(trainloader):inputs = inputs.to(device)real_imgs = torch.flatten(inputs, start_dim=1)fake_imgs, mu, logvar = vae(real_imgs)loss_kl, loss_re = loss_function(fake_imgs, real_imgs, mu, logvar)loss_all = loss_kl + loss_reoptimizer.zero_grad()loss_all.backward()optimizer.step()print(f"epoch:{epoch}, loss kl:{loss_kl.item()}, loss re:{loss_re.item()}, loss all:{loss_all.item()}")if batch_indx == 0:with torch.no_grad():x = torch.randn((32, 40)).to(device)fake = vae.decoder(x).reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)write_fake.add_image("Mnist Fake Image", img_grid_fake, global_step=step)step += 1if __name__ == "__main__":summary(VAE(), input_size=(1, 784))train(1000)

模型结构打印如下:
VAE [1, 784] –
├─Sequential: 1-1 [1, 100] –
│ └─Linear: 2-1 [1, 400] 314,000
│ └─ReLU: 2-2 [1, 400] –
│ └─Linear: 2-3 [1, 100] 40,100
│ └─ReLU: 2-4 [1, 100] –
├─Sequential: 1-2 [1, 40] –
│ └─Linear: 2-5 [1, 40] 4,040
├─Sequential: 1-3 [1, 40] –
│ └─Linear: 2-6 [1, 40] 4,040
├─Sequential: 1-4 [1, 784] –
│ └─Linear: 2-7 [1, 100] 4,100
│ └─ReLU: 2-8 [1, 100] –
│ └─Linear: 2-9 [1, 400] 40,400
│ └─ReLU: 2-10 [1, 400] –
│ └─Linear: 2-11 [1, 784] 314,384
│ └─Tanh: 2-12 [1, 784] –

训练结果,从结果上来看,是不如GAN的,主要原因在于其在KL散度和重建损失之间很难做到平衡,所以很难训练得好,当然原因是多方面的。
在这里插入图片描述

2 VAE的缺陷

变分自编码器(VAE, Variational Autoencoder)作为一种强大的深度学习模型,在生成建模领域有着广泛的应用,但它也存在一些缺陷,主要包括:

  • 生成样本质量:与生成对抗网络(GANs)相比,VAE生成的样本可能显得较为模糊或缺乏清晰度。尽管VAE能够生成连续且有结构的潜在空间,其生成的样本在某些情况下可能不够真实或细节不够丰富。

  • 潜在空间的连续性问题:虽然VAE设计用于学习连续的潜在空间,以允许插值和生成流畅的变化序列,但在实践中,这种连续性可能不如理论中那样完美。潜在空间中可能会出现空洞或不连贯区域,影响样本生成的质量和连续性变换的效果。

  • KL散度的平衡问题:VAE通过在其损失函数中加入KL散度项来约束潜在变量的分布,以确保它接近先验分布(通常是标准正态分布)。然而,KL散度的权重难以选择,如果设置不当,可能导致模型过分关注重构损失而忽视了潜在空间的平滑性和多样性,或者相反。

  • 训练难度与稳定性:VAE的训练过程比一些其他模型更为复杂,涉及到优化 Evidence Lower Bound (ELBO),这可能导致训练过程较为不稳定,需要更多的计算资源和更长的训练时间。特别是优化过程中对似然的近似以及对数似然的下界处理增加了训练的复杂度。

  • 表达能力与模型容量:由于VAE的编码器和解码器结构相对简单(通常为全连接层或简单的卷积层),在处理高度复杂的高维数据时,其表达能力可能受限,影响生成样本的质量和多样性。

这些缺陷提示研究者和实践者在使用VAE时需要仔细调整模型架构、损失函数的平衡以及训练策略,以最大化其生成能力和实用性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/842071.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【代码随想录】【算法训练营】【第20天】 [654]最大二叉树 [617]合并二叉树 [700]二叉搜索树中的搜索 [98]验证二叉搜索树

前言 思路及算法思维,指路 代码随想录。 题目来自 LeetCode。 day 19,一个愉快的周日~ day 20,一个悲伤的周一~ 题目详情 [654] 最大二叉树 题目描述 654 最大二叉树 解题思路 前提:构造二叉树 思路:寻找根节…

如何设置XHSC(华大)单片机的IO口中断

XHSC(华大)单片机IO口中断使用 一、代码说明 华大单片机的历程在华大或者小华的官网上都可以下载到,但是我们下载的历程基本注释都是非常简单,有的还没有注释;再加上小华跟华大的历程在代码架构上有所区别,所以新手在直接调用华大或者小华历程后,历程代码的可读性并不…

内网安全--域渗透准备知识

目录 知识点: 0x01 0x02 0x03 系列点: Linux主机信息收集 windows主机信息收集 知识点: 0、域产生原因 1、内网域的区别 2、如何判断在域内 3、域内常见信息收集 4、域内自动化工具收集 -局域网&工作组&域环境区别 -域…

# LLM高效微调详解-从Adpter、PrefixTuning到LoRA

一、背景 目前NLP主流范式是在大量通用数据上进行预训练语言模型训练,然后再针对特定下游任务进行微调,达到领域适应(迁移学习)的目的。 Context Learning v.s. SFT 指令微调是预训练语言模型微调的主流范式,其目的是…

通用代码生成器应用场景三,遗留项目反向工程

通用代码生成器应用场景三,遗留项目反向工程 如果您有一个遗留项目,要重新开发,或者源代码遗失,或者需要重新开发,但是希望复用原来的数据,并加快开发。 如果您的项目是通用代码生成器生成的,…

阿里云产品DTU评测报告(二)

阿里云产品DTU评测报告(二) 问题回顾问题处理继续执行 问题回顾 基于上一次DTU评测,在评测过程中遇到了windows系统情况下执行amp命令失败的情况,失败情况如图 导致后续命令无法执行,一时之间不知如何处理&#xff0…

python 两个表格字段列名称值,对比字段差异

支持xlsx,xls文件,相互对比字段列 输出两个表格文件相同字段,置底色为绿色 存在差异的不同字段,输出两个新的表格文件,差异字段,置底色为红色 注意点:读取的文件仅支持xlsx格式,头列需要删除…

【AD21】Gerber文件的输出

Gerber文件是对接生产的文件,该文件包含了PCB的所有层的信息,如铜层、焊盘、丝印层、阻焊层等。板厂使用这些文件来准备生产工艺。虽然可以将PCB发给板厂去打板,但是对于公司而言,直接发PCB会有泄密风险,Gerber文件会相…

《宝贵的人生建议》

致读者 2024/05/25 发表想法 简练表达,发散(灵活)运用。 原文:在写作过程中,我的主要精力是用在这个方面:把这些重要的经验教训浓缩为尽可能紧凑简炼、易于传播的语言。我鼓励读者在阅读时扩展这些“种子”…

不能错过的AI知识学习神器「Mo卡片」

1. 「Mo卡片」——知识点的另一种承载方式 1.1 产品特点 📱一款专为渴望理解和掌握人工智能知识的小伙伴量身打造的轻量级 App。 🏷AI 知识卡片集 Mo卡片内置了 26 套卡片集,总计 1387 张卡片,每张卡片都能获得 1 个核心知识。…

GpuMall智算云:AUTOMATIC1111/stable-diffusion-webui/stable-diffusion-webui-v1.8.0

配置环境介绍 目前平台集成了 Stable Diffusion WebUI 的官方镜像,该镜像中整合如下资源: GpuMall智算云 | 省钱、好用、弹性。租GPU就上GpuMall,面向AI开发者的GPU云平台 Stable Diffusion WebUI版本:v1.8.0 Python版本:3.10.…

nginx与nginx-rtmp-module安装

nginx与nginx-rtmp-module安装 画了好几天图,实在有些乏力,找点有意思的事情做做 觉得视频流传输挺有意思,B站找了些视频,但感觉有些大同小异,讲得不是很清楚 FFmpeg/RTMP/webRTC丨90分钟搞定直播逻辑-推流-流媒体服…

半年不在csdn写博客,总结一下这半年的学习经历,coderfun的一些碎碎念.

前言 自从自己建站一来,就不在csdn写博客了,但是后来自己的网站因为资金问题不能继续维护下去,所以便放弃了自建博客网站来写博客,等到以后找到稳定,打算满意的工作再来做自己的博客网站。此篇博客用来记录自己在csdn…

Git Large File Storage (LFS) 的安装与使用

Git Large File Storage [LFS] 的安装与使用 1. An open source Git extension for versioning large files2. Installing on Linux using packagecloud3. Getting Started4. Error: Failed to call git rev-parse --git-dir: exit status 128References 1. An open source Git…

Android Studio 获取 SHA1

以 debug.keystore 调试密钥库为例。 步骤1:明确 debug.keystore 位置 debug.keystore 在 .android 目录下: Windows 用户:C:\Users\用户名\.android\debug.keystore Mac 用户:/Users/用户名/.android/debug.keystore 假设我的…

【云原生】用 Helm 来简化 K8s 应用管理

用 Helm 来简化 K8s 应用管理 1.诞生背景2.主要功能3.相关概念4.工作原理5.架构演变6.Helm 常用命令7.推荐仓库8.Charts8.1 目录结构8.2 构建一个无状态应用模版 charts Helm 对于 Kubernetes 来说就相当于 Yum 对于 Centos 来说,如果没有 Yum 的话,我们…

旅游推荐管理系统

代码位置:旅游管理系统: 根据若依模版的一个旅游管理系统 - Gitee.com 分支dev 项目介绍 项目目的 随着社会的高速发展,人们生活水平的不断提高,以及工作节奏的加快,旅游逐渐成为一个热门的话题,因为其形式的多样,涉…

linux经典定时任务

在使用时记得替换为自己的脚本路径。请在相应的脚本第一行加上#!/bin/bash,否则脚本在定时任务中无法执行。 1、在每天凌晨2点执行 0 2 * * * /bin/sh bashup.sh 2、每天执行两次 下面的示例命令将在每天上午5点和下午5点执行。您可以通过逗号分隔指定多个时间戳…

IO多路复用模型原理

在linux没有实现epoll事件驱动机制之前,常规的手段是选择select和poll等IO多路复用的方法来实现并发服务程序。但是在大数据、高并发、集群情况下,select和poll的性能瓶颈就出现了,于是epoll就诞生了 Select select函数监视的文件描述符分三类:writefds、readfds和exceptf…