Token-Weighted Excess Reconstruction Loss \ for Discrete CVAE Disentanglement
Token-Weighted Excess Reconstruction Loss \ for Discrete CVAE Disentanglement
This note identifies a gradient-scaling failure mode in the standard conditional ELBO when applied to teacher-forced autoregressive models, and derives a token-weighted reconstruction loss that corrects it.
The setting is the shared-parameter discrete CVAE described in the companion note CVAE Post-Training Methodology for Latent Strategy Disentanglement.
Motivation: Teacher-Forcing Dilution
Setup and notation
Consider a discrete CVAE with latent values trained under teacher forcing. For one example with , define:
- per-token z-conditioned NLL: \ell_{t} (z) = -\log p_\phi (y_{t} | x, z, y_{< t}),
- per-token z-free baseline NLL: b_{t} = -\log p_\phi (y_{t} | x, y_{< t}).
Both use the same backbone ; the baseline simply omits the section from the input.
The current reconstruction loss averages uniformly over all supervised tokens with :
Two-regime structure under teacher forcing
Under teacher forcing, the solution sequence has two regimes:
-
Strategy-sensitive prefix (). The teacher-forced history y_{< t} does not yet determine the strategy. The base model has non-trivial uncertainty: . The z-conditioned model can reduce this uncertainty when matches the strategy used in , so varies across latent values.
-
Strategy-determined suffix (). The partial solution y_{< t} reveals the strategy. Continuation is nearly deterministic: and for all , because the teacher-forced prefix renders redundant.
The boundary is not a sharp cutoff but a useful abstraction. For synthetic-sequence tasks, it corresponds roughly to the first few solution tokens where the procedural strategy is manifested.
The dilution
Under the two-regime model, the token-averaged reconstruction becomes:
The variation of across latent values:
where is the mean per-informative-token difference. The factor is the dilution: the easy tokens contribute zero to the numerator but inflate the denominator.
Gradient Analysis
Posterior parameters
The reconstruction gradient on the posterior parameters of is:
Because (the probabilities sum to one), any component of that is constant across cancels. The effective gradient magnitude is proportional to the variation:
The KL gradient on is , independent of . The reconstruction signal dominates when:
for a constant set by the KL curvature. When t_0 / T << 1 and , the KL wins and the posterior collapses to the prior.
Subtracting a stop-gradiented scalar from gives . Since is constant across , it cancels in by the sum-to-one constraint. The gradient is identical to the unmodified loss. More generally, any additive term that does not vary with is a no-op for the posterior gradient.
z-dependent parameters
The z-token embeddings receive gradient only from reconstruction (the KL does not involve ):
For early tokens, is nonzero because the z-token influences predictions before the teacher-forced prefix dominates. For late tokens, it is near zero. The effective gradient magnitude is:
The dilution also degrades the signal-to-noise ratio: the late tokens contribute small residual gradients in random directions, mixing with the meaningful gradient contributions from early tokens.
Token-Weighted Reconstruction Loss
Definition
Define per-token weights as the stop-gradiented z-free baseline NLL:
w_{t} = \operatorname{sg}(b_{t}) = \operatorname{sg}(-\log p_\phi (y_{t} | x, y_{< t})).
Replace the uniform token average with a weighted average:
where prevents division by zero.
The full weighted reconstruction loss under the posterior:
The total objective:
Intuition
Tokens where the base model is uncertain ( large) are precisely the positions where z-conditioning has room to reduce NLL: the strategy has not yet been revealed by the teacher-forced prefix. Tokens where the base model is already confident () are positions where is redundant. Weighting by focuses the reconstruction objective on the z-informative prefix without requiring task-specific knowledge of .
Gradient improvement
Since for , the weighted reconstruction concentrates on the strategy-sensitive prefix:
The variation across latent values:
with no dilution factor. The posterior gradient is amplified by relative to the uniform-average loss. The z-parameter gradient is similarly amplified: the denominator is instead of , and late-token noise gradients are suppressed.
The reconstruction gradient on now dominates the KL when:
which is times easier to satisfy than the uniform-average condition.
Comparison With Scalar Baselines
It is important to distinguish token-level weighting from scalar baseline subtraction. Several candidate formulations fail to change the optimization:
Scalar stop-gradient baseline (no-op)
All gradients are identical to the standard ELBO because the subtracted term is constant with respect to all parameters.
Per-token stop-gradient baseline, uniform average (no-op for and )
For : the terms factor out by . Gradient unchanged.
For : the terms do not depend on . Gradient unchanged.
For backbone (without stop-gradient on ): the gradients partially cancel, which removes reconstruction pressure on without creating z-usage pressure. This is harmful: the backbone quality degrades with no compensating benefit.
Token weighting (this proposal)
The critical difference is that weighting changes the denominator of the average, not the numerator. This alters the variation of across latent values, which is the quantity that drives the posterior gradient. Subtracting a baseline — whether scalar or per-token — does not change this variation.
Conditions For Effectiveness
When it works
-
Nonzero initial z-sensitivity. At initialization, z-token embeddings are random, producing small random differences in across at early positions. Under the standard loss, this signal is diluted by and may be too weak to overcome the KL pressure. Under the weighted loss, the signal is preserved at , potentially sufficient for the optimizer to bootstrap z-usage.
-
Non-trivial base-model uncertainty at early positions. The synthetic- sequence tasks have genuine strategy ambiguity: different strategies produce different early solution tokens for the same input . The base model (initialized from the entangled checkpoint) should have at these positions.
-
Self-adapting behavior. Early in training, the base model may be uncertain at many positions, making roughly uniform. As the backbone improves and late tokens become easy, the weights concentrate on the strategy-sensitive prefix. The reweighting strengthens exactly when dilution becomes the binding constraint.
When it may not suffice
-
Complete z-insensitivity. If the decoder produces for all at every position (including early tokens), the weighted variation is still zero. Reweighting amplifies existing z-variation but cannot create it from nothing. This limitation can be addressed by pairing the weighted loss with an architectural change that ensures nonzero (for example, z-conditioned layer normalization or a multi-token z-prefix).
-
Degenerate weights. If for all (including early tokens), the denominator and the loss degenerates. This would indicate that the task has no token-level strategy ambiguity, in which case disentanglement via reconstruction alone may be impossible.
-
Single-token z-conditioning. The z-token is a single token in the input sequence. Even with correct gradient scaling, its influence on early-position predictions is mediated by attention and may remain small. Reweighting is necessary for correct gradient balance but may not be sufficient to overcome an architectural bottleneck.
Implementation
Computing the baseline
The z-free baseline requires one additional forward pass per batch through the z-free input template:
<bos> <x> {x} </x> <y> {y} </y> <eos>
This is the original entangled-model format. The per-token NLL is computed under teacher forcing at supervised positions and detached before use as weights.
Two implementation policies are possible:
-
Frozen policy. Compute from a frozen reference model initialized from the entangled checkpoint. This keeps the weighting signal stable while the disentangled model is trained.
-
Dynamic policy. Compute from the current training backbone at each step (still stop-gradiented before use), so weights evolve with the model.
Current implementation status in this repository: the frozen policy is implemented; the dynamic policy is intentionally deferred.
Modifying masked_token_nll_per_example
The current implementation computes:
denom = active.sum(dim=1).clamp_min(1.0)
return (token_loss * active).sum(dim=1) / denom
The weighted variant replaces the uniform active mask with
active * weights, where weights is the detached per-token baseline NLL:
weighted_active = active * weights.detach()
denom = weighted_active.sum(dim=1).clamp_min(eps)
return (token_loss * weighted_active).sum(dim=1) / denom
Cost
One additional forward pass per batch (z-free baseline). For the exact discrete CVAE with latent values, the training loop already performs forward passes per batch ( latent-conditioned and posterior). The baseline adds one more, increasing cost by .
Fallback behavior
When the denominator is below a threshold (base model is confident everywhere), the loss should fall back to the standard uniform-average reconstruction to avoid numerical instability. A simple implementation:
R^{\mathrm{eff}} (z) = cases(R^w (z) & \text{if} \sum_{t \in \mathcal{S}} w_{t} > \tau, R(z) & \mathrm{otherwise},)where is a small positive threshold (for example, ).