Wednesday, 16 October 2019

What is a batch-norm in machine learning?


Batch normalization is a recent idea proposed in [1] to improve training procedures in deep neural networks (and related models). I’m a huge fan of this idea because (a) its ridiculously simple, yet incredibly powerful and sheds a lot of light on the difficulties of training deep nets, and (2) the improvements in training it induces are quite outstanding.
Like I said, BN is embarrassingly simple: normalize the inputs to nonlinearities in every hidden layer. Thats it. Seriously.
To go into a little more detail. (Almost) every hidden layer in a network is of the form:
where
is some nonlinear function (say a RELU), is the weight matrix associated with the layer, and is the output of that layer (dropping bias terms for simplicity). Let’s call BN proposes normalizing
as:
where
are the first and second moments of
respectively. During training we use the empirical moments for every training batch. There are some extra elements used in practice to improve the expressiveness of a normalized batch and to allow this to work during test time, but the above is the core of BN.
The general intuition as to why BN is so effective is as follows. Training deep models is almost always done with gradient based procedures. Every weight is being adjusted according to its gradient under the assumption that all other weights will not change. In practice we change all weights in every iteration. Importantly, changing the weights of layer
changes the distribution of the inputs to layer making any assumptions of the gradient step to the weights of layer
pretty weak, especially at the beginning of training where changes can be dramatic. This greatly complicates training, and makes convergence difficult and slow. By applying BN between layers, we are in a sense enforcing that the inputs to every layer always be close to a standard normal distribution. This eases the dependencies between the layers, and means that the changes made are not counteracted by changes at previous layers. That is a very rough intuition mind you, for more in-depth explanations I definitely recommend reading [1].
This works unbelievably well. In practice, training is an order of magnitude faster (measured in training epochs), is much smoother, and training can be done with much larger learning rates. Its really unbelievable how well this works. My own research has to do with deep generative models, where we typically train a number of dependent deep (Bayesian) nets jointly. In my experience BN has been key to enable us to train the bigger, more complex models.

Normalize the activations of the previous layer at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1.

0 comments:

Post a Comment