Inter-Latent Divergence via Token-Weighted Jensen-Shannon Divergence
Inter-Latent Divergence via Token-Weighted Jensen-Shannon Divergence
This note formalizes an auxiliary loss that explicitly encourages different discrete latent values to produce different output distributions, addressing the ELBO's underdetermination of latent semantics.
Motivation
The standard conditional ELBO objective rewards any latent partition that improves reconstruction and has bounded KL. Nothing in this objective forces different latent values to produce different outputs.
Experimental evidence (Mar17, Mar19 collections) confirms this: across a wide range of , posterior entropy, router marginal KL, and router support weights, task accuracy remains high (0.985–1.0) but latent-strategy alignment is typically 0.16–0.27. The occupancy-oriented losses shape the distribution of latent usage without injecting strategy semantics.
The fundamental gap is that the ELBO lacks an inter-latent exclusivity signal: a pressure for when .
Full-Sequence Inter-Latent Divergence
The ideal objective would maximize the pairwise KL between full-sequence conditional distributions:
This is intractable for two reasons:
- The expectation requires autoregressive sampling from , which is sequential and non-differentiable.
- Under teacher forcing (replacing the expectation with ground-truth ), the late-token contributions are : the teacher-forced prefix reveals the strategy, making both and converge to the same (nearly deterministic) continuation. This is the same teacher-forcing dilution identified in the token-weighted reconstruction analysis.
Token-Weighted Jensen-Shannon Divergence
Definition
We propose a tractable proxy: the per-position Jensen-Shannon divergence across all latent-conditioned token distributions, weighted by z-free baseline NLL:
where the JSD of distributions under uniform mixture weights is:
and is the stop-gradiented z-free baseline NLL weight that focuses on strategy-sensitive positions.
The loss is negated so that minimizing maximizes inter-latent divergence.
Connection to Mutual Information
Under uniform mixture weights , the per-position JSD equals the conditional mutual information between the latent and the predicted token given the context:
The token-weighted sum therefore maximizes a weighted version of the per-step mutual information, focusing on positions where the z-free model is most uncertain (and thus most likely to be strategy-sensitive).
Key Properties
by construction. This provides natural gradient stability and makes the loss weight interpretable relative to other loss terms. No gradient clipping or normalization is required for the JSD term itself.
The discrete exact-enumeration training loop already computes logits for all latent values: \mathrm{logits_{by}_latent} \in \mathbb{R}^{B \times K \times T \times V}. Computing JSD from these logits requires only softmax, log-sum-exp, and elementwise operations—no additional forward passes. The marginal cost is negligible relative to the generator forward passes.
JSD is symmetric in its arguments, unlike asymmetric KL. This means the loss does not privilege one latent value over another, which is desirable for unsupervised disentanglement where we have no prior on which latent should correspond to which strategy.
Relationship to Token-Weighted Reconstruction
The token-weighted reconstruction loss (see excess_reconstruction.typ)
focuses the reconstruction gradient on strategy-sensitive tokens,
counteracting teacher-forcing dilution. The inter-latent JSD loss is
complementary:
- Token-weighted reconstruction says: "at strategy-sensitive positions, the model should predict the correct next token given the assigned latent ."
- Inter-latent JSD says: "at strategy-sensitive positions, different latent values should produce different next-token distributions."
The first targets reconstruction quality conditioned on ; the second targets exclusivity across values. Both use the same z-free baseline NLL weights, reinforcing each other's focus on the right positions.
Combination with InfoVAE
The inter-latent JSD is particularly effective when combined with an InfoVAE-style objective that reduces the per-example KL weight () while maintaining strong marginal matching (\mathrm{router_{marginal}_kl_{to}_prior_{weight}} > 0):
- InfoVAE removes the MI penalty, allowing maximum information to flow through (the posterior is free to encode strategy without being pushed toward the prior).
- Inter-latent JSD ensures the information flowing through is strategy-discriminating, not just generically predictive.
Together: "let be maximally informative" (InfoVAE) about "strategy specifically" (JSD).
Implementation Reference
- Loss function:
path:core/disentanglement/losses.py::inter_latent_jsd_loss - Config field:
DisentanglementLossConfig.inter_latent_divergence_weight - Integration:
SharedDiscreteCVAE.compute_inter_latent_divergence_loss - Diagnostic key:
diagnostics/inter_latent_jsd/weighted_mean