Baseline-Normalized Reconstruction Loss

This note describes a modification to the standard ELBO (reconstruction + KL) loss tailored to the problem of learning a disentangled factorization of an entangled-initialized base model.

Standard ELBO Objective

For one supervised pair (x,y)(x, y), the conditional ELBO is

LcELBO(x,y)=Ezqξ(x,y)[logpψ(yx,z)]KL(qξ(zx,y),pϕ(zx)).\mathcal{L}_\mathrm{cELBO}(x, y) = \mathbb{E}_{z \sim q_\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)] - \mathrm{KL}(q_\xi(z \mid x, y), p_\phi(z \mid x)).

With a beta-weighted KL:

Lβ(x,y)=Ezqξ(x,y)[logpψ(yx,z)]βKL(qξ(zx,y),pϕ(zx)).\mathcal{L}_{\beta}(x, y) = \mathbb{E}_{z \sim q_\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)] - \beta \mathrm{KL}(q_\xi(z \mid x, y), p_\phi(z \mid x)).

Equivalently, the minimized loss is

J(x,y)=R(x,y)+βK(x,y)+Laux(x,y),\mathcal{J}(x, y) = R(x, y) + \beta K(x, y) + \mathcal{L}_\mathrm{aux}(x, y),

where the first term is the reconstruction loss

R(x,y)=Ezqξ(x,y)[logpψ(yx,z)],R(x, y) = -\mathbb{E}_{z \sim q_\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)],

and the second term is the KL divergence tethering the posterior to the prior/router:

K(x,y)=KL(qξ(zx,y),pϕ(zx)).K(x, y) = \mathrm{KL}(q_\xi(z \mid x, y), p_\phi(z \mid x)).

In a standard VAE trained from scratch, R(x,y)R(x, y) is initially large, so reconstruction gradients dominate early. In the current project this is not the regime: pψ(yx,z)p_\psi (y | x, z) is initialized from an already strong entangled model pθ(yx)p_\theta (y | x), so the initial y-region NLL is often already very small. The KL term therefore starts on a comparable or larger scale and can push qξq_\xi toward pϕp_\phi before zz becomes informative.

Core issue

The failure mode is that the reconstruction loss is already small (at initialization, model can predict yy from xx without zz), so the encoder-side gradient that would make zz useful is weak from step 0.

Normalized Reconstruction Objective

Let cbase>0c_\mathrm{base} > 0 denote the frozen entangled model's end-of-training validation loss on the same y-region masking used by the disentangled trainer. This is a single scalar loaded from the entangled run artifacts.

Define the normalized reconstruction term

Rnorm(x,y)=R(x,y)cbase.R_\mathrm{norm}(x, y) = \frac{R(x, y)}{c_\mathrm{base}}.

The optimized objective becomes

Jnorm(x,y)=Rnorm(x,y)+βK(x,y)+Laux(x,y)=1cbaseR(x,y)+βK(x,y)+Laux(x,y).\mathcal{J}_\mathrm{norm}(x, y) = R_\mathrm{norm}(x, y) + \beta K(x, y) + \mathcal{L}_\mathrm{aux}(x, y) = \frac{1}{c_\mathrm{base}} R(x, y) + \beta K(x, y) + \mathcal{L}_\mathrm{aux}(x, y).

So normalization does not alter the shape of the reconstruction objective; it changes only its global weight. When the entangled baseline loss is tiny, the effective reconstruction coefficient 1/cbase1 / c_\mathrm{base} becomes large, restoring a better reconstruction-to-KL balance.

Interpretation

The corresponding validation metric

ρrecon=Rval/cbase\rho_\mathrm{recon} = R_\mathrm{val} / c_\mathrm{base}

has a direct interpretation:

  • ρrecon=1\rho_\mathrm{recon} = 1: the disentangled model matches the entangled baseline.
  • \rho_\mathrm{recon} < 1: the latent-conditioned model improves on the baseline (implying z informativeness).
  • ρrecon>1\rho_\mathrm{recon} > 1: the factorized model is even worse than the baseline at predicting y from x, z.

This is useful both as an optimization rescaling and as a comparison metric, because it reports reconstruction quality relative to the actual entangled initializer rather than on an absolute NLL scale that varies by task.

Relationship to Other Interventions

Token-Weighted Reconstruction

Token-weighted reconstruction changes the distribution of the reconstruction gradient across token positions by upweighting positions where the z-free model is uncertain. Baseline normalization changes only the global magnitude of the standard reconstruction term. These are complementary:

  • normalization says "make reconstruction matter again at all,"
  • token weighting says "within reconstruction, focus on strategy-sensitive tokens."

Both can be used simultaneously.

Inter-Latent JSD

Inter-latent JSD injects an exclusivity signal: different latent values should induce different token distributions. It does not by itself solve the weak-reconstruction problem. Normalization and JSD therefore address different bottlenecks:

  • normalization encourages zz to become useful,
  • JSD encourages the information in zz to be strategy-separating.

InfoVAE-Style Beta Reduction

Lowering β\beta weakens the KL pressure directly. Normalization strengthens reconstruction directly. Both rebalance the same competition, but in opposite ways:

  • beta reduction rescales K(x,y)K(x, y) downward,
  • normalization rescales R(x,y)R(x, y) upward.

Normalization is especially natural in the entangled-initialized regime because its scale is tied to the observed baseline loss rather than chosen manually.

Diagnostics to Watch

The main diagnostics are:

  • raw reconstruction_loss: did the conditional generator remain strong?
  • baseline_normalized_reconstruction_loss: is the factorized model beating the entangled baseline?
  • kl_prior_loss: is the posterior carrying information instead of collapsing?
  • loss_contrib/fraction_abs/reconstruction_loss: did the optimization balance actually shift toward reconstruction?
  • downstream structure metrics such as alignment or probe accuracy: does the stronger reconstruction signal translate into more informative latents?
Expected signature

If normalization is working, the reconstruction contribution should grow, the normalized reconstruction ratio should fall below 11, and the latent variables should become more informative without a large drop in task correctness.

Built with LogoFlowershow