VAE-变分自编码器(Variational Autoencoder,VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,结合了概率图模型与神经网络技术,广泛应用于数据生成、表示学习和数据压缩等领域。以下是对VAE的详细解释和理解:

基本概念

1. 自编码器(Autoencoder)

自编码器是一种无监督学习模型,通常用于降维和特征提取。它由两个主要部分组成:

  • 编码器(Encoder):将输入数据映射到一个低维隐变量空间。
  • 解码器(Decoder):从低维隐变量空间重建输入数据。
    自编码器的目标是使重建的数据尽可能与原始输入数据相似。

2. 变分自编码器(VAE)

VAE 是自编码器的一种扩展,它通过引入概率分布的概念来对隐变量空间进行建模。VAE 的目标不仅是重建输入数据,还要使隐变量遵循某种已知的概率分布(通常是标准正态分布)。这样可以通过采样隐变量来生成新数据。

VAE的工作原理

  1. 编码器
    在VAE中,编码器不是直接输出一个隐变量,而是输出隐变量的参数(均值 μ 和标准差 σ)。这些参数定义了隐变量的一个概率分布,通常假设为正态分布 N(μ, σ^2)。

  2. 重新参数化技巧(Reparameterization Trick)
    为了使模型能够通过梯度下降进行训练,VAE引入了重新参数化技巧。通过采样一个标准正态分布的变量 ε ~ N(0, 1),然后进行线性变换得到隐变量 z:
    在这里插入图片描述

这样,采样操作变成了一个确定性的操作,允许梯度反向传播。

  1. 解码器
    解码器接受从上述分布中采样的隐变量 z,并尝试重建输入数据。解码器的目标是最大化重建数据的概率。

损失函数

VAE 的损失函数由两部分组成:

  • 重构损失(Reconstruction Loss):衡量重建数据与原始数据的相似度,通常使用均方误差(MSE)或交叉熵损失。 KL

  • 散度(KL Divergence):衡量隐变量分布与标准正态分布的差异。通过最小化KL散度,使隐变量分布接近标准正态分布。

综合起来,VAE的损失函数为:

在这里插入图片描述

VAE的优点

  1. 生成能力:可以从隐变量空间采样生成新数据,具有良好的生成能力。
  2. 隐变量解释性:通过将隐变量空间约束为标准正态分布,隐变量具有一定的解释性和可操作性。
  3. 无监督学习:VAE是一种无监督学习模型,不需要标签数据即可进行训练。

VAE的缺点

  1. **生成质量有限:**生成数据的质量有时不如GAN(生成对抗网络)等其他生成模型。
  2. **训练复杂:**VAE的训练涉及到复杂的概率推断和优化过程。

总结

变分自编码器通过引入概率分布和重新参数化技巧,使得隐变量具有良好的生成能力和解释性。其核心思想是在保持重建数据质量的同时,使隐变量遵循标准正态分布,从而实现数据生成和表示学习。尽管存在一些缺点,但VAE在许多应用场景中仍然表现出色,并为生成模型的研究提供了重要的理论基础。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable# 定义VAE模型
class VAE(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(VAE, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc21 = nn.Linear(hidden_dim, latent_dim)self.fc22 = nn.Linear(hidden_dim, latent_dim)self.fc3 = nn.Linear(latent_dim, hidden_dim)self.fc4 = nn.Linear(hidden_dim, input_dim)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparameterize(self, mu, logvar):std = torch.exp(0.5*logvar)eps = torch.randn_like(std)return mu + eps*stddef decode(self, z):h3 = F.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x.view(-1, 784))z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 定义损失函数
def loss_function(recon_x, x, mu, logvar):BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLD# 加载MNIST数据集
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.ToTensor()),batch_size=128, shuffle=True)# 初始化模型
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)# 训练模型
def train(epoch):vae.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item() / len(data)))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))# 开始训练
for epoch in range(1, 11):train(epoch)

代码说明

  • 编码器和解码器:编码器将输入图像编码为潜在空间的均值和对数方差,解码器从潜在变量生成重建的图像。
  • Sampling层:这是实现重参数化技巧的关键部分,将均值和对数方差转换为潜在变量。
  • VAE类:组合编码器和解码器,并实现自定义训练步骤,包括计算重建损失和KL散度损失。
  • 数据准备和训练:加载MNIST数据集,对数据进行预处理,然后训练VAE模型。
    这个示例展示了一个简单的VAE模型。根据具体的应用需求,你可能需要调整网络结构和超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/14041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 Milvus Cloud + LlamaIndex 实现初级 RAG

初级 RAG 初级 RAG 的定义 初级 RAG 研究范式代表了最早的方法论,在 ChatGPT 广泛采用后不久就取得了重要地位。初级 RAG 遵循传统的流程,包括索引创建(Indexing)、检索(Retrieval)和生成(Generation),常常被描绘成一个“检索—读取”框架,其工作流包括三个关键步…

AWS安全性身份和合规性之Key Management Service(KMS)

AWS Key Management Service(KMS)是一项用于创建和管理加密密钥的托管服务,可帮助客户保护其数据的安全性和机密性。 比如一家医疗保健公司需要在AWS上存储敏感的病人健康数据,需要对数据进行加密以确保数据的机密性。他们使用AW…

课时134:awk实践_逻辑控制_自定义函数

1.3.7 自定义函数 学习目标 这一节,我们从 基础知识、简单实践、小结 三个方面来学习。 基础知识 需求 虽然awk提供了内置的函数来实现相应的内置函数,但是有些功能场景,还是需要我们自己来设定,这就用到了awk的自定义函数功能…

WebSocket简介

参考:Java NIO实现WebSocket服务器_nio websocket-CSDN博客 WebSocket API是HTML5中的一大特色,能够使得建立连接的双方在任意时刻相互推送消息,这意味着不同于HTTP,服务器服务器也可以主动向客户端推送消息了。 WebSocket协议是…

使用TensorBoard记录功能时,添加SummaryWriter到callbacks,某些版本可能不适用该如何修改

如果发现将SummaryWriter直接添加到callbacks不被支持,您可以采取另一种方式来集成TensorBoard记录功能,即通过自定义回调函数来实现。Hugging Face Transformers库允许用户自定义训练回调,这可以用来在训练过程中向TensorBoard写入日志。 下…

配置yum源

以下是在 Linux 系统中配置新的 yum 源的一般步骤和命令示例(以 CentOS 系统为例): 备份原有 yum 源配置文件:mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak 创建新的 yum 源配置文件&#xff08…

【PB案例学习笔记】-08 控件拖动实现

写在前面 这是PB案例学习笔记系列文章的第8篇,该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习,提高编程技巧,以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码,小凡都上传到了gitee…

反序列化漏洞的入门知识总结

1.概念定义 序列化与反序列化的目的是让数据在传输和处理的时候更简单,更快,反序列化出现在多种同面向对象语言所开发的网站和软件上,比如java,php,python等等,如果有语言一个都没学的,可以先去…

1941springboot VUE 服务机构评估管理系统开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 springboot VUE服务机构评估管理系统是一套完善的完整信息管理类型系统,结合springboot框架和VUE完成本系统,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发),系统具有完整的源代…

【NOIP2014普及组复赛】题2:比例简化

题2:比例简化 【题目描述】 在社交媒体上,经常会看到针对某一个观点同意与否的民意调查以及结果。例如,对某一观点表示支持的有 1498 1498 1498 人,反对的有 902 902 902 人,那么赞同与反对的比例可以简单的记为 …

计算机-编程相关

在 Linux 中、一切都是文件、硬件设备是文件、管道是文件、网络套接字也是文件。 for https://juejin.cn/post/6844904103437582344 fork 进程的一些问题 fork 函数比较特殊、一次调用会返回两次。在父进程和子进程都会返回。 每个进程在内核中都是一个 taskstruct 结构、for…

ECMAScript、BOM与DOM:网页开发的三大基石

在深入Web开发的世界时,有三个核心概念构成了理解网页如何工作以及如何与之交互的基础:ECMAScript、BOM(Browser Object Model),以及DOM(Document Object Model)。本文旨在简要介绍这三个概念&a…

Thingsboard规则链:Entity Type Switch节点详解

在物联网(IoT)领域,随着设备数量的爆炸式增长和数据复杂性的增加,高效、灵活的数据处理机制变得至关重要。作为一款先进的物联网平台,ThingsBoard提供了强大的规则链(Rule Chains)功能&#xff…

第四节 Starter 加载时机和源码理解

tips:每个 springBoot 的版本不同,代码的实现存会存在不同。 上一章,我们聊到 mybatis-spring-boot-starter; 简单分析了它的结构。 这一章我们将着重分析 Starter 的加载机制,并结合源码进行分析理解。 一、加载实际…

问题与解决:element ui垂直菜单展开后显示不全

比如我这个垂直菜单展开后,其实系统管理下面还有其他子菜单,但是显示不出来了。 解决方法很简单,只需要在菜单外面包一层el-scrollbar,并且将高度设置为100vh。

Laravel 11 PHP8

一直都是用laravel 7 左右的,现在要求将项目升级到laravel 11 和使用PHP8,随手记录一些小问题,laravel 11的包是领导给的,没有使用composer 安装,所以我也不确定和官方的是否一致 遇到这问题 可以这样 env 中默认的数…

基于若依的旅游推荐管理系统(spring boot+vue+mybatis+Ajax)

一、项目目的 随着社会的高速发展,人们生活水平的不断提高,以及工作节奏的加快,旅游逐渐成为一个热门的话题,因为其形式的多样,涉及的面比较广,成为人们放松压力,调节情绪的首要选择。 传统的旅…

上位机图像处理和嵌入式模块部署(mcu的按键输入)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 做技术的同学,大部分都会把精力放在技术本身,却忽视了学的东西有什么实际的用途。就拿gpio来说,一般我们点灯也…

正确认识IP地址和子网掩码的联系

IP地址和子网掩码是计算机网络中两个非常重要的概念,它们共同确定了设备在局域网中的地址以及该地址所属的子网,只要两者结合,就能确定唯一地址IP66_ip归属地在线查询_免费ip查询_ip精准定位平台。 IP地址是用于标识计算机网络中的每台设备的…

Ajax用法总结(包括原生Ajax、Jquery、Axois)

HTTP知识 HTTP(hypertext transport protocol)协议『超文本传输协议』,协议详细规定了浏览器和万维网服务器之间互相通信的规则。 请求报文 请求行: GET、POST /s?ieutf-8...(url的一长串参数) HTTP/1.1 请求头…