PyTorch Lightning has become one of the most popular frameworks for scaling PyTorch deep learning models while simplifying training code. At the heart of this framework lies the Trainer
class, a powerful abstraction that automates everything from GPU/TPU acceleration to logging and checkpointing.
In this detailed guide, we’ll walk through a PyTorch Lightning Trainer example from scratch. You’ll learn how to structure your project using LightningModule
, create clean data pipelines with LightningDataModule
, and train your model using the Trainer
. This hands-on approach will help you see the practical benefits of Lightning in action while adhering to best practices.
Why Use PyTorch Lightning Trainer?
Before jumping into code, let’s quickly recap why the Trainer
class is so valuable:
- Removes boilerplate: No need to manually write training loops.
- Hardware agnostic: Easily switch from CPU to GPU to TPU.
- Integrated logging and checkpointing: Supports TensorBoard, WandB, CSV logs, etc.
- Built-in features: Includes gradient accumulation, mixed precision, and distributed training.
With just a few lines of configuration, the Trainer
handles everything.
PyTorch Lightning Trainer Example: Project Setup
Getting started with PyTorch Lightning means rethinking how you structure a deep learning project. Rather than relying on a single monolithic script with messy training loops, Lightning encourages a clean, modular design that separates data handling, model logic, and training orchestration. This makes your code more maintainable, easier to debug, and simpler to scale across hardware platforms.
For our example, we’ll walk through a complete setup for training a Convolutional Neural Network (CNN) on the MNIST dataset—a classic image classification task involving 28×28 grayscale images of handwritten digits. Even though MNIST is a small dataset, the structure we follow is the same used for large-scale production pipelines.
Here’s a recommended project layout:
my_lightning_project/
├── model.py # Contains the LightningModule (model architecture and training logic)
├── data_module.py # Contains the LightningDataModule (data loading and preprocessing)
├── train.py # Main training script using the Trainer class
├── requirements.txt # List of required Python packages
Let’s break down each of these components:
model.py
: Define the Model and Training Steps
In this file, you’ll implement your deep learning model using LightningModule
, a subclass provided by Lightning. This class encapsulates not only the model architecture but also how it trains, validates, and optimizes. It promotes clean separation of model logic from training orchestration.
Your LightningModule
should include:
- The model layers and
forward()
pass - A
training_step()
that runs a single batch through the model - A
validation_step()
for performance tracking - A
configure_optimizers()
method that returns an optimizer (and optionally a scheduler)
data_module.py
: Standardize Data Loading
Lightning encourages use of the LightningDataModule
, which abstracts data loading into a clean interface. This class allows you to define:
prepare_data()
to download or preprocess the data oncesetup()
to initialize datasetstrain_dataloader()
andval_dataloader()
for batch generation
Keeping data handling separate from model logic improves reproducibility and allows easy swapping of datasets for experimentation.
train.py
: Orchestrate Training with the Trainer
Once your model and data modules are defined, train.py
acts as the entry point. This script initializes:
- The model and data modules
- Callbacks like early stopping and model checkpointing
- The
Trainer
, which manages the training and validation loops
The Trainer
is where PyTorch Lightning truly shines. With just one line of code, it supports features like:
- Training on CPU, GPU, or TPU
- Mixed precision (AMP) training
- Logging to TensorBoard, WandB, or CSV
- Gradient clipping, accumulation, and checkpointing
Why This Structure Is Scalable
This modular approach allows you to:
- Focus on the model and data without worrying about engineering details
- Easily test variations by swapping models, datasets, or hyperparameters
- Integrate into production systems and CI/CD pipelines
- Collaborate more effectively within teams
Moreover, this setup aligns with MLOps principles, making it easier to deploy models, maintain experiments, and track metrics.
Required Dependencies
To run this example, you’ll need to install the following packages:
pip install pytorch-lightning torch torchvision
For additional features like GPU support, AMP, and experiment tracking, install the extended Lightning version:
pip install lightning[extra]
Make sure your Python version is 3.8 or higher and that CUDA is configured if you plan to train on a GPU.
Bonus Tip: Use requirements.txt
For reproducibility, include your dependencies in a requirements.txt
file:
pytorch-lightning==2.1.1
torch==2.1.0
torchvision==0.16.0
Then install them with:
pip install -r requirements.txt
Step 1: Create the Model (LightningModule
)
In model.py
, define the neural network using a subclass of LightningModule
.
import torch
from torch import nn
from pytorch_lightning import LightningModule
class LitCNN(LightningModule):
def __init__(self, learning_rate=1e-3):
super().__init__()
self.learning_rate = learning_rate
self.model = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
Step 2: Prepare the Data (LightningDataModule
)
In data_module.py
, define how data is loaded and preprocessed using LightningDataModule
.
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir="./data", batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
# Download only
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.mnist_train = datasets.MNIST(self.data_dir, train=True, transform=transform)
self.mnist_val = datasets.MNIST(self.data_dir, train=False, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
Step 3: Configure the Trainer and Train the Model
In train.py
, combine everything together using the Trainer
class.
from pytorch_lightning import Trainer
from model import LitCNN
from data_module import MNISTDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# Callbacks
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
early_stopping = EarlyStopping(monitor="val_loss", patience=3)
# Initialize modules
model = LitCNN()
data_module = MNISTDataModule(batch_size=64)
# Initialize trainer
trainer = Trainer(
max_epochs=10,
accelerator="gpu", # or "cpu"
devices=1,
callbacks=[checkpoint_callback, early_stopping],
log_every_n_steps=10
)
# Train
trainer.fit(model, datamodule=data_module)
Step 4: Evaluate the Model
You can use the same trainer to evaluate the model:
trainer.validate(model, datamodule=data_module)
Or test on a separate dataset with trainer.test()
.
Step 5: Save and Load the Model
Saving:
torch.save(model.state_dict(), "mnist_model.pt")
Loading:
model = LitCNN()
model.load_state_dict(torch.load("mnist_model.pt"))
To resume training:
trainer.fit(model, datamodule=data_module, ckpt_path="path/to/checkpoint.ckpt")
Benefits Illustrated by This Example
🚀 Rapid Prototyping
You can go from concept to training in under 100 lines of code.
🔁 Reusability
Your model and datamodule are modular and can be reused across different experiments.
⚡ Hardware Agnosticism
With just a change in Trainer
, run on CPU, GPU, or TPU.
📊 Built-In Logging
Logging and monitoring are seamlessly integrated, with support for many logging frameworks.
Extending the Example: Tips for Production Use
- Use
TensorBoardLogger
orWandbLogger
for advanced logging. - Implement testing logic in
test_step()
and calltrainer.test()
. - Use
precision=16
for mixed precision (saves memory, faster training). - Scale training with
strategy="ddp"
anddevices=4
. - Export model to ONNX or TorchScript for deployment.
Conclusion
This PyTorch Lightning Trainer example demonstrates how to build a deep learning pipeline with modular components, automatic hardware acceleration, and clean architecture. Whether you’re working on a research project, building a production system, or teaching ML, Lightning’s Trainer provides a robust, scalable, and flexible framework for training models.
Instead of reinventing the wheel with every project, Lightning allows you to focus on the parts that matter most—your data and your model.
FAQs
Q: Is PyTorch Lightning faster than raw PyTorch?
Yes, it includes native support for mixed precision and distributed training which can significantly speed up training.
Q: Can I use my existing PyTorch model?
Yes, just wrap it in a LightningModule
and define the required methods.
Q: Can I deploy a model trained with Lightning?
Yes, models are standard PyTorch models and can be exported or served using TorchServe, ONNX, or other frameworks.
Q: Is this suitable for non-image tasks?
Absolutely. Lightning works with text, tabular, time series, and audio models as well.
Q: Can I use PyTorch Lightning in Jupyter notebooks?
Yes, Lightning is fully compatible with notebooks, Google Colab, and Kaggle kernels.