Fully Sharded Data Parallel (FSDP) is PyTorch’s native framework for training large models across multiple GPUs. Unlike DDP, which replicates the full model on each GPU, FSDP shards model parameters, gradients, and optimizer states across all workers — letting you train models that would never fit on a single GPU’s memory. Fine-tuning Llama 3 models with 7B–70B parameters is a common use case where FSDP becomes necessary: a 70B model in bfloat16 requires roughly 140GB just for parameters, and with gradients and optimizer state you’re looking at 420–560GB total, requiring a multi-GPU setup with FSDP sharding to make it tractable.
This guide covers the complete setup for fine-tuning Llama 3 with FSDP: environment configuration, the FSDP wrapping strategy for transformer models, memory-efficient training settings, gradient checkpointing integration, and checkpoint saving and loading.
Environment and Dependencies
FSDP is built into PyTorch (2.0+) and requires no additional installation. For Llama 3, you need access through Hugging Face (gated model, requires accepting Meta’s license) and the transformers library:
pip install torch transformers accelerate sentencepiece
huggingface-cli login # required for gated model access
FSDP training is launched with torchrun rather than python, which sets up the distributed process group automatically:
# Single node, 8 GPUs
torchrun --nproc_per_node=8 train_fsdp.py
# Multi-node (2 nodes, 8 GPUs each)
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 --master_addr= --master_port=29500 train_fsdp.py
The NCCL backend is used for inter-GPU communication and requires that all GPUs can communicate over NVLink (within node) or InfiniBand/RoCE (across nodes). For AWS, use p3dn.24xlarge or p4d.24xlarge instances with Elastic Fabric Adapter for best inter-node bandwidth.
FSDP Wrapping Strategy for Transformers
The most important FSDP configuration decision is the wrapping policy — which submodules to wrap as independent FSDP units. Each wrapped unit becomes an independently sharded piece; during the forward pass, FSDP performs an all-gather to reconstruct the full parameters of each unit just before it’s needed, then discards them after use to free memory. The granularity of wrapping trades off between memory efficiency (finer wrapping = less peak memory because parameters are gathered and discarded more frequently) and communication overhead (finer wrapping = more all-gather operations).
For transformer models, the standard approach is to wrap each transformer layer (LlamaDecoderLayer for Llama 3) as a separate FSDP unit. PyTorch provides transformer_auto_wrap_policy for this:
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import functools
def setup():
dist.init_process_group("nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def get_model():
local_rank = int(os.environ["LOCAL_RANK"])
# Load model on CPU first to avoid OOM on rank 0
with torch.device("meta"):
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
)
# FSDP wrap policy: shard each decoder layer independently
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(
model,
auto_wrap_policy=wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
sync_module_states=True, # broadcast weights from rank 0
param_init_fn=lambda module: module.to_empty(
device=torch.cuda.current_device(), recurse=False
),
)
return model
Loading with torch.device(“meta”) creates the model with no actual tensor storage — only the shape and dtype metadata. This avoids the rank-0 OOM that would occur if the full model were materialized on a single GPU before sharding. The param_init_fn and sync_module_states=True combination materializes each layer on the correct device and broadcasts weights from rank 0 to all other ranks during FSDP initialization.
Sharding Strategies
FSDP supports four sharding strategies that trade memory savings for communication volume. FULL_SHARD (ZeRO-3 equivalent) shards parameters, gradients, and optimizer states across all workers — maximum memory savings, highest communication volume. SHARD_GRAD_OP (ZeRO-2 equivalent) shards only gradients and optimizer states, keeping full parameter copies during the forward and backward passes — less communication overhead but higher peak memory. NO_SHARD makes FSDP behave like DDP — useful for testing. HYBRID_SHARD (available in PyTorch 2.1+) shards fully within each node but replicates across nodes — a good choice for multi-node setups where inter-node bandwidth is the bottleneck, since it eliminates cross-node all-gather for parameters while still sharding within-node for memory efficiency.
For fine-tuning Llama 3 8B on 8x A100 80GB GPUs, FULL_SHARD with bfloat16 mixed precision fits comfortably. For Llama 3 70B on the same setup, FULL_SHARD is required — SHARD_GRAD_OP would OOM since the full 70B parameter tensor (140GB in bfloat16) can’t fit on each GPU during the forward pass.
Gradient Checkpointing
Gradient checkpointing (activation checkpointing) trades compute for memory by not storing activations during the forward pass and recomputing them during the backward pass. With FSDP, activations are the second largest memory consumer after the sharded parameters — enabling checkpointing roughly halves activation memory at the cost of about 33% more compute. For fine-tuning large models with long sequences, it’s often necessary to enable checkpointing to fit any reasonable batch size:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
# Apply activation checkpointing to each LlamaDecoderLayer
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
check_fn=check_fn,
)
Apply activation checkpointing after FSDP wrapping. CheckpointImpl.NO_REENTRANT is preferred over the older reentrant implementation as it works correctly with FSDP and doesn’t have limitations with custom autograd functions. After wrapping, each LlamaDecoderLayer recomputes its activations during the backward pass rather than storing them from the forward pass.
Training Loop
The training loop with FSDP looks nearly identical to a standard DDP training loop — the sharding is transparent once the model is wrapped. The key differences are in the optimizer (you should use torch.optim.AdamW with the FSDP-wrapped model’s parameters, not the original unwrapped parameters) and in how you handle gradient clipping:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
optimizer = torch.optim.AdamW(
model.parameters(),
lr=2e-5,
weight_decay=0.01,
betas=(0.9, 0.95),
)
scaler = None # No GradScaler needed with bfloat16
for step, batch in enumerate(dataloader):
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
attention_mask = batch["attention_mask"].cuda()
with torch.autocast("cuda", dtype=torch.bfloat16):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss
loss.backward()
# Gradient clipping with FSDP: must call clip_grad_norm_ on the FSDP model
model.clip_grad_norm_(max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
if dist.get_rank() == 0 and step % 10 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
Two important notes: gradient clipping must be called as model.clip_grad_norm_(max_norm) on the FSDP-wrapped model, not torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm). The latter would clip gradients before FSDP has synchronized them across shards, producing incorrect clipping behavior. Second, bfloat16 doesn’t need a GradScaler (unlike float16) because bfloat16 has the same exponent range as float32 — loss scaling is not required. Using bfloat16 with FSDP is the standard approach for Llama fine-tuning and avoids the additional complexity of mixed-precision gradient scaling.
Saving and Loading Checkpoints
Checkpointing with FSDP requires care because the model parameters are sharded across all ranks — you can’t simply call torch.save(model.state_dict()) from rank 0, as rank 0 only holds its shard. There are two approaches: full state dict (gather all parameters to rank 0 and save a single file) and sharded state dict (each rank saves its own shard). Full state dict is simpler and produces a checkpoint compatible with single-GPU inference; sharded state dict is faster for large models since there’s no gather overhead.
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
def save_full_checkpoint(model, optimizer, step, path):
"""Save full model state dict from rank 0."""
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save({
'step': step,
'model_state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
}, path)
dist.barrier() # Wait for rank 0 to finish saving
def load_checkpoint(model, optimizer, path):
"""Load full checkpoint on rank 0 and broadcast."""
if dist.get_rank() == 0:
checkpoint = torch.load(path, map_location='cpu')
else:
checkpoint = None
# Broadcast checkpoint from rank 0 to all ranks
checkpoint = [checkpoint]
dist.broadcast_object_list(checkpoint, src=0)
checkpoint = checkpoint[0]
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['step']
The offload_to_cpu=True option in FullStateDictConfig moves gathered parameters to CPU as they’re collected, preventing OOM when the full model is larger than a single GPU’s memory. For 70B models even this can be tight; in that case use StateDictType.SHARDED_STATE_DICT which saves one file per rank and requires more complex loading logic but avoids the full model gather entirely.
Memory Budget for Common Configurations
Estimating your memory budget before starting a fine-tuning run saves debugging time. For a Llama 3 8B model in bfloat16 across N GPUs with FULL_SHARD: parameters take 16GB total, sharded to 16/N GB per GPU. Gradients match parameters: 16/N GB per GPU. AdamW optimizer state (two float32 moments) takes 64GB total, sharded to 64/N GB per GPU. Activations with gradient checkpointing depend on batch size and sequence length — roughly 1–2GB per GPU for batch size 4 at 2048 token sequence length. On 8x A100 80GB, the total per-GPU footprint is approximately (16+16+64)/8 + activations ≈ 12GB + activations, leaving substantial headroom for larger batch sizes. For Llama 3 70B on the same setup, the per-GPU parameter + gradient + optimizer footprint is approximately (140+140+560)/8 ≈ 105GB, which is tight on 80GB GPUs — you’d need to use bfloat16 for optimizer states (reducing the optimizer budget from 560GB to 280GB and the per-GPU total to roughly 70GB) or switch to 16x GPUs.
Practical Checklist
Before launching a long fine-tuning run with FSDP, verify these things on a small smoke test (2–4 steps): confirm that loss decreases and isn’t NaN on the first few steps, which catches weight initialization or data pipeline bugs early; check that all ranks have the same loss value (if they differ, you have a data sharding bug or non-synchronized state); verify your checkpoint save/load round-trip by saving after step 1, loading, and confirming the loss continues from where it left off; and profile memory usage with torch.cuda.max_memory_allocated() to ensure you have headroom before scaling to full batch sizes. FSDP’s communication overhead relative to DDP is typically 15–25% for within-node training and up to 50% for across-node training, so benchmark your throughput in tokens/second before committing to a long run and compare against expectations based on your hardware’s theoretical inter-GPU bandwidth.
Data Pipeline for Distributed Training
With FSDP, each GPU processes a different micro-batch — the global batch size is the per-GPU batch size multiplied by the number of GPUs. This means your DataLoader needs to distribute data correctly across ranks. PyTorch’s DistributedSampler handles this by partitioning the dataset so each rank sees a non-overlapping subset:
from torch.utils.data import DataLoader, DistributedSampler
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=True,
seed=42,
)
dataloader = DataLoader(
dataset,
batch_size=per_gpu_batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True,
)
# CRITICAL: call this at the start of each epoch so each epoch
# sees a different shuffle (sampler uses epoch as the random seed)
for epoch in range(num_epochs):
sampler.set_epoch(epoch)
for batch in dataloader:
...
The sampler.set_epoch(epoch) call is easy to forget and causes a subtle bug: without it, the sampler uses the same seed every epoch, so every epoch sees data in the same order. This doesn’t cause wrong gradients, but it means your model sees the same sequence of examples in every epoch, which hurts generalization. For fine-tuning on instruction datasets, you typically shuffle once and don’t worry about epoch ordering much, but for longer fine-tuning runs with multiple epochs the shuffle matters.
For tokenization and packing, it’s worth pre-tokenizing your dataset and packing multiple examples into fixed-length sequences (typically 2048 or 4096 tokens) rather than padding each example to the maximum length. Padding wastes compute on masked tokens; packing achieves near-100% token efficiency. The HuggingFace datasets library with the map function applied in parallel (num_proc=8) handles large dataset tokenization efficiently without loading everything into memory.
Choosing Between FSDP and DeepSpeed
FSDP and DeepSpeed ZeRO-3 are functionally similar — both shard parameters, gradients, and optimizer states across all workers. The practical differences come down to ecosystem integration and feature coverage. FSDP is native PyTorch, requires no additional dependencies, and integrates seamlessly with torch.compile, PyTorch’s profiler, and the rest of the PyTorch ecosystem. DeepSpeed has a longer history, more mature documentation for edge cases (very large models, offloading to CPU/NVMe), and built-in support for features like ZeRO-Infinity (parameter offloading to CPU/NVMe for models that don’t fit in GPU memory even with sharding). For fine-tuning Llama 3 on a standard multi-GPU setup where the model fits in GPU memory with sharding, FSDP is the right choice — simpler setup, no additional dependencies, and native integration with torch.compile for training speedups. If you need CPU offloading for parameters or optimizer states because GPU memory is genuinely insufficient even with full sharding, DeepSpeed ZeRO-Infinity is worth the added complexity. The HuggingFace Accelerate library provides an abstraction layer over both and lets you switch between them with a config file change, which is useful if you want to benchmark both on your specific setup without rewriting training code.
Monitoring Training Progress
With multi-GPU FSDP training, logging requires care to avoid duplicate output from all ranks. The standard pattern is to gate all logging behind a rank check and to average the loss across all ranks before logging to get the true global loss:
def reduce_loss(loss):
"""Average loss across all FSDP ranks."""
loss_tensor = loss.detach().clone()
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
return loss_tensor.item()
if dist.get_rank() == 0:
global_loss = reduce_loss(loss)
print(f"Step {step}: loss={global_loss:.4f}")
# Log to wandb, mlflow, etc.
wandb.log({"train/loss": global_loss, "step": step})
Logging throughput in tokens per second gives you a hardware-independent performance metric and makes it easy to compare configurations (different batch sizes, different numbers of GPUs, FSDP vs DDP) on equal footing. Calculate it as global_batch_tokens / step_time_seconds, where global_batch_tokens is per_gpu_batch_size * sequence_length * world_size. A well-configured FSDP run on 8x A100 80GB GPUs should achieve 2,000–4,000 tokens per second for Llama 3 8B fine-tuning at batch size 4 per GPU and 2048 token sequences, depending on whether gradient checkpointing is enabled. Enabling torch.compile on top of FSDP can improve this by 15–30% with no changes to the training logic.