Transformers have dominated sequence modelling since 2017, but they carry a fundamental architectural cost: attention is quadratic in sequence length. Processing a sequence of N tokens requires computing N² attention scores, which means doubling the context window quadruples the compute and memory required for attention. For most practical LLM deployments this is manageable, but it makes transformers poorly suited to tasks requiring very long context — genomic sequences, hour-long audio, full codebases, or multi-hour document sets — where sequence lengths reach tens or hundreds of thousands of tokens. State space models (SSMs), and Mamba in particular, address this with a fundamentally different approach that scales linearly with sequence length while maintaining competitive performance on language tasks.
State Space Models: The Core Idea
State space models originate in control theory, where they describe dynamic systems through a hidden state that summarises the system’s history. In the SSM formulation for sequence modelling, the hidden state h(t) is a continuous-valued vector that is updated as each new input x(t) arrives, and the output y(t) is computed from h(t). The key parameters are the state transition matrix A (how the hidden state evolves over time), the input matrix B (how inputs update the state), and the output matrix C (how the state maps to outputs). Discretising this continuous-time system for use with discrete token sequences gives the recurrence:
import torch
import torch.nn as nn
class SimpleSSM(nn.Module):
"""Minimal SSM implementation to illustrate the core recurrence."""
def __init__(self, d_model: int, d_state: int):
super().__init__()
self.d_state = d_state
# Learnable SSM parameters
self.A = nn.Parameter(torch.randn(d_model, d_state)) # state transition
self.B = nn.Parameter(torch.randn(d_model, d_state)) # input projection
self.C = nn.Parameter(torch.randn(d_model, d_state)) # output projection
self.D = nn.Parameter(torch.ones(d_model)) # skip connection
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
batch, seq_len, d_model = x.shape
h = torch.zeros(batch, d_model, self.d_state, device=x.device)
outputs = []
# Sequential recurrence — O(L) in sequence length, O(N) in state size
for t in range(seq_len):
x_t = x[:, t, :] # (batch, d_model)
# Discretised state update
dA = torch.exp(self.A) # ensures stability
dB = self.B
h = h * dA.unsqueeze(0) + x_t.unsqueeze(-1) * dB.unsqueeze(0)
y_t = (h * self.C.unsqueeze(0)).sum(-1) + x_t * self.D
outputs.append(y_t)
return torch.stack(outputs, dim=1) # (batch, seq_len, d_model)
This recurrence runs in O(L·N) time where L is the sequence length and N is the state dimension — linear in L, compared to transformer attention’s O(L²). The hidden state h summarises all previous inputs in a fixed-size vector, enabling constant-memory inference regardless of sequence length. The tradeoff is that this fixed-size state is a lossy compression of history — the model cannot attend back to arbitrary past positions the way transformer attention can, which is why SSMs tend to struggle more than transformers on tasks requiring precise recall of specific earlier tokens.
Mamba: Selective State Spaces
The original SSM formulations (S4, S4D, H3) used time-invariant parameters — the A, B, C matrices were fixed regardless of the input. Mamba’s key innovation is making these parameters input-dependent: the state transition and input projection matrices are computed dynamically from the current input token, allowing the model to selectively decide what information to retain in state and what to forget. This selective mechanism is analogous to the gating in LSTMs but operates on a continuous state space rather than discrete gates, and it dramatically improves Mamba’s ability to perform selective recall — remembering specific earlier tokens when the current context makes them relevant.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MambaBlock(nn.Module):
"""Simplified Mamba block illustrating selective SSM mechanism."""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
super().__init__()
self.d_inner = d_model * expand
self.d_state = d_state
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv,
padding=d_conv - 1, groups=self.d_inner)
# Input-dependent SSM parameter projections (the selective mechanism)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner)
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1)
.float().repeat(self.d_inner, 1)))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
# Split into SSM branch and gate branch
xz = self.in_proj(x)
x_ssm, z = xz.chunk(2, dim=-1)
# Depthwise conv for local context
x_ssm = self.conv1d(x_ssm.transpose(1, 2)).transpose(1, 2)
x_ssm = F.silu(x_ssm)
# Compute input-dependent B, C, dt (the selective parameters)
x_dbl = self.x_proj(x_ssm)
dt, B, C = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
dt = F.softplus(self.dt_proj(dt)) # discretisation step size
# SSM computation (simplified — real Mamba uses parallel scan)
A = -torch.exp(self.A_log.float())
# ... selective scan over sequence ...
y = x_ssm * self.D # skip connection (simplified)
y = y * F.silu(z) # gating
return self.out_proj(y) + residual
The selective scan is the computationally critical operation in Mamba. The naive sequential recurrence is slow on GPU because it cannot be parallelised across sequence positions. Mamba’s CUDA implementation uses a parallel scan algorithm (similar to prefix sum) that achieves both the linear scaling of the recurrence and the parallel efficiency of GPU hardware. This is why Mamba requires its custom CUDA kernels to achieve its claimed throughput — a pure PyTorch implementation of the selective scan falls back to sequential execution and is much slower.
Mamba vs Transformer: Performance Comparison
On standard language modelling benchmarks at the same parameter count, Mamba matches or slightly exceeds transformer performance up to roughly the 1–3B parameter scale. The Mamba paper reported perplexity competitive with transformer baselines on The Pile at model sizes from 130M to 2.8B parameters, with inference throughput 5x higher than transformers at sequence length 16K. The throughput advantage grows with sequence length: at 1K tokens Mamba is roughly 2–3x faster than a transformer; at 16K tokens it is 5x faster; at 64K tokens the gap is even larger because transformer attention’s quadratic cost increasingly dominates.
Where Mamba underperforms transformers is on tasks requiring precise in-context retrieval — looking up a specific fact mentioned earlier in a long document, copying a specific string from the input, or answering questions about events that appeared early in a very long context. The fixed-size state cannot retain all earlier information with perfect fidelity, so Mamba’s recall degrades over long sequences in a way that transformer attention does not, because attention can always attend directly back to any earlier position. Hybrid architectures (Mamba layers interleaved with sparse attention layers) address this tradeoff and have shown strong results — Jamba (Mamba + transformer hybrid) and Zamba achieve near-transformer recall quality with substantially lower inference cost.
Using Mamba in Practice
The mamba-ssm library provides the reference Mamba implementation with the optimised CUDA kernels. It requires a CUDA-capable GPU and a compatible PyTorch installation:
from mamba_ssm import Mamba
import torch
# Single Mamba layer
mamba_layer = Mamba(
d_model=512, # model dimension
d_state=16, # SSM state expansion factor
d_conv=4, # local convolution width
expand=2, # block expansion factor
).cuda()
x = torch.randn(2, 1024, 512).cuda() # (batch, seq_len, d_model)
y = mamba_layer(x) # same shape, O(L) compute
# Full Mamba language model via transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('state-spaces/mamba-2.8b-hf')
model = AutoModelForCausalLM.from_pretrained(
'state-spaces/mamba-2.8b-hf',
torch_dtype=torch.float16,
device_map='auto'
)
inputs = tokenizer("The key advantage of state space models is", return_tensors='pt').to('cuda')
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Mamba inference is particularly efficient for long-sequence generation because the recurrent computation at each decoding step takes constant time — the model maintains a fixed-size hidden state rather than a growing KV cache. This makes Mamba well-suited to streaming inference and memory-constrained deployment scenarios where transformer KV cache memory grows linearly with sequence length and eventually becomes the binding constraint. For a 7B transformer model generating 16K tokens with a 4096-dimensional KV cache, the KV cache alone requires roughly 16GB; a Mamba model of comparable size requires constant state memory regardless of sequence length.
Mamba 2 and SSM Variants
Mamba 2 (2024) refined the selective SSM by connecting it more explicitly to linear attention and structured state space duality, achieving better hardware utilisation through a larger state dimension and a restructured parallel scan that maps more efficiently to tensor cores. The practical implication is higher throughput than Mamba 1 at equivalent quality, particularly at the larger state dimensions needed for strong performance on harder language tasks. The mamba-ssm library supports both Mamba and Mamba2 layers; Mamba2 is generally preferred for new work.
Other SSM variants worth knowing: RWKV combines ideas from SSMs and transformers with a recurrent architecture that also runs in linear time, and has an active open-source community with models up to 14B parameters. RetNet proposes a retention mechanism with training-time parallelism and inference-time recurrence, similar in spirit to Mamba but with a different mathematical formulation. Griffin (DeepMind) interleaves recurrent layers with local attention to balance the recall tradeoff. The common thread across all these architectures is linear-time sequence processing with some form of selective gating — the research question is which specific formulation best balances the recall-efficiency tradeoff across diverse tasks.
When to Consider Mamba Over a Transformer
Use Mamba or a Mamba-transformer hybrid when sequence length is the primary constraint — tasks where you need to process or generate sequences of 32K tokens or more, where transformer attention’s quadratic cost is the binding compute or memory bottleneck. Audio and speech modelling (raw waveform processing at high sample rates), genomics (processing chromosome-length sequences), and time series with very long history windows are the strongest candidates. For standard LLM tasks at typical context lengths (4K–32K tokens), transformers remain the default because the ecosystem, tooling, and available pretrained models are substantially more mature, and the performance difference at these lengths is small enough that the engineering overhead of switching to Mamba is rarely justified. Monitor the hybrid architectures — Jamba-style models that combine Mamba efficiency with transformer recall quality at long context are the most likely path to Mamba-class architectures seeing widespread production deployment over the next 1–2 years.
The KV Cache Problem Mamba Solves
To appreciate why Mamba matters for inference, it helps to understand the concrete cost structure of transformer generation. During autoregressive decoding, a transformer must store the key and value projections for every token generated so far — the KV cache. For a model with hidden dimension d, L layers, and H attention heads, the KV cache size per token is 2 × L × H × (d/H) × 2 bytes (for float16) = 4Ld bytes. For a 7B parameter transformer with 32 layers and 4096 hidden dimension, each token adds roughly 512KB to the KV cache. At 8K tokens of generated context, the KV cache alone consumes 4GB — a significant fraction of a 24GB consumer GPU. At 32K tokens it is 16GB, leaving little room for model weights or batch parallelism.
Mamba replaces this growing cache with a fixed-size recurrent state. The state for a Mamba layer is a tensor of shape (batch, d_inner, d_state) — for a Mamba-2.8B model with d_inner=5120 and d_state=16, this is roughly 80K float32 values per layer, around 1.2MB for the full model regardless of sequence length. This constant memory footprint is what makes Mamba attractive for long-context and memory-constrained deployment, and it is why streaming generation from a Mamba model has a fundamentally different memory profile than streaming from a transformer.
Training Mamba Models
Training Mamba uses the parallel scan mode: during training, the entire input sequence is processed simultaneously using the parallel prefix scan algorithm, which achieves O(L log L) work rather than the O(L) work of sequential recurrence but parallelises efficiently across GPU threads. This means training throughput for Mamba is comparable to transformers at typical sequence lengths, despite the inference recurrence being sequential. The practical training experience is similar to transformers: standard AdamW optimisation, cosine learning rate schedule, gradient clipping at 1.0, and mixed precision (bfloat16) work well. Mamba is generally less sensitive to learning rate than transformers and tolerates slightly higher learning rates without instability, though this advantage is modest and not a reason to prefer Mamba for training-dominated workloads.
Fine-tuning pretrained Mamba models follows the same patterns as transformer fine-tuning. LoRA adapters attach to the in_proj and out_proj linear layers of each Mamba block — the same rank and alpha settings that work for transformer fine-tuning (rank 8–64, alpha equal to rank) transfer well to Mamba. Full fine-tuning is also straightforward and does not require special handling of the SSM parameters beyond standard gradient flow. The mamba-ssm library’s MambaLMHeadModel is compatible with the Hugging Face Trainer and PEFT libraries, so existing transformer fine-tuning pipelines require minimal modification to work with Mamba models.
Limitations and Open Questions
Mamba’s main unresolved limitation is in-context learning. Transformer models at scale develop strong few-shot in-context learning — the ability to learn a new task from a handful of examples in the prompt — and this capability is tightly linked to the attention mechanism’s ability to precisely retrieve and compare earlier examples. Mamba’s lossy state compression degrades few-shot performance relative to transformers of equivalent size, particularly on tasks with many in-context examples or tasks requiring exact copy from earlier in the context. Whether Mamba at sufficient scale or with hybrid attention layers can close this gap remains an active research question. For practitioners evaluating Mamba for production use, the recommendation is to benchmark specifically on few-shot and long-context retrieval tasks relevant to your application before committing — the aggregate benchmark numbers are favourable, but they average over many task types and may mask poor performance on the specific capabilities your application depends on.