Just read twice: closing the recall gap for recurrent language models

Simran Arora

Full team: Simran, Aman, Aaryan, Ben, Sabri, Xinyi (Jojo), Ashish, Atri, and Chris Ré.

ArXiv | Code | Models

TLDR Today's efficient ML architectures struggle to perform a fundamental skill called associative recall (AR) as well as the de-facto Transformer architecture. AR requires an LM to use info provided in long contexts when generating tokens. The issue is that efficient LMs need to predict what info from long contexts to store in a limited amount of memory. We observe that if prompts are ordered in "the right way" or we use non-causal variants of efficient LMs to process prompts, the prediction problem can become much easier! Our work departs from the modern causal language modeling orthodoxy, where LMs process text left-to-right in a fixed order, to close AR gaps between memory efficient and Transformer LMs.

Introduction

Recent work has made rapid progress in developing fixed-memory recurrent architectures (e.g., Mamba and RWKV) that are competitive with attention in language modeling perplexity. During inference, these models are more memory efficient and asymptotically faster than attention! Amazing, we have quality and efficiency! But what are we missing? As the saying goes, there's no free lunch!

While these LMs are efficient, they make a tradeoff in "associative recall" (AR) ability: they struggle to use information provided in-context when generating responses. Why does this matter? Well while the greater efficiency might be helpful for unlocking long-context applications, such as generating code given an entire codebase, if they can't actually refer to specific functions or variables in context, then that efficiency effectively doesn't matter. We find a popular 2.8Bn parameter Mamba LM trained on 300 Bn tokens of the Pile underperforms a 1.3Bn parameters (2.2× smaller) Transformer LM trained on 50 Bn tokens (6× fewer tokens) by 5 points, averaged across a suite of benchmarks that require AR abilities.

Unfortunately, this fundamental issue for memory-limited LMs: we prove models require Ω(N)\Omega(N) space, in input length N, to solve associative recall tests like the one below. We have a "context" of key-value token pair mappings on the left and "questions" on the right for which the model should output answers 44, 66, 11, 22, 33:

Is all lost? This begs the question of whether we can actually rely on the LMs with O(1)-memory recurrent state sizes for language modeling...

Luckily, models often do not need to remember all information provided in-context to excel at a task! The challenge is predicting which subset of information is useful to store. LMs have gotten a lot better at this through innovations on architectural inductive biases like input-dependent decays (e.g., Mamba), LSTM, fast weights, delta rule, etc. Other work like the Based architecture, instead increases the recurrent state size in hardware-efficient and math-guided ways! 1 2 Efficient LMs have continued to extend the Pareto frontier of the AR-quality and efficiency tradeoff space, but leave a lot to be desired.

Our insights. We make a simple observation that the order in which a recurrent LM reads input text drastically impacts the difficulty of predicting what to store in the limited memory. Suppose we ask questions QQ (e.g., “When did Galileo move to Florence?”), over documents DD (e.g., the detailed Wikipedia for Galileo Galilei). The model needs to remember just one fact from DD if the prompt is ordered [Q,D][Q, D], but needs to remember all facts when it is [D,Q][D, Q]. The issue is that essentially all our modern LLMs view text "causally", i.e. they process text in a fixed order – left-to-right. They need to predict the future (predicting that I’ll ask about Galileo’s move to Florence) when deciding what to store (reading Galileo’s page). Note that people don't read like this! We go back and forth and jump around - think back to SAT reading comprehension tests!

Our work goes beyond the causal-LM orthodoxy to close AR gaps to Transformers using efficient recurrent LMs!! We propose two methods in this vein, detailed in the rest of this blogpost:

  1. Just read twice prompting (JRT-Prompt): This method repeats the prompt context multiple times before the model needs to generate the answer. It works with any off-the-shelf causal or non-causal model, without modification! With this repeated prompt strategy we can do better than zero-shot Transformers in quality, but with comparable efficiency. You might think the extra context repetition makes things slow, but because we're linear and hardware-efficient we actually end up being faster than the zero-shot Tranformer! We apply JRT-Prompt to 1616 off-the-shelf LMs overall (spanning Based, GLA, Mamba-1, and Mamba-2) resulting in 11.0±1.311.0 \pm 1.3 points improvement on average across the benchmarks. Based LMs with JRT Prompt can process 3267832678 length sequences, 1616 batch size 11.9×11.9\times faster than Transformers with FlashAttention-2.

  2. Just read twice recurrent architecture (JRT-RNN): This architecture uses an encoder-decoder structure to non-causally process prompts then causally decode the answer. JRT-RNN provides 99%99\% of Transformer quality at 360360M params./3030Bn tokens, averaged across the recall-intensive ICL benchmarks. This represents 46.7%46.7\% improvement over Based and 78.8%78.8\% over Mamba. JRT-RNN provides 96%96\% of Transformer quality at 1.31.3Bn params. / 5050Bn tokens, representing 16.2%16.2\% improvement over Based and 34.5%34.5\% over Mamba on average. Benchmarking our custom CUDA implementation, prefix linear attention can process 3267832678 length sequences, 1616 batch size 19.2×19.2\times faster than FlashAttention-2.

We think there's a lot to do in this space and are interested to hear your feedback on these ideas! At a high level, why are we directly trying to mimic the way we use Transformers with our efficient LMs? We can exploit the asymmetries! Multiple linear passes is asymptotically better than quadratic Transformers!

You can try out our models and hardware-efficient kernels: https://github.com/HazyResearch/prefix-linear-attention.

Understanding the role of data order on efficient models

To demonstrate the sensitivity of recurrent LMs to the data ordering, we started by considering a simplified toy task: the model is given input sequences that contain elements from two sets A and B, and needs to output the tokens in the intersection of AA and BB. E.g., the correct output below would be 6:

Note, that the "set intersection" problem tested by this toy task is a quintessential problem in communication complexity theory that has been studied for decades. The "hardness" of many problems reduces to the hardness of set disjointness (are the two sets disjoint or not)? We also formally show that the associative recall problem reduces to set disjointness problem in our paper!

In evaluating LMs on the toy task, we consider two slices of the test set: when A>B|A| > |B| (“long AA to short BB”) and when B>A|B| > |A| (“short AA to long BB”). We find that causal models achieve better quality when B>A|B| > |A|, than when A>B|A| > |B|, which is intuitive since the LM needs to store all elements in set A to find the intersection with B. Meanwhile, non-causal LMs outperform when A>B|A| > |B| and match the causal case when B>A|B| > |A|! The causal LMs are far more sensitive to the data-ordering (highlighted in the right-hand-side panel).

We theoretically formalize the synthetic observations, pointing to two paths forward to build more reliable O(1)O(1)-memory recurrent LMs:

  1. The causal LM needs Ω(min(A,B))\Omega(\mathrm{min}(|A|, |B|)) bits of memory to solve set disjointness. Broadly, we should put data in “the right order” in-context (e.g., placing the set with min(A,B)\mathrm{min}(|A|, |B|) first in context). We use this idea in our first proposal, Just read twice prompting.

  2. We prove non-causal LMs can solve set disjointness in Ω(min(A,B))\Omega(\mathrm{min}(|A|, |B|)) bits of memory, regardless of the data order in context. We use this idea in our second proposal, the Just read twice recurrent architecture.

Just read twice prompting

Just read twice is very simple: we just repeat the prompt (e.g., document and question) multiple times in context before the model generates an answer. Intuitively, in the second (or subsequent) pass over the context, the LM can condition on the full context to decide what to store in the state -- we get around the issue of needing the data to be in "the right order"!

We take 1616 pretrained recurrent LMs from Hugging Face and evaluate them using standard zero-shot prompting versus JRT-prompts on six recall-intensive tasks (document question answering, information extraction). Each cell below shows the Standard / JRT scores respectively.

ArchitectureParamsTokensFDASWDENQSQUADv2TriviaQADrop
Transformer 1.3B10B74.4 / 86.141.4 / 52.528.2 / 31.939.0 / 53.149.5 / 49.322.3 / 33.6
Mamba 1.3B10B23.3 / 40.315.5 / 31.819.4 / 25.826.6 / 48.546.4 / 51.121.3 / 32.1
Based 1.3B10B48.6 / 58.927.6 / 44.719.7 / 28.431.0 / 46.744.1 / 51.919.5 / 34.6
Transformer 1.3B50B83.7 / 89.250.8 / 65.032.8 / 37.541.1 / 58.156.6 / 58.821.5 / 37.9
Mamba 1.3B50B41.9 / 55.732.6 / 45.426.9 / 33.931.5 / 53.554.9 / 56.720.4 / 33.8
Based 1.3B50B60.2 / 68.337.1 / 54.029.4 / 35.238.9 / 56.354.5 / 57.621.7 / 39.1
Mamba 130M300B25.7 / 32.817.5 / 31.516.8 / 21.727.1 / 51.943.5 / 50.117.4 / 30.7
Mamba370M300B41.9 / 58.327.6 / 42.223.8 / 31.134.9 / 51.053.6 / 51.719.3 / 33.2
Mamba 1.4B300B45.8 / 60.937.6 / 46.031.0 / 36.639.9 / 59.660.5 / 61.320.9 / 36.4
Mamba 2.8B300B54.3 / 66.638.9 / 48.933.5 / 40.143.9 / 59.466.2 / 63.919.8 / 36.9
GLA 1.3B100B48.3 / 68.637.7 / 53.626.6 /  31.134.7 / 54.855.5 / 54.617.4 / 30.7
GLA 2.7B100B47.1 / 65.843.6 / 54.527.1 / 32.937.2 / 55.757.9 / 57.022.2 / 34.0
Mamba2 130M300B32.2 / 50.929.5 / 43.320.6 / 28.930.4 / 47.043.7 / 47.218.0 / 34.0
Mamba2 370M300B60.8 / 76.738.3 / 52.126.6 / 33.635.3 / 51.854.6 / 54.722.4 / 36.3
Mamba2 1.3B300B66.8 / 74.750.0 / 59.633.6 / 40.542.9 / 59.663.8 / 62.423.2 / 36.6
Mamba2 2.7B300B68.7 / 81.655.2 / 60.834.4 / 41.745.4 / 59.466.4 / 66.523.0 / 42.5

This results in 11.0±1.311.0 \pm 1.3 points of accuracy improvement on average for non-Transformer LMs across the benchmarks. Checkout our paper for more analysis. We show JRT outperforms few-shot prompting, and the full repetition also appears to be more helpful than only repeating the question.3 These results highlight that LMs are brittle across varied data orderings in context.

JRT prompt with two repetitions does double the context length, but using recurrent LMs, this is still asymptotically more efficient than the quadratic scaling of Transformers. Interestingly, a lot of our work on efficient LMs tries to directly mimic the experience of using Transformers, but we could exploit the asymmetry: multiple linear passes can still be better than one quadratic pass in many regimes. We're excited for future methods along this line of thinking.

Benchmarking our hardware-efficient implementation of Based, written in the ThunderKittens framework, JRT-Prompt applied to Based models can provide 11.9X higher throughput than FlashAttention-2 at sequence length 3276832768 and batch size 1616 on an NVidia H100 GPU. It is also worth re-emphasizing that JRT-Prompt can be used with any off-the-shelf LM.

Just read twice recurrent architecture

We have shown that the recall quality of causal recurrent LMs varies depending on the order in which the information appears in context, making them brittle for in-context learning. In our set disjointness analysis, we found that non-causal recurrent LMs could help! If a LM can non-causally view the entire prompt context when deciding what information to store or discard, the LM may be able to make better selection decisions.

A long line of work has demonstrated the strength of non-causal "bidirectional" neural networks in language modeling. However, it is challenging to use them for fast text generation because the context must be re-processed for each generated token to achieve good quality. Encoder-decoder architectures with a bidirectional encoder and causal decoder offer a way to achieve fast causal generation while reaping the benefits of bidirectional LMs. Nonetheless, decoder-only causal autoregressive LMs remain the norm and encoder-decoder architectures have received little attention in the context of sub-quadratic efficient LMs.

We propose a simple encoder-decoder architecture, JRT-RNN, that goes beyond the de-facto causal modeling! We'll start by defining some baseline efficient LMs for the causal setting and then describe our approach.

Preliminaries in the causal case

Sequence mixers are the part of an LM that determines how words in input xRN×dx \in \mathbb{R}^{N \times d} for sequence length NN and model dimension dd, interact with each other when computing output representations yRN×dy \in \mathbb{R}^{N \times d} for the subsequent layer of a deep neural network.

First we'll start with the de-facto attention (Bahdanau et al., 2014) sequence mixer and then describe a popular category of efficient LMs, which uses linear attention (Katharopoulos et al., 2020). Linear attention is closely related to another popular category of efficient LMs based on state space models (Gu et al., 2021).

Attention models mix tokens as follows:

Attention takes O(N2d)O(N^2d) compute and linear O(Nd)O(Nd) memory during training. In inference, remember we generate one token at a time tt -- each new query QtQ_t needs to interact with all the prior K0:tK_{0:t} and V0:tV_{0:t} so if we cache those keys and values, compute and memory to generate a token scales with tt.

Linear attention models replace the exp\exp with a “feature map” ϕ:RN×dRN×D\phi: \mathbb{R}^{N \times d'} \rightarrow \mathbb{R}^{N \times D}, with “feature dimension” DD. As before, dd is the model dimension, but we’ll decouple dd’ so QQ and KK can be projected to a different dimension than VV as needed.

Shown in the figure, this is now O(NDd)O(NDd) linear in NN during training. During inference, by summing prior ht=i=0tϕ(KtT)Vth_t = \sum_{i=0}^{t} \phi(K_t^T) V_t such that htR1×Ddh_t \in \mathbb{R}^{1 \times Dd}, then multiplying this with QtQ_t, observe our compute and memory to generate a token is constant O(1)O(1) as tt grows!

Our approach

Prefix linear attention. We first uses what we refer to as prefix linear attention (PLA), inspired by the class of prefix LM architectures. We split the sequence into two regions: a non-causal encoder region of length MM, M<NM < N and a causal decoder region of length NMN-M. Adding projections to compute AeA_e, BeB_e on the encoder side, we compute outputs as:

Like standard linear attention, PLA is agnostic to the choice of feature map ϕ()\phi(\cdot). Observe that PLA retains the compute complexity and recurrent state size of causal linear attention during decoding, but processes the "prompt region" (first MM tokens) twice during prefill.

We introduce a hardware efficient CUDA implementation using the ThunderKittens CUDA library for PLA, which takes just 1.24×1.24\times the time of the baseline decoder linear attention LM. In benchmarking, this implementation provides 19.2×19.2\times higher 19.2×19.2\times higher throughput than FlashAttention-2 for prefill of length 32K32K, batch size 1616!

Pretraining Loss. Decoder-LMs are trained using a next token prediction loss: i.e, at each position ii in the sequence, the LM uses tokens 0i0-i to predict token i+1i+1. In PLA, we can’t compute a NTP loss for tokens 1 through M since the model views them non-causally. Instead, we use a masked language modeling loss for tokens in the encoder region (i.e., replacing / "masking" some words in the sequence and measuring how well the model predicts the original word that got masked) plus next token prediction loss in the causal decoder region (last NMN-M tokens). Note that if we mask roughly 15%15\% of tokens, let M=N2M = \frac{N}{2}, and normalizing for the total number of input sequences, the non-causal JRT LM would compute losses on roughly 65%65\% the number of tokens of the causal LMs. However, how best to normalize training across causal vs. non-causal models is an open question that we hope future work will continue to engage with.

Results

On the set of recall-intensive benchmarks, JRT-RNN provides consistent improvements over strong decoder recurrent LM baselines, achieving 96% the quality of the very strong Transformer ++ (Llama architecture) baseline! We compare to two very strong O(1)O(1)-memory recurrent LM baselines: JRT gives a 16.2%16.2\% improvement over the Based baseline and 34.5%34.5\% improvement over the Mamba baseline at 1.31.3Bn parameters and 5050Bn tokens. Checkout our paper for lots more results and analysis!

ArchitectureParams/TokFDA N=512N=512FDA N=512N=512SWDE N=512N=512SWDE N=512N=512NQ N=512N=512NQ N=512N=512SQUADv2TriviaQADropAvg.
Transformer1.3B/50B85.683.555.756.033.429.940.156.621.451.4
Mamba1.3B/50B55.440.144.033.727.623.232.254.520.736.8
Based1.3B/50B69.358.847.640.429.124.438.554.320.842.6
JRT-RNN1.3B/50B86.767.749.445.738.325.450.453.029.349.5

Overall, nearly all the focus on efficient attention alternatives LMs has been on decoder-only LMs. We're very excited about the potential of (1) non-causal modeling and (2) making multiple linear passes over the input prompt as two ways to help make efficient LMs competitive with Transformers on the important recall skill, and think there are so many more ideas to consider along these directions. 5

If you’re interested in chatting more about any of these topics, please feel free to reach out at: simarora@stanford.edu. We'd like to thank Hazy Research, Cartesia, and Together for supporting our work!

Finally, big thank you to labmates including Michael Zhang, Dan Fu, Neel Guha, and Jon Saad-Falcon on helpful feedback on this blogpost and the overall work!


  1. Mamba-2 recently followed this approach as well!
  2. There are also several combinations of linear recurrent models and attention layers –- H3, Jamba, Striped Hyena, Evo, Mistral, and Griffin -- resorting to O(N)O(N) KV-caches, with smaller constants.
  3. Echo embeddings is another very exciting paper you should check out along this line! That work repeats context prior to extracting text embeddings from autoregressive LMs, leading to improvements over the case without repetition.
  4. Ilya's seminal seq-to-seq paper exploits the observation that data order matters in the context of encoder-decoder recurrent LMs, proposing to reverse the order of the tokens source text!