4/10/2025
Deep neural networks have incredibly complex optimization trajectories that are difficult to build intuition for. For example, some optimization trajectories have long plateaus, while others learn functions of increasing complexity over training. To get a better conceptual understanding, there has been a lot of work on simpler models of optimization that are faithful to surprising empirical phenomena. Recently, I've been playing around with a simple toy model of optimization introduced in Saxe et al, 2018: vanilla gradient descent on a network of two linear layers. I was kind of in disbelief of how reflective this simple model is. Even better, this model is simple enough to do fully analytical derivations! The original presentation is grounded in neuroscience; we're going to focus more on the mathematical underpinnings. I will also take liberty in making the framework more/less general where necessary. Hopefully, this post can convey the core intuition about why deeper models optimize differently.
We will assume we are given $n$ orthogonal inputs $x^{(i)} \in \R^d$ of unit norm with outputs $y^{(i)} \in \R^d$ given by a true function $y^{(i)} = f^*(x^{(i)})$. Our goal is to learn $f$ by minimizing the mean squared error $\Eof{x^{(i)},y^{(i)}}{(f(x^{(i)}) - y^{(i)})^2}$.
How do we parameterize $f$? In this post, we will consider two cases
A central quantity of interest is the cross-covariance matrix of the output and input of the network $\Syx = \Eof{x^{(i)},y^{(i)}}{y^{(i)} x^{(i)T}}$.This is the linear function that minimizes the mean squared error between the output and input of the network.
Note that both architectures are equivalent in their expressivity, i.e. they can both represent (and only represent) any linear function of the input. Why do we care about two layer networks if they are equivalent to one layer networks? Even though they share the same loss minimizer, the optimization trajectory of the two networks are different, even for the simplest type of gradient descent. In the next two sections, we observe how the singular values of $f$ converge to the singular values of $\Syx$ at different rates. On a technical level, we will solve differential equations that correspond to the singular values of $f$ over time under the same change of variables as $\Syx$.
What does the optimization trajectory of the one layer linear network look like? The update to the weight matrix $\W{t}$ at time step $t$ is the gradient of the MSE loss, scaled by the learning rate $\lr$. Therefore, the instantaneous weight update can be written as
\begin{aligned} \frac{d\W{t}}{dt} &= \lr \sum_{i=1}^n \p{y^{(i)} - \W{t} x^{(i)}}x^{(i)T} && \text{[gradient of MSE]}\\ & = \lr \sum_{i=1}^n y^{(i)} x^{(i)T} - \lr \W{t} \sum_{i=1}^n x^{(i)} x^{(i)T} && \text{[expand]}\\ & = \lr \sum_{i=1}^n y^{(i)} x^{(i)T} - n \lr \W{t} && \text{[unit norm assumption]} \\ \p{\frac{1}{n\lr}} \frac{d\W{t}}{dt} &= \Syx - \W{t} && \text{[divide by $n\lr$]} \end{aligned}
Intuitively, this says that barring scaling factors, the strength of the update depends on how far the current weight matrix is from the cross-covariance matrix.
The progress of this trajectory is difficult to interpret directly. However, we can track the progress of the trajectory by performing a change of basis and looking at the singular values. We first decompose the target $\Syx$ as $\Syx = U S V^T$ where $S$ is a diagonal matrix of the singular values of $\Syx$. We also assume our initial model weights can be set to $\W{0} = U \St{0} V^T$ for some diagonal $\St{0}$ (one such initialization is $\W{0} = 0$). Interestingly, as long as this holds at initialization, the entire parameter trajectory can be written as $U \St{t} V^T$ for diagonal $\St{t}$ at time $t$. We can see this "inductively" by writing the update assuming $\W{t} = U \St{t} V^T$.
\begin{aligned} \p{\frac{1}{n\lr}} \frac{d\W{t}}{dt} &= \Syx - \W{t} \\ &= U S V^T - U \St{t} V^T \\ &= U \p{S - \St{t}} V^T \\ \p{\frac{1}{n\lr}} \frac{d}{dt} \St{t} &= S - \St{t} \end{aligned}
This is a differential equation for the singular values of $W$ over time. Since $S$ and $\St{t}$ are diagonal, this forms a seperate differential equation for each singular value. We can solve a differential equation for the $i$th singular value $\St{t}_{i}$ as a function of time.
\begin{aligned} \St{t}_{i} &= \p{e^{-n\lr t}}\St{0}_{i} + \p{1 - e^{-n\lr t}} S_{i} \\ \end{aligned}
Using our intialization of $\St{0} = 0$, we get the trajectory
\begin{aligned} \St{t}_{i} &= S_i \p{1 - e^{-n\lr t}} \\ \end{aligned}
Awesome! We now have a differential equation for the singular values of $W$ over time. As time goes to infinity, $\St{t}$ converges to the loss minimizer $S$, which means $W$ converges to the loss minimizer $\Syx$. Importantly, our result gives us the exact rate of convergence in terms of the hyperparameters and distance from initialization. We will use this rate later when we want to compare with the two layer network.
Now that we've done the warmup, the two layer network derivation will feel very natural. Following similar steps as above, we can write the update to the weight matrices $\Wone{t}$ and $\Wtwo{t}$ as
\begin{aligned} \p{\frac{1}{n\lr}} \frac{d\Wone{t}}{dt} &= \p{\Wtwo{t}}^T \p{\Syx - \Wtwo{t} \Wone{t}} \\ \p{\frac{1}{n\lr}} \frac{d\Wtwo{t}}{dt} &= \p{\Syx - \Wtwo{t} \Wone{t}} \p{\Wone{t}}^T\\ \end{aligned}
Similar to the one layer case, the gradient penalizes the distance between the current function $\Wtwo{t} \Wone{t}$ and the target function $\Syx$. We would like to analyze this after a change of variables using the SVD of $\Syx = U S V^T$. Here, we assume that at initialization, we can write the matrices in terms of the SVD under an arbitrary rotation $R$. We also assume that the singular values of $W^1_0$ and $W^2_0$ are the same (motivated by both being near zero).
\begin{aligned} W^2_0 W^1_0 &= U S_0 V^T \\ W^1_0 &= R \sqrt{S_0} {V}^T \\ W^2_0 &= U \sqrt{S_0} R^T \\ \end{aligned}
The arbitrary rotation $R$ (that satisfies $R^T R = I$) gives us necessary flexibility to write the initial weight matrices in terms of $U$ and $V$. The diagonal matrix $S_0$ represents the initial effective singular values of $f = W^2 W^1$. Unlike the one layer case, we can not initialize the weight matrices to be zero since this would make the gradient update zero. Therefore, we need to assume that a sufficiently small initialization is used such that the initialization conditions are approximately satisfied but is sufficiently large to avoid the saddle point at the origin. In all honesty, these assumptions are quite strong, but they allow us to do clean derivations and are likely not making egregious errors in modeling the empirical phenomena.
What can we say with all this setup? Turns out that our initial change of basis continues to hold inductively. Lets simplify the gradient update in terms of the singular values.
\begin{aligned} \p{\frac{1}{n\lr}} \frac{d\Wone{t}}{dt} &= \p{\Wtwo{t}}^T \p{\Syx - \Wtwo{t} \Wone{t}} && \text{[earlier derivation]}\\ &= \p{U \sqrt{\St{t}} R^T}^T \p{U S V^T - U \sqrt{\St{t}} R^T R \sqrt{\St{t}} V^T} && \text{[SVD parameterization]} \\ &= \p{R \sqrt{\St{t}} U^T} \p{U S V^T - U \St{t} V^T} && \text{[simplify]} \\ &= R \sqrt{\St{t}} \p{S - \St{t}} V^T && \text{[$U^TU = I$]} \\ \p{\frac{1}{n\lr}} \frac{d\St{t}}{dt} &= \sqrt{\St{t}} \p{S - \St{t}} && \text{[extract updated matrix]} \\ \end{aligned}
We can repeat this derivation for $\Wtwo{t}$ as well.
\begin{aligned} \p{\frac{1}{n\lr}} \frac{d\Wtwo{t}}{dt} &= \p{\Syx - \Wtwo{t} \Wone{t}} \p{\Wone{t}}^T && \text{[earlier derivation]}\\ &= \p{U S V^T - U \sqrt{\St{t}} R^T R \sqrt{\St{t}} V^T} \p{R \sqrt{\St{t}} V^T}^T && \text{[SVD parameterization]} \\ &= \p{U S V^T - U \St{t} V^T} V \sqrt{\St{t}} R^{T} && \text{[simplify]} \\ &= U \p{S - \St{t}} \sqrt{\St{t}} R^{T} && \text{[$V^T V = I$]} \\ \p{\frac{1}{n\lr}} \frac{d\St{t}}{dt} &= \p{S - \St{t}} \sqrt{\St{t}} && \text{[extract updated matrix]} \\ \end{aligned}
Since $S$ and $\St{t}$ are diagonal, the updates to both weight matrices are identical. This means that the assumptions we made about the initial weight matrices are preserved under the gradient updates. In comparison to the one layer case, the magnitude of the update is larger for larger singular values and smaller for smaller singular values (due to the $\sqrt{\St{t}}$ term). Therefore, it suffices to solve the differential equation for $\St{t}$ over time. The trajectory for the $i$th singular value $\St{t}_{i}$ is given by
\begin{aligned} \St{t}_{i} &= S_i \p{\frac{e^{2S_i n \lr t}}{e^{2S_i n \lr t} + \frac{S_i}{\St{0}_i} - 1}} \\ \end{aligned}
This trajectory, like the one layer case, starts at $\St{0}$ and converges to $S$. Since this exact trajectory is a bit hard to interpret directly, we will make a few approximations. First, we will assume that the initialization is small, so that $\frac{S_i}{\St{0}_i} \gg 1$. For a reasonable time $t$, the exponential term will dominate and we will assume $e^{2S_i n \lr t} \gg \frac{S_i}{\St{0}_i} \gg 1$. We will now approximate the second trajectory as
\begin{aligned} \St{t}_{i} &= S_i \p{\frac{e^{2S_i n \lr t}}{e^{2S_i n \lr t} + \frac{S_i}{\St{0}_i} - 1}} && \text{[derived above]} \\ &= S_i \p{1 - \frac{\frac{S_i}{\St{0}_i} - 1}{e^{2S_i n \lr t} + \frac{S_i}{\St{0}_i} - 1}} && \text{[rewrite]} \\ &\approx S_i \p{1 - \frac{\frac{S_i}{\St{0}_i}}{e^{2S_i n \lr t}}} && \text{[remove dominated terms]} \\ \end{aligned}
Now that we've worked out the math, we can see if this model is faithful to some aspects of practice.
How does the convergence rate of the two layer network compare to the one layer network? Lets take a look at their value across time.
\begin{aligned} \St{t}_{i} &\approx S_i \p{1 - e^{- n \lr t}} && \text{[one layer]} \\ \St{t}_{i} &\approx S_i \p{1 - \p{\frac{S_i}{\St{0}_i}} e^{-2S_i n \lr t}} && \text{[two layer]} \\ \end{aligned}
There are two main difference to highlight. First, the speed of the two layer network depends on the magnitude of initialization, which comes in as a constant preceding the exponential. Second, the exponent of the two layer network depends on $S_i$, which means that convergence is much faster for larger singular values. This is a concrete difference in the learning dynamics of one layer and two layer networks, even if their expressivity is the same.
It has been observed in many settings that models learn functions of increasing complexity over training (for some definition of complexity). This has been observed by models early in training being well characterized by linear models (Nakkiran et al, 2019), low degree polynomials (Abbe et al, 2023), classifiers using lower-order input statistics (Refinetti et al, 2022), etc.
This toy model predicts this phenomenon for two layer networks. Specifically, since the two layer network learns the largest singular values first, it will start with the best rank 1 approximation, followed by the best rank 2 approximation, and so on. This gives a very simple model to study questions surrounding simplicity bias. I have personally found this useful for a research question I'm working on, paper soon hopefully :))
The toy model is a continuous time model. In practice, we take discrete gradient updates and time is not continuous. Does this model accurately reflect discrete time optimization? Fortunately, work has found that the main ideas of this derivation can be rigorously extended to the discrete case when you take many steps with low learning rate (Gidel et al, 2019).
Applying a variety of suspicious assumptions, we can show that two layer networks learn qualitatively differently from one layer networks, even though they have the same expressivity, loss minimizer, and limiting solution. Specifically, the two layer network learns in directions with larger singular values first, leading to faster initial learning of high-performing simple functions. I think this setup is amazing due to its surpising faithfulness to various empirical phenomena while being ridiculously mathematically tractable. Please refer to the original paper for more details, extensions, and empirical results. Thank you for reading, and feel free to reach out with any questions or thoughts!