7/27/2024
The most common neural network architecture for sequence modeling is the transformer, replacing recurrent neural networks (RNNs). To output $n$ tokens, a typical causal transformer takes $\bO{n^2}$ time (taking $\bO{i}$ time to decode the $i$th token). An RNN does not face this problem, achieving $\bO{n}$ time (taking $\bO{1}$ time per decoding step). However, RNN's do not effectively leverage parallelism at training and do not match transformers performance at a given training compute budget.
The elegant work of Katharopoulos et al, 2020 finds that a slightly modified transformer is equivalent to an RNN's. This specifies an architecture that achieves the training parallelism of transformers with the decoding complexity of RNN's. This connection underlies the success of state space models (Gu et al, 2022), the current strongest competitor to transformers (refer to Gu and Dao, 2024 for a rigorous/up-to-date connection).
In this post, we will first review a vanilla transformer and RNN. Then, we will see how transformers with linear attention are equivalent to RNN's. This post very closely follows the elegant exposition of the original paper.
An RNN operates on a sequence token by token. For the $i$th token $x_i \in \R^d$, it produces a hidden state $h_i \in \R^k$ that 1) predicts the next token and 2) determines the next hidden state. To predict the $i$th token $x_{i+1} \in \R^d$, it uses the following recurrence
\[ \begin{aligned} h_{i+1} &= \RNN(x_i, h_i) &\text{[$\RNN : \R^k \times \R^d \to \R^d$]}\\ x_{i+1} &= \Head(h_{i+1}) &\text{[$\Head : \R^k \to [V]$]} \end{aligned} \]
where $\RNN$ is a deep neural network and $\Head$ (potentially randomly) selects the next token from $V$ candidates using the information in the hidden state. This sequential nature enables fast inference since it takes $\bO{1}$ time to generate each token. However, its sequential nature prohibits fast parallel training, limiting its application in current models.
Suppose the input is $x \in \R^{n \times d}$ for $n$ input tokens. A transformer produces an output $\R^{n \times k}$, where the $i$th layer of the output predicts the $i+1$th token. A transformer is composed of multiple blocks that are each $\R^{n \times k} \to \R^{n \times k}$ is implemented as
\begin{aligned} \Block(x) &= \MLP(\Attn(x) + x) &\text{[$\MLP ,\Attn ,\Block : \R^{n \times d} \to \R^{n \times d}$]} \\ \Transformer &= \Block_1 \circ \cdots \circ \Block_m &\text{[$\Transformer : \R^{n \times d} \to \R^{n \times d}$]} \\ h &= \Transformer(x) &\text{[$h \in \R^{n \times d}$]} \\ x_{i+1} &= \Head(h_i) &\text{[$\Head : \R^d \to [V]$]} \end{aligned}
$\MLP$ applies a fixed neural network $\R^d \to \R^d$ to each token independently and only takes constant compute per token. $\Attn : \R^{n\times d} \to \R^{n \times d}$ operates across all the tokens and is where the model spends quadratic time at generation. To generate $n$ tokens, we iteratively generate the $i+1$th token from the previous $i$ tokens.
We will focus on efficiently implementing an attention layer for a single block. Attention first computes a similarity matrix between all the tokens. This similarity matrix is used to linearly combine the representations of other tokens. More concretely, define weight matrices $W_Q \in \R^{d \times k}$, $W_K \in \R^{d \times k}$, and $W_V \in \R^{d \times d}$. Then, compute the attention as
\begin{aligned} Q &= xW_Q &\text{[$Q \in \R^{n \times k}$]} \\ K &= xW_K &\text{[$K \in \R^{n \times k}$]} \\ V &= xW_V &\text{[$V \in \R^{n \times d}$]} \\ \Attn(x) &= \text{softmax}\left(\frac{QK^{\top}}{\sqrt{d}}\right)V &\text{[$\Attn(x) \in \R^{n \times d}$]} \end{aligned}
Most transformers are causal, which means that the output for token $i$ only depends on tokens that precede it. This is achieved by masking the similarity matrix before applying softmax. Therefore, we can compute the $i$th row of attention as
\[ \Attn(x)_i = \frac{\sum_{j\in[i]}{\text{sim}(Q_i, K_j)V_j}}{\sum_{j\in[i]}{\text{sim}(Q_i, K_j)}} \]
for appropriate $\text{sim}$ (i.e. softmax is $\text{sim}(q, k) = \exp(q^{\top}k / \sqrt{d})$). Naively getting the $i$th row of output $h_i$ from the transformer would take $\bO{i^2}$ time since we also have to compute every previous key $K_j$ and value $V_j$. However, since keys and values depend only on previous tokens, we can cache them from decoding the previous $i-1$ tokens (referred to as KV-caching). Therefore, decoding the $i$th token takes $O(i)$ time, and decoding $n$ tokens takes $\bO{n^2}$.
To decode $n$ tokens from a transformer starting from a context of say one token, one feeds in the one token, takes the output associated with this token, and passes it into a linear layer to determine the second token. One can now pass in these two tokens, look at the output associated with the second token, and use it to determine the third token. This is repeated until $n$ tokens are obtained.
Real transformers have many necessary additions such as positional encodings and multiple attention heads. We'll intentionally ignore these to clean up the exposition.
Consider dropping the $\text{softmax}$ in $\Attn$. Then,
\begin{aligned} \Attn(x) &= (QK^{\top})V \\ \Attn(x)_i &= \frac{\sum_{j\in[i]}{(Q_i^{\top}K_j)V_j}}{\sum_{j\in[i]}{(Q_i^{\top}K_j)}} \end{aligned}
At first glance, this still seems like $\bO{i}$ compute for the $i$th token. However, a simple application of the associate property solves all of our problems :))
\begin{aligned} \Attn(x) &= Q(K^{\top}V) \\ \Attn(x)_i &= \frac{Q_i^{\top}\sum_{j\in[i]}{K_jV_j}}{Q_i^{\top}\sum_{j\in[i]}{K_j}} \end{aligned}
Notably, we can efficiently compute $\sum_{j\in[i]}{K_jV_j}$ and $\sum_{j\in[i]}{K_j}$ by saving $\sum_{j\in[i-1]}{K_jV_j}$ and $\sum_{j\in[i-1]}{K_j}$ from decoding the previous token. This takes decoding the $i$th token from $\bO{i}$ time to $\bO{1}$ time! Now, at the loss of the softmax non-linearity, we can get the training efficiency of transformers with the decoding speed of transformers!
Our victory only requires that we can apply the associative property to multiply keys $K_j$ and values $V_j$ before applying $\text{sim}$. Therefore, we can still keep non-linearity as long as $\text{sim}$ can be seperately applied to queries $Q_i$ and keys $K_j$. We can generalize our fast linear attention by defining similarity as $\text{sim}(q, k) = \phi(q)^{\top} \phi(k)$ for kernel function $\phi : \R^{k} \to \R^{k'}$ via
\begin{aligned} \Attn(x) &= \phi(Q)(\phi(K)^{\top}V) \\ \Attn(x)_i &= \frac{\phi(Q_i)^{\top}\sum_{j\in[i]}{\phi(K_j)V_j}}{\phi(Q_i)^{\top}\sum_{j\in[i]}{\phi(K_j)}} \end{aligned}
The paper finds that using $\phi(x) = \text{elu}(x) + 1$ performs well empirically ($\text{elu}$ defined in (Clevert et al, 2016)).
Transformers are generally considered to be very different from RNN's and are slower at inference. Interestingly, a simple modification of the attention mechanism collapses a transformer into an RNN that takes constant time to decode each token while preserving training parallelism.
The original work has some (now dated) evidence that this is faster than standard transformer infernece. The line of work on SSM's and Mamba has scaled this conceptual approach to real workloads, making it feasible to train models that are much faster for autoregressive inference. This is especially promising for tasks that require long context lengths such as code generation and DNA sequencing.
From a scientific perspective, this shows how transformers may not be as "special" as previously thought. For example Ahn et al, 2024 shows that transformer optimization is qualitatively similar for linear attention, which we know are effectively RNN's. This adds to the growing evidence that one should select architectures that most effectively leverage compute instead of searching for special implicit biases.
I'm a huge fan of the simplicity and exposition of this paper, clearly communicating a profound observation that has inspired future architecture research. Thank you for reading, and feel free to reach out with any questions or thoughts!