Building a Transformer model from scratch is one of the most rewarding experiences for any deep learning practitioner. The Transformer architecture, introduced in the groundbreaking paper “Attention Is All You Need,” revolutionized natural language processing and became the foundation for modern language models like GPT and BERT. In this comprehensive guide, we’ll walk through implementing every component of a Transformer model using PyTorch, giving you a deep understanding of how these powerful models work under the hood.
🏗️ What We’ll Build
A complete Transformer encoder-decoder architecture with multi-head attention, positional encoding, and feed-forward networks
Understanding the Transformer Architecture
The Transformer architecture consists of two main components: an encoder that processes the input sequence and a decoder that generates the output sequence. Each encoder and decoder layer contains several sub-components that work together to capture complex relationships in sequential data.
The key innovation of Transformers is the self-attention mechanism, which allows the model to weigh the importance of different positions in the sequence when processing each element. This enables the model to capture long-range dependencies more effectively than traditional RNN-based approaches.
Essential Components Overview
Before diving into implementation, let’s understand the core components we need to build:
- Multi-Head Attention: The heart of the Transformer, allowing the model to focus on different parts of the sequence simultaneously
- Positional Encoding: Since Transformers don’t have inherent sequence ordering, we add positional information
- Feed-Forward Networks: Dense layers that process the attention outputs
- Layer Normalization: Stabilizes training and improves convergence
- Encoder and Decoder Layers: Stack these components to form the complete architecture
Building the Multi-Head Attention Mechanism
The multi-head attention mechanism is the core innovation of Transformers. It allows the model to jointly attend to information from different representation subspaces at different positions.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Calculate attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax
attention_weights = F.softmax(scores, dim=-1)
# Apply weights to values
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections and reshape for multi-head attention
Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads and apply final linear layer
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
return self.w_o(attention_output)
The multi-head attention splits the input into multiple “heads,” each focusing on different aspects of the relationships between sequence elements. The scaled dot-product attention computes similarity scores between queries and keys, then uses these scores to weight the values.
Implementing Positional Encoding
Since Transformers process all positions simultaneously, they need explicit positional information to understand sequence order. Positional encoding adds sinusoidal patterns to the input embeddings.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super(PositionalEncoding, self).__init__()
# Create positional encoding matrix
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
# Calculate div_term for sinusoidal encoding
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
# Apply sine to even indices
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cosine to odd indices
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension and register as buffer
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
seq_len = x.size(0)
return x + self.pe[:seq_len, :]
The sinusoidal positional encoding uses different frequencies for different dimensions, allowing the model to learn to attend to relative positions. This encoding is added to the input embeddings before they enter the Transformer layers.
Creating the Feed-Forward Network
Each Transformer layer includes a position-wise feed-forward network that processes each position independently. This network consists of two linear transformations with a ReLU activation in between.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
The feed-forward network typically expands the dimensionality by a factor of 4 (d_ff = 4 * d_model), applies ReLU activation, then projects back to the original dimension. This allows the model to learn complex non-linear transformations of the attention outputs.
Building the Encoder Layer
The encoder layer combines multi-head attention with the feed-forward network, using residual connections and layer normalization for stable training.
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection and layer norm
attn_output = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection and layer norm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
Building the Decoder Layer
The decoder layer is more complex than the encoder, featuring both self-attention and encoder-decoder attention mechanisms.
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.encoder_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# Masked self-attention
attn_output = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# Encoder-decoder attention
attn_output = self.encoder_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# Feed-forward network
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
The decoder layer includes masked self-attention to prevent the model from seeing future tokens during training, and encoder-decoder attention to incorporate information from the source sequence.
đź’ˇ Key Implementation Insight
The decoder uses causal masking during self-attention to maintain the autoregressive property, ensuring that predictions for position i can only depend on positions less than i.
Assembling the Complete Transformer
Now we can combine all components into the complete Transformer architecture:
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads,
num_layers, d_ff, max_seq_length, dropout=0.1):
super(Transformer, self).__init__()
# Embedding layers
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
# Encoder and decoder layers
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Final projection layer
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def generate_mask(self, src, tgt):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
seq_length = tgt.size(1)
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
tgt_mask = tgt_mask & nopeak_mask
return src_mask, tgt_mask
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
# Encoder
src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
encoder_output = src_embedded
for layer in self.encoder_layers:
encoder_output = layer(encoder_output, src_mask)
# Decoder
tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
decoder_output = tgt_embedded
for layer in self.decoder_layers:
decoder_output = layer(decoder_output, encoder_output, src_mask, tgt_mask)
# Final projection
output = self.fc(decoder_output)
return output
Training Considerations and Usage
To use your Transformer model effectively, you’ll need to handle several training aspects:
# Initialize the model
model = Transformer(
src_vocab_size=10000,
tgt_vocab_size=10000,
d_model=512,
num_heads=8,
num_layers=6,
d_ff=2048,
max_seq_length=5000,
dropout=0.1
)
# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# Learning rate scheduling
def get_lr_scale(step, d_model, warmup_steps=4000):
return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)
The model requires careful hyperparameter tuning, particularly for the learning rate schedule. The original paper used a warm-up schedule that increases the learning rate linearly for the first warmup_steps, then decreases it proportionally to the inverse square root of the step number.
Optimizing Performance and Memory Usage
When implementing Transformers from scratch, consider these optimization strategies:
- Gradient Accumulation: For large models, accumulate gradients across multiple batches before updating parameters
- Mixed Precision Training: Use automatic mixed precision (AMP) to reduce memory usage and accelerate training
- Attention Optimization: Implement efficient attention mechanisms like Flash Attention for longer sequences
- Layer Dropout: Apply dropout to entire layers during training to improve generalization
The attention mechanism has quadratic complexity with respect to sequence length, so be mindful of memory constraints when processing long sequences. Consider techniques like attention windowing or sparse attention patterns for very long sequences.
Conclusion
Building a Transformer from scratch provides invaluable insights into the mechanics of modern deep learning architectures. Through this implementation, you’ve gained hands-on experience with attention mechanisms, positional encoding, and the intricate interplay between encoder and decoder components. The modular design of our implementation makes it easy to experiment with different configurations and understand how each component contributes to the model’s overall performance.
This foundation will serve you well as you explore more advanced architectures like GPT, BERT, and other Transformer variants. The principles you’ve learned here—attention, residual connections, layer normalization, and careful architectural design—are fundamental to understanding and building state-of-the-art language models and other sequence-to-sequence applications.