How to Use einops for Cleaner Tensor Operations in PyTorch

Tensor reshaping in PyTorch is one of those things that works but does not read well. A sequence of .view(), .permute(), .unsqueeze(), and .expand() calls might be correct but is nearly impossible to verify at a glance — you have to mentally simulate the shape transformations step by step. einops solves this with a notation that makes tensor operations self-documenting: the input and output shapes are written explicitly as named dimensions in the operation string, so the transformation is legible without any mental simulation. This article covers the full einops API — rearrange, reduce, repeat, and einsum — with concrete examples drawn from real model code.

Installation and Basic rearrange

pip install einops
import torch
from einops import rearrange, reduce, repeat, einsum

# rearrange: reshape and transpose in one readable operation
x = torch.randn(2, 8, 64)   # (batch, seq_len, hidden)

# Equivalent to x.transpose(1, 2) — but the intent is explicit
x_transposed = rearrange(x, 'b s h -> b h s')

# Split a dimension — e.g., split heads for multi-head attention
# 64 hidden = 8 heads * 8 head_dim
x_heads = rearrange(x, 'b s (h d) -> b h s d', h=8)
print(x_heads.shape)  # (2, 8, 8, 8) = (batch, heads, seq_len, head_dim)

# Merge dimensions — inverse of the above
x_merged = rearrange(x_heads, 'b h s d -> b s (h d)')
print(x_merged.shape)  # (2, 8, 64) — back to original

# Flatten a batch of images into a sequence of patches (ViT-style)
images = torch.randn(4, 3, 224, 224)  # (batch, channels, H, W)
patch_size = 16
patches = rearrange(images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                    p1=patch_size, p2=patch_size)
print(patches.shape)  # (4, 196, 768) = (batch, num_patches, patch_dim)
# Compare to the PyTorch equivalent — much harder to read:
# images.unfold(2, 16, 16).unfold(3, 16, 16).permute(0,2,3,1,4,5).reshape(4, 196, 768)

The named dimension syntax is the key feature: dimensions in the operation string have names (b for batch, s for sequence, h for heads, etc.) and parentheses denote dimension merging or splitting. Unknown split sizes must be provided as keyword arguments (h=8 in the example above). This means einops operations are self-validating — if the named dimensions do not match the tensor’s actual shape, you get a clear error rather than a silently wrong tensor.

rearrange in Attention Implementations

Multi-head attention is the most common place einops improves readability significantly. The projections, head-splitting, and output merging steps that normally require several lines of view and transpose calls collapse into two rearrange operations.

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.scale = self.d_head ** -0.5
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor,
                mask: torch.Tensor | None = None) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        qkv = self.qkv(x)  # (batch, seq_len, 3*d_model)

        # Split into Q, K, V and reshape into heads — one readable line
        q, k, v = rearrange(qkv, 'b s (three h d) -> three b h s d',
                             three=3, h=self.n_heads, d=self.d_head).unbind(0)
        # q, k, v: each (batch, heads, seq_len, head_dim)

        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        out = torch.matmul(attn, v)  # (batch, heads, seq_len, head_dim)

        # Merge heads back — inverse of the split above
        out = rearrange(out, 'b h s d -> b s (h d)')
        return self.out(out)

reduce: Pooling and Aggregation

reduce applies a reduction operation (mean, sum, max, min) along named dimensions, making pooling operations explicit about which dimensions are being collapsed.

from einops import reduce

feature_maps = torch.randn(8, 512, 7, 7)  # (batch, channels, H, W)

# Global average pooling — collapse H and W
gap = reduce(feature_maps, 'b c h w -> b c', 'mean')
print(gap.shape)  # (8, 512)
# Equivalent but less readable: feature_maps.mean(dim=[-2, -1])

# Spatial max pooling with 2x2 window (like nn.MaxPool2d(2))
pooled = reduce(feature_maps, 'b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2)
print(pooled.shape)  # (8, 512, 3, 3) — note: 7 does not divide evenly by 2, use 8x8 for clean result

# Mean over the sequence dimension for sequence classification
logits = torch.randn(4, 128, 768)  # (batch, seq_len, hidden)
pooled_seq = reduce(logits, 'b s h -> b h', 'mean')
print(pooled_seq.shape)  # (4, 768)

# Reduce over batch for computing dataset statistics
batch_mean = reduce(logits, 'b s h -> s h', 'mean')
print(batch_mean.shape)  # (128, 768)

repeat: Broadcasting Without the Confusion

repeat expands a tensor along new or existing dimensions, replacing .unsqueeze().expand() chains with a single explicit operation.

from einops import repeat

# Add a batch dimension and repeat
single = torch.randn(64)  # single embedding
batched = repeat(single, 'd -> b d', b=8)
print(batched.shape)  # (8, 64)
# Equivalent: single.unsqueeze(0).expand(8, -1)

# Repeat a positional encoding for each item in a batch
pos_enc = torch.randn(128, 768)  # (seq_len, hidden)
pos_enc_batched = repeat(pos_enc, 's h -> b s h', b=4)
print(pos_enc_batched.shape)  # (4, 128, 768)

# Tile a small mask across heads in attention
mask = torch.ones(4, 128, 128)  # (batch, seq, seq)
mask_heads = repeat(mask, 'b s1 s2 -> b h s1 s2', h=8)
print(mask_heads.shape)  # (4, 8, 128, 128)

# Repeat each token embedding k times (e.g., for beam search expansion)
beam_k = 5
embeddings = torch.randn(4, 128, 768)  # (batch, seq, hidden)
expanded = repeat(embeddings, 'b s h -> (b k) s h', k=beam_k)
print(expanded.shape)  # (20, 128, 768)

einsum with Named Dimensions

einops also provides an einsum wrapper that uses the same named-dimension syntax, which is more readable than standard torch.einsum for complex contractions.

from einops import einsum

# Matrix multiplication
A = torch.randn(4, 64, 32)   # (batch, seq, d_k)
B = torch.randn(4, 32, 16)   # (batch, d_k, d_v)

# Standard torch.einsum: 'bik,bkj->bij' — requires knowing index convention
# einops einsum: explicit named dimensions
result = einsum(A, B, 'batch seq dk, batch dk dv -> batch seq dv')
print(result.shape)  # (4, 64, 16)

# Attention score computation
queries = torch.randn(4, 8, 64, 32)   # (batch, heads, seq_q, d_head)
keys    = torch.randn(4, 8, 64, 32)   # (batch, heads, seq_k, d_head)
scores  = einsum(queries, keys,
                 'b h sq d, b h sk d -> b h sq sk')
print(scores.shape)  # (4, 8, 64, 64)

# Cross-attention between two sequences
enc = torch.randn(4, 32, 512)   # (batch, enc_seq, hidden)
dec = torch.randn(4, 16, 512)   # (batch, dec_seq, hidden)
cross_attn = einsum(dec, enc, 'b ds h, b es h -> b ds es')
print(cross_attn.shape)  # (4, 16, 32)

Practical Patterns: ViT Patch Embedding

Vision Transformer patch embedding is one of the most instructive real-world uses of einops because the operation — splitting an image into a grid of patches and flattening each patch — involves three dimensions being reorganised simultaneously.

import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange  # nn.Module wrapper for use in nn.Sequential

class PatchEmbedding(nn.Module):
    """ViT patch embedding using einops layers."""
    def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.projection = nn.Sequential(
            # Rearrange is an nn.Module — works in Sequential and gets traced by torch.compile
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size, p2=patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, embed_dim),
            nn.LayerNorm(embed_dim),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

    def forward(self, x):
        b = x.shape[0]
        x = self.projection(x)                          # (b, num_patches, embed_dim)
        cls = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat([cls, x], dim=1)                  # (b, num_patches+1, embed_dim)
        return x + self.pos_embed

The Rearrange layer from einops.layers.torch is an nn.Module that wraps a rearrange operation, making it composable with nn.Sequential and compatible with torch.compile, torch.jit.script, and model export. This means you can use einops notation throughout your model architecture without any special handling at inference time.

When to Use einops vs Native PyTorch

einops is most valuable in three situations: multi-head attention and transformer implementations where head splitting and merging are frequent, ViT and convolutional architectures where spatial dimensions are reorganised repeatedly, and any code that will be read or reviewed by someone who did not write it. The notation overhead is minimal — einops operations are typically one line versus two to four lines of native PyTorch — and the readability benefit compounds when the same patterns appear many times in a codebase.

Native PyTorch is preferable for simple operations that are already readable: x.mean(dim=-1) is clearer than reduce(x, 'b s h -> b s', 'mean'), and x.unsqueeze(0) is clearer than repeat(x, '... -> 1 ...'). The rule of thumb is to reach for einops when you need to name dimensions to understand the operation, and use native PyTorch for simple axis operations where the intent is immediately obvious. Runtime performance is essentially identical — einops operations compile to the same underlying PyTorch primitives and add negligible overhead beyond a shape-validation step that can be compiled away with torch.compile.

Using einops in Convolutional Architectures

Convolutional networks involve frequent reshaping between spatial and channel dimensions — converting feature maps to sequences for attention, reorganising outputs for skip connections, and implementing operations like depth-wise separable convolution or space-to-depth transformations. einops handles all of these more readably than the equivalent PyTorch operations.

from einops import rearrange, reduce
import torch

# Space-to-depth: reorganise spatial pixels into channel dimension
# Used in efficient CNN backbones to downsample without strided conv
def space_to_depth(x: torch.Tensor, block_size: int = 2) -> torch.Tensor:
    """Convert spatial dimensions into channel dimension."""
    return rearrange(x, 'b c (h s1) (w s2) -> b (c s1 s2) h w',
                     s1=block_size, s2=block_size)

# Depth-to-space: inverse — used in super-resolution (pixel shuffle)
def depth_to_space(x: torch.Tensor, block_size: int = 2) -> torch.Tensor:
    return rearrange(x, 'b (c s1 s2) h w -> b c (h s1) (w s2)',
                     s1=block_size, s2=block_size)

feature_map = torch.randn(4, 64, 56, 56)
downsampled = space_to_depth(feature_map, block_size=2)
print(downsampled.shape)  # (4, 256, 28, 28)
restored = depth_to_space(downsampled, block_size=2)
print(restored.shape)  # (4, 64, 56, 56)

# Convert CNN feature map to sequence for a transformer encoder layer
def cnn_features_to_sequence(features: torch.Tensor) -> torch.Tensor:
    """Flatten spatial dims into sequence for transformer processing."""
    return rearrange(features, 'b c h w -> b (h w) c')

def sequence_to_cnn_features(seq: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """Restore spatial structure after transformer processing."""
    return rearrange(seq, 'b (h w) c -> b c h w', h=h, w=w)

feats = torch.randn(4, 256, 14, 14)
seq = cnn_features_to_sequence(feats)   # (4, 196, 256)
restored_feats = sequence_to_cnn_features(seq, 14, 14)  # (4, 256, 14, 14)

Debugging Shape Errors with einops

One of the practical advantages of einops over native PyTorch reshaping is the quality of error messages when shapes do not match. Native PyTorch’s view and reshape produce errors like “shape ‘[8, 16, 64]’ is invalid for input of size 8192” — correct but unhelpful for diagnosing which dimension is wrong. einops errors name the dimension that caused the mismatch and show the expected versus actual value, making debugging significantly faster.

import torch
from einops import rearrange

# Intentionally wrong: trying to split 65 into 8 heads
x = torch.randn(4, 10, 65)  # hidden_dim=65, not divisible by 8
try:
    out = rearrange(x, 'b s (h d) -> b h s d', h=8)
except Exception as e:
    print(e)
# Error: einops: the tensor has shape (4, 10, 65), but we expected
# 'h d' to produce 65. The value 65 cannot be factored as 8 * integer.
# Much more informative than PyTorch's raw size mismatch error.

# Correct version
x = torch.randn(4, 10, 64)
out = rearrange(x, 'b s (h d) -> b h s d', h=8)
print(out.shape)  # (4, 8, 10, 8)

This error clarity is especially valuable during model development when you are iterating on architecture changes and tensor shapes are in flux. The named dimension syntax forces you to think explicitly about what each dimension represents, which often surfaces shape bugs before they manifest as wrong gradients or silent failures downstream.

einops with torch.compile and TorchScript

A practical concern when adopting any third-party library is compatibility with PyTorch’s compilation and export tooling. einops is fully compatible with torch.compile — the rearrange, reduce, and repeat operations trace correctly and are optimised by the TorchInductor backend. The Rearrange and Reduce nn.Module layers work with torch.jit.trace for TorchScript export as well, though torch.jit.script has limited support because einops uses string parsing that is not scriptable. For inference deployment, the recommended path is to use torch.compile (which handles einops correctly) or to export via ONNX after running one forward pass to trace the operations.

import torch
from einops.layers.torch import Rearrange

model = torch.nn.Sequential(
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16),
    torch.nn.Linear(768, 768),
)

# torch.compile works directly — einops layers are transparent to the compiler
compiled_model = torch.compile(model)
x = torch.randn(4, 3, 224, 224)
out = compiled_model(x)
print(out.shape)  # (4, 196, 768)

# ONNX export via trace
torch.onnx.export(
    model,
    x,
    "patch_embed.onnx",
    input_names=["images"],
    output_names=["patches"],
    dynamic_axes={"images": {0: "batch"}},
)

Common Patterns Worth Memorising

A handful of einops patterns cover the majority of what you will need in practice. The batch dimension pass-through using ... (ellipsis) lets you write operations that work regardless of leading batch dimensions: rearrange(x, '... h w -> ... (h w)') flattens the last two dimensions regardless of how many batch dimensions precede them. The decompose-then-operate pattern — split a composite dimension, do some operation, merge it back — is the core of multi-head attention, grouped convolution, and any strided processing. The batch-to-space and space-to-batch transformations using parenthesis grouping in rearrange are the canonical way to implement operations that need to process all positions in a spatial grid independently without writing an explicit loop.

If you are working on a codebase that mixes einops and native PyTorch, consistency matters more than which tool you use for any specific operation. Mixing rearrange for some operations and view().permute() for others in the same file makes the code harder to follow than either approach alone. The most readable codebases pick einops as the default for anything involving dimension names and use native PyTorch only for simple single-axis operations where the PyTorch API is already self-explanatory.

Why Named Dimensions Matter for Team Code

The biggest long-term benefit of einops is not personal productivity — it is what happens when someone else reads your code six months later, or when you return to a model implementation after a long gap. Native PyTorch tensor operations are correct but opaque: a line like x.view(b, -1, self.n_heads, self.head_dim).transpose(1, 2) requires reconstructing what the tensor’s dimensions mean before you can verify that the transformation is correct. With einops, that same operation is rearrange(x, 'b s (h d) -> b h s d', h=self.n_heads) — the intent is encoded directly in the syntax. You do not need to trace through the preceding code to understand what dimension 1 and dimension 2 represent; the operation string tells you.

This readability advantage compounds in larger codebases. A transformer implementation with ten attention layers, each with QKV projections, head splits, and output merges, involves thirty or more tensor reshaping operations. In pure PyTorch, each of those operations is a potential source of confusion for a reviewer trying to verify the implementation against a paper. In einops, each operation is self-documenting, and a reviewer can check the paper’s dimension convention against the operation strings directly. In practice, this reduces the time to review a model implementation and makes bugs in dimension handling much easier to catch during code review, before they show up as subtle training failures.

Integrating einops into an Existing Codebase

Adding einops to an existing PyTorch project is low-risk — it is a pure Python library with no C extensions beyond what PyTorch already provides, and it has no effect on the computation graph beyond the equivalent native PyTorch operations. The recommended integration strategy is incremental: start by replacing the most complex reshaping sequences in your model code, specifically any sequence involving three or more view, permute, unsqueeze, or expand calls chained together. These are the operations that benefit most from named dimensions and are most likely to have subtle errors that einops would have caught at the shape-validation step.

Do not replace simple single-axis operations — x.mean(-1), x.unsqueeze(0), x.flatten(1) — with einops equivalents. The native PyTorch operations are already readable for simple cases, and replacing them adds verbosity without clarity. The goal is to use einops where it genuinely improves comprehension, not to adopt it as a wholesale replacement for the PyTorch tensor API. A codebase that uses einops selectively for complex multi-dimensional operations and native PyTorch for simple ones is more readable than one that uses either exclusively.

For new projects, the approach most teams find productive is to establish einops as the default for all attention-related reshaping from the start, use the Rearrange and Reduce nn.Module layers for any reshaping that lives inside nn.Sequential blocks, and adopt native PyTorch for everything else unless a specific operation is clearer with einops notation. This keeps the dependency lightweight while capturing most of the readability benefit where it matters most — in the attention and embedding code that is typically the most complex and most frequently revisited part of a deep learning codebase.

einops and Type Checkers

One underappreciated benefit of einops is its compatibility with static analysis tools. Because tensor shapes are encoded in strings rather than implicit in the sequence of operations, tools like Beartype and jaxtyping can be combined with einops to add runtime shape assertions that catch mismatches at the point of operation rather than producing cryptic errors many lines later. The named dimensions also serve as implicit documentation for type annotation systems — when you write rearrange(x, 'batch seq hidden -> batch hidden seq'), the dimension names function as lightweight shape annotations that make the expected tensor semantics clear to any reader, human or tooling. This is particularly valuable in codebases that are moving toward more rigorous type safety, where einops provides a natural on-ramp to explicit shape documentation without requiring a full migration to a tensor typing library. The combination of clear shape notation in operation strings and runtime validation in einops makes dimension errors surface immediately and with full context, which is the single biggest source of debugging time savings for ML engineers working with complex multi-dimensional models. For teams adopting einops today, the library is stable, actively maintained, and has become the de facto standard for readable tensor operations in research and production transformer code alike. The pip install einops is one of the highest return-on-investment additions to any ML project’s requirements file.

Leave a Comment