June 3, 2024

The theoretical framework of structured state space duality (SSD) (see part 1 and 2 of this blogpost series) connects SSMs and (linear) attention through structured matrices. As mentioned in Part I, this connection allows us to derive new algorithms for selective SSMs that are faster than the parallel associative scan in Mamba-1 by leveraging matrix-multiplication as a primitive. Moreover, the connection brings system optimizations (tensor parallelism, sequence parallelism, variable sequence length) originally developed for Transformer to SSM land. 

The SSD Algorithm

Even though we already developed optimized scans implementations for Mamba-1, we were limited to small state expansion (typically N=16) as the algorithm and implementation did not use tensor cores (specialized hardware units that perform matrix multiplication). Typically matrix multiplication (matmul) FLOPs are much faster (up to 16x) than non-matmul FLOPs: the A100 GPU has 312 TFLOPS of BF16 matmul but only 19 TFLOPS of FP32 arithmetics, and the H100 has 989 TFLOPS of BF16 matmul but only 67 TFLOPS of FP32 arithmetics. One of our primary goals with Mamba-2 is to leverage tensor cores to speed up the SSM.

As SSD (part 1 of the blogpost series) connects SSMs and structured matrices, we see that efficient algorithms to compute SSM or linear attention correspond directly to different decompositions of the “token-mixing” or “sequence-mixing” matrix. After tying parameters and introducing the head structure, the SSM in Mamba-1 turns into a more restrictive form corresponding to a 1-semiseparable matrix (part 1, Efficiency). 

The block decomposition of this matrix corresponds to the SSD algorithm, which has 4 steps. There are two completely different interpretations of this algorithm!

[The paper contains the detailed description of each step, here we will just summarize and provide some intuition for each step. ]

SSD Algorithm: Block Matrix Decomposition

We first partition the SSM (semiseparable) matrix into blocks of size $\mathtt{Q} \times \mathtt{Q}$. Then, we use the properties of semiseparable matrices to factorize each off-diagonal block, which is low rank.

  1. (Orange) Each diagonal block is a smaller semiseparable matrix; we can compute this multiplication however we like, in particular, using the quadratic (attention-like) form of SSD
  2. (Green) There are only $\mathtt{T} / \mathtt{Q}$ total different green blocks because many of them are shared. These can be computed with batched matmuls
  3. (Yellow) Notice that the yellow terms themselves are a 1-semiseparable matrix; in other words, this step is equivalently to an SSM scan (on some modified $A$ factors)!
  4. (Blue) Similar to green, these can be computed with a batched matmul

SSD Algorithm: Chunking and State Passing

An alternative interpretation of the algorithm involves reasoning about how the SSM operates on the sequence. We first split the sequence of input into blocks (or chunks) of size Q. The steps then have the interpretation

  1. Intra-chunk outputs: compute the local output of each chunk in parallel. This can be interpreted as: what is the output per chunk supposing that the initial state (to the chunk) is 0.
  2. Chunk state: compute the final state of each chunk in parallel (using matmul). This can be interpreted as: what is the final state per chunk supposing that the initial state (to the chunk) is 0.
  3. Pass state:compute a recurrence on each of the chunks’ final states (using any desired algorithm, e.g. parallel or sequential scan). This can be interpreted as: what is the actual final state per chunk taking into account all previous inputs.
  4. Output state: for each chunk, given the initial state from (3), compute the contribution to the output just from the initial state.

We see that most of the algorithm (step 1, 2, 4) leverages matmuls (and hence tensor cores), and also can be computed completely in parallel!

Only step 3 requires a scan, but it operates on a much shorter sequence and usually only takes a small fraction of the time.

In the “SSD Minimal” code that we provide in the paper (Listing 1) and the code release, we delineate each of these four steps. As promised, this algorithm is much easier to implement than the original selective scan of Mamba.

Special Cases

We note that special cases of this algorithm have been seen before. In particular RetNet [CITE], which we showed in Part II to be a special case of SSD, mention a “chunkwise” algorithm which computes the quadratic form on a chunk of the input one-at-a-time and passes the final state to the next chunk. This turns out to be essentially equivalent to the SSD algorithm (in the special case of RetNet, or a decay matrix mask $L$). Our derivation comes from a different direction—the block matrix decomposition—which also makes it more obvious how to parallelize this algorithm and make it really fast in practice.

Other forms of “chunkwise” recurrences have become popular, such as in [Gated Linear Attention (GLA)] [CITE].

The Details

Let’s talk about a couple of additional details in the implementation (these don’t even appear in the full paper, so pay attention!) that unpack some of the choices in above code.

The SSM Scan

In the above code, we utilized the connection between scalar SSM recurrences

$$h_{t+1} = A_t h_t + B_t x_t$$

and matrix multiplication by 1-semiseparable matrices

$$L = \begin{bmatrix} 1 & \\ a_1 & 1 & \\  a_2a_1 & a_2 & 1 \\ \vdots & \vdots & \ddots & \ddots \\ a_{\mathtt{T}-1}\dots a_1 & a_{\mathtt{T}-1}\dots a_2 & \dots & a_{\mathtt{T}-1} & 1 \\ \end{bmatrix}$$

which we covered in Part II (and Section 3 of the paper).

We use this version for several reasons:

  1. Code-wise, it’s simpler to materialize and multiply by this matrix than to actually implement a parallel associative scan
  2. Because of the block decomposition of the SSM matrix, the sequence length is reduced by a factor of $\approx 100$ - so doing the scan in time $O(\mathtt{T}^2)$ instead of $O(\mathtt{T})$ isn’t too bad
  3. We have to materialize a 1-SS matrix anyways for Step 1 of the algorithm (the diagonal blocks), so might as well reuse the code 🤷

While this example code is simpler and reasonably efficient on GPU (and probably TPU as well!), it’s no longer truly linear at long sequences. Our most optimized implementation does replace the 1-SS multiplication in Step 3 of the SSD algorithm with an actual associative scan.

Stability

There’s still a subtlety with materializing the 1-semiseparable matrix – how should we do this in a simple and fast way?

Attempt 1: Ratios of cumprods

The first naive attempt may be to notice that the entries of this matrix are cumulative products 

$$a_{i:j}^\times = a_i \times \cdots \times a_{j-1} = \frac{a_{i:\mathtt{T}}^\times}{a_{j:\mathtt{T}}^\times}$$

However, this runs into severe numerical issues because these products can get really tiny (imagine $a_t \approx 0.9$ and powering it up for a sequence length $\mathtt{T}$ in the thousands!)

Fix 1: The segment sum (segsum)

The second attempt would be to do all of this in log-space, because all the $a_t$ are positive; so the products become additions, and instead of `cumprod`s to deal with we have `cumsum`s instead. Then in order to compute the 1-SS matrix, we just have to compute the sums $\log a_i + \dots + \log a_{j-1}$ for every segment $[i:j]$. We call this the segment sum (segsum) primitive, analogous to cumulative sum (cumsum).

Attempt 2: Differences of cumsums

The obvious way to do this again is 

$$a_{i:j}^\times = \exp\left( \log a_i + \cdots + \log a_{j-1}\right) = (\log a)_{i:\mathtt{T}}^+ - (\log a)_{j:\mathtt{T}}^+$$

where we compute a single cumulative sum of $a$ along the time axis, and then subtract. In code, we can do this with

def segsum_unstable(x):
    """Naive segment sum calculation."""
    T = x.size(-1)
    x_cumsum = torch.cumsum(x, dim=-1)
    x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum

(and then the 1-semiseparable matrix is just the exponential of this output)

Sums are generally pretty stable, so this should work – right?

Fix 2: No subtractions

Unfortunately, it turns out this still doesn’t work.

The values of this 1-SS matrix roughly represent the SSM dynamics, which are very sensitive to these values of $a_t$. And even in log space, these cumsums can be fairly large, which runs into catastrophic cancellation when subtracted. So we really have to find a way to compute this matrix with only additions, while still vectorizing everything…

 Attempt 3: Stable Segsum

This leads to the helper function in the [above SSD code]. Instead of computing a single cumsum and then subtracting, we find a way to use a batch of independent cumsums that immediately produces the right answer without subtraction.

These details do matter! Without the right implementation of these primitives, the basic SSD algorithm produces NaNs immediately during training.

Discretization

The lineage of structured state space models developed from [S4] and [its] [predecessors] [cite] which were viewed as continuous-time systems.

In Mamba, however, we don’t actually view it as continuous anymore. In fact, as mentioned in the Discussion (Section 5) of the original paper, Mamba trades off with S4 on modeling different types of data:

  • S4 is a continuous-time model that excels at modeling continuous data (e.g. perceptual signals such as audio waveforms and pixel-level vision)
  • Mamba is a discrete-time model that excels at modeling discrete data (e.g. tokenized data such as language)

However, the parameterization of Mamba still used the same discretization step as in prior structured SSMs, where there is another parameter $\Delta$ being modeled. We do this because the discretization step has other side effects such as properly normalizing the activations :warning: which is important for performance.

The initializations and parameterizations from the previous [theory on structured SSMs][HTTYH] still work out-of-the-box, so why fix what’s not broken?

Despite this, we’re pretty sure that the discretization step isn’t necessary for Mamba. In the Mamba-2 paper, we chose to work directly with the “discrete parameters” $A$ and $B$, which in all previous structured SSM papers (including Mamba-1) were denoted

$$\begin{align*}\bar{A} &= \exp(e^{\Delta A}) \\ \bar{B} &= (\exp(e^{\Delta A}) - I) A^{-1} B \end{align*}$$

In the code, we also kept the same parameterization and discretization step as in Mamba—again, why fix what’s not broken?---but hypothesize that “discrete-centric” variants (such as the “gamma normalization” of [LRU] and [Griffin]) should work equally well. In order to use the continuous parameterization, simply transform the parameters through the above formulas before plugging into the SSD code above!

Is Discretization Necessary?

Probably not. But it’s just a simple invertible transformation, so use either version as you like! 

What’s Next?

In the [final part of this series], we’ll continue talking about the implementation of Mamba-2, but on a more macroscopic level – about the entire neural network, instead of just details of the core SSD layer.

We’ll also talk about the actual speed of the algorithm covered in this post.


Mamba-2: The Systems (and Results)

Transformers have benefited from 7 years of systems optimization from the whole research community and large companies. The SSD framework draws connections between SSMs and attention, and allows us to implement many of these optimizations for models like Mamba-2 as well. We focus on tensor parallel and sequence parallel for large-scale training, as well as variable-length sequences for efficient finetuning and inference.

Tensor parallelism

One difficulty with large-scaling training of Mamba-1 using tensor parallelism (TP) is that it requires 2 all-reduces per layer, compared to just 1 all-reduce per attention or MLP layer in Transformer. This is because some of the SSM parameters are functions of the inner activations, not of the input to the layer. In Mamba-2, with the “parallel projection” structure, all SSM parameters are functions of the input to the layer, and we can easily apply TP to the input projection: 

  1. We split the input projection and output projection matrices into 2, 4, 8 shards, depending on the TP degree.
  2. We use a grouped norm with number of groups divisible by the TP degree, so that normalization is done separately per GPU.

These changes result in 1 all-reduce per layer, instead of 2.

Sequence parallelism

When training on very long sequence length, we might need to split along the sequence length and assign different parts to different devices. There are two main forms of sequence parallelism (SP):

  1. For the residual and normalization operation: this replaces the all-reduce in TP with a reduce-scatter, residual + normalization, then all-gather. Since Mamba-2 uses the same residual and normalization structure as Transformer, this form of SP applies directly with no modification.
  2. For the attention or SSM operation, aka context parallelism (CP). For attention, one could use Ring attention to split it up along the sequence dimension. For Mamba-2, the SSD framework comes to our help once again: using the same block decomposition, we can have each GPU computing its local output and its final states, then pass the states between GPUs (using send/receive communication primitives), before updating the final output of each GPU.

Variable length

For finetuning and inference, in the same batch we often have sequences of different lengths. For Transformer, one would usually pad so all sequences have the same length (wasting computation), or implement attention specifically for variable length sequences with careful load-balancing. 

With SSM, we can simply treat the whole batch as a long “sequence”, and avoid passing the states between different sequences in the batch by setting the state transition $A_t$ to 0 for tokens at the end of each sequence.

Results

How well do these optimizations work? The faster SSD algorithm allows us to increase the state dimension (N=64 or 128 compared to N=16 in Mamba-1). Even though technically Mamba-2 is more restricted than Mamba-1 for the same N, the larger state dimensions generally improve model quality. Here we show results for models trained on 300B tokens on the Pile, with Mamba-2 outperforming Mamba-1 and Pythia.

What about hybrid models? We have seen from recent and concurrent work (Jamba, Zamba) that combining Mamba layers with attention layers can improve over pure Transformer or Mamba. We validate at 2.7B parameters and 300B tokens scale that hybrid model with just 6 attention layers (and 58 SSD layers) outperforms 64 SSD layers, as well as Transformer++ (32 gated MLP and 32 attention layers).

We also validated that the SSD algorithm is significantly faster than the parallel associative scan from Mamba-1 for the same state-dimension (see figure below, with sequence length 2k). Getting those tensor cores to go brrr is the key!

Future Directions

With SSD, we have connected (linear) attention and SSMs, allowing us to design faster algorithms and implement system optimizations for SSMs. There are still tons of exciting directions that we (and hopefully the community) want to tackle:

  1. Understanding: hybrid models with a few (4-6) attention layers perform very well, even better than pure Mamba(-2) or Transformer++. What are these attention layers doing? Can they be replaced with another mechanism?
  2. Training optimizations: though SSD might be faster than attention, Mamba-2 as a whole might still be slower than Transformers at short (e.g. 2K) sequence length, since the MLP layers in Transformers are very hardware-friendly. Our implementation of SSD does not specifically take advantage of new features on H100 GPUs, and we look forward to future optimization that would make SSM faster to train than Transformers for large-scale pretraining at 2-4K sequence length.
  3. Inference optimizations: there’s a whole suite of optimizations tailored to Transformers, in particular handling the KV cache (quantization, speculative decoding). How would the inference landscape change if model states (e.g. SSM states) no longer scale with context length, and KV cache is no longer the bottleneck?