Abstract
The uniform application of computational resources across all tokens in a transformer sequence is a fundamental inefficiency: not every token requires the same depth of processing. Mixture-of-Depths (MoD) is a class of methods that extend the sparse computation paradigm from the feedforward width dimension (as in Mixture-of-Experts) to the depth dimension, enabling individual tokens to bypass selected layers entirely. This post provides a rigorous technical analysis of MoD architectures, examining the token routing mechanisms, the auxiliary losses required to prevent routing collapse, and the theoretical connections to conditional computation and adaptive inference. We analyze the empirical evidence from recent work showing that MoD models can match dense transformer performance with significantly reduced FLOPs per forward pass, and situate this work within the broader literature on efficient transformers, early exit networks, and learned computation allocation. We conclude by identifying open problems including training instability, the interaction between depth routing and attention-based context integration, and the path toward unified sparse computation across both width and depth.
1. Introduction
The transformer architecture (Vaswani et al., 2017) applies a fixed sequence of operations to every token in its input: self-attention followed by a feedforward network, repeated $L$ times for a model with $L$ layers. This uniformity is architecturally elegant and hardware-friendly — it maps directly onto the kind of regular, batched computation that modern GPUs and TPUs are designed to exploit. However, it encodes a strong and largely unjustified inductive bias: the assumption that every token, at every position, requires exactly the same number of transformations to be represented usefully.
Empirically, this assumption is questionable. Work on mechanistic interpretability (Elhage et al., 2021) has shown that different layers in transformers perform qualitatively different operations, with early layers handling syntactic and local patterns while later layers encode more abstract semantic content. Logit lens analyses (nostalgebraist, 2020) demonstrate that the model’s prediction often stabilizes after only a fraction of the total layers, with the remaining layers providing marginal refinement. Probing studies have confirmed that for many tokens — particularly predictable function words, punctuation, and short common phrases — the representations reach their final semantic form well before the last layer. The per-token application of all $L$ layers is thus computationally wasteful for a significant fraction of the forward pass.
This observation motivates a class of methods collectively termed adaptive computation or conditional computation: the idea that the amount of computation applied to a given input should depend on the input itself. Early work in this direction (Graves, 2016; Bengio et al., 2015) explored scalar halting variables and stochastic depth, but the field has seen renewed interest with the success of Mixture-of-Experts (MoE) architectures (Shazeer et al., 2017; Fedus et al., 2022), which implement input-conditional sparsity in the width dimension by routing tokens to selected expert feedforward networks.
Mixture-of-Depths (MoD) extends this logic to the depth dimension. Rather than routing tokens to different experts within a single layer, MoD allows tokens to skip entire layers, processing only a subset of the full layer stack. The practical implication is significant: a model with 24 layers and 50% token participation at each layer performs, on average, the computational work of a 12-layer model, while retaining the expressive representational capacity of a 24-layer model for tokens that require full depth processing.
This post proceeds as follows. Section 2 reviews the relevant literature on adaptive computation, early exit networks, and sparse computation. Section 3 provides a detailed technical analysis of MoD mechanisms, routing formulations, and training objectives. Section 4 discusses empirical results, architectural variants, and comparisons to dense baselines. Section 5 examines open problems and failure modes. Section 6 concludes.
2. Related Work
Adaptive Computation Time (Graves, 2016) introduced the first explicit adaptive computation framework for RNNs, allowing the network to apply a variable number of processing steps to each input via a learned scalar halting probability. The model learns to allocate more computation to difficult inputs by minimizing a ponder cost that penalizes excessive computation. While influential, ACT suffered from training instability and was difficult to extend to the transformer architecture.
Stochastic Depth (Huang et al., 2016) proposed randomly dropping entire residual blocks during training as a regularization strategy, with all layers active at inference. This is superficially similar to MoD but differs fundamentally: stochastic depth is input-agnostic (the dropping decision does not depend on the token), and it is disabled at test time. MoD, by contrast, makes input-dependent routing decisions that remain active during inference.
Mixture-of-Experts (Shazeer et al., 2017; Fedus et al., 2022) established the dominant paradigm for sparse computation in transformers by replacing dense feedforward layers with a collection of expert networks and a learned router that directs each token to a small subset of experts. The Switch Transformer (Fedus et al., 2022) scaled this to 1.6 trillion parameters with a top-1 routing scheme, demonstrating that sparse models can match dense performance with substantially reduced per-example FLOPs. MoD is conceptually dual to MoE: where MoE introduces sparsity in the expert (width) dimension at a fixed depth, MoD introduces sparsity in the depth dimension at a fixed width.
Early Exit Networks (Teerapittayanon et al., 2016; Schwartz et al., 2020) attach auxiliary classifiers to intermediate layers and allow predictions to be returned from earlier exits when a confidence threshold is met. While effective for classification tasks, early exit mechanisms do not generalize naturally to autoregressive language modeling because exiting early for one token has no principled meaning in the context of a next-token prediction objective. MoD addresses this by routing tokens through a learned sparse selection rather than a threshold-based exit.
Mixture-of-Depths (Raposo et al., 2024) is the most direct instantiation of depth-wise conditional computation in modern transformers. Raposo et al. demonstrate that a router operating at each layer can learn to selectively process tokens, and that the resulting models achieve comparable perplexity to dense baselines at 50% of the FLOP budget per token. Their work establishes that depth routing is learnable and that the router does not collapse to trivial solutions when properly regularized.
LayerSkip (Elhoushi et al., 2024) explores a related direction from the inference efficiency perspective, training models with a self-speculative decoding strategy that leverages early exit for draft generation. While not strictly MoD in the routing sense, it reinforces the empirical finding that transformer depth is highly redundant for a significant fraction of tokens.
3. Technical Analysis
3.1 Formal Setup
Consider a standard transformer with $L$ layers. Let $\mathbf{x}_l^{(t)} \in \mathbb{R}^d$ denote the hidden state of token $t$ at layer $l$. In the dense transformer, the update rule is:
$$\mathbf{x}_{l+1}^{(t)} = \mathbf{x}_l^{(t)} + \text{FFN}_l\left(\mathbf{x}_l^{(t)} + \text{Attn}_l\left(\mathbf{X}_l\right)^{(t)}\right)$$
where $\mathbf{X}_l = [\mathbf{x}_l^{(1)}, \ldots, \mathbf{x}_l^{(T)}]$ is the full sequence of hidden states at layer $l$. In Mixture-of-Depths, each layer $l$ is associated with a router $r_l : \mathbb{R}^d \to [0, 1]$ that computes a routing score for each token. A subset $\mathcal{S}_l \subseteq \{1, \ldots, T\}$ of tokens is selected to participate in layer $l$, typically via top-$k$ selection:
$$\mathcal{S}_l = \text{top-k}_{t}\left\{r_l(\mathbf{x}_l^{(t)})\right\}, \quad |\mathcal{S}_l| = \lfloor C \cdot T \rfloor$$
where $C \in (0, 1]$ is the capacity factor controlling the fraction of tokens that participate in the layer. Tokens not selected for layer $l$ bypass it via the residual connection:
$$\mathbf{x}_{l+1}^{(t)} = \begin{cases} \mathbf{x}_l^{(t)} + \text{Layer}_l(\mathbf{X}_l)^{(t)} & \text{if } t \in \mathcal{S}_l \\ \mathbf{x}_l^{(t)} & \text{if } t \notin \mathcal{S}_l \end{cases}$$
This formulation makes the average FLOPs per token a function of $C$ and $L$. For a model with uniform capacity $C$ at all layers, the expected compute per token is reduced by a factor of $C$ relative to the dense model, while the total parameter count remains identical.
3.2 Router Architecture
The router in MoD is typically a lightweight learned linear projection $r_l(\mathbf{x}) = \sigma(\mathbf{w}_l^\top \mathbf{x})$ where $\mathbf{w}_l \in \mathbb{R}^d$ is a learned weight vector and $\sigma$ is a sigmoid or softmax activation. The scalar output is used as the routing score, and top-$k$ selection identifies the $\lfloor C \cdot T \rfloor$ tokens with highest scores at each layer.
A critical design decision is whether routing scores participate in the weighted combination of layer outputs. Raposo et al. (2024) propose multiplying the layer output by the routing weight for selected tokens:
$$\mathbf{x}_{l+1}^{(t)} = \mathbf{x}_l^{(t)} + r_l(\mathbf{x}_l^{(t)}) \cdot \text{Layer}_l(\mathbf{X}_l)^{(t)}, \quad t \in \mathcal{S}_l$$
This “routing weight” formulation provides a gradient signal through the router that is tightly coupled to the layer’s output contribution, facilitating learning of meaningful routing decisions. Without this weighting, the router receives gradient only through the discrete top-$k$ selection, which is non-differentiable and requires techniques such as straight-through estimators or auxiliary losses.
3.3 Attention and the Routing-Context Interaction
A subtle but important complication arises when MoD is applied to the self-attention sublayer rather than just the FFN. In the dense transformer, every token attends to every other token (within the context window). If token $t$ is routed away from attention at layer $l$, it neither attends to nor is attended to by other tokens at that layer. This breaks the assumption that all tokens contribute equally to the key-value matrix used in attention computation.
Let $\mathbf{K}_l = \mathbf{X}_l W_K$ and $\mathbf{V}_l = \mathbf{X}_l W_V$ be the full key and value matrices. If only tokens in $\mathcal{S}_l$ participate in attention, then the effective key-value set is $\{\mathbf{k}_l^{(t)}, \mathbf{v}_l^{(t)}\}_{t \in \mathcal{S}_l}$, and the attention computation becomes:
$$\text{Attn}_l(\mathbf{x}_l^{(t)}) = \text{softmax}\left(\frac{\mathbf{q}_l^{(t)} (\mathbf{K}_l^{\mathcal{S}_l})^\top}{\sqrt{d_k}}\right) \mathbf{V}_l^{\mathcal{S}_l}$$
This sparse attention pattern has implications for information flow: tokens that are consistently routed away from attention-heavy layers may fail to integrate relevant contextual information, potentially degrading performance on long-range dependency tasks. Empirical analysis by Raposo et al. (2024) suggests that routing primarily affects the FFN sublayer in practice, with the model learning to keep attention participation rates higher than FFN participation rates, consistent with the information-aggregation role of attention.
3.4 Load Balancing and Collapse Prevention
A persistent challenge in learned routing is routing collapse: the degenerate solution where the router directs all tokens to the same subset of layers (or, in the extreme case, routes all tokens through all layers with no effective sparsity). To prevent this, training typically incorporates an auxiliary load balancing loss analogous to that used in MoE training.
Let $f_l^{(t)} = \mathbf{1}[t \in \mathcal{S}_l]$ be the binary routing decision. Define the fraction of tokens routed to layer $l$ as $\hat{p}_l = \frac{1}{T} \sum_t f_l^{(t)}$. The load balancing loss encourages $\hat{p}_l \approx C$ for all layers:
$$\mathcal{L}_{\text{bal}} = \alpha \sum_l \left(\hat{p}_l – C\right)^2$$
where $\alpha$ is a small coefficient (typically $10^{-2}$ to $10^{-3}$) that prevents the auxiliary loss from dominating the language modeling objective $\mathcal{L}_{\text{LM}}$. The total training loss is $\mathcal{L} = \mathcal{L}_{\text{LM}} + \mathcal{L}_{\text{bal}}$.
An alternative approach is to use a fixed, non-learned routing schedule (e.g., alternating MoD and dense layers), sacrificing input-adaptivity for training stability. This “semi-adaptive” variant has been explored as a lower-variance baseline that still achieves significant FLOP reduction.
3.5 FLOP Analysis
For a transformer with hidden dimension $d$, FFN expansion ratio $r$, sequence length $T$, and $L$ layers, the dense per-token FLOP count (ignoring attention, which scales quadratically with sequence length) is approximately:
$$\text{FLOPs}_{\text{dense}} \approx L \cdot (4d^2 + 2rd^2) \cdot T = L \cdot 2d^2(2 + r) \cdot T$$
For a MoD model with capacity $C$ and all $L$ layers subject to routing:
$$\text{FLOPs}_{\text{MoD}} \approx C \cdot L \cdot 2d^2(2 + r) \cdot T = C \cdot \text{FLOPs}_{\text{dense}}$$
In practice, not all layers are subject to routing — practitioners often leave the first and last few layers dense to preserve input encoding and output projection quality — and the routing computation itself adds a small overhead of $O(Td)$ per layer. For large $d$, this overhead is negligible.
The key insight is that MoD achieves FLOP reduction without reducing the parameter count or model capacity. A 7B parameter MoD model with $C = 0.5$ performs approximately the same FLOPs as a 3.5B parameter dense model during inference, but has access to the full 7B parameter representational capacity for the tokens it selects to process deeply.
4. Discussion
4.1 What Does the Router Learn?
A natural question is whether the routing decisions made by MoD models are interpretable and consistent with linguistic intuitions about token difficulty. Raposo et al. (2024) provide some analysis showing that high-frequency, semantically predictable tokens (common function words, punctuation, whitespace) tend to be routed away from deep processing, while content words and tokens in syntactically complex positions receive more layers. This is qualitatively consistent with the logit lens findings mentioned in the introduction: easy tokens stabilize early.
However, the routing behavior is not simply a function of token identity — it depends on context. The same token may receive different depth allocations depending on its position in the sentence, its surrounding tokens, and the task context. This context-sensitivity is what distinguishes MoD from simpler token-type-based pruning approaches and is a key source of its effectiveness.
4.2 Interaction with Speculative Decoding
MoD interacts synergistically with speculative decoding (Leviathan et al., 2023; Chen et al., 2023). In standard speculative decoding, a small draft model generates candidate tokens which are then verified in parallel by the larger target model. With MoD, the target model itself behaves like a variable-cost verifier: tokens that the router identifies as easy can be processed cheaply, effectively achieving a form of adaptive-cost verification that reduces latency beyond what standard speculative decoding achieves.
The LayerSkip approach (Elhoushi et al., 2024) formalizes this connection by training a single model that serves as both its own draft (via early exit) and its own verifier (via full depth), exploiting the same adaptive computation logic as MoD but in a sequential rather than parallel routing framework.
4.3 Training Stability and Curriculum
MoD training exhibits sensitivity to the point at which routing is introduced. Training from scratch with full routing from step one often leads to unstable dynamics where the router oscillates before finding a good allocation. A common mitigation is a training curriculum that begins with dense (no routing) pretraining for some fraction of total compute, then introduces routing gradually by increasing $C$ from 1.0 down to the target capacity. This warm-start approach stabilizes training but complicates the compute budget analysis, as the warm-start phase incurs dense-model costs.
An alternative is to initialize the router weights to produce near-uniform scores, ensuring that early in training all tokens receive approximately equal depth allocation and the routing gradually differentiates as the model develops preferences. Careful initialization of $\mathbf{w}_l$ to have small norm achieves this approximately.
4.4 MoD + MoE: Unified Sparse Computation
The most ambitious direction in this space is the combination of MoD (depth sparsity) and MoE (width sparsity) into a single architecture. The intuition is that these two forms of sparsity are largely orthogonal: MoE selects which expert processes a token at a given layer, while MoD selects which layers process a given token. A model that applies both can potentially achieve much greater FLOP reduction than either alone.
Raposo et al. (2024) experiment with combined MoD+MoE models and find that the two sparsity mechanisms compose reasonably well, with combined models achieving FLOP reductions multiplicative in the two capacity factors. However, the interaction between the two routing systems introduces additional complexity in load balancing — a token that is routed to an overloaded expert layer and simultaneously routed away from that layer by MoD creates inconsistent gradient signals that require careful handling.
5. Conclusion
Mixture-of-Depths represents a principled and empirically validated approach to reducing the computational cost of transformer inference without sacrificing model capacity. By learning to route individual tokens through variable subsets of the layer stack, MoD models exploit the empirical observation that transformers over-invest computation in predictable tokens, allocating uniform depth where adaptive depth would suffice.
The key technical contributions in this space — the routing-weight formulation that enables gradient flow through discrete top-$k$ selection, the load balancing auxiliary objective that prevents routing collapse, and the FLOP analysis that quantifies the theoretical efficiency gains — provide a solid foundation for building production-grade adaptive-depth transformers.
Open problems remain substantial. The interaction between depth routing and attention-mediated context integration is not fully understood, particularly for long documents where tokens skipping attention layers may miss critical long-range dependencies. The training curriculum sensitivity suggests that MoD training dynamics are not yet fully characterized theoretically. And the combination of MoD with MoE, while promising, introduces load balancing interactions that require further study.
As language models continue to scale and inference cost becomes an increasingly important constraint, adaptive computation methods like MoD will play a central role in making large models economically viable. The depth dimension has been underexplored relative to the width dimension in the sparsity literature; MoD establishes that this was an oversight worth correcting.
References
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30.
- Graves, A. (2016). Adaptive computation time for recurrent neural networks. arXiv preprint arXiv:1603.08983.
- Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. (2016). Deep networks with stochastic depth. European Conference on Computer Vision, 646–661.
- Shazeer, N., Mirhoseini, A., Maziarz, K., Davis, A., Le, Q., Hinton, G., & Dean, J. (2017). Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. International Conference on Learning Representations.
- Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. Journal of Machine Learning Research, 23(1), 5232–5270.
- Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., … & Olah, C. (2021). A mathematical framework for transformer circuits. Transformer Circuits Thread.
- Raposo, D., Ritter, S., Richards, B., Lillicrap, T., Humphreys, P. C., & Santoro, A. (2024). Mixture-of-depths: Dynamically allocating compute in transformer models. arXiv preprint arXiv:2404.02258.
- Elhoushi, M., Shrivastava, A., Liskovich, D., Hosmer, B., Wasti, B., Lai, L., … & Wu, C.-J. (2024). LayerSkip: Enabling early exit inference and self-speculative decoding. arXiv preprint arXiv:2404.16710.
- Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast inference from transformers via speculative decoding. International Conference on Machine Learning.
- Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre, L., & Jumper, J. (2023). Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318.
- Teerapittayanon, S., McDanel, B., & Kung, H. T. (2016). BranchyNet: Fast inference via early exiting from deep neural networks. International Conference on Pattern Recognition, 2464–2469.
- Schwartz, R., Stanovsky, G., Swayamdipta, S., Dodge, J., & Smith, N. A. (2020). The right tool for the job: Matching model and instance complexities. Proceedings of ACL, 6640–6651.
- Bengio, E., Bacon, P.-L., Pineau, J., & Precup, D. (2015). Conditional computation in neural networks for faster models. arXiv preprint arXiv:1511.06297.