One-line definition
Initialize each layer’s weights from a distribution whose variance is set so the variance of activations (forward pass) and gradients (backward pass) stays approximately constant from layer to layer. Two standard schemes: Xavier/Glorot for tanh/sigmoid layers, Kaiming/He for ReLU-family layers.
Why it matters
If weights are too small, activations shrink toward zero through depth and gradients vanish. If too large, activations explode and gradients blow up. With either failure, training stalls or diverges in the first few hundred steps.
A correct init lets a 24-layer transformer train to convergence with vanilla SGD or Adam; an incorrect init makes the same architecture untrainable without ad-hoc fixes (warmup hacks, smaller LR, etc.).
The variance argument
For a linear layer with , zero-mean with variance , and drawn iid with mean 0 and variance :
To preserve variance (), pick .
The same argument on the backward pass gives . Compromise:
For ReLU activations, half the activations are zeroed out, halving variance. Compensate with a factor of 2:
Practical defaults
| Layer type | Init |
|---|---|
| Linear, ReLU/GELU activation | Kaiming-normal, fan-in |
| Linear, tanh/sigmoid | Xavier-uniform |
| Conv, ReLU | Kaiming-normal, fan-in |
| Embeddings | for transformers; when followed by LayerNorm |
| LayerNorm | 1 |
| LayerNorm | 0 |
| Bias | 0 |
Most modern frameworks default to Kaiming-uniform for nn.Linear (PyTorch). For transformers, GPT-style models often add a per-residual scaling on the output projections to keep residual-stream variance bounded with depth.
Special cases
- Residual connections: with N layers, the residual stream’s variance grows linearly with depth unless the contributions from each block are downscaled. GPT-2 / GPT-3 scale output projections by .
- Identity init for recurrent (Le et al., 2015): initialize the recurrent weight matrix to the identity to make RNNs behave like feed-forward at .
- Orthogonal init: weight matrices initialized to orthogonal matrices preserve norms exactly. Used in some RL policy networks.
Common pitfalls
- Using PyTorch’s default
nn.Linearfor a transformer without checking it. The default is Kaiming-uniform with the wrong fan; many transformer codebases override it with . - Initializing bias to nonzero. Almost never helps; can break symmetry breaking arguments.
- Forgetting to scale residual outputs. Without it, deep transformers produce huge residual-stream values at init.
- Trusting “it trains” as proof of correct init. It might converge slower than a properly initialized run.