PyTorch Lightning Fabric is the lightweight core of the Lightning ecosystem — it handles device placement, mixed precision, distributed training strategy, and gradient synchronisation, while leaving you in full control of your training loop. Unlike full PyTorch Lightning, which wraps your entire training logic in a LightningModule, Fabric is a thin wrapper you drop into an existing training loop with minimal changes. This article covers how Fabric works, how to migrate a plain PyTorch training loop to Fabric, and how to use it for multi-GPU and multi-node training without rewriting your model code.
What Fabric Handles vs What You Keep
In a standard PyTorch DDP setup, you write boilerplate for: initialising the process group, wrapping your model in DistributedDataParallel, configuring a DistributedSampler for your dataloader, moving tensors to the right device, setting up a GradScaler for mixed precision, and tearing everything down at the end. Fabric replaces all of that boilerplate with a single Fabric object while keeping your forward pass, loss computation, and optimizer step exactly as you wrote them. The mental model is: Fabric owns the infrastructure layer, you own the training logic.
Migrating a Plain PyTorch Loop to Fabric
# BEFORE: plain PyTorch single-GPU training loop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for epoch in range(10):
for batch in train_loader:
inputs, targets = batch
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.CrossEntropyLoss()(outputs, targets)
loss.backward()
optimizer.step()
# AFTER: same loop with Fabric — 6 lines changed, everything else identical
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lightning.fabric import Fabric
# Configure strategy, accelerator, and precision here — not in the loop
fabric = Fabric(accelerator="cuda", devices=4, strategy="ddp",
precision="bf16-mixed")
fabric.launch()
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# setup() wraps model in DDP, moves to device, sets up GradScaler
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
for epoch in range(10):
for batch in train_loader:
inputs, targets = batch # already on correct device
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.CrossEntropyLoss()(outputs, targets)
fabric.backward(loss) # replaces loss.backward() — handles scaler
optimizer.step()
The key changes are: (1) create a Fabric instance with your desired configuration, (2) call fabric.launch() to initialise the distributed environment, (3) pass your model and optimizer through fabric.setup(), (4) pass your dataloader through fabric.setup_dataloaders(), and (5) replace loss.backward() with fabric.backward(loss). Everything else — forward pass, loss computation, optimizer step, logging — stays exactly the same.
Switching Between Strategies Without Code Changes
The main practical advantage of Fabric over raw DDP/FSDP setup is that you can switch training strategies by changing a single argument to the Fabric constructor rather than rewriting the distributed setup code. The same training loop runs on a single GPU, multiple GPUs with DDP, or multiple GPUs with FSDP just by changing the strategy string.
from lightning.fabric import Fabric
from lightning.fabric.strategies import FSDPStrategy
import torch.nn as nn
# Single GPU — no strategy needed
fabric_single = Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
# Multi-GPU DDP — same loop, 4 GPUs
fabric_ddp = Fabric(accelerator="cuda", devices=4, strategy="ddp",
precision="bf16-mixed")
# Multi-GPU FSDP — for models too large for DDP
fsdp_strategy = FSDPStrategy(
auto_wrap_policy={nn.TransformerDecoderLayer}, # wrap each decoder layer
activation_checkpointing_policy={nn.TransformerDecoderLayer},
cpu_offload=False,
)
fabric_fsdp = Fabric(accelerator="cuda", devices=4, strategy=fsdp_strategy,
precision="bf16-mixed")
# Multi-node: 2 nodes x 4 GPUs = 8 GPUs total
fabric_multinode = Fabric(accelerator="cuda", devices=4, num_nodes=2,
strategy="ddp", precision="bf16-mixed")
# All use the SAME training loop — only the Fabric constructor changes
Gradient Accumulation and Clipping with Fabric
from lightning.fabric import Fabric
import torch
fabric = Fabric(accelerator="cuda", devices=4, strategy="ddp",
precision="bf16-mixed")
fabric.launch()
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
accumulation_steps = 4 # effective batch = batch_size * 4 * num_gpus
max_grad_norm = 1.0
for epoch in range(num_epochs):
optimizer.zero_grad()
for step, batch in enumerate(train_loader):
inputs, targets = batch
is_accumulating = (step + 1) % accumulation_steps != 0
# no_backward_sync skips gradient all-reduce until the last accumulation step
# This is the Fabric equivalent of DDP's no_sync() context manager
with fabric.no_backward_sync(model, enabled=is_accumulating):
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
fabric.backward(loss)
if not is_accumulating:
# Gradient clipping — Fabric handles unscaling for mixed precision
fabric.clip_gradients(model, optimizer, max_norm=max_grad_norm)
optimizer.step()
optimizer.zero_grad()
fabric.no_backward_sync() is the Fabric equivalent of PyTorch DDP’s model.no_sync() context manager — it defers the gradient all-reduce until the final accumulation step, which avoids paying the communication cost on every micro-step. fabric.clip_gradients() handles the interaction between gradient clipping and the mixed precision scaler automatically, which is one of the trickier manual steps in a raw PyTorch mixed precision setup.
Checkpointing with Fabric
import os
def save_checkpoint(fabric, model, optimizer, epoch, step, path):
"""Save a distributed-safe checkpoint with Fabric."""
state = {
"model": model, # Fabric handles DDP/FSDP state dict extraction
"optimizer": optimizer,
"epoch": epoch,
"step": step,
}
fabric.save(path, state)
def load_checkpoint(fabric, model, optimizer, path):
"""Load checkpoint — works regardless of how many GPUs were used to save."""
state = {
"model": model,
"optimizer": optimizer,
"epoch": 0,
"step": 0,
}
remainder = fabric.load(path, state)
return state["epoch"], state["step"]
# Training loop with checkpointing
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
for epoch in range(num_epochs):
for step, batch in enumerate(train_loader):
# ... training step ...
if step % 500 == 0 and fabric.is_global_zero:
save_checkpoint(fabric, model, optimizer, epoch, step,
f"{ckpt_dir}/epoch{epoch}_step{step}.ckpt")
# Resume from checkpoint
start_epoch, start_step = load_checkpoint(
fabric, model, optimizer, "checkpoints/epoch2_step1000.ckpt"
)
fabric.save() and fabric.load() handle the distributed complexity of checkpointing — for FSDP, this means consolidating the sharded state dict correctly; for DDP, it means saving only from rank 0 while ensuring all ranks have consistent state. The fabric.is_global_zero guard ensures that file operations only happen on one process even in multi-node setups.
Logging and Metrics Across Ranks
from lightning.fabric import Fabric
import torch
fabric = Fabric(accelerator="cuda", devices=4, strategy="ddp")
fabric.launch()
# ... setup ...
running_loss = torch.tensor(0.0, device=fabric.device)
n_correct = torch.tensor(0, device=fabric.device)
n_total = torch.tensor(0, device=fabric.device)
for batch in train_loader:
inputs, targets = batch
outputs = model(inputs)
loss = criterion(outputs, targets)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
running_loss += loss.detach()
n_correct += (outputs.argmax(dim=-1) == targets).sum()
n_total += targets.size(0)
# all_reduce aggregates values across all GPUs
running_loss = fabric.all_reduce(running_loss, reduce_op="mean")
n_correct = fabric.all_reduce(n_correct, reduce_op="sum")
n_total = fabric.all_reduce(n_total, reduce_op="sum")
if fabric.is_global_zero:
epoch_loss = running_loss.item() / len(train_loader)
epoch_acc = n_correct.item() / n_total.item()
print(f"Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")
fabric.all_reduce() aggregates tensors across all distributed ranks, making it straightforward to compute global metrics without managing the underlying torch.distributed.all_reduce calls directly. The is_global_zero guard then ensures that logging, printing, and checkpoint saving only happen once even when running across many processes.
When to Use Fabric vs Full PyTorch Lightning
Fabric is the right choice when you have an existing training loop you want to scale without restructuring it around a LightningModule. Research code, custom training loops with unusual control flow, and codebases where the training logic is tightly coupled to the model architecture all benefit from Fabric’s minimal-change approach. The tradeoff is that you do not get Lightning’s built-in callbacks, automatic logging integrations, hyperparameter tuning hooks, or trainer-level features like automatic learning rate finding. If you are starting a new project from scratch and expect to want those features, full Lightning is worth the upfront cost of structuring your code around LightningModule. If you are taking an existing codebase and need to add multi-GPU support without a major refactor, Fabric is the path of least resistance — the migration from a single-GPU PyTorch loop to a Fabric-based multi-GPU loop can typically be done in under an hour, compared to the larger restructuring that full Lightning requires.
Fabric vs Raw DDP: What You Actually Save
To appreciate what Fabric removes, it helps to see what raw DDP setup looks like for a multi-GPU training job. A standard DDP launch script requires: importing torch.distributed, calling dist.init_process_group with the right backend and timeout, wrapping the model in DistributedDataParallel with the device IDs, creating a DistributedSampler and passing it to the DataLoader, creating a GradScaler if using mixed precision, calling scaler.scale(loss).backward() instead of loss.backward(), calling scaler.unscale_(optimizer) before gradient clipping, calling scaler.step(optimizer) and scaler.update() instead of just optimizer.step(), and calling dist.destroy_process_group() at the end. Each of these is a potential source of subtle bugs — forgetting to unscale before clipping is a common mistake that silently clips incorrectly scaled gradients, and forgetting the DistributedSampler means all GPUs process the same data rather than different shards.
Fabric eliminates all of this without hiding the training loop. The model, optimizer, and dataloader are the same objects you created — Fabric just wraps them in the appropriate distributed containers and registers the right hooks. This transparency makes debugging easier: if something goes wrong, you can inspect the wrapped model or optimizer directly, and the training loop logic is in one place rather than distributed across DDP boilerplate.
Using Fabric with Custom Training Loops and Research Code
One of Fabric’s strongest use cases is research code where the training loop has non-standard control flow — alternating between multiple models (GANs, RLHF with separate actor and critic, distillation with teacher and student), manual learning rate scheduling at non-epoch boundaries, or custom gradient manipulation between the backward pass and the optimizer step. Full Lightning’s Trainer makes these patterns awkward because the Trainer owns the training loop and exposes hooks for customisation rather than letting you write the loop directly. Fabric keeps you in control.
from lightning.fabric import Fabric
import torch
# GAN training with Fabric: two models, two optimizers, custom alternation
fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp",
precision="bf16-mixed")
fabric.launch()
generator = Generator()
discriminator = Discriminator()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
# Setup both models and both optimizers
generator, g_optimizer = fabric.setup(generator, g_optimizer)
discriminator, d_optimizer = fabric.setup(discriminator, d_optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
for epoch in range(num_epochs):
for real_images, _ in train_loader:
batch_size = real_images.size(0)
# --- Train discriminator ---
d_optimizer.zero_grad()
real_preds = discriminator(real_images)
d_real_loss = criterion(real_preds, torch.ones(batch_size, 1, device=fabric.device))
noise = torch.randn(batch_size, latent_dim, device=fabric.device)
fake_images = generator(noise).detach()
fake_preds = discriminator(fake_images)
d_fake_loss = criterion(fake_preds, torch.zeros(batch_size, 1, device=fabric.device))
d_loss = (d_real_loss + d_fake_loss) / 2
fabric.backward(d_loss)
d_optimizer.step()
# --- Train generator ---
g_optimizer.zero_grad()
noise = torch.randn(batch_size, latent_dim, device=fabric.device)
fake_images = generator(noise)
fake_preds = discriminator(fake_images)
g_loss = criterion(fake_preds, torch.ones(batch_size, 1, device=fabric.device))
fabric.backward(g_loss)
g_optimizer.step()
Running Fabric from the Command Line
Fabric integrates with the fabric run command-line launcher, which handles multi-GPU and multi-node process spawning without requiring you to modify your script with torchrun-specific boilerplate.
# Single node, 4 GPUs
fabric run --accelerator cuda --devices 4 train.py
# Multi-node: 2 nodes, 4 GPUs each (run on each node)
fabric run --accelerator cuda --devices 4 --num-nodes 2 --node-rank 0 --main-address 10.0.0.1 train.py
# Equivalent torchrun command (more verbose)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=10.0.0.1 --master_port=12345 train.py
The fabric run launcher is a thin wrapper around torchrun that infers some arguments from the Fabric constructor in your script rather than requiring them to be specified twice. For multi-node jobs on a cluster with a job scheduler like SLURM, you can pass the SLURM environment variables directly and Fabric will configure the distributed environment from them automatically, which removes another common source of configuration errors in multi-node training.
Practical Tips for Fabric Adoption
A few patterns make Fabric adoption smoother in practice. First, always call fabric.seed_everything(seed) at the start of your script rather than manually seeding each library — it seeds Python, NumPy, PyTorch, and CUDA in a way that is consistent across distributed ranks, which is important for reproducibility in multi-GPU experiments. Second, use fabric.print() instead of plain print() for training logs — it suppresses output from all ranks except rank 0, so you do not get four copies of every log line in a 4-GPU run. Third, when profiling, wrap just the section you want to measure with fabric.barrier() before and after to ensure all ranks are synchronised at the measurement points, which prevents false readings from stragglers. These small habits make distributed training with Fabric significantly less confusing and the resulting logs and metrics much easier to interpret.
Fabric with FSDP: Practical Considerations
When you switch to the FSDP strategy in Fabric, the key difference from DDP is that model parameters are sharded across GPUs rather than replicated. This means each GPU only holds a fraction of the total parameter count in memory at any given time — layers are gathered from other GPUs only when they are needed for a forward or backward pass, then discarded. The practical implication is that very large models that do not fit in DDP (where each GPU needs a full copy of the model) can be trained with FSDP across multiple GPUs. For a 7B parameter model in float16, a full copy requires about 14 GB per GPU. With 4-way FSDP sharding, each GPU holds roughly 3.5 GB of parameters plus activations, which is the difference between fitting on a 24 GB consumer GPU and requiring an 80 GB A100.
The auto_wrap_policy argument to FSDPStrategy controls which modules get their own FSDP wrapper — equivalently, which layers are sharded independently. Setting this to the transformer layer class (e.g., nn.TransformerDecoderLayer or your custom layer class) is the standard choice: it means each transformer layer is a separate FSDP unit that gathers and discards its parameters independently. Wrapping too coarsely (e.g., wrapping the entire model as one unit) loses the memory benefit because the whole model is gathered for every forward pass. Wrapping too finely (e.g., wrapping individual linear layers) adds excessive communication overhead. Transformer layer granularity is the right default for most LLM training setups.
Fabric for Fine-Tuning with LoRA
Fabric works naturally with PEFT libraries like HuggingFace’s peft for LoRA fine-tuning. The setup pattern is the same — wrap the peft model and optimizer with fabric.setup() — and Fabric correctly handles the frozen base model parameters (they do not receive gradient updates and are not included in the optimizer state), while the LoRA adapter weights are trained normally.
from lightning.fabric import Fabric
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import torch
fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp",
precision="bf16-mixed")
fabric.launch()
# Load base model and wrap with LoRA
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B")
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj","v_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# trainable params: 3,538,944 || all params: 3,216,253,952 || trainable%: 0.11%
# Only trainable params go to optimizer
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4
)
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
for batch in train_loader:
input_ids = batch["input_ids"]
labels = batch["labels"]
outputs = model(input_ids=input_ids, labels=labels)
fabric.backward(outputs.loss)
optimizer.step()
optimizer.zero_grad()
The combination of Fabric and LoRA is particularly useful for fine-tuning runs that need to scale across multiple GPUs without the overhead of a full Lightning training setup. You get multi-GPU gradient synchronisation, mixed precision, and distributed-safe checkpointing for the adapter weights with minimal boilerplate, while the frozen base model weights are handled transparently.
Installation and Version Notes
Fabric is included in the lightning package, which replaced the older split between pytorch-lightning and lightning-fabric. Install with pip install lightning and import from lightning.fabric. The minimum PyTorch version for current Fabric releases is 2.0, which is required for FSDP v2 support and the torch.compile integration. If you are on an older codebase still using PyTorch 1.x, the lightning-fabric package pinned to an older version supports back to PyTorch 1.11, but you will lose FSDP strategy support. For most active projects, upgrading to PyTorch 2.x before adopting Fabric is the right move — the two upgrades together give you Fabric’s distributed abstraction layer plus torch.compile for compiled training, which are complementary performance improvements that work well together and are both straightforward to enable once you are on the current Fabric API. For teams already using full PyTorch Lightning who want to extract just the distributed training layer for a specific project, Fabric exposes exactly that subset — it is the same infrastructure that the Lightning Trainer uses internally, now available as a standalone tool without the Trainer coupling. The end result is a training codebase that scales from a laptop to a multi-node GPU cluster with a one-line change and zero additional boilerplate.