Why AdamW matters

Adams does exactly that for you: Big steps when the gradients do not change much and small steps when they vary rapidly (adapting the step size for each weight individually).Let us understand how Adam works (ignore the colored parts for now):Taken from “Fixing Weight Decay Regularization in Adam” by Ilya Loshchilov, Frank Hutter.Adam keeps track of (exponential moving) averages of the gradient (called the first moment, from now on denoted as m) and the square of the gradients (called raw second moment, from now on denoted as v).In every time step the gradient g=∇ f[x(t-1)] is calculated, followed by calculating the moving averages:m(t) = β1 · m(t-1) + (1-β1) · g(t)v(t) = β2 · v(t-1) + (1-β2) · g(t)²The parameters β1 (i.e. 0.9) and β2 (i.e. 0.999) control how quickly the averages decay meaning how “how far into the past you average over the gradients (squared)”..Read the equations the following way: “The new average is equal to 0.9 (or 0.999 for the squares of the gradients) times the old average plus 0.1 times the current gradient”..With each time step, the old gradients are multiplied with 0.9 one additional time meaning that they contribute less and less to the moving average.Please note that in line 9 and 10 the averages are rescaled by (1-β^t) where t is the timestep..To understand why this is necessary, consider the first time step and remember that m(0) and v(0) are initialized as 0..This means that the average after the first time step is m(1) = 0.9 · 0 + 0.1 · g(1) = 0.1 · g(1)..However, the average after the first time step should be exactly g(1), which you get if you divide m(1) by (1–0.9^1)=0.1.We set η=1 for simplicity (learning rate schedule multiplier) and put everything together in line 12:When descending a step “down the hill” the step size is adapted by multiplying the learning rate α with m(t) and dividing with the root of v(t) (let us ignore the hat ^ at this point).x(t)= x(t-1) — α · m(t) / [ sqrt( v(t) ) + ϵ]Remember that the variance of a random variable x is defined as Var(x) = <x²>-<x>² where < > is the expected value..The exponential moving average of the square of the gradients v(t) is called the uncentered variance because we did not subtract the square of the mean of the gradients.The variance quantifies how much the gradients vary around their means..If the gradients stay approximately constant because we are “walking down a meadow”, the variance of the gradients is approximately 0 and the uncentered variance v(t) approximately equal to m(t)²..This means that m(t) / sqrt(v(t)) is around 1 and the step “down the hill” is in the order of α.If on the other hand, the gradients are changing rapidly, sqrt(v(t)) is much larger than m(t) and the step “down the hill” is therefore much smaller than α.Summing up, this means that Adam is able to adapt step sizes for each individual weight from estimating the first and second moments of the gradients. When the gradients do not change much and “we do not have to be careful walking down the hill”, the step size is of the order of α, if they do and “we need to be careful not to walk in the wrong direction”, the step size is much smaller.In the next section, I will explain what L2 regularization is and in the last section, I will summarize the authors' findings on why Adam with L2 regularization yields models that generalize worse than models trained with SGD and how they propose to fix this problem.2) L2 regularization and weight decayThe idea behind L2 regularization or weight decay is that networks with smaller weights (all other things being equal) are observed to overfit less and generalize better..I suggest that you read Michael Nielsen's great ebook if you are not familiar with the concept.Of course, large weights are still possible but only if they significantly reduce the loss..The rate of the weight decay per step w defines the relative importance of minimizing the original loss function (more important if small w is chosen) and finding small weights (more important if large w is chosen)..If you compare the update of the weights as explained before (new weight is equal to old weight minus learning rate times gradient)x(t) = x(t-1) — α ∇ f[x(t-1)]to the version with weight decayx(t) = (1-w) x(t-1) — α ∇ f[x(t-1)]you will notice the additional term -w x(t-1) that exponentially decays the weights x and thus forces the network to learn smaller weights.Often, instead of performing weight decay, a regularized loss function is defined (L2 regularization):f_reg[x(t-1)] = f[x(t-1)] + w’/2 · x(t-1)²If you calculate the gradient of this regularized loss function∇ f_reg[x(t-1)] = ∇ f[x(t-1)] + w’ · x(t-1)and update the weightsx(t) = x(t-1) — α ∇ f_reg[x(t-1)]x(t) = x(t-1) — α ∇ f[x(t-1)] — α· w’ · x(t-1)you will see that this is equivalent to weight decay if you define w’ = w/α.Common deep learning libraries usually implement the latter L2 regularization. However, the article shows, that this equivalence only holds for SGD and not for adaptive optimizers like Adam!In the last section of this post, I will explain why L2 regularization is not equivalent to weight decay for Adam, what the differences between Adam and the proposed AdamW are and why using AdamW gives better generalizing models.3) AdamWLet us take another look at the Adam algorithm.Taken from “Fixing Weight Decay Regularization in Adam” by Ilya Loshchilov, Frank Hutter.The violet term in line 6 shows L2 regularization in Adam (not AdamW) as it is usually implemented in deep learning libraries.. More details

Leave a Reply