Token-Weighted Excess Reconstruction Loss \ for Discrete CVAE Disentanglement

This note identifies a gradient-scaling failure mode in the standard conditional ELBO when applied to teacher-forced autoregressive models, and derives a token-weighted reconstruction loss that corrects it.

The setting is the shared-parameter discrete CVAE described in the companion note CVAE Post-Training Methodology for Latent Strategy Disentanglement.

Motivation: Teacher-Forcing Dilution

Setup and notation

Consider a discrete CVAE with KK latent values trained under teacher forcing. For one example (x,y)(x, y) with y=(y1,,yT)y = (y_1, \dots, y_{T}), define:

  • per-token z-conditioned NLL: \ell_{t} (z) = -\log p_\phi (y_{t} | x, z, y_{< t}),
  • per-token z-free baseline NLL: b_{t} = -\log p_\phi (y_{t} | x, y_{< t}).

Both use the same backbone ϕ\phi; the baseline simply omits the angle.lzangle.rzkangle.l/zangle.rangle.l z angle.r z_{k} angle.l / z angle.r section from the input.

The current reconstruction loss averages uniformly over all supervised tokens S\mathcal{S} with S=T|\mathcal{S}| = T:

R(z)=1/TtSt(z).R(z) = 1 / T \sum_{t \in \mathcal{S}} \ell_{t} (z).

Two-regime structure under teacher forcing

Under teacher forcing, the solution sequence has two regimes:

  • Strategy-sensitive prefix (tt0t \leq t_0). The teacher-forced history y_{< t} does not yet determine the strategy. The base model has non-trivial uncertainty: bt>0b_{t} > 0. The z-conditioned model can reduce this uncertainty when zz matches the strategy used in yy, so t(z)\ell_{t} (z) varies across latent values.

  • Strategy-determined suffix (t>t0t > t_0). The partial solution y_{< t} reveals the strategy. Continuation is nearly deterministic: bt0b_{t} \approx 0 and t(z)0\ell_{t} (z) \approx 0 for all zz, because the teacher-forced prefix renders zz redundant.

The boundary t0t_0 is not a sharp cutoff but a useful abstraction. For synthetic-sequence tasks, it corresponds roughly to the first few solution tokens where the procedural strategy is manifested.

The dilution

Under the two-regime model, the token-averaged reconstruction becomes:

R(z)=1/Tt=1t0t(z)+underbrace(1/Tt=t0+1Tt(z),0).R(z) = 1 / T \sum_{t=1}^{t_0} \ell_{t} (z) + underbrace(1 / T \sum_{t=t_0 + 1}^{T} \ell_{t} (z), \approx 0).

The variation of R(z)R(z) across latent values:

R(z)R(z)=1/Tt=1t0[t(z)t(z)]=t0/Tδzz,R(z) - R(z') = 1 / T \sum_{t=1}^{t_0} [\ell_{t} (z) - \ell_{t} (z')] = t_0 / T \cdot \delta_{z z'},

where δzz=1/t0tt0[t(z)t(z)]\delta_{z z'} = 1 / t_0 \sum_{t \leq t_0} [\ell_{t} (z) - \ell_{t} (z')] is the mean per-informative-token difference. The factor t0/Tt_0 / T is the dilution: the Tt0T - t_0 easy tokens contribute zero to the numerator but inflate the denominator.

Gradient Analysis

Posterior parameters ξ\xi

The reconstruction gradient on the posterior parameters ξ\xi of qξ(zx,y)q_\xi (z | x, y) is:

(Lrecon)/(ξ)=z(qξ)/(ξ)R(z).(\partial \mathcal{L}_\mathrm{recon}) / (\partial \xi) = \sum_{z} (\partial q_\xi) / (\partial \xi) \cdot R(z).

Because zqξ/ξ=0\sum_{z} \partial q_\xi / \partial \xi = 0 (the probabilities sum to one), any component of R(z)R(z) that is constant across zz cancels. The effective gradient magnitude is proportional to the variation:

LreconξmaxzzR(z)R(z)=t0Tδ.\left\lVert \frac{\partial \mathcal{L}_\mathrm{recon}}{\partial \xi} \right\rVert \sim \max_{z \neq z'} |R(z) - R(z')| = \frac{t_0}{T} |\delta|.

The KL gradient on ξ\xi is O(β)O(\beta), independent of TT. The reconstruction signal dominates when:

t0/Tδ>βCt_0 / T |\delta| > \beta \cdot C

for a constant CC set by the KL curvature. When t_0 / T << 1 and β=O(1)\beta = O(1), the KL wins and the posterior collapses to the prior.

Why a scalar baseline does not help

Subtracting a stop-gradiented scalar C=sg(1/Ttbt)C = \operatorname{sg}(1 / T \sum_{t} b_{t}) from R(z)R(z) gives R(z)CR(z) - C. Since CC is constant across zz, it cancels in z(q/ξ)[R(z)C]\sum_{z} (\partial q / \partial \xi) \cdot [R(z) - C] by the sum-to-one constraint. The gradient is identical to the unmodified loss. More generally, any additive term that does not vary with zz is a no-op for the posterior gradient.

z-dependent parameters ψ\psi

The z-token embeddings ψ\psi receive gradient only from reconstruction (the KL does not involve ψ\psi):

(Lrecon)/(ψ)=zq(z)1/Tt(t(z))/(ψ).(\partial \mathcal{L}_\mathrm{recon}) / (\partial \psi) = \sum_{z} q(z) \cdot 1 / T \sum_{t} (\partial \ell_{t} (z)) / (\partial \psi).

For early tokens, (t)/(ψ)(\partial \ell_{t}) / (\partial \psi) is nonzero because the z-token influences predictions before the teacher-forced prefix dominates. For late tokens, it is near zero. The effective gradient magnitude is:

(Lrecon)/(ψ) t0/Tψearly.||(\partial \mathcal{L}_\mathrm{recon}) / (\partial \psi)|| ~ t_0 / T \cdot ||\nabla_{\psi}||_{\mathrm{early}}.

The dilution also degrades the signal-to-noise ratio: the Tt0T - t_0 late tokens contribute small residual gradients in random directions, mixing with the t0t_0 meaningful gradient contributions from early tokens.

Token-Weighted Reconstruction Loss

Definition

Define per-token weights as the stop-gradiented z-free baseline NLL:

w_{t} = \operatorname{sg}(b_{t}) = \operatorname{sg}(-\log p_\phi (y_{t} | x, y_{< t})).

Replace the uniform token average with a weighted average:

Rw(z)=(tSwtt(z))/(tSwt+ϵ),R^w (z) = (\sum_{t \in \mathcal{S}} w_{t} \cdot \ell_{t} (z)) / (\sum_{t \in \mathcal{S}} w_{t} + \epsilon),

where ϵ>0\epsilon > 0 prevents division by zero.

The full weighted reconstruction loss under the posterior:

Lreconw=z=1Kqξ(zx,y)Rw(z).\mathcal{L}_\mathrm{recon}^w = \sum_{z=1}^{K} q_\xi (z | x, y) \cdot R^w (z).

The total objective:

J=Lreconw+βKL(qξ(zx,y),pϕ(zx))+auxiliary terms.\mathcal{J} = \mathcal{L}_\mathrm{recon}^w + \beta \cdot \mathrm{KL}(q_\xi (z | x, y), p_\phi (z | x)) + \mathrm{auxiliary\ terms}.

Intuition

Tokens where the base model is uncertain (btb_{t} large) are precisely the positions where z-conditioning has room to reduce NLL: the strategy has not yet been revealed by the teacher-forced prefix. Tokens where the base model is already confident (bt0b_{t} \approx 0) are positions where zz is redundant. Weighting by btb_{t} focuses the reconstruction objective on the z-informative prefix without requiring task-specific knowledge of t0t_0.

Gradient improvement

Since wt0w_{t} \approx 0 for t>t0t > t_0, the weighted reconstruction concentrates on the strategy-sensitive prefix:

Rw(z)(t=1t0btt(z))/(t=1t0bt).R^w (z) \approx (\sum_{t=1}^{t_0} b_{t} \cdot \ell_{t} (z)) / (\sum_{t=1}^{t_0} b_{t}).

The variation across latent values:

Rw(z)Rw(z)=(t=1t0bt[t(z)t(z)])/(t=1t0bt)=O(δ),R^w (z) - R^w (z') = (\sum_{t=1}^{t_0} b_{t} \cdot [\ell_{t} (z) - \ell_{t} (z')]) / (\sum_{t=1}^{t_0} b_{t}) = O(|\delta|),

with no t0/Tt_0 / T dilution factor. The posterior gradient is amplified by T/t0T / t_0 relative to the uniform-average loss. The z-parameter gradient is similarly amplified: the denominator is tt0bt\sum_{t \leq t_0} b_{t} instead of TT, and late-token noise gradients are suppressed.

The reconstruction gradient on ξ\xi now dominates the KL when:

δ>βC,|\delta| > \beta \cdot C,

which is T/t0T / t_0 times easier to satisfy than the uniform-average condition.

Comparison With Scalar Baselines

It is important to distinguish token-level weighting from scalar baseline subtraction. Several candidate formulations fail to change the optimization:

Scalar stop-gradient baseline (no-op)

Lexc, scalar=zq(z)[R(z)sg(overline(b))]=Lreconconst.\mathcal{L}^{\mathrm{exc,\ scalar}} = \sum_{z} q(z) \cdot [R(z) - \operatorname{sg}(overline(b))] = \mathcal{L}_\mathrm{recon} - \mathrm{const}.

All gradients are identical to the standard ELBO because the subtracted term is constant with respect to all parameters.

Per-token stop-gradient baseline, uniform average (no-op for ξ\xi and ψ\psi)

Lexc, pertoken=zq(z)1/Tt[t(z)sg(bt)].\mathcal{L}^{\mathrm{exc,\ per-token}} = \sum_{z} q(z) \cdot 1 / T \sum_{t} [\ell_{t} (z) - \operatorname{sg}(b_{t})].

For ξ\xi: the sg(bt)\operatorname{sg}(b_{t}) terms factor out by zq/ξ=0\sum_{z} \partial q / \partial \xi = 0. Gradient unchanged.

For ψ\psi: the sg(bt)\operatorname{sg}(b_{t}) terms do not depend on ψ\psi. Gradient unchanged.

For backbone θ\theta (without stop-gradient on btb_{t}): the gradients partially cancel, which removes reconstruction pressure on θ\theta without creating z-usage pressure. This is harmful: the backbone quality degrades with no compensating benefit.

Token weighting (this proposal)

The critical difference is that weighting changes the denominator of the average, not the numerator. This alters the variation of Rw(z)R^w (z) across latent values, which is the quantity that drives the posterior gradient. Subtracting a baseline — whether scalar or per-token — does not change this variation.

Conditions For Effectiveness

When it works

  • Nonzero initial z-sensitivity. At initialization, z-token embeddings are random, producing small random differences in t(z)\ell_{t} (z) across zz at early positions. Under the standard loss, this signal is diluted by t0/Tt_0 / T and may be too weak to overcome the KL pressure. Under the weighted loss, the signal is preserved at O(δ)O(|\delta|), potentially sufficient for the optimizer to bootstrap z-usage.

  • Non-trivial base-model uncertainty at early positions. The synthetic- sequence tasks have genuine strategy ambiguity: different strategies produce different early solution tokens for the same input xx. The base model (initialized from the entangled checkpoint) should have bt>0b_{t} > 0 at these positions.

  • Self-adapting behavior. Early in training, the base model may be uncertain at many positions, making wtw_{t} roughly uniform. As the backbone improves and late tokens become easy, the weights concentrate on the strategy-sensitive prefix. The reweighting strengthens exactly when dilution becomes the binding constraint.

When it may not suffice

  • Complete z-insensitivity. If the decoder produces t(z)=t(z)\ell_{t} (z) = \ell_{t} (z') for all z,zz, z' at every position (including early tokens), the weighted variation is still zero. Reweighting amplifies existing z-variation but cannot create it from nothing. This limitation can be addressed by pairing the weighted loss with an architectural change that ensures nonzero t/ψ\partial \ell_{t} / \partial \psi (for example, z-conditioned layer normalization or a multi-token z-prefix).

  • Degenerate weights. If bt0b_{t} \approx 0 for all tt (including early tokens), the denominator sumwt0sum w_{t} \approx 0 and the loss degenerates. This would indicate that the task has no token-level strategy ambiguity, in which case disentanglement via reconstruction alone may be impossible.

  • Single-token z-conditioning. The z-token is a single token in the input sequence. Even with correct gradient scaling, its influence on early-position predictions is mediated by attention and may remain small. Reweighting is necessary for correct gradient balance but may not be sufficient to overcome an architectural bottleneck.

Implementation

Computing the baseline btb_{t}

The z-free baseline requires one additional forward pass per batch through the z-free input template:

<bos> <x> {x} </x> <y> {y} </y> <eos>

This is the original entangled-model format. The per-token NLL btb_{t} is computed under teacher forcing at supervised positions and detached before use as weights.

Two implementation policies are possible:

  • Frozen policy. Compute btb_{t} from a frozen reference model initialized from the entangled checkpoint. This keeps the weighting signal stable while the disentangled model is trained.

  • Dynamic policy. Compute btb_{t} from the current training backbone at each step (still stop-gradiented before use), so weights evolve with the model.

Current implementation status in this repository: the frozen policy is implemented; the dynamic policy is intentionally deferred.

Modifying masked_token_nll_per_example

The current implementation computes:

denom = active.sum(dim=1).clamp_min(1.0)
return (token_loss * active).sum(dim=1) / denom

The weighted variant replaces the uniform active mask with active * weights, where weights is the detached per-token baseline NLL:

weighted_active = active * weights.detach()
denom = weighted_active.sum(dim=1).clamp_min(eps)
return (token_loss * weighted_active).sum(dim=1) / denom

Cost

One additional forward pass per batch (z-free baseline). For the exact discrete CVAE with KK latent values, the training loop already performs K+1K + 1 forward passes per batch (KK latent-conditioned and 11 posterior). The baseline adds one more, increasing cost by 1/(K+1)1 / (K + 1).

Fallback behavior

When the denominator sumwtsum w_{t} is below a threshold (base model is confident everywhere), the loss should fall back to the standard uniform-average reconstruction to avoid numerical instability. A simple implementation:

R^{\mathrm{eff}} (z) = cases(R^w (z) & \text{if} \sum_{t \in \mathcal{S}} w_{t} > \tau, R(z) & \mathrm{otherwise},)

where τ\tau is a small positive threshold (for example, 10410^{-4}).

Built with LogoFlowershow