A Guide to Effective Initialization of Neural Networks

ML Fundamentals Jun 3, 2024

The Problem of Vanishing/Exploding Gradients

During the training of deep neural networks, the mean and variance of activations can quickly shoot off to very high values or drop down to zero, causing the local gradients to become NaN or zero. This prevents the network from learning effectively. There are several techniques to handle this vanishing/exploding gradient problem:

  1. Proper Weight Initialization: Initialization methods like Xavier/Glorot or Kaiming/He scale the weights based on the number of input and output units, ensuring that the variance of the activations is preserved across layers. It helps prevent the activations from exploding or vanishing in the initial stages of training.
  2. Gradient Clipping: It clips the gradients to a maximum value during backpropagation to prevent them from becoming too large. It improves stability during training and can be applied to the gradients of individual parameters or the global norm of the gradients.
  3. Batch Normalization: Batch/Layer Normalization is a widely used technique that normalizes the activations of each layer to have a mean of zero and a standard deviation of one. It also acts as a form of regularization, improving the model's generalisation performance.
  4. Activation functions: Using activation functions that have non-vanishing gradients, such as ReLU (Rectified Linear Unit) or its variants (Leaky ReLU, ELU), can help mitigate the vanishing gradient problem. activation-functions-explained
  5. Model Architecture: Residual connections, as used in ResNet architectures, provide an alternative path for the gradients to flow through the network. LSTM/GRU address the vanishing gradient problem in traditional recurrent networks.
  6. Advanced Optimisers: Optimisers like LARS, LAMB or LION help to mitigate the exploding gradient problem by adjusting the learning rates dynamically during training. A proper warmup schedule also helps in achieving learning stability.

Although, there are multiple approaches, good initial parameters are very essential. Along with other techniques, a good initialization improves training efficiency, resulting in better models at lower cost.

Common Initialization Strategies

Zero Initialization

Setting all weights and biases to zero is a bad strategy as it prevents symmetry breaking and halts gradient flow.

Random Initialisation

Initializing weights with random numbers from a normal distribution can help break symmetry, but the activations still tend to diminish or explode for deeper layers.

Xavier/Glorot Initialization

Proposed in 2010, this initialization scales the weights by $\sqrt{1/n}$, where $n$ is the number of input units. It helps ensure that the variance remains the same across layers for tanh/sigmoid activations.

W i j U [ 6 f a n i n + f a n o u t , 6 f a n i n + f a n o u t ]

Where π‘ˆ is a uniform distribution and π‘“π‘Žπ‘›π‘–π‘› is the size of the previous layer (number of columns in π‘Š) and π‘“π‘Žπ‘›π‘œπ‘’π‘‘ is the size of the current layer.

Kaiming/He Initialization

For ReLU activations, Xavier initialization is not optimal. Kaiming initialization scales the weights by $\sqrt{2/n}$, which helps maintain the variance across layers. This implies an initialization scheme of:

π‘€π‘™βˆΌπ‘(0,2/𝑛)

That is, a zero-centered Gaussian with a standard deviation of 2/𝑛 (variance shown in equation above). Biases are initialized at 0.

LSUV Initialization

Layer-Sequential Unit-Variance Initialization ( All you need is a good init) is a simple method for weight initialization for deep net learning. The initialization strategy involves the following two-step:

  1. First, pre-initialize weights of each convolution or inner-product layer with orthonormal matrices.
  2. Second, proceed from the first to the final layer, normalizing the variance of the output of each layer to be equal to one.

Fixup and T-Fixup Initialization

FixUp Initialization, or Fixed-Update Initialization, aims to train very deep residual networks stably at a maximal learning rate without normalization. It fixes the variance scaling issue due to residual connection. Initialize with Kaiming enable stability after activation: ie $Var(F(x))=Var(x)$ . But now with residual $Var(F(x)+x)$ will be greater than $ Var(x) $ so variance grows with each block!

The steps are as follows:

  1. Initialize the classification layer and the last layer of each residual branch to 0.
  2. Initialize every other layer using a standard method, e.g. Kaiming Initialization, and scale only the weight layers inside residual branches by $L^{\frac{1}{2m-2}}$.
  3. Add a scalar multiplier (initialized at 1) in every branch and a scalar bias (initialized at 0) before each convolution, linear, and element-wise activation layer.

Setting weights to zero can result in the problem of gradient flow halting and symmetricity breaking. The Fixup initialization strategy addresses issues in deep residual networks by utilizing residual connections. These connections allow gradients to flow uninterrupted through the network, preventing vanishing gradients and ensuring continuous learning. Additionally, Fixup breaks symmetry by initializing the first convolutional layer with a non-zero method, enhancing the diversity of learned features. This asymmetry propagates through the network via residual connections, further boosting the model's representational power.

T-Fixup extends this concept further to transformer models.

Calculating Fan-in and Fan-out

For dense layers, fan-in is the number of inputs, and fan-out is the number of outputs. For convolutional layers:

fan_in = num_input_feature_maps * receptive_field_size
fan_out = num_output_feature_maps * receptive_field_size

PyTorch Implementation

PyTorch does not use modern initialization techniques by default for backward compatibility reasons. You can explicitly initialize the weights using torch.nn.init functions for proper initialization:

import torch
import torch.nn as nn

# Helper function for initializing weights
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        # Kaiming initialization for convolutional and linear layers
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
        if hasattr(m, 'special_init'):
            nn.init.constant_(m.weight, 0)
    elif isinstance(m, nn.BatchNorm2d):
        # Batch normalization layer initialization
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

# Example model
class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # Flag the 2nd convolutional layer for special init
        self.conv1.special_init = True
        
        
    def forward(self, x):
    	...
      

# Initialize the model weights
model = ExampleModel()
model.apply(init_weights)

Different initialization strategies are suitable for different activation functions and network architectures. Understanding the rationale behind these strategies will help you improve your networks efficiently. To learn more about this, these are some good references:

  1. https://adityassrana.github.io/blog/theory/2020/08/26/Weight-Init.html#Problem:-Why-does-good-initialization-matter?
  2. https://paperswithcode.com/methods/category/initialization
  3. Β΅Transfer: https://decentdescent.org/tp5.html
  4. https://pavisj.medium.com/convolutions-and-backpropagations-46026a8f5d2c

Tags