PyTorch Lightning is a high-level training framework that wraps raw PyTorch and handles the engineering boilerplate — training loops, validation, checkpointing, multi-GPU distribution, mixed precision, and logging — while keeping the model code in pure PyTorch. For ML engineers who find themselves rewriting the same training loop scaffolding across projects, Lightning provides a standardised structure that is reproducible, testable, and ready for multi-GPU scaling without changes to the model code. This article covers how Lightning’s core abstractions work, how to structure a fine-tuning job for an LLM, and where Lightning adds genuine value versus where it adds indirection you may not want.
LightningModule: Organising Model and Training Logic
The core abstraction in PyTorch Lightning is LightningModule, a subclass of nn.Module that adds lifecycle hooks for training, validation, and testing steps. Instead of writing a training loop with explicit forward passes, loss computation, and optimizer steps, you define training_step and configure_optimizers methods, and the Trainer calls them at the right times. The model remains plain PyTorch throughout — there is no Lightning-specific model format, and a LightningModule can be extracted to a vanilla nn.Module for inference at any time.
import torch
import torch.nn as nn
import lightning as L
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
class LLMFinetuner(L.LightningModule):
def __init__(
self,
model_name: str = "meta-llama/Llama-3.2-1B",
lr: float = 2e-5,
weight_decay: float = 0.01,
warmup_steps: int = 100,
):
super().__init__()
self.save_hyperparameters()
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
def training_step(self, batch, batch_idx):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
loss = outputs.loss
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
with torch.no_grad():
outputs = self.model(**batch)
self.log("val/loss", outputs.loss, prog_bar=True, sync_dist=True)
def configure_optimizers(self):
# Parameter groups: no weight decay for biases and norm params
decay, no_decay = [], []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or "bias" in name or "norm" in name.lower():
no_decay.append(param)
else:
decay.append(param)
optimizer = torch.optim.AdamW(
[{"params": decay, "weight_decay": self.hparams.weight_decay},
{"params": no_decay, "weight_decay": 0.0}],
lr=self.hparams.lr, betas=(0.9, 0.95),
)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}
The save_hyperparameters() call in __init__ automatically saves all constructor arguments to the checkpoint under hparams, enabling reproducible resumption and easy hyperparameter logging in experiment trackers. The sync_dist=True flag in validation logging ensures loss values are correctly averaged across GPUs in multi-GPU training — omitting it causes each GPU to log its own local loss, producing misleadingly low aggregate metrics in distributed setups.
LightningDataModule: Reproducible Data Pipelines
Lightning's LightningDataModule encapsulates dataset loading, splitting, and DataLoader creation in a single object that can be reused across experiments. Separating data logic into a DataModule makes it easy to swap datasets without touching model code and ensures the same data pipeline (including tokenization and preprocessing) is used consistently across training, validation, and inference.
from torch.utils.data import DataLoader, Dataset
import lightning as L
from datasets import load_dataset
from transformers import AutoTokenizer
class InstructionDataModule(L.LightningDataModule):
def __init__(self, model_name: str, dataset_name: str,
max_length: int = 512, batch_size: int = 4):
super().__init__()
self.model_name = model_name
self.dataset_name = dataset_name
self.max_length = max_length
self.batch_size = batch_size
def setup(self, stage: str):
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
tokenizer.pad_token = tokenizer.eos_token
raw = load_dataset(self.dataset_name, split="train").train_test_split(test_size=0.05)
def tokenize(examples):
tokens = tokenizer(
examples["text"],
max_length=self.max_length,
truncation=True,
padding="max_length",
)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
self.train_ds = raw["train"].map(tokenize, batched=True, remove_columns=raw["train"].column_names)
self.val_ds = raw["test"].map(tokenize, batched=True, remove_columns=raw["test"].column_names)
self.train_ds.set_format("torch")
self.val_ds.set_format("torch")
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=4, pin_memory=True)
Trainer: Multi-GPU, Mixed Precision, and Checkpointing
The Trainer is where Lightning earns its keep. Switching from single-GPU to multi-GPU training, enabling mixed precision, and configuring checkpointing are all Trainer arguments — the model and data code do not change. This is the primary reason teams adopt Lightning: the same LightningModule that ran on one GPU in development runs on 8 GPUs with FSDP in production by changing two Trainer arguments.
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.strategies import FSDPStrategy
import torch
# Single GPU, float32 (development)
trainer_dev = L.Trainer(
max_epochs=3,
accelerator="gpu",
devices=1,
precision="bf16-mixed",
val_check_interval=0.25, # validate 4x per epoch
)
# Multi-GPU with FSDP (production)
fsdp_strategy = FSDPStrategy(
auto_wrap_policy=None, # set to transformer wrap policy for large models
activation_checkpointing_policy=None,
cpu_offload=False,
)
trainer_prod = L.Trainer(
max_epochs=3,
accelerator="gpu",
devices=8, # 8 GPUs on one node
strategy=fsdp_strategy,
precision="bf16-mixed",
gradient_clip_val=1.0, # global norm clipping
accumulate_grad_batches=4, # effective batch = 4 * 8 * batch_size
callbacks=[
ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch}-{val/loss:.4f}",
monitor="val/loss",
save_top_k=3,
mode="min",
save_last=True, # always keep latest checkpoint for resumption
),
LearningRateMonitor(logging_interval="step"),
EarlyStopping(monitor="val/loss", patience=3, mode="min"),
],
logger=WandbLogger(project="llm-finetune", log_model=False),
log_every_n_steps=10,
)
# Launch training
model = LLMFinetuner(model_name="meta-llama/Llama-3.2-1B", lr=2e-5)
datamodule = InstructionDataModule("meta-llama/Llama-3.2-1B", "tatsu-lab/alpaca")
trainer_prod.fit(model, datamodule=datamodule)
# Resume from checkpoint
trainer_prod.fit(model, datamodule=datamodule, ckpt_path="checkpoints/last.ckpt")
The accumulate_grad_batches argument handles gradient accumulation automatically — Lightning accumulates gradients over the specified number of steps, scales the loss appropriately, and calls the optimizer step at the right interval. Combined with devices=8, this means the effective batch size is accumulate_grad_batches × devices × per_device_batch_size, which Lightning also uses to correctly scale the learning rate scheduler's total steps via trainer.estimated_stepping_batches. Getting this calculation right manually in a custom training loop is a common source of subtle bugs; Lightning handles it correctly by default.
LoRA Fine-Tuning with Lightning and PEFT
PEFT's LoRA adapters integrate cleanly with Lightning. The standard pattern is to wrap the pretrained model with get_peft_model before passing it to the LightningModule, after which training proceeds normally — only LoRA parameters are updated, the rest remain frozen. Checkpointing saves only the adapter weights when you call model.save_pretrained(), which is more efficient than saving the full model checkpoint through Lightning's default mechanism for large base models.
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM
import lightning as L
import torch
class LoRAFinetuner(L.LightningModule):
def __init__(self, model_name: str = "meta-llama/Llama-3.2-1B",
lora_r: int = 16, lora_alpha: int = 32,
lora_dropout: float = 0.05, lr: float = 2e-4):
super().__init__()
self.save_hyperparameters()
base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none",
)
self.model = get_peft_model(base_model, lora_config)
self.model.print_trainable_parameters() # confirm only LoRA params are trainable
def training_step(self, batch, batch_idx):
loss = self.model(**batch).loss
self.log("train/loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
# Only pass trainable parameters to the optimizer
trainable = [p for p in self.model.parameters() if p.requires_grad]
return torch.optim.AdamW(trainable, lr=self.hparams.lr, weight_decay=0.01)
def on_save_checkpoint(self, checkpoint):
# Save only LoRA adapter weights, not full model
self.model.save_pretrained(f"lora-adapter-epoch{self.current_epoch}")
When Lightning Helps and When It Gets in the Way
Lightning adds the most value when you need multi-GPU training, systematic experiment tracking, or reproducible checkpointing with minimal boilerplate. If you are regularly running experiments across single-GPU, multi-GPU, and TPU environments, or if you have a team where different engineers need to read and modify each other's training code, Lightning's standardised structure reduces friction meaningfully. The separation of model logic from training infrastructure also makes it easier to write unit tests for training steps — you can call training_step directly on a batch without running a full training loop.
Lightning gets in the way when you need precise control over the training loop that does not fit its abstractions — custom gradient manipulation between steps, non-standard optimizer scheduling tied to validation metrics, or multi-step rollout training as in RL-based fine-tuning. In these cases, the overhead of working around Trainer's assumptions outweighs its convenience. Hugging Face Accelerate is a lighter-weight alternative that handles device placement and distributed training without imposing a class hierarchy on your model — it wraps your existing training loop rather than replacing it, making it the better choice when you need distributed training but want to keep a custom loop. Raw PyTorch with DDP or FSDP is the right choice when maximum control and minimal framework overhead are the priority and you are willing to maintain the boilerplate yourself.
Callbacks: Adding Behaviour Without Changing Model Code
Lightning's callback system lets you inject behaviour at any point in the training lifecycle — on batch start, on epoch end, on validation end — without modifying the LightningModule. This is the right place for custom logging, gradient monitoring, evaluation metrics, and learning rate warmup schedules that are experiment-specific rather than model-specific. Keeping this logic in callbacks rather than in training_step keeps the model code clean and makes it easy to enable or disable behaviours by adding or removing callbacks from the Trainer.
The most useful built-in callbacks for LLM work are ModelCheckpoint, which saves the best N checkpoints by a monitored metric and always keeps the latest checkpoint for resumption; LearningRateMonitor, which logs learning rate to your experiment tracker at each step and is essential for verifying that warmup and decay are behaving as expected; and EarlyStopping, which halts training when validation loss stops improving and is particularly useful for fine-tuning runs where you want to stop before overfitting without specifying an exact epoch count upfront. Writing custom callbacks is straightforward — subclass L.Callback and override whichever hook methods you need. Common uses include logging gradient norms every N steps, running custom evaluation tasks on a held-out dataset after each epoch, and sending notifications when training completes or when validation loss reaches a target threshold.
Profiling and Debugging Training with Lightning
Lightning provides a built-in profiler that measures time spent in each lifecycle hook, making it easy to identify whether training is bottlenecked on the model forward pass, the data loading, the optimizer step, or validation. Pass profiler="simple" to the Trainer for a high-level summary printed at the end of training, or profiler="advanced" for per-operation timing. For GPU utilisation profiling, profiler="pytorch" wraps PyTorch's built-in profiler and generates a trace that can be visualised in TensorBoard's trace viewer or Perfetto.
Debugging training dynamics is easier in Lightning than in a custom loop because the standardised lifecycle hooks give you consistent places to add inspection code without worrying about whether it will interfere with gradient accumulation or distributed training. A pattern that works well is to override on_before_optimizer_step in a custom callback to log gradient norms per parameter group — this fires after gradients are accumulated but before the optimizer step, giving you a clean signal for whether gradients are flowing correctly and whether any parameter groups are receiving unexpectedly large or small updates. Logging gradient norms for the first 100–200 steps of a new fine-tuning run is a reliable way to catch issues with learning rate, weight decay configuration, or frozen parameter groups before they compound into a bad checkpoint.
Fabric: Lightning Without the Opinionated Structure
Lightning Fabric is a lower-level API in the same library that provides distributed training, mixed precision, and device management without requiring you to subclass LightningModule or use the Trainer. You write a standard Python training loop, wrap your model and optimizer with fabric.setup(), and replace the backward call with fabric.backward(loss). Fabric handles the distribution strategy, gradient synchronisation, and precision casting transparently. This makes it a natural middle ground between raw DDP (which requires substantial boilerplate for multi-GPU) and full Lightning (which imposes a class hierarchy). If you want Lightning's multi-GPU capabilities but prefer to keep explicit control over your training loop, Fabric is worth evaluating as an alternative to Accelerate.
Practical Gotchas When Using Lightning with LLMs
Several issues arise specifically when using Lightning with large language models that do not appear in typical image classification or regression examples. The first is dtype handling: if you load a pretrained model in bfloat16 and pass precision="bf16-mixed" to the Trainer, Lightning will attempt to autocast operations, which is correct for a model initialised in float32 but redundant and occasionally problematic for a model already in bfloat16. For models loaded in bfloat16 directly (which is the recommended approach for LLMs to avoid a float32 intermediate), use precision="bf16-true" rather than precision="bf16-mixed" to tell Lightning not to apply automatic mixed precision casting on top of your already-bfloat16 model.
The second common issue is checkpoint size with FSDP. By default, Lightning's ModelCheckpoint saves the full consolidated model weights from rank 0. For a 7B parameter model in bfloat16, this is a 14GB file written at the end of every checkpoint interval. With FSDP, sharded checkpoints — where each rank saves its own weight shard — are much faster to write and load, but require the same number of GPUs to load as were used during training. Lightning supports both modes via the FSDPStrategy state_dict_type argument: use FULL_STATE_DICT for portable single-file checkpoints and SHARDED_STATE_DICT for fast checkpointing during long training runs where you are checkpointing to guard against preemption rather than for portability.
The third issue is learning rate scheduler step frequency. Lightning defaults to calling the scheduler at the epoch level, but LLM fine-tuning typically uses step-level scheduling (cosine decay with linear warmup over training steps, not epochs). Always specify "interval": "step" in the scheduler dictionary returned from configure_optimizers, and verify the scheduler is stepping correctly by checking the logged learning rate in the first few hundred steps. A scheduler that steps per epoch instead of per step will apply the entire warmup and decay over a single epoch, producing a training run that never gets the learning rate right.
Choosing Between Lightning, Accelerate, and Raw PyTorch
Lightning is the right choice when your team values standardised, readable training code that works across single-GPU development and multi-GPU production without modification, and when the abstractions (LightningModule, DataModule, Trainer) fit your workflow cleanly. Accelerate is the right choice when you have an existing custom training loop you want to distribute without refactoring into Lightning's class structure — it wraps your loop with minimal changes and supports all the same distributed strategies. Raw PyTorch DDP or FSDP is the right choice when you need maximum performance, want zero framework overhead, and are willing to manage the training loop boilerplate yourself. For most LLM fine-tuning work, the practical difference in training speed between the three approaches is negligible — the choice comes down to how much you value structured code versus explicit control.