PyTorch Lightning Trainer Example: A Hands-On Guide

PyTorch Lightning has become one of the most popular frameworks for scaling PyTorch deep learning models while simplifying training code. At the heart of this framework lies the Trainer class, a powerful abstraction that automates everything from GPU/TPU acceleration to logging and checkpointing.

In this detailed guide, we’ll walk through a PyTorch Lightning Trainer example from scratch. You’ll learn how to structure your project using LightningModule, create clean data pipelines with LightningDataModule, and train your model using the Trainer. This hands-on approach will help you see the practical benefits of Lightning in action while adhering to best practices.


Why Use PyTorch Lightning Trainer?

Before jumping into code, let’s quickly recap why the Trainer class is so valuable:

  • Removes boilerplate: No need to manually write training loops.
  • Hardware agnostic: Easily switch from CPU to GPU to TPU.
  • Integrated logging and checkpointing: Supports TensorBoard, WandB, CSV logs, etc.
  • Built-in features: Includes gradient accumulation, mixed precision, and distributed training.

With just a few lines of configuration, the Trainer handles everything.


PyTorch Lightning Trainer Example: Project Setup

Getting started with PyTorch Lightning means rethinking how you structure a deep learning project. Rather than relying on a single monolithic script with messy training loops, Lightning encourages a clean, modular design that separates data handling, model logic, and training orchestration. This makes your code more maintainable, easier to debug, and simpler to scale across hardware platforms.

For our example, we’ll walk through a complete setup for training a Convolutional Neural Network (CNN) on the MNIST dataset—a classic image classification task involving 28×28 grayscale images of handwritten digits. Even though MNIST is a small dataset, the structure we follow is the same used for large-scale production pipelines.

Here’s a recommended project layout:

my_lightning_project/
├── model.py               # Contains the LightningModule (model architecture and training logic)
├── data_module.py         # Contains the LightningDataModule (data loading and preprocessing)
├── train.py               # Main training script using the Trainer class
├── requirements.txt       # List of required Python packages

Let’s break down each of these components:

model.py: Define the Model and Training Steps

In this file, you’ll implement your deep learning model using LightningModule, a subclass provided by Lightning. This class encapsulates not only the model architecture but also how it trains, validates, and optimizes. It promotes clean separation of model logic from training orchestration.

Your LightningModule should include:

  • The model layers and forward() pass
  • A training_step() that runs a single batch through the model
  • A validation_step() for performance tracking
  • A configure_optimizers() method that returns an optimizer (and optionally a scheduler)

data_module.py: Standardize Data Loading

Lightning encourages use of the LightningDataModule, which abstracts data loading into a clean interface. This class allows you to define:

  • prepare_data() to download or preprocess the data once
  • setup() to initialize datasets
  • train_dataloader() and val_dataloader() for batch generation

Keeping data handling separate from model logic improves reproducibility and allows easy swapping of datasets for experimentation.

train.py: Orchestrate Training with the Trainer

Once your model and data modules are defined, train.py acts as the entry point. This script initializes:

  • The model and data modules
  • Callbacks like early stopping and model checkpointing
  • The Trainer, which manages the training and validation loops

The Trainer is where PyTorch Lightning truly shines. With just one line of code, it supports features like:

  • Training on CPU, GPU, or TPU
  • Mixed precision (AMP) training
  • Logging to TensorBoard, WandB, or CSV
  • Gradient clipping, accumulation, and checkpointing

Why This Structure Is Scalable

This modular approach allows you to:

  • Focus on the model and data without worrying about engineering details
  • Easily test variations by swapping models, datasets, or hyperparameters
  • Integrate into production systems and CI/CD pipelines
  • Collaborate more effectively within teams

Moreover, this setup aligns with MLOps principles, making it easier to deploy models, maintain experiments, and track metrics.

Required Dependencies

To run this example, you’ll need to install the following packages:

pip install pytorch-lightning torch torchvision

For additional features like GPU support, AMP, and experiment tracking, install the extended Lightning version:

pip install lightning[extra]

Make sure your Python version is 3.8 or higher and that CUDA is configured if you plan to train on a GPU.

Bonus Tip: Use requirements.txt

For reproducibility, include your dependencies in a requirements.txt file:

pytorch-lightning==2.1.1
torch==2.1.0
torchvision==0.16.0

Then install them with:

pip install -r requirements.txt


Step 1: Create the Model (LightningModule)

In model.py, define the neural network using a subclass of LightningModule.

import torch
from torch import nn
from pytorch_lightning import LightningModule

class LitCNN(LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


Step 2: Prepare the Data (LightningDataModule)

In data_module.py, define how data is loaded and preprocessed using LightningDataModule.

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def prepare_data(self):
        # Download only
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.mnist_train = datasets.MNIST(self.data_dir, train=True, transform=transform)
        self.mnist_val = datasets.MNIST(self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)


Step 3: Configure the Trainer and Train the Model

In train.py, combine everything together using the Trainer class.

from pytorch_lightning import Trainer
from model import LitCNN
from data_module import MNISTDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# Callbacks
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
early_stopping = EarlyStopping(monitor="val_loss", patience=3)

# Initialize modules
model = LitCNN()
data_module = MNISTDataModule(batch_size=64)

# Initialize trainer
trainer = Trainer(
    max_epochs=10,
    accelerator="gpu",  # or "cpu"
    devices=1,
    callbacks=[checkpoint_callback, early_stopping],
    log_every_n_steps=10
)

# Train
trainer.fit(model, datamodule=data_module)


Step 4: Evaluate the Model

You can use the same trainer to evaluate the model:

trainer.validate(model, datamodule=data_module)

Or test on a separate dataset with trainer.test().


Step 5: Save and Load the Model

Saving:

torch.save(model.state_dict(), "mnist_model.pt")

Loading:

model = LitCNN()
model.load_state_dict(torch.load("mnist_model.pt"))

To resume training:

trainer.fit(model, datamodule=data_module, ckpt_path="path/to/checkpoint.ckpt")


Benefits Illustrated by This Example

🚀 Rapid Prototyping

You can go from concept to training in under 100 lines of code.

🔁 Reusability

Your model and datamodule are modular and can be reused across different experiments.

⚡ Hardware Agnosticism

With just a change in Trainer, run on CPU, GPU, or TPU.

📊 Built-In Logging

Logging and monitoring are seamlessly integrated, with support for many logging frameworks.


Extending the Example: Tips for Production Use

  • Use TensorBoardLogger or WandbLogger for advanced logging.
  • Implement testing logic in test_step() and call trainer.test().
  • Use precision=16 for mixed precision (saves memory, faster training).
  • Scale training with strategy="ddp" and devices=4.
  • Export model to ONNX or TorchScript for deployment.

Conclusion

This PyTorch Lightning Trainer example demonstrates how to build a deep learning pipeline with modular components, automatic hardware acceleration, and clean architecture. Whether you’re working on a research project, building a production system, or teaching ML, Lightning’s Trainer provides a robust, scalable, and flexible framework for training models.

Instead of reinventing the wheel with every project, Lightning allows you to focus on the parts that matter most—your data and your model.


FAQs

Q: Is PyTorch Lightning faster than raw PyTorch?
Yes, it includes native support for mixed precision and distributed training which can significantly speed up training.

Q: Can I use my existing PyTorch model?
Yes, just wrap it in a LightningModule and define the required methods.

Q: Can I deploy a model trained with Lightning?
Yes, models are standard PyTorch models and can be exported or served using TorchServe, ONNX, or other frameworks.

Q: Is this suitable for non-image tasks?
Absolutely. Lightning works with text, tabular, time series, and audio models as well.

Q: Can I use PyTorch Lightning in Jupyter notebooks?
Yes, Lightning is fully compatible with notebooks, Google Colab, and Kaggle kernels.

Leave a Comment