Flash Attention: Making Transformers Scale
In my previous post on Transformer Attention, we explored the mathematical foundations of attention. The key limitation? Quadratic memory complexity $O(n^2)$ makes long sequences prohibitively expensive. Flash Attention solves this.
[!NOTE] Flash Attention achieves 2-4x speedup and dramatically reduces memory usage without any approximation — it’s mathematically identical to standard attention.
The Memory Bottleneck Problem
Standard attention computes and stores the full $n \times n$ attention matrix:
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]For a sequence of length 8192 with batch size 1:
- Attention matrix: $8192^2 \times 4$ bytes = 256 MB per head
- With 32 heads: 8 GB just for attention weights!
[!CRITICAL] This quadratic scaling is why GPT-3 was limited to 2048 tokens, and why long-context models like Claude and GPT-4 required architectural innovations.
Flash Attention: The Key Insight
Flash Attention exploits the memory hierarchy of modern GPUs:
| Memory Type | Size | Bandwidth |
|---|---|---|
| SRAM (on-chip) | ~20 MB | ~19 TB/s |
| HBM (GPU RAM) | 40-80 GB | ~1.5 TB/s |
The insight: Memory I/O is the bottleneck, not compute. Standard attention:
- Loads Q, K from HBM → computes $QK^T$ → writes to HBM
- Loads $QK^T$ from HBM → computes softmax → writes to HBM
- Loads softmax output from HBM → multiplies by V → writes to HBM
Flash Attention fuses all operations into a single kernel that keeps intermediate results in fast SRAM.
The Tiling Algorithm
Flash Attention processes the attention matrix in tiles that fit in SRAM:
# Pseudocode for Flash Attention forward pass
def flash_attention(Q, K, V, block_size=64):
"""
Tiled attention computation.
Q, K, V: (batch, seq_len, d_head)
"""
n = Q.shape[1]
output = torch.zeros_like(Q)
# Process in blocks
for i in range(0, n, block_size):
q_block = Q[:, i:i+block_size]
# Track running max and normalizer for numerical stability
m_i = torch.full((q_block.shape[0], block_size), float('-inf'))
l_i = torch.zeros((q_block.shape[0], block_size))
o_i = torch.zeros_like(q_block)
for j in range(0, n, block_size):
k_block = K[:, j:j+block_size]
v_block = V[:, j:j+block_size]
# Compute attention scores for this tile
scores = q_block @ k_block.transpose(-1, -2) / math.sqrt(d_k)
# Online softmax update
m_ij = torch.max(scores, dim=-1).values
m_new = torch.maximum(m_i, m_ij)
# Rescale and accumulate
alpha = torch.exp(m_i - m_new)
beta = torch.exp(m_ij - m_new)
l_i = alpha * l_i + beta * torch.sum(torch.exp(scores - m_ij), dim=-1)
o_i = alpha * o_i + beta * (torch.exp(scores - m_ij) @ v_block)
m_i = m_new
# Final normalization
output[:, i:i+block_size] = o_i / l_i.unsqueeze(-1)
return output
[!TIP] The magic is the online softmax algorithm — we can compute exact softmax incrementally without storing the full attention matrix!
Online Softmax: The Mathematical Trick
Standard softmax requires two passes:
- Find max for numerical stability
- Compute exp and normalize
Online softmax does it in one pass using running statistics:
\[m^{(j)} = \max(m^{(j-1)}, \max(S_{:,j}))\] \[\ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + \sum_i e^{S_{i,j} - m^{(j)}}\] \[O^{(j)} = e^{m^{(j-1)} - m^{(j)}} O^{(j-1)} + e^{S_{:,j} - m^{(j)}} V_j\][!PROOF] Correctness: At convergence, $O/\ell$ equals the exact attention output. This follows from the distributive property of the softmax normalization. ∎
Memory Complexity Comparison
| Algorithm | Memory | I/O Complexity |
|---|---|---|
| Standard Attention | $O(n^2)$ | $O(n^2 d + n^2)$ |
| Flash Attention | $O(n)$ | $O(n^2 d^2 / M)$ |
Where $M$ is SRAM size (~20 MB) and $d$ is head dimension (~64-128).
[!SUCCESS] For typical transformer configs, Flash Attention reduces memory from quadratic to linear in sequence length!
Practical Performance
Using Flash Attention in Practice
With Hugging Face Transformers
from transformers import AutoModelForCausalLM
# Flash Attention 2 is enabled automatically for supported models
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
)
Direct Usage with flash-attn Library
from flash_attn import flash_attn_func
# Q, K, V shape: (batch, seq_len, n_heads, head_dim)
output = flash_attn_func(q, k, v, causal=True)
[!WARNING] Flash Attention requires GPU with compute capability >= 8.0 (Ampere or newer). For older GPUs, consider xFormers or PyTorch’s built-in
scaled_dot_product_attention.
Flash Attention 2 & 3
Flash Attention has evolved:
Flash Attention 2 (2023):
- Better work partitioning across GPU threads
- 2x faster than FA1
- Better parallelism for small batch sizes
Flash Attention 3 (2024):
- Exploits Hopper architecture (H100)
- Asynchronous operations
- 1.5-2x faster than FA2 on H100
Key Takeaways
[!TIP] Summary: Flash Attention is IO-aware — it minimizes memory transfers between GPU HBM and SRAM. By tiling the computation and using online softmax, it achieves linear memory with no approximation.
- Memory I/O is the bottleneck, not compute — Flash Attention optimizes for this
- Tiling + online softmax = exact attention with linear memory
- 2-4x speedup with 10-20x memory reduction for long sequences
- Drop-in replacement — mathematically identical to standard attention
Linked Mentions
Loading linked mentions...
@article{kumar2024flash-atte,
author = {Rohit Kumar},
title = {Flash Attention: Making Transformers Scale},
journal = {Rohit Kumar's AI Research Blog},
year = {2024},
month = {December},
url = {https://rohit.vision/blogs/posts/flash-attention-scaling/}
}
Comments