Efficient language models as arithmetic circuits

Simran Arora, Sabri Eyuboglu, Atri Rudra

This blogpost was written as a resource in conjunction with our talk at STOC 2024. It focuses on work done with the following full team: Michael Zhang, Aman Timalsina, Michael Poli, Isys Johnson, Dylan Zinsley, Silas Alberti, James Zou, Christopher Ré.


The Transformer architecture, which powers most language models (LMs) in industry, requires compute and memory to grow quickly with the input size during training and inference, precluding dreams of modeling inputs containing millions of lines of code or the 3.2Bn nucleotide pairs in genome sequences. Alternative architectures (e.g. Mamba, RWKV), which seek to reduce the compute and memory requirements relative to Transformers, are emerging at a rapid pace. However, deviating from the Transformer orthodoxy is a major risk: models are served to millions of users and take billions of dollars to train, and we have limited insight into how switching to these alternatives might impact quality.

Thus, we need frameworks to reason about which architectural changes matter and why. We’ve been writing a series of papers that use arithmetic circuits to make some sense of the landscape of architectures (Zoology, Based, JRT). Our general approach is to view each architecture as different a arithmetic circuit, a composition of sums (+)(+) and products (×)(\times), and to reason about the efficiency of different LMs in terms of the number of operations required to perform tasks of interest. This post describes our theoretical results.

Part 1: A single skill, associative recall, causes a lot of trouble

It is difficult to reason about language modeling quality given the breadth of skills involved (e.g. fact memorization, in-context learning) and the many different levels of abstraction for skills (e.g. some work evaluates LM’s grammar skills while others evaluate an LM’s ability to “emotionally self regulate”). We decided to train a bunch of LMs across architectures and do error analysis to see what we’d find!1

Error analysis across efficient LM architectures

Surprisingly, we found that a single skill called associative recall (AR) glaringly stood out in the error analysis: Transformers crushed it and the efficient LMs struggled. An LM performs AR if it recalls information that it has seen earlier in the input. For instance, LMs need to refer to functions or variables that already exist in the inputted codebase when generating subsequent lines of code. Our prior posts here and here dive more into our empirical observations. As a recap:

  1. Gated convolutions. We found 80%+ of the language modeling quality difference between Transformers vs. the class of “gated convolution” LMs (built from gating and convolution operations e.g., H3, Hyena, RWKV) is attributed to AR. We find these LMs perform abysmally on in-context learning (ICL) tasks that require performing recall (e.g., document question answering, information extraction).

Next, training other classes of LM architectures, including selective state space LMs (e.g. Mamba, GateLoop) and linear attention LMs (e.g. GLA, Based), at 1.3Bn parameters, 50Bn tokens, and evaluating on ICL tasks that require performing recall (FDA, SWDE, NQ), we found:

  1. Selective state space models. For 1k and 2k length documents, Mamba underperforms the Transformer by 42% and 55% on average across tasks.
  2. Linear attention. There is excitment around linear attention since they extend the Pareto frontier of the efficiency-recall quality tradeoff space beyond the above architecture classes. Still, Based underperforms the Transformer by 27% and 41% at 1k and 2k length documents on average.

This blogpost focuses on how we theoretically explore what's going on with AR.

Formalizing the associative recall problem

Based on our analysis, we defined the MQAR problem: An LM needs to refer back to key-value token mappings (on the left) to answer questions (on the right). Here the answers are 4, 6, 1, 2.

A 4 B 3 C 6 F 1 E 2A ? C ? F ? E ?\textrm{A 4 B 3 C 6 F 1 E 2} \rightarrow \textrm{A ? C ? F ? E ?}

Formally, given an input sequence x=x0,...,xN1x = {x_0, . . . , x_{N−1}} where each xiCx_i \in C is a token drawn from a vocabulary of size c=Cc = |C|. The task is to check, for every query 1i<N1 \leq i < N, whether there exists a 0j<i0 \leq j < i such that xixjx_i ≡ x_j. If so, output xj+1x_{j+1}.2

Theoretically reasoning about open-ended problems like language modeling is complicated, but the simplified MQAR task helped us make some progress. When we train LMs on synthetic MQAR data (like the example above), the trends correlate with real language modeling gaps between architectures trained to large scale. Prior work has extensively explored AR over multiple decades (Graves et al., 2014, Ba et al., 2016, Schlag et al., 2021, etc.). We find MQAR quality correlates better with real language modeling quality than previously used formulations (discussed in more detail in our prior blogpost).


Part 2: Background

Here we'll briefly recap the definitions of different sequence mixers and intuition on why our efficient LMs might struggle to perform MQAR.

Sequence mixers are the part of an LM that determines how words in input xRN×dx \in \mathbb{R}^{N \times d} for sequence length NN and model dimension dd, interact with each other when computing output representations yRN×dy \in \mathbb{R}^{N \times d} for the subsequent layer of a deep neural network.

First we'll start with the de-facto attention sequence mixer, and then describe two (closely related) categories of efficient LMs: linear attention and state space models. The LMs shown in Figure 1 above all fall in these two categories.

Attention models mix tokens as follows:

Attention takes O(N2d)O(N^2d) compute and linear O(Nd)O(Nd) memory during training. During inference, remember we generate one token at a time tt -- each new query QtQ_t needs to interact with all the prior K0:tK_{0:t} and V0:tV_{0:t} so if we cache those keys and values, compute and memory to generate a token scales with tt.

Linear attention models, proposed by Katharopoulos et al., 2020 replaces exp\exp with a “feature map” ϕ:RN×dRN×D\phi: \mathbb{R}^{N \times d'} \rightarrow \mathbb{R}^{N \times D}, with “feature dimension” DD. As before, dd is the model dimension, but we’ll decouple dd’ so QQ and KK can be projected to a different dimension than VV as needed.

Shown in the figure, this is now O(NDd)O(NDd) linear in NN during training. During inference, by summing prior ht=i=0tϕ(KtT)Vth_t = \sum_{i=0}^{t} \phi(K_t^T) V_t such that htR1×Ddh_t \in \mathbb{R}^{1 \times Dd}, then multiplying this with QtQ_t, observe our compute and memory to generate a token is constant O(1)O(1) as tt grows.

State space models The mixing proceeds as follows:

And many SSM variants also use Hadamard product (elementwise multiplication between 22 vectors) operations in addition to the convolution.

This ignores that SSMs are continuous objects, for a more precise definition see prior blogs. Observe that we again have a recurrent view (O(1)O(1) compute and memory) we can use during inference and a sub-quadratic parallel convolution O(dNlogN)O(dN \log N) with FFT view we can use during training.

Intuition on why gated convolutions may struggle on MQAR Before diving into the theory, let's make sure we have some high level intuition on why efficient LMs might struggle. MQAR requires us to compare tokens to find matching tokens, and then shift forwards the answer for the next token prediction.

Attention computes inner products between all tokens, making it easy to find matches! In the above figure, note the convolution is more restricted, a Toeplitz matrix, so a token can only be compared to a token 44 positions prior if all tokens compare to the tokens 44 positions prior to them as well. The comparisons are accomplished by using a convolution filter that shifts the sequence 44 positions. Once we shift, the model can compare the shifted and original sequences to identify matches using the aforementioned Hadamard products.

In our convolution models, we apply a unique convolution filter for each of the dd dimensions, so we can support multiple shifts. BUT, to support all O(N)O(N) possible shifts in a sequence of length NN, each hidden effectively needs to store information about all tokens to be able to perform the comparison.

Okay so we have some intuition on why convolutions might struggle -- the filters are input-independent and it's hard to support all the shifts with small model dimension. Attention's input-dependent mixing is helpful to find the matching tokens. We walk through this intuition more slowly in our prior blogpost.


Part 3: Unifying the messy architectural landscape using arithmetic circuits

Above, we walked through a toy construction of how LMs might solve MQAR, but there's a messy landscape of LMs with subtle modifications from architecture to architecture. We use theory to reason about the landscape systematically.

Towards this, we first observe our efficent LM classes (gated convolutions, linear attentions) can be written as polynomials in terms of the input variables! The models we consider are maps M:RN×dRN×dM:\mathbb{R}^{N\times d}\to \mathbb{R}^{N\times d} and each of the NdNd output values are polynomials in the NdNd input variables x[i,j]x[i,j] for (i,j)[N]×[d](i,j)\in [N]\times [d].

Polynomial view of the primitive operations in efficient language models

Gated convolutions. Let’s walk through how Hadamard product and convolution operations are expressed as polynomial operations. Suppose we have two vectors:

x=[x0,x1,,xN1]x = [x_0, x_1, …, x_{N-1}]

v=[v0,v1,,vN1]v = [v_0, v_1, …, v_{N-1}]

First, recall that a Hadamard product between vectors xx and vv computes: y=[x0v0,x1v1,,xN1vN1] y = [x_0 v_0, x_1 v_1, …, x_{N-1} v_{N-1}]. Note that here the NN outputs xivix_i\cdot v_i is a degree 22 polynomial in the variables xix_i and viv_i.

Next, a convolution between vectors xx and vv computes (for 0i2N20\le i\le 2N-2):

(xv)[i]=j=0ixijvj(x \ast v)[i] = \sum_{j=0}^{i} x_{i-j}v_j.

We can see that each of the 2N12N-1 outputs are polynomials in the 2N2N inputs variables in xx and vv. In models that use convolutions typically vv is fixed (it is a learned convolution filter), so the convolution becomes a linear operator– i.e. for each of the output positions we compute a polynomial of degree at most 11.

Linear attention models We can express each entry of the (i,j)RN×d(i, j) \in \mathbb{R}^{N \times d} linear attention output as a polynomial as follows. We let DD be the feature dimension and Q,KRN×DQ, K \in \mathbb{R}^{N \times D} and VRd×dV \in \mathbb{R}^{d \times d} be our embedding projections.

hi,j(x)=m[D],n[N](WQx)[i,m](WKx)[n,m](WVx)[n,j]h_{i,j}(x) = \sum_{m \in [D], n \in [N]} (W_Qx)[i, m] \cdot (W_Kx)[n, m] \cdot (W_Vx)[n,j]

We can see that this computes a degree 33 polynomial at each position (three variables from input xx are being multiplied).

!! Note !! One limitation of the arithmetic circuit view is that we can’t exactly represent non-polynomial activation functions, like the exp\exp in attention or Mamba. However, foreshaddowing a bit, we will use functional approximation theory to approximate exp\exp with (small degree) Taylor polynomials for exp\exp.

Measuring the efficiency of models using arithmetic circuit complexity

All gated convolution models (Dauphin et al., 2016, H3, BiGS, Hyena, RWKV, M2, etc.) and linear attention models (with polynomial feature maps) are polynomials, which means they can be directly computed by arithmetic circuits! Efficient ML is all about the scaling needed to obtain some desired quality, and there are many ways to define efficiency e.g., data, compute, parameters: here we use the complexity of arithmetic circuits as our measure of efficiency. Arithmetic circuits provide a natural and powerful computational model for studying the complexity of polynomials, with deep connections to fundamental problems in theoretical computer science. The complexity of a polynomial is determined by the size of the smallest arithmetic circuit that computes the polynomial.

What is an arithmetic circuit? An arithmetic circuit is a directed acyclic graph that represents a polynomial function of the circuit’s input variables. We can express our polynomial (gated convolution) as arithmetic circuits with our input xx and model parameters θ\theta as inputs. Each node in the circuit is an addition (++) or multiplication (×\times) operation between inputs.

As an example, let’s express the Discrete Fourier Transform (DFT) as a linear arithmetic circuit for an input vv of length n=4n=4 (Dao et al., 2020). Linear means the circuit uses just linear gates (i.e. gates on input xx and yy compute ax+byax+by for constants aa and bb), contrasting general arithmetic circuits, which also have multiplication gates (i.e. on input xx and yy compute xyxy).

An arithmetic circuit is an (n,s,Δ,ω)(n, s, \Delta, \omega)-circuit, if it takes nn variables, has size ss, depth at most Δ\Delta and width ω\omega. The size is the number of nodes in the circuit, depth is the length of the longest path between input and output nodes, and width is the maximum number of wires that cross any “horizontal cut” in the circuit. If we implement a DFT by doing the naive O(n2)O(n^2) time algorithm, then for the order 4 DFT matrix, we have 4 rows (4 inner products). To implement each inner product we need 3 linear gates for an overall 12 linear gate ops. But here, we show the FFT algorithm, using fewer (8) linear gates.

Quick notes on arithmetic circuits, polynomials, and structured matrices

Feel free to skip this section if you're familiar with these concepts

Again, when we talk about arithmetic circuits being equivalent to polynomials, we are talking about outputting NdNd values y[i,j]y[i,j] by computing polynomials with respect to the NdNd input variables (e.g., x[i,j]x[i,j]) in xRN×dx \in \mathbb{R}^{N \times d} . Each polynomial has a single output value, so the model/circuit computes NdNd polynomials, which as mentioned before, are each of degree at most 22 in the case of convolutions plus Hadamard product, and degree 33 for linear attention.

In our prior tutorial blogpost on polynomials and convolutions, we thought about polynomials in the coefficient space, whereas in all our discussion in this blogpost, our polynomials are in evaluation space. Why? Well, for the convolution case discussed in our prior blogpost, we were picking the evaluation points, namely we chose the complex roots of unity as our evaluation points permitting us go back and forth between coefficient and evaluation spaces (we could take the DFT of the coefficient vector in order to evaluate the convolution (polynomial)).

Why did we think about coefficients in the Monarch Mixer work? Well, we cared about causality, which is fundamentally about the positions of different tokens governing which other tokens they can interact with in the context -- so, we chose a polynomial basis to help us analyze that property.

Reminder of the polynomial basis considered in our prior work

Let’s also write [x0,x1,,xN1][x_0, x_1, …, x_{N-1}] as coefficients of a polynomial f(Z)f(Z) and [v0,v1,,vN1][v_0, v_1, …, v_{N-1}] coefficients of g(Z)g(Z):

f(Z)=x0+x1Z+.+xN1ZN1f(Z) = x_0 + x_1Z + …. + x_{N-1}Z^{N-1}

g(Z)=v0+v1Z++vN1ZN1g(Z) = v_0 + v_1Z + … + v_{N-1}Z^{N-1}

First, recall that a Hadamard product between vectors xx and vv computes:

y=[x0v0,x1v1,,xN1vN1] y = [x_0 v_0, x_1 v_1, …, x_{N-1} v_{N-1}]

By definition, the Hadamard product of two polynomials is obtained by element-wise multiplying the coefficients of the polynomials. I.e.:

f(Z)g(Z)=x0k0+x1k1Z++xN1kN1ZN1f(Z) \odot g(Z) = x_0 k_0+x_1 k_1 Z+ \dots+ x_{N-1} k_{N-1} Z^{N-1}

The ZZ powers remain unaffected because they are part of the polynomial's structure (corresponding to the positions in the sequence), not the elements being operated on. The operation preserves the polynomial’s form, altering only the coefficients.

Second, a convolution between vectors xx and vv computes (for 0i2N20\le i\le 2N-2):

(xv)[i]=j=0ixijvj(x \ast v)[i] = \sum_{j=0}^{i} x_{i-j}v_j

This is equivalent to multiplying the two polynomials whose coefficients are provided by xx and vv respectively: h(Z)=f(Z)g(Z)h(Z) = f(Z)g(Z),

or

h(Z)=(x0+x1Z+.+xN1ZN1)(v0+v1z++vN1ZN1)h(Z) = ( x_0 + x_1Z + …. + x_{N-1}Z^{N-1} ) ( v_0 + v_1z + … + v_{N-1}Z^{N-1} ),

Which is the same as

h(Z)=i=02N2wiZih(Z) = \sum{i=0}^{2N - 2} w_i Z^i.

In the above, each coefficient is:

wi=(xv)[i]=j=0ixijvjw_i = (x \ast v)[i] = \sum_{j=0}^{i} x_{i-j}v_{j}.

Meanwhile, arithmetic circuits need to work for all potential evaluation points - we can't pick our evalution points. Taking in the NdNd input variables, arithmetic circuits can use arbitrary computations to compute the NdNd outputs (sharing computation across the NdNd polynomials within the circuit)!

Another way to think about things. We can move between coefficient and evaluation space in general: polynomial evaluation at NN points is equal to multiplying by a structured Vandermonde Matrix containing the polynomial coefficients! However, for Vandermonde matrices, the arithmetic circuit conversion has complexity O(Nlog2N)O(N \log^2 N) arithmetic operations -- this is a bad way to express the circuits since it's a large number of operations and could result in numerical stability issues. In general, Vandermonde matrices are ill-conditioned, but the DFT, a special case of Vandermonde matrices, has condition number 11 (i.e. is very well conditioned).

Hopefully this helps you think about why we choose different polynomial bases across these works. Alright, let's get back to the main story!

Unifying efficient LM architectures via the polynomial view

This section unifies architectures using the polynomial view.

Gated convolutions. Using the polynomial view, we first prove that any gated convolution model (including H3, BiGS, Hyena, RWKV, M2, etc.) can be simulated by a single canonical representation, BaseConv, within tight (poly)logarithmic factors in model parameters and depth. Given an input uRN×du \in \mathbb{R}^{N \times d}, the BaseConv operator on the right is defined as:

where hRN×dh \in \mathbb{R}^{N \times d} contains learnable filters, WRd×dW \in \mathbb{R}^{d \times d} is a linear projection, and b1,b2RN×db_1, b_2 \in \mathbb{R}^{N \times d} are bias terms. The \odot represents a Hadamard product and the \ast represents a convolution. Each BaseConv layer is a polynomial of degree at most 22 since the left and right sides are linear (degree 11) in input uu, and taking their Hadamard product, two linear polynomials get multiplied together.

Linear attentions. We prove that an (N,O(log2(Nd)),d,O(N(d+D)),O(max(d,D)))(N, O(log2(Nd)), d, O(N(d + D)), O(\max(d, D)))-BaseConv (using (n,s,Δ,ω)(n, s, \Delta, \omega)-circuit notation) computes the output for a single linear attention layer with feature dimension D. This points to the relative efficiency of linear attention over gated convolutions: we need multiple BaseConv layers to represent one linear attention layer.

Moreover, we can convert any arithmetic circuit into a BaseConv model with only a poly-log loss in parameters:

Theorem (Arithmetic circuit equivalency). For every low-depth arithmetic circuit of size ss, depth Δ\Delta, that takes uRN×du \in \mathbb{R}^{N \times d} as input, there is an equivalent BaseConv operator that uses O~(sΔ)\tilde{O}(s\Delta) parameters and O~(Δ)\tilde{O}(\Delta) layers.3

For specific instantiations of gated convolutions (e.g., Hyena models), we show in our paper that we can even get rid of the poly-log factor, where O~()\tilde{O}(\cdot) hides poly-logarithmic factors.

Why is this theorem exciting to us? Let's walk through a few implications of this theorem:

  1. Arithmetic circuits are non-differentiable objects. However, BaseConv is differentiable, allowing us to reason about how things actually learn!

  2. Arithmetic circuits can do arbitrary computation. Meanwhile, notice that the left-hand-side (linear map) portion of BaseConv only operates on the channel dimension dd of the N×dN \times d input and the right-hand-side (convolution) only operates on the sequence dimension NN. Despite the relative restrictivity of BaseConv, it does not lose any power!

  3. The theorem generalizes prior results. A prior paper from our group, Kaleidoscope, shows results for linear arithmetic circuits and Butterfly matrices; here we extend to general arithmetic circuits.

Expand for an overview of how we prove the above theorem.

First, recall an arithmetic circuit CC is layered with each layer being only ×\times or ++ gates. The ×\times gates we handle via the Hadamard product \odot operation in BaseConv. So in some sense all that remains is to handle the ++ gates – or more specifically, we need the capability to implement linear maps using BaseConv.

To handle linear maps, we use a result from [6] that shows that any linear map can be implemented as a O~(1)\tilde{O}(1) product of Butterfly (factor) matrices (a classical type of structured matrix that e.g. arises in the FFT). These matrices can be represented as sum of shifts of 3 diagonal matrices, i.e. they are of the form SiD1+D2Si+D3S^i\cdot D_1+D_2\cdot S^i+D_3 where SS is the shift matrix and D1,D2,S3D_1,D_2,S_3 are diagonal matrices. Any of these three matrices can be implemented as xbx \odot b for a suitably defined bias matrix bb.4 There are other details that need to be handled (mainly being able to “remember” part of the input when applying a layer of BaseConv) but those can be done (see our paper [1] for the details).

We looked at a the landscape of efficient LMs and distilled it to BaseConv. We did not lose any power (within poly-logarithmic factors) with our simplification and we can now reason about classes of architectures more easily! Let's do this next.


Part 3: Solving MQAR with different classes of efficient architectures

BaseConv is exciting: we can reason about the simple object as a proxy for the messy zoo of efficient LMs.

Concretely, we can propose an arithmetic circuit that solves the MQAR task, draw an equivalence between the circuit and BaseConv, and derive an upper bound on the depth of BaseConv needed to represent the circuit. For instance, a naive algorithm for MQAR would use a sequential pass over the sequence and put the keys and values into a hash map, then for queries, search the hash map.

Unfortunately, converting this algorithm to a circuit requires massive Ω(N)\Omega(N) depth of BaseConv in sequence length NN.

We also came up with a more efficient parallel arithmetic circuit by drawing inspiration from the parallel binary search algorithm, which has O(logN)O(\log N) depth (Algorithm 10). This allows us to obtain a tighter upper bound for BaseConv and MQAR. Given the equivalence between BaseConv and arithmetic circuits, we could conclude a BaseConv model with O~(Nd)\tilde{O}(Nd) parameters and O~(1)\tilde{O}(1) layers can solve MQAR!

So this is awesome, we have this single powerful BaseConv tool!! Why don’t we just do everything with this? It’s easy to reason about! The issue is: we're still losing poly-log factors on the depth in our bound (recall O~()\tilde{O}() hides poly-log factors) :( While poly-log is relatively tight in theory, in ML architectures, we pay a lot if the depth is too large — it makes things hard to optimize. We really want to know if we can solve MQAR in O(1)O(1) layers, like we can do for attention.

Can we do better?

Lower bound for BaseConv and MQAR

Unfortunately, we find can't get O(1)O(1) layers with BaseConv :( We show a lower bound on the depth that depends on NN.

Recapping the situation for one-hot encodings. In our toy constructions in Part 2, we had used one-hot encodings for each token in the sequence. We showed both Attention and BaseConv could solve MQAR with these encodings, but the size of the one-hot encoding is d=Nd=N, which is much too large.

Expand to recap attention and BaseConvs's one-hot encoding solutions again

Here we describe the Part 2 intuition with math notation :)

We will assume that our input matrix x{0,1}N×dx\in\{0,1\}^{N\times d} uses 1-hot encoding, i.e. each row is a standard basis vector in {0,1}n\{0,1\}^n (note that this implies d=Nd=N). Note that the attention-query-key inner product, A[i,j]=Q[i,:],K[j,:]A[i,j]=\langle Q[i,:],K[j,:]\rangle, is 11 if Q[i,:]=K[j,:]Q[i,:]=K[j,:] and 00 otherwise (note that we crucially used the fact that each row is 1-hot encoded here). AA tells us the positions of matching tokens in MQAR. Given the matching MQAR-key tokens, we want to output the corresponding MQAR-value, so we can let our attention-values be V=(Sx)V = (S \cdot x), where SS a matrix that shifts entries forwards by one position. We need to use ``positional encodings'' to implement the shift. We can compute AVA \cdot V to solve the MQAR problem. Visualize this construction.

Moving on to BaseConv, note that all of the above operations for attention are polynomial operations yielding an arithmetic circuit of size O(N3)O(N^3) and O(1)O(1) to solve the MQAR problem. This is because the circuit size is O(N2d)O(N^2d) but we also have d=nd=n, so O(N3)O(N^3) overall. If we use our arithmetic circuit equivalency theorem then we can get an O~(1)\tilde{O}(1) layer BaseConv implementation with O~(N3)\tilde{O}(N^3) parameters. In other words, for 1-hot-encoded inputs there is no difference between Attention and BaseConv: both can solve MQAR in O(1)O(1) layers! However, 1-hot encodings require model dimension to scale with sequence length NN!

We want more compresssed representations in ML... Things get interesting if we use a more compressed representation. For example, consider the case when d=log2Cd=\log_2{|C|} for vocab CC. Here we can still use Attention to solve MQAR in O(1)O(1) layers. This is because we have that Q[i,:]=K[j,:]Q[i,:]=K[j,:] (for our QKTQK^T inner products) iff and only if Q[i,:],K[j,:]=d\langle Q[i,:],K[j,:]\rangle=d. The latter is obtained with a “threshold” functionality, provided a linear map+ReLU or Softmax/other non-polynomial activation functions. The rest of the Attention construction remains the same.

However, in BaseConv, it is not possible to implement the above threshold “trick” with O(1)O(1) layers. In fact, we prove a lower bound in the even simpler case of checking if two vectors y,x{0,1}logCy,x \in \{0,1\}^{\log{|C|}} are equal (when x,yCx,y \in C). The basic takeaway from our proof ends up being that to represent equality exactly, we need a polynomial of degree dd, but O(1)O(1) layer BaseConv can only represent constant degree polynomials.

Note that this argument was tied to a specific encoding of xx, but are also able to show that regardless of how xx is encoded, with d2(logN)1ϵd\le 2^{(\log{N})^{1-\epsilon}}, a BaseConv model requires the number of layers to solve MQAR to scale with NN.

Expand for an overview of the proof of this result.

This proof invokes the index problem. The index problem has two agents, Alice and Bob, where Alice has a string x0,1nx \in {0, 1}^n and Bob has an index i[n]i \in [n]. The goal for the players is to output the ithi^{th} entry: xix_i. We also require the communication to be one-way: only Alice is allowed to send a single message to Bob and Bob needs to output the answer. We reduce MQAR to the index problem (described in Part 4).

  1. Properties of BaseConv First we note that an LL layer BaseConv model that solves AR on input xRN×dx\in\mathbb{R}^{N\times d} is equivalent to a polynomial on NdNd variables (one corresponding to each entry in the matrix xx) and has degree at most 2L2^L. 5 Recall each BaseConv layer is the same as a polynomial of degree at most 22 -- composing LL such layers gives us a polynomial of degree at most 2L2^L.

  2. We next invoke the index problem. Alice has the input a=(1,x1,,n,xn)a=(1,x_1,\dots,n,x_n) (encoded as a matrix in R(N1)×d\mathbb{R}^{(N-1)\times d}) and Bob has the input b=i[n]b= i \in [n] (encoded as a vector in Rd\mathbb{R}^d). If there exists an LL layer BaseConv model to solve AR, then there exists a polynomial P(a,b)P(a,b) of degree at most 2L2^L that outputs xix_i. Alice knows everything except bb. She can substitute her values aa, to get a polynomial Qa(Z)=P(a,Z)Q_a(Z)=P(a,Z), where ZZ has dd variables as placeholders for Bob’s input and still, the degree of Q(Z)deg(P)2LQ(Z)\le \deg(P)\le 2^L. Alice sends Bob the coefficients of Qa(Z)Q_a(Z) and Bob outputs Qa(b)=P(a,b)=xiQ_a(b)=P(a,b)=x_i.

  3. Lower bound Note that by the lower bound of the index problem, we can conclude that the number of bits Alice needs to encode all the coefficients is Ω(N)\Omega(N). Since Q(Z)Q(Z) is a dd-variate polynomial of degree at most 2L2^L, it has at most d2Ld^{2^L} coefficients. If BB is the number of maximum number of bits we need to represent any of these coefficients, we have: Bd2LΩ(N)B \cdot d^{2^L} \geq \Omega(N).

    Using the fact that all of the original parameters in BaseConv model use O(logN)O(\log{N}) bits (as well as the fact that d2(logN)1ϵd\le 2^{(\log{N})^{1-\epsilon}}), a little bit of algebraic manipulation gives us our claimed lower bound of

LΩ(ϵloglogN)L\ge \Omega\left(\epsilon\cdot\log\log{N}\right).

Linear attention and MQAR

We have seen that Transformers can represent MQAR in constant model width and layers, independent of the sequence length NN, and the scaling needed for gated convolutions to solve MQAR is undesirable.Can we mimic Transformers without the quadratic scaling in sequence length? Enter linear attention models! These retain input-dependent mixing, making it easy to compare tokens in the sequence unlike convolutions, and give nice efficiency properties (sub-quadratic training and constant inference).

We love polynomials In spirit of our polynomial view of efficient architectures, we consider using a natural option for the feature map, a polynomial approximation to the attention exponential function following Zhang et al., 2024. Recall the Taylor approximation of exp(x)\exp(x):

exp(x)=1+x+x2!++xtt!\exp(x) = 1 + x + \frac{x}{2!} + … + \frac{x^t}{t!} …

Let’s let \otimes represent an outer product operation and let the linear attention feature map, ϕ(\phi(\cdot):RN×dRN×D): \mathbb{R}^{N \times d’} \rightarrow \mathbb{R}^{N \times D}, be the Taylor polynomial for exp\exp following Zhang et al., 2024:

ϕ(q)=Concatenation(1,q,(qq)/2!,)\phi(q) = \mathrm{Concatenation}(1, q, (q \otimes q)/2!,…)

ϕ(k)=Concatenation(1,k,(kk)/2!,)\phi(k) = \mathrm{Concatenation}(1, k, (k \otimes k)/2!, …)

Elementwise multiplying ϕ(q)\phi(q) and ϕ(k)\phi(k), we get for q,kRN×dq, k \in \mathbb{R}^{N \times d’}: exp(qk)=ϕ(q)ϕ(k)=Concatenation(1,qk,qqkk/2!,)\exp(qk) = \phi(q)\phi(k) = \mathrm{Concatenation}(1, qk, qqkk/2!,…)

While we would need infinite terms (i.e., an infinite dimension tensor and amount of GPU hardware memory consumption) to exactly represent attention’s expqk\exp{qk}, we find that roughly order-22 (feature dimension D=1+d+d2D = 1 + d’ + d’^2) seems to be empirically effective given the range of values that qq and kk fall into in large language modeling experiments.

Varying values of dd’ changes the amount of meory (size of hth_t) the model uses. While arithmetic circuit complexity typically focuses on the total size and depth of the circuit computing polynomials, we also care about the amount of external memory a model uses to produce an answer from a hardware efficiency perspective. Keeping track of the memory starts to fall into the communication complexity model of computation, bringing us to our final section.


Part 4: Memory-limited learning

Efficient LMs seeks to reduce the memory requirement relative to Transformers, which require massive “KV”-caches that grow in input length.

Expand to recap the recurrent memory use by architecture!

Let’s briefly review the total state size (GPU memory consumption) ss by different architectures as they processes an xRN×dx \in \mathbb{R}^{N \times d} input (see Appendix E.2 for more discussion).

  • Attention: the size of the KV cache (hKh_{K} and hVh_{V}) described above is: s=2×d×Ns = 2 \times d \times N.
  • Linear attention: the state size is determined by hth_t as discussed above: s=dDs = dD, for feature dimension DD and model dimension dd.
  • Gated convolution (H3): the state is determined by the number of heads dstated_{state}: s=d×dstates = d \times d_{state}.

Specifically, note the efficient Transformer alternatives seek to use external memory that’s independent of sequence length NN.

Autoregressive language modeling should remind you of one-pass streaming algorithms. In this vein, our theoretical analysis of the memory-efficient LMs invokes a prior lower bound result for the index problem in the one-pass streaming setting: the one-way randomized communication complexity of the index problem for a n-length input bit string is Ω(n)\Omega(n).

Index problem. The index problem has two agents, Alice and Bob, where Alice has a string x0,1nx \in {0, 1}^n and Bob has an index i[n]i \in [n]. The goal for the players is to output the ithi^{th} entry: xix_i. We also require the communication to be one-way: only Alice is allowed to send a single message to Bob and Bob needs to output the answer.

There is a pretty straightforward reduction from indexing to MQAR (and in fact to the simpler version AR where the query comes at the end). Specifically given the input x{0,1}nx\in\{0,1\}^n and i[n]i\in [n], the corresponding input to AR is (a sequence; index pair): 1,x1,,j,xj,,n,xn;i1,x_1,\dots,j,x_j,\dots,n,x_n;i. Note that if a model solves this AR instance, then it should match ii to the earlier occurrence of ii and output xix_i.

Recurrent models and MQAR. We can easily conclude that any recurrent model M\mathcal{M} that solves MQAR requires the state size to be at least Ω(N)\Omega(N) bits. To show this, suppose Alice runs M\mathcal{M} on the portion of the AR input above she has (i.e. 1,x1,,j,xj,,n,xn1,x_1,\dots,j,x_j,\dots,n,x_n) and sends the final state hN1h_{N-1} to Bob, who using the state hN1h_{N-1} and his input ii can run M\mathcal{M} to output the input value at position ii. Since we have a lower bound for the index problem, this would present a contradiction and we conclude Ω(N)\Omega(N) bits are required.6

Maybe all hope is not lost! Our subsequent work designs multi-pass recurrent models, which can use less memory than autoregressive models depending on the order of information in the input sequences.


Part 5: Tying the theory to experiments!

We evaluate several efficient LMs, from the architecture classes discussed above, on the MQAR problem. We vary the recurrent state size, or amount of memory each architecture uses during inference, on the xx-axis.

We can see the theoretical scaling results reflected in the figure below.

  1. Attention baseline Sliding window attention is a natural input-dependent and fixed-memory architecture. As we traverse the window size (memory use), it faces a stark quality tradeoff.
  2. Input independent convolutions Gated convolutions (H3, Hyena) falle below the a (sliding window) attention baseline.
  3. Input dependent Input-dependent mixers (Based linear attention, Mamba’s selective SSMs) expand the Pareto frontier relative to attention!
  4. Memory limits All these recurrent LMs face a memory vs. recall-quality tradeoff. Some models are better than others at using their alotted memory.

You can find more discussion of our empirical results in our prior blogposts and try out our MQAR synthetic here: https://github.com/HazyResearch/zoology. We used our insights to design new LM architectures.

Postscript: Related Works!

If you're interested in this line of work, we've compiled a quick reading list of works on complexity and language models to check out. This is a relatively biased list, please reach out if you have more pointers!

Expand for the related works discussion

Works related to arithmetic circuits and language models:

State space models:

Connections to other “traditional” computation models There is a rapidly growing body of work connecting language model architectures to well studied models of computation (other than arithmetic circuits, e.g. boolean circuits) mainly with the goal of showing limitations of these models. Since this body of work is slightly orthogonal to this blog post, we’ll just give a biased sample below:

Takeaways and future questions

We’ve had a ton of fun playing with MQAR and trying to build efficient LMs that solve it more efficiently over the past while. Overall, we hope the polynomial view of efficient LMs helps you reason about key ingredients like input-dependent sequence mixing and managing the quality-efficiency tradeoffs around recurrent state size.

A couple of questions for that might be interesting to think about:

  1. What are other “toy tasks”, like MQAR, that both are empirically meaningful and enable theoretical analysis? Perhaps not only for the sequence mixer. Work from [11] shows that both Transformers and SSMs struggle with certain capabilities like state tracking. Can we better understand when state tracking is important in language modeling with simple synthetics?
  2. We also don’t use the full power of BaseConv to prove our results (we largely instantiate a few types of convolution kernels, like shift kernels). What are the minimal sufficient representations for the capabilities we need?
  3. Our results show that there exists BaseConv models that exactly solve MQAR. However, we do not have any proof that running Gradient Descent with MQAR will converge to a BaseConv model that can solve MQAR. It would be very nice to prove that this is indeed the case. As a first step, it would be nice to present any provable fast learning algorithm that given test data from MQAR learns a BaseConv model that solves the problem.

Please feel free to reach out to us: simran@cs.stanford.edu and eyuboglu@stanford.edu.


  1. One view is to ignore these fine-grained skills and equate quality to perplexity. Efficient LMs have increasingly closed the perplexity gaps to Transformers! But, we find it could be useful to look at more fine-grained skills.
  2. Perhaps of interest to theory folks, we also prove the hardness of MQAR reduces to the hardness of set disjointness, a quintessential and decodes-old problem in complexity theory.
  3. We note that the above result is most interesting when Δ=O~(1)\Delta=\tilde{O}(1)-- luckily for us, many of the well known fast transforms (like the FFT) have poly-log depth arithmetic circuits.
  4. A subtle but crucial point here is that the linear maps that we want to implement are defined on matrices of length O(Nd)O(Nd), but BaseConv does not have the ability to implement arbitrary linear maps over matrices in one layer (we are only allowed linear maps on the rows or columns of these matrices re-packaged as O(N)×O(d)O(N)\times O(d) matrices). However, since the matrices we need to compute maps on here are very structured (shift, diagonal matrices), we can use the bias matrices and Hadamard product to implement them.
  5. Discussed previously, this follows from the definition of BaseConv since each term in the parenthesis in the definition of BaseConv– (uW+b1)(hu+b2)(u\cdot W+b_1)\odot(h\ast u+ b_2)-- is a linear (degree 11) map (since W,b1,hW,b_1,h and b2b_2 are fixed). Multiplying two linear map gives us a polynomial of degree at most 22.
  6. We note that the above reduction style is standard in data streaming lower bounds and indeed the connection to communication complexity lower bounds have also been exploited in other works to prove similar lower bounds (Schlag et al., 2021, Jelassi et al., 2024, Bhattamishra et al., 2024, Zubić et al., 2024, etc.).