BatchNorm vs LayerNorm: Theory, Assumptions, and Dynamics
Published:
Abstract
Normalization layers mitigate internal covariate shift by enforcing consistent activation statistics. Batch Normalization (BatchNorm) and Layer Normalization (LayerNorm) differ fundamentally in the axis over which moments are computed, leading to distinct inductive biases. This post provides a rigorous comparison of their formulations, underlying statistical assumptions, implications for training dynamics, gradient behavior, and dynamic range preservation.
Core Mathematical Setup
Consider a mini-batch of activations \(\mathbf{X} \in \mathbb{R}^{B \times d}\), where \(B\) is the batch size and \(d\) is the feature dimension. The goal of normalization is to transform \(\mathbf{X}\) such that its (estimated) mean is zero and variance is one:
\(\mathbb{E}[\hat{\mathbf{X}}] \approx \mathbf{0}, \quad \mathrm{Var}(\hat{\mathbf{X}}) \approx 1\)
Batch Normalization: Batch-Wise Statistics
BatchNorm computes statistics independently for each feature dimension \(j = 1, \dots, d\) across the batch axis:
\(\mu_j = \frac{1}{B} \sum_{i=1}^{B} x_{ij}, \qquad \sigma_j^2 = \frac{1}{B} \sum_{i=1}^{B} (x_{ij} - \mu_j)^2\) \(\hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}}, \qquad y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j\)
Why BatchNorm Aligns with Convolutional Architectures
In image data, low-level statistics are largely invariant across samples. Formally:
\(\mathbb{E}_{\text{batch}}[x_{\cdot j}] \approx \mathbb{E}_{\text{data}}[x_{\cdot j}]\)
Layer Normalization: Feature-Wise (Per-Sample) Statistics
LayerNorm computes statistics independently for each sample \(i\) across its feature vector:
\(\mu_i = \frac{1}{d} \sum_{j=1}^{d} x_{ij}, \qquad \sigma_i^2 = \frac{1}{d} \sum_{j=1}^{d} (x_{ij} - \mu_i)^2\) \(\hat{x}_{ij} = \frac{x_{ij} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}}, \qquad y_{ij} = \gamma_i \hat{x}_{ij} + \beta_i\)
Why Transformers Require LayerNorm
In transformers, batch-wise moments become highly noisy and uninformative:
\(\mathrm{Var}_{\text{batch}}(\mu_j) \text{ is large and semantically meaningless}\)
Consequences of Swapping Normalization Axes
The mismatch leads to instability or loss of regularization, as detailed in the previous versions.
Training Stability and Gradient Dynamics
Normalization controls gradient magnitude. For a linear layer \(\mathbf{y} = \mathbf{W} \hat{\mathbf{x}}\), the gradient with respect to the normalized input scales as:
\(\left\| \frac{\partial \mathcal{L}}{\partial \hat{x}_{ij}} \right\| \propto \frac{1}{\sigma} \cdot \left\| \frac{\partial \mathcal{L}}{\partial y} \right\| \approx O(1)\)
Preservation of Dynamic Range Across Layers
Without normalization, activation variance can grow or shrink exponentially with depth. Normalization resets the moments at every layer:
\(\mathrm{Var}(\hat{\mathbf{h}}^{(\ell)}) \approx 1 \quad \Rightarrow \quad \mathrm{Var}(\mathbf{W} \hat{\mathbf{h}}^{(\ell)}) \text{ depends only on } \mathbf{W}\)
The affine parameters \((\gamma, \beta)\) then allow the network to recover any desired scale and shift, keeping activations in the sensitive region of the nonlinearity.
Global vs Local Statistics and Biological Inspiration
BatchNorm’s global statistics can suppress rare signals. Biological systems use local divisive normalization, e.g.:
\(\text{Output} \approx \frac{\text{center response}}{\sqrt{\text{center}^2 + \sum \text{surrounding}^2 + \epsilon}}\)
Final Theoretical Insight
\(\begin{align*} \text{BatchNorm} &\colon \text{align statistics across data (global)} \\ \text{LayerNorm} &\colon \text{stabilize geometry within sample (local)} \\ \text{Biological vision} &\colon \text{local contrast normalization (center-surround)} \end{align*}\)
Key References: Ioffe & Szegedy (2015). Batch Normalization. Ba, Kiros & Hinton (2016). Layer Normalization. Vaswani et al. (2017). Attention Is All You Need. Santurkar et al. (2018). How Does Batch Normalization Help Optimization?
