High-level Outline, built around Key Figures
High-level Outline, built around Key Figures
Phase 1: Obtaining an "entangled mixture model"
- Train a model on a synthetic distribution corresponding to a "mixture of strategies".
- Ground truth: , where is a "strategy" (e.g., "insertion sort" for list sorting)
- Idea: this simulates (in a simple synthetic setting) true multi-modal multi-strategy language models; but since we carefully train the models ourselves from scratch, we have ground truth strategies we can evaluate against. Allows us to measure: does our "disentangled factorization"/"latent strategy learning" method effectively uncover these latent strategies?
- We consider a number of synthetic tasks: multi-digit addition, sorting, solving linear equations, grid path finding, base conversion, …
- This can provide a test bed for evaluating methods for "disentanglement". Future efforts can use this as a benchmark and build upon the methods and approaches we explore.
- Key Figure 1: Decodability of strategy from embeddings vs position. At what point in the generation of the solution does the model "decide" on a strategy or "know" the strategy.
- Can also consider the dynamics of this over time, represented as a gif/movie or as a heatmap.
Phase 2: Learning a disentangled factorization of the base model .
- Apply our methodology to the entangled mixture model and attempt to learn a disentangled factorization :
- is called the strategy router (routes inputs/tasks to strategies by defining a distribution over latent strategy space )
- is called the strategy-conditioned generator (generates solutions given the input , guided by the strategy )
- Key Figure 2: dynamics of latent variable : shows that controls strategy by showing mapping/correspondence between and under .
- Animated Scatter (or frames on a horizontal facet): scatter of (if cts) with color-coded under .
- Linear decodability of from under .
- Key Figure 3: Posterior strategy-centroid distance. Posterior strategy-centroid variance.
- Key Figure 3: Dynamics of posterior
- Animated scatter (gif or frames)
- Linear decodability over the course of training