How to Write Tests for ML Models with pytest

Most ML codebases have no tests. Data preprocessing pipelines, model forward passes, training loops, and inference endpoints all run without any automated verification, which means bugs surface as degraded metrics weeks later or as production incidents. Writing tests for ML code is harder than for standard software because the outputs are probabilistic and the “correct” behaviour is often statistical rather than exact, but the patterns for doing it well are well-established. This article covers how to structure a pytest test suite for an ML project, what to actually test at each layer of the stack, and how to handle the specific challenges that make ML testing different from ordinary unit testing.

Project Structure and Test Organisation

A well-structured ML project separates tests by what they are testing and how expensive they are to run. Fast unit tests (milliseconds each) should run on every commit; slow integration tests (minutes each, requiring GPU or large data) should run on a schedule or before merges to main.

project/
├── src/
│   ├── data/          # dataset, preprocessing, augmentation
│   ├── models/        # architecture, forward pass
│   ├── training/      # loss functions, training loop, optimiser setup
│   └── inference/     # prediction, postprocessing, serving logic
└── tests/
    ├── conftest.py    # shared fixtures: small datasets, tiny models, device setup
    ├── unit/
    │   ├── test_data.py        # preprocessing, tokenisation, augmentation
    │   ├── test_models.py      # shape checks, output range, gradient flow
    │   └── test_losses.py      # loss function correctness
    ├── integration/
    │   ├── test_training.py    # training loop converges on a toy problem
    │   └── test_inference.py   # end-to-end prediction pipeline
    └── regression/
        └── test_metrics.py     # known inputs produce expected metric values
# conftest.py — shared fixtures available to all tests
import pytest
import torch
import torch.nn as nn

@pytest.fixture(scope="session")
def device():
    """Use GPU if available, otherwise CPU. Session-scoped for efficiency."""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

@pytest.fixture
def tiny_model():
    """A minimal 2-layer MLP for fast shape and gradient tests."""
    return nn.Sequential(
        nn.Linear(32, 16),
        nn.ReLU(),
        nn.Linear(16, 4),
    )

@pytest.fixture
def small_batch():
    """A small fixed batch for deterministic tests."""
    torch.manual_seed(42)
    return {
        "inputs": torch.randn(8, 32),
        "labels": torch.randint(0, 4, (8,)),
    }

Testing Data Pipelines

Data preprocessing is the most common source of silent bugs — incorrect normalisation, wrong tokenisation, label encoding errors, and shape mismatches that only show up as slow convergence or slightly degraded metrics. Test these deterministically by fixing inputs and verifying exact outputs.

import pytest
import torch
import numpy as np
from src.data import normalize_features, TextDataset, build_tokenizer

class TestNormalization:
    def test_output_range(self):
        """After normalization, all values should be in a reasonable range."""
        raw = torch.randn(100, 10) * 50 + 20   # mean≈20, std≈50
        normed = normalize_features(raw)
        assert normed.mean(dim=0).abs().max() < 0.1, "Mean should be near zero"
        assert (normed.std(dim=0) - 1.0).abs().max() < 0.1, "Std should be near one"

    def test_deterministic(self):
        """Same input always produces same output."""
        x = torch.randn(10, 5)
        assert torch.allclose(normalize_features(x), normalize_features(x))

    def test_no_nans(self):
        """Normalization should not produce NaNs even with near-zero std."""
        x = torch.zeros(10, 5)   # zero std — division-by-zero risk
        result = normalize_features(x)
        assert not torch.isnan(result).any()

class TestDataset:
    def test_length(self, tmp_path):
        """Dataset length matches the number of samples."""
        dataset = TextDataset(["sample one", "sample two", "sample three"])
        assert len(dataset) == 3

    def test_output_shapes(self):
        """Each item has expected keys and tensor shapes."""
        dataset = TextDataset(["hello world"] * 4, max_length=16)
        item = dataset[0]
        assert "input_ids" in item
        assert item["input_ids"].shape == (16,)
        assert item["attention_mask"].shape == (16,)

    def test_labels_are_valid(self):
        """All labels should be within the valid class range."""
        dataset = TextDataset(["text"] * 10, labels=[0, 1, 2, 0, 1, 2, 0, 1, 2, 0], num_classes=3)
        for i in range(len(dataset)):
            label = dataset[i]["label"]
            assert 0 <= label < 3, f"Invalid label {label} at index {i}"

Testing Model Forward Passes

Shape tests and gradient flow tests are the two most valuable model tests. Shape tests catch the majority of architectural bugs — mismatched dimensions, wrong output heads, incorrect batch dimension handling — before you have wasted GPU time on a training run. Gradient flow tests catch dead layers, detached computation graphs, and incorrect use of .detach() that silently prevent parts of the model from training.

import pytest
import torch
import torch.nn as nn
from src.models import MyTransformerModel

class TestModelForwardPass:
    @pytest.fixture
    def model(self):
        return MyTransformerModel(vocab_size=1000, hidden_size=64, num_layers=2, num_classes=5)

    def test_output_shape(self, model):
        """Output shape should match (batch_size, num_classes)."""
        x = torch.randint(0, 1000, (4, 32))   # batch=4, seq_len=32
        out = model(x)
        assert out.shape == (4, 5), f"Expected (4, 5), got {out.shape}"

    def test_output_is_finite(self, model):
        """Forward pass should not produce NaN or Inf."""
        x = torch.randint(0, 1000, (4, 32))
        out = model(x)
        assert torch.isfinite(out).all(), "Model output contains NaN or Inf"

    def test_output_changes_with_input(self, model):
        """Different inputs should (almost always) produce different outputs."""
        x1 = torch.randint(0, 1000, (4, 32))
        x2 = torch.randint(0, 1000, (4, 32))
        assert not torch.allclose(model(x1), model(x2)), "Model ignores input"

    def test_gradient_flows_to_all_parameters(self, model):
        """Every parameter with requires_grad=True should receive a gradient."""
        x = torch.randint(0, 1000, (4, 32))
        labels = torch.randint(0, 5, (4,))
        loss = nn.CrossEntropyLoss()(model(x), labels)
        loss.backward()
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert param.grad is not None, f"No gradient for {name}"
                assert param.grad.abs().sum() > 0, f"Zero gradient for {name}"

    def test_eval_mode_disables_dropout(self, model):
        """In eval mode, repeated forward passes should be deterministic."""
        model.eval()
        x = torch.randint(0, 1000, (4, 32))
        with torch.no_grad():
            out1 = model(x)
            out2 = model(x)
        assert torch.allclose(out1, out2), "Eval mode outputs are not deterministic"

    def test_batch_independence(self, model):
        """Changing one sample in the batch should not affect other samples' outputs."""
        model.eval()
        x = torch.randint(0, 1000, (4, 32))
        with torch.no_grad():
            out_full = model(x)
            x_modified = x.clone()
            x_modified[2] = torch.randint(0, 1000, (32,))
            out_modified = model(x_modified)
        # Samples 0, 1, 3 should be unchanged; sample 2 will differ
        assert torch.allclose(out_full[0], out_modified[0])
        assert torch.allclose(out_full[1], out_modified[1])
        assert torch.allclose(out_full[3], out_modified[3])

The batch independence test is particularly valuable for catching incorrect use of batch normalisation during inference (where the running statistics should be used, not the batch statistics) and for catching attention mask bugs that allow one sequence to attend to a different sequence's tokens.

Testing Loss Functions

import pytest
import torch
import torch.nn as nn
from src.training import FocalLoss, ContrastiveLoss

class TestFocalLoss:
    def test_zero_loss_on_perfect_predictions(self):
        """Perfect predictions should produce near-zero loss."""
        loss_fn = FocalLoss(gamma=2.0)
        # Construct nearly-perfect logits: correct class gets very high logit
        logits = torch.full((4, 5), -10.0)
        labels = torch.tensor([0, 1, 2, 3])
        for i, label in enumerate(labels):
            logits[i, label] = 10.0
        loss = loss_fn(logits, labels)
        assert loss.item() < 0.01, f"Loss on perfect predictions too high: {loss.item()}"

    def test_loss_is_nonnegative(self):
        """Loss should never be negative."""
        loss_fn = FocalLoss(gamma=2.0)
        for _ in range(10):
            logits = torch.randn(8, 5)
            labels = torch.randint(0, 5, (8,))
            assert loss_fn(logits, labels).item() >= 0

    def test_higher_gamma_reduces_easy_example_weight(self):
        """Higher gamma should downweight well-classified examples more."""
        easy_logits = torch.zeros(4, 2)
        easy_logits[:, 0] = 5.0   # very confident correct predictions
        labels = torch.zeros(4, dtype=torch.long)
        loss_gamma0 = FocalLoss(gamma=0.0)(easy_logits, labels)
        loss_gamma2 = FocalLoss(gamma=2.0)(easy_logits, labels)
        assert loss_gamma2 < loss_gamma0, "Higher gamma should reduce easy example loss"

class TestTrainingConvergence:
    def test_model_can_overfit_tiny_dataset(self):
        """A model should be able to memorise a tiny dataset — if it cannot, something is broken."""
        model = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 2))
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        criterion = nn.CrossEntropyLoss()
        # 4 fixed samples — model should memorise these perfectly
        x = torch.randn(4, 8)
        y = torch.tensor([0, 1, 0, 1])
        initial_loss = criterion(model(x), y).item()
        for _ in range(200):
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
        final_loss = criterion(model(x), y).item()
        assert final_loss < 0.01, f"Model failed to overfit: initial={initial_loss:.3f}, final={final_loss:.3f}"

The overfit test is one of the most reliable ML-specific tests: if a model cannot memorise a tiny fixed dataset, there is almost certainly a bug in the forward pass, the loss function, or the optimiser setup. This test catches gradient sign errors, detached computation graphs, and incorrect loss reductions that are invisible during normal training where the model's validation metric degrades slowly rather than the training loss failing to decrease at all.

Running Tests Efficiently with Markers and CI

# Mark slow tests so you can skip them during fast development iteration
import pytest

@pytest.mark.slow
def test_full_training_loop():
    # ... expensive test requiring GPU
    pass

@pytest.mark.gpu
def test_cuda_memory_management():
    if not torch.cuda.is_available():
        pytest.skip("No GPU available")
    # ... GPU-specific test
# Run only fast unit tests (default during development)
pytest tests/unit/ -v

# Run everything except slow tests
pytest tests/ -m "not slow" -v

# Run only GPU tests
pytest tests/ -m "gpu" -v

# Run full suite with coverage report
pytest tests/ --cov=src --cov-report=html -v
# .github/workflows/test.yml
name: ML Tests
on: [push, pull_request]
jobs:
  unit-tests:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with: {python-version: "3.11"}
      - run: pip install -e ".[dev]"
      - run: pytest tests/unit/ -m "not gpu" -v --tb=short

The decision of what to put in CI versus what to run on a schedule comes down to execution time and resource requirements. Unit tests that run in under two minutes and require only CPU belong in CI on every push — they catch regressions immediately when they are easiest to fix. GPU integration tests and convergence tests that take 10–30 minutes belong on a nightly or pre-merge schedule. The key discipline is keeping the fast unit test suite actually fast: a test suite that takes 20 minutes discourages running it locally and defeats the purpose of having tests at all.

What Makes ML Testing Different from Standard Software Testing

Testing ML code has three properties that have no equivalent in ordinary software testing, and misunderstanding them leads to either under-testing (writing no tests because "you can't test randomness") or over-testing (writing brittle tests that break on every run because they check floating point outputs exactly).

The first difference is non-determinism. Model outputs depend on random initialisation, random data sampling, and random augmentation. Tests that check exact output values will fail on different hardware or different random seeds. The solution is to test properties rather than exact values: instead of asserting that a logit equals 2.347, assert that all logits are finite, that the output has the right shape, that the softmax probabilities sum to one, and that a model in eval mode produces the same output on two identical inputs. These property-based assertions are robust to the non-determinism inherent in ML and catch the bugs that actually matter.

The second difference is statistical correctness. Some tests need to verify that a metric meets a threshold, not that it equals an exact value. A tokeniser test might assert that encoding followed by decoding recovers the original string for 99.9% of test cases rather than 100%, because the tokeniser may be lossy for certain rare Unicode characters. A model calibration test might assert that confidence scores have an ECE below 0.1 on a fixed validation sample rather than checking exact probabilities. Writing these tests requires deciding what the acceptable range is, which forces you to think about what "correct" actually means for your task — a valuable exercise in itself.

The third difference is the role of the overfit test. In standard software, you do not test that a sorting algorithm can sort a four-element list because that is trivially verifiable by reading the code. In ML, the overfit test — verifying that your model can memorise a tiny fixed dataset — is a genuine, high-value test because the failure mode (training loss does not decrease even on a tiny dataset) is easy to accidentally introduce and hard to diagnose from slow validation metrics alone. Any time you add a new model component, change the loss function, or modify the training loop, running the overfit test first takes seconds and immediately confirms or rules out the most common classes of implementation error.

Testing Inference and Serving Code

Inference pipelines have different failure modes than training code. The most common are: preprocessing that differs between training time and serving time (different normalisation, different tokenisation, different image resizing), incorrect handling of edge cases in input validation (empty strings, sequences longer than max_length, images with unexpected aspect ratios), and performance regressions after model updates. All of these are testable and should be.

import pytest
import torch
from src.inference import Predictor

class TestInferencePipeline:
    @pytest.fixture
    def predictor(self, tmp_path):
        # Save and reload a model to test the full serialisation round-trip
        model = build_and_train_tiny_model()
        model_path = tmp_path / "model.pt"
        torch.save(model.state_dict(), model_path)
        return Predictor(model_path=str(model_path))

    def test_single_prediction_shape(self, predictor):
        result = predictor.predict("sample input text")
        assert "label" in result
        assert "confidence" in result
        assert 0.0 <= result["confidence"] <= 1.0

    def test_batch_prediction_matches_single(self, predictor):
        """Batch prediction should give same results as single prediction."""
        inputs = ["text one", "text two", "text three"]
        batch_results = predictor.predict_batch(inputs)
        for i, text in enumerate(inputs):
            single_result = predictor.predict(text)
            assert batch_results[i]["label"] == single_result["label"]

    def test_empty_input_raises_clearly(self, predictor):
        with pytest.raises(ValueError, match="Input cannot be empty"):
            predictor.predict("")

    def test_long_input_is_truncated_not_errored(self, predictor):
        long_text = "word " * 10000
        result = predictor.predict(long_text)  # should truncate, not raise
        assert result is not None

The batch-matches-single test is important because batch inference implementations often take shortcuts — sharing attention mask computations, using different padding strategies — that can cause the predictions for a given sample to depend on what other samples are in the same batch. Catching this bug before it reaches production is straightforward with this test and very expensive to diagnose after the fact. The empty input and long input edge case tests catch the most common input validation failures and are cheap to write because they are entirely deterministic regardless of model weights.

Regression Tests for Model Updates

When you update a model — new training data, new architecture, new preprocessing — you want to know whether the change improved, degraded, or had no effect on a fixed set of representative inputs. Regression tests serve this function by computing predictions on a fixed golden dataset and comparing the outputs or metrics to a stored baseline. The implementation pattern is: compute predictions on the golden dataset once, store the results as a fixture file, then assert in the test that new predictions match the baseline within a tolerance. For classification, you might assert that the top-1 accuracy on the golden set is within 1% of the baseline. For regression, assert that the RMSE has not increased by more than a threshold. Any significant degradation on the golden set surfaces as a failing test, triggering a review before the new model reaches production.

Handling Non-Determinism in Tests

Setting random seeds makes tests deterministic, which is sometimes what you want (reproducible shape and value checks) but requires care about scope. A test that sets a global seed at the start may interact with other tests in unexpected ways if the test runner executes them in a different order. The recommended practice is to use pytest fixtures that set the seed locally within each test that needs determinism, rather than setting it globally in conftest. This keeps each test self-contained and prevents order-dependent failures.

import pytest
import torch
import numpy as np
import random

@pytest.fixture
def fixed_seed():
    """Set all relevant random seeds for a deterministic test environment."""
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # Optionally force deterministic CUDA ops (slower but fully reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    yield
    # Reset benchmark after test to avoid performance impact on other tests
    torch.backends.cudnn.benchmark = True

def test_model_output_is_reproducible(fixed_seed, tiny_model):
    """With fixed seed, two identical forward passes should give identical output."""
    x = torch.randn(4, 32)
    out1 = tiny_model(x)
    # Reset seed and repeat
    torch.manual_seed(42)
    x2 = torch.randn(4, 32)
    out2 = tiny_model(x2)
    assert torch.allclose(x, x2), "Inputs should match with same seed"
    assert torch.allclose(out1, out2), "Outputs should match with same seed"

For tests that intentionally exercise probabilistic behaviour — verifying that dropout actually drops neurons, or that data augmentation produces varied outputs — you want the opposite: explicitly do not set a seed, run the test multiple times implicitly by using pytest's parametrize with different seeds, and assert the statistical property rather than the exact value. A dropout test might assert that the fraction of zeros in the output is within a reasonable range of the expected dropout probability across a large batch, rather than checking the exact positions of zero elements.

Measuring What Your Tests Actually Cover

Coverage reports tell you which lines of your source code are executed by your tests. For ML code, 100% line coverage is neither achievable nor the right goal — GPU-specific code paths, rare error handling branches, and distributed training code are hard to cover in unit tests. A more practical target is 80–90% line coverage for your data processing and model code, with known gaps documented. Running pytest --cov=src --cov-report=term-missing shows exactly which lines are not covered, which often surfaces entire functions or classes that have no tests at all — the most common state for a codebase that has grown organically without a testing culture. Starting from zero tests, the highest-value coverage to add first is: normalisation and preprocessing functions (deterministic, easy to test, high bug surface), model shape and gradient flow tests (catch the most common architecture bugs), and the overfit convergence test (catches training loop bugs). Everything else follows from there, and even this minimal set provides far more confidence than the no-test baseline that most ML codebases start from.

pytest Plugins Worth Knowing

Three pytest plugins are particularly useful for ML projects. pytest-xdist runs tests in parallel across CPU cores, which can cut unit test suite time by 4x on a typical laptop when tests are CPU-bound and independent. Run it with pytest -n auto to use all available cores. pytest-benchmark tracks the execution time of specific operations across runs and fails the test if a function regresses beyond a threshold — useful for catching performance regressions in preprocessing or inference code that does not show up in correctness tests. pytest-randomly randomises test execution order on each run, which surfaces order-dependent test failures caused by tests that share mutable global state (a common problem when models are defined as module-level globals). Using all three in development eliminates a class of testing infrastructure problems before they compound.

The broader point about ML testing culture is worth stating plainly: the cost of writing a shape test or an overfit test is ten minutes; the cost of diagnosing a silent training bug that has been present for three weeks of GPU runs is measured in days. Testing ML code is not harder than testing other software — it requires a slightly different set of patterns for handling non-determinism and statistical correctness — but the leverage is enormous because the feedback loops without tests are so much slower. The first 20 tests in an ML codebase deliver more value per hour of investment than almost anything else you can do to the development workflow.

Leave a Comment