Weight initialization is one of those topics that practitioners often treat as boilerplate — copy the default, move on — but poor initialization causes real problems: slow convergence, vanishing or exploding gradients, and training instabilities that are indistinguishable from architecture bugs. Understanding why Xavier and Kaiming initialization work, when to use each, and what modern LLM training does differently gives you a reliable mental model for diagnosing initialization-related issues and making informed choices when default settings do not apply.
The Problem: Variance Explosions and Collapses
When a layer has fan_in input connections and weights initialised from a distribution with variance σ², the output variance is approximately fan_in × σ² (assuming zero-mean inputs and weights). In a deep network, if this product is greater than 1, activation variance grows exponentially with depth — exploding activations and gradients. If it is less than 1, variance shrinks exponentially — vanishing gradients. The goal of a principled initialization scheme is to choose σ² such that activations maintain roughly constant variance as they propagate forward through the network, and gradients maintain roughly constant variance as they propagate backward.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def measure_activation_variance(init_fn, n_layers=20, width=512, batch=64):
"""Measure how activation variance evolves through a deep network."""
variances = []
x = torch.randn(batch, width)
for i in range(n_layers):
layer = nn.Linear(width, width, bias=False)
init_fn(layer.weight)
with torch.no_grad():
x = torch.relu(layer(x))
variances.append(x.var().item())
return variances
# Random normal with fixed std — variance explodes or vanishes
std_too_large = lambda w: nn.init.normal_(w, std=0.1)
std_default = lambda w: nn.init.normal_(w, std=1.0 / (512 ** 0.5)) # simple heuristic
kaiming = lambda w: nn.init.kaiming_normal_(w, mode='fan_in', nonlinearity='relu')
vars_large = measure_activation_variance(std_too_large)
vars_default = measure_activation_variance(std_default)
vars_kaiming = measure_activation_variance(kaiming)
print(f"Large std final variance: {vars_large[-1]:.4e}")
print(f"Simple heuristic variance: {vars_default[-1]:.4e}")
print(f"Kaiming final variance: {vars_kaiming[-1]:.4f}") # stays near 1.0
Xavier (Glorot) Initialization
Xavier initialization, introduced by Glorot and Bengio in 2010, derives the correct initialization variance for layers with symmetric activation functions (tanh, sigmoid, linear). The derivation requires that the variance of activations is preserved in the forward pass AND that the variance of gradients is preserved in the backward pass. These two constraints cannot both be satisfied simultaneously unless fan_in equals fan_out, so Xavier uses the harmonic mean: σ² = 2 / (fan_in + fan_out). This produces a uniform or normal distribution with that variance, and it is the correct default for linear layers, embedding projections, and any layer followed by a tanh or linear activation.
import torch
import torch.nn as nn
import math
def xavier_uniform_(tensor: torch.Tensor) -> torch.Tensor:
"""Xavier uniform initialization — equivalent to nn.init.xavier_uniform_."""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
# Uniform distribution U(-a, a) where a = sqrt(6 / (fan_in + fan_out))
a = math.sqrt(6.0 / (fan_in + fan_out))
return tensor.uniform_(-a, a)
def xavier_normal_(tensor: torch.Tensor) -> torch.Tensor:
"""Xavier normal initialization — equivalent to nn.init.xavier_normal_."""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
std = math.sqrt(2.0 / (fan_in + fan_out))
return tensor.normal_(0, std)
# Usage: initialise a transformer projection layer
linear = nn.Linear(512, 512)
xavier_normal_(linear.weight)
nn.init.zeros_(linear.bias) # biases initialised to zero
# PyTorch default for nn.Linear is Kaiming uniform — often wrong for transformers
# Always initialise explicitly rather than relying on defaults
print(f"Weight std after xavier_normal_: {linear.weight.std():.4f}")
print(f"Expected: {math.sqrt(2.0 / (512 + 512)):.4f}")
Kaiming (He) Initialization
Xavier initialization assumes that the activation function preserves variance symmetrically. ReLU violates this assumption — it zeros out half of all activations on average, which halves the output variance compared to a linear activation. Kaiming He’s 2015 paper derives the correct correction: use σ² = 2 / fan_in for ReLU (doubling the variance to account for the zeroing), or σ² = 2 / (fan_in × (1 + slope²)) for leaky ReLU. Kaiming initialization is the correct default for any network using ReLU or its variants, which covers most modern CNNs and residual networks.
import torch
import torch.nn as nn
import math
def kaiming_normal_(tensor: torch.Tensor, nonlinearity: str = 'relu') -> torch.Tensor:
"""Kaiming normal — equivalent to nn.init.kaiming_normal_."""
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(tensor)
# gain corrects for different nonlinearities
gain = nn.init.calculate_gain(nonlinearity)
std = gain / math.sqrt(fan_in)
return tensor.normal_(0, std)
# gains for common activations:
# relu: sqrt(2) ≈ 1.414
# leaky_relu: sqrt(2 / (1 + slope^2))
# tanh: 5/3 ≈ 1.667
# sigmoid: 1.0
# linear: 1.0
# selu: 3/4 ≈ 0.75 (approximation)
class ResNetBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.bn2 = nn.BatchNorm2d(channels)
self._init_weights()
def _init_weights(self):
# Kaiming normal for ReLU layers
nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
# Zero-init the last conv in each residual block — improves early training
nn.init.zeros_(self.conv2.weight)
def forward(self, x):
return x + self.bn2(self.conv2(torch.relu(self.bn1(self.conv1(x)))))
Initialization for Transformers and LLMs
Transformer initialization has evolved significantly from the original “Attention Is All You Need” paper. The standard modern approach for LLMs uses a small normal distribution (std ≈ 0.02) for all weight matrices, with special treatment for residual projection layers. GPT-2 introduced scaling the residual projection weights by 1/√(2N) where N is the number of layers, so that the contribution of each residual block to the residual stream stays approximately constant regardless of depth. Llama and most subsequent open-weight models follow the same pattern: normal initialization with std=0.02 for all weights, with the output projections of attention and MLP blocks scaled down to prevent the residual stream from growing too large in early training.
import torch
import torch.nn as nn
import math
class LLMLayer(nn.Module):
"""Illustrates GPT-2 style weight initialization for transformer layers."""
def __init__(self, d_model: int, n_layers: int):
super().__init__()
self.d_model = d_model
self.n_layers = n_layers
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False) # residual output
self.gate = nn.Linear(d_model, d_model * 4, bias=False)
self.up = nn.Linear(d_model, d_model * 4, bias=False)
self.down = nn.Linear(d_model * 4, d_model, bias=False) # residual output
self._init_weights()
def _init_weights(self):
std = 0.02
# Standard init for non-residual projections
for proj in [self.q_proj, self.k_proj, self.v_proj, self.gate, self.up]:
nn.init.normal_(proj.weight, mean=0.0, std=std)
# Scaled-down init for residual projections (GPT-2 pattern)
residual_std = std / math.sqrt(2 * self.n_layers)
for proj in [self.o_proj, self.down]:
nn.init.normal_(proj.weight, mean=0.0, std=residual_std)
The residual scaling matters most at large depth. For a 32-layer transformer, the residual projection std is reduced by a factor of 8 (√64 = 8) compared to the base std of 0.02 — so o_proj and down_proj are initialised with std ≈ 0.0025. Without this scaling, the variance of the residual stream grows linearly with depth in early training before the model has learned to regulate it, causing gradient instabilities that manifest as loss spikes in the first few hundred steps.
Embedding Initialization
Token and positional embeddings are typically initialised from a normal distribution with std=0.02 or std=1/√(d_model), matching the weight matrix initialization. The choice matters more for smaller vocabularies and lower-dimensional embeddings — at d_model=4096 and vocab size=32K, both schemes are effectively equivalent. Tying the input embedding and output (lm_head) weights, as done in many LLMs, implicitly constrains the initialization and eliminates the need to think about them separately.
Practical Initialization Checklist
Using ReLU or leaky ReLU activations in a CNN? Use Kaiming normal with the appropriate nonlinearity gain. PyTorch’s nn.init.kaiming_normal_ handles this correctly out of the box. Using tanh, sigmoid, or linear activations? Use Xavier normal or uniform. Building a transformer from scratch? Use normal(0, 0.02) for all weight matrices and scale the residual output projections down by 1/√(2N). Always zero-init biases. For residual networks, zero-init the last layer of each residual block (the second conv or the down-projection in an MLP block) — this initialises each block as the identity transformation and lets the network learn residuals incrementally, which substantially improves early training stability. Do not rely on PyTorch’s default initialization for custom architectures: nn.Linear defaults to Kaiming uniform with fan_in mode, which is correct for ReLU but wrong for transformers, and nn.Embedding defaults to normal(0, 1) which is too large for most LLM embedding dimensions. Always call your own _init_weights method explicitly.
Why PyTorch’s Default Initialization Is Often Wrong
PyTorch’s built-in layers use Kaiming uniform with fan_in mode as the default for nn.Linear and nn.Conv2d. This is the correct choice for networks using ReLU activations with no special structure, but it is wrong in at least three common situations. First, transformer attention and MLP projection layers typically use GELU or SiLU activations, not ReLU — Kaiming was derived specifically for ReLU’s variance halving property, and using it for GELU introduces a mild but unnecessary variance mismatch. In practice this rarely causes catastrophic failure, but the correct choice for GELU is closer to Xavier normal. Second, residual output projections in transformers should be scaled down as described above, and the default initialization does not do this. Third, nn.Embedding defaults to normal(0, 1), which has unit variance — far too large for a 4096-dimensional embedding layer in a modern LLM, where std=0.02 is standard. The unit-variance default means the initial embedding magnitude is roughly 64 (√4096), which immediately creates imbalanced contributions when the embedding is added to positional encodings or passed through the first transformer layer.
The practical implication is that for any non-trivial architecture — anything beyond a simple ReLU classifier — you should implement an explicit _init_weights method rather than relying on PyTorch defaults. The pattern used in transformers is to apply the initialization in a self.apply(self._init_weights) call in __init__, which recursively visits every module and applies the correct initialization based on module type. This is the approach used in the GPT-2 reference implementation, Hugging Face Transformers, and essentially every serious LLM training codebase.
Orthogonal Initialization and When to Use It
Orthogonal initialization initialises weight matrices as random orthogonal matrices — matrices where all columns (or rows) are orthonormal. The key property is that orthogonal matrices preserve the Euclidean norm of vectors they are multiplied with, so they provably prevent variance explosion or collapse at initialization regardless of network depth. Orthogonal initialization consistently outperforms Gaussian initialization for very deep networks (50+ layers) without skip connections, where Kaiming and Xavier can still produce suboptimal variance propagation due to accumulated multiplicative errors. For networks with residual connections (ResNets, transformers), the skip connections already stabilise variance propagation, making orthogonal initialization less necessary — Kaiming or the LLM-style normal initialization is sufficient. Orthogonal initialization is available in PyTorch as nn.init.orthogonal_ and can be applied to any 2D weight tensor; for Conv2d layers, PyTorch reshapes the weight to 2D before orthogonalising. The main cost is computational — computing an orthogonal matrix via QR decomposition is more expensive than sampling from a normal distribution, which matters for large embedding matrices or wide MLP layers.
Initializing Custom Architectures: A Worked Example
The most reliable approach for a new architecture is to measure activation variance at initialization and confirm it stays close to 1.0 throughout the forward pass before training begins. This takes ten lines of code and immediately reveals initialization bugs that would otherwise appear as slow convergence or NaN losses hundreds of training steps later. Instrument your model’s forward pass to print the mean and variance of activations after each major sublayer, run a single forward pass with a random batch, and verify that the variance stays in the range 0.5–2.0 throughout. Any layer where variance jumps by more than 5x in either direction is a candidate for corrected initialization. This diagnostic is especially important for novel architectures that mix attention, convolution, and MLP components, where the interaction between normalization layers and initialization choices is harder to reason about analytically.
Initialization and Learning Rate Interactions
Weight initialization and learning rate are deeply coupled: the optimal learning rate depends on the scale of the initial weights, and the stability of training depends on both simultaneously. This is why learning rate warmup — starting with a very small learning rate and increasing it over the first few hundred or thousand steps — is standard for large transformer training. At the start of training, the model is far from any useful solution and gradients are large relative to the weights; a high learning rate at this stage causes weight updates to overshoot stable regions. Warmup gives the optimizer time to accumulate gradient history and establish reliable update directions before taking large steps. Models initialised with smaller weights (as with the scaled residual initializations in LLMs) tend to need less warmup because the initial gradient magnitudes are smaller. If you find that a model trains stably with one initialization but shows loss spikes early in training with another, check whether your warmup schedule is calibrated to the weight scale — a warmup that is too short for larger initial weights is a common culprit. As a rule of thumb, increase warmup steps proportionally when you increase initial weight scale, and decrease them when you use a smaller initialization.
The interaction also extends to weight decay. Weight decay regularises training by pulling weights toward zero, which competes with the initialization scale — aggressively regularised models effectively end up with smaller effective weights regardless of how they were initialised. Standard values (1e-2 for AdamW in most LLM recipes) are calibrated for the typical 0.02 std initialization; if you use a non-standard initialization scale, verify that weight decay is not overwhelming your initialised scale in the first few thousand steps by monitoring the L2 norm of weight matrices at regular intervals during training.
Muon and Recent Advances in Initialization-Aware Optimizers
Recent work on optimizers has renewed interest in the relationship between initialization and optimization dynamics. Muon (Momentum + Orthogonalization), developed in 2024 and used in several frontier model training runs, applies Nesterov momentum followed by orthogonalization of the weight update matrix using Newton-Schulz iterations. The orthogonalization step ensures that each gradient step moves the weight matrix in a direction that is optimally conditioned relative to its current state — effectively correcting for the anisotropic gradient scaling that standard SGD and Adam apply. The practical implication for initialization is that Muon is less sensitive to initial weight scale than Adam, because the orthogonalization step normalises the update regardless of gradient magnitude. If you are experimenting with Muon or similar spectral optimizers, your existing Kaiming or LLM-style initializations transfer well and do not need to be redesigned for the new optimizer.