High-level Outline, built around Key Figures

Phase 1: Obtaining an "entangled mixture model"

  • Train a model on a synthetic distribution corresponding to a "mixture of strategies".
  • Ground truth: p(yx)=zp(yx,z)p(zx)dμ(z)p(y|x) = \int_z p(y|x,z) p(z|x) d\mu(z), where zz is a "strategy" (e.g., "insertion sort" for list sorting)
  • Idea: this simulates (in a simple synthetic setting) true multi-modal multi-strategy language models; but since we carefully train the models ourselves from scratch, we have ground truth strategies we can evaluate against. Allows us to measure: does our "disentangled factorization"/"latent strategy learning" method effectively uncover these latent strategies?
  • We consider a number of synthetic tasks: multi-digit addition, sorting, solving linear equations, grid path finding, base conversion, …
  • This can provide a test bed for evaluating methods for "disentanglement". Future efforts can use this as a benchmark and build upon the methods and approaches we explore.
  • Key Figure 1: Decodability of strategy from embeddings vs position. At what point in the generation of the solution yy does the model "decide" on a strategy or "know" the strategy.
    • Can also consider the dynamics of this over time, represented as a gif/movie or as a heatmap.

Phase 2: Learning a disentangled factorization of the base model pbase(yx)p_{base}(y|x).

  • Apply our methodology to the entangled mixture model and attempt to learn a disentangled factorization (p(zx),p(yx,z))(p(z|x), p(y|x,z)):
    • p(zx)p(z|x) is called the strategy router (routes inputs/tasks xx to strategies zz by defining a distribution over latent strategy space Z\mathcal{Z})
    • p(yx,z)p(y|x,z) is called the strategy-conditioned generator (generates solutions yy given the input xx, guided by the strategy zz)
  • Key Figure 2: dynamics of latent variable zz: shows that zz controls strategy by showing mapping/correspondence between zz and strat(y)strat(y) under z p(zx),y p(yx,z)z ~p(z|x), y ~ p(y|x,z).
    • Animated Scatter (or frames on a horizontal facet): scatter of zz (if cts) with color-coded strat(y)strat(y) under z p(zx),y p(yx,z)z ~p(z|x), y ~ p(y|x,z).
    • Linear decodability of start(y)start(y) from zz under z p(zx),y p(yx,z)z ~p(z|x), y ~ p(y|x,z).
  • Key Figure 3: Posterior strategy-centroid distance. Posterior strategy-centroid variance.
  • Key Figure 3: Dynamics of posterior
    • Animated scatter (gif or frames)
    • Linear decodability over the course of training
Built with LogoFlowershow