Inter-Latent Divergence via Token-Weighted Jensen-Shannon Divergence

This note formalizes an auxiliary loss that explicitly encourages different discrete latent values zkz_{k} to produce different output distributions, addressing the ELBO's underdetermination of latent semantics.

Motivation

The standard conditional ELBO objective LcELBO=Ezq(zx,y)[logpψ(yx,z)]βKL(q(zx,y),pϕ(zx))\mathcal{L}_\mathrm{cELBO} = \mathbb{E}_{z \sim q(z \mid x, y)}[\log p_\psi(y \mid x, z)] - \beta \mathrm{KL}(q(z \mid x, y), p_\phi(z \mid x)) rewards any latent partition that improves reconstruction and has bounded KL. Nothing in this objective forces different latent values to produce different outputs.

Experimental evidence (Mar17, Mar19 collections) confirms this: across a wide range of β\beta, posterior entropy, router marginal KL, and router support weights, task accuracy remains high (0.985–1.0) but latent-strategy alignment is typically 0.16–0.27. The occupancy-oriented losses shape the distribution of latent usage without injecting strategy semantics.

The fundamental gap is that the ELBO lacks an inter-latent exclusivity signal: a pressure for p(yx,zi)p(yx,zj)p(y \mid x, z_i) \neq p(y \mid x, z_j) when iji \neq j.

Full-Sequence Inter-Latent Divergence

The ideal objective would maximize the pairwise KL between full-sequence conditional distributions:

KL(p(yx,zi),p(yx,zj))=Eyp(x,zi)[tlogp(ytx,zi,y<t)logp(ytx,zj,y<t)].\mathrm{KL}(p(y \mid x, z_i), p(y \mid x, z_j)) = \mathbb{E}_{y \sim p(\cdot \mid x, z_i)} \left[\sum_{t} \log p(y_t \mid x, z_i, y_{<t}) - \log p(y_t \mid x, z_j, y_{<t})\right].

This is intractable for two reasons:

  • The expectation requires autoregressive sampling from p(x,zi)p(\cdot | x, z_{i}), which is sequential and non-differentiable.
  • Under teacher forcing (replacing the expectation with ground-truth yy), the late-token contributions are 0\approx 0: the teacher-forced prefix y<ty_{<t} reveals the strategy, making both p(ytx,zi,y<t)p(y_{t} | x, z_{i}, y_{<t}) and p(ytx,zj,y<t)p(y_{t} | x, z_{j}, y_{<t}) converge to the same (nearly deterministic) continuation. This is the same teacher-forcing dilution identified in the token-weighted reconstruction analysis.

Token-Weighted Jensen-Shannon Divergence

Definition

We propose a tractable proxy: the per-position Jensen-Shannon divergence across all KK latent-conditioned token distributions, weighted by z-free baseline NLL:

LJSD=twtJSD ⁣(p(ytx,z1,y<t),,p(ytx,zK,y<t)),\mathcal{L}_\mathrm{JSD} = - \sum_{t} w_t \cdot \operatorname{JSD}\!\left(p(y_t \mid x, z_1, y_{<t}), \ldots, p(y_t \mid x, z_K, y_{<t})\right),

where the JSD of KK distributions under uniform mixture weights is:

JSD(p1,,pK)=H ⁣(1Kk=1Kpk)1Kk=1KH(pk),\operatorname{JSD}(p_1, \ldots, p_K) = H\!\left(\frac{1}{K} \sum_{k=1}^{K} p_k\right) - \frac{1}{K} \sum_{k=1}^{K} H(p_k),

and wt=sg(bt)w_t = \operatorname{sg}(b_t) is the stop-gradiented z-free baseline NLL weight that focuses on strategy-sensitive positions.

The loss is negated so that minimizing LJSD\mathcal{L}_\mathrm{JSD} maximizes inter-latent divergence.

Connection to Mutual Information

Under uniform mixture weights πk=1/K\pi_{k} = 1/K, the per-position JSD equals the conditional mutual information between the latent zz and the predicted token yty_{t} given the context:

JSD(p1,,pK)=I(z;ytx,y<t)under zUnif({1,,K}).\operatorname{JSD}(p_1, \ldots, p_K) = I(z; y_t \mid x, y_{<t}) \quad \text{under } z \sim \mathrm{Unif}(\{1, \ldots, K\}).

The token-weighted sum therefore maximizes a weighted version of the per-step mutual information, focusing on positions where the z-free model is most uncertain (and thus most likely to be strategy-sensitive).

Key Properties

Boundedness

JSD[0,logK]JSD \in [0, \log K] by construction. This provides natural gradient stability and makes the loss weight λJSD\lambda_\mathrm{JSD} interpretable relative to other loss terms. No gradient clipping or normalization is required for the JSD term itself.

Computational cost

The discrete exact-enumeration training loop already computes logits for all KK latent values: \mathrm{logits_{by}_latent} \in \mathbb{R}^{B \times K \times T \times V}. Computing JSD from these logits requires only softmax, log-sum-exp, and elementwise operations—no additional forward passes. The marginal cost is negligible relative to the KK generator forward passes.

Symmetry

JSD is symmetric in its arguments, unlike asymmetric KL. This means the loss does not privilege one latent value over another, which is desirable for unsupervised disentanglement where we have no prior on which latent should correspond to which strategy.

Relationship to Token-Weighted Reconstruction

The token-weighted reconstruction loss (see excess_reconstruction.typ) focuses the reconstruction gradient on strategy-sensitive tokens, counteracting teacher-forcing dilution. The inter-latent JSD loss is complementary:

  • Token-weighted reconstruction says: "at strategy-sensitive positions, the model should predict the correct next token given the assigned latent zz."
  • Inter-latent JSD says: "at strategy-sensitive positions, different latent values should produce different next-token distributions."

The first targets reconstruction quality conditioned on zz; the second targets exclusivity across zz values. Both use the same z-free baseline NLL weights, reinforcing each other's focus on the right positions.

Combination with InfoVAE

The inter-latent JSD is particularly effective when combined with an InfoVAE-style objective that reduces the per-example KL weight (β0\beta \approx 0) while maintaining strong marginal matching (\mathrm{router_{marginal}_kl_{to}_prior_{weight}} > 0):

  • InfoVAE removes the MI penalty, allowing maximum information to flow through zz (the posterior is free to encode strategy without being pushed toward the prior).
  • Inter-latent JSD ensures the information flowing through zz is strategy-discriminating, not just generically predictive.

Together: "let zz be maximally informative" (InfoVAE) about "strategy specifically" (JSD).

Implementation Reference

  • Loss function: path:core/disentanglement/losses.py::inter_latent_jsd_loss
  • Config field: DisentanglementLossConfig.inter_latent_divergence_weight
  • Integration: SharedDiscreteCVAE.compute_inter_latent_divergence_loss
  • Diagnostic key: diagnostics/inter_latent_jsd/weighted_mean
Built with LogoFlowershow