How to Use PyTorch CUDA Memory Management to Avoid OOM Errors

PyTorch’s CUDA memory allocator sits between your Python code and the GPU hardware, and most OOM errors are caused by misunderstanding how it works rather than by genuinely running out of memory. The allocator maintains a cache of previously allocated blocks and reuses them rather than calling cudaMalloc and cudaFree on every tensor allocation — this is necessary because cudaMalloc is slow (microseconds to milliseconds) and calling it per-tensor would dominate training time. The consequence is that memory displayed by nvidia-smi is the total reserved memory (all cached blocks), not the memory actually occupied by live tensors. A process can show 20GB reserved in nvidia-smi while only using 8GB for live tensors, with 12GB sitting in the cache ready for reuse. Understanding this distinction is the foundation of effective CUDA memory management.

The Memory Stats API

torch.cuda.memory_stats() returns a detailed breakdown of the allocator’s internal state. The most useful fields for diagnosing OOM errors:

import torch

# After a training step or forward pass
stats = torch.cuda.memory_stats(device=0)

# Memory currently occupied by live tensors
allocated = stats["allocated_bytes.all.current"] / 1e9
print(f"Allocated: {allocated:.2f} GB")

# Memory reserved by the allocator (allocated + cached free blocks)
reserved = stats["reserved_bytes.all.current"] / 1e9
print(f"Reserved:  {reserved:.2f} GB")

# Peak allocated since last reset
peak = stats["allocated_bytes.all.peak"] / 1e9
print(f"Peak allocated: {peak:.2f} GB")

# Number of OOM retries (allocator tried to release cache before failing)
retries = stats["num_alloc_retries"]
print(f"OOM retries: {retries}")

# Simpler summary
print(torch.cuda.memory_summary(device=0, abbreviated=True))

The num_alloc_retries field is particularly useful — a non-zero value means the allocator hit a point where it couldn’t satisfy a request from its cache, released cached memory back to CUDA, and retried. This often happens before a hard OOM and indicates you’re operating near the memory limit. If you see retries increasing during training, you’re in the danger zone and should investigate which tensors are consuming memory unexpectedly.

Finding Memory Leaks with Snapshots

Memory leaks in PyTorch training loops are usually caused by accidentally accumulating tensors — keeping references to intermediate activations, appending loss tensors to a list without calling .item(), or storing model outputs that retain computation graphs. The memory snapshot API introduced in PyTorch 2.0 captures the allocation history and lets you visualise which tensors are holding memory:

import torch
import pickle

# Start recording allocations
torch.cuda.memory._record_memory_history(max_entries=100000)

# Run your training loop for a few steps
for i, batch in enumerate(dataloader):
    outputs = model(batch)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i >= 10:
        break

# Save snapshot
snapshot = torch.cuda.memory._snapshot()
with open("memory_snapshot.pkl", "wb") as f:
    pickle.dump(snapshot, f)

torch.cuda.memory._record_memory_history(enabled=None)  # stop recording

Upload the .pkl file to https://pytorch.org/memory_viz to get an interactive flame graph showing every allocation — tensor shape, dtype, size, and the Python stack trace that created it. This makes it immediately obvious when a list is accumulating loss tensors across batches, or when a large activation from a forward pass is still alive during the next batch because something holds a reference to it.

torch.cuda.empty_cache() — What It Does and Doesn’t Do

empty_cache() releases the allocator’s cached free blocks back to CUDA, reducing the reserved memory shown by nvidia-smi. It does not free memory occupied by live tensors — it only clears the cache of blocks that are free but held by the allocator for potential reuse. Calling it does not prevent OOM errors caused by live tensor accumulation; it only helps when reserved but unused cache is the issue. The most common misuse is calling empty_cache() in a training loop expecting it to prevent OOM — if OOM is caused by gradients or activations not being freed, empty_cache() has no effect. It is legitimately useful after a large one-off operation (loading a big checkpoint, running a large batch for evaluation) where you want to release the cache before switching to a different workload that needs different memory patterns.

# Legitimate use: release cache after evaluation before resuming training
model.eval()
with torch.no_grad():
    for batch in eval_loader:
        outputs = model(batch)
        # ... compute metrics

# Release cache from large eval batches before resuming smaller training batches
torch.cuda.empty_cache()

model.train()
for batch in train_loader:
    # Training step with different memory layout
    ...

Memory Fragmentation and the expandable_segments Backend

Memory fragmentation is a less obvious cause of OOM that affects long training runs. The allocator manages memory in blocks of varying sizes, and after many allocations and frees of different sizes, the free blocks become fragmented — you have enough total free memory to satisfy a request but no single contiguous block large enough. This shows up as OOM errors during training that weren’t present early in the run, and the memory summary will show high reserved memory with low allocated memory (lots of free blocks, but fragmented). PyTorch 2.0 introduced an expandable segments allocator backend that virtually eliminates fragmentation by using CUDA virtual memory management to expand and contract blocks dynamically:

import os
# Enable before importing torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch

Expandable segments is now the default in recent PyTorch versions, but if you’re on an older version or see fragmentation-related OOM, enabling it explicitly is the first thing to try. Other useful PYTORCH_CUDA_ALLOC_CONF settings: max_split_size_mb controls the maximum size of a block the allocator will split to serve a smaller request (reducing fragmentation at the cost of some memory overhead), and garbage_collection_threshold sets the fraction of reserved memory that triggers a cache cleanup pass before attempting allocation.

Pinning Problematic Allocations with gc and weakref

When memory snapshots reveal that unexpected tensors are keeping memory alive, the culprit is usually a Python reference cycle or a closure that captures a tensor. The standard fix is to explicitly delete tensors and call gc.collect() at the end of each training step, and to use .item() to extract scalar values from loss tensors before storing them:

import gc

losses = []  # BAD: accumulates tensors with computation graphs
losses_scalar = []  # GOOD: accumulates Python floats

for batch in dataloader:
    loss = model(batch)

    losses_scalar.append(loss.item())  # detach from graph, copy to CPU
    # NOT: losses.append(loss)  # this keeps the entire computation graph alive

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)  # set_to_none=True frees gradient memory

    # Explicitly delete large intermediates if needed
    del loss

# Only call gc.collect() if you suspect reference cycles; it's slow
# gc.collect()
# torch.cuda.empty_cache()

set_to_none=True in optimizer.zero_grad() is worth highlighting separately — it sets gradient tensors to None rather than filling them with zeros, which actually frees the gradient memory rather than keeping the allocation. This reduces peak memory by roughly the size of all model parameters, which for a 7B model is several gigabytes. It’s a free win with no training correctness tradeoff.

Diagnosing OOM at Training Time

When you hit an OOM during training, PyTorch 2.1+ prints a memory summary automatically before raising the exception. This summary shows reserved, allocated, and fragmentation statistics at the moment of OOM — read it carefully before reaching for batch size reductions or gradient checkpointing. The key question is whether allocated memory is close to the GPU capacity (genuine memory pressure — reduce batch size or enable gradient checkpointing) or whether reserved is close to capacity but allocated is much lower (fragmentation — try expandable_segments or empty_cache before the problematic step). A third pattern is allocated memory growing monotonically across batches (memory leak — use snapshots to find what’s accumulating). Each cause has a different fix, and confusing them wastes significant debugging time.

For models that reliably OOM on the first forward pass rather than after many steps, the issue is almost always batch size or sequence length exceeding the model’s static memory requirements. Estimate peak memory before training with a dry-run at batch size 1 and scale up: peak memory scales roughly linearly with batch size for the activation component (the KV cache and intermediate activations), while model weights and optimizer states are fixed overhead. If even batch size 1 OOMs, you need gradient checkpointing, mixed precision, or a smaller model — there’s no memory management trick that overcomes a model that simply doesn’t fit. Use torch.cuda.memory_summary() after a single forward pass at batch size 1 to establish the baseline and plan accordingly.

Mixed Precision and Its Memory Implications

Automatic mixed precision (AMP) with torch.autocast reduces activation memory by roughly 50% by computing forward passes in bfloat16 or float16 while keeping master weights and optimizer states in float32. This is the highest-leverage single change for reducing training memory — typically enabling batch sizes 1.5–2x larger — and should be the first thing you try before gradient checkpointing or other more invasive techniques. The GradScaler is required when using float16 (to prevent gradient underflow) but not when using bfloat16, which has a wider dynamic range. On Ampere GPUs (A100, A10) and newer, bfloat16 is strongly preferred over float16: it has the same exponent range as float32 so it never needs gradient scaling, and the hardware runs bfloat16 tensor cores at full throughput.

AMP does not reduce the memory occupied by model parameters and optimizer states, which together account for roughly 16 bytes per parameter for Adam (4 bytes fp32 weights + 4 bytes fp32 gradients + 4 bytes first moment + 4 bytes second moment). For a 7B parameter model this is ~112GB — well beyond any single GPU. This is why large model training requires quantized optimizer states (8-bit Adam from bitsandbytes reduces this to ~10 bytes per parameter), ZeRO sharding across multiple GPUs, or offloading optimizer states to CPU memory. The memory management API described in this article addresses activation and intermediate tensor memory, which is the variable component that scales with batch size and sequence length — it won’t solve the fixed overhead of large model weights.

Custom Memory Allocators and CUDA Streams

For advanced use cases, PyTorch exposes hooks to replace the default caching allocator with a custom one via torch.cuda.memory.CUDAPluggableAllocator. This is rarely needed in practice — the default allocator handles the vast majority of workloads well — but becomes relevant when you’re integrating PyTorch with another CUDA library that has its own memory manager, or when you need deterministic memory layout for multi-process IPC. More commonly useful is understanding how CUDA streams interact with memory: allocations on different streams are tracked separately, and tensors created on a non-default stream are not returned to the cache until the stream’s work is complete. If you use custom CUDA streams for overlapping computation and data transfer, ensure that tensors created on those streams are explicitly deleted (or go out of scope) after the stream synchronises, otherwise they accumulate in a stream-specific cache that doesn’t get reused by the default stream’s allocations.

The practical memory management workflow for a new training setup: first, run a single forward and backward pass at batch size 1 and call torch.cuda.memory_summary() to establish the baseline memory floor. Second, enable AMP and repeat — this establishes the reduced activation baseline. Third, increase batch size until you’re using 85–90% of GPU memory (leaving headroom for allocation variability). Fourth, if you need larger effective batch sizes, add gradient accumulation rather than increasing batch size further. Fifth, if the model still doesn’t fit at batch size 1, add gradient checkpointing (trading compute for activation memory) and/or ZeRO optimizer state sharding. Only resort to empty_cache calls and explicit gc.collect() if snapshots reveal a genuine accumulation problem — adding these speculatively adds overhead without addressing the root cause.

Multi-GPU Memory Considerations

On multi-GPU setups, each GPU has its own independent memory allocator and cache. Memory pressure on one GPU does not automatically trigger cache releases on others — you need to call empty_cache on each device explicitly if needed. With DDP, all GPUs run identical model replicas and have identical memory usage, so an OOM on one GPU means all GPUs are at risk. With FSDP or DeepSpeed ZeRO, model parameters and optimizer states are sharded across GPUs, so individual GPU memory usage is lower — but the all-gather operations that reconstitute parameters for each forward pass create temporary memory spikes that can be 2–3x the shard size. Profile FSDP memory with torch.cuda.memory_summary() on each rank separately to understand the per-GPU allocation pattern, since the peak during an all-gather may be substantially higher than the steady-state shard size. The FSDP cpu_offload parameter offloads parameters and gradients to CPU between uses, dramatically reducing GPU memory at the cost of PCIe bandwidth — useful for fitting larger models at the cost of training throughput.

Leave a Comment