VAE 理论推导及代码实现

VAE 理论推导及代码实现

熵、交叉熵、KL 散度的概念

熵(Entropy)

假设 p (x)是一个分布函数,满足在 x 上的积分为 1,那么 p ( x ) p(x) p(x)的熵定义为 H ( p ( x ) ) H (p (x)) H(p(x)),这里我们简写为 H ( p ) H(p) H(p)
H ( p ) = ∫ p ( x ) log ⁡ 1 p ( x ) d x H(p)=\int p(x) \log \frac{1}{p(x)} dx H(p)=p(x)logp(x)1dx
直观上,越分散的分布函数熵越大。越集中的分布函数熵越小。熵的最小值为 0.

从信息论的角度来说,熵又叫信息熵,它的大小表示信息量的多少,分散的分布函数可能性多、拿到 p (x)后对于 x 的推断不确定性大,即信息量大,而对于 p © =1 这种情况拿到分布函数直接就拿到了结果,因此信息量为 0

交叉熵(Cross-Entropy)

假设 p ( x ) p(x) p(x) q ( x ) q(x) q(x)是两个分布函数,交叉熵的小大评价了这两个分布函数的相似与否。 p p p q q q 的交叉熵记为 H ( p , q ) H(p, q) H(p,q)
H ( p , q ) = ∫ p ( x ) log ⁡ 1 q ( x ) d x H(p, q)=\int p(x) \log \frac{1}{q(x)} d x H(p,q)=p(x)logq(x)1dx

交叉熵小一分布相似;交叉熵大一分布不相似。交叉熵最大为无穷大,最小为 p p p 的熵 H ( p ) H (p) H(p)

KL 散度

假设 p ( x ) p(x) p(x) q ( x ) q (x) q(x)是两个分布函数,KL 散度的小大评价了这两个分布函数的相似与否,同时考虑了 K L ( x ) KL(x) KL(x)这个分布的信息量。记为 K L ( p , q ) KL(p, q) KL(p,q)。注意: K L ( p , q ) KL (p, q) KL(p,q)也不一定等于 K L ( q , p ) KL (q, p) KL(q,p)
K L ( p , q ) = H ( p , q ) − H ( p ) K L(p, q)=H(p, q)-H(p) KL(p,q)=H(p,q)H(p)
∫ p ( x ) log ⁡ 1 q ( x ) d x − ∫ p ( x ) log ⁡ 1 p ( x ) d x = ∫ p ( x ) log ⁡ p ( x ) q ( x ) d x \begin{aligned} & \int p(x) \log \frac{1}{q(x)} d x-\int p(x) \log \frac{1}{p(x)} d x \\ & =\int p(x) \log \frac{p(x)}{q(x)} d x \end{aligned} p(x)logq(x)1dxp(x)logp(x)1dx=p(x)logq(x)p(x)dx

KL散度小—分布相似 & [ p ( x ) [p(x) [p(x) 分散 | p ( x ) p(x) p(x) 信息量大]。
K L \mathrm{KL} KL 散度大–分布不相似 & [ p ( x ) [p(x) [p(x) 集中 ∣ p ( x ) \mid p(x) p(x) 信息量小]。
K L \mathrm{KL} KL 散度最小值为 0 : p ( x ) 0: p(x) 0:p(x) q ( x ) q(x) q(x) 完全相同时。

概率知识

将p(x)其改写为包含了传入参数的形式
p ( x ) = ∑ z p ( x ∣ z ) p ( z ) p(x)=\sum_z p(x \mid z) p(z) p(x)=zp(xz)p(z)

连续分布时,该式就变成了
p ( x ) = ∫ z ⁡ p ( x ∣ z ) p ( z ) d z p(x)=\int_z^{\operatorname{}} p(x \mid z) p(z) d z p(x)=zp(xz)p(z)dz

p ( z ) p(z) p(z)可以是任意分布,在VAE中我们常常假设p(z)服从标准正态分布。

变分方法

Intractability:
p θ ( z ∣ x ) = p θ ( x ∣ z ) p θ ( z ) / p θ ( x ) p θ ( x ) = ∫ p θ ( z ) p θ ( x ˙ ∣ z ) d z \begin{aligned} p_{\boldsymbol{\theta}}(\mathbf{z} \mid \mathbf{x}) & =p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mathbf{z}) p_{\boldsymbol{\theta}}(\mathbf{z}) / p_{\boldsymbol{\theta}}(\mathbf{x}) \\ p_{\boldsymbol{\theta}}(\mathbf{x}) & =\int p_{\boldsymbol{\theta}}(\mathbf{z}) p_{\boldsymbol{\theta}}(\dot{\mathbf{x}} \mid \mathbf{z}) d \mathbf{z} \end{aligned} pθ(zx)pθ(x)=pθ(xz)pθ(z)/pθ(x)=pθ(z)pθ(x˙z)dz
p ( z ∣ x ( i ) ) = p ( z , x ( i ) ) p ( x ( i ) ) = p ( x = x ( i ) ∣ z = z ( i ) ) p ( z = z ( i ) ) ∫ z ( i ) p ( x = x ( i ) ∣ z = z ( i ) ) p ( z = z ( i ) ) d z ( i ) = p ( x ( i ) ∣ z ) p ( z ) ∫ z p ( x ( i ) ∣ z ) p ( z ) d z \begin{aligned} p\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) & =\frac{p\left(\mathbf{z}, \mathbf{x}^{(i)}\right)}{p\left(\mathbf{x}^{(i)}\right)} \\ & =\frac{p\left(\mathbf{x}=\mathbf{x}^{(i)} \mid \mathbf{z}=\mathbf{z}^{(i)}\right) p\left(\mathbf{z}=\mathbf{z}^{(i)}\right)}{\int_{\mathbf{z}^{(i)}} p\left(\mathbf{x}=\mathbf{x}^{(i)} \mid \mathbf{z}=\mathbf{z}^{(i)}\right) p\left(\mathbf{z}=\mathbf{z}^{(i)}\right) d \mathbf{z}^{(i)}} \\ & =\frac{p\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right) p(\mathbf{z})}{\int_{\mathbf{z}} p\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right) p(\mathbf{z}) d \mathbf{z}} \end{aligned} p(zx(i))=p(x(i))p(z,x(i))=z(i)p(x=x(i)z=z(i))p(z=z(i))dz(i)p(x=x(i)z=z(i))p(z=z(i))=zp(x(i)z)p(z)dzp(x(i)z)p(z)
参考:https://zhuanlan.zhihu.com/p/519448634

如果假设参数 θ \theta θ 已知, 那么先验分布 p θ ( z ) p_\theta(\mathbf{z}) pθ(z) 和条件似然函数 p θ ( x ( i ) ∣ z ) p_\theta\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right) pθ(x(i)z) 就都是已知的。理论上 来说, 只要把分母里的积分项 ∫ z p θ ( x ( i ) ∣ z ) p ( z ) d z \int_{\mathbf{z}} p_\theta\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right) p(\mathbf{z}) d \mathbf{z} zpθ(x(i)z)p(z)dz 计算出来, 那整个后验分布 p ( z ∣ x ( i ) ) p\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) p(zx(i)) 就 可以求了, 后验推断问题也就解决了。但是, 现实很骨感, 在没有对 p θ ( z ) p_\theta(\mathbf{z}) pθ(z) p θ ( x ( i ) ∣ z ) p_\theta\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right) pθ(x(i)z) 作任何 简化假设的前提下, 这个积分基本上是没有解析解的。你想硬着头皮解, 那么基本意味着你要穷举 隐变量 z \mathbf{z} z 的所有可能取值, 假设 z \mathbf{z} z k k k 个维度, 每个维度采样 n n n 个取值, 那么这个穷举过程的复 杂度就是 O ( n k ) O\left(n^k\right) O(nk)

当然也有人用MCMC来做积分项的估计,虽然这个方案做采样估计很精准,但是费时费力,很难适用于大数据场景。所以一般更常见的方案是采用变分方法(variational method),它可以绕过对积分项的求解,通过把统计推断问题转化成参数优化问题来实现“降维打击”。

首先变分方法会设置一个新的参数化分布 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)), 它的参数是 ϕ \phi ϕ, 我们把它称作"识别模型" (原文记作recognition model) 。变分方法的核心思想是:直接让“识别模型”去拟合后验分布 p θ ( z ∣ x ( i ) ) p_\theta\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) pθ(zx(i)), 只要近似到位, 那么采用 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)) 作为后验推断的结果就行了。如何做近似呢? 很简单, 直接最小化 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)) p θ ( z ∣ x ( i ) ) p_\theta\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) pθ(zx(i)) 两者间的KL散度即可。

就这样,变分方法把原来的统计推断问题转化成了优化问题:

Approximation p θ ( z ∣ x ) ≅ q ϕ ( z ∣ x ) \quad p_\theta(z \mid x) \cong q_\phi(z \mid x) pθ(zx)qϕ(zx)
D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) = − ∑ decoder  q ϕ ( z ∣ x ) log ⁡ ( p θ ( z ∣ x ) q ϕ ( z ∣ x ) ) = − ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x , z ) p θ ( x ) q ϕ ( z ∣ x ) ) = − ∑ z q ϕ ( z ∣ x ) [ log ⁡ ( p θ ( x , z ) q ϕ ( z ∣ x ) ) − log ⁡ ( p θ ( x ) ) ‾ ]  non-negative  log ⁡ ( p θ ( x ) ) = K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) + ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x , z ) q ϕ ( z ∣ x ) ) = D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) + L ( θ , ϕ ; x )  Variational lower bound  \begin{aligned} & D_{K L}\left(q_\phi(z \mid x) \| p_\theta(z \mid x)\right)=-\sum_{\text {decoder }} q_\phi(z \mid x) \log \left(\frac{p_\theta(z \mid x)}{q_\phi(z \mid x)}\right)=-\sum_z q_\phi(z \mid x) \log \left(\frac{\frac{p_\theta(x, z)}{p_\theta(x)}}{q_\phi(z \mid x)}\right) \\ & =-\sum_z q_\phi(z \mid x)\left[\log \left(\frac{p_\theta(x, z)}{q_\phi(z \mid x)}\right)-\underline{\log \left(p_\theta(x)\right)}\right] \\ & \begin{array}{c} \text { non-negative } \\ \log \left(p_\theta(x)\right) \end{array}=K L\left(q_\phi(z \mid x) \| p_\theta(z \mid x)\right)+\sum_z q_\phi(z \mid x) \log \left(\frac{p_\theta(x, z)}{q_\phi(z \mid x)}\right) \\ & =D_{K L}\left(q_\phi(z \mid x)|| p_\theta(z \mid x)\right)+\frac{L(\theta, \phi ; x)}{\text { Variational lower bound }} \\ & \end{aligned} DKL(qϕ(zx)pθ(zx))=decoder qϕ(zx)log(qϕ(zx)pθ(zx))=zqϕ(zx)log qϕ(zx)pθ(x)pθ(x,z) =zqϕ(zx)[log(qϕ(zx)pθ(x,z))log(pθ(x))] non-negative log(pθ(x))=KL(qϕ(zx)pθ(zx))+zqϕ(zx)log(qϕ(zx)pθ(x,z))=DKL(qϕ(zx)∣∣pθ(zx))+ Variational lower bound L(θ,ϕ;x)

Maximize the lower bound
L ( θ , ϕ ; x ) = ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x , z ) q ϕ ( z ∣ x ) ) = ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x ∣ z ) p θ ( z ) q ϕ ( z ∣ x ) ) = ∑ z q ϕ ( z ∣ x ) [ log ⁡ ( p θ ( x ∣ z ) ) + log ⁡ ( p θ ( z ) q ϕ ( z ∣ x ) ) ] = E q ϕ ( z ∣ x ) [ log ⁡ ( p θ ( x ∣ z ) ) ]  Reconstruction Loss  − D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ) )  Regularization Loss  \begin{aligned} & L(\theta, \phi ; x)=\sum_z q_\phi(z \mid x) \log \left(\frac{p_\theta(x, z)}{q_\phi(z \mid x)}\right)=\sum_z q_\phi(z \mid x) \log \left(\frac{p_\theta(x \mid z) p_\theta(z)}{q_\phi(z \mid x)}\right) \\ &= \sum_z q_\phi(z \mid x)\left[\log \left(p_\theta(x \mid z)\right)+\log \left(\frac{p_\theta(z)}{q_\phi(z \mid x)}\right)\right] \\ &= \frac{E_{q_\phi(z \mid x)}\left[\log \left(p_\theta(x \mid z)\right)\right]}{\text { Reconstruction Loss }}-\frac{D_{K L}\left(q_\phi(z \mid x) \| p_\theta(z)\right)}{\text { Regularization Loss }} \end{aligned} L(θ,ϕ;x)=zqϕ(zx)log(qϕ(zx)pθ(x,z))=zqϕ(zx)log(qϕ(zx)pθ(xz)pθ(z))=zqϕ(zx)[log(pθ(xz))+log(qϕ(zx)pθ(z))]= Reconstruction Loss Eqϕ(zx)[log(pθ(xz))] Regularization Loss DKL(qϕ(zx)pθ(z))

L ( θ , ϕ ; x ) = ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x , z ) q ϕ ( z ∣ x ) ) = ∑ z q ϕ ( z ∣ x ) log ⁡ ( p θ ( x ∣ z ) p θ ( z ) q ϕ ( z ∣ x ) ) = ∑ z q ϕ ( z ∣ x ) [ log ⁡ ( p θ ( x ∣ z ) ) + log ⁡ ( p θ ( z ) q ϕ ( z ∣ x ) ) ] \begin{gathered} L(\theta, \phi ; x)=\sum_z q_\phi(z \mid x) \log \left(\frac{p_\theta(x, z)}{q_\phi(z \mid x)}\right)=\sum_z q_\phi(z \mid x) \log \left(\frac{p_\theta(x \mid z) p_\theta(z)}{q_\phi(z \mid x)}\right) \\ =\sum_z q_\phi(z \mid x)\left[\log \left(p_\theta(x \mid z)\right)+\log \left(\frac{p_\theta(z)}{q_\phi(z \mid x)}\right)\right] \end{gathered} L(θ,ϕ;x)=zqϕ(zx)log(qϕ(zx)pθ(x,z))=zqϕ(zx)log(qϕ(zx)pθ(xz)pθ(z))=zqϕ(zx)[log(pθ(xz))+log(qϕ(zx)pθ(z))]

Regularization Loss( 重参数化)

而在实践中, 一般不对 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)) 直接作采样, 采用 reparameterization trick 来简化操作, 我们 设 z ( i , l ) = g ϕ ( ϵ ( i ; l ) ; x ( i ) ) \mathbf{z}^{(i, l)}=g_\phi\left(\epsilon^{(i ; l)} ; \mathbf{x}^{(i)}\right) z(i,l)=gϕ(ϵ(i;l);x(i)), 其中 g ϕ g_\phi gϕ 是一个拟合函数 (e.g. 神经网络) , 而噪声 ϵ ( i ; l ) \epsilon^{(i ; l)} ϵ(i;l) 可以通过 采样得到, 一般直接采样自简单的标准正态分布。

∫ q θ ( z ∣ x ) log ⁡ p ( z ) d z = ∫ N ( z ; μ , σ 2 ) log ⁡ N ( z ; 0 , I ) d z \int q_\theta(z \mid x) \log p(z) d z=\int N\left(z ; \mu, \sigma^2\right) \log N(z ; 0, I) dz qθ(zx)logp(z)dz=N(z;μ,σ2)logN(z;0,I)dz
f ( x ) = 1 σ 2 π e − 1 2 ( x − μ σ ) 2 f(x)=\frac{1}{\sigma \sqrt{2 \pi}} e^{-\frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^2} f(x)=σ2π 1e21(σxμ)2

= ∫ N ( z ; μ , σ 2 ) ( − 1 2 z 2 − 1 2 log ⁡ ( 2 π ) ) d z = − 1 2 ∫ N ( z ; μ , σ 2 ) z 2 d z − J 2 log ⁡ ( 2 π ) = − J 2 log ⁡ ( 2 π ) − 1 2 E z ∼ N ( z ; μ , σ 2 ) [ Z 2 ] = − J 2 log ⁡ ( 2 π ) − 1 2 ( E z ∼ N ( z ; μ , σ 2 ) [ Z ] 2 + Var ⁡ ( Z ) ) = − J 2 log ⁡ ( 2 π ) − 1 2 ∑ j = 1 J ( μ j 2 + σ j 2 )  Let  J  be the dimensionality of  z \begin{aligned} & =\int N\left(z ; \mu, \sigma^2\right)\left(-\frac{1}{2} z^2-\frac{1}{2} \log (2 \pi)\right) d z=-\frac{1}{2} \int N\left(z ; \mu, \sigma^2\right) z^2 d z-\frac{J}{2} \log (2 \pi) \\ & =-\frac{J}{2} \log (2 \pi)-\frac{1}{2} E_{z \sim N\left(z ; \mu, \sigma^2\right)}\left[Z^2\right] \\ & =-\frac{J}{2} \log (2 \pi)-\frac{1}{2}\left(E_{z \sim N\left(z ; \mu, \sigma^2\right)}[Z]^2+\operatorname{Var}(Z)\right) \\ & =-\frac{J}{2} \log (2 \pi)-\frac{1}{2} \sum_{j=1}^J\left(\mu_j^2+\sigma_j^2\right) \quad \text { Let } J \text { be the dimensionality of } z \end{aligned} =N(z;μ,σ2)(21z221log(2π))dz=21N(z;μ,σ2)z2dz2Jlog(2π)=2Jlog(2π)21EzN(z;μ,σ2)[Z2]=2Jlog(2π)21(EzN(z;μ,σ2)[Z]2+Var(Z))=2Jlog(2π)21j=1J(μj2+σj2) Let J be the dimensionality of z

L1用于最小化 K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) KL(q(z|x) || p(z)) KL(q(zx)∣∣p(z)),VAE假设 q ( z ∣ x ) q(z|x) q(zx)的分布为正态分布,而 p ( z ) p(z) p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下:
K L ( N ( μ 1 , σ 1 2 ) , N ( μ 2 , σ 2 2 ) ) = log ⁡ σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 K L\left(N\left(\mu_1, \sigma_1^2\right), N\left(\mu_2, \sigma_2^2\right)\right)=\log \frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+\left(\mu_1-\mu_2\right)^2}{2 \sigma_2^2}-\frac{1}{2} KL(N(μ1,σ12),N(μ2,σ22))=logσ1σ2+2σ22σ12+(μ1μ2)221

由于此处p(z)为标准正态分布,因此其μ为0,σ为1,那么我们带入后可得
L 1 = − 1 2 ( log ⁡ σ 2 − σ 2 − μ 2 + 1 ) L_1=-\frac{1}{2}\left(\log \sigma^2-\sigma^2-\mu^2+1\right) L1=21(logσ2σ2μ2+1)

采用reparameterization trick有两大好处:

  • 由于分布 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)) 可能是一个比较复杂的函数, 直接采样操作费时费力, 而且采样方差可能很 大, 不利于收玫, 通过reparameterization可以简化操作, 提高效率, 提高数值上的稳定性;
  • 假设我们不考虑采样难度, 直接对 q ϕ ( z ∣ x ( i ) ) q_\phi\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) qϕ(zx(i)) 采样, 那么梯度反向传播的时候, 损失函数中的 1 L ∑ l = 1 L [ log ⁡ p θ ( x ( i ) ∣ z ( i , l ) ) ] \frac{1}{L} \sum_{l=1}^L\left[\log p_\theta\left(\mathbf{x}^{(i)} \mid \mathbf{z}^{(i, l)}\right)\right] L1l=1L[logpθ(x(i)z(i,l))] 是没法对 ϕ \phi ϕ 求导的, 这样损失函数 L ( ϕ , θ , x ( i ) ) \mathcal{L}\left(\phi, \theta, \mathbf{x}^{(\mathbf{i})}\right) L(ϕ,θ,x(i)) 只能通过KL散度 的梯度对 ϕ \phi ϕ 做优化, 这和我们做联合参数优化的意图是违背的。所以使用reparameterization trick 让 z ( i , l ) = g ϕ ( ϵ ( i ; l ) ; x ( i ) ) \mathbf{z}^{(i, l)}=g_\phi\left(\epsilon^{(i ; l)} ; \mathbf{x}^{(i)}\right) z(i,l)=gϕ(ϵ(i;l);x(i)), 实际上是让参数 θ , ϕ \theta, \phi θ,ϕ 可以同时得到期望项和KL散度项的反传 梯度进行优化, 让模型学得更好。

Reconstruction Loss

L ( θ , ϕ ; x ( i ) ) = − D K L ( q ϕ ( z ∣ x ( i ) ) ∥ p θ ( z ) ) ‾ + E q ϕ ( z ∣ x ( i ) ) [ log ⁡ p θ ( x ( i ) ∣ z ) ] ‾ − D K L ( ( q ϕ ( z ) ∥ p θ ( z ) ) = ∫ q θ ( z ) ( log ⁡ p θ ( z ) − log ⁡ q θ ( z ) ) d z = 1 2 ∑ j = 1 J ( 1 + log ⁡ ( ( σ j ) 2 ) − ( μ j ) 2 − ( σ j ) 2 ) f ∗ = arg ⁡ max ⁡ f ∈ F E z ∼ q x ∗ ( log ⁡ p ( x ∣ z ) ) = arg ⁡ max ⁡ f ∈ F E z ∼ q x ∗ ( − ∥ x − f ( z ) ∥ 2 2 c ) \begin{aligned} & \mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right)=\underline{-D_{K L}\left(q_{\boldsymbol{\phi}}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)}+\underline{\mathbb{E}_{q_{\boldsymbol{\phi}}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right)}\left[\log p_{\boldsymbol{\theta}}\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right)\right]} \\ & -D_{K L}\left(\left(q_{\boldsymbol{\phi}}(\mathbf{z}) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)=\int q_{\boldsymbol{\theta}}(\mathbf{z})\left(\log p_{\boldsymbol{\theta}}(\mathbf{z})-\log q_{\boldsymbol{\theta}}(\mathbf{z})\right) d \mathbf{z}\right. \\ & =\frac{1}{2} \sum_{j=1}^J\left(1+\log \left(\left(\sigma_j\right)^2\right)-\left(\mu_j\right)^2-\left(\sigma_j\right)^2\right) \\ & f^*=\underset{f \in F}{\arg \max } \mathbb{E}_{z \sim q_x^*}(\log p(x \mid z)) \\ & =\underset{f \in F}{\arg \max } \mathbb{E}_{z \sim q_x^*}\left(-\frac{\|x-f(z)\|^2}{2 c}\right) \\ & \end{aligned} L(θ,ϕ;x(i))=DKL(qϕ(zx(i))pθ(z))+Eqϕ(zx(i))[logpθ(x(i)z)]DKL((qϕ(z)pθ(z))=qθ(z)(logpθ(z)logqθ(z))dz=21j=1J(1+log((σj)2)(μj)2(σj)2)f=fFargmaxEzqx(logp(xz))=fFargmaxEzqx(2cxf(z)2)

L ( θ , ϕ ; x ( i ) ) = − D K L ( q ϕ ( z ∣ x ( i ) ) ∥ p θ ( z ) ) + E q ϕ ( z ∣ x ( i ) ) [ log ⁡ p θ ( x ( i ) ∣ z ) ] \mathcal{L}\left(\boldsymbol{\theta}, \boldsymbol{\phi} ; \mathbf{x}^{(i)}\right)=-D_{K L}\left(q_{\boldsymbol{\phi}}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right) \| p_{\boldsymbol{\theta}}(\mathbf{z})\right)+\mathbb{E}_{q_{\boldsymbol{\phi}}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right)}\left[\log p_{\boldsymbol{\theta}}\left(\mathbf{x}^{(i)} \mid \mathbf{z}\right)\right] L(θ,ϕ;x(i))=DKL(qϕ(zx(i))pθ(z))+Eqϕ(zx(i))[logpθ(x(i)z)]


import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image


parser = argparse.ArgumentParser(description='VAE MNIST Example with Different Losses')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100000, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./datasets', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./datasets', train=False, transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

### 1.
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function_original(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

### 2. 
# using the loss function which only consider reconstruction term.
def loss_function_only_recon(recon_x, x):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    return BCE

### 3. 
# be careful of the way two losses calculated.
# the only difference of this loss function is that the second term - KLD
# is "mean".
def loss_function_o1(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

### 4.
def loss_function_o2(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='mean')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

### 5.
def loss_function_kld(recon_x, x, mu, logvar):
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return KLD

### 6.
# apply the l1 loss
def loss_function_l1(recon_x, x, mu, logvar):
    L1 = F.l1_loss(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return L1 + KLD

### 7.
# apply the MSE loss
def loss_function_l2(recon_x, x, mu, logvar):
    L1 = F.mse_loss(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return L1 + KLD

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        which_loss = 7
        
        if which_loss==1:
            loss = loss_function_original(recon_batch, data, mu, logvar)
        elif which_loss==2:
            loss = loss_function_only_recon(recon_batch, data)
        elif which_loss==3:
            loss = loss_function_o1(recon_batch, data, mu, logvar)
        elif which_loss==4:
            loss = loss_function_o2(recon_batch, data, mu, logvar)
        elif which_loss==5:
            loss = loss_function_kld(recon_batch, data, mu, logvar)
        elif which_loss==6:
            loss = loss_function_l1(recon_batch, data, mu, logvar)
        elif which_loss==7:
            loss = loss_function_l2(recon_batch, data, mu, logvar)
            
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            if (i == 0) and (epoch % 10 == 0):
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'vae_img/7_m_reconstruction_' + str(epoch) + '.png', nrow=n)

if __name__ == "__main__":
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test(epoch)
        if epoch%10 == 0:
            with torch.no_grad():
                sample = torch.randn(64, 20).to(device)
                sample = model.decode(sample).cpu()
                save_image(sample.view(64, 1, 28, 28),
                        'vae_img/7_m_sample_' + str(epoch) + '.png')

参考

https://zhuanlan.zhihu.com/p/345360992

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/14883.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

干货!ICLR 2023 | 更稳定高效的因果发现方法-自适应加权

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! ╱ 个人简介╱ 张岸 新加坡国立大学NExT实验室博士后,主要研究Robust & Trustable AI。 个人主页:https://anzhang314.github.io/ 01 内容简介 可微分的因果发现方法,是从…

input 各类事件汇总触发时机触发顺序

今天梳理了一下input框的各类事件,简单介绍一下吧 目录 1.click 2.focus 3.blur 4.change 5.input 6.keydown 7.keyup 8.select 1.click 点击事件,简单易理解,点击触发,等下跟focus事件一起比较 2.focus 获取焦点事件…

每日学术速递4.24

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.CV 1.Collaborative Diffusion for Multi-Modal Face Generation and Editing(CVPR 2023) 标题:多模态人脸生成和编辑的协同扩散 作者:Ziqi Huang, Kelvin C.K. …

RabbitMQ之发布确认

1. 发布确认原理 ​ 生产者将信道设置成 confirm 模式,一旦信道进入 confirm 模式,所有在该信道上面发布的 消息都将会被指派一个唯一的 ID(从 1 开始),一旦消息被投递到所有匹配的队列之后,broker 就会发送一个确认给生产者(包含…

Hibernate多表关联——(一对多关系)

Hibernate多表关联——(一对多关系) 文章目录 Hibernate多表关联——(一对多关系)1.分别在类中添加属性:2.hibernate建表3.使用测试类在表中添加数据 hibernate是连接数据库使得更容易操作数据库数据的一个框架&#x…

ASEMI代理亚德诺AD8130ARZ-REEL7芯片应用与参数分析

编辑-Z 本文将对AD8130ARZ-REEL7芯片进行详细的应用与参数分析,包括其主要特征、接口定义、电气特性以及使用注意事项等方面,旨在为广大读者提供对该芯片更全面的了解。 1、主要特征 AD8130ARZ-REEL7芯片是一种用于高速、低功耗差分信号放大的电路&…

R语言 | 因子

目录 一、使用factor()函数或as.factor()函数建立因子 二、指定缺失的Levels值 三、labels参数 四、因子的转换 五、数值型因子转换时常见的错误 六、再看levels参数 七、有序因子 八、table()函数 九、认识系统内建的数据集 在类别数据中,有些数据是可以排序…

使用binding时,LayoutSubscribeFragmentBinding报错

LayoutRecommendFragmentBinding是一个DataBinding类,它由编译器自动生成,用于访问布局文件中的视图。如果你在代码中看到LayoutRecommendFragmentBinding报红(提示未解析的引用),可能有以下原因: 1. 检查…

软件工程开发文档写作教程(04)—开发文档的编制策略

本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl本文参考资料:电子工业出版社《软件文档写作教程》 马平,黄冬梅编著 开发文档编制策略 文档策略是由上级(资深)管理者制订的,对下级开发单位或开发人…

【C++ Metaprogramming】0. 在C++中实现类似C#的泛型类

两年前,笔者因为项目原因刚开始接触C,当时就在想,如果C有类似C#中的泛型限定就好了,能让代码简单许多。我也一度认为: 虽然C有模板类,但是却没办法实现C#中泛型特有的 where 关键词: public c…

胜叔说SI_PI_EMC

第一课 分享的目的 书籍推荐 第二课 什么是理论分析 仿真不是目的,仿真是验证理论分析的方法 测试不是目的,测试是验证理论分析的方法 第三课 信号完整性简介 小型化、高功率、高密度 传输线理论:传输线是由 信号路径和返回路径共同组…

OSI七层模型、TCP/IP四层模型

OSI七层模型和TCP/IP四层模型 OSI七层模型 物理层:底层数据传输,如网线、网卡标准数据链路层:定义数据基本格式,如何传输如何标识;如网卡MAC地址网络层:定义IP地址,定义路由功能;如…

温度调制式差示扫描量热法(MTDSC)中的正弦波温度控制技术

摘要:在调制温度式差式扫描量热仪(MTDSC)中,关键技术之一是正弦波加热温度的实现,此技术是制约目前国内无法生产MTDSC量热仪的重要障碍,这主要是因为现有的PID温控技术根本无法实现不同幅值和频率正弦波这样…

中文版gpt-最新的人工智能gpt

最新的人工智能gpt 什么是GPT? GPT是一种自然语言处理和语言生成技术,它能够学习和理解自然语言,并生成高质量的文本。GPT是由OpenAI开发的,它采用了最新的深度学习技术,具备了强大的自我学习能力和语言理解能力。它…

【PaddleNLP-kie】关键信息抽取2:UIE模型做图片信息提取全流程

文章目录 本文参考UIE理论部分step0、UIEX原始模型使用网页体验本机安装使用环境安装使用docker的环境安装快速开始 step1、UIEX模型微调(小样本学习)数据标注(label_studio)导出数据转换微调训练:评估:定制…

第二节 ogre sdk 配置使用

上一节,我们介绍过了ogre源码的编译学习,在实际项目中,我们并不需要如此复杂的编译安装过程,可以直接使用官网提供的sdk库进行项目环境配置。下面简单介绍下配置过程。 一 OgreSDK下载 https://dl.cloudsmith.io/public/ogrecav…

Centos安全加固策略

目录 密码安全策略 设置密码的有效期和最小长度 设置用户密码的复杂度 登录安全策略 设置用户远程登录的安全策略 安全的远程管理方式 访问控制 限制root用户登录 修改ssh 22端口 设置登录超时时间 限制IP访问 安全审计 审核策略开启 日志属性设置 查看系统登录…

Google Play编写长描述的最佳实践

在我们为应用编写详细说明时,要遵循以下建议: 我们作为应用营销人员,要了解受众群体的需求和顾虑,如果不知道用户关心什么,那么我们可以查看关键词的搜索量、每个关键词的 Google Play 安装报告、当前关键字排名等等。…

element+vue小技巧和报错解决(持续更新)

目录 1-关于el-table复选框中表头和内容不对齐问题 2-日期选择器传值给后端格式不对 3-获取表格的当前行数据#default"{row}" 1-关于el-table复选框中表头和内容不对齐问题 <el-table:data"tableData"stripestyle"width: 100%"tooltip-ef…

Django框架之自定义管理页面

Django框架Admin站点管理一些默认的显示和功能包括语言都可以自定义设置处理&#xff0c;以贴近我们的实际业务。 属性说明 列表页属性 配置文件myapp/admin.py from django.contrib import admin from .models import Grades, Students# Register your models here.# 注册班…
最新文章