Baseline-Normalized Reconstruction Loss
Baseline-Normalized Reconstruction Loss
This note describes a modification to the standard ELBO (reconstruction + KL) loss tailored to the problem of learning a disentangled factorization of an entangled-initialized base model.
Standard ELBO Objective
For one supervised pair , the conditional ELBO is
With a beta-weighted KL:
Equivalently, the minimized loss is
where the first term is the reconstruction loss
and the second term is the KL divergence tethering the posterior to the prior/router:
In a standard VAE trained from scratch, is initially large, so reconstruction gradients dominate early. In the current project this is not the regime: is initialized from an already strong entangled model , so the initial y-region NLL is often already very small. The KL term therefore starts on a comparable or larger scale and can push toward before becomes informative.
The failure mode is that the reconstruction loss is already small (at initialization, model can predict from without ), so the encoder-side gradient that would make useful is weak from step 0.
Normalized Reconstruction Objective
Let denote the frozen entangled model's end-of-training validation loss on the same y-region masking used by the disentangled trainer. This is a single scalar loaded from the entangled run artifacts.
Define the normalized reconstruction term
The optimized objective becomes
So normalization does not alter the shape of the reconstruction objective; it changes only its global weight. When the entangled baseline loss is tiny, the effective reconstruction coefficient becomes large, restoring a better reconstruction-to-KL balance.
Interpretation
The corresponding validation metric
has a direct interpretation:
- : the disentangled model matches the entangled baseline.
- \rho_\mathrm{recon} < 1: the latent-conditioned model improves on the baseline (implying z informativeness).
- : the factorized model is even worse than the baseline at predicting y from x, z.
This is useful both as an optimization rescaling and as a comparison metric, because it reports reconstruction quality relative to the actual entangled initializer rather than on an absolute NLL scale that varies by task.
Relationship to Other Interventions
Token-Weighted Reconstruction
Token-weighted reconstruction changes the distribution of the reconstruction gradient across token positions by upweighting positions where the z-free model is uncertain. Baseline normalization changes only the global magnitude of the standard reconstruction term. These are complementary:
- normalization says "make reconstruction matter again at all,"
- token weighting says "within reconstruction, focus on strategy-sensitive tokens."
Both can be used simultaneously.
Inter-Latent JSD
Inter-latent JSD injects an exclusivity signal: different latent values should induce different token distributions. It does not by itself solve the weak-reconstruction problem. Normalization and JSD therefore address different bottlenecks:
- normalization encourages to become useful,
- JSD encourages the information in to be strategy-separating.
InfoVAE-Style Beta Reduction
Lowering weakens the KL pressure directly. Normalization strengthens reconstruction directly. Both rebalance the same competition, but in opposite ways:
- beta reduction rescales downward,
- normalization rescales upward.
Normalization is especially natural in the entangled-initialized regime because its scale is tied to the observed baseline loss rather than chosen manually.
Diagnostics to Watch
The main diagnostics are:
- raw
reconstruction_loss: did the conditional generator remain strong? baseline_normalized_reconstruction_loss: is the factorized model beating the entangled baseline?kl_prior_loss: is the posterior carrying information instead of collapsing?loss_contrib/fraction_abs/reconstruction_loss: did the optimization balance actually shift toward reconstruction?- downstream structure metrics such as alignment or probe accuracy: does the stronger reconstruction signal translate into more informative latents?
If normalization is working, the reconstruction contribution should grow, the normalized reconstruction ratio should fall below , and the latent variables should become more informative without a large drop in task correctness.