Step-by-Step Guide to Creating a Transformer from Scratch in PyTorch

Building a Transformer model from scratch is one of the most rewarding experiences for any deep learning practitioner. The Transformer architecture, introduced in the groundbreaking paper “Attention Is All You Need,” revolutionized natural language processing and became the foundation for modern language models like GPT and BERT. In this comprehensive guide, we’ll walk through implementing every component of a Transformer model using PyTorch, giving you a deep understanding of how these powerful models work under the hood.

🏗️ What We’ll Build

A complete Transformer encoder-decoder architecture with multi-head attention, positional encoding, and feed-forward networks

Understanding the Transformer Architecture

The Transformer architecture consists of two main components: an encoder that processes the input sequence and a decoder that generates the output sequence. Each encoder and decoder layer contains several sub-components that work together to capture complex relationships in sequential data.

The key innovation of Transformers is the self-attention mechanism, which allows the model to weigh the importance of different positions in the sequence when processing each element. This enables the model to capture long-range dependencies more effectively than traditional RNN-based approaches.

Essential Components Overview

Before diving into implementation, let’s understand the core components we need to build:

  • Multi-Head Attention: The heart of the Transformer, allowing the model to focus on different parts of the sequence simultaneously
  • Positional Encoding: Since Transformers don’t have inherent sequence ordering, we add positional information
  • Feed-Forward Networks: Dense layers that process the attention outputs
  • Layer Normalization: Stabilizes training and improves convergence
  • Encoder and Decoder Layers: Stack these components to form the complete architecture

Building the Multi-Head Attention Mechanism

The multi-head attention mechanism is the core innovation of Transformers. It allows the model to jointly attend to information from different representation subspaces at different positions.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply weights to values
        output = torch.matmul(attention_weights, V)
        return output, attention_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape for multi-head attention
        Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and apply final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        return self.w_o(attention_output)

The multi-head attention splits the input into multiple “heads,” each focusing on different aspects of the relationships between sequence elements. The scaled dot-product attention computes similarity scores between queries and keys, then uses these scores to weight the values.

Implementing Positional Encoding

Since Transformers process all positions simultaneously, they need explicit positional information to understand sequence order. Positional encoding adds sinusoidal patterns to the input embeddings.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        
        # Calculate div_term for sinusoidal encoding
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pe[:seq_len, :]

The sinusoidal positional encoding uses different frequencies for different dimensions, allowing the model to learn to attend to relative positions. This encoding is added to the input embeddings before they enter the Transformer layers.

Creating the Feed-Forward Network

Each Transformer layer includes a position-wise feed-forward network that processes each position independently. This network consists of two linear transformations with a ReLU activation in between.

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

The feed-forward network typically expands the dimensionality by a factor of 4 (d_ff = 4 * d_model), applies ReLU activation, then projects back to the original dimension. This allows the model to learn complex non-linear transformations of the attention outputs.

Building the Encoder Layer

The encoder layer combines multi-head attention with the feed-forward network, using residual connections and layer normalization for stable training.

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

Building the Decoder Layer

The decoder layer is more complex than the encoder, featuring both self-attention and encoder-decoder attention mechanisms.

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.encoder_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        attn_output = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Encoder-decoder attention
        attn_output = self.encoder_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed-forward network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

The decoder layer includes masked self-attention to prevent the model from seeing future tokens during training, and encoder-decoder attention to incorporate information from the source sequence.

đź’ˇ Key Implementation Insight

The decoder uses causal masking during self-attention to maintain the autoregressive property, ensuring that predictions for position i can only depend on positions less than i.

Assembling the Complete Transformer

Now we can combine all components into the complete Transformer architecture:

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, 
                 num_layers, d_ff, max_seq_length, dropout=0.1):
        super(Transformer, self).__init__()
        
        # Embedding layers
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        
        # Encoder and decoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        # Final projection layer
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask
    
    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        # Encoder
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        encoder_output = src_embedded
        for layer in self.encoder_layers:
            encoder_output = layer(encoder_output, src_mask)
        
        # Decoder
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
        decoder_output = tgt_embedded
        for layer in self.decoder_layers:
            decoder_output = layer(decoder_output, encoder_output, src_mask, tgt_mask)
        
        # Final projection
        output = self.fc(decoder_output)
        return output

Training Considerations and Usage

To use your Transformer model effectively, you’ll need to handle several training aspects:

# Initialize the model
model = Transformer(
    src_vocab_size=10000,
    tgt_vocab_size=10000,
    d_model=512,
    num_heads=8,
    num_layers=6,
    d_ff=2048,
    max_seq_length=5000,
    dropout=0.1
)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# Learning rate scheduling
def get_lr_scale(step, d_model, warmup_steps=4000):
    return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)

The model requires careful hyperparameter tuning, particularly for the learning rate schedule. The original paper used a warm-up schedule that increases the learning rate linearly for the first warmup_steps, then decreases it proportionally to the inverse square root of the step number.

Optimizing Performance and Memory Usage

When implementing Transformers from scratch, consider these optimization strategies:

  • Gradient Accumulation: For large models, accumulate gradients across multiple batches before updating parameters
  • Mixed Precision Training: Use automatic mixed precision (AMP) to reduce memory usage and accelerate training
  • Attention Optimization: Implement efficient attention mechanisms like Flash Attention for longer sequences
  • Layer Dropout: Apply dropout to entire layers during training to improve generalization

The attention mechanism has quadratic complexity with respect to sequence length, so be mindful of memory constraints when processing long sequences. Consider techniques like attention windowing or sparse attention patterns for very long sequences.

Conclusion

Building a Transformer from scratch provides invaluable insights into the mechanics of modern deep learning architectures. Through this implementation, you’ve gained hands-on experience with attention mechanisms, positional encoding, and the intricate interplay between encoder and decoder components. The modular design of our implementation makes it easy to experiment with different configurations and understand how each component contributes to the model’s overall performance.

This foundation will serve you well as you explore more advanced architectures like GPT, BERT, and other Transformer variants. The principles you’ve learned here—attention, residual connections, layer normalization, and careful architectural design—are fundamental to understanding and building state-of-the-art language models and other sequence-to-sequence applications.

Leave a Comment