VAE-变分自编码器(Variational Autoencoder,VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,结合了概率图模型与神经网络技术,广泛应用于数据生成、表示学习和数据压缩等领域。以下是对VAE的详细解释和理解:

基本概念

1. 自编码器(Autoencoder)

自编码器是一种无监督学习模型,通常用于降维和特征提取。它由两个主要部分组成:

  • 编码器(Encoder):将输入数据映射到一个低维隐变量空间。
  • 解码器(Decoder):从低维隐变量空间重建输入数据。
    自编码器的目标是使重建的数据尽可能与原始输入数据相似。

2. 变分自编码器(VAE)

VAE 是自编码器的一种扩展,它通过引入概率分布的概念来对隐变量空间进行建模。VAE 的目标不仅是重建输入数据,还要使隐变量遵循某种已知的概率分布(通常是标准正态分布)。这样可以通过采样隐变量来生成新数据。

VAE的工作原理

  1. 编码器
    在VAE中,编码器不是直接输出一个隐变量,而是输出隐变量的参数(均值 μ 和标准差 σ)。这些参数定义了隐变量的一个概率分布,通常假设为正态分布 N(μ, σ^2)。

  2. 重新参数化技巧(Reparameterization Trick)
    为了使模型能够通过梯度下降进行训练,VAE引入了重新参数化技巧。通过采样一个标准正态分布的变量 ε ~ N(0, 1),然后进行线性变换得到隐变量 z:
    在这里插入图片描述

这样,采样操作变成了一个确定性的操作,允许梯度反向传播。

  1. 解码器
    解码器接受从上述分布中采样的隐变量 z,并尝试重建输入数据。解码器的目标是最大化重建数据的概率。

损失函数

VAE 的损失函数由两部分组成:

  • 重构损失(Reconstruction Loss):衡量重建数据与原始数据的相似度,通常使用均方误差(MSE)或交叉熵损失。 KL

  • 散度(KL Divergence):衡量隐变量分布与标准正态分布的差异。通过最小化KL散度,使隐变量分布接近标准正态分布。

综合起来,VAE的损失函数为:

在这里插入图片描述

VAE的优点

  1. 生成能力:可以从隐变量空间采样生成新数据,具有良好的生成能力。
  2. 隐变量解释性:通过将隐变量空间约束为标准正态分布,隐变量具有一定的解释性和可操作性。
  3. 无监督学习:VAE是一种无监督学习模型,不需要标签数据即可进行训练。

VAE的缺点

  1. **生成质量有限:**生成数据的质量有时不如GAN(生成对抗网络)等其他生成模型。
  2. **训练复杂:**VAE的训练涉及到复杂的概率推断和优化过程。

总结

变分自编码器通过引入概率分布和重新参数化技巧,使得隐变量具有良好的生成能力和解释性。其核心思想是在保持重建数据质量的同时,使隐变量遵循标准正态分布,从而实现数据生成和表示学习。尽管存在一些缺点,但VAE在许多应用场景中仍然表现出色,并为生成模型的研究提供了重要的理论基础。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable# 定义VAE模型
class VAE(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(VAE, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc21 = nn.Linear(hidden_dim, latent_dim)self.fc22 = nn.Linear(hidden_dim, latent_dim)self.fc3 = nn.Linear(latent_dim, hidden_dim)self.fc4 = nn.Linear(hidden_dim, input_dim)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparameterize(self, mu, logvar):std = torch.exp(0.5*logvar)eps = torch.randn_like(std)return mu + eps*stddef decode(self, z):h3 = F.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x.view(-1, 784))z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 定义损失函数
def loss_function(recon_x, x, mu, logvar):BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLD# 加载MNIST数据集
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.ToTensor()),batch_size=128, shuffle=True)# 初始化模型
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)# 训练模型
def train(epoch):vae.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item() / len(data)))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))# 开始训练
for epoch in range(1, 11):train(epoch)

代码说明

  • 编码器和解码器:编码器将输入图像编码为潜在空间的均值和对数方差,解码器从潜在变量生成重建的图像。
  • Sampling层:这是实现重参数化技巧的关键部分,将均值和对数方差转换为潜在变量。
  • VAE类:组合编码器和解码器,并实现自定义训练步骤,包括计算重建损失和KL散度损失。
  • 数据准备和训练:加载MNIST数据集,对数据进行预处理,然后训练VAE模型。
    这个示例展示了一个简单的VAE模型。根据具体的应用需求,你可能需要调整网络结构和超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/14041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 Milvus Cloud + LlamaIndex 实现初级 RAG

初级 RAG 初级 RAG 的定义 初级 RAG 研究范式代表了最早的方法论,在 ChatGPT 广泛采用后不久就取得了重要地位。初级 RAG 遵循传统的流程,包括索引创建(Indexing)、检索(Retrieval)和生成(Generation),常常被描绘成一个“检索—读取”框架,其工作流包括三个关键步…

WebSocket简介

参考:Java NIO实现WebSocket服务器_nio websocket-CSDN博客 WebSocket API是HTML5中的一大特色,能够使得建立连接的双方在任意时刻相互推送消息,这意味着不同于HTTP,服务器服务器也可以主动向客户端推送消息了。 WebSocket协议是…

【PB案例学习笔记】-08 控件拖动实现

写在前面 这是PB案例学习笔记系列文章的第8篇,该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习,提高编程技巧,以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码,小凡都上传到了gitee…

1941springboot VUE 服务机构评估管理系统开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 springboot VUE服务机构评估管理系统是一套完善的完整信息管理类型系统,结合springboot框架和VUE完成本系统,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发),系统具有完整的源代…

计算机-编程相关

在 Linux 中、一切都是文件、硬件设备是文件、管道是文件、网络套接字也是文件。 for https://juejin.cn/post/6844904103437582344 fork 进程的一些问题 fork 函数比较特殊、一次调用会返回两次。在父进程和子进程都会返回。 每个进程在内核中都是一个 taskstruct 结构、for…

Thingsboard规则链:Entity Type Switch节点详解

在物联网(IoT)领域,随着设备数量的爆炸式增长和数据复杂性的增加,高效、灵活的数据处理机制变得至关重要。作为一款先进的物联网平台,ThingsBoard提供了强大的规则链(Rule Chains)功能&#xff…

第四节 Starter 加载时机和源码理解

tips:每个 springBoot 的版本不同,代码的实现存会存在不同。 上一章,我们聊到 mybatis-spring-boot-starter; 简单分析了它的结构。 这一章我们将着重分析 Starter 的加载机制,并结合源码进行分析理解。 一、加载实际…

问题与解决:element ui垂直菜单展开后显示不全

比如我这个垂直菜单展开后,其实系统管理下面还有其他子菜单,但是显示不出来了。 解决方法很简单,只需要在菜单外面包一层el-scrollbar,并且将高度设置为100vh。

Laravel 11 PHP8

一直都是用laravel 7 左右的,现在要求将项目升级到laravel 11 和使用PHP8,随手记录一些小问题,laravel 11的包是领导给的,没有使用composer 安装,所以我也不确定和官方的是否一致 遇到这问题 可以这样 env 中默认的数…

基于若依的旅游推荐管理系统(spring boot+vue+mybatis+Ajax)

一、项目目的 随着社会的高速发展,人们生活水平的不断提高,以及工作节奏的加快,旅游逐渐成为一个热门的话题,因为其形式的多样,涉及的面比较广,成为人们放松压力,调节情绪的首要选择。 传统的旅…

上位机图像处理和嵌入式模块部署(mcu的按键输入)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 做技术的同学,大部分都会把精力放在技术本身,却忽视了学的东西有什么实际的用途。就拿gpio来说,一般我们点灯也…

正确认识IP地址和子网掩码的联系

IP地址和子网掩码是计算机网络中两个非常重要的概念,它们共同确定了设备在局域网中的地址以及该地址所属的子网,只要两者结合,就能确定唯一地址IP66_ip归属地在线查询_免费ip查询_ip精准定位平台。 IP地址是用于标识计算机网络中的每台设备的…

Ajax用法总结(包括原生Ajax、Jquery、Axois)

HTTP知识 HTTP(hypertext transport protocol)协议『超文本传输协议』,协议详细规定了浏览器和万维网服务器之间互相通信的规则。 请求报文 请求行: GET、POST /s?ieutf-8...(url的一长串参数) HTTP/1.1 请求头…

Mac安装 Intellij IDEA,亲测有效M1、M2可用

引言 最近开始学习使用spring boot写一个简单的后端项目,使用Intellij IDEA软件,Intellij IDEA为新用户提供了30天的免费试用。 方案 1.官网下载Intellij IDEA IntelliJ IDEA – the Leading Java and Kotlin IDE 或者直接网盘连接下载:…

第一份工资

当我拿到我人生的第一份工资时,那是一种难以言表的激动。我记得那个下午,阳光透过窗户洒在了我的办公桌上,我看着那张支票,心中满是欣喜和自豪。那是我独立生活的开始,也是我对自己能力的一种肯定。 我记得我是如何支配…

SQL注入:pikachu靶场中的SQL注入通关

目录 1、数字型注入(post) 2、字符型注入(get) 3、搜索型注入 4、XX型注入 5、"insert/update"注入 Insert: update: 6、"delete"注入 7、"http header"注入 8、盲…

【Linux安全】Firewalld防火墙

目录 一.Firewalld概述 二.Firewalld和iptables的关系 1.firewalld和iptables的联系 2.firewalld和iptables的区别 三.Firewalld区域 1.概念 2.九个区域 3.区域介绍 4.Firewalld数据处理流程 四.Firewalld-cmd命令行操作 1.查看 2.增加 3.删除 4.修改 五.Firewa…

arping 一键检测网络设备连通性(KALI工具系列二)

目录 1、KALI LINUX简介 2、arping工具简介 3、在KALI中使用arping 3.1 目标主机IP(win) 3.2 KALI的IP 4、操作示例 4.1 IP测试 4.2 ARP测试 4.3 根据存活情况返回 5、总结 1、KALI LINUX简介 Kali Linux 是一个功能强大、多才多艺的 Linux 发…

【机器学习与大模型】驱动下的电子商务应用

摘要: 随着信息技术的飞速发展,电子商务已经成为当今商业领域中最为活跃和重要的部分之一。而机器学习和大模型的出现,为电子商务带来了新的机遇和挑战。本文深入探讨了机器学习与大模型在电子商务中的应用,包括个性化推荐、精准营…

基于双向长短期记忆 Bi-LSTM 对消费者投诉进行多类分类

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记…