Shallow MLP Gym
Train a two-layer MLP on a synthetic regression task. Data is $\mathbf{x} \sim \mathcal{N}(0, I_d)$. The target is a staircase of low-rank monomials; the student trains to match it via SGD.
Setup
student network
$$\hat{f}(\mathbf{x}) = \frac{1}{n} \sum_{i=1}^n a_i\, \sigma\!\left(\frac{\mathbf{w}_i^\top \mathbf{x}}{\sqrt{d}}\right)$$
init ($\mu$P)
$$W_{ij}(0) \sim \mathcal{N}(0,\alpha^2)$$
$$a_i(0) = 0$$
updates ($\mu$P)
$$\Delta\mathbf{W} = -\eta \cdot n\, \nabla_{\!\mathbf{W}}\,\mathcal{L}$$
$$\Delta\mathbf{a} = -\eta \cdot n\, \nabla_{\!\mathbf{a}}\,\mathcal{L}$$
Target function
type
order $T$:
$$\mathcal{L} = \tfrac{1}{2}\,\mathbb{E}\!\left[(\hat{f} - f^*)^2\right]$$
Hyperparameters
input dimension
$d$
10
width
$n$
100
init scale
$\alpha$
1
learning rate
$\eta$
0.01
batch size
$B$
200
nonlinearity
$\sigma$
Network
Simulation
0 steps/sec
Plots