【TensorFlow-windows】学习笔记六——变分自编码器

#前言

对理论没兴趣的直接看代码吧,理论一堆,而且还有点复杂,我自己的描述也不一定准确,但是代码就两三句话搞定了。

国际惯例,参考博文

论文:Tutorial on Variational Autoencoders

【干货】一文读懂什么是变分自编码器

CS598LAZ - Variational Autoencoders

MusicVAE: Creating a palette for musical scores with machine learning

【Learning Notes】变分自编码器(Variational Auto-Encoder,VAE)

花式解释AutoEncoder与VAE

#理论

##基础知识

似然函数(引自百度百科)

似然函数是关于统计模型中的参数的函数,表示模型参数的似然性。在给定输出xxx时,关于参数θ\thetaθ的似然函数L(θ∣x)L(\theta|x)L(θx)在数值上等于给定参数θ\thetaθ后变量XXX的概率:
L(θ∣x)=P(X=x∣θ)L(\theta|x)=P(X=x|\theta) L(θx)=P(X=xθ)
有两个比较有趣的说法来区分概率与似然的关系,比如抛硬币的例子:

  • 概率说法:对于“一枚正反对称的硬币上抛十次”这种事件,问硬币落地时十次都是正面向上的“概率”是多少
  • 似然说法:对于“一枚硬币上抛十次”,问这枚硬币正反面对称的“似然”程度是多少。

极大似然估计

(摘自西瓜书)两大学派:

  • 频率主义学:参数是固定的,通过优化似然函数来确定参数
  • 贝叶斯学派:参数是变化的,且本身具有某种分布,先假设参数服从某个先验分布,然后基于观测到的数据来计算参数的后验分布

极大似然估计(Maximum Likelihood Estimation,MLE)源自频率主义学。

假设DDD是第ccc类样本的集合,比如所有的数字333的图片集合,假设它们是独立同分布的,则参数θ\thetaθ对于数据集DDD的似然是:
P(D∣θ)=∏x∈DP(x∣θ)P(D|\theta)=\prod_{x\in D}P(x|\theta) P(Dθ)=xDP(xθ)
极大似然估计就是寻找一个θ\thetaθ使得样本xxx出现的概率最大

但是上面的连乘比较难算,这就出现了对数似然
L(θ)=log⁡P(D∣θ)=∑x∈Dlog⁡P(x∣D)L(\theta)=\log P(D|\theta)=\sum_{x\in D}\log P(x|D) L(θ)=logP(Dθ)=xDlogP(xD)
我们的目标就是求参数θ\thetaθ的极大似然估计θ^\hat{\theta}θ^
θ^=arg⁡max⁡θL(θ)\hat{\theta}=\arg \max_{\theta}L(\theta) θ^=argθmaxL(θ)
例子:在连续属性情况下,如果样本集合概率密度函数p(x∣c)∼N(μ,σ2)p(x|c)\sim N(\mu,\sigma^2)p(xc)N(μ,σ2),那么参数μ,σ2\mu,\sigma^2μ,σ2的极大似然估计就是
μ^=1∣D∣∑x∈Dxσ^2=1∣D∣∑x∈D(x−μ^)(x−μ^)T\begin{aligned} \hat{\mu}&=\frac{1}{|D|}\sum_{x\in D}x\\ \hat{\sigma}^2&=\frac{1}{|D|}\sum_{x \in D}(x-\hat{\mu})(x-\hat{\mu})^T \end{aligned} μ^σ^2=D1xDx=D1xD(xμ^)(xμ^)T
其实就是计算均值和方差了。这样想,这些样本就服从这个高斯分布,那么把高斯分布直接当做参数,一定能够大概率得到此类样本,也就是说用333的样本所服从的高斯分布作为模型参数一定能使333出现的概率P(x∣θ)P(x|\theta)P(xθ)最大。

###期望值最大化算法(EM)

这一部分简单说一下即可,详细的在我前面的博客HMM——前向后向算法中有介绍,主要有两步:

  • E步:求Q函数Q(θ,θ(i))Q(\theta,\theta^{(i)})Q(θ,θ(i)),这个θ(i)\theta^{(i)}θ(i)就是当前迭代次数iii对应的参数值,Q函数实际就是对数联合似然函数log⁡P(X,Z∣θ)\log P(X,Z|\theta)logP(X,Zθ)在分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))下的期望
    Q(θ,θ(i))=EZ∣X,θ(i)L(θ∣X,Z)=Ez[log⁡P(X,Z∣θ)∣X,θ(i)]=∑ZP(Z∣X,θ(i))log⁡P(X,Z∣θ)\begin{aligned} Q(\theta,\theta^{(i)})&=E_{Z|X,\theta^{(i)}}L(\theta|X,Z)\\ &=E_z\left[\log P(X,Z|\theta)|X,\theta^{(i)}\right]\\ &=\sum_Z P(Z|X,\theta^{(i)})\log P(X,Z|\theta) \end{aligned} Q(θ,θ(i))=EZX,θ(i)L(θX,Z)=Ez[logP(X,Zθ)X,θ(i)]=ZP(ZX,θ(i))logP(X,Zθ)

  • M步:求使得Q函数最大化的参数θ​\theta​θ,并将其作为下一步的θ(i)​\theta^{(i)}​θ(i)
    θ(i+1)=arg⁡max⁡θQ(θ,θ(i))\theta^{(i+1)}=\arg\max_\theta Q(\theta,\theta^{(i)}) θ(i+1)=argθmaxQ(θ,θ(i))

从西瓜书上再摘点主要内容过来:

有时候样本的一些属性可以观测到,而另一些属性观测不到,所以就定义未观测变量为隐变量,设XXX为可观测变量,ZZZ为隐变量,θ\thetaθ为模型参数,则可写出对数似然:
L(θ∣X,Z)=ln⁡P(X,Z∣θ)L(\theta|X,Z)=\ln P(X,Z|\theta) L(θX,Z)=lnP(X,Zθ)
但是ZZZ又不知道,所以采用边缘化(marginal)方法消除它
L(θ∣X)=ln⁡P(X∣θ)=ln⁡∑ZP(X,Z∣θ)=∑i=1Nln⁡{∑ZP(xi,Z∣θ)}L(\theta|X)=\ln P(X|\theta)=\ln\sum_Z P(X,Z|\theta)=\sum_{i=1}^N\ln\left\{\sum_Z P(x_i,Z|\theta)\right\} L(θX)=lnP(Xθ)=lnZP(X,Zθ)=i=1Nln{ZP(xi,Zθ)}
使用EM算法求解参数的方法是:

  • 基于θ(i)\theta^{(i)}θ(i)推断隐变量ZZZ的期望,记为ZtZ^tZt
  • 基于已观测变量XXXZtZ^tZt对参数θ\thetaθ做极大似然估计,求得θ(i+1)\theta^{(i+1)}θ(i+1)

【注】是不是感觉很像坐标下降法

变分推断

(摘自西瓜书)

变分推断是通过使用已知简单分布来逼近需推断的复杂分布,并通过限制近似分布的类型,从而得到一种局部最优,但具有确定解的近似后验分布。

继续看上面的EM算法的M步,我们得到了:
θ(i+1)=arg⁡max⁡θQ(θ,θ(i))=arg⁡max⁡θ∑ZP(Z∣x,θ(i))ln⁡P(x,Z∣θ)\begin{aligned} \theta^{(i+1)}&=\arg\max_\theta Q(\theta,\theta^{(i)})\\ &=\arg\max_\theta \sum_ZP\left(Z|x,\theta^{(i)}\right)\ln P(x,Z|\theta) \end{aligned} θ(i+1)=argθmaxQ(θ,θ(i))=argθmaxZP(Zx,θ(i))lnP(x,Zθ)
还记得QQQ函数的意义吧,对数联合似然函数ln⁡P(X,Z∣θ)\ln P(X,Z|\theta)lnP(X,Zθ)在分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))下的期望。当分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))与变量ZZZ的真实后验分布相等的时候,QQQ函数就近似于对数似然函数,因而EM算法能够获得稳定的参数θ\thetaθ,且隐变量ZZZ的分布也能通过该参数获得。

但是通常情况下,P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))只是隐变量ZZZ所服从的真实分布的近似,若用Q(Z)Q(Z)Q(Z)表示,则
ln⁡P(X)=L(Q)+KL(Q∣∣P)\ln P(X)=L(Q)+KL(Q||P) lnP(X)=L(Q)+KL(QP)
其中
L(Q)=∫Q(Z)ln⁡{P(X,Z)Q(Z)}dZKL(Q∣∣P)=−∫Q(Z)ln⁡P(Z∣X)Q(Z)dZL(Q)=\int Q(Z)\ln\left\{\frac{P(X,Z)}{Q(Z)}\right\}dZ\\ KL(Q||P)=-\int Q(Z)\ln \frac{P(Z|X)}{Q(Z)}dZ L(Q)=Q(Z)ln{Q(Z)P(X,Z)}dZKL(QP)=Q(Z)lnQ(Z)P(ZX)dZ
但是,这个ZZZ模型可能很复杂,导致E步的P(Z∣X,θ)P(Z|X,\theta)P(ZX,θ)比较难推断,这时候就借用变分推断了,假设ZZZ服从分布
Q(Z)=∏i=1MQi(Zi)Q(Z)=\prod_{i=1}^MQ_i(Z_i) Q(Z)=i=1MQi(Zi)
也就是说多变量ZZZ可拆解为一系列相互独立的多变量ZiZ_iZi,可以另QiQ_iQi是非常简单的分布。

【PS】浅尝辄止了,经过层层理论已经引出了变分自编码的主要思想,变分推断,使用简单分布逼近复杂分布,实际上,变分自编码所使用的简单分布就是高斯分布,用多个高斯分布来逼近隐变量,随后利用服从这些分布的隐变量重构我们想要的数据。

变分自编码

先看优化目标ELBO(Evidence Lower Bound):
ELBO=log⁡p(x)−KL[q(z∣x)∣∣p(z∣x)]ELBO=\log p(x)-KL\left[q(z|x)||p(z|x)\right] ELBO=logp(x)KL[q(zx)p(zx)]
其中qqq是假设分布,ppp是真实分布,我们希望最大化第一项而最小化KL距离,所以整个规则就是最大化ELBOELBOELBO,但是这里面有个p(z∣x)p(z|x)p(zx)代表隐变量的真实分布,这个是无法求解的,所以需要简化:

简化结果是:
log⁡p(x)−KL(q(z)∣∣p(z∣x))=Ez∼q[log⁡P(x∣z)]−KL(q(z)∣∣p(z))\log p(x)-KL(q(z)||p(z|x))=E_{z\sim q}\left[\log P(x|z)\right]-KL(q(z)||p(z)) logp(x)KL(q(z)p(zx))=Ezq[logP(xz)]KL(q(z)p(z))
证明:

假设KL距离为KL(q(z)∣∣p(z∣x))=Ez∼q[log⁡q(z)−log⁡p(z∣x)]KL(q(z)||p(z|x))=E_{z\sim q}\left[\log q(z)-\log p(z|x)\right]KL(q(z)p(zx))=Ezq[logq(z)logp(zx)]

那么直接使用贝叶斯准则:

  • p(z∣x)=p(x∣z)p(z)p(x)p(z|x)=\frac{p(x|z)p(z)}{p(x)}p(zx)=p(x)p(xz)p(z)
  • log⁡p(z∣x)=log⁡p(x∣z)+log⁡p(z)−log⁡p(x)\log p(z|x)=\log p(x|z)+\log p(z)-\log p(x)logp(zx)=logp(xz)+logp(z)logp(x)
  • p(x)p(x)p(x)不依赖于z

可以得到:
KL(q(z)∣∣p(z∣x))=Ez∼q(log⁡q(z)−log⁡p(x∣z)−log⁡p(z))+log⁡p(x)KL(q(z)||p(z|x))=E_{z\sim q}(\log q(z)-\log p(x|z)-\log p(z))+\log p(x) KL(q(z)p(zx))=Ezq(logq(z)logp(xz)logp(z))+logp(x)
其中Ez∼q(log⁡q(z)−log⁡p(z))=KL(q(z)∣∣p(z))E_{z\sim q}(\log q(z)-\log p(z))=KL(q(z)||p(z))Ezq(logq(z)logp(z))=KL(q(z)p(z))

所以就能继续简化那个KL(q(z)∣∣p(z∣x))KL(q(z)||p(z|x))KL(q(z)p(zx))了:
log⁡p(x)−KL(q(z)∣p(z∣x))=Ez∼q[log⁡p(x∣z)]−KL(q(z)∣∣p(z))\log p(x)-KL(q(z)|p(z|x))=E_{z\sim q}[\log p(x|z)]-KL(q(z)||p(z)) logp(x)KL(q(z)p(zx))=Ezq[logp(xz)]KL(q(z)p(z))
证毕

但是我们的优化目标是
ELBO=log⁡p(x)−KL[q(z∣x)∣∣p(z∣x)]ELBO=\log p(x)-KL\left[q(z|x)||p(z|x)\right] ELBO=logp(x)KL[q(zx)p(zx)]
发现一个是q(z)q(z)q(z)一个是q(z∣x)q(z|x)q(zx),怎么办呢?看论文第8页有这样一句话:

Note that X is fixed, and Q can be any distribution, not just a distribution which does a good job mapping X to the z’s that can produce X. Since we’re interested in inferring P(X), it makes sense to construct a Q which does depend on X, and in particular, one which makes D [Q(z)k|P(z|X)] small

翻译一下意思就是:XXX是固定的(因为它是样本集),Q也可是任意分布,并非仅是能够生成XXX的分布,因为我们想推断P(X)P(X)P(X),那么构建一个依赖于XXXQQQ分布是可行的,还能让KL(Q(z)∣∣P(z∣X))KL(Q(z)||P(z|X))KL(Q(z)P(zX))较小:
log⁡p(x)−KL(q(z∣x)∣p(z∣x))=Ez∼q[log⁡p(x∣z)]−KL(q(z∣x)∣∣p(z))\log p(x)-KL(q(z|x)|p(z|x))=E_{z\sim q}[\log p(x|z)]-KL(q(z|x)||p(z)) logp(x)KL(q(zx)p(zx))=Ezq[logp(xz)]KL(q(zx)p(z))
这个式子就是变分自编码的核心了

这样我们就知道了优化目标(等号右边的时候),我们看看变换后的式子为什么能够计算?

首先没了p(z∣x)p(z|x)p(zx),其次每一项都能计算,我们挨个来看:

  • 如何计算q(z∣x)q(z|x)q(zx)
    我们可以使用神经网络逼近q(z∣x)q(z|x)q(zx),假设q(z∣x)q(z|x)q(zx)服从高斯分布N(μ,σ)N(\mu,\sigma)N(μ,σ)

    • 神经网络的输出就是均值μ\muμ和方差σ\sigmaσ
    • 输入是图片,输出是分布

    计算q(z∣x)q(z|x)q(zx)就是编码过程了

  • 如果计算p(x∣z)?用一个神经网络去逼近p(x|z)? 用一个神经网络去逼近p(xz)p(x|z),假设神经网络输出是,假设神经网络输出是f(z)$
    假设p(x∣z)p(x|z)p(xz)服从另一种高斯分布

    • x=f(z)+ηx=f(z)+\etax=f(z)+η,其中η∼N(0,I)\eta\sim N(0,I)ηN(0,I)
    • 简化成l2l_2l2损失:∣∣X−f(z)∣∣2||X-f(z)||^2Xf(z)2

    计算p(x∣z)p(x|z)p(xz)就是解码过程了

最终损失就是
L=∣∣X−f(z)∣∣2−λ⋅KL(q(z∣x)∣∣p(z))L=||X-f(z)||^2-\lambda\cdot KL(q(z|x)||p(z)) L=Xf(z)2λKL(q(zx)p(z))
在这里,我们先不看这个最终损失的式子,我们去瞅瞅未经过l2l_2l2简化的的优化目标
ELBO=Ez∼q[log⁡p(x∣z)]−KL(q(z∣x)∣∣p(z))ELBO=E_{z\sim q}[\log p(x|z)]-KL(q(z|x)||p(z)) ELBO=Ezq[logp(xz)]KL(q(zx)p(z))

  • 计算第二项的KL散度
    我们经常选择q(z∣x)=N(z∣μ(x;θ),Σ(x;θ))q(z|x)=N(z|\mu(x;\theta),\Sigma(x;\theta))q(zx)=N(zμ(x;θ),Σ(x;θ)),这里面μ,Σ\mu,\Sigmaμ,Σ通常是任意确定的函数,且其参数θ\thetaθ能够从数据中学习。通常通过神经网络获取,并且Σ\SigmaΣ被限制为一个对角阵。这样选择的好处是便于计算,仅此而已,那么右边的KL(q(z∣x)∣∣p(z))KL(q(z|x)||p(z))KL(q(zx)p(z))就编程了两个多元高斯分布的KL距离,有闭式解为:
    D(N(μ0,Σ0)∣∣N(μ1,Σ1))=12(tr(Σ1−1Σ0)+(μ1−μ0)TΣ1−1(μ1−μ0)−k+log⁡(det⁡Σ1det⁡Σ0)D(N(\mu_0,\Sigma_0)||N(\mu_1,\Sigma_1))=\\ \frac{1}{2}\left(tr(\Sigma^{-1}_1\Sigma_0\right)+(\mu_1-\mu_0)^T\Sigma_1^{-1}(\mu_1-\mu_0)-k+\log(\frac{\det \Sigma_1}{\det \Sigma_0}) D(N(μ0,Σ0)N(μ1,Σ1))=21(tr(Σ11Σ0)+(μ1μ0)TΣ11(μ1μ0)k+log(detΣ0detΣ1)
    其中kkk是分布的维数,而在变分推断中,经常又被简化成
    D(N(μ(x),Σ(x))∣∣N(0,I))=12(tr(Σ(x))+(μ(x))T(μ(x))−k+log⁡det⁡(Σ(x))D(N(\mu(x),\Sigma(x))||N(0,I))=\\ \frac{1}{2}\left(tr(\Sigma(x)\right)+(\mu(x))^T(\mu(x))-k+\log\det(\Sigma(x)) D(N(μ(x),Σ(x))N(0,I))=21(tr(Σ(x))+(μ(x))T(μ(x))k+logdet(Σ(x))

  • 计算第一项Ez∼qE_{z\sim q}Ezq
    论文中说这一项的计算有点小技巧(tricky),本来是可以通过采样的方法估计Ez∼q(log⁡p(x∣z))E_{z\sim q}(\log p(x|z))Ezq(logp(xz)),但是只有将很多的zzz通过fff式子(解码部分)输出以后才能得到较好的估计结果,这个计算量很大,因此想到了随机梯度下降,我们可以拿一个样本zzz,将p(x∣z)p(x|z)p(xz)作为Ez∼q(log⁡p(x∣z))E_{z\sim q}(\log p(x|z))Ezq(logp(xz))的估计,所以式子又变成了:
    Ex∼D(log⁡p(x)−KL(q(z∣x)∣∣p(z∣x)))=Ex∼D[Ez∼q[log⁡p(x∣z)]]−KL(q(z∣x)∣∣p(z))E_{x\sim D}(\log p(x)-KL(q(z|x)||p(z|x)))=\\ E_{x\sim D}\left[E_{z\sim q}\left[\log p(x|z)\right]\right]-KL(q(z|x)||p(z)) ExD(logp(x)KL(q(zx)p(zx)))=ExD[Ezq[logp(xz)]]KL(q(zx)p(z))
    意思就是我们从样本集合DDD中取一个样本xxx来计算,所以对于单个,可以计算下式梯度:
    log⁡p(x∣z)−KL(q(z∣x)∣∣p(z))\log p(x|z)-KL(q(z|x)||p(z)) logp(xz)KL(q(zx)p(z))
    这样消除了Ez∼qE_{z\sim q}Ezq中对qqq的依赖。

    论文中有个图很好

    这里写图片描述
    其实log⁡p(x∣z)−KL(q(z∣x)∣∣p(z))\log p(x|z)-KL(q(z|x)||p(z))logp(xz)KL(q(zx)p(z))刚好就是左图,主要就是反传的时候没法计算梯度,看左图红框部分,这一部分是随机采样,是无法计算梯度的,那么文中就说了一个技巧:重新参数化(reparameterization trick),给定了μ(x),Σ(x)\mu(x),\Sigma(x)μ(x),Σ(x)也就是Q(z∣x)Q(z|x)Q(zx)的均值和方差,我们先从N(0,I)N(0,I)N(0,I)中采样,然后计算z=μ(x)+Σ12∗ϵz=\mu(x)+\Sigma^{\frac{1}{2}}*\epsilonz=μ(x)+Σ21ϵ,所以我们又可以计算下式的梯度了:
    Ex∼D[Eϵ∼N(0,I)[log⁡p(x∣z=μ(x)+Σ1/2(x)∗ϵ)]−KL(q(z∣x)∣∣p(z))]E_{x\sim D}\left[E_{\epsilon\sim N(0,I)}\left[\log p(x|z=\mu(x)+\Sigma^{1/2}(x)*\epsilon)\right]-KL(q(z|x)||p(z)) \right] ExD[EϵN(0,I)[logp(xz=μ(x)+Σ1/2(x)ϵ)]KL(q(zx)p(z))]
    这就完成了从左图到右图的转变。

代码实现-模型训练及保存

理论很复杂,但是我们看着右图就能实现,无需看理论,理论只是让我们知道为什么会有右图这种网络结构。按照标准流程来书写代码:

  • 读数据
  • 初始化相关参数
  • 定义数据接收接口以便测试使用
  • 初始化权重和偏置
  • 定义基本模块:编码器、采样器、解码器
  • 构建模型
  • 定义预测函数、损失函数、优化器
  • 训练

整个代码很简单,我就只贴部分重点的:

初始化权重偏置

#初始化权重、偏置
def glorot_init(shape):return tf.random_normal(shape=shape,stddev=1./tf.sqrt(shape[0]/2.0))
#权重
weights={'encoder_h1':tf.Variable(glorot_init([num_input,hidden_dim])),'z_mean':tf.Variable(glorot_init([hidden_dim,latent_dim])),'z_std':tf.Variable(glorot_init([hidden_dim,latent_dim])),'decoder_h1':tf.Variable(glorot_init([latent_dim,hidden_dim])),'decoder_out':tf.Variable(glorot_init([hidden_dim,num_input]))
}
#偏置
biases={'encoder_b1':tf.Variable(glorot_init([hidden_dim])),'z_mean':tf.Variable(glorot_init([latent_dim])),'z_std':tf.Variable(glorot_init([latent_dim])),'decoder_b1':tf.Variable(glorot_init([hidden_dim])),'decoder_out':tf.Variable(glorot_init([num_input]))
}

注意这里使用了另一种初始化方法,说是Xavier初始化方法,因为直接使用上一篇博客的方法

tf.Variable(tf.random_normal([num_input,num_hidden1])),

训练时候一直给我弹出loss:nan,我也是醉了,以后还是用之前学theano时候采用的fan_in-fan_out方法初始化权重算了。

定义基本模块

注意需要定义编码器、采样器、解码器

#定义编码器
def encoder(x):encoder=tf.matmul(x,weights['encoder_h1'])+biases['encoder_b1']encoder=tf.nn.tanh(encoder)z_mean=tf.matmul(encoder,weights['z_mean'])+biases['z_mean']z_std=tf.matmul(encoder,weights['z_std'])+biases['z_std']return z_mean,z_std#定义采样器
def sampler(z_mean,z_std):eps=tf.random_normal(tf.shape(z_std),dtype=tf.float32,mean=0,stddev=1.0,name='epsilon')z=z_mean+tf.exp(z_std/2)*epsreturn z#定义解码器
def decoder(x):decoder=tf.matmul(x,weights['decoder_h1'])+biases['decoder_b1']decoder=tf.nn.tanh(decoder)decoder=tf.matmul(decoder,weights['decoder_out'])+biases['decoder_out']decoder=tf.nn.sigmoid(decoder)return decoder

构建模型

#构建模型
[z_mean,z_std]=encoder(X)#计算均值方差
sample_latent=sampler(z_mean,z_std)#采样隐空间
decoder_op=decoder(sample_latent)#重构输出

预测函数和损失

#预测函数
y_pred=decoder_op
y_true=X
tf.add_to_collection('recon',y_pred)
#定义损失函数和优化器
def vae_loss(x_reconstructed,x_true,z_mean,zstd):#重构损失encode_decode_loss=x_true*tf.log(1e-10+x_reconstructed)\+(1-x_true)*tf.log(1e-10+1-x_reconstructed)encode_decode_loss=-tf.reduce_sum(encode_decode_loss,1)#KL损失kl_div_loss=1+z_std-tf.square(z_mean)-tf.exp(z_std)kl_div_loss=-0.5*tf.reduce_sum(kl_div_loss,1)return tf.reduce_mean(encode_decode_loss+kl_div_loss)
loss_op=vae_loss(decoder_op,y_true,z_mean,z_std)
optimizer=tf.train.RMSPropOptimizer(learning_rate=learning_rate)
train_op=optimizer.minimize(loss_op)

注意这里损失的第一项类似于交叉熵损失:y×log⁡(y^)+(1−y)×(1−log⁡y^)y\times \log(\hat{y})+(1-y)\times(1-\log \hat y)y×log(y^)+(1y)×(1logy^)

关于交叉熵损失和均方差损失的区别,可以看我前面的博客:损失函数梯度对比-均方差和交叉熵

训练和保存模型

#参数初始化
init=tf.global_variables_initializer()
input_image,input_label=read_images('./mnist/train_labels.txt',batch_size)
#训练和保存模型
saver=tf.train.Saver()
with tf.Session() as sess:sess.run(init)coord=tf.train.Coordinator()tf.train.start_queue_runners(sess=sess,coord=coord)for step in range(1,num_steps):batch_x,batch_y=sess.run([input_image,tf.one_hot(input_label,1,0)])sess.run(train_op,feed_dict={X:batch_x})if step%disp_step==0 or step==1:loss=sess.run(loss_op,feed_dict={X:batch_x})print('step '+str(step)+' ,loss '+'{:.4f}'.format(loss))coord.request_stop()coord.join()print('optimization finished')saver.save(sess,'./VAE_mnist_model/VAE_mnist')

常规的保存方法,没什么说的,训练日志:

step 1 ,loss 616.3002
step 1000 ,loss 169.6044
step 2000 ,loss 163.5006
step 3000 ,loss 166.1648
step 4000 ,loss 161.2366
step 5000 ,loss 155.1714
step 6000 ,loss 153.2840
step 7000 ,loss 161.5571
step 8000 ,loss 152.0021
step 9000 ,loss 159.5550
step 10000 ,loss 154.6315
step 11000 ,loss 153.8298
step 12000 ,loss 141.5825
step 13000 ,loss 149.7792
step 14000 ,loss 150.9575
step 15000 ,loss 151.2249
step 16000 ,loss 159.3878
step 17000 ,loss 148.7136
step 18000 ,loss 148.5801
step 19000 ,loss 150.6678
step 20000 ,loss 146.3471
step 21000 ,loss 156.4142
step 22000 ,loss 148.7607
step 23000 ,loss 145.4101
step 24000 ,loss 153.3523
step 25000 ,loss 157.8997
step 26000 ,loss 136.9668
step 27000 ,loss 155.7835
step 28000 ,loss 137.7291
step 29000 ,loss 153.1723
optimization finished

【很尴尬的事情】偷偷说一句,上面保存错东东了,别打我,但是也不必重新训练,看接下来的蛇皮操作。

#代码实现-模型加载及测试

老样子,先载入模型

sess=tf.Session()
new_saver=tf.train.import_meta_graph('./VAE_mnist_model/VAE_mnist.meta')
new_saver.restore(sess,'./VAE_mnist_model/VAE_mnist')

获取计算图:

graph=tf.get_default_graph()

看看保存了啥

print (graph.get_all_collection_keys())
#['queue_runners', 'recon', 'summaries', 'train_op', 'trainable_variables', 'variables']

准备调用recon函数重构数据。

等等,“重构”?搞错了,这里应该是依据噪声来生成数据的,不是输入一个数据然后重构它,这是AE的做法,我们在VAE中应该称之为生成了,然而不幸的是,我们保存的recon函数接收的是图片输入,无法指定decoder部分所需的分布参数,也就是均值方差,怎么办?手动选择性载入模型,这篇博客有介绍怎么在测试阶段定义网络权重和载入训练好的权重,但是我好像没成功,懒得试了,按照我自己的想法来做。

眼尖的童鞋会发现,我们之前一直只关注recon函数了,忽视了其它keys,很快发现最后两个trainable_variablesvariables貌似与我们想要的模型参数有关,我们来输出一下这两个东东里面都保存了啥:

第一个trainable_variables

for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):print(i)

输出

<tf.Variable 'Variable:0' shape=(784, 512) dtype=float32_ref>
<tf.Variable 'Variable_1:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_2:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_3:0' shape=(2, 512) dtype=float32_ref>
<tf.Variable 'Variable_4:0' shape=(512, 784) dtype=float32_ref>
<tf.Variable 'Variable_5:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_6:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_7:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_8:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_9:0' shape=(784,) dtype=float32_ref>

第二个:GLOBAL_VARIABLES

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):print(i)
<tf.Variable 'Variable:0' shape=(784, 512) dtype=float32_ref>
<tf.Variable 'Variable_1:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_2:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_3:0' shape=(2, 512) dtype=float32_ref>
<tf.Variable 'Variable_4:0' shape=(512, 784) dtype=float32_ref>
<tf.Variable 'Variable_5:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_6:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_7:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_8:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_9:0' shape=(784,) dtype=float32_ref>
<tf.Variable 'Variable/RMSProp:0' shape=(784, 512) dtype=float32_ref>
<tf.Variable 'Variable/RMSProp_1:0' shape=(784, 512) dtype=float32_ref>
<tf.Variable 'Variable_1/RMSProp:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_1/RMSProp_1:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_2/RMSProp:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_2/RMSProp_1:0' shape=(512, 2) dtype=float32_ref>
<tf.Variable 'Variable_3/RMSProp:0' shape=(2, 512) dtype=float32_ref>
<tf.Variable 'Variable_3/RMSProp_1:0' shape=(2, 512) dtype=float32_ref>
<tf.Variable 'Variable_4/RMSProp:0' shape=(512, 784) dtype=float32_ref>
<tf.Variable 'Variable_4/RMSProp_1:0' shape=(512, 784) dtype=float32_ref>
<tf.Variable 'Variable_5/RMSProp:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_5/RMSProp_1:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_6/RMSProp:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_6/RMSProp_1:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_7/RMSProp:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_7/RMSProp_1:0' shape=(2,) dtype=float32_ref>
<tf.Variable 'Variable_8/RMSProp:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_8/RMSProp_1:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'Variable_9/RMSProp:0' shape=(784,) dtype=float32_ref>
<tf.Variable 'Variable_9/RMSProp_1:0' shape=(784,) dtype=float32_ref>

很容易发现我们只需要从可训练的参数集中获取权重,还记得之前说过的么,我们啥都往sess.run中丢试试,看看能不能取出来值:

a=sess.run(graph.get_collection('trainable_variables'))
for i in a:print(i.shape)
(784, 512)
(512, 2)
(512, 2)
(2, 512)
(512, 784)
(512,)
(2,)
(2,)
(512,)
(784,)

可以看出参数可以从a中通过索引取出来了。

接下来简单,重新定义一下模型的decoder计算:

latent_dim=2
noise_input=tf.placeholder(tf.float32,shape=[None,latent_dim])
decoder=tf.matmul(noise_input,a[3])+a[8]
decoder=tf.nn.tanh(decoder)
decoder=tf.matmul(decoder,a[4])+a[9]
decoder=tf.nn.sigmoid(decoder)

然后就可以尝试丢进去一个均值和方差去预测了:

generate=sess.run(decoder,feed_dict={noise_input:[[3,3]]})

可视化

generate=generate*255.0
gen_img=generate.reshape(28,28)
plt.imshow(gen_img)
plt.show()

这里写图片描述
效果还不错,再测试几个,丢[0,0][0,0][0,0]试试:

这里写图片描述

[0,5][0,5][0,5]试试:

这里写图片描述

好了,不玩了,随机性太大了,我都不知道啥噪声能输出啥数字,只有输出的时候才知道。

后记

好玩是好玩,但是我不知道哪个数字对应哪种噪声输入,还是有点郁闷,下一篇我们就去看看有名的搞基网GAN

本文训练代码:链接:https://pan.baidu.com/s/19QSNfT7fgWrU68CV7lXSZA 密码:6c7l

本文测试代码:链接:https://pan.baidu.com/s/1CAxPGnmCTg-OT8TqnXQdVw 密码:tmep

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246605.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【TensorFlow-windows】学习笔记七——生成对抗网络

前言 既然学习了变分自编码(VAE)&#xff0c;那也必须来一波生成对抗网络(GAN)。 国际惯例&#xff0c;参考网址&#xff1a; 论文: Generative Adversarial Nets PPT:Generative Adversarial Networks (GANs) Generative Adversarial Nets in TensorFlow GAN原理学习笔记…

Openpose——windows编译(炒鸡简单)

前言 最近准备看看rtpose的代码&#xff0c;发现已经由openpose这个项目维护着了&#xff0c;由于经常在windows下调试代码&#xff0c;所以尝试了一下如何在windows下编译openpose源码&#xff0c;整体来说非常简单的。 国际惯例&#xff0c;参考博客&#xff1a; [OpenPos…

强化学习——Qlearning

前言 在控制决策领域里面强化学习还是占很重比例的&#xff0c;最近出了几篇角色控制的论文需要研究&#xff0c;其中部分涉及到强化学习&#xff0c;都有开源&#xff0c;有兴趣可以点开看看&#xff1a; A Deep Learning Framework For Character Motion Synthesis and Edit…

【TensorFlow-windows】keras接口学习——线性回归与简单的分类

前言 之前有写过几篇TensorFlow相关文章&#xff0c;但是用的比较底层的写法&#xff0c;比如tf.nn和tf.layers&#xff0c;也写了部分基本模型如自编码和对抗网络等&#xff0c;感觉写起来不太舒服&#xff0c;最近看官方文档发现它的教程基本都使用的keras API&#xff0c;这…

【TensorFlow-windows】keras接口——卷积手写数字识别,模型保存和调用

前言 上一节学习了以TensorFlow为底端的keras接口最简单的使用&#xff0c;这里就继续学习怎么写卷积分类模型和各种保存方法(仅保存权重、权重和网络结构同时保存) 国际惯例&#xff0c;参考博客&#xff1a; 官方教程 【注】其实不用看博客&#xff0c;直接翻到文末看我的c…

【TensorFlow-windows】keras接口——BatchNorm和ResNet

前言 之前学习利用Keras简单地堆叠卷积网络去构建分类模型的方法&#xff0c;但是对于很深的网络结构很难保证梯度在各层能够正常传播&#xff0c;经常发生梯度消失、梯度爆炸或者其它奇奇怪怪的问题。为了解决这类问题&#xff0c;大佬们想了各种办法&#xff0c;比如最原始的…

【TensorFlow-windows】keras接口——卷积核可视化

前言 在机器之心上看到了关于卷积核可视化相关理论&#xff0c;但是作者的源代码是基于fastai写的&#xff0c;而fastai的底层是pytorch&#xff0c;本来准备自己用Keras复现一遍的&#xff0c;但是尴尬地发现Keras还没玩熟练&#xff0c;随后发现了一个keras-vis包可以用于做…

【TensorFlow-windows】投影变换

前言 没什么重要的&#xff0c;就是想测试一下tensorflow的投影变换函数tf.contrib.image.transform中每个参数的含义 国际惯例&#xff0c;参考文档 官方文档 描述 调用方法与默认参数&#xff1a; tf.contrib.image.transform(images,transforms,interpolationNEAREST,…

【TensorFlow-windows】扩展层之STN

前言 读TensorFlow相关代码看到了STN的应用&#xff0c;搜索以后发现可替代池化&#xff0c;增强网络对图像变换(旋转、缩放、偏移等)的抗干扰能力&#xff0c;简单说就是提高卷积神经网络的空间不变性。 国际惯例&#xff0c;参考博客&#xff1a; 理解Spatial Transformer…

【TensorFlow-windows】MobileNet理论概览与实现

前言 轻量级神经网络中&#xff0c;比较重要的有MobileNet和ShuffleNet&#xff0c;其实还有其它的&#xff0c;比如SqueezeNet、Xception等。 本博客为MobileNet的前两个版本的理论简介与Keras中封装好的模块的对应实现方案。 国际惯例&#xff0c;参考博客&#xff1a; 纵…

【TensorFlow-windows】keras接口——ImageDataGenerator裁剪

前言 Keras中有一个图像数据处理器ImageDataGenerator&#xff0c;能够很方便地进行数据增强&#xff0c;并且从文件中批量加载图片&#xff0c;避免数据集过大时&#xff0c;一下子加载进内存会崩掉。但是从官方文档发现&#xff0c;并没有一个比较重要的图像增强方式&#x…

【TensorFlow-windows】TensorBoard可视化

前言 紧接上一篇博客&#xff0c;学习tensorboard可视化训练过程。 国际惯例&#xff0c;参考博客&#xff1a; MNIST机器学习入门 Tensorboard 详解&#xff08;上篇&#xff09; Tensorboard 可视化好帮手 2 tf-dev-summit-tensorboard-tutorial tensorflow官方mnist_…

深度学习特征归一化方法——BN、LN、IN、GN

前言 最近看到Group Normalization的论文&#xff0c;主要提到了四个特征归一化方法&#xff1a;Batch Norm、Layer Norm、Instance Norm、Group Norm。此外&#xff0c;论文还提到了Local Response Normalization(LRN)、Weight Normalization(WN)、Batch Renormalization(BR)…

【TensorFlow-windows】keras接口——利用tensorflow的方法加载数据

前言 之前使用tensorflow和keras的时候&#xff0c;都各自有一套数据读取方法&#xff0c;但是遇到一个问题就是&#xff0c;在训练的时候&#xff0c;GPU的利用率忽高忽低&#xff0c;极大可能是由于训练过程中读取每个batch数据造成的&#xff0c;所以又看了tensorflow官方的…

骨骼动画——论文与代码精读《Phase-Functioned Neural Networks for Character Control》

前言 最近一直玩CV&#xff0c;对之前学的动捕知识都忘得差不多了&#xff0c;最近要好好总结一下一直以来学习的内容&#xff0c;不能学了忘。对2017年的SIGGRAPH论文《Phase-Functioned Neural Networks for Character Control》进行一波深入剖析吧&#xff0c;结合源码。 额…

颜色协调模型Color Harmoniztion

前言 最近做换脸&#xff0c;在肤色调整的那一块&#xff0c;看到一个有意思的文章&#xff0c;复现一波玩玩。不过最后一步掉链子了&#xff0c;有兴趣的可以一起讨论把链子补上。 主要是github上大佬的那个复现代码和原文有点差异&#xff0c;而且代码复杂度过高&#xff0…

Openpose推断阶段原理

前言 之前出过一个关于openpose配置的博客&#xff0c;不过那个代码虽然写的很好&#xff0c;而且是官方的&#xff0c;但是分析起来很困难&#xff0c;然后再opencv相关博客中找到了比较清晰的实现&#xff0c;这里分析一波openpose的推断过程。 国际惯例&#xff0c;参考博…

换脸系列——眼鼻口替换

前言 想着整理一下换脸相关的技术方法&#xff0c;免得以后忘记了&#xff0c;最近脑袋越来越不好使了。应该会包含三个系列&#xff1a; 仅换眼口鼻&#xff1b;换整个面部&#xff1b;3D换脸 先看看2D换脸吧&#xff0c;网上已经有现成的教程了&#xff0c;这里拿过来整理一…

换脸系列——整脸替换

前言 前面介绍了仅替换五官的方法&#xff0c;这里介绍整张脸的方法。 国际惯例&#xff0c;参考博客&#xff1a; [图形算法]Delaunay三角剖分算法 维诺图&#xff08;Voronoi Diagram&#xff09;分析与实现 Delaunay Triangulation and Voronoi Diagram using OpenCV (…

3D人脸重建——PRNet网络输出的理解

前言 之前有款换脸软件不是叫ZAO么&#xff0c;分析了一下&#xff0c;它的实现原理绝对是3D人脸重建&#xff0c;而非deepfake方法&#xff0c;找了一篇3D重建的论文和源码看看。这里对源码中的部分函数做了自己的理解和改写。 国际惯例&#xff0c;参考博客&#xff1a; 什…