Chapter 8 — Frontier & Future

Part 2: State Space Models & Mamba — The O(n) Alternative to Attention

Mamba is the most exciting architecture to emerge since the Transformer itself. It achieves O(n) training complexity instead of O(n²), which matters enormously for very long sequences. This chapter gives you both the intuition and the math.

The Transformer's Fundamental Bottleneck

Attention is O(n²) in sequence length. For:

  • n = 4K tokens: 16M attention pairs (fine)
  • n = 64K tokens: 4B attention pairs (expensive)
  • n = 1M tokens: 1T attention pairs (impossible)

This is why even the most advanced models cap at 128K–1M context. And even at 128K, attention is extremely slow and memory-intensive.

What if we could have a model that:

  • Handles arbitrary sequence length
  • Uses O(n) memory and compute
  • Still captures long-range dependencies?

The State Space Model Intuition

A State Space Model (SSM) is a dynamical system:

$$h'(t) = A h(t) + B x(t)$$
$$y(t) = C h(t) + D x(t)$$

Where:

  • x(t) = input at time t
  • h(t) = hidden state at time t (captures all past information)
  • y(t) = output at time t
  • A, B, C, D = learned parameters

The key insight: This is a linear recurrence. You can compute the hidden state at any time t using just h(t-1) and x(t). This is O(n) — no need to look at all past tokens.

But wait — doesn't this have the same problem as RNNs? The hidden state is a fixed-size vector. Won't it bottleneck?

Yes — but SSMs have a mathematical trick that makes them fundamentally different from LSTMs.


From Continuous to Discrete SSM

The equations above are continuous-time (differential equations). For sequence modeling, we discretize with step size Δ:

$$h_t = \bar{A} h_{t-1} + \bar{B} x_t$$
$$y_t = C h_t$$

Where:

$$\bar{A} = \exp(\Delta A)$$
$$\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \Delta B$$

This discrete form runs like an RNN at inference: O(1) per step, O(n) total.

But during training, the discrete SSM can be computed as a convolution:

$$y = \text{SSM}(x) = x * k$$

Where k is the SSM's "impulse response" kernel:

$$k_0 = C \bar{B}$$
$$k_1 = C \bar{A} \bar{B}$$
$$k_2 = C \bar{A}^2 \bar{B}$$

Convolutions can be computed with FFT in O(n log n) during training. This is much faster than O(n²) attention!

So SSMs get the best of both worlds:

  • Training: O(n log n) via convolution
  • Inference: O(n) via recurrence

Interview corner case 🎯: "What is the key mathematical duality of SSMs?" — SSMs can be viewed as either recurrences (sequential, O(n) inference) or as convolutions (parallel, O(n log n) training via FFT). This duality is what makes them efficient for both training and inference.


The Weakness of Linear SSMs: No Input-Dependent Dynamics

The original SSM (A, B, C are fixed matrices) has a fatal flaw for language modeling:

The dynamics don't depend on the input. Whether the input token is "the" or "quantum entanglement," the state transition is the same. The model can't selectively remember or forget based on content.

LSTMs solved this with gates (input-dependent transitions). Can SSMs do the same?


Mamba: Selective State Spaces (2023)

Paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" — Gu and Dao

The key innovation: Make B, C, and Δ input-dependent:

# Linear SSM (fixed):
A = some_matrix          # Fixed, independent of input
B = some_matrix          # Fixed
C = some_matrix          # Fixed
Δ = some_scalar          # Fixed step size

# Mamba (selective):
B = linear(x)            # B depends on the input!
C = linear(x)            # C depends on the input!
Δ = softplus(linear(x))  # Δ depends on the input! (controls how fast state updates)

Now the state transition is content-aware. The model can learn to:

  • Remember important information (large Δ → slow state change → information persists)
  • Forget irrelevant information (small Δ → fast state change → old info discarded quickly)

This is analogous to LSTM gates, but within the SSM framework.

The tradeoff: Input-dependent transitions break the convolution interpretation (you can't use FFT anymore because the kernel changes for every input). Instead, Mamba uses a specialized "selective scan" algorithm implemented in CUDA that's still efficient.


Mamba Architecture

Input x
  ↓
[Linear projection: expand to d_inner]
  ↓
Split into two paths:
  Path 1: x' → [SSM scan] → output
  Path 2: x → [SiLU gating]
  ↓
Element-wise multiply (SiLU gating × SSM output)
  ↓
[Linear projection: back to d_model]
  ↓
Output

This is the "Mamba block" — analogous to the transformer's attention + FFN.

A full Mamba model stacks many Mamba blocks with residual connections:

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super().__init__()
        d_inner = int(expand * d_model)  # expand = 2 typically

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)

        # 1D convolution (local, short-range mixing)
        self.conv1d = nn.Conv1d(d_inner, d_inner, kernel_size=d_conv,
                                padding=d_conv-1, groups=d_inner)

        # SSM parameters
        self.x_proj = nn.Linear(d_inner, d_state * 2 + 1, bias=False)
        self.dt_proj = nn.Linear(1, d_inner)

        # A is discretization base (log-initialized for stability)
        A = torch.arange(1, d_state+1, dtype=torch.float).unsqueeze(0)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(d_inner))

        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        # ... complex selective scan implementation
        # See the official Mamba repo for the full CUDA implementation
        pass

Mamba vs. Transformer: The Honest Comparison

AspectTransformerMamba
Training complexityO(n²) per layerO(n) per layer
Inference per tokenO(n) — grows with context!O(1) — constant!
Memory for contextO(n) KV cacheO(d_state) fixed!
Long-range dependencyDirect (any token to any)Indirect (through state)
Training speed (short seqs)Fast (highly optimized)Similar
Training speed (long seqs)Slow O(n²)Fast O(n)
Recall (exact retrieval)ExcellentStruggles
GeneralizationVery goodGood, improving
Hardware optimizationMassive (years of work)Early stage

Where Mamba wins: Very long sequences (>32K tokens), streaming inference (O(1) state per step), memory-constrained devices.

Where Transformers win: Tasks requiring exact recall, copy tasks, retrieval tasks, well-studied with years of optimization. FlashAttention has narrowed the efficiency gap significantly.


Hybrid Models: Best of Both Worlds (2024)

The field is moving toward hybrid architectures that combine Mamba and attention:

Jamba (AI21 Labs): Alternates transformer and Mamba blocks:

Block 1: Transformer attention
Block 2: Mamba
Block 3: Transformer attention
Block 4: Mamba
...

Zamba (Zyphra): 7B hybrid model that matches Mistral 7B quality at lower compute cost for long sequences.

Samba (Microsoft): Combines sliding window attention + Mamba + MoE FFN. Gets complementary strengths from each.

The emerging consensus: pure attention for short sequences, Mamba or hybrid for very long sequences.


Beyond Mamba: RWKV and Other Alternatives

RWKV (Receptance Weighted Key Value): An RNN that can be trained in parallel like a transformer (using a variant of the attention formula). Runs as O(1) per step at inference. Used in models up to 14B parameters.

RetNet (Microsoft): Retention mechanism — a mathematical variant of attention that has both a parallel form (training) and a recurrent form (inference). O(n) inference, O(n log n) training.

GLA (Gated Linear Attention): Linear attention with input-dependent gating, combining the efficiency of linear attention with the selectivity of gated models.

The trend: Every few months, a new variant claims to match or beat Mamba/transformer at some specific task. The landscape is still evolving.


Interview Corner Cases — SSMs & Mamba 🎯

  • "Why can't you just use an LSTM instead of Mamba?" → LSTMs are also O(n) inference, but (1) they can't be efficiently parallelized during training (fundamental sequential dependency), (2) they use discrete gates rather than continuous state updates, (3) they have much smaller state capacity. Mamba can be trained in parallel via the selective scan algorithm while still running recurrently at inference.
  • "What is the selective scan in Mamba, and why is it called 'selective'?" → Standard SSM scan: state updates are input-independent (the A, B matrices are fixed). Selective: A, B, C, Δ depend on the input at each step. This allows the model to dynamically adjust how fast/strongly it updates its state based on content. A CUDA kernel implements this efficiently without materializing the full sequence.
  • "Can Mamba perform in-context learning like transformers?" → This is an active research question. Mamba can do some in-context learning but generally underperforms transformers on tasks that require exact recall or copying from the context. Tasks like "repeat the exact text I just gave you" are hard because the compressed state may lose exact information.
  • "What is the 'state' in Mamba and how big is it?" → The state is an (d_inner × d_state) matrix per layer per sequence position. For Mamba-3B: d_inner=5120, d_state=16 → 81,920 floats per layer. Compare to KV cache: at 4K context, 81,920 vs. 4096 × d_head × n_heads ≈ 4096 × 128 × 32 = 16.7M. Mamba's state is dramatically smaller.