Sequence-Level Inter-Latent Divergence via Policy Gradient Estimation

This note formalizes a policy-gradient approach to directly optimizing full-sequence inter-latent divergence, complementing the token-level JSD proxy (see inter_latent_divergence.typ).

Status: Proposal. Not yet implemented.

Problem Statement

We want to maximize the full-sequence KL divergence between latent-conditioned distributions:

Lseqdiv=ijKL(p(yx,zi),p(yx,zj)).\mathcal{L}_\mathrm{seq-div} = \sum_{i \neq j} \mathrm{KL}(p(y \mid x, z_i), p(y \mid x, z_j)).

Under the autoregressive decomposition:

KL(p(yx,zi),p(yx,zj))=Eyp(x,zi)[tlogp(ytx,zi,y<t)logp(ytx,zj,y<t)].\mathrm{KL}(p(y \mid x, z_i), p(y \mid x, z_j)) = \mathbb{E}_{y \sim p(\cdot \mid x, z_i)} \left[\sum_{t} \log p(y_t \mid x, z_i, y_{<t}) - \log p(y_t \mid x, z_j, y_{<t})\right].

The expectation over y p(x,zi)y ~ p(\cdot | x, z_{i}) (free-running generation) is the key obstacle: it requires autoregressive sampling, which is sequential and non-differentiable.

Proposed Approach: Prefix Rollout + REINFORCE

Setup

For each example xx and each latent pair (zi,zj)(z_{i}, z_{j}):

  • Generate a short prefix rollout of NN tokens from ziz_{i}: y1:N(i)p(x,zi)y^{(i)}_{1:N} \sim p(\cdot \mid x, z_i) autoregressively.

  • Compute the per-token log-probability differences: rt=logp(yt(i)x,zi,y<t(i))logp(yt(i)x,zj,y<t(i)).r_t = \log p(y^{(i)}_t \mid x, z_i, y^{(i)}_{<t}) - \log p(y^{(i)}_t \mid x, z_j, y^{(i)}_{<t}).

  • The total reward for the rollout is: R=t=1Nrt.R = \sum_{t=1}^{N} r_{t}.

  • Apply REINFORCE to estimate the gradient: LseqdivRt=1Nlogp(yt(i)x,zi,y<t(i)).\nabla \mathcal{L}_\mathrm{seq-div} \approx R \cdot \nabla \sum_{t=1}^{N} \log p(y^{(i)}_t \mid x, z_i, y^{(i)}_{<t}).

Variance Reduction

The token-level JSD proxy (from inter_latent_divergence.typ) can serve as a control variate:

R^=Rsg(LJSDprefix)\hat{R} = R - \operatorname{sg}(\mathcal{L}_\mathrm{JSD-prefix})

where LJSDprefix\mathcal{L}_\mathrm{JSD-prefix} is the token-weighted JSD computed over the prefix under teacher forcing. Since the token-level JSD approximates the same quantity (divergence at strategy-sensitive positions), it should be correlated with RR and reduce variance.

Connections

  • To scheduled sampling: Both approaches break teacher forcing during training. Scheduled sampling replaces ground-truth tokens with model predictions to improve reconstruction; the policy gradient approach generates free-running prefixes to estimate divergence. The policy gradient targets divergence specifically rather than reconstruction quality.

  • To the token-level JSD: The token-level JSD can be seen as a teacher-forced approximation of the sequence-level divergence. The policy gradient approach removes the teacher-forcing assumption but introduces sampling noise and higher variance.

Implementation Challenges

Sequential generation

Generating NN-token prefixes for each of KK latent values during each training step is expensive. For K=4,N=20,B=128K = 4, N = 20, B = 128, this requires 4×20=804 \times 20 = 80 sequential forward passes per batch in addition to the standard teacher-forced passes.

Gradient variance

REINFORCE gradients are notoriously high-variance. Even with the JSD-based control variate, the estimator may require many samples or large batches to be useful. The DReG (doubly reparameterized gradient) estimator may help for the continuous case.

Training stability

Early in training, model predictions are poor, so the generated prefixes may not reflect meaningful strategy differences. The reward signal may be noisy or misleading until the model has learned basic sequence structure. A warmup period (e.g., only enable after the token-level JSD loss has been active for some steps) may be needed.

When to Revisit

This approach should be considered if the token-level JSD proxy proves insufficient because:

  • Strategy differences emerge primarily at the sequence level (e.g., strategies that produce similar individual tokens but different trajectories), or
  • The teacher-forcing assumption causes the token-level JSD to miss important divergence that exists under free-running generation.

The token-level JSD should be evaluated first as it is much cheaper and may be sufficient for the current synthetic tasks where strategies diverge in early tokens.

Built with LogoFlowershow