FlashAttention and IO-Aware Algorithm Design: Recomputation, Tiling, and the Memory Hierarchy of Efficient Transformers

Abstract

The standard self-attention mechanism in transformer models incurs $O(N^2)$ time and space complexity with respect to sequence length $N$, creating a fundamental throughput bottleneck rooted not in arithmetic operations but in memory bandwidth. FlashAttention (Dao et al., 2022) reframed attention as an IO-aware algorithmic problem: by carefully tiling the computation to exploit the GPU SRAM hierarchy and avoiding materialization of the full $N \times N$ attention matrix in HBM, it achieves exact attention with 2–4× wall-clock speedups and sub-quadratic memory footprint. FlashAttention-2 (Dao, 2023) extended these gains through improved work partitioning across warps and thread blocks. This paper provides a detailed technical analysis of the FlashAttention algorithm family—covering the online softmax normalization trick, the tiling schedule, the backward-pass recomputation strategy, and the extension to multi-query and grouped-query attention. We situate these contributions within the broader landscape of IO-aware algorithm design and discuss implications for long-context inference, training throughput, and hardware co-design.

1. Introduction

Modern transformer-based language models devote a substantial fraction of their compute time to self-attention. For a sequence of length $N$ and model dimension $d$, the standard attention operation

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

requires constructing an $N \times N$ matrix before applying softmax and contracting with $V$. At $N = 4096$ and batch size 4 with 32-bit floats, the attention matrix alone occupies roughly 2 GB—well beyond the 40–80 MB of on-chip SRAM available in contemporary A100/H100 GPUs. The consequence is that attention is memory-bandwidth-bound rather than compute-bound: the bottleneck is the round-trip cost of reading and writing large tensors to high-bandwidth memory (HBM), not the floating-point throughput of the tensor cores.

This diagnosis reframes the optimization target. Rather than reducing FLOPs, what is needed is a reduction in HBM reads and writes—an objective that demands algorithm redesign at the level of the memory hierarchy. The classical tool for such redesign is tiling: partitioning a computation into blocks small enough to fit in fast memory, performing the computation entirely in fast memory, and writing only the final result back to slow memory. Cache-oblivious algorithms and blocked matrix multiplication exemplify this principle in the classical setting.

Dao et al. (2022) applied this principle to attention with one critical complication: softmax is a global reduction over the key dimension, seemingly incompatible with blockwise computation. The key insight of FlashAttention is that softmax normalization can be maintained incrementally using an online normalization trick (Milakov & Gimelshein, 2018), allowing the algorithm to process keys and values in tiles while accumulating a running maximum and denominator—never materializing the full attention matrix.

This paper provides a thorough technical exposition of FlashAttention and its successors, analyzes the IO complexity formally, and discusses the broader implications for transformer training and inference at scale.

2. Related Work

Dao et al. (2022) introduced FlashAttention, demonstrating exact attention with $O(N)$ memory and significant wall-clock speedups on GPT-2 and BERT training. The work formalizes GPU memory hierarchy costs and proves an IO complexity lower bound for exact attention.

Dao (2023) presented FlashAttention-2, refining the thread-block and warp-level parallelism to better utilize the A100’s tensor cores, achieving up to 2× speedup over FlashAttention on forward passes and improved backward-pass efficiency.

Rabe & Staats (2021) independently proposed a memory-efficient attention algorithm using gradient checkpointing in the sequence dimension. While achieving $O(\sqrt{N})$ memory without recomputation overhead, the approach is slower in wall-clock time than FlashAttention due to less aggressive tiling.

Child et al. (2019) proposed Sparse Transformer, restricting attention to local windows and strided global patterns to reduce complexity to $O(N \sqrt{N})$. This trades exactness for efficiency, whereas FlashAttention achieves efficiency without approximation.

Kitaev et al. (2020) introduced Reformer, using locality-sensitive hashing to group queries and keys into buckets, achieving $O(N \log N)$ approximate attention. The approximation quality depends on hash function design and bucket granularity.

Choromanski et al. (2021) proposed Performer, approximating the softmax kernel with random feature maps to obtain linear attention. While theoretically elegant, random feature approximations can degrade significantly on tasks requiring precise attention over distant tokens.

Shazeer (2019) introduced Multi-Query Attention (MQA), sharing key and value heads across query heads to reduce KV cache size during autoregressive decoding—a technique later complemented by Ainslie et al. (2023) with Grouped-Query Attention (GQA), which FlashAttention-2 explicitly supports.

3. Technical Analysis

3.1 Memory Hierarchy and IO Complexity

We model the GPU as a two-level memory system: HBM of size $M_{\text{HBM}}$ with bandwidth $B_{\text{HBM}}$, and SRAM of size $M_{\text{SRAM}} \ll M_{\text{HBM}}$ with bandwidth $B_{\text{SRAM}} \gg B_{\text{HBM}}$. On an A100-80GB, $M_{\text{HBM}} \approx 80\,\text{GB}$, $B_{\text{HBM}} \approx 2\,\text{TB/s}$, $M_{\text{SRAM}} \approx 40\,\text{MB}$, $B_{\text{SRAM}} \approx 19\,\text{TB/s}$.

The IO cost of an algorithm is measured in the number of elements read from or written to HBM. For standard attention, the steps are:

  1. Load $Q, K \in \mathbb{R}^{N \times d}$ from HBM: $2Nd$ reads
  2. Compute $S = QK^\top / \sqrt{d} \in \mathbb{R}^{N \times N}$: write $N^2$ elements
  3. Load $S$, compute $P = \text{softmax}(S)$: $N^2$ reads + $N^2$ writes
  4. Load $P, V$, compute $O = PV$: $N^2 + Nd$ reads, $Nd$ writes

Total HBM accesses: $\Theta(N^2 + Nd)$. For $N \gg d$ (typical in practice), this is $\Theta(N^2)$.

Dao et al. (2022) prove a lower bound: any exact attention algorithm requires $\Omega(N^2 / M_{\text{SRAM}})$ HBM reads/writes. FlashAttention achieves $\Theta(N^2 d / M_{\text{SRAM}})$, which is optimal up to the factor $d$.

3.2 Online Softmax and the Tiling Schedule

The fundamental obstacle to tiling attention is that softmax requires global statistics. For a row $\mathbf{s} \in \mathbb{R}^N$:

$$\text{softmax}(\mathbf{s})_i = \frac{e^{s_i}}{\sum_{j=1}^N e^{s_j}}$$

Computing this in blocks requires knowing $\sum_j e^{s_j}$ before writing any output—seemingly requiring two passes over the keys. The online softmax trick of Milakov & Gimelshein (2018) resolves this with running statistics. Define for the first $t$ tiles:

$$m_t = \max_{j \leq t} s_j, \quad \ell_t = \sum_{j=1}^t e^{s_j – m_t}$$

When a new tile arrives with local maximum $\tilde{m}$, the statistics update as:

$$m_{t+1} = \max(m_t, \tilde{m})$$
$$\ell_{t+1} = e^{m_t – m_{t+1}} \ell_t + e^{\tilde{m} – m_{t+1}} \tilde{\ell}$$

and the accumulated output is rescaled accordingly:

$$O_{t+1} = \text{diag}\!\left(\frac{e^{m_t – m_{t+1}} \ell_t}{\ell_{t+1}}\right) O_t + \frac{e^{\tilde{m} – m_{t+1}}}{\ell_{t+1}} \tilde{V}_\text{weighted}$$

This recurrence allows exact softmax-weighted accumulation in a single forward pass over tiles, with only $O(N)$ HBM writes for the output matrix $O$.

The full tiling schedule partitions $Q$ into row-tiles of size $B_r$ and $K, V$ into column-tiles of size $B_c$, chosen such that $B_r, B_c \approx \sqrt{M_{\text{SRAM}} / (4d)}$ to ensure the SRAM working set fits. For each query tile, the algorithm iterates over all key-value tiles, updating $(m, \ell, O)$ in SRAM before writing the final $O$ tile back to HBM.

3.3 Backward Pass via Recomputation

Standard backpropagation through attention requires storing the $N \times N$ attention matrix $P$ for the backward pass, incurring $O(N^2)$ memory. FlashAttention avoids this by recomputing $P$ from the stored softmax statistics $(m, \ell)$ during the backward pass—a form of gradient checkpointing applied at the operation level rather than the layer level.

Given stored $Q, K, V, O, \ell, m$ (all $O(Nd)$ except $\ell, m$ which are $O(N)$), the backward pass recomputes each tile of $P$ on-the-fly:

$$P_{ij} = \frac{e^{s_{ij} – m_i}}{\ell_i}$$

then uses $P$ to compute $dV, dK, dQ$ for that tile before discarding it. The total backward IO cost is $\Theta(N^2 d / M_{\text{SRAM}})$—matching the forward pass, but with a higher constant factor due to the three gradient computations.

The memory footprint drops from $O(N^2)$ for standard attention to $O(N)$ for FlashAttention (excluding activations), enabling training with significantly longer sequences on fixed hardware.

3.4 FlashAttention-2: Work Partitioning

FlashAttention-2 identifies that the original implementation underutilizes GPU parallelism because the outer loop iterates over query tiles (assigned to thread blocks), but inner-loop work over key-value tiles is sequential within each block. The key change is to restructure computation so that both query and key-value tiles can be parallelized across thread blocks in the sequence dimension.

Specifically, FA-2 assigns different query blocks to different thread blocks and sequences the KV iteration within each, eliminating the need for inter-block synchronization during the running statistics update. Additionally, FA-2 reduces the number of non-GEMM operations (rescaling of $O$ by the softmax denominator) and improves warp-level parallelism by assigning non-overlapping query slices to different warps within a block.

These changes yield up to 70–75% utilization of the A100’s theoretical FLOP throughput on the forward pass, compared to approximately 30–40% for FA-1.

3.5 Extension to Multi-Query and Grouped-Query Attention

In MQA, a single head of $K$ and $V$ is shared across $h$ query heads. In GQA with $g$ groups, $K$ and $V$ have $h/g$ heads each. FlashAttention-2 handles both by broadcasting the KV tiles across the relevant query heads within the tiling loop, avoiding redundant HBM reads of $K$ and $V$. This is particularly important for autoregressive decoding, where the KV cache must be loaded repeatedly for each generated token.

The IO reduction for MQA relative to full multi-head attention (MHA) during the KV-loading phase is a factor of $h$, substantially reducing bandwidth pressure at long context lengths.

3.6 Causal Masking and Variable-Length Sequences

Causal (autoregressive) masking is incorporated by skipping the computation of tiles that are entirely in the masked-out region (future positions). For query tile $i$ and key tile $j$ with $j > i$ (in block-index terms), the tile contributes zero to the output and is skipped. Tiles on the diagonal require partial masking, handled by zeroing the relevant elements before the online softmax update.

Variable-length sequences (VarLen attention) pack multiple sequences of different lengths into a single batch dimension, using a cumulative sequence-length array to index into the query and key matrices. FA-2 supports this natively, enabling efficient training on datasets with high length variance without padding overhead.

4. Discussion

4.1 Impact on Long-Context Training and Inference

The $O(N)$ memory footprint of FlashAttention is the primary enabler of long-context transformers. Models such as GPT-4 (128K context), Claude 3 (200K), and Gemini 1.5 (1M) would not be feasible to train efficiently without IO-aware attention—not because the FLOPs are prohibitive, but because the HBM bandwidth cost of materializing the attention matrix would dominate training time.

In autoregressive inference, attention over the KV cache is bandwidth-bound rather than compute-bound: the GPU is idle waiting for data from HBM. FlashAttention’s tiling helps here by reducing the volume of intermediate data written to HBM, but the KV cache itself grows linearly with context length, creating a separate memory management problem addressed by techniques like PagedAttention (Kwon et al., 2023) and sliding-window cache eviction.

4.2 Implications for Hardware Co-Design

FlashAttention’s success highlights a broader principle: for memory-bandwidth-bound operations, algorithm design must account for the memory hierarchy explicitly. This has prompted hardware vendors to revisit the SRAM-HBM capacity ratio. NVIDIA’s H100 increases SRAM per SM from 192 KB to 256 KB, directly benefiting FlashAttention’s tile sizes.

There is also growing interest in architectural alternatives that sidestep the attention bottleneck altogether: linear recurrences (Gu & Dao, 2023; Mamba), linear attention (Katharopoulos et al., 2020), and hybrid architectures that use attention only for a subset of layers. The existence of FlashAttention, however, has raised the competitive bar for these alternatives—efficient exact attention is now a strong baseline, not an easy target.

4.3 Relationship to Kernel Fusion and Graph Optimization

FlashAttention is a form of operator fusion: it fuses the matrix multiplications, softmax, and weighted value aggregation into a single CUDA kernel, eliminating intermediate HBM round-trips. This is conceptually related to XLA’s fusion passes and TVM’s operator scheduling, but implemented as a hand-tuned kernel rather than auto-generated code.

The success of handcrafted kernels over auto-tuned approaches for attention reflects the difficulty of automatically discovering the online softmax trick from a high-level computation graph. Recent work on FlashDecoding (Dao et al., 2023) extends the tiling strategy to the inference-time, multi-query, variable-KV-cache setting, further demonstrating the generative power of the IO-aware design principle.

4.4 Limitations and Open Problems

Despite its success, FlashAttention has several limitations. First, the CUDA implementation is not trivially portable: adapting it to AMD ROCm, Intel XPUs, or custom accelerators requires significant engineering effort. Triton-based implementations (e.g., the Triton FA-2 kernel) partially address this through a hardware-agnostic intermediate representation.

Second, the tiling schedule is sensitive to the ratio $N / d$: for very small $d$ (e.g., $d = 16$) the tiles become too small to amortize overhead, while very large $d$ (e.g., $d = 256$ in some multi-head configurations) may exhaust SRAM. FlashAttention-3 (Shah et al., 2024) addresses these issues on H100 by exploiting asynchronous memory copy engines and the tensor memory accelerator (TMA).

Third, the recomputation strategy in the backward pass, while memory-efficient, increases the backward-pass wall-clock time by approximately 30% relative to storing the attention matrix. For training regimes where backward-pass throughput is the bottleneck, this may be suboptimal.

5. Conclusion

FlashAttention represents one of the most impactful algorithm-level contributions to practical deep learning since the introduction of the transformer. By reframing attention as an IO-aware problem and exploiting the GPU memory hierarchy through tiling and online normalization, it achieves exact computation with linear memory and substantial wall-clock speedups—without any approximation or model modification.

The core lesson is methodological: performance at scale is increasingly determined by memory bandwidth and data movement, not raw FLOP counts. Algorithms that respect this constraint—through tiling, fusion, and recomputation—will continue to outperform those that optimize for FLOPs alone. FlashAttention’s success has already catalyzed a broader research agenda around IO-aware design for softmax, layer normalization, and attention variants (e.g., sliding window, ALiBi, RoPE), suggesting that the principle is far from exhausted.

As context lengths extend toward millions of tokens and model dimensions grow, the gap between memory-aware and memory-naive implementations will only widen. The algorithmic foundations laid by FlashAttention—IO complexity analysis, online statistical accumulation, and SRAM-resident tiling—provide the conceptual vocabulary for that next generation of efficient transformer kernels.

References

Direct Preference Optimization: Bypassing the Reward Model in RLHF and the Mathematics of Implicit Reward Learning
Grouped Query Attention and Multi-Query Attention: KV Cache Compression, Inference Efficiency, and the Memory Bandwidth Bottleneck in Large Language Models

Leave a Comment

Your email address will not be published. Required fields are marked *