Lazy vs. rich training dynamics: it's all in the output multiplier
One of the most important ideas in deep learning theory from the last few years is that sometimes, in some settings, training in deep learning may be linearized in its parameters. Training in this “lazy” regime can be induced just by adding an output multiplier on the network.
Here's an example from Chizat et al. (2018). To start, we're going to make a “teacher” function of the form $$f^*(\mathbf{x}) = \frac{1}{\sqrt{k}} \sum_{i=1}^k a^*_i \,\mathrm{ReLU}\!\left(\langle \mathbf{w}^*_i, \mathbf{x} \rangle\right).$$ This teacher is just a ReLU network with random weights. We could have chosen another teacher function, but this one will allow nice visualizations. We let the teacher have $k = 3$ ReLU neurons. The input is two-dimensional: $\mathbf{x} \in \mathbb{R}^2$.
The “student” network is the one we will train. Its function is $$\hat{f}(\mathbf{x}) = \frac{\alpha}{n} \sum_{i=1}^n a_i \,\mathrm{ReLU}\!\left(\langle \mathbf{w}_i, \mathbf{x} \rangle\right).$$ We let the number of student neurons $n = 200$. We train the parameters $(a_i, \mathbf{w}_i)_{i=1}^n$ to make the student represent the teacher, minimizing the squared loss over Gaussian data $$\mathcal{L} = \mathbb{E}_{\mathbf{x} \sim \mathcal{N}(0,\mathbf{I}_2)}\!\left[\left(\hat{f}(\mathbf{x}) - f^*(\mathbf{x})\right)^2\right].$$
The parameterization and the output multiplier. The student network has a factor of $\frac{\alpha}{n}$ in front of it. The factor of $\frac{1}{n}$ puts the network in mean-field parameterization. We initialize all student parameters from $\mathcal{N}(0,1)$ and train with ordinary gradient descent. The extra factor of $\alpha$, however, is our new addition. Making this output multiplier large, small, or order one will fundamentally change how training happens and where the student weights end up.
Below is an interactive module that lets you see how training changes with different values of $\alpha$. First, try running the simulation at $\alpha = 1$. You will see the loss drop on the left plot, and in the right plot, the student weights (dots) will drift closer to the teacher features (dashed lines). If you use $\alpha \gg 1$, you will find that the dynamics are “lazy”: the student weights need not change much to drive the loss down. By contrast, when $\alpha \ll 1$, the dynamics are “ultra-rich”: the student weights need to change enormously, and they grow to strikingly align with the teacher features.
- Explain what red and blue represent in the weight plot
- Add intuition at the start for what “lazy” means: net weights/features don’t change much during training
- Gear ratio analogy: large $\alpha$ is like a large gear ratio—tiny movements of the weights produce big changes in the output