Matryoshka Representation Learning: How It Works and Why It Matters for RAG

Most embedding models produce a single fixed-size vector — 768 dimensions, 1536 dimensions, or similar — and that’s the only representation you get. If you need smaller embeddings for cost or latency reasons, you either use a different model or apply dimensionality reduction after the fact, both of which degrade quality. Matryoshka Representation Learning (MRL) solves this by training a single model to produce embeddings that are simultaneously useful at multiple truncated dimensions. You can take the first 64, 128, 256, or 512 dimensions of a 1024-dimensional MRL embedding and each truncated version is independently meaningful — the model has been explicitly trained to pack the most important semantic information into the early dimensions. This article covers how MRL works, how to train and fine-tune MRL embedding models with sentence-transformers, and the practical tradeoffs for RAG and vector search applications.

How Matryoshka Representation Learning Works

Standard embedding training optimises a single loss computed on the full embedding vector. MRL modifies the training objective to optimise losses at multiple embedding dimensions simultaneously. For a model that produces 1024-dimensional embeddings, MRL adds loss terms at a set of nested dimensions — for example, {64, 128, 256, 512, 1024} — and the total training loss is a weighted sum of the losses at each dimension. Each partial loss treats the first m dimensions of the embedding as a complete representation and evaluates how well those m dimensions capture the semantic relationships in the training batch.

The key insight is that by consistently penalising poor quality at smaller dimensions during training, the model learns to organise information hierarchically: the first 64 dimensions capture the coarsest, most important semantic distinctions, the next 64 add finer-grained information, and so on up to the full dimension. This is where the name comes from — the nested structure mirrors Matryoshka dolls, where each doll contains a complete smaller version of itself. The result is that truncating an MRL embedding to a smaller size degrades gracefully and predictably, rather than catastrophically as it does for standard embeddings not trained with this objective.

Training an MRL Embedding Model with sentence-transformers

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from datasets import Dataset

# Load a base model to fine-tune with MRL
model = SentenceTransformer("BAAI/bge-base-en-v1.5")

# Define the nested dimensions to train at
# These should be powers of 2 and include the full model dimension
matryoshka_dims = [64, 128, 256, 512, 768]  # bge-base is 768-dimensional

# Wrap the base loss with MatryoshkaLoss
# MultipleNegativesRankingLoss is the standard choice for retrieval fine-tuning
base_loss = MultipleNegativesRankingLoss(model)
loss = MatryoshkaLoss(
    model,
    loss=base_loss,
    matryoshka_dims=matryoshka_dims,
    matryoshka_weights=[1, 1, 1, 1, 1],  # equal weight at each dimension
)

# Training data: pairs of (query, positive_passage)
train_data = Dataset.from_dict({
    "anchor":   ["what is gradient checkpointing", "how does LoRA work"],
    "positive": ["Gradient checkpointing saves memory by recomputing activations",
                 "LoRA adds low-rank adapter matrices to frozen model weights"],
})

args = SentenceTransformerTrainingArguments(
    output_dir="mrl-bge-base",
    num_train_epochs=3,
    per_device_train_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_data,
    loss=loss,
)
trainer.train()
model.save_pretrained("mrl-bge-base-finetuned")

The matryoshka_weights parameter controls the relative importance of each dimension in the combined loss. Setting all weights to 1 treats each dimension equally. In practice, you may want to upweight smaller dimensions slightly (e.g., 2 for 64d, 1.5 for 128d, 1 for larger) if your deployment target is primarily the smaller truncated sizes, to push the model to work harder at packing information into the early dimensions at the expense of marginal quality at full size.

Using MRL Embeddings at Different Dimensions

from sentence_transformers import SentenceTransformer
import numpy as np

model = SentenceTransformer("mrl-bge-base-finetuned")

texts = [
    "How does gradient checkpointing reduce memory?",
    "Explain LoRA fine-tuning",
    "What is knowledge distillation?",
]

# Generate full embeddings — then truncate to the desired dimension
full_embeddings = model.encode(texts, normalize_embeddings=True)
print(f"Full embedding shape: {full_embeddings.shape}")  # (3, 768)

def get_mrl_embeddings(texts, dim, model):
    """Encode and truncate to a specific MRL dimension, then renormalise."""
    embeddings = model.encode(texts, normalize_embeddings=False)
    truncated = embeddings[:, :dim]
    # Renormalise after truncation — important for cosine similarity to work correctly
    norms = np.linalg.norm(truncated, axis=1, keepdims=True)
    return truncated / np.maximum(norms, 1e-12)

emb_64  = get_mrl_embeddings(texts, 64,  model)
emb_128 = get_mrl_embeddings(texts, 128, model)
emb_256 = get_mrl_embeddings(texts, 256, model)

print(f"64d shape:  {emb_64.shape}")   # (3, 64)
print(f"128d shape: {emb_128.shape}")  # (3, 128)

# Memory comparison: storing 1M vectors
dim_768_gb  = 1_000_000 * 768  * 4 / 1e9  # float32
dim_128_gb  = 1_000_000 * 128  * 4 / 1e9
dim_64_gb   = 1_000_000 * 64   * 4 / 1e9
print(f"768d: {dim_768_gb:.2f} GB | 128d: {dim_128_gb:.2f} GB | 64d: {dim_64_gb:.2f} GB")
# 768d: 3.07 GB | 128d: 0.51 GB | 64d: 0.26 GB

MRL for Two-Stage Retrieval in RAG

The most compelling production use case for MRL is two-stage retrieval: use a small-dimension embedding (64d or 128d) for the initial ANN search over the full corpus, retrieve a large candidate set (top 100–500), then re-score that candidate set using the full-dimension embedding. This gives you most of the quality of full-dimension retrieval at a fraction of the ANN index memory and query latency cost, because the expensive full-dimension comparison is only done on a small reranking pool rather than the entire corpus.

import numpy as np
import faiss

def build_two_stage_index(corpus_texts, model, initial_dim=128, full_dim=768):
    """Build a two-stage MRL retrieval index."""
    print("Encoding corpus...")
    full_embeddings = model.encode(corpus_texts, normalize_embeddings=False,
                                   batch_size=256, show_progress_bar=True)
    # Store full embeddings for reranking
    full_normed = full_embeddings / np.linalg.norm(full_embeddings, axis=1, keepdims=True)

    # Build FAISS index on truncated embeddings for fast initial retrieval
    initial_embeddings = full_embeddings[:, :initial_dim].copy()
    norms = np.linalg.norm(initial_embeddings, axis=1, keepdims=True)
    initial_normed = initial_embeddings / np.maximum(norms, 1e-12)

    index = faiss.IndexFlatIP(initial_dim)  # inner product = cosine for normalised vecs
    index.add(initial_normed.astype(np.float32))

    return index, full_normed

def two_stage_search(query, model, index, full_embeddings,
                     initial_dim=128, initial_k=200, final_k=10):
    """Two-stage MRL retrieval: fast approximate search, then exact reranking."""
    # Stage 1: fast search with truncated query embedding
    query_full = model.encode([query], normalize_embeddings=False)
    query_initial = query_full[:, :initial_dim]
    query_initial /= np.maximum(np.linalg.norm(query_initial, axis=1, keepdims=True), 1e-12)

    _, candidate_ids = index.search(query_initial.astype(np.float32), initial_k)
    candidate_ids = candidate_ids[0]

    # Stage 2: rerank candidates with full-dimension embedding
    query_normed = query_full / np.maximum(np.linalg.norm(query_full, axis=1, keepdims=True), 1e-12)
    candidate_embeddings = full_embeddings[candidate_ids]  # (initial_k, full_dim)
    scores = (candidate_embeddings @ query_normed.T).squeeze()

    top_k_local = np.argsort(scores)[::-1][:final_k]
    return candidate_ids[top_k_local], scores[top_k_local]

Quality vs Dimension Tradeoffs in Practice

The quality degradation from truncation varies by model and training data, but a well-trained MRL model typically shows surprisingly small drops at moderate truncation. For the OpenAI text-embedding-3 models (which use MRL), the quality at 256 dimensions is within 1–3% of full-dimension quality on MTEB benchmarks, while the memory footprint is reduced by 6x. At 64 dimensions, quality drops more noticeably — typically 5–10% below full dimension on retrieval tasks — but is still competitive with older, smaller embedding models at their full dimension.

The right dimension to deploy depends on your tolerance for quality degradation and your infrastructure constraints. For RAG applications where the retrieval is followed by an LLM that can compensate for imperfect recall, 128d or 256d is often a good balance. For semantic search applications where the embedding score is the final signal (no reranking, no LLM), you should benchmark carefully on your specific domain data before truncating aggressively. Quality degradation is not uniform across domains — technical text and code tend to need more dimensions to discriminate well compared to general English prose.

from sentence_transformers.evaluation import InformationRetrievalEvaluator

def benchmark_mrl_dimensions(model, queries, corpus, relevant_docs,
                               dims=(64, 128, 256, 512, 768)):
    """Measure NDCG@10 at each MRL dimension on your domain data."""
    results = {}
    for dim in dims:
        # Temporarily patch the model to truncate output
        class TruncatedModel:
            def encode(self, texts, **kwargs):
                embs = model.encode(texts, normalize_embeddings=False, **kwargs)
                truncated = embs[:, :dim]
                norms = np.linalg.norm(truncated, axis=1, keepdims=True)
                return truncated / np.maximum(norms, 1e-12)

        evaluator = InformationRetrievalEvaluator(
            queries=queries,
            corpus=corpus,
            relevant_docs=relevant_docs,
            name=f"dim_{dim}",
        )
        score = evaluator(TruncatedModel())
        results[dim] = score
        print(f"  dim={dim:4d}: NDCG@10 = {score:.4f}")
    return results

MRL vs Post-hoc Dimensionality Reduction

The obvious alternative to MRL is to train a standard embedding model and then apply PCA or another dimensionality reduction technique to compress the embeddings. This is simpler to implement but consistently produces lower quality at the same reduced dimension compared to MRL. The reason is that PCA optimises for preserving variance in the embedding space, which is not the same as preserving the semantic relationships that matter for retrieval. MRL directly optimises retrieval quality at each dimension during training, which means the truncated dimensions are shaped by the actual task objective rather than by the geometry of the full embedding space.

A practical benchmark: on the MTEB retrieval benchmark, MRL embeddings truncated to 256d typically outperform PCA-compressed embeddings at 256d by 3–6 NDCG points, with the gap widening at more aggressive truncations (64d, 128d). For production systems where embedding dimension directly translates to index memory and query latency, that gap is meaningful enough that MRL is the right choice if you have the option. The main reason to use PCA instead is if you have an existing standard embedding model in production and cannot retrain — PCA gives you some dimension reduction with zero training cost, at the cost of retrieval quality relative to a purpose-trained MRL model.

Which Models Support MRL Out of the Box

Several widely-used embedding models are already trained with MRL and can be truncated without any fine-tuning on your part. OpenAI’s text-embedding-3-small and text-embedding-3-large both use MRL — the API accepts a dimensions parameter that returns truncated embeddings directly, and OpenAI’s own benchmarks show text-embedding-3-large at 256d outperforms the older text-embedding-ada-002 at its full 1536d on most MTEB tasks. In the open-source ecosystem, the nomic-embed-text-v1.5 model from Nomic AI is MRL-trained and produces high-quality embeddings at 64, 128, 256, and 768 dimensions. BAAI’s bge-m3 model supports variable output dimensions through a similar mechanism. For most new projects starting from scratch, using one of these pretrained MRL models is the right first choice — fine-tuning an MRL model on domain-specific data only makes sense once you have established that the off-the-shelf model quality is insufficient for your retrieval task.

from openai import OpenAI

client = OpenAI()

def embed_with_mrl(texts: list[str], dim: int = 256,
                   model: str = "text-embedding-3-large") -> list[list[float]]:
    """Use OpenAI's MRL embeddings at a specified dimension."""
    response = client.embeddings.create(
        input=texts,
        model=model,
        dimensions=dim,   # MRL truncation happens server-side
    )
    return [item.embedding for item in response.data]

# 256d costs the same as full 3072d but uses 12x less index memory
embeddings_256 = embed_with_mrl(["sample text", "another document"], dim=256)
print(f"Embedding size: {len(embeddings_256[0])}")  # 256

Storage and Latency Impact at Scale

The practical impact of MRL on production RAG infrastructure depends on corpus size and query volume. For a corpus of 10 million documents, moving from 768d float32 embeddings to 128d reduces the raw vector storage from about 30 GB to 5 GB — enough to shift from requiring a dedicated high-memory vector database instance to fitting comfortably in a smaller one. More importantly, ANN search latency scales with embedding dimension: FAISS HNSW and IVF indexes both run meaningfully faster at lower dimensions, which reduces p99 query latency for high-QPS applications.

Quantisation and MRL are complementary, not alternatives. You can combine 128d MRL truncation with int8 quantisation to get a further 4x storage reduction (128d int8 uses about 128 bytes per vector versus 3,072 bytes for 768d float32 — a 24x reduction overall). At this level of compression, a 100 million document corpus fits in roughly 13 GB of RAM, which changes the economics of self-hosted vector search substantially. The quality tradeoff of combining aggressive truncation and quantisation needs to be benchmarked on your data, but for general English retrieval tasks the combined NDCG degradation relative to 768d float32 is typically in the 3–7% range — acceptable for most applications and recoverable with a reranking stage.

When MRL Is Not the Right Tool

MRL is specifically designed for the use case where you want a single model to serve multiple deployment targets with different dimension budgets. If your entire deployment uses a single fixed embedding size and there is no need to trade off quality for storage or latency, MRL provides no benefit over standard embedding training — the multi-dimension training objective adds some training complexity without affecting the quality of the full-dimension output in any meaningful way. MRL also does not help with the orthogonal problem of domain adaptation: if your retrieval quality is poor because the model was not trained on domain-similar text, you need domain fine-tuning (using hard negative mining or contrastive loss on your domain data), not a different training objective for dimension flexibility. The two fine-tuning goals — domain adaptation and dimension flexibility — can be combined by training with both a domain-specific contrastive loss and the MatryoshkaLoss wrapper simultaneously, which is the recommended approach when you need both.

Choosing a Dimension for Your Application

The practical decision framework for choosing an MRL dimension is straightforward. Start by measuring the NDCG@10 or recall@K of your retrieval pipeline at the full model dimension on a representative sample of your production query distribution — this is your quality ceiling. Then benchmark at progressively smaller dimensions (512, 256, 128, 64) and find the smallest dimension where the quality drop is within your tolerance, typically defined as no more than 2–5 NDCG points below full-dimension quality. For most general-domain RAG applications, 256d hits this threshold comfortably. For highly specialised technical or scientific domains, 512d is often the smallest dimension that preserves acceptable quality.

Once you have a candidate dimension from benchmarking, measure the end-to-end impact on your application — not just retrieval metrics, but final answer quality if the retrieval feeds an LLM. Retrieval recall at 128d versus 256d may differ by 2%, but if your LLM can work with 20 retrieved passages rather than 10, the recall gap can be closed by retrieving more candidates at the smaller dimension, keeping quality constant while still realising the storage and latency benefits. This is the core advantage of two-stage MRL retrieval: the initial dimension controls your infrastructure costs, while the reranking dimension controls your quality ceiling, and you tune both independently based on your actual constraints.

MRL has quietly become the default approach for production embedding model deployment among teams that care about inference cost, and the ecosystem support in sentence-transformers, OpenAI’s API, and open-source models like nomic-embed and bge-m3 means you no longer need to implement it from scratch. The combination of a pretrained MRL model, a well-chosen deployment dimension, and a two-stage retrieval pipeline gives you a system that is simultaneously cheaper to run and easier to tune than the single-dimension alternatives that preceded it. If you are building or maintaining a RAG pipeline today and have not yet measured the quality-dimension curve on your corpus, that benchmark is the highest-value 30 minutes you can spend on your retrieval infrastructure.

Leave a Comment