Machine learning entirely upends the traditional tradeoffs in code optimisation. A kernel in a machine learning model is run trillions of times even on small examples, across arrays of GPUs drawing multiple gigawatts. While most developers can be blissfully unaware of memory hierarchies, the bandwidth between global and shared memory in an H100 is measured in terabytes per second but still is not enough. Machine learning has unforgiving performance requirements and traditional tools and thinking have not caught up with that.
Thankfully it is not all grim. Kernels in machine learning models are tiny in comparison to most other practical software. Moreover we can afford to spend a lot more time and resources on optimisation. This includes hours of autotuning or months of tuning by highly specialised engineers. It is one of my projects here at Zefram to explore how much of the research on superoptimisation and how many tricks discovered throughout the literature and practice of running machine learning models at scale we can cram into one compiler that does it all for you.
FlashAttention1 is a fascinating case study. It is not a different model architecture nor is it an approximation. FlashAttention computes exactly the same softmax attention as previous implementations. It is much faster by changing how to compute rather than what to compute, aligning better to the engineering constraints posed by real hardware.
In this post we explore a perhaps unconventional view on what FlashAttention does. This is to illustrate a vision of how FlashAttention could have been discovered automatically by a sufficiently smart compiler.
Computing Softmax
The softmax operation is at the core of the attention mechanism defining the transformer architecture, from its inception to the state of the art of LLMs today. Given a vector the softmax is defined by
While the softmax operation is deceptively simple, computing it efficiently in modern LLMs is not. Implemented naively, softmax struggles with numerical stability. This is especially a problem for modern LLMs when we do not have much precision in our floating point numbers to work with. The issue arises when the sum of the exponentials in the denominator becomes huge, reaching the limits of what the floating point format can represent.
The safe softmax algorithm fixes the numerical stability problems, based on the following trick. By applying some basic algebra, we see that we can compute the same value as the softmax when we add an arbitrary constant to every element:
By choosing and subtracting the maximum of the data from each data point, we ensure that and therefore
This fixes the issue of the exploding sum in the denominator, as it is now bounded by the number of data points:
But now we have introduced a different problem: Naive softmax can be computed in just one traversal of the data to compute the sum of the exponentials. Safe softmax appears to require two traversals: one to compute the maximum, and one to sum the adjusted exponentials.
LLMs are limited by the speed by which data can be transferred within the memory hierarchy. GPUs have a big and slow pool of memory (global memory) together with small but fast memory (shared memory) for every shader module. To compute anything on a set of data, the GPU has to transfer the data from global memory into the shared memory of the SM that performs the computation.
Because the shared memory is not large enough to fit all of the data required by attention at once, the data needs to be streamed in chunks. But that means that, if we implemented safe softmax with two traversals, we would have to stream the data from global to shared memory twice. To make matters worse, once we have computed the softmax, we use it to weight the values in the attention mechanism, thus incurring a third traversal of the data.
Reductions in Monoids
So the problem that we are facing is to take an aggregate of exponentially weighted values in a way that is both numerically stable and can be efficiently implemented on GPUs. Let us take a step back by studying the theory of aggregation in general. In particular, let us look at monoids.
Definition: A monoid is a set together with a binary operation
that satisfies the following conditions:
- Associativity: For all we have
- Unitality: There exists an element (called the unit or neural element) such that for all we have
Examples of monoids are abundant. Here are some of the most familiar:
- Real numbers with addition and .
- Any vector space with addition and .
- Real numbers with multiplication and .
- Extended real numbers with and .
- Square matrices with matrix multiplication and the identity matrix.
The associativity law for monoids has a direct computational interpretation: it allows us to compute an aggregate in a monoid in many different ways without changing the end result. Consider for example the sum of 8 elements:
Reading the parentheses operationally, we begin by taking the sum of and . We then add to this intermediate result, then , etc. This bracketing therefore corresponds to computing the sum sequentially, from left to right. Alternatively, we could have chosen the following bracketing:
Here we start by adding up and , then add on the left, then , etc. So this corresponds to a sequential computation, but this time from right to left. But we could also bracket the sum like this:
Here we have two sequential left to right sums, one for each half of the data set, that could be computed independently and in parallel. Then in the end we add up the intermediate results for each half. Each of these options of bracketing the sum yields the same result when the composition operation is associative. We can indicate this by writing the sum without any parentheses at all:
The associativity law gives us implementation freedom. By describing an algorithm as an aggregate in a monoid, we describe what is computed without prematurely restricting how it is computed. On GPUs we can pick chunks so that the sum can be performed within the shared memory of a SM, optimising for memory bandwidth, or that fit within the fixed size of a matrix multiply performed by a tensor core. On CPUs the chunk size can be chosen to align with the size of available SIMD operations. When the dataset does not fit into memory all at once, we can split it into parts to send to multiple machines and aggregate at the end. Alternatively, we can process the aggregation as a stream.
The monoid unit allows to deal with corner cases more elegantly. The unit allows us to talk about aggregates without conditions that the data set must be non-empty. It allows us to filter data points even in operations that can not change the sequence length. When data does not evenly distribute to compute units, we can pad it with units. This can sometimes be more efficient than implementing branching logic that deals with variable size sequences. Many semigroups (associative operations) that occur in practice have a unit, and if they don't, they can be freely extended with one. So the monoid unit is very convenient to have while not restricting what computations we can do.
Computing Softmax via Monoids
Tying this back to softmax, we can compute the numerically safe sum of exponentials in a single pass over the data by formulating it as a monoid.
Definition: For any vector space we can equip the set
with a composition operation defined by
where . We write
for the projection map .
Lemma: Let be a vector space. Then is an associative operation on and is the neutral element for . In particular is a monoid.
Theorem: Let be a vector space, , and . Then
Proof: By induction on . When both sides are . For the induction step, suppose the claim holds for some and let , . We can then calculate:
Applying the induction hypothesis to the left summand, we have:
We can aggregate values in the monoid to compute the scalar normalisation factor in the denominator of the softmax, and a separate aggregation in to obtain the vector-valued numerator. By dividing the numerator by the denominator in the end, we have calculated the safe softmax. Denoting by the th one-hot vector in , we have
Attention
By accumulating in the value vector space and weighting with logits, this gives us a way to compute softmax attention as the quotient of two aggregations. In particular we can see FlashAttention emerge by reduction factorisation.
Let , and . Then
We note that the weights for the numerator and denominator are the same. A sufficiently smart compiler would notice this fact and avoid computing them twice, effectively performing the aggregation in the pullback monoid .
Teaching the Compiler
We can apply the same trick of accumulating in for any exponentially weighted sum in any vector space , and those are rather common across machine learning. Tiling an accumulation, sharing subparts, cancelling out factors, eliminating common subexpressions are all general techniques. A compiler that is equipped with these optimisations could have discovered FlashAttention automatically. Looking into the future, we can distill a library of generalisable tricks from the literature and teach them to a compiler so that when new architectures come around we can save the months of research required to optimise them.