Gradient Accumulation and Gradient Checkpointing Explained

Gradient accumulation and gradient checkpointing are two of the most useful memory management techniques for training large models, and they’re frequently confused because they both involve trading something for reduced memory usage. They solve different problems: gradient accumulation lets you simulate large batch sizes on hardware that can’t fit them, while gradient checkpointing reduces activation memory at the cost of extra computation. Understanding both precisely lets you apply them correctly — and combine them when you need to.

Gradient Accumulation

The gradient of the loss with respect to model parameters is additive across samples in a batch. If you can’t fit a batch of 128 samples in GPU memory, you can run 8 forward-backward passes with batch size 16, accumulate the gradients from each pass without updating the optimizer, and update once after all 8 passes. The optimizer sees a gradient mathematically equivalent to what you’d get from a single batch of 128 — the accumulation is exact, not an approximation.

accumulation_steps = 8
optimizer.zero_grad()

for step, batch in enumerate(dataloader):
    outputs = model(**batch)
    loss = outputs.loss / accumulation_steps  # normalize by accumulation steps
    loss.backward()

    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

The loss normalization is important and often missed. Without dividing by accumulation_steps, the accumulated gradient magnitude grows proportionally with accumulation steps, effectively scaling your learning rate. Dividing before the backward pass keeps gradient magnitude equivalent to a single large-batch forward pass.

With mixed precision training, wrap each micro-batch in autocast and only call scaler.step() and scaler.update() on the actual optimizer step — not every micro-batch:

scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad()

for step, batch in enumerate(dataloader):
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        outputs = model(**batch)
        loss = outputs.loss / accumulation_steps

    scaler.scale(loss).backward()

    if (step + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

What Gradient Accumulation Doesn’t Do

Gradient accumulation does not reduce peak memory during training. The memory required for a single forward-backward pass is determined by the micro-batch size. If your micro-batch is 16 samples, you need the activation memory for 16 samples during the backward pass regardless of how many steps you accumulate. Gradient accumulation helps when your constraint is effective batch size — too small a batch causes training instability or slow convergence — not when you’re running out of GPU memory on a single sample.

One subtle gotcha: gradient accumulation interacts poorly with batch normalization because BN statistics are computed per micro-batch. Transformer models using layer norm are not affected.

Gradient Checkpointing

During a standard backward pass, PyTorch needs the activations computed during the forward pass to calculate gradients. For a transformer with L layers, this means storing intermediate activations for every layer simultaneously — memory that scales linearly with both sequence length and depth. For a 7B parameter model on 4,096-token sequences, activation memory can easily exceed parameter memory.

Gradient checkpointing discards activations during the forward pass and recomputes them during the backward pass when needed. You store only a subset of checkpoint activations at segment boundaries and recompute the rest on demand. Activation memory drops from O(L) to O(sqrt(L)) with optimal granularity, at the cost of roughly 33% more compute — each segment is forward-passed twice. In practice the overhead is 20–25% because recomputed passes benefit from fused kernels and reduced memory pressure can allow larger batch sizes.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    use_cache=False  # must disable KV cache with gradient checkpointing
)
model.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

use_cache=False is required: the KV cache is unnecessary during training and incompatible with checkpointing. use_reentrant=False is recommended for PyTorch 2.0+ — the non-reentrant implementation handles nested autograd operations more robustly.

For custom modules, use torch.utils.checkpoint.checkpoint directly:

from torch.utils.checkpoint import checkpoint

class TransformerLayer(nn.Module):
    def forward(self, x, use_ckpt=False):
        fn = lambda x: self.attn(x) + self.ffn(x)
        return checkpoint(fn, x, use_reentrant=False) if use_ckpt else fn(x)

Combining Both

The two techniques are orthogonal and combine cleanly. Gradient checkpointing reduces the memory cost of a single forward-backward pass; gradient accumulation then lets you run many of those passes to build up to the effective batch size you need. With HuggingFace Trainer:

args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,       # micro-batch per GPU
    gradient_accumulation_steps=64,       # effective batch = 4 * 64 = 256
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
)

Memory Accounting

When debugging OOM errors, it helps to know what’s actually consuming GPU memory during a training step: model parameters, optimizer states (AdamW stores two extra tensors per parameter — momentum and variance), gradients, and activations. For a 7B model in bf16: parameters are roughly 14GB. AdamW optimizer states in fp32 are roughly 56GB — usually the largest component and why 7B full fine-tuning requires 80GB+ GPUs. Gradients add another 14GB. Activations depend on batch size and sequence length.

Gradient checkpointing only reduces activations. If you’re still OOM after enabling it, the next levers are reducing optimizer state memory (8-bit Adam via bitsandbytes, paged AdamW, or Adafactor), reducing micro-batch size, or switching to parameter-efficient methods like LoRA that reduce trainable parameter count and eliminate most optimizer state memory.

Gradient Accumulation in Distributed Training

When using DDP (DistributedDataParallel), gradient accumulation requires one additional consideration: by default, DDP synchronizes gradients across all processes after every backward call. With gradient accumulation, you only want synchronization at the actual optimizer step, not at every micro-batch — otherwise you’re paying the AllReduce communication cost accumulation_steps times per optimizer step instead of once.

PyTorch provides the no_sync() context manager specifically for this. Wrap all micro-batch backward passes except the last in model.no_sync(), and let the final micro-batch trigger the actual AllReduce:

for step, batch in enumerate(dataloader):
    is_sync_step = ((step + 1) % accumulation_steps == 0)

    ctx = contextlib.nullcontext() if is_sync_step else model.no_sync()

    with ctx:
        outputs = model(**batch)
        loss = outputs.loss / accumulation_steps
        loss.backward()

    if is_sync_step:
        optimizer.step()
        optimizer.zero_grad()

Without no_sync(), DDP gradient accumulation is functionally correct but wastes significant communication bandwidth. The overhead depends on model size and network bandwidth: on a high-bandwidth NVLink setup it may be negligible, but on multi-node training over Ethernet it can double the time per optimizer step.

HuggingFace Trainer and Accelerate handle this automatically when you configure gradient_accumulation_steps — they wrap the no_sync context for you. If you’re writing your own training loop with DDP, add no_sync explicitly.

Checkpointing Granularity Trade-offs

HuggingFace’s gradient_checkpointing_enable() applies checkpointing at the transformer layer granularity — each transformer block is a checkpoint segment. This is a sensible default but not always optimal. Coarser checkpointing (fewer segments, larger recomputed blocks) uses less memory for the checkpoint tensors themselves but requires more recomputation. Finer checkpointing (more segments, smaller recomputed blocks) reduces recomputation but increases checkpoint storage and the overhead of the checkpointing machinery itself.

For most transformer models, per-layer checkpointing hits the right balance. If you’re finding that gradient checkpointing still doesn’t reduce memory enough, you can apply checkpointing at a finer granularity — for example, checkpointing the attention and FFN sublayers separately rather than the full transformer block. This requires manually wrapping sublayers with torch.utils.checkpoint.checkpoint, but can reduce peak activation memory by another 30–40% at the cost of proportionally more recomputation.

The other dimension is which layers to checkpoint. Not all layers are equal memory consumers. Attention layers with long sequences have large activation tensors (the attention scores scale quadratically with sequence length). FFN layers with wide hidden dimensions have large intermediate activations. If memory is very tight and you want to minimize recomputation, selectively checkpointing only the attention layers while leaving FFN activations stored can be a useful middle ground — though this requires custom per-layer wrapping rather than the model-level enable call.

When to Use Each Technique

Use gradient accumulation when your training requires a large effective batch size for convergence stability (common in pre-training and contrastive learning), but your micro-batch is already as large as your GPU memory allows. The technique is free in terms of memory and only costs wall-clock time proportional to the number of accumulation steps.

Use gradient checkpointing when you’re running OOM during training and the culprit is activation memory — specifically when sequence lengths are long, models are deep, or you need to fit a larger micro-batch to improve GPU utilization. The 20–25% compute overhead is usually worth the 4–8x reduction in activation memory it provides.

In practice, both are enabled together in most serious fine-tuning and pre-training setups. Gradient checkpointing allows larger micro-batches given a memory budget; gradient accumulation then stacks micro-batches up to the effective batch size the optimizer needs. They’re complementary tools that address different constraints in the same training setup.

Choosing an Effective Batch Size

The right effective batch size depends on the task and optimization dynamics, not just what fits in memory. For language model fine-tuning, the common guidance is that larger batch sizes require proportionally higher learning rates (the linear scaling rule) and may converge to sharper minima — though the empirical relationship is nuanced and task-dependent. For contrastive learning (where the number of in-batch negatives matters), larger batches directly improve training signal quality. For supervised fine-tuning of LLMs on small instruction datasets, smaller effective batches often work well and let you iterate faster.

A practical approach: start with an effective batch size in the range that published work on similar tasks has found effective (often 64–256 for LLM fine-tuning), then measure validation loss trajectory and adjust. Gradient accumulation makes it cheap to experiment — doubling accumulation_steps doubles effective batch size with no infrastructure changes, just twice the wall-clock time per optimizer step. Use this flexibility to empirically find what works for your specific dataset and model rather than treating batch size as fixed.

Verifying Correctness

One way to verify that your gradient accumulation implementation is correct is to compare the parameter updates between a single large-batch step and an equivalent number of accumulated micro-batch steps. On a small model with a fixed random seed, the parameter values after one step should be nearly identical (within floating point tolerance) between the two approaches. Any significant divergence indicates a bug — usually either missing loss normalization or incorrect optimizer/scaler call placement. Running this sanity check once when setting up a new training loop is worth the few minutes it takes.

Similarly for gradient checkpointing: verify that loss values and gradient norms match between checkpointed and non-checkpointed runs on the same inputs before committing to checkpointing in a long training run. Numerical differences should be within typical floating point variance. Larger differences can indicate incompatibilities between checkpointing and custom layers — custom attention implementations, for example, sometimes require special handling to be compatible with the non-reentrant checkpoint API.

Selective Checkpointing Strategies

Not all transformer layers are equal candidates for gradient checkpointing. The layers that consume the most activation memory are those with large intermediate tensors: attention layers on long sequences (where the attention score matrix is seq_len × seq_len) and FFN layers with wide intermediate projections (typically 4× the hidden dimension). Layers near the beginning of the network also accumulate activation memory for longer, since their outputs must be retained through all subsequent forward passes before the backward pass reaches them.

A selective checkpointing strategy applies checkpointing only to the highest-memory layers rather than uniformly to all layers. For a 32-layer model where only 16 layers are checkpointed (the first and last 8, for example, which are the highest-leverage for memory reduction), the recomputation overhead is roughly half that of full checkpointing while still capturing most of the memory savings. PyTorch’s checkpoint API makes this straightforward — you apply the checkpoint wrapper selectively per-layer rather than using the model-level enable call.

For very long sequences where attention activation memory dominates (the attention score tensor grows as seq_len^2), checkpointing the attention sublayers while leaving the FFN activations stored can be a useful pattern. The attention score tensor for a 32K sequence in bf16 is 32768 × 32768 × 2 bytes ≈ 2GB per layer per batch sample — clearly the dominant activation memory consumer. Checkpointing attention specifically eliminates this cost while the smaller FFN activations are retained, reducing total recomputation overhead compared to full layer checkpointing.

Training Stability with Gradient Accumulation

Large effective batch sizes affect optimization dynamics in ways that go beyond memory. The sharpness of loss landscape minima tends to correlate with batch size — larger batches converge to sharper minima with worse generalization in some settings (though this relationship is debated and task-dependent). The practical implication is that simply scaling effective batch size with gradient accumulation without adjusting learning rate and warmup schedule can produce suboptimal results even when the training loss converges.

The linear scaling rule, originally formulated for image classification but widely applied in LLM fine-tuning, suggests scaling learning rate proportionally with effective batch size: if you double the effective batch from 128 to 256, double the peak learning rate. This keeps the per-sample learning rate signal consistent. In practice, strict linear scaling often overshoots for large batch sizes, and a square-root scaling rule (scale LR by sqrt(batch_size_ratio)) is more conservative and often more stable.

Warmup schedule length should also scale with effective batch size. The warmup period is designed to let the optimizer accumulate reliable gradient statistics before taking large steps. With a larger effective batch, each optimizer step covers more data and gradient estimates are more reliable sooner — but the learning rate is also higher, making early instability more damaging. A common heuristic is to keep the warmup duration in terms of effective samples (not steps) constant when changing batch size, which means fewer warmup steps at larger batch sizes but the same amount of data seen during warmup.

Profiling Memory During Training

When debugging OOM errors or trying to understand where memory is going, PyTorch’s memory profiler provides precise attribution. torch.cuda.memory_summary() gives a snapshot of current allocation and peak usage. For a step-by-step breakdown, the PyTorch memory profiler can record every allocation and free during a training step and produce a timeline showing exactly when peak memory occurs and which operations are responsible.

The most useful diagnostic is often the difference between allocated memory and reserved memory. PyTorch’s CUDA caching allocator reserves memory from the OS in large chunks and sub-allocates from those chunks to avoid frequent system calls. reserved_memory – allocated_memory is memory that PyTorch holds but isn’t currently using — effectively fragmentation. If reserved memory is much larger than allocated memory and you’re close to OOM, you can call torch.cuda.empty_cache() to release reserved-but-unused memory back to the OS (though this doesn’t free actually-used memory and the allocator will re-request it on the next large allocation). Persistent large gaps between allocated and reserved memory suggest memory fragmentation that can be mitigated by changing allocation patterns — for example, allocating all large tensors before small ones to reduce fragmentation.

Leave a Comment