Dec 11, 2023 · 24 min read
Zoology (Blogpost 1): Measuring and Improving Recall in Efficient Language Models
There is tons of excitement around the new attention-free efficient language model architectures threatening Attention’s reign (e.g., Approximate Attention Methods, S4, Liquid S4, MEGA, GSS, H3, BIGS, Hyena, S5, RWKV, RetNet, Monarch Mixer, Mamba, and many more). Checkout Sasha Rush's awesome talk at MLSys 2023: Do We Need Attention?. When recent work started showing that some of these new architectures can match attention in language quality, we decided to dig in.
We started benchmarking a popular class of sub-quadratic Transformer alternatives, collectively referred to as gated-convolutions (e.g. Hyena, H3, and RWKV). We found that these sub-quadratic architectures are much worse than Transformers at recalling information previously mentioned in the prompt (Table 1, Section 1: Mind the Perplexity Gap!). On this task, termed associative recall (AR), we found that a 70M parameter attention model outperforms a 1.4 billion parameter gated-convolution model. Although this might sound like some esoteric task, AR has a long history in machine learning and prior work shows that the ability to solve these AR tasks is highly correlated with enticing capabilities like in-context learning.
To understand the gap on AR, we trained models on a simple synthetic task proposed in prior work for testing AR capability. But, we were totally stumped when all the gated-convolutions solved the synthetic: what then accounted for the huge gap we saw in the real language models? Instead of treating it as a binary, we started measuring performance as a function of sequence length and model dimension. A striking pattern emerged: to solve AR on longer and longer sequences, gated-convolutions needed more and more dimensionality (and, thus, FLOPs) while attention did not (Section 2: Synthetics Help Explain the Gap). We also derived theoretical AR solutions (e.g. by manually setting the weights of the architecture) and showed that these solutions required dimensionality to scale as we observed in the experiment (Section 3: Solving Synthetics by Hand).
Finally, we shed light on what’s needed to close the gap to attention while remaining sub-quadratic (Section 4: Closing the Gap)! For all the details, check out our paper titled Zoology: Measuring and Improving Recall in Efficient Language Models, a nod to the wonderful Hyenas, Hippos, Mambas, etc. that we studied. This work was possible due to an incredible team of collaborators: Aman Timalsina, Isys Johnson, Michael Poli, James Zou, Atri Rudra, and Christopher Ré!
Broader Implications. In the literature, an architecture's efficiency is typically measured in terms of the asymptotic cost of a layer. If there’s one takeaway lesson from our work, it’s this: taken alone, the complexity of a layer is an inadequate measure of efficiency (after all, by this metric, this layer blows everyone else out of the water).
What if we measured an architecture’s efficiency in terms of the FLOPs required to solve a specific task, instead of the FLOPs required per layer?
We think that simple, synthetic tasks are critical for measuring efficiency in this way. But formulating the “right” synthetic task can be challenging (we spent weeks iterating on our associative recall synthetic). We’re excited to share our testbed for synthetics in this simple GitHub repo: HazyResearch/zoology. Hopefully it can serve as a starting point for others using synthetics to measure efficiency!
Section 1. Mind the perplexity gap!
tl;dr Gated convolutions struggle with recall on real language.
We pretrain from scratch 17 language models that span 4 parameter scales (70M, 150M, 360M, and 1.4Bn) and state-of-the gated convolution architectures (Hyena, H3, RWKV).
H3: each layer contains a short convolution followed by a long convolution, sandwiched by elementwise gating.
Hyena: similar to H3, differing in the parametrization and placement of the convolutions.
RWKV: often referred to as an RNN, but it can be rewritten as a gated convolution (at a high level, differing from H3/Hyena in the parametrization for the convolution filters, ordering of gating, convolutions, and projections). See our paper for a formal write-up.
Pure long convolutions (S4-like): uses a global/long convolution, with no gating. We include this reference point to underscore the benefit of gating.
Llama-style Transformers: a modern Transformer baseline with rotary embeddings.
We train all architectures on uniform infrastructure and data (the Pile for 10B tokens, 50B for 1.4Bn parameter models) using the EleutherAI GPT-NeoX codebase. For all architectures, a single model layer includes a sequence mixer followed by a SwiGLU block.
Overall Perplexity: We find that there is a consistent perplexity gap between the SoTA attention-free models and Transformers (extended table in the paper). Results are also linked in this wandb report.
Fine-Grained Analysis: We find a single simple issue is responsible for >82% of the overall perplexity gap: the gated convolution models’ poor quality on tokens that require associative recall (AR). In fact, a 1.4Bn parameter Hyena model still underperforms a 70M parameter attention model by over a perplexity point on the AR slice.
How do we measure quality specifically on the AR slice though? It's unclear how best to measure recall on real language models trained on real data. We propose a simple proxy: evaluating the model’s perplexity on the subset of next-token-predictions that form completions to repeated bigrams (termed, AR Hits). In the examples below “crosscut complex” and Maison Bergey” form bigrams. The second occurrences of “complex” and “Bergey" are deeped AR Hits because they can be predicted by recalling the prior occurrences of the bigrams in the context.
Some bigrams (e.g. Barack Obama) can be memorized during training and don't require in-context recall. In the plot below, we plot perplexity against the frequency of the bigram in the training data. Strikingly, for AR hits that appear infrequently in the training data, there's a huge gap between the gated convolution models and attention. On other tokens and AR hits that appear frequently in the training data, there is virtually no gap between the models.
This suggest that in-context recall is the fundamental issue separating the gated convolution models from attention.
Curious whether these findings hold on larger scale models? Expand below for a cool experiment between RWKV-Raven 7B and Llama 2 7B evaluating their relative associative recall abilities!
Associative Recall at 7B parameters: RWKV vs. LLama
We evaluate the RWKV-Raven model downloaded from https://huggingface.co/docs/transformers/model_doc/rwkv and the Llama 2 model downloaded from https://github.com/facebookresearch/llama. These are both popular models that took a significant amount of effort to train, towards maximizing quality. We find that there is a gap between RWKV and attention at the 7B scale and that it increases as the model needs to conduct more recalls per inputsequence.
We summarize the experimental protocol below.
Data. Since frequent bigrams may just be memorized by the model and not require in-context recall, our work measures AR quality on infrequent bigrams in validation sequences. We do not have access to the custom training data mixtures used in training RWKV or Llama 2 to measure bigram frequencies, so we use a synthetic test to fairly measure the AR capabilities. Taking a tokenized sequence from the Pile, we take the set of occurring tokens and form key-value pairs at random – each key-value pair occurs exactly twice in the sequence. We initialize a new sequence of length 1024 and insert P of these pairs at random positions. We fill the remaining positions with a fixed token ID (0).
Measurement. On the second occurrence of an MQAR key, the model should look back to the prior occurrence to output the corresponding MQAR value token. We measure the AR perplexity of the model based on its ability to predict the correct MQAR value token as the next word for these repeated keys. Again, all sequences are constructed using the models' vocabulary tokens. We evaluate (inference only, no training) when the sequences contain key-value pairs, using samples per value. The tokens that do not contain a key or value are simply filled with a fixed token id (so this token is repeated frequently in the sequence). We plot perplexity for AR and non-AR tokens (fixed token) vs. . We find RWKV quality degrades with on the AR slice (blue line), while all other lines remain flat. MQAR remains problematic for the gated convolution model at scale.
Section 2. Synthetic data help explain the gap!
tl;dr Gated convolutions perform associative recall less efficiently than attention.
There is a divide between our understanding of how gated convolutions solve AR and the downstream AR quality gaps. In particular, in the H3 and Hyena work, we see gated convolutions can solve synthetic tasks that are meant to test AR ability, as perfectly as attention. But since there’s still a real world AR gap, what’s the catch?
A New Test for AR Ability. Through measuring recall on real data, we learn the key disparity is that prior synthetic formulations assume there is one query per input, at a fixed position in the sequence, where tokens come from a small vocabulary size (e.g. |V | < 50, less than model dimension). Yet, language modeling often requires performing multiple recalls -- e.g. for both “centrif -ug -ation” and “is -olation” in the above example, in a single forward pass. The recalls need to be performed at varying positions, with tokens that come from a large vocabulary (larger than model dimension).
The crisp difference between prior and our synthetic setup is:
Compared to the prior AR formulations, MQAR better captures the persisting quality gaps on synthetic and real world data. However, it is not clear why MQAR elucidates the gap.
It’s all a matter of scaling! We study why gated convolutions face this gap. With theory and experiments, we find:
Even though gated convolutions are sub-quadratic in sequence length, the architecture uses larger model widths (dimensionality, hidden size) than attention to solve MQAR.
On MQAR synthetic data, we validate the theoretical results. We construct BaseConv, a canonical, simplified gated convolution model that can provably simulate all other models built from gating and convolution primitives including H3, Hyena, RWKV (find more details in the paper). Across these gated convolutions, we see the dimension needs to grow with sequence length to solve MQAR as well as attention!
So all the architectures can solve the task. But, the scaling needed to do so presents an issue.
Section 3. Theory deepens our understanding!
tl;dr We can manually set the weights of a gated convolution to solve associative recall and see for ourselves why gated convolutions exhibit poor scaling.
In our paper, we analyze solutions to MQAR that gated-convolutions could learn in theory. Specifically, we use a rich theory on polynomials to prove that any model built from gating and convolution primitives can solve MQAR with model dimension that scales in sequence length, while attention can solve MQAR with dimension that is independent of sequence length. Shoutout to Aman Timalsina and Isys Johnson for their work on this theory!
Details on the theoretical results
The broad strokes of our analysis are as follows:
- BaseConv. We defined a minimal representation of a gated convolution, which we term BaseConv, which provably simulates all other architectures that one could design from the gating and convolution primitives (including H3, Hyena, RWKV, RetNet, etc.), within poly-log factors! In code, BaseConv is very simple:
v = conv(u) // convolution
w = linear(u) // projection
y = v * w // gating
Studying the representational capacity of that is required to solve MQAR: To theoretically analyze gated convolutions, we note convolutions and gating are both actually operations that take in two polynomials as input and output a polynomial.
A convolution is defined between two discrete sequences and :
Gating (a Hadamard product) is defined between two discrete sequences and :
So we can view an overall gated convolution model as some complex polynomial. There’s a rich literature that studies the polynomial complexity that is required to represent solutions to different problems (“arithmetic circuits”). We draw on this literature to reason about the complexity of gated convolutions required to solve the MQAR problem!
Through this analysis, we were able to convince ourselves that the scaling of gated convolutions is not an artifact of our experiments, but rather a fundamental property of the architecture. In this section, we provide an intuitive explanation of the theoretical MQAR solutions attention and gated-convolutions might be learning. Hopefully readers will come away with an understanding of the fundamental differences between attention and gated convolutions that lead to the scaling differences we observe in practice.
The Sequence Mixing Matrix. Both attention and convolutions take as input a sequence of embeddings and output a sequence of embeddings of the same shape . Crucially, both work by applying a linear transform that “mixes” the sequence of embeddings together.
The fundamental difference lies in how the matrix is defined. In Attention . Note that is a function of the input ! Meanwhile, in a gated convolution is a convolution matrix that is defined by the model parameters. It is NOT a function of the input .
How does Attention solve MQAR? We’ll begin by briefly describing how attention can solve MQAR. This solution uses two layers. The first layer performs a shift-by-one that combines neighboring MQAR keys and values into a single embedding. The second layer performs long range lookups using the MQAR key part of the embedding as the attention keys and the MQAR value part of the embedding as the attention queries and values. The description below focuses on intuition and is rather hand-wavy - if you’re interested in a more precise discussion, see Theorem H.7.2 in our paper.
One source of confusion is the overloading of the terms key and value in the MQAR task and attention layer. In this section, we’ll try to always use qualifiers like MQAR key and attention key to distinguish.
Layer 1: Shift-by-one. The first layer shifts each key over by one so that the output embedding contains both the MQAR key and the value (e.g. by storing the MQAR key in the first dimensions and the MQAR value in the rest). The corresponding attention matrix (visualized below) can be constructed using position embeddings or Alibi biases.
Layer 2: Lookup. In the second layer, we shift MQAR values forward wherever there are matching MQAR keys (second half of the embedding). Attention makes it easy to find the matching MQAR keys, since we compute the pairwise similarity of the attention query and key embeddings. This means that we can capture the positions where the “C” and “A” previously occurred in the sequence! The matrix is visualized below – lighting up at positions [1, 1] where “C” is the MQAR key for “8” and [4, 4] where “A” is the MQAR key for “3”.
Then attention can put the MQAR values for “C” (i.e. “8”) and “A” (i.e. “3”) into the output sequence as shown below!
Recall how the matrix is computed in attention, given the projections of the input, and :
To construct the necessary attention matrix above, we’ll use the first half of the embedding as the attention key and the second half as the attention query. Below we visualize these comparisons for the 5th and 7th rows of the attention matrix , demonstrating how the attention score spikes at the appropriate position.
With the attention matrix constructed from these comparisons, we can successfully shift the correct values forward and solve the task. Critically, we’re able to do this with a model dimension with size independent of the sequence length (it just needs to be big enough to store two token representations).
How do Gated Convolutions solve MQAR? Gated convolutions can solve MQAR, but as we’ll see below, all known solutions with constant many layers require the model dimension to scale with the sequence length. Below, we describe one such solution that uses two layers. The first will use the convolution and gating to compare each token to all the other tokens to find matching MQAR keys. The second will output MQAR values.
Layer 1: Finding Matching MQAR Keys. To understand how the first layer works, we first need to dive back into the differences between attention and convolutions. **In the second layer of the attention solution, each row of is a one-hot tensor that lights up in the location of the matching key. This is visualized below.
However, in a convolution, is restricted to be fixed diagonal matrix (i.e. Toeplitz). This means a token can “look back” tokens only if all tokens “look back” tokens – we have less flexibility. In the illustration below, each row of is one-hot that performs a shift of four tokens.
This restriction spells trouble for MQAR because there may be shifts of different distances in a single MQAR sequence. For example, solving MQAR in the sequence above requires performing a shift of 2 tokens for A and 6 tokens for C. The ideal matrix that performs these shifts is not diagonal-constant, so it cannot be modeled by a single convolution.
To get around this, we’ll use the fact that in most gated convolution models, we actually apply different convolutions to different dimensions of the input.
A candidate solution. With such convolutions, we can cover the possible gaps that may appear in an MQAR sequence of length ! If dimension stores information for token “C” and dimension stores information for token “A” etc., then the corresponding convolutions and can perform the shifts of distance and respectively! See the schematic below.
Again, when we apply a shift to the sequence, it shifts all token embeddings by the same matrix. So this means, when we apply all these convolutions to the sequence, for all the shift distances, we get an embedding for token that contains information from all of the previous tokens. These embeddings are represented by the pretty rainbow-looking embedding in the visualization below.
Next, we’ll use the gating (i.e. element-wise multiplication) and MLP to produce a mask for token that spikes only at positions that match . In the example below, there is a match with the prior occurrence of
A (which is in orange).!
In the interest of brevity, we’re brushing over some details of how we set the weights of the MLP to accomplish this. The curious reader might check out our paper to learn more or do the exercise of convincing themselves that the weights of the MLP can be set appropriately.
Layer 2: Shift and mask. With this mask in hand, we will now apply it to the sequence to isolate the correct MQAR value for each MQAR key. To do so, we perform a similar shift as in layer one, except this time with an off-by-one (since the MQAR value is the token that comes one position after the MQAR key). This is visualized below.
We can then apply the gating we produced in the previous layer to isolate the correct token!
As an aside, you might be wondering why the input to layer two in the visualization is the original token when layer one outputs the mask. This can be handled by the residual connections and MLPs in the overall architecture.
Finally, we’ve isolated the correct value
3 which can be predicted with the language model’s classification head.
And that’s it! Something should feel pretty unsatisfying about this solution though…
The bottom line. So we saw in the prior section that gated convolutions require dimensionality that scales with the sequence length to solve MQAR. So what’s the reconciliation? Well, we made a big assumption in our gated convolution solution: that the model dimension was big enough to store a full copy of the sequence in a single embedding. In other words, for this solution to work, must scale linearly with . This explains the scaling we saw in the synthetic experiments of Section 2.
Section 4: Closing the Associative Recall Gap!
tl;dr Selective sparse attention keeps us sub-quadratic.
Attention uses quadratic interactions between tokens to identify the matching MQAR key tokens, then efficiently extracts the MQAR value. The architecture first compares tokens then shifts information (the MQAR value) forwards. The convolution architecture, in contrast, reverses the order – the shift comes first and then the compare. Every other token j needs to be shifted to position i in order to compare tokens i and j. The challenge is that there is often insufficient dimensionality/convolutions being applied to perform all the required token i and j comparisons to solve MQAR. We need to increase the dimensionality to solve MQAR with convolutions, hurting efficiency.
Input-dependence is key! In our work, we prove that if our architecture instead had input-dependent convolutions, we could solve MQAR as efficiently as attention – with dimensionality independent of sequence length. Through input-dependence, we could determine which shifts are required for the input and only support those token interactions. If the number of required shifts is not too large, we don’t need large dimensionality!
Expository input-dependent architectures on the Pile. We use these insights around input-dependence to design architectures that efficiently close the MQAR gap. Here, we present two proof of concept architectures, which highlight how input-dependence closes the MQAR gap. We take our canonical gated convolution model (BaseConv) and just replace 3 layers with each the following alternative layers:
Programmatic Selection: We identify all positions in the sequence where the token previously occurred in the sequence (by simply causally inspecting the raw input token IDs). We sparsely use attention at these positions. E.g. if the sentence contains “Sesame Street and Spy Kids .. Sesame”, attention would be used at the second “Sesame”. This simple method closes most of the overall and AR gap!
Random Selection (Control!). This results in a variable amount of attention per sequence. As a control, we take the same amount of attention and randomly distribute it across positions in the same input sequence. This clearly underperforms the Programmatic Selection!
Learned Selection. We simply use a linear layer and sigmoid at each position to decide whether to use attention. We only use attention on the top-k positions to ensure that the layer is sub-quadratic.
|+ Random Selection
|+ Programmatic Selection
|+ Learned Selection (Sub-Quadratic!)
In our next post, we present
Based, which generalizes these expository results in a strong, unified architecture. We show that
Based outperforms the strong Llama-style Transformer baseline in quality on the Pile, while remaining sub-quadratic in sequence length during training and allowing O(1) inference complexity!
Based is also competitive with late-braking concurrent architectures like Mamba at the evaluated scales.
Extended Analysis: Understanding concurrent input-dependent architectures.
In the final days of this work, we were excited to see new releases like Mamba and Striped Hyena, which also share insights on the importance of input-dependence. We were excited to evaluate these models downstream and on the MQAR synthetics, and briefly share some findings on their strengths and potential tradeoffs.
In the last week, we’ve seen a couple interesting ideas that improve a sub-quadratic architecture’s input-dependence:
Mamba. Mamba uses input-dependent A (via discretization), B, C matrices in the State Space Model. We note that the prior RetNet model uses input-dependent B and C matrices. These are recurrent models, which recall capture the prior context in a hidden state of fixed dimension. We find that RetNet and Mamba perform much better than gated convolutions on MQAR.
However, we noticed Mamba started dropping off at longer sequence lengths and decided to study this further. We find Mamba quality falls off as the number of MQAR keys and values (KVs) per sequence grows large. Intuitively, it is difficult to “store” all the keys and values in a single RNN hidden state. We theoretically prove in our paper that the hidden state for RetNet needs to grow with the number of MQAR KV pairs to solve the task. We can see these behaviors translate to the empirical results below for Mamba:
Based appears to outperform Mamba as the number of KVs per sequence increases on synthetic data.
Multi-head convolutions. Striped Hyena uses attention, so it is not a pure sub-quadratic model. However their analysis shows interesting results with respect to increasing the number of convolution heads! This idea was first proposed in the H3 paper and a theorem in the Laughing Hyena work shows that increasing the number of heads should improve the gated convolution architecture’s ability to solve recall with lower dimensionality. We can see this translate to the empirical results on our MQAR task below:
An important note here is that Hyena (and thus Striped Hyena, Multi-Head Hyena) have a long convolution, which makes it a challenging choice for high throughput and low latency inference (since the KV cache is as large as the input sequence length).
In conclusion, we learn that recently popularized gated convolutions require undesirable scaling relative to attention for solving associative recall in language modeling, in contrast to common thought. The scaling could be improved by increasing the effective dimensionality of the convolution step (multi-head convolutions) at the cost of additional FLOPs. We could also use input-dependent sequence mixing to recover scaling that matches attention. We propose MQAR as a simple synthetic for evaluating your new architecture's AR ability -- in our work, we demonstrate that MQAR synthetic results correlate with downstream language modeling AR quality. In the next post, we'll put this analysis to further use by proposing an architecture that closes the AR gap to attention and outperforms overall, while remaining sub-quadratic in sequence length!
We would like to thank our wonderful collaborators and labmates including Chris Ré, Dan Fu, Michael Poli, Jerry Liu, Michael Wornow, and Atri Rudra for discussion and feedback on this blogpost. We are grateful for the support of Together.ai.