Online Hard Example Mining (OHEM): How It Works and When to Use It

When training on imbalanced or long-tailed datasets, the standard minibatch gradient update is dominated by easy examples — the samples the model already classifies correctly with high confidence. These examples contribute near-zero loss and near-zero gradient, so most of each training step’s compute is spent processing information the model has already learned. Online Hard Example Mining (OHEM) addresses this by selecting only the highest-loss samples within each batch for the backward pass, concentrating gradient signal on the examples that are actually informative at that point in training. This article covers how OHEM works mechanically, when it helps, how to implement it efficiently in PyTorch, and how it compares to related techniques like focal loss and class-weighted sampling.

How OHEM Works

Standard minibatch training computes the loss for every sample in the batch and averages them before backpropagating. OHEM modifies this by performing a forward pass on the full batch to compute per-sample losses, then selecting only the top-k highest-loss samples, and backpropagating only through those selected samples. The selection is adaptive and online — “online” means it happens during training using the model’s current predictions rather than requiring a pre-analysis pass over the dataset. At each batch, the hard examples are different because the model’s parameters have updated since the previous step, so what was hard last step may now be easy, and new hard examples may have emerged.

The original OHEM paper (Shrivastava et al., 2016) applied this to object detection with R-CNN, where the ratio of background (easy) to foreground (hard) regions was extreme — roughly 1000:1. Backpropagating through all background examples wasted compute and diluted the gradients from the informative foreground examples. By selecting only the top 25% of highest-loss regions per image for the backward pass, training converged faster and to better final detection accuracy. The same principle applies to any setting with high class imbalance, hard negatives, or long-tailed class distributions.

Basic OHEM Implementation in PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class OHEMLoss(nn.Module):
    """Online Hard Example Mining loss for classification tasks.
    
    Selects the top-k highest-loss samples per batch and computes the
    loss only over those samples. Gradients flow only through selected samples.
    """
    def __init__(self, keep_ratio: float = 0.5, min_kept: int = 1,
                 ignore_index: int = -100):
        """
        keep_ratio: fraction of samples to keep (0 < keep_ratio <= 1.0)
        min_kept:   minimum number of samples to always keep
        """
        super().__init__()
        self.keep_ratio = keep_ratio
        self.min_kept = min_kept
        self.ignore_index = ignore_index

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        logits:  (N, C) — raw unnormalised class scores
        targets: (N,)   — integer class indices
        """
        # Compute per-sample loss without reduction
        loss_per_sample = F.cross_entropy(
            logits, targets,
            ignore_index=self.ignore_index,
            reduction='none',
        )
        # Mask out ignored samples
        valid_mask = targets != self.ignore_index
        valid_losses = loss_per_sample[valid_mask]
        n_valid = valid_mask.sum().item()

        if n_valid == 0:
            return loss_per_sample.sum() * 0  # zero loss, maintains grad_fn

        # Determine number of samples to keep
        n_keep = max(self.min_kept, int(n_valid * self.keep_ratio))
        n_keep = min(n_keep, n_valid)

        # Select top-n_keep highest-loss samples
        # Use topk — more efficient than full sort for large batches
        _, top_indices = valid_losses.topk(n_keep, largest=True, sorted=False)

        # Backpropagate only through selected samples
        hard_losses = valid_losses[top_indices]
        return hard_losses.mean()

# Usage in training loop
criterion = OHEMLoss(keep_ratio=0.5, min_kept=32)
logits = model(inputs)          # (batch, num_classes)
loss = criterion(logits, targets)
loss.backward()

The key implementation detail is using reduction='none' to get per-sample losses before selection. Gradients only flow through the valid_losses[top_indices] tensor — PyTorch’s autograd automatically handles this because the backward pass only traverses the computation graph nodes connected to the final loss value, and samples not in top_indices are not connected.

OHEM for Semantic Segmentation

Semantic segmentation is the domain where OHEM has the most consistent benefit, because each image contributes hundreds of thousands of pixels and the vast majority are background or easy-to-classify regions. Applying OHEM at the pixel level keeps only the hardest pixels per image for the gradient update.

class SegmentationOHEMLoss(nn.Module):
    """OHEM cross-entropy loss for semantic segmentation.
    
    Operates at the pixel level: selects the top-k highest-loss pixels
    across the batch for the backward pass.
    """
    def __init__(self, keep_ratio: float = 0.5, min_kept: int = 100_000,
                 thresh: float = 0.7, ignore_index: int = 255):
        """
        thresh: only keep pixels where max predicted prob < thresh
                (i.e., the model is uncertain). Optional alternative to top-k.
        """
        super().__init__()
        self.keep_ratio = keep_ratio
        self.min_kept = min_kept
        self.thresh = thresh
        self.ignore_index = ignore_index

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        logits:  (N, C, H, W) — segmentation logits
        targets: (N, H, W)    — integer class labels per pixel
        """
        N, C, H, W = logits.shape
        # Reshape to (N*H*W, C) and (N*H*W,)
        logits_flat  = logits.permute(0, 2, 3, 1).reshape(-1, C)
        targets_flat = targets.reshape(-1)

        # Per-pixel loss
        pixel_losses = F.cross_entropy(
            logits_flat, targets_flat,
            ignore_index=self.ignore_index,
            reduction='none',
        )
        valid_mask = targets_flat != self.ignore_index
        valid_losses = pixel_losses[valid_mask]
        n_valid = valid_mask.sum().item()

        if n_valid == 0:
            return pixel_losses.sum() * 0

        # Threshold-based selection: keep pixels where model is uncertain
        with torch.no_grad():
            probs = F.softmax(logits_flat[valid_mask], dim=-1)
            max_probs, _ = probs.max(dim=-1)
            threshold_mask = max_probs < self.thresh

        n_thresh = threshold_mask.sum().item()
        n_keep = max(self.min_kept, int(n_valid * self.keep_ratio))

        if n_thresh >= n_keep:
            # Enough uncertain pixels — use threshold-based selection
            hard_losses = valid_losses[threshold_mask]
        else:
            # Fall back to top-k if not enough uncertain pixels
            n_keep = min(n_keep, n_valid)
            _, top_idx = valid_losses.topk(n_keep, largest=True, sorted=False)
            hard_losses = valid_losses[top_idx]

        return hard_losses.mean()

OHEM vs Focal Loss: When to Use Which

OHEM and focal loss both address the easy example problem but through different mechanisms, and they are appropriate in different situations. Focal loss (Lin et al., 2017) multiplies the cross-entropy loss by a modulating factor (1 - p_t)^γ that smoothly down-weights easy examples (high p_t) and up-weights hard ones (low p_t). Every sample contributes to the gradient, but easy samples contribute less. OHEM takes a harder approach: the bottom fraction of samples contributes nothing to the gradient at all.

Focal loss is differentiable everywhere and has no hyperparameter for how many samples to keep — the modulation is continuous. This makes it simpler to use and less sensitive to the exact γ value than OHEM is to the keep_ratio. Focal loss is the better default for object detection (where it was designed) and binary classification with class imbalance. OHEM tends to be more effective for segmentation tasks with extreme imbalance, for embedding learning with very large negative pools, and for any setting where you want to explicitly control what fraction of the batch contributes gradients rather than using a soft weighting.

The two approaches can be combined: apply focal loss as the per-sample loss function and then select only the top-k by focal loss value for the backward pass. This gives you the smooth modulation of focal loss plus the hard cutoff of OHEM, which can be useful when the dataset has both extreme imbalance and many moderately hard examples that focal loss alone would still upweight too much.

class FocalOHEMLoss(nn.Module):
    """Combined Focal + OHEM loss."""
    def __init__(self, gamma: float = 2.0, keep_ratio: float = 0.5,
                 min_kept: int = 1):
        super().__init__()
        self.gamma = gamma
        self.keep_ratio = keep_ratio
        self.min_kept = min_kept

    def forward(self, logits, targets):
        # Per-sample focal loss
        ce = F.cross_entropy(logits, targets, reduction='none')
        probs = F.softmax(logits, dim=-1)
        p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal_loss = (1 - p_t) ** self.gamma * ce

        # OHEM selection on focal losses
        n_keep = max(self.min_kept, int(len(focal_loss) * self.keep_ratio))
        _, top_idx = focal_loss.topk(n_keep, largest=True, sorted=False)
        return focal_loss[top_idx].mean()

OHEM for Metric Learning and Embedding Training

OHEM has a natural application in contrastive and metric learning: in a batch of triplets (anchor, positive, negative), many negatives are easy — already far from the anchor in embedding space. Gradients from these easy negatives are near zero and do not improve the model. Hard negative mining selects only the negatives closest to the anchor (hardest negatives) or negatives within a distance margin (semi-hard negatives) for each anchor. The article on hard negative mining in this series covers that in depth; OHEM generalises the same principle to arbitrary loss functions beyond triplet loss by operating on loss values rather than embedding distances. For embedding models trained with InfoNCE or MultipleNegativesRankingLoss, OHEM-style selection keeps only the top fraction of pairs by their contribution to the batch loss, which in practice means keeping the pairs where the model’s current representation is most confused.

Tuning the Keep Ratio

The keep_ratio hyperparameter (what fraction of each batch to retain) is the main dial to tune. Too high (0.9+) and OHEM behaves nearly identically to standard training, providing no benefit. Too low (0.1 or below) and you may discard so many samples per batch that training becomes noisy and unstable, particularly early in training when the loss values are high and uninformative. The standard starting value is 0.5 — keep the hardest 50% of each batch. For heavily imbalanced datasets where the easy class dominates, 0.25 or even 0.1 can be appropriate, but monitor training stability (loss curves should decrease smoothly, not spike). The min_kept parameter is an important safeguard: it ensures that even if most of the batch is below the keep threshold, a minimum number of samples always contribute gradient, which prevents degenerate updates on batches that happen to be unusually easy.

When OHEM Actually Helps vs When It Does Not

OHEM provides consistent benefit in three specific settings. The first is severe class imbalance where one class dominates the dataset at ratios of 10:1 or higher — medical imaging (lesion vs background), object detection (background vs foreground), fraud detection, and similar tasks. In these settings, the majority class generates so many low-loss examples per batch that the gradient signal from the minority class is numerically swamped, even with class-weighted loss. OHEM directly fixes this by discarding the easy majority-class examples entirely rather than down-weighting them.

The second setting is curriculum learning scenarios where you intentionally want the model to focus on progressively harder examples as training proceeds. OHEM is naturally adaptive: early in training when the model is poor, almost every sample is hard, so the keep ratio selection does not change the effective training set much. As training progresses and the model improves, more samples become easy and OHEM filters them out, naturally shifting training focus toward the remaining hard cases without any explicit curriculum schedule.

The third is fine-tuning a strong pretrained model on a small dataset. When the pretrained model already handles most of the dataset with high confidence, standard fine-tuning wastes most of each batch on easy examples the model handles well from pretraining. OHEM concentrates the fine-tuning signal on the samples that are genuinely novel or difficult for the pretrained model on the new task.

OHEM does not help — and can actually hurt — in two common situations. The first is when the dataset is already hard throughout, meaning the model does not improve enough during training to create an easy/hard split. In this case, OHEM’s selection degenerates to near-random selection (all samples have similar loss values) and the reduced effective batch size from discarding samples adds noise without benefit. The second is early in training on a task with noisy labels: the highest-loss samples may be mislabelled examples rather than genuinely hard ones, and OHEM will concentrate training on the noise, accelerating overfitting to label errors. If you have label noise concerns, clean or filter your dataset first before applying OHEM.

Memory and Compute Considerations

A common misconception about OHEM is that it saves compute per iteration because you backpropagate through fewer samples. In the naive implementation above, this is true only for the backward pass — the forward pass still processes the full batch. For tasks where the forward pass is cheap (small models, short sequences), the compute savings from OHEM may be minimal because the full-batch forward pass dominates the per-step cost. For tasks where the forward pass is expensive (large language models, high-resolution images), OHEM can meaningfully reduce per-step time because fewer samples flow through the expensive backward pass.

A more compute-efficient variant used in some detection frameworks separates the OHEM selection into a no-gradient forward pass (to rank samples) followed by a gradient-enabled forward-backward pass on only the selected samples. This avoids computing gradients for the discarded samples during the forward pass as well, but requires processing each batch twice through the model, which is only worth it when the selected fraction is small (below 0.3) and the model is large enough that the gradient computation is the bottleneck.

class EfficientOHEMLoss(nn.Module):
    """Two-pass OHEM: no-grad forward for selection, then grad-enabled pass on hard samples only."""
    def __init__(self, keep_ratio=0.5, min_kept=1):
        super().__init__()
        self.keep_ratio = keep_ratio
        self.min_kept = min_kept

    def forward(self, model, inputs, targets, criterion=nn.CrossEntropyLoss(reduction='none')):
        # Pass 1: no-grad forward to find hard examples
        with torch.no_grad():
            logits_nograd = model(inputs)
            losses = criterion(logits_nograd, targets)

        n_keep = max(self.min_kept, int(len(losses) * self.keep_ratio))
        _, hard_idx = losses.topk(n_keep, largest=True, sorted=False)

        # Pass 2: grad-enabled forward on hard samples only
        hard_logits = model(inputs[hard_idx])
        hard_loss = criterion(hard_logits, targets[hard_idx])
        return hard_loss.mean()

In practice, for most workloads the single-pass implementation is simpler and good enough — the two-pass version is worth the added complexity only when you have profiled the training loop and confirmed that OHEM selection is a meaningful fraction of per-step time. The single-pass approach with reduction='none' and topk selection is the standard implementation used in most open-source segmentation frameworks and is the right starting point for any new OHEM integration.

Decision Framework: OHEM, Focal Loss, or Class Weighting

When facing class imbalance or hard example problems, the choice between OHEM, focal loss, and class-weighted loss comes down to your dataset characteristics and implementation constraints. Use class-weighted loss as the simplest baseline — it requires no implementation beyond passing a weight tensor to nn.CrossEntropyLoss and handles static imbalance well. Move to focal loss when you need smooth, differentiable down-weighting of easy examples without discarding them, particularly for detection tasks where the loss landscape benefit of retaining all samples outweighs the added gradient noise. Use OHEM when the imbalance is extreme enough that even focal loss leaves too many easy examples in the gradient, when you want explicit control over the effective training set size per batch, or when you are fine-tuning a strong pretrained model where most of the dataset is already easy. The combined focal-OHEM approach is worth trying when both the distribution is heavily imbalanced and there are many moderately hard examples that focal loss alone insufficiently concentrates on — but start with the simpler options first and only add complexity when benchmarking shows a measurable benefit on your validation set.

Leave a Comment