Skip to content
HN On Hacker News ↗

irhum.github.io - LoRA and Weight Decay

▲ 50 points 2 comments by jxmorris12 5d ago HN discussion ↗

Pangram verdict · v3.3

We believe that this document is fully human-written

0 %

AI likelihood · overall

Human
100% human-written 0% AI-generated
SEGMENTS · HUMAN 6 of 6
SEGMENTS · AI 0 of 6
WORD COUNT 1,426
PEAK AI % 0% · §6
Analyzed
May 20
backend: pangram/v3.3
Segments scanned
6 windows
avg 238 words each
Distribution
100 / 0%
human / AI fraction
Verdict
Human
Pangram v3.3

Article text · 1,426 words · 6 segments analyzed

Human AI-generated
§1 Human · 0%

LoRA (Hu et al., 2021) is a now popular alternative to the full finetuning of a Large Language Models (LLMs): instead of tuning the billions of weights of the full model, we add small “adapter” weight matrices that modify the original weight matrices, and tune those instead.

This blogpost dives deeper into a curious behavior: although LoRA is commonly seen an drop-in for full finetuning, its interaction with weight decay means it solves a different optimization problem than full finetuning. Namely, one where the solution weights are regularized towards the frozen base model \((W \rightarrow W_{\text{init}})\), instead of \(W \rightarrow 0\) as in full finetuning. This means, given increasingly more resources (even equalling that of full finetuning), LoRA does not increasingly better approximate full finetuning, because its objective function is implicitly different to that of full finetuning. This, depending on use case can either be seen as a bug or a feature, but is something practitioners should explicitly account for.

Recap: Finetuning With LLMs, we typically finetune an initial model (that is “good” on a wide range of text-to-text tasks) to boost performance on a specific task of interest (e.g. generating database queries from natural language). We do this in a two-step process:

First, creating a finetuning training dataset \({(x_i, y_i)_n}\), which contain pairs of inputs \(x\) and targets \(y\).1 Optimize the weights of the initial model such that our finetuning training dataset \({(x_i, y_i)_n}\) becomes more “probable”. The idea here is that a model that is more likely to generate the correct answers \(y\) on \(x\)’s from our training set, will generalize and also be more likely to generate \(y\)’s on new \(x\)’s.

Full Finetuning Full finetuning means we tune all the weights of the model. For a model such as GPT-3 175B (Brown et al., 2020), this means giving our optimization algorithm 175 billion numbers it can “dial” up and down as needed to make our finetuning training data more “probable”.

§2 Human · 0%

Let’s dig a bit deeper, and more concretely define what we mean by weights here. Each layer in a Transformer is primarily made of two components: a multihead attention network, followed by a feedforward network. This means the bulk of the “weights” that make up each layer are stored in six matrices2, as shown. \(\theta\) then, is used as shorthand refer to all the weights, stored in all the matrices across all the layers of the model.

In full finetuning, every single weight in \(\theta\) is opened up for updating. Our aim is to produce updated weights that minimize the negative log likelihood (NLL) as shown on the left3. There’s no closed form way to get the “optimal” weights, so we solve the optimization problem by repeatedly applying many steps of gradient descent, as shown on the right.

Now, directly doing gradient descent this way would quickly lead to overfitting4, so we usually regularize the problem. With LLMs, the regularization tool of choice is usually weight decay. Specifically, when using vanilla SGD5, weight decay is equivalent to having a term in the loss equal to the squared sum of the weights: \[R(\theta)=\sum_i \sum_j[W_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+\cdots\] Hence, the overall objective now is as follows (where \(\lambda\) is a hyperparameter controlling the strength of the weight decay): \[\min_{\color{YellowOrange}{\theta}} \biggl[\underbrace{-\log P_{\color{YellowOrange}{\theta}}({\color{PineGreen}{y}} \mid {\color{RoyalBlue}{x}})}_{\color{BrickRed}{L}} + \frac{\lambda}{2} R({\color{YellowOrange}{\theta}})\biggr]\] Differentiating this to objective to get the gradient, we notice the gradient update has two distinct terms6: the first corresponding to the minimizing the negative log likelihood as before, and a new second term \(-\alpha\lambda w\) that pushes the weight towards the origin \(0\).

§3 Human · 0%

\[ % https://tex.stackexchange.com/a/9477 \def\mathunderline#1#2{\color{#1}\underline{{\color{black}#2}}\color{black}} \begin{align*} &{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \left(\mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} + \mathunderline{LimeGreen}{\frac{\lambda}{2} \frac{\partial R}{\partial \color{YellowOrange}{w}}} \right)\\ \Rightarrow &{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \left(\mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} + \mathunderline{LimeGreen}{\lambda {\color{YellowOrange}{w}}} \right)\\ \Rightarrow &{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \mathunderline{BrickRed}{\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}}} - \alpha \mathunderline{LimeGreen}{\lambda {\color{YellowOrange}{w}}} \end{align*}\] Which means the regularized problem now looks like:

In summary, adding a squared sum of weights loss is equivalent to subtracting a scaled version of each weight at each gradient descent step. This shifts the minima towards where the weights are closer to \(0\)7; i.e. no one weight can have extremely large effects on the predictions of the model. Full finetuning is highly flexible, but also extremely memory intensive: you generally need at least 3x the memory8 required for the model itself, to account for its gradients and optimizer state. This was not an issue when models were \(O(100M)\) params, but is certainly so today where they’re regularly \(O(10B)\) to \(O(100B)\) params. Moreover, if you have 10 sub-tasks in your application (where you’re tuning the model for each task), full finetuning requires you to host 10 versions of the model (as if hosting 1 isn’t expensive as is!).

§4 Human · 0%

LoRA finetuning LoRA (Low Rank Adapter) finetuning takes a different approach: instead of tuning the massive weight matrices of an LLM directly, we use a pair of small adapter matrices for each weight matrix we want to tune, of the following form:

That is, for each initial, frozen weight \(W_{\text{init}}\), we have adapter matrices \(A\) and \(B\). These two matrices are multiplied together to form \(\Delta W\), which is a low rank “adjustment” matrix for \(W_{\text{init}}\), forming the adapted, tuned matrix \(W\). This cuts the number of free parameters significantly: assume the original matrix \(W_{\text{init}}\) is \(4,096 \times 16,384\). In the original, we’d have 67 million parameters to tune just for this one weight matrix, as follows: \[4,096 \times 16,384 = 67,108,864 \approx 67 \text{ million}\] With LoRA with rank \(r=4\), we only have: \[4,096 \times 4 + 4 \times 16,384 = 81,920\] This is less than 0.1% of the original number of parameters; the added overhead of storing 3 variants of these values (weights, gradients and optimizer states) is negligible compared to the memory used by the model itself. Moreover, since the initial weights are “shared” across all the finetuning runs, at inference time we only need to load one copy of the initial model to be shared across many finetuned versions, with inference for each task using their own per-task adapter matrices. This makes having a “per-task” tuned LLM in an application not only viable, but easy.

The Interaction Now that we’ve covered what LoRA is, we can begin to discuss how it interacts with weight decay to produce a feature/bug. Since \(A\) and \(B\) are the “actual” matrices we’re performing gradient descent on, the weight decay term in the objective looks like this, in that we’re moving the minima towards where the adapter matrices are closer to 0: \[R(\theta)=\sum_i

§5 Human · 0%

\sum_j[A_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+ \sum_i \sum_j[B_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+ \cdots\] Let’s contrast this with the formulation in full finetuning:

In full finetuning, we have \(W \rightarrow 0\), in that the weight decays to 0 directly. However, in LoRA, because \(A\) and \(B\) decay to 0, in effect we have \(W \rightarrow W_{\text{init}}\) instead.

This means LoRA solutions are biased towards the original frozen weight matrices, unlike in full finetuning, where they’re biased towards zero. And this behavior does not go away with increasing the LoRA rank \(r\) - one could increase it all the way to infinity(!), and the optimization process would still be biased towards the original frozen weights instead of zero. That is, even in the limit, LoRA does not approximate full finetuning, but a different objective.

A fix If we wanted the full adapted matrix to go towards zero (as would happen in full finetuning), we’d need a regularization term where the entire adapted weight matrix goes to zero, as follows: \[\begin{align*} R(\theta)&=\sum_i \sum_j[W_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+\cdots\\ &=\sum_i \sum_j[W_{{\color{RoyalBlue}q\color{Black}\text{,init}}}^{\color{PineGreen}{1}} + A_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}B_{{\color{RoyalBlue}q}}^{\color{PineGreen}{1}}]_{ij}^2+\cdots \end{align*}\] This is actually straightforward to derive, and yields a pair of update equations that can be implemented much like standard weight decay. First, start at the core definition of weight decay, which involves calculating the gradient of the weight w.r.t.

§6 Human · 0%

the regularization term: \[{\color{YellowOrange}{w}} \leftarrow {\color{YellowOrange}{w}} - \alpha \left(\frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{w}} + \frac{\lambda}{2} \frac{\partial R}{\partial \color{YellowOrange}{w}} \right)\] Second, compute the gradient of \(A\) and \(B\)9 w.r.t. the “corrected” \(R(\theta)\) above. This yields: \[\begin{align*} \frac{\partial R}{\partial \color{YellowOrange}{A}}&=2 (W_{\text{init}} + {\color{YellowOrange}{A}}{\color{PineGreen}{B}}) {\color{PineGreen}{B^T}}\\ \frac{\partial R}{\partial \color{YellowOrange}{B}}&=2 {\color{PineGreen}{A^T}}(W_{\text{init}} + {\color{PineGreen}{A}}{\color{YellowOrange}{B}}) \end{align*}\] Inserting back into the definition of weight decay, we get the following concrete update equations for \(A\) and \(B\): \[\begin{align*} {\color{YellowOrange}{A}} &\leftarrow {\color{YellowOrange}{A}} - \alpha \frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{A}} - \alpha \lambda (W_{\text{init}} + {\color{YellowOrange}{A}}{\color{PineGreen}{B}}) {\color{PineGreen}{B^T}}\\ {\color{YellowOrange}{B}} &\leftarrow {\color{YellowOrange}{B}} - \alpha \frac{\partial \color{BrickRed}{L}}{\partial \color{YellowOrange}{B}} - \alpha \lambda {\color{PineGreen}{A^T}}(W_{\text{init}} + {\color{PineGreen}{A}}{\color{YellowOrange}{B}}) \end{align*}\]

In code This is what the standard formulation of weight decay in the Optax (Babuschkin et al., 2020) library looks like. It’s quite clean: add a weight_decay (\(\lambda\)) scaled version of the parameter p to its current update g10.