Attention with linear bias, aka ALiBi (Press et al.), is a simple and widely used method to improve Transformer performance and its ability to extrapolate to sequences longer than those it was trained on. It’s used in popular models including the MPT (from MosaicML), Replit, and BLOOM (from BigScience) models and in some of the Cerebras and Falcon models. It’s also used in encoder-only (non-causal) models such as MosaicBERT and even in audio models like Meta’s VoiceBox and AudioBox.
However, ALiBi was previously not as efficient as standard attention since it was harder to leverage optimized implementations such as FlashAttention. To unlock new use cases of ALiBi for large-scale training and new domains with long context, we have been working on optimizing attention with ALiBi. We are excited to share that ALiBi is now supported in FlashAttention v2.4, with a 4-5x speedup compared to a naive Pytorch implementation and 2-3x speedup over torch’s SPDA (currently dispatching to an implementation from xformers when there is an attention bias). In this blogpost, we share some background on ALiBi and its benefits, how to implement ALiBi in a hardware-efficient manner, and some benchmark results.
Background on ALiBi
Transformers typically use position embedding vectors, such as sinusoidal, learned or rotary, to indicate the position of each of their inputs. But these methods lead transformer models to overfit to these position indicators and lead those models to be unable to extrapolate to sequences that are longer than those they were trained on.
ALiBi is a positioning method for Transformers that does not use position embeddings at all, thereby alleviating the issues mentioned above. ALiBi indicates positionality in the transformer by simply modifying the attention mechanism, biasing the attention mechanism to have words that are away from each other interact less than words that are nearby.
This approach leads to other benefits, such as the model being able to attend to cached keys/queries which have previously been computed.
Implementing ALiBi in FlashAttention
ALiBi simply adds a bias to the attention scores. Typically, the bias is pre-built in the High Bandwidth Memory (HBM, typically the GPU memory) and loaded on-chip as needed. However, loading from HBM to on-chip is an expensive operation, and in the case of ALiBi, the cost is proportional to the sequence length, either O(seqlen^2) (non-causal ALiBi) or O(seqlen) (causal ALiBi).
To solve this problem, we need to minimize the load itself. Therefore, instead of loading the entire ALiBi bias, we only load the ALiBi slopes of shape (nheads) or (batch_size, nheads), which does not depend on the sequence length, and then generate the ALiBi bias inside the kernel. Since batch_size and nheads are typically much smaller than seqlen, we can minimize the cost of loading the bias. As a result, we are able to achieve about 94% efficiency compared to the baseline without ALiBi.
Benchmark results: 3-5x speedup for the attention operation
We benchmark the implementation of ALiBi in FlashAttention 2.4 and compare to (1) a naive implementation in PyTorch, and (2) torch’s scaled_dot_product_attention (SDPA), which, as of PyTorch 2.1 and 2.2, dispatches to an implementation from xformers when there is attention bias, and dispatches to FlashAttention-2 when there is no attention bias. On A100 80GB SXM4, we see an average speedup of 4.8x compared to naive PyTorch, and 3.0x compared to torch’s SDPA/xformers. The benchmarking script is here.
We vary (batch size, sequence length) from (32, 512) to (1, 16k). We set hidden dimension to 2048, with head dimension 128 (i.e. 16 heads).
We additionally compare attention with rotary embedding and ALiBi, both using optimized implementations. Previously attention with rotary embedding was faster than ALiBi as it was easier to leverage fast attention implementations. With our ALiBi implementation in FlashAttention v2.4, the two methods now enjoy around the same speed.
Modeling improvements such as ALiBi, when paired with hardware-aware optimizations such as FlashAttention, can significantly improve foundation models’ capabilities and efficiency. We are excited about integrating these improvements to popular frameworks, and on future research on length extrapolation and LLM efficiency.