<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://rohit.vision/blogs/feed.xml" rel="self" type="application/atom+xml" /><link href="https://rohit.vision/blogs/" rel="alternate" type="text/html" /><updated>2026-03-09T16:12:07+00:00</updated><id>https://rohit.vision/blogs/feed.xml</id><title type="html">Rohit Kumar | AI Research Blog</title><subtitle>Deep dives into Computer Vision, LLMs, Diffusion Models, and Agentic AI. Technical tutorials with math, code, and interactive visualizations.</subtitle><author><name>Rohit Kumar</name></author><entry><title type="html">Flash Attention: Making Transformers Scale</title><link href="https://rohit.vision/blogs/posts/flash-attention-scaling/" rel="alternate" type="text/html" title="Flash Attention: Making Transformers Scale" /><published>2024-12-18T00:00:00+00:00</published><updated>2024-12-18T00:00:00+00:00</updated><id>https://rohit.vision/blogs/posts/flash-attention-scaling</id><content type="html" xml:base="https://rohit.vision/blogs/posts/flash-attention-scaling/"><![CDATA[<p>In my <a href="/posts/transformer-attention-deep-dive/">previous post on Transformer Attention</a>, we explored the mathematical foundations of attention. The key limitation? <strong>Quadratic memory complexity</strong> $O(n^2)$ makes long sequences prohibitively expensive. Flash Attention solves this.</p>

<blockquote>
  <p>[!NOTE]
Flash Attention achieves <strong>2-4x speedup</strong> and dramatically reduces memory usage without any approximation — it’s mathematically identical to standard attention.</p>
</blockquote>

<h2 id="the-memory-bottleneck-problem">The Memory Bottleneck Problem</h2>

<p>Standard attention computes and stores the full $n \times n$ attention matrix:</p>

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

<p>For a sequence of length 8192 with batch size 1:</p>
<ul>
  <li>Attention matrix: $8192^2 \times 4$ bytes = <strong>256 MB</strong> per head</li>
  <li>With 32 heads: <strong>8 GB</strong> just for attention weights!</li>
</ul>

<blockquote>
  <p>[!CRITICAL]
This quadratic scaling is why GPT-3 was limited to 2048 tokens, and why long-context models like Claude and GPT-4 required architectural innovations.</p>
</blockquote>

<h2 id="flash-attention-the-key-insight">Flash Attention: The Key Insight</h2>

<p>Flash Attention exploits the <strong>memory hierarchy</strong> of modern GPUs:</p>

<table>
  <thead>
    <tr>
      <th>Memory Type</th>
      <th>Size</th>
      <th>Bandwidth</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>SRAM (on-chip)</td>
      <td>~20 MB</td>
      <td>~19 TB/s</td>
    </tr>
    <tr>
      <td>HBM (GPU RAM)</td>
      <td>40-80 GB</td>
      <td>~1.5 TB/s</td>
    </tr>
  </tbody>
</table>

<p>The insight: <strong>Memory I/O is the bottleneck, not compute</strong>. Standard attention:</p>
<ol>
  <li>Loads Q, K from HBM → computes $QK^T$ → writes to HBM</li>
  <li>Loads $QK^T$ from HBM → computes softmax → writes to HBM</li>
  <li>Loads softmax output from HBM → multiplies by V → writes to HBM</li>
</ol>

<p>Flash Attention fuses all operations into a <strong>single kernel</strong> that keeps intermediate results in fast SRAM.</p>

<h2 id="the-tiling-algorithm">The Tiling Algorithm</h2>

<p>Flash Attention processes the attention matrix in <strong>tiles</strong> that fit in SRAM:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Pseudocode for Flash Attention forward pass
</span><span class="k">def</span> <span class="nf">flash_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">block_size</span><span class="o">=</span><span class="mi">64</span><span class="p">):</span>
    <span class="s">"""
    Tiled attention computation.
    Q, K, V: (batch, seq_len, d_head)
    """</span>
    <span class="n">n</span> <span class="o">=</span> <span class="n">Q</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">Q</span><span class="p">)</span>
    
    <span class="c1"># Process in blocks
</span>    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">block_size</span><span class="p">):</span>
        <span class="n">q_block</span> <span class="o">=</span> <span class="n">Q</span><span class="p">[:,</span> <span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="n">block_size</span><span class="p">]</span>
        
        <span class="c1"># Track running max and normalizer for numerical stability
</span>        <span class="n">m_i</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">full</span><span class="p">((</span><span class="n">q_block</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">block_size</span><span class="p">),</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>
        <span class="n">l_i</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">q_block</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">block_size</span><span class="p">))</span>
        <span class="n">o_i</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">q_block</span><span class="p">)</span>
        
        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">block_size</span><span class="p">):</span>
            <span class="n">k_block</span> <span class="o">=</span> <span class="n">K</span><span class="p">[:,</span> <span class="n">j</span><span class="p">:</span><span class="n">j</span><span class="o">+</span><span class="n">block_size</span><span class="p">]</span>
            <span class="n">v_block</span> <span class="o">=</span> <span class="n">V</span><span class="p">[:,</span> <span class="n">j</span><span class="p">:</span><span class="n">j</span><span class="o">+</span><span class="n">block_size</span><span class="p">]</span>
            
            <span class="c1"># Compute attention scores for this tile
</span>            <span class="n">scores</span> <span class="o">=</span> <span class="n">q_block</span> <span class="o">@</span> <span class="n">k_block</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span>
            
            <span class="c1"># Online softmax update
</span>            <span class="n">m_ij</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">values</span>
            <span class="n">m_new</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">m_i</span><span class="p">,</span> <span class="n">m_ij</span><span class="p">)</span>
            
            <span class="c1"># Rescale and accumulate
</span>            <span class="n">alpha</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_i</span> <span class="o">-</span> <span class="n">m_new</span><span class="p">)</span>
            <span class="n">beta</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">m_ij</span> <span class="o">-</span> <span class="n">m_new</span><span class="p">)</span>
            
            <span class="n">l_i</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">l_i</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">scores</span> <span class="o">-</span> <span class="n">m_ij</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">o_i</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">o_i</span> <span class="o">+</span> <span class="n">beta</span> <span class="o">*</span> <span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">scores</span> <span class="o">-</span> <span class="n">m_ij</span><span class="p">)</span> <span class="o">@</span> <span class="n">v_block</span><span class="p">)</span>
            <span class="n">m_i</span> <span class="o">=</span> <span class="n">m_new</span>
        
        <span class="c1"># Final normalization
</span>        <span class="n">output</span><span class="p">[:,</span> <span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="n">block_size</span><span class="p">]</span> <span class="o">=</span> <span class="n">o_i</span> <span class="o">/</span> <span class="n">l_i</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    
    <span class="k">return</span> <span class="n">output</span>
</code></pre></div></div>

<blockquote>
  <p>[!TIP]
The magic is the <strong>online softmax</strong> algorithm — we can compute exact softmax incrementally without storing the full attention matrix!</p>
</blockquote>

<h2 id="online-softmax-the-mathematical-trick">Online Softmax: The Mathematical Trick</h2>

<p>Standard softmax requires two passes:</p>
<ol>
  <li>Find max for numerical stability</li>
  <li>Compute exp and normalize</li>
</ol>

<p>Online softmax does it in <strong>one pass</strong> using running statistics:</p>

\[m^{(j)} = \max(m^{(j-1)}, \max(S_{:,j}))\]

\[\ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + \sum_i e^{S_{i,j} - m^{(j)}}\]

\[O^{(j)} = e^{m^{(j-1)} - m^{(j)}} O^{(j-1)} + e^{S_{:,j} - m^{(j)}} V_j\]

<blockquote>
  <p>[!PROOF]
<strong>Correctness</strong>: At convergence, $O/\ell$ equals the exact attention output. This follows from the distributive property of the softmax normalization. ∎</p>
</blockquote>

<h2 id="memory-complexity-comparison">Memory Complexity Comparison</h2>

<table>
  <thead>
    <tr>
      <th>Algorithm</th>
      <th>Memory</th>
      <th>I/O Complexity</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Standard Attention</td>
      <td>$O(n^2)$</td>
      <td>$O(n^2 d + n^2)$</td>
    </tr>
    <tr>
      <td>Flash Attention</td>
      <td>$O(n)$</td>
      <td>$O(n^2 d^2 / M)$</td>
    </tr>
  </tbody>
</table>

<p>Where $M$ is SRAM size (~20 MB) and $d$ is head dimension (~64-128).</p>

<blockquote>
  <p>[!SUCCESS]
For typical transformer configs, Flash Attention reduces memory from <strong>quadratic to linear</strong> in sequence length!</p>
</blockquote>

<h2 id="practical-performance">Practical Performance</h2>

<div class="collapsible" data-label="Show Benchmark Results">

```python
import torch
from flash_attn import flash_attn_func
import time

def benchmark_attention(seq_len, n_heads=32, d_head=64, batch=1):
    """Compare standard vs Flash Attention"""
    
    device = 'cuda'
    q = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    k = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    v = torch.randn(batch, seq_len, n_heads, d_head, device=device, dtype=torch.float16)
    
    # Warmup
    for _ in range(10):
        _ = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
    
    # Flash Attention
    start = time.time()
    for _ in range(100):
        _ = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
    flash_time = (time.time() - start) / 100
    
    # Standard attention (for comparison, smaller seq_len)
    q_std = q.transpose(1, 2)
    k_std = k.transpose(1, 2)
    v_std = v.transpose(1, 2)
    
    start = time.time()
    for _ in range(100):
        attn = torch.matmul(q_std, k_std.transpose(-2, -1)) / (d_head ** 0.5)
        attn = torch.softmax(attn, dim=-1)
        _ = torch.matmul(attn, v_std)
    torch.cuda.synchronize()
    std_time = (time.time() - start) / 100
    
    print(f"Seq len {seq_len}: Flash={flash_time*1000:.2f}ms, Std={std_time*1000:.2f}ms, Speedup={std_time/flash_time:.2f}x")

# Run benchmarks
for seq_len in [1024, 2048, 4096, 8192]:
    benchmark_attention(seq_len)
```

**Typical results on A100:**
| Sequence Length | Standard | Flash Attention | Speedup |
|-----------------|----------|-----------------|---------|
| 1024 | 1.2 ms | 0.4 ms | 3.0x |
| 2048 | 4.8 ms | 1.2 ms | 4.0x |
| 4096 | 19.2 ms | 4.1 ms | 4.7x |
| 8192 | OOM | 15.3 ms | ∞ |

</div>

<h2 id="using-flash-attention-in-practice">Using Flash Attention in Practice</h2>

<h3 id="with-hugging-face-transformers">With Hugging Face Transformers</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoModelForCausalLM</span>

<span class="c1"># Flash Attention 2 is enabled automatically for supported models
</span><span class="n">model</span> <span class="o">=</span> <span class="n">AutoModelForCausalLM</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span>
    <span class="s">"meta-llama/Llama-2-7b-hf"</span><span class="p">,</span>
    <span class="n">torch_dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float16</span><span class="p">,</span>
    <span class="n">attn_implementation</span><span class="o">=</span><span class="s">"flash_attention_2"</span>
<span class="p">)</span>
</code></pre></div></div>

<h3 id="direct-usage-with-flash-attn-library">Direct Usage with flash-attn Library</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">flash_attn</span> <span class="kn">import</span> <span class="n">flash_attn_func</span>

<span class="c1"># Q, K, V shape: (batch, seq_len, n_heads, head_dim)
</span><span class="n">output</span> <span class="o">=</span> <span class="n">flash_attn_func</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">causal</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>

<blockquote>
  <p>[!WARNING]
Flash Attention requires <strong>GPU with compute capability &gt;= 8.0</strong> (Ampere or newer). For older GPUs, consider xFormers or PyTorch’s built-in <code class="language-plaintext highlighter-rouge">scaled_dot_product_attention</code>.</p>
</blockquote>

<h2 id="flash-attention-2--3">Flash Attention 2 &amp; 3</h2>

<p>Flash Attention has evolved:</p>

<p><strong>Flash Attention 2</strong> (2023):</p>
<ul>
  <li>Better work partitioning across GPU threads</li>
  <li>2x faster than FA1</li>
  <li>Better parallelism for small batch sizes</li>
</ul>

<p><strong>Flash Attention 3</strong> (2024):</p>
<ul>
  <li>Exploits Hopper architecture (H100)</li>
  <li>Asynchronous operations</li>
  <li>1.5-2x faster than FA2 on H100</li>
</ul>

<h2 id="key-takeaways">Key Takeaways</h2>

<blockquote>
  <p>[!TIP]
<strong>Summary</strong>: Flash Attention is <strong>IO-aware</strong> — it minimizes memory transfers between GPU HBM and SRAM. By tiling the computation and using online softmax, it achieves linear memory with no approximation.</p>
</blockquote>

<ol>
  <li><strong>Memory I/O is the bottleneck</strong>, not compute — Flash Attention optimizes for this</li>
  <li><strong>Tiling + online softmax</strong> = exact attention with linear memory</li>
  <li><strong>2-4x speedup</strong> with <strong>10-20x memory reduction</strong> for long sequences</li>
  <li><strong>Drop-in replacement</strong> — mathematically identical to standard attention</li>
</ol>

<hr />

<p><em>Previous: <a href="/posts/transformer-attention-deep-dive/">Transformer Attention: A Mathematical Deep Dive</a></em></p>]]></content><author><name>Rohit Kumar</name></author><category term="transformers" /><category term="attention" /><category term="optimization" /><category term="flash-attention" /><category term="gpu" /><summary type="html"><![CDATA[In my previous post on Transformer Attention, we explored the mathematical foundations of attention. The key limitation? Quadratic memory complexity $O(n^2)$ makes long sequences prohibitively expensive. Flash Attention solves this.]]></summary></entry><entry><title type="html">Transformer Attention: A Mathematical Deep Dive</title><link href="https://rohit.vision/blogs/posts/transformer-attention-deep-dive/" rel="alternate" type="text/html" title="Transformer Attention: A Mathematical Deep Dive" /><published>2024-12-17T00:00:00+00:00</published><updated>2024-12-17T00:00:00+00:00</updated><id>https://rohit.vision/blogs/posts/transformer-attention-deep-dive</id><content type="html" xml:base="https://rohit.vision/blogs/posts/transformer-attention-deep-dive/"><![CDATA[<p>The attention mechanism is the core innovation behind transformers. Let’s break it down mathematically and implement it from scratch.</p>

<blockquote>
  <p>[!NOTE]
This post assumes familiarity with basic linear algebra and neural networks. If you’re new to these topics, check out my <a href="#">prerequisites guide</a>.</p>
</blockquote>

<h2 id="the-attention-formula">The Attention Formula</h2>

<p>At its heart, attention computes a weighted sum of values based on query-key similarity:</p>

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

<p>Where:</p>
<ul>
  <li>$Q$ = Query matrix of shape $(n, d_k)$</li>
  <li>$K$ = Key matrix of shape $(m, d_k)$</li>
  <li>$V$ = Value matrix of shape $(m, d_v)$</li>
  <li>$d_k$ = Key/query dimension (scaling factor)</li>
</ul>

<blockquote>
  <p>[!TIP]
Think of attention as a “soft” dictionary lookup: queries find relevant keys, and retrieve their associated values.</p>
</blockquote>

<h2 id="why-scale-by-sqrtd_k">Why Scale by $\sqrt{d_k}$?</h2>

<p>The dot product $QK^T$ grows with dimension. For large $d_k$, the softmax saturates to one-hot vectors, killing gradients. Scaling by $\sqrt{d_k}$ keeps variance stable:</p>

\[\text{Var}(q \cdot k) = d_k \cdot \text{Var}(q_i) \cdot \text{Var}(k_i) = d_k\]

<p>After scaling: $\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = 1$</p>

<blockquote>
  <p>[!WARNING]
Forgetting this scaling factor is a common bug! Without it, gradients vanish for $d_k &gt; 64$.</p>
</blockquote>

<h2 id="pytorch-implementation">PyTorch Implementation</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>

<span class="k">class</span> <span class="nc">ScaledDotProductAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">d_k</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span> 
        <span class="n">query</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span>  <span class="c1"># (batch, n, d_k)
</span>        <span class="n">key</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span>    <span class="c1"># (batch, m, d_k)
</span>        <span class="n">value</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span>  <span class="c1"># (batch, m, d_v)
</span>        <span class="n">mask</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="bp">None</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">]:</span>
        <span class="c1"># Compute attention scores
</span>        <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">scale</span>
        
        <span class="c1"># Apply mask (for causal attention or padding)
</span>        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>
        
        <span class="c1"># Softmax over keys
</span>        <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">attn_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">)</span>
        
        <span class="c1"># Weighted sum of values
</span>        <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
        
        <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">attn_weights</span>
</code></pre></div></div>

<h2 id="multi-head-attention">Multi-Head Attention</h2>

<p>Instead of a single attention function, transformers use multiple “heads” in parallel:</p>

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]

<p>where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$</p>

<blockquote>
  <p>[!QUESTION]
Why use multiple heads instead of one large attention? Answer: Each head can attend to different aspects of the input (syntax, semantics, position, etc.).</p>
</blockquote>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">d_model</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">n_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">ScaledDotProductAttention</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">query</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        
        <span class="c1"># Linear projections and reshape for multi-head
</span>        <span class="n">Q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span><span class="p">(</span><span class="n">query</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">K</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span><span class="p">(</span><span class="n">key</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">V</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span><span class="p">(</span><span class="n">value</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        
        <span class="c1"># Apply attention
</span>        <span class="n">x</span><span class="p">,</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
        
        <span class="c1"># Concatenate heads and project
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">attn</span>
</code></pre></div></div>

<h2 id="complexity-analysis">Complexity Analysis</h2>

<table>
  <thead>
    <tr>
      <th>Operation</th>
      <th>Time Complexity</th>
      <th>Space Complexity</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>$QK^T$ computation</td>
      <td>$O(n \cdot m \cdot d_k)$</td>
      <td>$O(n \cdot m)$</td>
    </tr>
    <tr>
      <td>Softmax</td>
      <td>$O(n \cdot m)$</td>
      <td>$O(n \cdot m)$</td>
    </tr>
    <tr>
      <td>Attention × Value</td>
      <td>$O(n \cdot m \cdot d_v)$</td>
      <td>$O(n \cdot d_v)$</td>
    </tr>
    <tr>
      <td><strong>Total</strong></td>
      <td>$O(n \cdot m \cdot d)$</td>
      <td>$O(n \cdot m)$</td>
    </tr>
  </tbody>
</table>

<p>For self-attention ($n = m$), this is <strong>quadratic</strong> in sequence length — the main bottleneck for long sequences.</p>

<h2 id="interactive-demo">Interactive Demo</h2>

<p>Try this attention visualization on Hugging Face:</p>

<div class="hf-space" data-src="exbert-project/exbert" data-height="600"></div>

<h2 id="try-it-yourself">Try It Yourself</h2>

<p>Run this simple attention calculation directly in your browser:</p>

<div class="runnable" data-lang="python">

  <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>

<span class="c1"># Simple attention example (no PyTorch needed!)
</span><span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="n">exp_x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">exp_x</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">exp_x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">):</span>
    <span class="n">d_k</span> <span class="o">=</span> <span class="n">K</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    <span class="n">scores</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">@</span> <span class="n">K</span><span class="p">.</span><span class="n">T</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span>
    <span class="n">weights</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">weights</span> <span class="o">@</span> <span class="n">V</span><span class="p">,</span> <span class="n">weights</span>

<span class="c1"># Create sample query, key, value vectors
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>  <span class="c1"># 2 queries, dim 4
</span><span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>  <span class="c1"># 3 keys
</span><span class="n">V</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>  <span class="c1"># 3 values
</span>
<span class="n">output</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Query shape:"</span><span class="p">,</span> <span class="n">Q</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Key shape:"</span><span class="p">,</span> <span class="n">K</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Value shape:"</span><span class="p">,</span> <span class="n">V</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Attention weights (which keys each query attends to):"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">Output shape:"</span><span class="p">,</span> <span class="n">output</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
</code></pre></div>  </div>

</div>

<h2 id="key-takeaways">Key Takeaways</h2>

<blockquote>
  <p>[!TIP]
<strong>Summary</strong>: Attention is a “soft dictionary” — queries find keys, retrieve values. Multi-head attention learns multiple perspectives. The $\sqrt{d_k}$ scaling prevents gradient issues.</p>
</blockquote>

<ol>
  <li><strong>Attention is a soft dictionary lookup</strong>: queries find relevant keys, retrieve values</li>
  <li><strong>Scaling prevents gradient vanishing</strong> in high dimensions</li>
  <li><strong>Multi-head = multiple perspectives</strong> on the same input</li>
  <li><strong>Quadratic complexity</strong> motivates efficient variants (Flash Attention, Linear Attention)</li>
</ol>

<hr />

<h2 id="new-feature-showcase">New Feature Showcase</h2>

<p>This section demonstrates the new “Second Brain” features added to the blog.</p>

<h3 id="enhanced-callouts">Enhanced Callouts</h3>

<blockquote>
  <p>[!ABSTRACT]
This post provides a mathematical deep-dive into the attention mechanism, the core innovation behind transformer architectures. We cover scaled dot-product attention, multi-head attention, complexity analysis, and provide interactive implementations.</p>
</blockquote>

<blockquote>
  <p>[!DEFINITION]
<strong>Scaled Dot-Product Attention</strong> is a function that maps a query and a set of key-value pairs to an output, where the output is computed as a weighted sum of the values, with weights determined by the compatibility of the query with the corresponding keys.</p>
</blockquote>

<blockquote>
  <p>[!PROOF]
<strong>Variance Stability Proof</strong>: Let $q_i, k_i \sim \mathcal{N}(0, 1)$ be i.i.d. standard normal. Then $\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k$. Dividing by $\sqrt{d_k}$ gives $\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = 1$. ∎</p>
</blockquote>

<blockquote>
  <p>[!EXAMPLE]
For a 3-token sequence “The cat sat”, self-attention allows “sat” to attend to “cat” (subject) with high weight, while “The” attends mostly to itself since it’s a common determiner.</p>
</blockquote>

<blockquote>
  <p>[!CRITICAL]
<strong>GPU Memory Warning</strong>: Attention’s $O(n^2)$ space complexity means a sequence of length 8192 requires ~256MB just for the attention matrix (float32). This is why Flash Attention and memory-efficient variants are essential for long-context models!</p>
</blockquote>

<blockquote>
  <p>[!SUCCESS]
After implementing attention correctly with proper scaling, you should see smooth training curves and stable gradients even with $d_k = 512$ or higher.</p>
</blockquote>

<h3 id="collapsible-code-block">Collapsible Code Block</h3>

<p>The full multi-head attention implementation is hidden by default to keep the article clean:</p>

<div class="collapsible" data-label="Show Full Transformer Block">

  <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s">"""Complete transformer block with attention, FFN, and residual connections."""</span>
    
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        
        <span class="c1"># Multi-head attention
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        
        <span class="c1"># Feed-forward network
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">GELU</span><span class="p">(),</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="c1"># Self-attention with residual
</span>        <span class="n">attn_out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">attention</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">attn_out</span><span class="p">)</span>
        
        <span class="c1"># FFN with residual
</span>        <span class="n">ffn_out</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">ffn_out</span><span class="p">)</span>
        
        <span class="k">return</span> <span class="n">x</span>

<span class="c1"># Example usage
</span><span class="n">block</span> <span class="o">=</span> <span class="n">TransformerBlock</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">n_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">d_ff</span><span class="o">=</span><span class="mi">2048</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span>  <span class="c1"># batch=2, seq_len=100, d_model=512
</span><span class="n">output</span> <span class="o">=</span> <span class="n">block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Input: </span><span class="si">{</span><span class="n">x</span><span class="p">.</span><span class="n">shape</span><span class="si">}</span><span class="s"> -&gt; Output: </span><span class="si">{</span><span class="n">output</span><span class="p">.</span><span class="n">shape</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div>  </div>

</div>

<h3 id="video-demo-example-syntax">Video Demo (Example Syntax)</h3>

<p>Here’s how you can embed videos showing your model in action:</p>

<div class="video-embed" data-src="https://www.youtube.com/watch?v=kCc8FmEb1nY" data-caption="Andrej Karpathy's excellent 'Let's build GPT' tutorial"></div>

<h3 id="image-comparison-slider">Image Comparison Slider</h3>

<p>Drag the slider to compare raw vs processed attention patterns:</p>

<div class="image-compare" data-before="/assets/images/attention_before.png" data-after="/assets/images/attention_after.png">
  <span class="compare-label-before">Raw</span>
  <span class="compare-label-after">Processed</span>
</div>

<p><strong>Usage syntax:</strong></p>
<div class="language-html highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nt">&lt;div</span> <span class="na">class=</span><span class="s">"image-compare"</span> 
     <span class="na">data-before=</span><span class="s">"/path/to/before.png"</span> 
     <span class="na">data-after=</span><span class="s">"/path/to/after.png"</span><span class="nt">&gt;</span>
  <span class="nt">&lt;span</span> <span class="na">class=</span><span class="s">"compare-label-before"</span><span class="nt">&gt;</span>Before<span class="nt">&lt;/span&gt;</span>
  <span class="nt">&lt;span</span> <span class="na">class=</span><span class="s">"compare-label-after"</span><span class="nt">&gt;</span>After<span class="nt">&lt;/span&gt;</span>
<span class="nt">&lt;/div&gt;</span>
</code></pre></div></div>

<h3 id="multi-image-layouts">Multi-Image Layouts</h3>

<p><strong>Single Image (1x1)</strong> - Default layout:</p>

<figure class="single-image">
    <img src="https://picsum.photos/seed/single/600/300" alt="Single image with caption" /><figcaption>Single image with caption</figcaption></figure>

<p><strong>Two-Column (1x2)</strong> - Just add two images:</p>

<div class="image-row"><figure>
        <img src="https://picsum.photos/seed/attention1/400/300" alt="Self-Attention" /><figcaption>Self-Attention</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/attention2/400/300" alt="Multi-Head Attention" /><figcaption>Multi-Head Attention</figcaption></figure></div>

<p><strong>Grid (3+)</strong> - Automatically creates grid:</p>

<div class="image-grid cols-3 "><figure>
        <img src="https://picsum.photos/seed/layer1/300/200" alt="Layer 1" /><figcaption>Layer 1</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/layer2/300/200" alt="Layer 2" /><figcaption>Layer 2</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/layer3/300/200" alt="Layer 3" /><figcaption>Layer 3</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/layer4/300/200" alt="Layer 4" /><figcaption>Layer 4</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/layer5/300/200" alt="Layer 5" /><figcaption>Layer 5</figcaption></figure><figure>
        <img src="https://picsum.photos/seed/layer6/300/200" alt="Layer 6" /><figcaption>Layer 6</figcaption></figure></div>

<p><strong>Usage syntax:</strong></p>
<div class="language-liquid highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
{# Single image #}
<span class="p">{%</span><span class="w"> </span><span class="nt">include</span><span class="w"> </span>img.html<span class="w"> </span><span class="na">src</span><span class="o">=</span><span class="s2">"/path/image.png"</span><span class="w"> </span><span class="p">%}</span>
<span class="p">{%</span><span class="w"> </span><span class="nt">include</span><span class="w"> </span>img.html<span class="w"> </span><span class="na">src</span><span class="o">=</span><span class="s2">"/path/image.png"</span><span class="w"> </span><span class="na">cap</span><span class="o">=</span><span class="s2">"With caption"</span><span class="w"> </span><span class="p">%}</span>

{# Two-column #}
<span class="p">{%</span><span class="w"> </span><span class="nt">include</span><span class="w"> </span>img.html<span class="w"> </span><span class="na">src</span><span class="o">=</span><span class="s2">"/path/1.png, /path/2.png"</span><span class="w"> </span><span class="na">cap</span><span class="o">=</span><span class="s2">"Left, Right"</span><span class="w"> </span><span class="p">%}</span>

{# Grid (3+ images) #}
<span class="p">{%</span><span class="w"> </span><span class="nt">include</span><span class="w"> </span>img.html<span class="w"> </span><span class="na">src</span><span class="o">=</span><span class="s2">"/1.png, /2.png, /3.png"</span><span class="w"> </span><span class="na">cap</span><span class="o">=</span><span class="s2">"A, B, C"</span><span class="w"> </span><span class="na">cols</span><span class="o">=</span><span class="s2">"3"</span><span class="w"> </span><span class="p">%}</span>

</code></pre></div></div>

<h3 id="summary-of-new-markdown-syntax">Summary of New Markdown Syntax</h3>

<table>
  <thead>
    <tr>
      <th>Feature</th>
      <th>Syntax</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Abstract</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!ABSTRACT]</code></td>
    </tr>
    <tr>
      <td>Definition</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!DEFINITION]</code></td>
    </tr>
    <tr>
      <td>Proof</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!PROOF]</code></td>
    </tr>
    <tr>
      <td>Example</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!EXAMPLE]</code></td>
    </tr>
    <tr>
      <td>Critical</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!CRITICAL]</code></td>
    </tr>
    <tr>
      <td>Success</td>
      <td><code class="language-plaintext highlighter-rouge">&gt; [!SUCCESS]</code></td>
    </tr>
    <tr>
      <td>Collapsible Code</td>
      <td><code class="language-plaintext highlighter-rouge">&lt;div class="collapsible"&gt;...&lt;/div&gt;</code></td>
    </tr>
    <tr>
      <td>Video Embed</td>
      <td><code class="language-plaintext highlighter-rouge">&lt;div class="video-embed" data-src="URL"&gt;</code></td>
    </tr>
    <tr>
      <td>Image Compare</td>
      <td><code class="language-plaintext highlighter-rouge">&lt;div class="image-compare" data-before="..." data-after="..."&gt;</code></td>
    </tr>
    <tr>
      <td>Single Image</td>
      <td><code class="language-plaintext highlighter-rouge">{% include img.html src="/path.png" %}</code></td>
    </tr>
    <tr>
      <td>Two Images</td>
      <td><code class="language-plaintext highlighter-rouge">{% include img.html src="/1.png, /2.png" %}</code></td>
    </tr>
    <tr>
      <td>Image Grid</td>
      <td><code class="language-plaintext highlighter-rouge">{% include img.html src="/1.png, /2.png, /3.png" cols="3" %}</code></td>
    </tr>
  </tbody>
</table>

<hr />

<p><em>Next post: We’ll implement Flash Attention and benchmark against naive attention.</em></p>]]></content><author><name>Rohit Kumar</name></author><category term="transformers" /><category term="attention" /><category term="deep-learning" /><category term="tutorial" /><summary type="html"><![CDATA[The attention mechanism is the core innovation behind transformers. Let’s break it down mathematically and implement it from scratch.]]></summary></entry></feed>