Chapter 3 — Modern LLM Architectures

Part 1: From GPT-2 to LLaMA — The Architecture Evolution

Read time: ~30 minutes Why this matters: Every architecture choice between GPT-2 (2019) and LLaMA 3 (2024) was made to improve either training stability, inference efficiency, or model quality. Knowing why each choice was made tells you how to think about new architectures.

The Genealogy

Transformer (2017) — "Attention Is All You Need"
    ├── BERT (2018, Google) — Encoder-only, MLM
    │       └── RoBERTa, DeBERTa, ALBERT...
    │
    └── GPT (2018, OpenAI) — Decoder-only, causal LM
            │
            GPT-2 (2019) — 1.5B params, BPE tokenizer, weight tying
            │
            GPT-3 (2020) — 175B, few-shot, sparse attention variant
            │
            GPT-NeoX / GPT-J (2021, EleutherAI) — Open source GPT-3 style
            │     └── Introduces ALiBi positional bias (in GPT-NeoX-20B)
            │
            OPT (2022, Meta) — Open GPT-3 scale models
            │
            Chinchilla (2022, DeepMind) — Scaling laws insight
            │
            LLaMA (2023, Meta) — Chinchilla-optimal, open weights
            │       ├── RoPE positional encoding
            │       ├── SwiGLU activation
            │       ├── RMSNorm (faster than LayerNorm)
            │       └── Pre-norm architecture
            │
            LLaMA 2 (2023) — 2T tokens, RLHF, commercial license
            │
            Mistral 7B (2023) — GQA + Sliding Window Attention
            │
            Mixtral 8×7B (2023) — Mixture of Experts on Mistral
            │
            LLaMA 3 (2024) — 128K context, GQA, better tokenizer
            │
            LLaMA 3.1/3.2/3.3 (2024-2025) — Multilingual, multimodal variants

What Changed from GPT-2 to LLaMA? A Diff

Let's look at the exact changes, one by one:

Change 1: Normalization — LayerNorm → RMSNorm

GPT-2 uses: LayerNorm(x) — normalize by mean and variance:

# LayerNorm
mean = x.mean(dim=-1, keepdim=True)
var  = x.var(dim=-1, keepdim=True)
x_norm = (x - mean) / sqrt(var + eps)
out = gamma * x_norm + beta  # learned scale and shift

LLaMA uses: RMSNorm(x) — normalize only by root mean square:

# RMSNorm (simpler, faster)
rms = sqrt((x**2).mean(dim=-1, keepdim=True) + eps)
x_norm = x / rms
out = gamma * x_norm  # no bias (beta) needed!

Why RMSNorm?

  • 7-15% faster than LayerNorm (no mean computation, no bias parameter)
  • Empirically similar quality
  • Less memory (no beta parameter)

Interview corner case 🎯: "What's the difference between LayerNorm and BatchNorm?" — BatchNorm normalizes across the batch dimension (requires large batches, bad for small batches). LayerNorm normalizes across the feature dimension per sample (works with any batch size, which is why it's used in transformers).


Change 2: Positional Encoding — Absolute Learned → RoPE

GPT-2 uses: Learned absolute positional embeddings.

pos_emb = nn.Embedding(max_seq_len, d_model)
x = token_emb(tokens) + pos_emb(positions)

Problem: Extrapolates poorly beyond max_seq_len.

LLaMA uses: Rotary Position Embedding (RoPE).

RoPE encodes position by rotating the Q and K vectors in 2D subspaces:

def apply_rotary_emb(q, k, freqs_cis):
    """
    q, k: (B, T, n_heads, d_head)
    freqs_cis: (T, d_head/2) complex frequencies

    Rotate each (q, k) pair by angle proportional to position.
    Key property: q_rot[p] · k_rot[q] = q · k × f(p-q)
    The dot product naturally encodes relative position.
    """
    # Treat d_head dimensions as complex pairs: (d0, d1) → d0 + i*d1
    q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
    k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
    # Multiply by rotation frequencies (complex multiplication = rotation!)
    q_rot = torch.view_as_real(q_ * freqs_cis).flatten(-2)
    k_rot = torch.view_as_real(k_ * freqs_cis).flatten(-2)
    return q_rot.type_as(q), k_rot.type_as(k)

Why RoPE is better:

  • Encodes relative positions, not absolute → more generalizable
  • Can be extended to longer sequences by adjusting the rotation frequencies
  • Works with causal attention naturally

Change 3: Activation Function — GELU → SwiGLU

GPT-2 uses: GELU (Gaussian Error Linear Unit)

def gelu(x):
    return x * 0.5 * (1 + torch.tanh(0.7978845608 * (x + 0.044715 * x**3)))

LLaMA uses: SwiGLU — a gated activation:

# SwiGLU uses 3 weight matrices instead of 2
def swiglu_ffn(x):
    gate = F.silu(x @ W1)    # Swish(x·W1) — the gate
    up   = x @ W2             # Linear path
    return (gate * up) @ W3   # Element-wise multiply, then project
    # Note: hidden_dim = int(8/3 * d_model) to keep parameter count same

Where SiLU(x) = x * sigmoid(x) (Swish activation).

Why SwiGLU? The gating mechanism allows the network to selectively pass or block information. Empirically shows ~8% perplexity improvement over GELU at the same parameter count, from the PaLM paper.

Interview corner case 🎯: "Why does SwiGLU use 3 weight matrices when standard FFN uses 2?" — Because the gating structure introduces a third projection. To keep the total parameter count equal to a standard 4×-FFN, LLaMA uses hidden_dim = 8/3 × d_model instead of 4 × d_model.


Change 4: Attention — Full → Grouped Query Attention (GQA)

Standard Multi-Head Attention (MHA): Every head has its own Q, K, V.

8 heads → 8 Q projections, 8 K projections, 8 V projections

During inference, all 8 K and V heads must be cached per layer (the KV cache).

Grouped Query Attention (GQA): Multiple Q heads share K and V heads.

8 Q heads, 2 K heads, 2 V heads  (4 Q heads share each K/V head)

The KV cache is 4× smaller! This is huge for long-context inference.

Multi-Query Attention (MQA): All Q heads share ONE K and V head.

8 Q heads, 1 K head, 1 V head   (maximum KV cache savings)
MHA:          Q₁K₁V₁  Q₂K₂V₂  Q₃K₃V₃  Q₄K₄V₄   (8 K,V per layer)
GQA (g=2):    Q₁Q₂K₁V₁  Q₃Q₄K₂V₂              (2 K,V per layer)
MQA:          Q₁Q₂Q₃Q₄K₁V₁                      (1 K,V per layer)

LLaMA 3 (70B): 8 Q heads, 8 K/V heads → GQA with group size 1 = basically MHA. LLaMA 3 (8B): 32 Q heads, 8 K/V heads → GQA with group size 4. Mistral 7B: 32 Q heads, 8 K/V heads → GQA.

Interview corner case 🎯: "How much memory does GQA save, and when does it matter?" — GQA with G groups reduces KV cache by H/G where H is number of heads. For LLaMA-7B at 4K context: MHA KV cache = 2 × 32 × 4096 × 128 × 2 bytes × 32 layers ≈ 2GB. GQA (8 KV heads) reduces this to ~0.5GB. This matters enormously when serving many concurrent users or very long contexts.


Change 5: Context Length

ModelContext LengthMethod
GPT-21,024 tokensLearned absolute pos emb
GPT-32,048 tokensLearned absolute pos emb
LLaMA 12,048 tokensRoPE
LLaMA 24,096 tokensRoPE
Mistral 7B8,192 tokens (window 4096)Sliding Window + RoPE
LLaMA 3128,000 tokensRoPE + rope_scaling
GPT-4128,000 tokensUnknown

How to extend context with RoPE: The rotation frequencies in RoPE use a base of 10000. By increasing this base (e.g., to 500000 in LLaMA 3), you spread the rotation over a wider range, effectively giving the model a longer "ruler" to measure positions. Then fine-tune on longer sequences with the new base.


The Full LLaMA 2 Architecture Specification

For reference, here's exactly what a LLaMA 2 7B block looks like:

d_model  = 4096
n_layers = 32
n_heads  = 32    (Q heads)
n_kv_heads = 32  (K/V heads, same as Q for LLaMA-2-7B)
d_head   = 128   (4096 / 32)
d_ffn    = 11008 (= int(8/3 * 4096) rounded to nearest 256)
vocab_size = 32000 (SentencePiece BPE)
max_seq_len = 4096
norm_eps = 1e-5 (RMSNorm epsilon)

One forward pass:

input: (B, T) token IDs
→ token_embedding → (B, T, 4096)
→ 32 × [RMSNorm → Attention(GQA, RoPE) → residual → RMSNorm → SwiGLU FFN → residual]
→ RMSNorm
→ Linear(4096 → 32000)
→ logits (B, T, 32000)

Parameter count breakdown (7B):

Embeddings:     32000 × 4096 = 131M
Per layer:
  Attention:    4 × 4096 × 4096 = 67M    (Q, K, V, O projections)
  FFN (SwiGLU): 3 × 4096 × 11008 = 135M  (gate, up, down)
  Norms:        2 × 4096 = 8K (tiny)
  Per layer:    ~202M
32 layers:      32 × 202M = 6.46B
LM head:        tied with embeddings (no extra params)
Total:          ~6.74B ≈ 7B ✓

Mistral 7B: Two Key Innovations

Sliding Window Attention (SWA)

Standard attention: token T attends to tokens 1..T → O(T²) cost.

Sliding Window Attention: token T attends to tokens max(0, T-W)..T where W is the window size.

Standard:      T1 T2 T3 T4 T5 T6  ← all tokens attend to all past tokens
               ○  ○  ○  ○  ○  ○   (for T6)

Sliding (W=3): T1 T2 T3 T4 T5 T6
                           ○  ○  ○  (for T6, only sees T4, T5, T6)

But wait — how does SWA capture long-range context? Through depth! At layer 2, information from position T-1 and T-2 (from the window at layer 1) is present in the token's representation at position T-1. At layer 3, information from T-3 is now available. With L layers and window W, the effective receptive field is L×W.

Mistral uses W=4096 with L=32 layers → effective context = ~131K tokens, despite the 8K sequence limit.

Interview corner case 🎯: "How does sliding window attention relate to the O(n²) problem?" — SWA is O(n×W) where W is fixed, so O(n) as n grows. This enables much longer sequences. Flash Attention with SWA can be implemented efficiently.

Rolling KV Cache

For SWA, you only need to cache the last W=4096 K/V pairs per layer. Older entries can be evicted. This bounds the KV cache at W × n_layers × 2 × d_head × n_kv_heads regardless of generation length.


Phi-3 and the Data Quality Revolution

Microsoft's Phi models introduced an important insight: data quality matters more than data quantity at small scales.

Phi-1 (1.3B params) was trained on ~7B tokens of curated Python exercises and textbooks. It outperformed models trained on 10× more tokens of raw web data on coding benchmarks.

Phi-3 mini (3.8B params) trained on 3.3T tokens of highly filtered web data + synthetic data, achieving performance near LLaMA 2 70B on many tasks.

The lesson: don't just throw more data at a small model. Curate it carefully.

"Textbooks Are All You Need" — the philosophy behind Phi.


Summary Table: Architecture Features by Model

FeatureGPT-2GPT-3LLaMA 1LLaMA 2Mistral 7BLLaMA 3
NormPost-LNPost-LNPre-RMSNormPre-RMSNormPre-RMSNormPre-RMSNorm
Pos encodingLearned absLearned absRoPERoPERoPERoPE
AttentionMHAMHA + sparseMHAMHAGQAGQA
ActivationGELUGELUSwiGLUSwiGLUSwiGLUSwiGLU
Context1K2K2K4K8K128K
TokenizerBPE (50K)BPE (50K)SP (32K)SP (32K)SP (32K)BPE (128K)
Open weights

Next: Architecture Improvements That Actually Matter — Flash Attention, RoPE, GQA, and SwiGLU explained clearly.