← Infrastructure Transformer Systems
Infrastructure

FlashAttention-2 Memory Optimization

FlashAttention-2 introduces parallelization over the sequence length dimension, rescuing low SM occupancy for long-context workloads.

Source: mortalapps.com
TL;DR
  • FlashAttention-2 introduces parallelization over the sequence length dimension, rescuing low SM occupancy for long-context workloads.
  • It reverses the warp-level scheduling loop, splitting Queries across warps instead of Keys/Values, minimizing shared memory (SRAM) read/write synchronization.
  • It restructures the online softmax mathematics to eliminate non-matmul operations from the inner loop, ensuring Tensor Cores remain fully saturated.
  • These memory and scheduling optimizations elevate Model FLOP Utilization to nearly 73% of maximum theoretical FLOPs on A100 GPUs.

Why This Matters

As models scale to context windows beyond 8K tokens, batch sizes must shrink to fit inside rigid GPU memory constraints. FlashAttention-1 parallelized strictly over the batch size and the head dimension. For a long sequence with a batch size of, the number of active thread blocks drops well below the number of available Streaming Multiprocessors (SMs), leaving the GPU severely underutilized. FlashAttention-2's ability to parallelize over the sequence length restores high occupancy, making long-context LLMs economically viable on modern infrastructure.

Core Intuition

FlashAttention-1 solved the HBM read/write problem via tiling but suffered from internal SM inefficiencies. Imagine a factory (the SM) where four workers (warps) assemble parts. In version, all workers grabbed different parts of the Key/Value matrices, generated intermediate results, threw them onto a shared table (SRAM), synchronized, and added them up. In version, each worker takes a specific slice of the Query matrix and shares the Key/Value matrices. The worker completes the entire process for their assigned query slice independently, avoiding the shared table traffic entirely.

Technical Deep Dive

FlashAttention-2 optimizes the inner loop mechanics of the attention kernel through three primary architectural changes. First, it employs a loop order reversal. By swapping the outer and inner loops—iterating over Q blocks in the outer loop, and K and V blocks in the inner loop—the algorithm retains Q tiles in shared memory while sweeping over K and V blocks. Second, it alters warp-level partitioning. A thread block is divided into multiple warps (typically 4 or 8 warps of 32 threads). Flash-Attention-2 slices the Q block horizontally across these warps, while K and V blocks remain accessible to all warps in the shared memory. This eliminates the requirement for warps to write intermediate query-key results back to shared memory to perform inter-warp reductions. Finally, non-matmul operations run on standard CUDA cores (which output 19.5 TFLOPs on an A100), while matmul operations run on Tensor Cores (outputting 312 TFLOPs on an A100). FlashAttention-2 surgically defers scaling and bound-checking math outside the inner loop to maximize Tensor Core active time.

Key Takeaways

FlashAttention-2 parallelizes over the sequence dimension, rescuing low SM occupancy for extremely long contexts.
Warp-level workloads are reorganized to slice Q and share K/V, eliminating costly SRAM synchronization bubbles.
Non-matmul arithmetic is surgically removed from the inner loop to prevent Tensor Core starvation.
It achieves up to 73% of the hardware's theoretical maximum FLOPs, doubling the speed of standard FlashAttention-1.