InfoVAE Connections to the Current Loss Structure

This note formalizes the relationship between the InfoVAE objective (Zhao et al., 2019) and the current disentanglement loss stack, showing that the existing infrastructure already supports InfoVAE-style training through independent tuning of existing hyperparameters.

The Two Roles of the KL Term

The standard ELBO's KL term KL(q(zx,y),p(zx))\mathrm{KL}(q(z | x, y), p(z | x)) simultaneously serves two distinct roles:

  • Mutual information penalty: reduces Iq(z;(x,y))I_{q} (z; (x, y)), discouraging the posterior from encoding information about the data in zz. This hurts disentanglement by pushing the posterior toward the prior and erasing strategy information.

  • Marginal matching: pushes the aggregate posterior q(z)=Ex,y[q(zx,y)]q(z) = \mathbb{E}_{x,y}[q(z | x, y)] toward the prior p(z)p(z), ensuring that the prior is a good approximation of the marginal posterior for generation. This helps generation quality.

These roles conflict: reducing the KL improves generation-time prior quality but simultaneously discourages the posterior from being informative.

The InfoVAE Decomposition

InfoVAE separates these roles explicitly. The KL can be decomposed as:

KL(q(zx,y),p(z))=Iq(z;(x,y))+KL(q(z),p(z)),\mathrm{KL}(q(z | x, y), p(z)) = I_{q} (z; (x, y)) + \mathrm{KL}(q(z), p(z)),

where q(z)=Ex,y[q(zx,y)]q(z) = \mathbb{E}_{x,y} [q(z | x, y)] is the aggregate posterior.

InfoVAE replaces the standard ELBO with:

LInfoVAE=Eq[logp(yx,z)](1α)KL(q(zx,y),p(zx))(α+λ1)D(q(z),p(z)),\mathcal{L}_\mathrm{InfoVAE} = \mathbb{E}_{q} [\log p(y | x, z)] - (1 - \alpha) \cdot \mathrm{KL}(q(z | x, y), p(z | x)) - (\alpha + \lambda - 1) \cdot D(q(z), p(z)),

where DD is a divergence measure (e.g., MMD) and α,λ\alpha, \lambda are hyperparameters.

Key regimes:

  • α=0,λ=1\alpha = 0, \lambda = 1: standard VAE (MI penalty active).
  • α1\alpha \to 1: remove MI penalty, keep only marginal matching via DD. This is the "MMD-VAE" regime.
  • α=0,λ>1\alpha = 0, \lambda > 1: β\beta-VAE with extra marginal regularization.

Mapping to Current Project Infrastructure

The current loss structure already has two independent controls that map onto the InfoVAE decomposition:

Parameter mapping

  • beta corresponds to (1α)(1 - \alpha): weight on the per-example KL term KL(q(zx,y),p(zx))\mathrm{KL}(q(z | x, y), p(z | x)). Controls the MI penalty.
  • router_marginal_kl_to_prior_weight corresponds to (α+λ1)(\alpha + \lambda - 1): weight on the marginal matching term KL(Ex[pϕ(zx)],Uniform)\mathrm{KL}(\mathbb{E}_{x} [p_\phi (z | x)], \mathrm{Uniform}). Controls how well the prior covers the latent space.

The existing infrastructure thus supports InfoVAE-style objectives by independently tuning these two weights. No code changes are required.

The InfoVAE Regime

The specific regime of interest is:

β0,router_marginal_kl_to_prior_weight1.0.\beta \approx 0, \quad \mathtt{router\_marginal\_kl\_to\_prior\_weight} \geq 1.0.

Effect of β0\beta \approx 0: Removes the MI penalty. The posterior q(zx,y)q(z | x, y) is free to encode maximum information about (x,y)(x, y) without being pushed toward the prior. This should prevent the observed sampled-z probe decay during training, which is consistent with the KL term eroding strategy information that the posterior had learned.

Effect of strong marginal matching: Keeps the aggregate prior Ex[pϕ(zx)]\mathbb{E}_{x} [p_\phi (z | x)] close to uniform. This ensures that at generation time, sampling z pϕ(zx)z ~ p_\phi (z | x) explores all latent values, not just a collapsed subset.

Prior Experimental Gap

The Mar19 sparse loss sweep tested marginal KL weights at 0.1, 0.5, 1.0 but always alongside β=0.1\beta = 0.1. The interaction of very low beta with strong marginal matching was not explored. This is precisely the InfoVAE regime.

Compound Intervention: InfoVAE + Inter-Latent JSD

The InfoVAE regime and the inter-latent JSD loss (see inter_latent_divergence.typ) address complementary failure modes:

  • InfoVAE addresses "how much information flows through zz": by removing the MI penalty, the posterior is free to encode strategy information without being pushed toward the prior.

  • Inter-latent JSD addresses "what kind of information flows through zz": by maximizing divergence between latent-conditioned distributions, the model is incentivized to use zz for strategy-discriminating information specifically, not just any predictive partition.

The combined objective for the discrete case can be written as

L=Ezq[logp(yx,z)]βKL(q(zx,y),p(zx))+λmarginalKL(Ex[p(zx)],Unif)+λJSDLJSD.\mathcal{L} = \mathbb{E}_{z \sim q}[\log p(y \mid x, z)] - \beta \cdot \mathrm{KL}(q(z \mid x, y), p(z \mid x)) + \lambda_\mathrm{marginal} \cdot \mathrm{KL}(\mathbb{E}_{x}[p(z \mid x)], \mathrm{Unif}) + \lambda_\mathrm{JSD} \cdot \mathcal{L}_\mathrm{JSD}.

MMD Alternative for Continuous Latents

For continuous latent variables, the marginal KL KL(Ex[q(zx,y)],N(0,I))\mathrm{KL}(\mathbb{E}_{x} [q(z | x, y)], \mathcal{N}(0, I)) may not fully capture the aggregate posterior's distributional properties (it only matches first and second moments via the moment-matching approximation).

InfoVAE suggests replacing marginal KL with Maximum Mean Discrepancy (MMD):

MMD2(q(z),p(z))=Ez,zq[k(z,z)]2Ezq,z~p[k(z,z~)]+Ez~,z~p[k(z~,z~)],MMD^2(q(z), p(z)) = \mathbb{E}_{z, z' \sim q}[k(z, z')] - 2 \mathbb{E}_{z \sim q, \tilde z \sim p}[k(z, \tilde z)] + \mathbb{E}_{\tilde z, \tilde z' \sim p}[k(\tilde z, \tilde z')],

where p(z)=N(0,I)p(z) = \mathcal{N}(0, I) and kk is a positive-definite kernel (e.g., Gaussian or inverse multiquadratic).

Implementation: compute kernel-based MMD on minibatch samples. For a batch of BB examples with posterior samples z1,...,zB q(zxi,yi)z_1, ..., z_{B} ~ q(z | x_{i}, y_{i}) and prior samples z1,...,zB N(0,I)z'_1, ..., z'_B ~ \mathcal{N}(0, I), the U-statistic MMD estimator is O(B2)O(B^2) and unbiased.

This is a natural extension when the continuous CVAE experiments move to the InfoVAE regime.

References

  • Zhao, Song, Ermon. "InfoVAE: Balancing Learning and Inference in Variational Autoencoders." AAAI 2019.
  • Tolstikhin, Bousquet, Gelly, Schoelkopf. "Wasserstein Auto-Encoders." ICLR 2018. (related MMD-based approach)
Built with LogoFlowershow