June 10, 2024

Naman Agarwal, Daniel Suo, Xinyi Chen, Elad Hazan

One of the biggest challenges for the Transformer architecture, powering modern large language models, is their computational inefficiencies on long sequences. The computational complexity of inference of the attention module scales quadratically with the context length, which is therefore limited in practice.

The computational bottleneck of transformers gave rise to recent interest in state space models. These alternate architectures are variants of recurrent neural networks that are much faster in terms of inference. Recent works showed their promise in a variety of tasks requiring long context, such as the long-range arena benchmark.

In this post we describe a recent methodological advancement for state space models: spectral filtering. We describe the theoretical foundations of this technique, and how it gives rises to provable methods with long memory for learning linear dynamical systems. We then describe how the spectral filtering algorithm can be used for designing neural architectures and preliminary results in long range tasks.

Fundamentals of Sequence Prediction

Stateful time series prediction have been used to model a plethora of phenomena. Consider a sequence of inputs and outputs, or measurements, coming from some real-world machine

π‘₯1,π‘₯2,…,π‘₯T,…  β‡’   y1,y2,y3,…,yT,…

For example, these can be:

  • y𝑑 are measurements of ocean temperature, in a study of global warming. π‘₯𝑑 are various environmental knowns, such as time of the year, water acidity, etc.
  • Language modeling or translation, π‘₯𝑑 are words in Hebrew, and y𝑑 is their English translation
  • π‘₯𝑑 are physical controls of a robot, i.e. force in various directions, and y𝑑 are the location to which it moves

Such systems are generally called dynamical systems, and the simplest type of dynamics is linear dynamics. A linear dynamical system has a particularly intuitive interpretation as a (configurable) vector field as depicted below (from Wikipedia):

A series of tables depicting dynamical systems (vector fields).

For linear dynamical systems, the output is generated according to the following equations: 

β„Žt+1= π΄β„Žt+𝐡π‘₯t+Ξ·t yt+1= Cht+Dπ‘₯𝑑+΢𝑑 

Here ht is a hidden state, which can have very large or even infinite dimensionality, 𝐴,𝐡,𝐢,𝐷 are linear transformations and η𝑑, πœπ‘‘ are noise vectors.

This setting is general enough to capture a variety of machine learning models previously considered in isolation, such as hidden Markov models, principle and independent component analysis, mixture of Gaussian clusters and many more, see this survey by Roweis and Ghahramani.


The Memory of Linear Dynamics

The linear equations governing the dynamics are recursive in nature. An input is multiplied by the system matrix 𝐴 before affecting the output. The recursive nature of the dynamics means that an input effect π‘˜ steps into the future will be multiplied by π΄π‘˜, and thus depend exponentially in the eigenvalues of the matrix 𝐴.

The matrix 𝐴 is asymmetric in general and can have complex eigenvalues. If the amplitude of these eigenvalues is >1, then the output 𝑦𝑑 can grow without bounds. This is called an β€œexplosive” system.

In a well-behaved system, the eigenvalues of 𝐴 have magnitude <1. If the magnitudes are bounded away from one, say at most 1βˆ’π›Ώ for some 𝛿>0, then after roughly 1/𝛿 the input decays so much that it won’t have a substantial effect on the output. The quantity 1/𝛿 is referred to as the system memory. This mathematical fact implies that the effective memory of the system is on the order of 1/𝛿. In general, the parameter 𝛿 is unknown a priori and can get arbitrarily small as we approach systems with long range dependencies leading to instability in training linear dynamical systems with a long context. This issue is specifically highlighted in the work of Orvieto et al., who observe that on long range tasks learning an LDS directly does not succeed and requires interventions such as stable exponential parameterizations and specific normalization which have been repeatedly used either implicitly or explicitly in the SSM literature.

Spectral Filtering

A notable deviation from the standard theory of linear dynamical systems that allows efficient learning in the presence of arbitrarily long memory is the technique of spectral filtering. The idea is to project the sequence of inputs to a small subspace that is constructed using a special structure of discrete linear dynamical systems, where successive powers of the system matrix appear in the impulse response function. We can then represent the output as a linear combination of spectral filters which are sequence-length sized vectors that given the target sequence length can be computed offline. These spectral-filters are the eigenvectors of a special matrix constructed as the average of outer products of the discrete impulse-response functions. The structure of this matrix implies that it has a very concentrated spectrum, and very few filters suffice to accurately reproduce any signal. This magical mathematical fact is explained in the next section, the filters themselves are depicted here

A graph depicting an example of spectral transform units.

A schematic figure of the basic neural architecture called the Spectral Transform Unit (STU) is depictured in the following figure. More sophisticated variants including hybrid attention-STU architectures are also implemented and tested by the team.

A more sophisticated graph of STU architecture

Where do the filters come from?

This section is a bit more mathematical; it gives only the gist of how the filters arise. The subspace that we would like to span is the set of all vectors that have the form πœ‡π›Ό=[1, 𝛼, 𝛼2…], since these vectors naturally arise in the recursive application of a linear dynamical system. We thus consider a uniform distribution over these vectors, and the matrix 𝑍=∫1𝛼=0πœ‡π›Όπœ‡π›ΌβŠ€. This is a fixed matrix, unrelated to the data, that naturally arises from the structure of linear dynamics. It has a special property: it is a Hankle matrix, depicted below, and known theorems in mathematics show that its spectrum has an exponential decay property.

A graph depicting a hankel matrix

The filters we use are the eigenvectors corresponding to the largest eigenvalues of this matrix. For more details on how to extend this intuition to higher dimensions see this paper.

Why is Spectral Filtering Important for Longer Memory?

The main advantage of spectral filtering is that for certain types of linear dynamical systems, in particular those with symmetric matrices, the effective memory (measured by the number of filters) required to represent an observation at any point in the sequence in the spectral basis is independent of the system memory parameter 1/𝛿!. This guarantee indicates that if we featurize the input into the spectral basis, we can potentially design models that are capable of efficiently and stably representing systems with extremely long memory even with effective memory approaching infinity 1/π›Ώβ†’βˆž. This striking fact motivates our derivation of the recurrent spectral architecture and is the underlying justification for the performance and training stability gains we see in experiments.

Experiments with neural architectures that make use of spectral filtering, which we call Spectral Transform Unit (STU), show promise on the long-range arena benchmarks as follows:

STU experiment

For more details on the STU neural architecture, and mathematical details on how our filters are designed, and their theoretical properties, check out our recent paper!

Want to read more from Hazan Lab? Stay up to date with their research at their lab website here.