AdamW vs Adafactor vs Lion: Choosing an Optimizer for LLM Training

Choosing an optimizer for LLM training involves tradeoffs that do not appear at smaller scales: optimizer state memory can rival model parameter memory, different optimizers interact differently with learning rate schedules and weight decay, and some optimizers that work well at fine-tuning scale fail to converge on pretraining workloads. AdamW remains the industry default, but Adafactor, Lion, and 8-bit Adam each address specific constraints that make AdamW impractical in some settings. Understanding the mechanics behind each optimizer makes it easier to choose correctly and diagnose convergence problems when they arise.

AdamW: The Default and Why It Works

AdamW maintains two moment estimates per parameter: a first moment (exponential moving average of gradients, analogous to momentum) and a second moment (exponential moving average of squared gradients, used for per-parameter learning rate scaling). The update rule scales each parameter’s learning rate inversely by the square root of its second moment, giving parameters with consistently large gradients a smaller effective learning rate and parameters with small or infrequent gradients a larger effective learning rate. This per-parameter adaptivity is what makes Adam converge faster than SGD on most deep learning tasks — different layers and different parameter types in a transformer have very different gradient magnitudes, and a single global learning rate is a poor fit for all of them simultaneously.

import torch
import torch.nn as nn
from torch.optim import AdamW

def configure_adamw(model: nn.Module, lr: float = 3e-4, weight_decay: float = 0.1):
    """AdamW with correct weight decay application.
    
    Key insight: weight decay should NOT be applied to biases, LayerNorm/RMSNorm
    parameters, or embeddings — only to weight matrices. Applying it to all
    parameters (as naive implementations do) regularises the wrong things.
    """
    decay_params = []
    no_decay_params = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # No weight decay for 1D params (biases, norm weights) or embeddings
        if param.ndim <= 1 or 'bias' in name or 'norm' in name.lower() or 'embed' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    param_groups = [
        {'params': decay_params,    'weight_decay': weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0},
    ]
    return AdamW(param_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)

# Standard LLM training hyperparameters:
# lr: 1e-4 to 3e-4 for pretraining, 1e-5 to 2e-5 for SFT fine-tuning
# betas: (0.9, 0.95) — higher beta2 than default (0.999) for stability at large batch
# weight_decay: 0.1 for pretraining, 0.0 to 0.01 for fine-tuning
# eps: 1e-8 (default) — increase to 1e-6 if you see NaN losses early in training

The 'W' in AdamW stands for decoupled weight decay — the critical difference from the original Adam optimizer. In standard L2-regularised Adam, weight decay is applied to the gradient before the adaptive scaling, meaning parameters with large second moments receive less regularisation. AdamW applies weight decay directly to the parameter after the update step, independently of the gradient scaling. Empirically, decoupled weight decay consistently improves generalisation on language modelling tasks, and it is the reason virtually every LLM training recipe uses AdamW rather than Adam. The practical implication is that you should always use AdamW over Adam for transformer training, and you should configure the parameter groups to exclude biases and normalization parameters from weight decay as shown above.

The Memory Problem: AdamW at Scale

AdamW's Achilles heel is memory. For a model with P parameters stored in bfloat16 (2 bytes each), AdamW requires 8 additional bytes per parameter for the two moment estimates stored in float32 — four times the model's own memory footprint just for optimizer state. A 7B parameter model needs roughly 14GB for parameters in bfloat16, but AdamW adds another 56GB for optimizer state, for a total of 70GB just for weights and optimizer — before activations, gradients, or the data batch. At 70B parameters, optimizer state alone exceeds 500GB. This is why memory-efficient alternatives to AdamW become necessary as model scale increases.

Adafactor: Memory-Efficient Adaptive Optimization

Adafactor addresses AdamW's memory cost by factorising the second moment matrix. Instead of storing one float32 second moment per parameter, Adafactor stores row and column statistics separately for 2D weight matrices, which reduces the second moment storage from O(rows × cols) to O(rows + cols). For a 4096×4096 weight matrix, this reduces second moment memory from 64MB to 32KB — a 2000x reduction. Combined with no first moment (Adafactor uses only second moments), the total optimizer state is roughly the same size as the model parameters rather than 4x larger.

from transformers import Adafactor
import torch.nn as nn

def configure_adafactor(model: nn.Module, lr: float = None):
    """Adafactor configuration for memory-constrained LLM training.
    
    With lr=None, Adafactor uses its built-in learning rate schedule
    (1/sqrt(step)), which works well for pretraining but can be too
    aggressive for fine-tuning. For fine-tuning, set an explicit lr.
    """
    return Adafactor(
        model.parameters(),
        lr=lr,
        relative_step=(lr is None),       # use built-in schedule if no lr given
        scale_parameter=(lr is None),      # scale lr by parameter RMS
        warmup_init=(lr is None),          # warmup if using built-in schedule
        weight_decay=0.0,                  # Adafactor handles regularisation differently
        clip_threshold=1.0,                # gradient clipping built in
    )

# Memory comparison (approx, float32 optimizer state):
# Model: 7B params at bf16 = ~14GB
# AdamW optimizer state: 7B * 8 bytes = ~56GB
# Adafactor optimizer state: ~14GB (same as model params, factored 2nd moment)
# 8-bit Adam optimizer state: ~7GB (quantised to int8)

The tradeoff with Adafactor is convergence quality and training stability. Without a first moment, Adafactor lacks the momentum that helps AdamW navigate noisy loss landscapes and escape flat regions. On large pretraining runs with many tokens, Adafactor with its default learning rate schedule converges to comparable loss as AdamW, but it can be less stable in the early training phase and more sensitive to batch size. For fine-tuning, use Adafactor with an explicit learning rate rather than its built-in schedule — the built-in schedule decays too aggressively for fine-tuning workloads where you typically want a small constant or cosine-decayed learning rate rather than a 1/sqrt(t) schedule. T5, FLAN-T5, and several other Google Research models were trained with Adafactor, validating its pretraining quality at scale.

8-Bit Adam: AdamW Quality at Half the Memory

8-bit Adam, implemented in the bitsandbytes library, keeps AdamW's two moment estimates but stores them in 8-bit integers instead of float32. This reduces optimizer state memory from 8 bytes per parameter to 2 bytes — the same size as the model parameters in bfloat16. The quantisation uses dynamic exponent quantisation, which is more numerically stable than fixed-point int8 for the wide dynamic range of optimizer moments. In practice, 8-bit Adam produces training curves nearly identical to 32-bit AdamW with essentially no quality loss, and it is the simplest drop-in for reducing optimizer memory when AdamW is too large for your GPU budget.

import bitsandbytes as bnb
import torch.nn as nn

def configure_8bit_adam(model: nn.Module, lr: float = 2e-5, weight_decay: float = 0.01):
    """8-bit AdamW — same quality as AdamW, half the optimizer memory.
    
    Requires: pip install bitsandbytes
    Works on CUDA GPUs; CPU and MPS not supported.
    """
    decay_params, no_decay_params = [], []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndim <= 1 or 'bias' in name or 'norm' in name.lower():
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    return bnb.optim.AdamW8bit(
        [
            {'params': decay_params,    'weight_decay': weight_decay},
            {'params': no_decay_params, 'weight_decay': 0.0},
        ],
        lr=lr,
        betas=(0.9, 0.999),
        eps=1e-8,
    )

# Paged Adam: same as 8-bit Adam but offloads optimizer state to CPU RAM
# when GPU memory is full, then pages it back as needed. Useful for very
# large models on single-GPU setups.
paged_adam = bnb.optim.PagedAdamW8bit(model.parameters(), lr=2e-5)

Lion: Sign-Based Updates for Memory and Speed

Lion (EvoLved Sign Momentum), discovered via program search in 2023, computes the sign of the gradient combined with momentum rather than the gradient magnitude. The update is: m ← β₁m + (1-β₁)g; θ ← θ - lr × sign(m + β₂g) × weight_decay × θ. Because Lion uses only the sign of the update direction, it stores only one moment vector (the momentum) rather than AdamW's two, reducing optimizer state by half compared to AdamW. Lion also applies a uniform effective learning rate to all parameters regardless of gradient magnitude — effectively a sign-based SGD with momentum — which can converge faster in terms of steps at the cost of requiring a smaller learning rate (roughly 10x smaller than AdamW's learning rate for equivalent convergence).

import torch
import torch.nn as nn

class Lion(torch.optim.Optimizer):
    """Lion optimizer — memory-efficient, sign-based updates."""
    def __init__(self, params, lr: float = 1e-4, betas=(0.9, 0.99), weight_decay: float = 0.0):
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            wd = group['weight_decay']
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                exp_avg = state['exp_avg']
                # Update step: sign of interpolated momentum + current grad
                update = exp_avg.lerp(grad, 1 - beta1).sign_()
                p.add_(update * lr + p * wd * lr, alpha=-1)
                # Update momentum with higher beta2
                exp_avg.lerp_(grad, 1 - beta2)

# Lion hyperparameters: lr ~3e-5 to 1e-4 (10x smaller than AdamW)
# betas: (0.9, 0.99) — higher beta2 than AdamW's typical (0.9, 0.95)
# weight_decay: 0.1 to 1.0 — Lion typically needs stronger weight decay than AdamW

Choosing Your Optimizer: Decision Framework

Use AdamW as your default for fine-tuning tasks at any scale. It is the most well-validated optimizer for transformer fine-tuning, and the parameter group configuration (excluding biases and norm params from weight decay) is a small one-time effort with meaningful quality benefit. For fine-tuning a 7B or 13B model on a single 80GB GPU, AdamW with bfloat16 parameters and a standard batch size is feasible — the optimizer state fits alongside the model and gradients.

Switch to 8-bit Adam when AdamW's optimizer state is pushing you to smaller batch sizes or requiring CPU offloading. It is the lowest-friction memory optimisation available — same training code, same hyperparameters, half the optimizer memory. The bitsandbytes library installs cleanly on most CUDA setups and integrates with Hugging Face Trainer via a single flag: optim="adamw_bnb_8bit" in TrainingArguments.

Use Adafactor for pretraining runs at 30B+ parameter scale where even 8-bit Adam is too large, or when training on hardware with very limited memory per accelerator. Be prepared to tune the learning rate schedule explicitly rather than relying on Adafactor's built-in schedule for fine-tuning tasks. Use Lion when you want AdamW-quality convergence with 50% less optimizer memory and you are willing to tune a smaller learning rate and stronger weight decay. Lion's quality on language modelling is competitive with AdamW, but its hyperparameter sensitivity — particularly the need for a much smaller learning rate — makes it less forgiving for practitioners who are adapting existing AdamW recipes without careful retuning.

Gradient Clipping and Its Interaction with Optimizers

Gradient clipping is an almost universal component of LLM training and interacts differently with each optimizer. The standard approach is global norm clipping: compute the global L2 norm of all gradients concatenated, and if it exceeds a threshold (typically 1.0), scale all gradients down proportionally. This prevents any single large gradient update from destabilising training without zeroing out the gradient direction entirely. For AdamW, gradient clipping is applied before the optimizer step, so the adaptive second moments still accumulate based on unclipped gradient magnitudes — this is usually fine and the recommended approach. Adafactor has gradient clipping built into the optimizer via its clip_threshold parameter, which clips the update by its RMS rather than the gradient norm. For Lion, gradient clipping should be applied to gradients before the optimizer step, since Lion uses the sign of gradients — an extremely large gradient and a normal gradient produce the same sign, but clipping prevents them from corrupting the momentum accumulation in extreme cases.

When training loss suddenly spikes after hundreds of stable steps, gradient clipping threshold is usually the first thing to check. A norm spike that briefly exceeds the clip threshold by 5–10x suggests the model encountered an anomalous batch or a numerically unstable region — raise the clip threshold temporarily and see whether the spike recurs, or examine the loss contribution of individual data batches around the spike. Persistent gradient norm growth that clips every step suggests the learning rate is too high or weight decay is insufficient; reducing learning rate by 2–3x usually resolves this.

Optimizer State Checkpointing

Optimizer state is the most frequently overlooked component of LLM training checkpoints. A complete checkpoint for resumable training must include model weights, optimizer state, the learning rate scheduler state, and the random number generator state. Omitting optimizer state and resuming from a model-only checkpoint resets the moment estimates to zero, which causes the optimizer to behave as if training is starting fresh and can produce a transient loss spike while the moments re-accumulate. For AdamW, this spike typically lasts 100–500 steps; for 8-bit Adam it behaves similarly. For Adafactor, the factored second moments accumulate faster and the spike is usually shorter but can be more pronounced if the learning rate schedule is also reset.

The practical consequence is that partial checkpoint saves (saving only model weights to reduce storage) are fine for inference but not for training resumption. Always save complete checkpoints — model weights, optimizer state, scheduler state — at regular intervals when running long training jobs. For very large models where optimizer state is hundreds of gigabytes, consider saving full checkpoints every N steps and model-only checkpoints at intermediate steps, ensuring you always have at least one recent complete checkpoint available for resumption. Tools like Hugging Face Accelerate and DeepSpeed handle multi-GPU optimizer state sharding and checkpointing correctly out of the box, but if you are managing checkpoints manually, verify that your checkpoint includes all required components by running a resume-from-checkpoint test before launching a long run.

Optimizer Quick Reference

AdamW: the right default for 99% of fine-tuning tasks. Configure parameter groups to exclude biases and norm parameters from weight decay. Use betas (0.9, 0.95) for pretraining, (0.9, 0.999) for fine-tuning. Weight decay 0.1 for pretraining, 0.01 or less for fine-tuning. 8-bit Adam: identical hyperparameters to AdamW, half the optimizer memory, no quality tradeoff. The best first step when AdamW runs out of GPU memory. Install bitsandbytes and pass a single flag to Hugging Face Trainer. Adafactor: use for pretraining runs above 30B parameters where optimizer memory is the hard constraint. Use an explicit learning rate for fine-tuning; the built-in schedule is designed for pretraining. Be prepared for slightly more sensitivity to batch size and learning rate warmup. Lion: 50% less optimizer memory than AdamW, competitive quality, but requires 10x smaller learning rate and stronger weight decay. Not a drop-in for AdamW — needs deliberate hyperparameter retuning. Most beneficial when you need AdamW-class quality with Adafactor-class memory, and are willing to invest in careful tuning.

Leave a Comment