Sequence-Level Inter-Latent Divergence via Policy Gradient Estimation
Sequence-Level Inter-Latent Divergence via Policy Gradient Estimation
This note formalizes a policy-gradient approach to directly optimizing
full-sequence inter-latent divergence, complementing the token-level JSD
proxy (see inter_latent_divergence.typ).
Status: Proposal. Not yet implemented.
Problem Statement
We want to maximize the full-sequence KL divergence between latent-conditioned distributions:
Under the autoregressive decomposition:
The expectation over (free-running generation) is the key obstacle: it requires autoregressive sampling, which is sequential and non-differentiable.
Proposed Approach: Prefix Rollout + REINFORCE
Setup
For each example and each latent pair :
-
Generate a short prefix rollout of tokens from : autoregressively.
-
Compute the per-token log-probability differences:
-
The total reward for the rollout is:
-
Apply REINFORCE to estimate the gradient:
Variance Reduction
The token-level JSD proxy (from inter_latent_divergence.typ) can serve as
a control variate:
where is the token-weighted JSD computed over the prefix under teacher forcing. Since the token-level JSD approximates the same quantity (divergence at strategy-sensitive positions), it should be correlated with and reduce variance.
Connections
-
To scheduled sampling: Both approaches break teacher forcing during training. Scheduled sampling replaces ground-truth tokens with model predictions to improve reconstruction; the policy gradient approach generates free-running prefixes to estimate divergence. The policy gradient targets divergence specifically rather than reconstruction quality.
-
To the token-level JSD: The token-level JSD can be seen as a teacher-forced approximation of the sequence-level divergence. The policy gradient approach removes the teacher-forcing assumption but introduces sampling noise and higher variance.
Implementation Challenges
Generating -token prefixes for each of latent values during each training step is expensive. For , this requires sequential forward passes per batch in addition to the standard teacher-forced passes.
REINFORCE gradients are notoriously high-variance. Even with the JSD-based control variate, the estimator may require many samples or large batches to be useful. The DReG (doubly reparameterized gradient) estimator may help for the continuous case.
Early in training, model predictions are poor, so the generated prefixes may not reflect meaningful strategy differences. The reward signal may be noisy or misleading until the model has learned basic sequence structure. A warmup period (e.g., only enable after the token-level JSD loss has been active for some steps) may be needed.
When to Revisit
This approach should be considered if the token-level JSD proxy proves insufficient because:
- Strategy differences emerge primarily at the sequence level (e.g., strategies that produce similar individual tokens but different trajectories), or
- The teacher-forcing assumption causes the token-level JSD to miss important divergence that exists under free-running generation.
The token-level JSD should be evaluated first as it is much cheaper and may be sufficient for the current synthetic tasks where strategies diverge in early tokens.