Mixed precision training is one of the cheapest performance improvements available in PyTorch: you change three to five lines of code and get 1.5–2x training throughput with the same GPU, often with no change in final model quality. Despite this, the details of how it works — why fp16 needs a GradScaler while bf16 does not, when each format is appropriate, and what to do when mixed precision causes NaN losses — are poorly explained in most tutorials. This article covers the mechanics of both formats, the complete implementation with torch.amp, common failure modes and how to debug them, and the specific considerations for fine-tuning large models.
fp16 vs bf16: What the Difference Actually Means
Both fp16 and bf16 use 16 bits per value compared to 32 bits for float32, halving memory usage for stored activations and parameters. The difference is in how those 16 bits are allocated between the exponent (which controls range) and the mantissa (which controls precision). fp16 uses 5 bits for the exponent and 10 for the mantissa, giving it higher precision but a limited numerical range: the maximum representable value is 65,504 and values below roughly 6e-5 underflow to zero. bf16 uses 8 bits for the exponent and 7 for the mantissa — the same exponent width as float32 — giving it the same numerical range as float32 at the cost of lower precision per value.
In practice this means: bf16 almost never overflows or underflows during training because its range matches float32, so gradient scaling is unnecessary. fp16 frequently produces gradient values that overflow to infinity or underflow to zero, particularly for large models with large activation magnitudes, which is why fp16 training requires a GradScaler to dynamically rescale the loss before the backward pass. On Ampere GPUs and newer (A100, H100, RTX 3090+), both formats use Tensor Cores and deliver similar throughput. On Volta and Turing GPUs (V100, RTX 2080), only fp16 has Tensor Core support — bf16 runs at fp32 speed on these. If your GPU is an A100 or newer, prefer bf16 for its stability. If you are on a V100 or older, use fp16 with GradScaler.
Basic AMP Training Loop
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# GradScaler is only needed for fp16 — not for bf16
# For bf16, pass enabled=False or just don't use it
scaler = GradScaler(device='cuda', enabled=True) # set enabled=False for bf16
# Determine dtype based on GPU capability
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
# autocast converts eligible ops to the specified dtype automatically
with autocast(device_type='cuda', dtype=dtype):
outputs = model(inputs)
loss = criterion(outputs, targets)
# fp16: scale loss before backward to prevent underflow
# bf16: scaler.scale() is a no-op when enabled=False
scaler.scale(loss).backward()
# Unscale gradients before clipping so clip threshold is in the right units
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# scaler.step() skips the update if gradients contain inf/NaN
scaler.step(optimizer)
scaler.update() # adjusts scale factor based on whether inf/NaN appeared
The autocast context manager decides which operations to run in the lower precision dtype and which to keep in float32. Matrix multiplications and convolutions are converted (they are where the Tensor Core speedup comes from). Operations sensitive to precision like softmax, layer norm, and loss functions are kept in float32. You do not need to manually cast tensors — autocast handles this, including casting inputs to lower precision as they enter eligible ops and keeping accumulations in float32.
GradScaler Internals and Tuning
GradScaler works by multiplying the loss by a scale factor (initially 65536 by default) before the backward pass. This shifts gradient values away from the fp16 underflow range. After the backward pass, it divides the gradients back by the same scale factor before the optimizer step. If any gradient value is infinite or NaN (indicating the scale was too large and caused overflow), the optimizer step is skipped entirely for that batch and the scale factor is halved. If several consecutive batches complete without overflow, the scale factor is gradually increased again. This adaptive scaling converges to a value that keeps gradients in a numerically stable range for the current model and learning rate.
from torch.amp import GradScaler
# Default scaler — reasonable for most models
scaler = GradScaler(device='cuda')
# Tuned scaler for models prone to instability
scaler_conservative = GradScaler(
device='cuda',
init_scale=2**14, # start lower than default 2**16
growth_factor=2.0, # double scale when stable (default)
backoff_factor=0.5, # halve scale on overflow (default)
growth_interval=2000, # wait 2000 stable steps before growing (default: 2000)
)
# Monitor scaler behaviour during training
for step, batch in enumerate(dataloader):
# ... forward/backward as above ...
if step % 100 == 0:
print(f"Step {step} | loss scale: {scaler.get_scale():.0f}")
# Healthy: scale stays large (1000+) and grows over time
# Problem: scale keeps halving toward 1.0 — indicates persistent overflow
A scale that keeps halving toward 1.0 is a sign that the model is producing infinite gradients even at low scale. This usually means the model has an architecture issue (missing layer norm, extreme learning rate) rather than a precision issue. Switching to bf16 often resolves it immediately since bf16 cannot overflow. If you must use fp16 and the scaler collapses, check for operations that produce very large activations — large embedding tables with no normalisation, attention logits without temperature scaling, or residual connections that compound activation magnitudes across many layers.
Mixed Precision with HuggingFace Trainer
from transformers import TrainingArguments, Trainer
# bf16 (preferred on A100/H100)
args_bf16 = TrainingArguments(
output_dir="./output",
bf16=True, # enables bf16 AMP — no GradScaler needed
fp16=False,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
learning_rate=2e-5,
)
# fp16 (for older GPUs without bf16 Tensor Core support)
args_fp16 = TrainingArguments(
output_dir="./output",
fp16=True, # enables fp16 AMP with GradScaler automatically
bf16=False,
fp16_opt_level="O1", # O1: mixed precision; O2: more aggressive
per_device_train_batch_size=16,
)
# For FSDP + bf16 (multi-GPU fine-tuning of large models)
args_fsdp = TrainingArguments(
output_dir="./output",
bf16=True,
fsdp="full_shard",
fsdp_config={"backward_prefetch": "backward_pre"},
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
)
What autocast Does and Does Not Cover
Understanding which operations autocast converts and which it leaves in float32 prevents a common class of bugs where engineers assume their entire model runs in fp16/bf16 but certain layers remain in float32, causing unexpected memory usage or type mismatch errors.
import torch
from torch.amp import autocast
# These ops are converted to lower precision inside autocast:
# - nn.Linear (matmul)
# - nn.Conv1d/2d/3d
# - nn.MultiheadAttention
# - torch.matmul, torch.bmm, torch.mm
# These ops stay in float32 even inside autocast:
# - nn.LayerNorm, nn.BatchNorm
# - nn.Softmax
# - Loss functions (cross_entropy, mse_loss, etc.)
# - torch.exp, torch.log, torch.pow
# Verify which dtype your tensors are in at different points
model = nn.Sequential(nn.Linear(64, 64), nn.LayerNorm(64), nn.Linear(64, 10)).cuda()
x = torch.randn(4, 64, device='cuda')
with autocast(device_type='cuda', dtype=torch.float16):
h1 = model[0](x) # Linear -> fp16
h2 = model[1](h1) # LayerNorm -> fp32 (autocast promotes back)
out = model[2](h2) # Linear -> fp16
print(f"Linear out: {h1.dtype}") # torch.float16
print(f"LayerNorm out: {h2.dtype}") # torch.float32
print(f"Final out: {out.dtype}") # torch.float16
# Parameters stay in fp32 always — only activations are cast
for name, param in model.named_parameters():
assert param.dtype == torch.float32, f"{name} is {param.dtype}"
Parameters are always stored in float32. The autocast context only casts activations as they flow through eligible ops — it does not change the dtype of the model’s .weight tensors. This is intentional: weight updates in float32 preserve the small gradient steps that would round to zero in float16. If you need to actually store parameters in half precision for inference (to reduce loaded model size), use model.half() or model.to(torch.bfloat16) explicitly after training.
Mixed Precision for Inference
import torch
from torch.amp import autocast
model.eval()
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
# AMP inference: same autocast context, no GradScaler needed
with torch.no_grad():
with autocast(device_type='cuda', dtype=dtype):
outputs = model(inputs)
# For maximum inference throughput: convert model weights to bf16 permanently
model_bf16 = model.to(torch.bfloat16)
model_bf16.eval()
with torch.no_grad():
# No autocast needed — model and inputs both already bf16
outputs = model_bf16(inputs.to(torch.bfloat16))
# Benchmark the difference
import time
def benchmark(fn, n=100):
# Warmup
for _ in range(10): fn()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n): fn()
torch.cuda.synchronize()
return (time.perf_counter() - t0) / n * 1000 # ms per call
x = torch.randn(32, 512, device='cuda')
fp32_ms = benchmark(lambda: model(x))
bf16_ms = benchmark(lambda: model_bf16(x.to(torch.bfloat16)))
print(f"fp32: {fp32_ms:.2f}ms | bf16: {bf16_ms:.2f}ms | speedup: {fp32_ms/bf16_ms:.2f}x")
Debugging NaN Losses with Mixed Precision
NaN losses are the most common mixed precision failure and the least informative — the loss becomes NaN but the traceback points to the loss computation, not the underlying cause. The systematic debugging approach is: first disable mixed precision entirely and check if the NaN appears in float32. If it does, the problem is not precision-related. If float32 trains cleanly, the NaN is caused by fp16 overflow in a specific operation and you need to find which one.
import torch
# Register hooks to find which op first produces inf/NaN
nan_origins = []
def make_nan_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
if not torch.isfinite(output).all():
nan_origins.append(name)
print(f"NaN/Inf first appeared in: {name}")
return hook
# Register on every module
hooks = []
for name, module in model.named_modules():
hooks.append(module.register_forward_hook(make_nan_hook(name)))
# Run one forward pass with AMP
with autocast(device_type='cuda', dtype=torch.float16):
output = model(inputs)
loss = criterion(output, targets)
# Clean up hooks
for h in hooks: h.remove()
if nan_origins:
print(f"First NaN at: {nan_origins[0]}")
# Common fixes:
# - Add LayerNorm before that module
# - Clamp activations: torch.clamp(x, min=-1e4, max=1e4)
# - Switch that specific module to fp32: module.to(torch.float32)
The most common fix once you locate the offending operation is to exempt it from autocast by wrapping it in with autocast(enabled=False). This keeps that specific operation in float32 while leaving the rest of the model in fp16, preserving most of the throughput benefit while avoiding the overflow. For attention layers specifically, the logit values scale with sequence length and can overflow fp16 at long contexts — this is one reason FlashAttention implements its own numerically stable attention computation rather than relying on the standard matmul approach.
Mixed Precision with Gradient Accumulation
Gradient accumulation — running multiple small forward-backward passes before each optimizer step — interacts with AMP in one important way: you should only call scaler.update() and scaler.step() at the actual optimizer step, not at every micro-step. The GradScaler’s overflow detection works per optimizer step, so calling it on every accumulation step would incorrectly reset the scale on each micro-batch rather than across the full accumulated gradient.
accumulation_steps = 4
scaler = GradScaler(device='cuda')
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
for step, batch in enumerate(dataloader):
inputs, targets = batch[0].cuda(), batch[1].cuda()
with autocast(device_type='cuda', dtype=dtype):
outputs = model(inputs)
# Divide loss by accumulation steps so gradient magnitude is correct
loss = criterion(outputs, targets) / accumulation_steps
scaler.scale(loss).backward()
# Only step the optimizer every accumulation_steps batches
if (step + 1) % accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Memory Savings in Practice
The memory reduction from mixed precision is real but often misunderstood. The savings come primarily from activations stored during the forward pass for use in the backward pass — these are stored in fp16/bf16 and consume roughly half the memory of float32 activations. Model parameters and optimizer states remain in float32. This means the total memory savings depend on the ratio of activation memory to parameter memory, which is higher for large batch sizes and long sequences than for small batches.
For transformer models, where attention activations scale quadratically with sequence length, mixed precision combined with gradient checkpointing provides compounding memory savings: gradient checkpointing reduces the number of stored activations, and mixed precision halves the size of each stored activation. On an A100 with 80GB VRAM, a Llama-3 8B model fine-tuned at sequence length 4096 with batch size 4 fits in memory with bf16 + gradient checkpointing, but overflows at float32 with the same settings. The interaction between these two techniques is why they are almost always used together in large model fine-tuning recipes.
One memory overhead that surprises engineers new to mixed precision: the GradScaler itself maintains a scale tensor and some internal state, but this is negligible (a few KB). The real hidden cost is that some frameworks maintain a float32 copy of parameters alongside the fp16/bf16 training copy for the weight update step — this is called “master weights” and doubles the parameter memory footprint. PyTorch’s native AMP does not do this; APEX’s O2 level does. If you are using PyTorch’s built-in torch.amp (which you should for new code), there are no master weights and parameters are always float32 — you only pay for the activation savings.
Choosing Between fp16 and bf16 in Practice
The decision is almost always: use bf16 on A100/H100/newer hardware, use fp16 on V100/older hardware. The only reason to use fp16 on modern hardware is if you are running a model checkpoint that was originally trained in fp16 and shows quality differences when evaluated in bf16 — a rare situation that usually indicates the original training used numerically unstable operations that happened to cancel out in fp16. For fine-tuning and training from scratch on any Ampere or Hopper GPU, bf16 is strictly better: same throughput, more numerical stability, no GradScaler complexity. For inference, the same rule applies — bf16 on modern hardware, fp16 on older hardware, and float32 only when you need exact numerical reproducibility or are running on hardware without half-precision Tensor Cores. The performance difference between fp16 and bf16 on A100 is negligible (both use the same Tensor Core paths), so there is no throughput reason to prefer fp16 on modern hardware under any circumstances.
Expected Speedup and When It Does Not Materialise
The typical training throughput improvement from enabling AMP is 1.5–2x on Tensor Core-equipped GPUs for transformer models. The speedup comes almost entirely from the matrix multiplications in attention and feed-forward layers, which run significantly faster on Tensor Cores in fp16/bf16 than in float32. If you enable AMP and see little or no speedup, the most common cause is that your training loop is bottlenecked somewhere other than the GPU computation: the DataLoader is not prefetching fast enough (CPU-bound data loading), gradient accumulation is set too high so the optimizer step overhead dominates, or the model has too many non-Tensor-Core operations (custom CUDA kernels, element-wise operations on small tensors) relative to the matmul operations that benefit from lower precision. Profile with torch.profiler to confirm where the time is actually being spent before concluding that AMP is not providing benefit — in the vast majority of transformer training scenarios it does, and the three-line change to enable it is almost always worth making before trying more complex optimisations.
The short checklist for enabling mixed precision correctly: pick bf16 on A100/H100, fp16 with GradScaler on older hardware; wrap forward and loss computation in autocast; call scaler.unscale_ before gradient clipping; call scaler.step and scaler.update at the optimizer step boundary, not at every accumulation micro-step; and monitor the scale factor for the first few hundred steps to confirm it is not collapsing. That combination covers the full implementation correctly and safely for the vast majority of training workloads. For most transformer training jobs on modern hardware, enabling bf16 AMP is the single highest-return-on-effort optimisation available, and it should be the first change you make before considering quantisation, operator fusion, or other more complex techniques.