MethodologyConceptsDiscrete Latent VAEs

Discrete Latent VAEs

This note records the training-specific considerations that arise when the latent variable is discrete rather than continuous.

Related notes:

A discrete-state VAE uses categorical latent variables instead of Gaussian latents.

Model and parameterization

Given data xx or conditional context cc:

  • prior: p(z)p(z), often uniform categorical, or conditional prior pϕ(zc)p_\phi(z \mid c)
  • encoder / inference model: logits ψ(x)RK\ell_\psi(x) \in \mathbb{R}^K with qψ(z=kx)=softmax(ψ(x))kq_\psi(z = k \mid x) = \operatorname{softmax}(\ell_\psi(x))_k
  • decoder: pθ(xz)p_\theta(x \mid z) or pθ(yx,z)p_\theta(y \mid x, z), with zz provided through an embedding or other conditioning interface

ELBO objective

The negative ELBO is

L(θ,ψ)=Ezqψ(x)[logpθ(xz)]+βKL ⁣(qψ(zx)p(z)).\mathcal{L}(\theta, \psi) = - \mathbb{E}_{z \sim q_\psi(\cdot \mid x)}[\log p_\theta(x \mid z)] + \beta \, \mathrm{KL}\!\left(q_\psi(z \mid x)\,\|\,p(z)\right).

In the conditional case, replace p(z)p(z) with a conditional prior such as pϕ(zc)p_\phi(z \mid c) and condition the decoder on cc as well.

KL for categorical latents

For categorical qq and pp with probabilities qkq_k and pkp_k,

KL(q,p)=k=1Kqklogqkpk.\mathrm{KL}(q, p) = \sum_{k=1}^{K} q_k \log \frac{q_k}{p_k}.

This term is analytic, so no sampling is needed for the KL itself.

Reconstruction expectation

The main practical issue is the reconstruction expectation

Ezqψ(x)[logpθ(xz)].\mathbb{E}_{z \sim q_\psi(\cdot \mid x)}[\log p_\theta(x \mid z)].

Two main regimes matter:

  1. Exact marginalization for small KK: Ez[logpθ(xz)]=k=1Kqklogpθ(xz=k).\mathbb{E}_{z}[\log p_\theta(x \mid z)] = \sum_{k=1}^{K} q_k \log p_\theta(x \mid z = k). This is low variance, unbiased, and differentiable.
  2. Monte Carlo for large KK: Ez[logpθ(xz)]1Ss=1Slogpθ(xz(s)).\mathbb{E}_{z}[\log p_\theta(x \mid z)] \approx \frac{1}{S} \sum_{s=1}^{S} \log p_\theta(x \mid z^{(s)}). This is straightforward for decoder parameters but raises the usual discrete-gradient issue for the encoder.

Differentiating through discrete sampling

Sampling a categorical latent is not pathwise differentiable. Common options are:

  • REINFORCE / score-function estimators: ψEz[f(z)]=Ez ⁣[f(z)ψlogqψ(zx)]\nabla_\psi \mathbb{E}_{z}[f(z)] = \mathbb{E}_{z}\!\left[f(z)\nabla_\psi \log q_\psi(z \mid x)\right] often with a baseline for variance reduction
  • Gumbel-Softmax / Concrete relaxations: use a soft one-hot approximation in the backward pass
  • control-variate estimators such as VIMCO, RELAX, or REBAR: more complex, but potentially lower variance
Practical rule

For small or moderate KK, prefer exact marginalization. For larger or structured discrete latent spaces, a Gumbel-Softmax-style relaxation is often the simplest practical choice.

Why this note matters here

The main methodology cluster already covers the discrete-latent interfaces. What this note adds is the training-specific reminder that discrete latents change the reconstruction expectation and gradient-estimation story, even when the high-level CVAE factorization looks the same.

Built with LogoFlowershow