Summary of Methodology and Experimental Setup for Initial Synthetic Sequences Experiments

Problem Setting

The synthetic-sequences experiments start from an entangled conditional model pθ(yx)p_\theta (y | x) trained on synthetic solution traces. For a fixed input xx, the output distribution is intentionally multi-modal: several valid solution strategies can lead to the same final answer while producing different intermediate traces.

The goal is to replace this entangled conditional distribution by a latent factorization

  • discrete: p(yx)=zpψ(yx,z)pϕ(zx),p(y \mid x) = \sum_{z} p_\psi(y \mid x, z)\, p_\phi(z \mid x),
  • continuous: p(yx)=pψ(yx,z)pϕ(zx)dz,p(y \mid x) = \int p_\psi(y \mid x, z)\, p_\phi(z \mid x)\, \mathrm{d}z,

where zz should represent a high-level strategy variable.

Desired behavior

A good factorization should satisfy three properties simultaneously: (i) strong reconstruction / generation quality, (ii) informative routing so that zz affects the solution trajectory, and (iii) stable strategy semantics so that interventions on zz induce distinct, reusable behaviors.

Parameterization

Discrete Latents

The discrete implementation uses special latent tokens

<z1> ... <zK>

inside the canonical prompt format.

  • router prompt:
<bos> <x> ... </x> <z>

defines pϕ(zx)p_\phi (z | x) from the next-token distribution restricted to the latent-token set;

  • generator prompt:
<bos> <x> ... </x> <z> z_i </z> <y> ... </y> <eos>

defines pψ(yx,zi)p_\psi (y | x, z_{i});

  • posterior prompt: the canonical (x,y)(x, y) record is processed by a separate posterior backbone plus pooling head to define qξ(zx,y)q_\xi (z | x, y).

In the implementation, the router and generator are shared: the same adapted causal LM parameterizes both pϕ(zx)p_\phi (z | x) and pψ(yx,z)p_\psi (y | x, z).

Continuous Latents

The continuous implementation keeps the same exposed prompt structure but replaces the discrete latent token with a vector injection.

  • router: pϕ(zx)=N(μϕ(x),diag(σϕ(x)2))p_\phi (z | x) = \mathcal{N}(\mu_\phi (x), \operatorname{diag}(\sigma_\phi (x)^2)),
  • posterior: qξ(zx,y)=N(μξ(x,y),diag(σξ(x,y)2))q_\xi (z | x, y) = \mathcal{N}(\mu_\xi (x, y), \operatorname{diag}(\sigma_\xi (x, y)^2)),
  • generator: a sampled latent vector zRdzz \in \mathbb{R}^{d_{z}} is mapped to embedding space E(z)RdmodelE(z) \in \mathbb{R}^{d_\mathrm{model}} and injected at the position of a placeholder token such as <z_empty>.
<z_empty>

Implementation detail To implement the continuous latent approach, we keep the same prompt format but designate a single latent token position as the injection point. The generator prompt is <bos> <x> ... </x> <z> <z_empty> </z> <y> ... </y> <eos> and the model adds a learned projection E(z)E(z) to the embedding at that single position.

Training Objective

The common objective is the conditional ELBO

J(x,y)=R(x,y)+βK(x,y)+Laux(x,y),\mathcal{J}(x, y) = R(x, y) + \beta K(x, y) + \mathcal{L}_\mathrm{aux}(x, y),

where the first term is the reconstruction loss

R(x,y)=Ezqξ(x,y)[logpψ(yx,z)],R(x, y) = -\mathbb{E}_{z \sim q_\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)],

and the second term is the KL divergence tethering the posterior to the prior/router:

K(x,y)=KL(qξ(zx,y),pϕ(zx)).K(x, y) = \mathrm{KL}(q_\xi(z \mid x, y), p_\phi(z \mid x)).

Mode-specific estimation differs:

  • discrete: exact summation over all latent values:

\mathbb{E}{z \sim q\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)] = \sum_{z} q_\xi(z \mid x, y), \log p_\psi(y \mid x, z);

continuous:reparameterizedsamplingfromtheposteriorGaussian:- continuous: reparameterized sampling from the posterior Gaussian:

\mathbb{E}{z \sim q\xi(\cdot \mid x, y)}[\log p_\psi(y \mid x, z)] \approx \frac{1}{S} \sum_{s=1}^{S} \log p_\psi(y \mid x, z_s), \quad z_s = \mu_\xi(x, y) + \sigma_\xi(x, y) \odot \epsilon_s, \quad \epsilon_s \sim \mathcal{N}(0, I).

The current trainer also supports the following auxiliary controls: - either continuous or discrete: - token-weighted reconstruction: reweights reconstruction toward tokens where the z-free baseline is uncertain, with the goal of concentrating learning on strategy-sensitive positions. - baseline-normalized reconstruction: rescales the standard reconstruction term by the entangled baseline y-loss, with the goal of restoring a useful reconstruction-to-KL balance when the initializer already "reconstructs" well. - discrete only: - posterior entropy: penalizes high-entropy $q_\xi (z | x, y)$, with the goal of making posterior assignments sharper and more committed. Intuition: the solution $y$ should uniquely identify a single strategy $z$. - router support: penalizes probability mass outside the latent-token set at the routing position, with the goal of making $p_\phi (z | x)$ live on the intended discrete latent support. This is a subtle implementation detail that is related to the shared parameterization of the router and generator; not very important in the big picture. - router marginal KL to prior: penalizes deviation of the batch-mean router distribution from the target prior, with the goal of preventing latent-usage collapse at the aggregate level. $\mathrm{KL}(\mathbb{E}_x[p_\phi(z \mid x)] \,\|\, p_\mathrm{prior}(z))$, where $p_\mathrm{prior}(z) = \mathrm{Unif}(\{z_1, \ldots, z_K\})$. Intuition: the router can direct different inputs to different strategies, but all strategies should be used at least some of the time rather than collapsing to a single dominant mode. - inter-latent token-level Jensen-Shannon divergence: maximizes divergence among latent-conditioned token distributions, with the goal of making different latent values induce genuinely different behaviors. Each strategy should produce a different distribution over solution trajectories. - continuous only: - posterior variance: penalizes large posterior variance, with the goal of tightening sampled latent geometry around informative posterior means. - router marginal KL to prior: penalizes mismatch between the batch-mean continuous router moments and the target prior, with the goal of keeping sampled latents generation-ready rather than drifting arbitrarily. Baseline normalization is especially relevant in this project because the generator is initialized from the entangled baseline, making raw reconstruction loss small from the start. Companion notes provide the detailed arguments for token weighting, JSD, and normalized reconstruction. The weight on each of these auxiliary terms is a hyperparameter. ## Training Algorithm Given a base conditional model $p(y \mid x)$, the disentangling methodology can be summarized as: 1. Load the base model. 2. Initialize the factorization $(p_\phi(z \mid x), p_\psi(y \mid x, z))$ and the posterior network $q_\xi(z \mid x, y)$ from it according to the latent configuration. 3. For each optimization step, construct a batch by sampling inputs $x_1, \ldots, x_n$ and responses $y_i \sim p(y \mid x_i)$ from the base model. 4. Infer the posterior $q_\xi(z \mid x_i, y_i)$ and prior/router $p_\phi(z \mid x_i)$ for each pair $(x_i, y_i)$. 5. If $z$ is discrete, compute the exact ELBO reconstruction term by enumerating latent values; otherwise estimate it with reparameterized latent samples. 6. Compute the full objective: reconstruction $+$ beta-weighted KL $+$ any enabled auxiliary losses. 7. Take a gradient step with respect to the trainable parameters. 8. Repeat until convergence and retain the learned factorization $(p_\phi(z \mid x), p_\psi(y \mid x, z))$. ## Experiments ### Task Suite The current task family is: - `list_summation`: summing a short list; strategies = `left-to-right`, `right-to-left`, `pairwise`. - `sorting_algorithms`: emitting a sorting trace; strategies = `bubble-sort`, `selection-sort`, `insertion-sort`, `merge-sort`, `heap-sort`. - `grid_pathfinding`: path construction on a grid; strategies = `right-first`, `down-first`, `alternating`. - `linear_equation_solving`: solving one-variable equations; strategies = `subtract-then-divide`, `divide-then-subtract`, `inverse-ops`. - `base_conversion`: converting between number bases; strategies = `repeated-division`, `via-binary`, `decomposition`. - `multidigit_addition`: explicit addition traces; strategies = `right-to-left-carry`, `left-to-right-partials`, `rounding-decomposition`. All tasks use a canonical sequence format ```text <bos> <x> ... </x> <y> ... </y> <eos> ``` with deterministic parsing and strategy attribution. ### Experimental Design The synthetic-sequences experiments instantiate the general methodology with a simple two-stage design: 1. train an entangled base model on the synthetic task distribution, so that $p_\theta (y | x)$ captures the multi-strategy output distribution without an explicit latent factorization; 2. apply the disentangling methodology described above to that trained model to learn a factorized representation $(p_\phi (z | x), p(y | x, z))$. This design makes the second stage a post-training factorization problem: the disentangled model starts from a base model that already solves the task well, and the main question is whether latent structure can be made explicit and controllable. > [!info] Why synthetic sequences > The setting gives exact task validators and known strategy taxonomies, so > failures can be localized: does the posterior carry strategy information, does > the router preserve it, and do controlled latent interventions induce the > intended behaviors? ### Evaluation The main evaluation modes are: - router-sampled generation: sample or draw $z ~ p_\phi (\cdot | x)$, then decode $y ~ p_\phi (\cdot | x, z)$; - controlled latents: intervene on $z$ and decode one sample per latent value; - posterior diagnostics: probe $q_\xi (z | x, y)$ directly, especially important in the continuous case. The most informative summary metrics are: - answer quality: parse success and final-answer accuracy; - discrete structure: `alignment_one_to_one`, `local_strategy_coverage_mean`, `local_strategy_full_coverage_rate`; - continuous structure: sampled-z probe accuracy and centroid distance, together with posterior-mu probe accuracy to measure the posterior-to-sampled gap; - optimization diagnostics: raw reconstruction loss, baseline-normalized reconstruction loss, KL, and loss-contribution fractions.
Built with LogoFlowershow