Chapter 3 — Modern LLM Architectures

Part 2: Architecture Improvements — The Engineering That Makes It Fast

These are the things that differentiate "reads like a paper" knowledge from "works in production" knowledge. Know these cold.

Flash Attention — Making O(n²) Manageable

The Problem

Standard attention materializes a full (T × T) attention matrix in GPU memory. For T=8192 tokens:

  • Attention matrix size: 8192² × 2 bytes (FP16) = 128MB per head per layer
  • GPT-3 has 96 layers × 96 heads = 9,216 heads
  • Total: 9,216 × 128MB = over 1 TB — impossible

Even for smaller models, the O(n²) memory is the bottleneck for long contexts.

The Flash Attention Solution (Dao et al., 2022)

Flash Attention doesn't reduce the computational complexity — it's still O(n²) FLOPs. What it reduces is memory access complexity.

The key insight: GPU HBM (High-Bandwidth Memory, the main GPU RAM) is slow. GPU SRAM (on-chip cache, shared memory) is fast but tiny (~MB range).

Standard attention:

1. Load Q, K from HBM → SRAM
2. Compute QK^T → store attention matrix to HBM (SLOW!)
3. Load attention matrix from HBM → SRAM
4. Compute softmax → store back to HBM (SLOW!)
5. Load attention weights + V from HBM
6. Compute weighted sum → store output to HBM

Many slow HBM reads and writes.

Flash Attention:

1. Tile the Q, K, V matrices into blocks that fit in SRAM
2. For each tile:
   a. Load Q_tile, K_tile, V_tile into SRAM (fast)
   b. Compute partial attention scores in SRAM
   c. Accumulate into running output (no full matrix materialized!)
3. Only write final output to HBM

No O(n²) matrix ever written to HBM. Memory usage drops from O(n²) to O(n). Speed improvement: 2-4× on standard attention, more for longer sequences.

Interview corner case 🎯: "Flash Attention is described as 'IO-aware' — what does that mean?" — It's aware of the GPU memory hierarchy (HBM vs SRAM). The algorithm is designed to minimize slow HBM reads/writes, not to minimize FLOPs. The FLOPs are the same; the memory bandwidth utilization is dramatically better.

Flash Attention 2 (2023) — Further Improvements

  • Parallelizes across the sequence dimension (better GPU utilization)
  • Reduces non-matrix multiply operations (warp-level optimizations)
  • ~2× faster than Flash Attention 1

Flash Attention 3 (2024) — H100 Optimized

  • Overlaps computation and memory transfers using async operations
  • Uses H100's FP8 Tensor Cores
  • ~1.5-2× faster than FA2 on H100s

In practice: You get Flash Attention for free in modern transformers:

# PyTorch 2.0+: just use F.scaled_dot_product_attention
# It automatically uses Flash Attention when possible
attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

RoPE — The Math Behind Rotary Position Embeddings

The Motivation

We want a position encoding where:

  • The dot product q(pos_m) · k(pos_n) is a function only of (m-n) — the relative position
  • This relative information should naturally emerge from the attention computation
  • It should work for any sequence length (not just up to max_seq_len)

The Construction

Consider a 2D case first. We encode position m by rotating a vector v by angle m·θ:

$$\text{Rotate}(v, m \cdot \theta) = \begin{bmatrix} \cos(m \cdot \theta) & -\sin(m \cdot \theta) \\ \sin(m \cdot \theta) & \cos(m \cdot \theta) \end{bmatrix} \begin{bmatrix} v_0 \\ v_1 \end{bmatrix}$$

Now compute the dot product of rotated vectors at positions m and n:

$$\text{Rotate}(q, m \cdot \theta) \cdot \text{Rotate}(k, n \cdot \theta) = q \cdot R(m \cdot \theta)^T R(n \cdot \theta) \cdot k = q \cdot R((n-m) \cdot \theta) \cdot k$$

The dot product only depends on (m-n), not on absolute positions!

For d_head dimensions, we apply this rotation to d_head/2 pairs of dimensions, each with a different frequency θ_i:

$$\theta_i = \frac{1}{10000^{2i/d_{\text{head}}}} \quad \text{for } i = 0, 1, \ldots, \frac{d_{\text{head}}}{2} - 1$$

This gives us a spectrum of frequencies — from very fast rotation (θ_0 = 1) to very slow (θ_{d_head/2-1} = 1/10000). Different dimensions encode relative position at different granularities.

def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)  # (max_seq_len, dim/2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex: e^{i·m·θ}
    return freqs_cis

Interview corner case 🎯: "How does the YaRN method extend context length for RoPE models?" — YaRN (Yet Another RoPE extensioN) applies different scaling to different frequency components. Low-frequency components (slow rotation, encoding long-range positions) are scaled down more aggressively than high-frequency ones. This lets the model extrapolate to longer sequences without fine-tuning, or with minimal fine-tuning.


Tokenizers — From BPE to SentencePiece to TikToken

Why the Tokenizer Matters

The tokenizer is layer 0 of any LLM system. Its quality directly affects:

  • Context efficiency: Better tokenizers compress more text into fewer tokens
  • Language coverage: Byte-level BPE handles all languages; word-level fails on rare words
  • Math/code performance: How you tokenize 3+5=8 affects arithmetic ability

BPE (Byte Pair Encoding) — GPT-2, GPT-3, LLaMA 3

Training process:

1. Start with individual bytes as the vocabulary (256 tokens)
2. Count all adjacent byte pairs in the training corpus
3. Merge the most frequent pair into a new token
4. Repeat until vocab_size is reached (e.g., 50,257 for GPT-2; 128,256 for LLaMA 3)

Example:

Corpus: "low lower lowest"
After merging (l, o): "lo" is a token
After merging (lo, w): "low" is a token
...

The vocabulary of common words and subwords emerges naturally.

GPT-4 / LLaMA 3 upgrade: Larger vocab sizes (100K–128K) → more efficient tokenization, fewer tokens per sentence, better multilingual coverage.

SentencePiece — LLaMA 1 and 2

Operates on Unicode characters directly (not bytes). Language-agnostic. Handles Chinese, Japanese, Arabic naturally without special handling.

TikToken — OpenAI's Fast BPE

GPT-3.5/4 use TikToken: same BPE algorithm but implemented in Rust for speed. Can tokenize millions of tokens per second.

Interview corner case 🎯: "Why might a model perform worse on non-English text?" — If the tokenizer was trained mostly on English data, non-English text gets tokenized into more tokens (less efficiently). Chinese characters might each become multiple tokens while English words are single tokens. This means the model sees less meaningful context per token for non-English text. LLaMA 3's 128K vocabulary was specifically designed to be more multilingual-efficient.


KV Cache — The Secret to Fast Inference

What Is It?

During autoregressive generation, you process tokens one at a time:

  • Step 1: Process tokens [T1]
  • Step 2: Process tokens [T1, T2]
  • Step 3: Process tokens [T1, T2, T3]

Without caching, at step N you recompute Keys and Values for ALL previous tokens — O(N²) total compute.

With KV caching: at step N, you only compute K and V for the new token TN. All previous K, V are cached.

# Pseudo-code for inference with KV cache
kv_cache = {}  # Maps layer_id → (past_keys, past_values)

for new_token in new_tokens:
    # Compute Q, K, V only for the new token
    q = compute_query(new_token, layer)    # (1, d_head)
    k_new = compute_key(new_token, layer)  # (1, d_head)
    v_new = compute_value(new_token, layer) # (1, d_head)

    # Append to cache
    k_cache = torch.cat([kv_cache[layer][0], k_new.unsqueeze(0)], dim=1)
    v_cache = torch.cat([kv_cache[layer][1], v_new.unsqueeze(0)], dim=1)

    # Attend over all cached keys/values
    # q: (1, d_head), k_cache: (T, d_head) → attention scores: (1, T)
    attn_out = attention(q, k_cache, v_cache)

    # Store updated cache
    kv_cache[layer] = (k_cache, v_cache)

KV Cache Memory

For LLaMA 2 7B:

  • 32 layers, 32 heads, d_head=128
  • Each token: 2 × 32 × 32 × 128 = 262,144 floats × 2 bytes = 512 KB per token
  • For 4096-token context: 4096 × 512KB = 2GB

This is why GQA and MQA are important — they reduce the number of KV heads, directly cutting KV cache size.

Interview corner case 🎯: "What is paged attention, and why does vLLM use it?" — Standard KV cache allocates contiguous GPU memory for the full max sequence length per request. Most requests are much shorter → wasted memory → fewer concurrent users. PagedAttention (used in vLLM) stores KV cache in non-contiguous "pages" (like OS virtual memory), allowing it to serve 10-100× more concurrent requests on the same GPU.


Speculative Decoding — 3× Faster Generation

The Problem

LLM generation is memory bandwidth bound — the GPU spends most time loading model weights from memory, not computing. A single forward pass generates exactly ONE token. You have 7B parameters to load for every single token.

The Insight

What if we had a tiny "draft" model that's 10× faster? It generates multiple tokens as a guess. Then we verify all of them in parallel with the large model (one forward pass). If the draft was right, we've generated N tokens for the cost of roughly one large-model pass.

Draft model: generates tokens T1, T2, T3, T4, T5 (fast)
Large model: verify all 5 in one parallel forward pass
             Accept T1 ✓, T2 ✓, T3 ✓, T4 ✗  → accept 3, reject from T4

Accepted tokens per large-model pass: ~2-4 on average → 2-4× throughput improvement with identical output distribution.

Interview corner case 🎯: "Does speculative decoding change the output distribution?" — No! The acceptance/rejection criterion is designed so the final output distribution is mathematically identical to running just the large model. It's not an approximation — it's a provably exact algorithm.


Mixture of Experts (MoE) — Brief Preview

(Full chapter in Chapter 8, but you'll see it referenced here)

Standard FFN: every token goes through the same FFN at each layer.

MoE FFN: N "expert" FFNs, a router selects K of them for each token.

# Simplified MoE forward pass
gate_scores = x @ router_weights   # (B*T, n_experts)
top_k_experts = topk(gate_scores, k=2)  # Select 2 experts
output = sum(expert_i(x) * gate_score_i for i in top_k_experts)

Mixtral 8×7B: 8 experts, 2 active per token.

  • Total params: 45B (7B × 8 experts, plus shared weights)
  • Active params per token: ~13B (2 experts × 7B / experts)
  • Computation: same as 13B model (sparse activation)
  • Quality: approaches 70B model performance

Interview Corner Cases — Chapter 3 Full List 🎯

  • "Why does LLaMA use SentencePiece and LLaMA 3 switch to TikToken-style BPE?" → Larger vocabulary (128K vs 32K) is more efficient, especially for multilingual text and code. At 128K vocab, most common words fit in one token.
  • "What is the difference between Pre-LN and Post-LN, and why did the field switch?" → Post-LN (original paper) places LayerNorm after the residual add. Gradients passing through LayerNorm during backprop can become very small (LayerNorm "re-scales" them). Pre-LN places LayerNorm before — gradients flow unimpeded through the residual path. Pre-LN requires less careful LR warmup and trains stably to much larger depths.
  • "Why does Mistral use 8 KV heads when it has 32 Q heads?" → GQA: 4 Q heads share each KV head. This reduces the KV cache by 4× at the cost of very minor quality degradation. The quality-efficiency tradeoff is extremely favorable for inference-heavy deployments.
  • "What is the benefit of weight tying, and is it always beneficial?" → Weight tying (sharing token embedding matrix with LM head) reduces parameters by ~vocab_size × d_model and often improves performance. The theory: if a token has a high-magnitude embedding vector, the model should predict it with high probability when its internal representation aligns with that direction. Not always beneficial — very large vocab or very small d_model might see degradation.
  • "How does sparse attention help with long-context understanding?" → By focusing each token's attention on nearby tokens (plus potentially a few "global" tokens), you reduce the O(n²) cost while retaining most of the long-range signal through stacking (each layer can propagate info from the previous layer's window, building a growing receptive field with depth).
  • "If GQA has fewer KV heads, doesn't that hurt quality?" → Surprisingly little. Research (Ainslie et al., 2023) shows GQA with 1/4 of KV heads loses minimal quality versus MHA. The Q heads apparently don't need unique KV representations to learn different attention patterns — the differences between heads come mainly from the Q projections.