【PyTorch][chapter 18][李宏毅深度学习]【无监督学习][ VAE]

前言:

          VAE——Variational Auto-Encoder,变分自编码器,是由 Kingma 等人于 2014 年提出的基于变分贝叶斯(Variational Bayes,VB)推断的生成式网络结构。与传统的自编码器通过数值的方式描述潜在空间不同,它以概率的方式描述对潜在空间的观察,在数据生成方面表现出了巨大的应用价值。VAE一经提出就迅速获得了深度生成模型领域广泛的关注,并和生成对抗网络(Generative Adversarial Networks,GAN)被视为无监督式学习领域最具研究价值的方法之一,在深度生成模型领域得到越来越多的应用。

           Durk Kingma 目前也是 OpenAI 的研究科学家

   VAE 是我深度学习过程中偏难的一部分,涉及到的理论基础:

          极大似然估计, KL 散度 ,Bayes定理,蒙特卡洛重采样思想,VI变分思想,ELBO


目录:

  1.    AE 编码器缺陷
  2.    VAE 编码器 跟AE 编码器差异
  3.    VAE 编码器
  4.     VAE 思想
  5.      Python 代码例子

一 AE 编码器缺陷

   1.1 AE 简介

   输入一张图片 x

   编码器Encoder:

                 z=f(x)  通过神经网络得到低维度的特征空间Z

   解码器Decoder:

                 \hat{x}=g(z)  通过特征空间 重构输入的图像

   损失函数:

               J=mse(x,\hat{x})

   1.2 特征空间z

           单独使用解码器Decoder

           特征空间z 维度为10,固定其它维度参数. 取其中两维参数,产生不同的

值(如下图星座图),然后通过Decoder 生成不同的图片.就会发现该维度

跟图像的某些特征有关联.

1.3 通过特征空间z重构缺陷:泛化能力差

     

       如上图:

            假设通过AE 模型训练动物的图像,特征空间Z为一维度。

      两种狗分别对应特征向量z_1,z_3, 我们取一个特征向量z_2,期望通过

     解码器输出介于两种狗中间的一个样子的一种狗。

          实际输出: ,随机输出一些乱七八糟的图像。

     原因:

          因为训练的时候,模型对训练的图像和特征空间Z的映射是离散的,对特征空间z

中没有训练过的空间没有约束,所以通过解码器输出的图像也是随机的.


二  VAE 编码器 跟AE 编码器差异

        2.1  AE 编码器特征空间

      假设特征空间Z 为一维,

  通过编码器生成的特征空间为一维空间的一个离散点c,然后通过解码器重构输入x

2.2 VAE 编码器

      通过编码器产生一个均值为u,方差为\sigma的高斯分布,然后在该分布上采样得到

特征空间的一个点c, 通过解码器重构输入. 现在特征空间Z是一个高斯分布,

泛化能力更强


三 VAE 编码器

3.1 模型简介

 输入 :x

 经过编码器 生成一个服从高斯分布的特征空间 z \sim N(u,\sigma^2) ,

通过重参数采样技巧 采样出特征点 C=\begin{bmatrix} c_1,c_2,c_3 \end{bmatrix}

 把特征点 输入解码器,重构出输入x

3.2 标准差\sigma(黄色模块)设计原理

           方差 \sigma^2   标准差 \sigma 

           因为标准差是非负的,但是经过编码器输出的可能是负的值,所以

    认为其输出值为 a=log (\sigma) ,再经过 exp 操作,得到一个非负的标准差

        \sigma=e^{a}=\sigma

      很多博主用的\sigma^2,我理解是错误的,为什么直接用 标准差

       参考3.3  苏剑林的 重参数采样 原理画出来的。

3.3 为什么要重参数采样 reparameterization trick

        我们要从p(Z|X)中采样一个Z出来,尽管我们知道了p(Z|X)是正态分布,但是均值方差都是靠模型算出来的,我们要靠这个过程反过来优化均值方差的模型。
但是“采样”这个操作是不可导的,而采样的结果是可导的

p(Z|X) 的概率可以写成如下形式

   说明

    服从 N(0,1)的标准正态分布

   从N(u,\sigma^2)中采样一个Z,相当于从N(0,I)标准正态分布中采样一个e,然后让

    Z=u+e*\sigma

   我们将从采样N(u,\sigma^2)变成了从N(0,I)中采样,然后通过参数变换得服从N(u,\sigma^2)分布。这样一来,“采样”这个操作就不用参与梯度下降了,改为采样的结果参与,使得整个模型可训练了。其中 u,\sigma是求导参数,e 为已知道参数

3.4 损失函数

         J=J_1+J_2

         该模型有两个约束条件

         1   一个输入图像x和重构的图像\hat{x},mse 误差最小

                    J_1= ||x-\hat{x}||_2

         2   特征空间Z 要服从高斯分布(使用KL 散度)

                     J_2=KL(N(u,\sigma^2)||N(0,1))

                  该值越小越好

     KL 散度简化

3.5 伪代码


四  VAE 思想

        4.1 高斯混合模型

             我们重构出m张图片 X=\begin{Bmatrix} x_1 &x_2 & ... & x_m \end{Bmatrix}

              P(X)=\prod_i^{m} P(x_i),P(X) 很复杂无法求解.

            常用的思路是通过引入隐藏变量(latent variable) Z。

           寻找 Z空间到 X空间的映射,这样我们通过在Z空间采样映射到 X  空间就可以生成新的图片。

          P(X)=\int _z P(x|z)P(z)dz   

          我们使用多个高斯分布的P(z) 去拟合P(X)的分布,这里面P(z)为已知道

            在强化学习里面,蒙特卡罗重采样也是用了该方案.

例:

如上图 P(X=红色)=2/5  ,P(X=绿色)=3/5 

 我们可以通过高斯混合模型原理的方法求解

P(X=红色)=P(X=红色|Z=正方形)*P(Z=正方形)+ P(X=红色|Z=圆形)*P(Z=圆形)

                    

 P(X=绿色)也是一样

   4.2 极大似然估计

      目标:极大似然函数

            L= logP(x) 

      已知:

            编码器的概率分布\int_z q(z|x)dz=1

       则:

          L=L*\int_z q(z|x)dz(相当于乘以1)

              =\int_z q(z|x) log P(x)dz (因为P(x)跟z 无关,可以直接拿到积分里面)

             =\int_z q(z|x)log \frac{P(z,x)}{p(z|x)}

           贝叶斯定理:

         P(z,x)=p(x)p(z|x)

          =\int q(z|x)log \frac{p(z,x)}{p(z|x)}\frac{q(z|x)}{q(z|x)}

         =\int_z q(z|x)log \frac{q(z|x)}{p(z|x)}+\int_z q(z|x)log \frac{p(z,x)}{q(z|x)}

         =KL(q(z|x)||q(z|x))+\int_z q(z|x)log \frac{p(z,x)}{q(z|x)}

   1:  VAE叫做“变分自编码器”,它跟变分法有什么联系

固定概率分布p(x)(或q(x)的情况下,对于任意的概率分布q(x)(或p(x))),都有KL(p(x)||q(x))≥0,而且只有当p(x)=q(x)时才等于零。

因为KL(p(x)∥∥q(x))实际上是一个泛函,要对泛函求极值就要用到变分法

  \geq L_b=\int_z q(z|x)log\frac{p(z,x)}{q(z|x)}

ELBO:全称为 Evidence Lower Bound,即证据下界。

上面KL(q(z|x)||q(z|x)) 我们取了下界0

        =\int_z q(z|x)log \frac{p(z)p(x|z)}{q(z|x)}

  贝叶斯定理

   p(z,x)=p(x|z)p(z)

   注意: 这里面P(Z)在4.1 高斯混合模型 是已知道的概率分布,符合高斯分布

 =\int_z q(z|x)log p(z|x)+\int_z q(z|x)log \frac{p(z)}{p(z|x)}

=-KL(q(z|x)||p(z))+H(q(z|x)||p(x||z))

我们目标值是求L 的最大值

第一项:

因为KL 散度的非负性

-KL(q(z|x)||p(z))极大值点为 p(z)=q(z|x),因为p(z)是符合高斯分布的

所以通过编码器生成的q(z|x)也要跟它概率一致,符合高斯分布。

第二项:

H(q(z|x)||p(x||z)) 

 这部分代表重构误差,我们用mse(x,\hat{x}) 来训练该部分的误差


五 Python 代码

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 26 15:47:20 2024@author: chengxf2
"""import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms   # transforms用于数据预处理# 定义变分自编码器(VAE)模型
class VAE(nn.Module):def __init__(self, latent_dim):super(VAE, self).__init__()# Encoderself.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=latent_dim*2),  # 输出均值和方差nn.ReLU())# Decoderself.decoder = nn.Sequential(nn.Linear(in_features =latent_dim , out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def reparameterize(self, mu, logvar):std = torch.exp(logvar/2.0)  # 计算标准差,Encoder 出来的可能有负的值,标准差为非负值,所以要乘以expeps = torch.randn_like(std)  # 从标准正态分布中采样噪声z = mu + eps * std  # 重参数化技巧return zdef forward(self, x):# 编码[batch, latent_dim*2]encoded = self.encoder(x)#[ z = mu|logvar]mu, logvar = torch.chunk(encoded, 2, dim=1)  # 将输出分割为均值和方差z = self.reparameterize(mu, logvar)  # 重参数化# 解码decoded = self.decoder(z)return decoded, mu, logvar# 定义训练函数
def train_vae(model, train_loader, num_epochs, learning_rate):criterion = nn.BCELoss()  # 二元交叉熵损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器model.train()  # 设置模型为训练模式for epoch in range(num_epochs):total_loss = 0.0for data in train_loader:images, _ = dataimages = images.view(images.size(0), -1)  # 展平输入图像optimizer.zero_grad()# 前向传播outputs, mu, logvar = model(images)# 计算重构损失和KL散度reconstruction_loss = criterion(outputs, images)kl_divergence = 0.5 * torch.sum( -logvar +mu.pow(2) +logvar.exp()-1)# 计算总损失loss = reconstruction_loss + kl_divergence# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()# 输出当前训练轮次的损失print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, total_loss / len(train_loader)))print('Training finished.')# 示例用法
if __name__ == '__main__':# 设置超参数latent_dim = 32  # 潜在空间维度num_epochs = 1  # 训练轮次learning_rate = 1e-4  # 学习率# 加载MNIST数据集train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)# 创建VAE模型model = VAE(latent_dim)# 训练VAE模型train_vae(model, train_loader, num_epochs, learning_rate)


VAE到底在做什么?VAE原理讲解系列#1_哔哩哔哩_bilibili

VAE里面的概率知识。VAE原理讲解系列#2_哔哩哔哩_bilibili

vae损失函数怎么理解? - 知乎

如何搭建VQ-VAE模型(Pytorch代码)_哔哩哔哩_bilibili

变分自编码器(一):原来是这么一回事 - 科 学空间|Scientific Spaces

16: Unsupervised Learning - Auto-encoder_哔哩哔哩_bilibili

【生成模型VAE】十分钟带你了解变分自编码器及搭建VQ-VAE模型(Pytorch代码)!简单易懂!—GAN/机器学习/监督学习_哔哩哔哩_bilibili

[diffusion] 生成模型基础 VAE 原理及实现_哔哩哔哩_bilibili

[论文简析]VAE: Auto-encoding Variational Bayes[1312.6114]_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/704821.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

排序(9.17)

1.排序的概念及其运用 1.1排序的概念 排序 :所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性 :假定在待排序的记录序列中,存在多个具有相同的关键字的记…

实战 vue3 使用百度编辑器ueditor

前言 在开发项目由于需求vue自带对编辑器不能满足使用,所以改为百度编辑器,但是在网上搜索发现都讲得非常乱,所以写一篇使用流程的文章 提示:以下是本篇文章正文内容,下面案例可供参考 一、下载ueditor编辑器 一个“…

三数之和(哈希,双指针)

15. 三数之和 - 力扣&#xff08;LeetCode&#xff09; 方法1&#xff1a;哈希算法&#xff08;不推荐&#xff09; 缺点&#xff1a;时间复杂度O&#xff08;N^2&#xff09;&#xff0c;去重情况复杂 class Solution { public:vector<vector<int>> threeSum(ve…

【Java EE初阶二十五】简单的表白墙(一)

1. 前端部分 1.1 前端代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"wid…

2步破解官方sublime4

sublime简要破解流程 1.下载sublime官方最新版2. 破解流程 1.下载sublime官方最新版 打开 官方网站下载 portable version 版&#xff0c;省的安装。。解压到任意位置&#xff0c;备份 sublime_text.exe 文件 2. 破解流程 打开网址把文件 sublime_text.exe 拖入网页搜索替换…

【非递归版】归并排序算法(2)

目录 MergeSortNonR归并排序 非递归&归并排序VS快速排序 整体思想 图解分析​ 代码实现 时间复杂度 归并排序在硬盘上的应用&#xff08;外排序&#xff09; MergeSortNonR归并排序 前面的快速排序的非递归实现&#xff0c;我们借助栈实现。这里我们能否也借助栈去…

国产服务器操作系统

为何记录 最近的开发工作经常接触到国产服务器操作系统的业务&#xff0c;经常被x86、arm、龙芯、鲲鹏、欧拉...搞得一脸懵逼&#xff0c;遂记之&#xff01; 操作系统 这里按照应用场景分&#xff1a; 桌面操作系统&#xff1a;主要用于pc&#xff0c;如Windows、macOS、Li…

MATLAB练习题:电子管的更换策略问题

​讲解视频&#xff1a;可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇&#xff08;数学建模清风主讲&#xff0c;适合零基础同学观看&#xff09;_哔哩哔哩_bilibili 在一台设备上&#xff0c;安装有四只型号和规格完全相同的电子…

腾讯design vue项目 上传桶 腾讯云的桶 对象存储 打包web端项目上传dist

1.说明 将腾讯design 项目上传到 腾讯云的对象存储中 &#xff0c;但是发现 再这个腾讯design项目中 直接npm run build 打包以后 上传 发现 不能用 需要配置东西 2.解决 使用腾讯云的cos-nodejs-sdk-v5 插件 代码上传 cos-nodejs-sdk-v5 - npm npm i cos-nodejs-sdk-v5 …

[算法沉淀记录]排序算法 —— 快速排序

排序算法 —— 快速排序介绍 基本概念 快速排序&#xff08;Quicksort&#xff09;是一种排序算法&#xff0c;最早由东尼霍尔提出。在平均状况下&#xff0c;排序 n 个项目要 Ο(n log n) 次比较。在最坏状况下则需要 Ο(n2) 次比较&#xff0c;但这种状况并不常见。事实上&…

《论文阅读》一个基于情感原因的在线共情聊天机器人 SIGIR 2021

《论文阅读》一个基于情感原因的在线共情聊天机器人 前言简介数据集构建模型架构损失函数实验结果咨询策略总结前言 亲身阅读感受分享,细节画图解释,再也不用担心看不懂论文啦~ 无抄袭,无复制,纯手工敲击键盘~ 今天为大家带来的是《Towards an Online Empathetic Chatbot…

EMQX Enterprise 5.5 发布:新增 Elasticsearch 数据集成

EMQX Enterprise 5.5.0 版本已正式发布&#xff01; 在这个版本中&#xff0c;我们引入了一系列新的功能和改进&#xff0c;包括对 Elasticsearch 的集成、Apache IoTDB 和 OpenTSDB 数据集成优化、授权缓存支持排除主题等功能。此外&#xff0c;新版本还进行了多项改进以及 B…

设计模式(二)单例模式的七种写法

相关文章设计模式系列 面试的时候&#xff0c;问到许多年轻的Android开发他所会的设计模式是什么&#xff0c;基本上都会提到单例模式&#xff0c;但是对单例模式也是一知半解&#xff0c;在Android开发中我们经常会运用单例模式&#xff0c;所以我们还是要更了解单例模式才对…

vue3 使用qrcodejs2-fix生成二维码并可下载保存

直接上代码 <el-button click‘setEwm’>打开弹框二维码</el-button><el-dialog v-model"centerDialogVisible" align-center ><div class"code"><div class"content" id"qrCodeUrl" ref"qrCodeUrl&q…

【Vue】组件通信组件通信

&#x1f4dd;个人主页&#xff1a;五敷有你 &#x1f525;系列专栏&#xff1a;JVM ⛺️稳中求进&#xff0c;晒太阳 组件通信 组件通信&#xff0c;就是指组件与组件之间的数据传递 组件的数据是独立的&#xff0c;无法直接访问其他组件的数据想用其他组件的数据--&…

Qt5转Qt6笔记

背景 现在的主程序和扩展的dll库都是qt5环境下编译发布的。但是想以后用qt6。所以考虑是否能够在qt5中兼容qt6的动态链接库进行加载。于是...就开始吧 开始 2024-02-23 安装好qt6后&#xff0c;在vs2019中需要新增qt6版本的安装路径。目录在&#xff1a;扩展->QT VS Tools…

Linux笔记--硬链接与软链接

一、硬链接 1.inode和block 文件包含两部分数据&#xff1a;文件属性和实际内容&#xff0c;属性放在inode中&#xff0c;实际内容放在data block中。还有个超级区块&#xff08;superblock&#xff09;记录整个文件系统的整体信息&#xff0c;包括inode和block的总量&#x…

python 循环语句 while 循环

while循环 Python 编程中 while 语句用于循环执行程序&#xff0c;即在某条件下&#xff0c;循环执行某段程序&#xff0c;以处理需要重复处理的相同任务。其基本形式为&#xff1a; while 判断条件(condition)&#xff1a; 执行语句(statements)…… 执行语句可以是单个语句…

[Docker 教学] 常用的Docker 命令

Docker是一种流行的容器化技术。使用Docker可以将数据科学应用程序连同代码和所需的依赖关系打包成一个名为镜像的便携式工件。因此&#xff0c;Docker可以简化开发环境的复制&#xff0c;并使本地开发变得轻松。 以下是一些必备的Docker命令列表&#xff0c;这些命令将在你下一…

golang学习6,glang的web的restful接口传参

1.get传参 //get请求 返回json 接口传参r.GET("/getJson/:id", controller.GetUserInfo) 1.2.接收处理 package controllerimport "github.com/gin-gonic/gin"func GetUserInfo(c *gin.Context) {_ c.Param("id")ReturnSucess(c, 200, &quo…