BatchNorm vs LayerNorm: Theory, Assumptions, and Dynamics

3 minute read

Published:

Abstract

Normalization layers mitigate internal covariate shift by enforcing consistent activation statistics. Batch Normalization (BatchNorm) and Layer Normalization (LayerNorm) differ fundamentally in the axis over which moments are computed, leading to distinct inductive biases. This post provides a rigorous comparison of their formulations, underlying statistical assumptions, implications for training dynamics, gradient behavior, and dynamic range preservation.

Core Mathematical Setup

Consider a mini-batch of activations \(\mathbf{X} \in \mathbb{R}^{B \times d}\), where \(B\) is the batch size and \(d\) is the feature dimension. The goal of normalization is to transform \(\mathbf{X}\) such that its (estimated) mean is zero and variance is one:

\(\mathbb{E}[\hat{\mathbf{X}}] \approx \mathbf{0}, \quad \mathrm{Var}(\hat{\mathbf{X}}) \approx 1\)

Batch Normalization: Batch-Wise Statistics

BatchNorm computes statistics independently for each feature dimension \(j = 1, \dots, d\) across the batch axis:

\(\mu_j = \frac{1}{B} \sum_{i=1}^{B} x_{ij}, \qquad \sigma_j^2 = \frac{1}{B} \sum_{i=1}^{B} (x_{ij} - \mu_j)^2\) \(\hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}}, \qquad y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j\)

BatchNorm (per channel) x_{1j} x_{2j} … x_{Bj} │ │ │ ─────── batch mean & variance ─────── │ normalize + affine (γⱼ, βⱼ)

Why BatchNorm Aligns with Convolutional Architectures

In image data, low-level statistics are largely invariant across samples. Formally:

\(\mathbb{E}_{\text{batch}}[x_{\cdot j}] \approx \mathbb{E}_{\text{data}}[x_{\cdot j}]\)

Layer Normalization: Feature-Wise (Per-Sample) Statistics

LayerNorm computes statistics independently for each sample \(i\) across its feature vector:

\(\mu_i = \frac{1}{d} \sum_{j=1}^{d} x_{ij}, \qquad \sigma_i^2 = \frac{1}{d} \sum_{j=1}^{d} (x_{ij} - \mu_i)^2\) \(\hat{x}_{ij} = \frac{x_{ij} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}}, \qquad y_{ij} = \gamma_i \hat{x}_{ij} + \beta_i\)

LayerNorm (per sample/token) x_{i1} x_{i2} … x_{id} │ │ │ ─────── feature mean & variance ─────── │ normalize + affine (γᵢ, βᵢ)

Why Transformers Require LayerNorm

In transformers, batch-wise moments become highly noisy and uninformative:

\(\mathrm{Var}_{\text{batch}}(\mu_j) \text{ is large and semantically meaningless}\)

Consequences of Swapping Normalization Axes

The mismatch leads to instability or loss of regularization, as detailed in the previous versions.

Training Stability and Gradient Dynamics

Normalization controls gradient magnitude. For a linear layer \(\mathbf{y} = \mathbf{W} \hat{\mathbf{x}}\), the gradient with respect to the normalized input scales as:

\(\left\| \frac{\partial \mathcal{L}}{\partial \hat{x}_{ij}} \right\| \propto \frac{1}{\sigma} \cdot \left\| \frac{\partial \mathcal{L}}{\partial y} \right\| \approx O(1)\)

Preservation of Dynamic Range Across Layers

Without normalization, activation variance can grow or shrink exponentially with depth. Normalization resets the moments at every layer:

\(\mathrm{Var}(\hat{\mathbf{h}}^{(\ell)}) \approx 1 \quad \Rightarrow \quad \mathrm{Var}(\mathbf{W} \hat{\mathbf{h}}^{(\ell)}) \text{ depends only on } \mathbf{W}\)

The affine parameters \((\gamma, \beta)\) then allow the network to recover any desired scale and shift, keeping activations in the sensitive region of the nonlinearity.

Global vs Local Statistics and Biological Inspiration

BatchNorm’s global statistics can suppress rare signals. Biological systems use local divisive normalization, e.g.:

\(\text{Output} \approx \frac{\text{center response}}{\sqrt{\text{center}^2 + \sum \text{surrounding}^2 + \epsilon}}\)

Final Theoretical Insight

\(\begin{align*} \text{BatchNorm} &\colon \text{align statistics across data (global)} \\ \text{LayerNorm} &\colon \text{stabilize geometry within sample (local)} \\ \text{Biological vision} &\colon \text{local contrast normalization (center-surround)} \end{align*}\)

Key References: Ioffe & Szegedy (2015). Batch Normalization. Ba, Kiros & Hinton (2016). Layer Normalization. Vaswani et al. (2017). Attention Is All You Need. Santurkar et al. (2018). How Does Batch Normalization Help Optimization?