In my previous post on Transformer Attention, we explored the mathematical foundations of attention. The key limitation? Quadratic memory complexity $O(n^2)$ makes long sequences prohibitively expensive. Flash Attention solves this.

[!NOTE] Flash Attention achieves 2-4x speedup and dramatically reduces memory usage without any approximation — it’s mathematically identical to standard attention.

The Memory Bottleneck Problem

Standard attention computes and stores the full $n \times n$ attention matrix:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

For a sequence of length 8192 with batch size 1:

  • Attention matrix: $8192^2 \times 4$ bytes = 256 MB per head
  • With 32 heads: 8 GB just for attention weights!

[!CRITICAL] This quadratic scaling is why GPT-3 was limited to 2048 tokens, and why long-context models like Claude and GPT-4 required architectural innovations.

Flash Attention: The Key Insight

Flash Attention exploits the memory hierarchy of modern GPUs:

Memory Type Size Bandwidth
SRAM (on-chip) ~20 MB ~19 TB/s
HBM (GPU RAM) 40-80 GB ~1.5 TB/s

The insight: Memory I/O is the bottleneck, not compute. Standard attention:

  1. Loads Q, K from HBM → computes $QK^T$ → writes to HBM
  2. Loads $QK^T$ from HBM → computes softmax → writes to HBM
  3. Loads softmax output from HBM → multiplies by V → writes to HBM

Flash Attention fuses all operations into a single kernel that keeps intermediate results in fast SRAM.

The Tiling Algorithm

Flash Attention processes the attention matrix in tiles that fit in SRAM:

# Pseudocode for Flash Attention forward pass
def flash_attention(Q, K, V, block_size=64):
    """
    Tiled attention computation.
    Q, K, V: (batch, seq_len, d_head)
    """
    n = Q.shape[1]
    output = torch.zeros_like(Q)
    
    # Process in blocks
    for i in range(0, n, block_size):
        q_block = Q[:, i:i+block_size]
        
        # Track running max and normalizer for numerical stability
        m_i = torch.full((q_block.shape[0], block_size), float('-inf'))
        l_i = torch.zeros((q_block.shape[0], block_size))
        o_i = torch.zeros_like(q_block)
        
        for j in range(0, n, block_size):
            k_block = K[:, j:j+block_size]
            v_block = V[:, j:j+block_size]
            
            # Compute attention scores for this tile
            scores = q_block @ k_block.transpose(-1, -2) / math.sqrt(d_k)
            
            # Online softmax update
            m_ij = torch.max(scores, dim=-1).values
            m_new = torch.maximum(m_i, m_ij)
            
            # Rescale and accumulate
            alpha = torch.exp(m_i - m_new)
            beta = torch.exp(m_ij - m_new)
            
            l_i = alpha * l_i + beta * torch.sum(torch.exp(scores - m_ij), dim=-1)
            o_i = alpha * o_i + beta * (torch.exp(scores - m_ij) @ v_block)
            m_i = m_new
        
        # Final normalization
        output[:, i:i+block_size] = o_i / l_i.unsqueeze(-1)
    
    return output

[!TIP] The magic is the online softmax algorithm — we can compute exact softmax incrementally without storing the full attention matrix!

Online Softmax: The Mathematical Trick

Standard softmax requires two passes:

  1. Find max for numerical stability
  2. Compute exp and normalize

Online softmax does it in one pass using running statistics:

\[m^{(j)} = \max(m^{(j-1)}, \max(S_{:,j}))\] \[\ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + \sum_i e^{S_{i,j} - m^{(j)}}\] \[O^{(j)} = e^{m^{(j-1)} - m^{(j)}} O^{(j-1)} + e^{S_{:,j} - m^{(j)}} V_j\]

[!PROOF] Correctness: At convergence, $O/\ell$ equals the exact attention output. This follows from the distributive property of the softmax normalization. ∎

Memory Complexity Comparison

Algorithm Memory I/O Complexity
Standard Attention $O(n^2)$ $O(n^2 d + n^2)$
Flash Attention $O(n)$ $O(n^2 d^2 / M)$

Where $M$ is SRAM size (~20 MB) and $d$ is head dimension (~64-128).

[!SUCCESS] For typical transformer configs, Flash Attention reduces memory from quadratic to linear in sequence length!

Practical Performance

import torch
from flash_attn import flash_attn_func
import time

def benchmark_attention(seq_len, n_heads=32, d_head=64, batch=1):
    """Compare standard vs Flash Attention"""
    
    device = 'cuda'
    q = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    k = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    v = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    
    # Warmup
    for _ in range(10):
        _ = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
    
    # Flash Attention
    start = time.time()
    for _ in range(100):
        _ = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
    flash_time = (time.time() - start) / 100
    
    # Standard attention (for comparison, smaller seq_len)
    q_std = q.transpose(1, 2)
    k_std = k.transpose(1, 2)
    v_std = v.transpose(1, 2)
    
    start = time.time()
    for _ in range(100):
        attn = torch.matmul(q_std, k_std.transpose(-2, -1)) / (d_head ** 0.5)
        attn = torch.softmax(attn, dim=-1)
        _ = torch.matmul(attn, v_std)
    torch.cuda.synchronize()
    std_time = (time.time() - start) / 100
    
    print(f"Seq len {seq_len}: Flash={flash_time*1000:.2f}ms, Std={std_time*1000:.2f}ms, Speedup={std_time/flash_time:.2f}x")

# Run benchmarks
for seq_len in [1024, 2048, 4096, 8192]:
    benchmark_attention(seq_len)

Typical results on A100:

Sequence Length Standard Flash Attention Speedup
1024 1.2 ms 0.4 ms 3.0x
2048 4.8 ms 1.2 ms 4.0x
4096 19.2 ms 4.1 ms 4.7x
8192 OOM 15.3 ms

Using Flash Attention in Practice

With Hugging Face Transformers

from transformers import AutoModelForCausalLM

# Flash Attention 2 is enabled automatically for supported models
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2"
)

Direct Usage with flash-attn Library

from flash_attn import flash_attn_func

# Q, K, V shape: (batch, seq_len, n_heads, head_dim)
output = flash_attn_func(q, k, v, causal=True)

[!WARNING] Flash Attention requires GPU with compute capability >= 8.0 (Ampere or newer). For older GPUs, consider xFormers or PyTorch’s built-in scaled_dot_product_attention.

Flash Attention 2 & 3

Flash Attention has evolved:

Flash Attention 2 (2023):

  • Better work partitioning across GPU threads
  • 2x faster than FA1
  • Better parallelism for small batch sizes

Flash Attention 3 (2024):

  • Exploits Hopper architecture (H100)
  • Asynchronous operations
  • 1.5-2x faster than FA2 on H100

Key Takeaways

[!TIP] Summary: Flash Attention is IO-aware — it minimizes memory transfers between GPU HBM and SRAM. By tiling the computation and using online softmax, it achieves linear memory with no approximation.

  1. Memory I/O is the bottleneck, not compute — Flash Attention optimizes for this
  2. Tiling + online softmax = exact attention with linear memory
  3. 2-4x speedup with 10-20x memory reduction for long sequences
  4. Drop-in replacement — mathematically identical to standard attention

Previous: Transformer Attention: A Mathematical Deep Dive