Machine learning entirely upends the traditional tradeoffs in code optimisation. A kernel in a machine learning model is run trillions of times even on small examples, across arrays of GPUs drawing multiple gigawatts. While most developers can be blissfully unaware of memory hierarchies, the bandwidth between global and shared memory in an H100 is measured in terabytes per second but still is not enough. Machine learning has unforgiving performance requirements and traditional tools and thinking have not caught up with that.

Thankfully it is not all grim. Kernels in machine learning models are tiny in comparison to most other practical software. Moreover we can afford to spend a lot more time and resources on optimisation. This includes hours of autotuning or months of tuning by highly specialised engineers. It is one of my projects here at Zefram to explore how much of the research on superoptimisation and how many tricks discovered throughout the literature and practice of running machine learning models at scale we can cram into one compiler that does it all for you.

FlashAttention1 is a fascinating case study. It is not a different model architecture nor is it an approximation. FlashAttention computes exactly the same softmax attention as previous implementations. It is much faster by changing how to compute rather than what to compute, aligning better to the engineering constraints posed by real hardware.

In this post we explore a perhaps unconventional view on what FlashAttention does. This is to illustrate a vision of how FlashAttention could have been discovered automatically by a sufficiently smart compiler.

Computing Softmax

The softmax operation is at the core of the attention mechanism defining the transformer architecture, from its inception to the state of the art of LLMs today. Given a vector xRdx \in \mathbb{R}^d the softmax is defined by

softmax(x)i:=exij=1dexj\text{softmax}(x)_i := \frac{e^{x_i}}{\sum_{j = 1}^d e^{x_j}}

While the softmax operation is deceptively simple, computing it efficiently in modern LLMs is not. Implemented naively, softmax struggles with numerical stability. This is especially a problem for modern LLMs when we do not have much precision in our floating point numbers to work with. The issue arises when the sum of the exponentials in the denominator becomes huge, reaching the limits of what the floating point format can represent.

The safe softmax algorithm fixes the numerical stability problems, based on the following trick. By applying some basic algebra, we see that we can compute the same value as the softmax when we add an arbitrary constant cc to every element:

exi+cj=1dexj+cexiecj=1dexjec=ececexij=1dexj=exij=1dexj \frac{e^{x_i + c}}{\sum_{j = 1}^d e^{x_j + c}} \frac{e^{x_i} e^c}{\sum_{j = 1}^d e^{x_j} e^c} = \frac{e^c}{e^c} \frac{e^{x_i}}{\sum_{j = 1}^d e^{x_j}} = \frac{e^{x_i}}{\sum_{j = 1}^d e^{x_j}}

By choosing c:=max(x)c := - \max(x) and subtracting the maximum of the data from each data point, we ensure that xi+c0x_i + c \leq 0 and therefore

0<exj+c10 < e^{x_j + c} \leq 1

This fixes the issue of the exploding sum in the denominator, as it is now bounded by the number of data points:

0<j=1dexj+cd0 < \sum_{j = 1}^d e^{x_j + c} \leq d

But now we have introduced a different problem: Naive softmax can be computed in just one traversal of the data to compute the sum of the exponentials. Safe softmax appears to require two traversals: one to compute the maximum, and one to sum the adjusted exponentials.

LLMs are limited by the speed by which data can be transferred within the memory hierarchy. GPUs have a big and slow pool of memory (global memory) together with small but fast memory (shared memory) for every shader module. To compute anything on a set of data, the GPU has to transfer the data from global memory into the shared memory of the SM that performs the computation.

Because the shared memory is not large enough to fit all of the data required by attention at once, the data needs to be streamed in chunks. But that means that, if we implemented safe softmax with two traversals, we would have to stream the data from global to shared memory twice. To make matters worse, once we have computed the softmax, we use it to weight the values in the attention mechanism, thus incurring a third traversal of the data.

Reductions in Monoids

So the problem that we are facing is to take an aggregate of exponentially weighted values in a way that is both numerically stable and can be efficiently implemented on GPUs. Let us take a step back by studying the theory of aggregation in general. In particular, let us look at monoids.

DefinitionA monoid is a set MM together with a binary operation

+:M×MM+ : M \times M \to M

that satisfies the following conditions:

  • Associativity: For all a,b,cMa, b, c \in M we have (a+b)+c=a+(b+c)(a + b) + c = a + (b + c)
  • Unitality: There exists an element 0M0 \in M (called the unit or neural element) such that for all mMm \in M we have 0+m=m+0=m.0 + m = m + 0 = m.

Examples of monoids are abundant. Here are some of the most familiar:

  • Real numbers R\mathbb{R} with addition ++ and 00.
  • Any vector space VV with addition ++ and 00.
  • Real numbers R\mathbb{R} with multiplication * and 11.
  • Extended real numbers R{}\mathbb{R} \cup \{ -\infty \} with max\max and -\infty.
  • Square matrices Rn×n\mathbb{R}^{n \times n} with matrix multiplication and the identity matrix.

The associativity law for monoids has a direct computational interpretation: it allows us to compute an aggregate in a monoid in many different ways without changing the end result. Consider for example the sum of 8 elements:

((((((x1+x2)+x3)+x4)+x5)+x6)+x7)+x8 ((((((x_1 + x_2) + x_3) + x_4) + x_5) + x_6) + x_7) + x_8

Reading the parentheses operationally, we begin by taking the sum of x1x_1 and x2x_2. We then add x3x_3 to this intermediate result, then x4x_4, etc. This bracketing therefore corresponds to computing the sum sequentially, from left to right. Alternatively, we could have chosen the following bracketing:

x1+(x2+(x3+(x4+(x5+(x6+(x7+x8)))))) x_1 + (x_2 + (x_3 + (x_4 + (x_5 + (x_6 + (x_7 + x_8))))))

Here we start by adding up x7x_7 and x8x_8, then add x6x_6 on the left, then x5x_5, etc. So this corresponds to a sequential computation, but this time from right to left. But we could also bracket the sum like this:

(((x1+x2)+x3)+x4)+(((x5+x6)+x7)+x8) (((x_1 + x_2) + x_3) + x_4) + (((x_5 + x_6) + x_7) + x_8)

Here we have two sequential left to right sums, one for each half of the data set, that could be computed independently and in parallel. Then in the end we add up the intermediate results for each half. Each of these options of bracketing the sum yields the same result when the composition operation is associative. We can indicate this by writing the sum without any parentheses at all:

x1+x2+x3+x4+x5+x6+x7+x8 x_1 + x_2 + x_3 + x_4 + x_5 + x_6 + x_7 + x_8

The associativity law gives us implementation freedom. By describing an algorithm as an aggregate in a monoid, we describe what is computed without prematurely restricting how it is computed. On GPUs we can pick chunks so that the sum can be performed within the shared memory of a SM, optimising for memory bandwidth, or that fit within the fixed size of a matrix multiply performed by a tensor core. On CPUs the chunk size can be chosen to align with the size of available SIMD operations. When the dataset does not fit into memory all at once, we can split it into parts to send to multiple machines and aggregate at the end. Alternatively, we can process the aggregation as a stream.

The monoid unit allows to deal with corner cases more elegantly. The unit allows us to talk about aggregates without conditions that the data set must be non-empty. It allows us to filter data points even in operations that can not change the sequence length. When data does not evenly distribute to compute units, we can pad it with units. This can sometimes be more efficient than implementing branching logic that deals with variable size sequences. Many semigroups (associative operations) that occur in practice have a unit, and if they don't, they can be freely extended with one. So the monoid unit is very convenient to have while not restricting what computations we can do.

Computing Softmax via Monoids

Tying this back to softmax, we can compute the numerically safe sum of exponentials in a single pass over the data by formulating it as a monoid.

DefinitionFor any vector space VV we can equip the set

SE(V):=V×(R{})\text{SE}(V) := V \times (\mathbb{R} \cup \{-\infty\})

with a composition operation \odot defined by

(v1,m1)(v2,m2)=(em1mv1+em2mv2,m)(v_1, m_1) \odot (v_2, m_2) = (e^{m_1 - m} v_1 + e^{m_2 - m} v_2, m)

where m=max(m1,m2)m = \max(m_1, m_2). We write

π:SE(V)V\pi : \text{SE}(V) \to V

for the projection map π(v,m)=v\pi(v, m) = v.

LemmaLet VV be a vector space. Then \odot is an associative operation on SE(V)\text{SE}(V) and (0,)(0, -\infty) is the neutral element for \odot. In particular (SE(V),,(0,))(\text{SE}(V), \odot, (0, -\infty)) is a monoid.

TheoremLet VV be a vector space, v1,,vdVv_1, \ldots, v_d \in V, and x1,,xdRx_1, \ldots, x_d \in \mathbb{R}. Then

emax(x)π(i=1d(vi,xi))=i=1dexivie^{\max(x)} \pi\left(\bigodot_{i = 1}^d (v_i, x_i)\right) = \sum_{i = 1}^d e^{x_i} v_i

ProofBy induction on d0d \geq 0. When d=0d = 0 both sides are 00. For the induction step, suppose the claim holds for some d0d \geq 0 and let vd+1Vv_{d + 1} \in V, xd+1Rx_{d + 1} \in \mathbb{R}. We can then calculate:

 emax(x)π(i=1d+1(vi,xi))= emax(x)π(i=1d(vi,xi)(vd+1,xd+1))= emax(x)(emax(x1:d)max(x)π(i=1d(vi,xi))+exd+1max(x)vi+1)= emax(x1:d)π(i=1d(vi,xi))+exd+1vi+1\begin{aligned} &\ e^{\max(x)} \pi\left( \bigodot_{i = 1}^{d + 1} (v_i, x_i) \right) \\ =&\ e^{\max(x)} \pi\left( \bigodot_{i = 1}^{d} (v_i, x_i) \odot (v_{d + 1}, x_{d + 1}) \right) \\ =&\ e^{\max(x)} \left( e^{\max(x_{1:d}) - \max(x)} \pi\left( \bigodot_{i = 1}^d (v_i, x_i) \right) + e^{x_{d + 1} - \max(x)} v_{i + 1} \right) \\ =&\ e^{\max(x_{1:d})} \pi\left( \bigodot_{i = 1}^d (v_i, x_i) \right) + e^{x_{d + 1}} v_{i + 1} \end{aligned}

Applying the induction hypothesis to the left summand, we have:

=i=1dexivi+exd+1vd+1=i=1d+1exivi\begin{aligned} \cdots &= \sum_{i = 1}^d e^{x_i} v_i + e^{x_{d + 1}} v_{d + 1} = \sum_{i = 1}^{d + 1} e^{x_i} v_i \end{aligned}

We can aggregate values in the monoid SE(R)\text{SE}(\mathbb{R}) to compute the scalar normalisation factor in the denominator of the softmax, and a separate aggregation in SE(Rd)\text{SE}(\mathbb{R}^d) to obtain the vector-valued numerator. By dividing the numerator by the denominator in the end, we have calculated the safe softmax. Denoting by uiu_i the iith one-hot vector in Rd\mathbb{R}^d, we have

softmax(x)=i=1dexiuii=1dexi=emax(x)π(i=1d(ui,xi))emax(x)π(i=1d(1,xi))=π(i=1d(ui,xi))π(i=1d(1,xi)) \text{softmax}(x) = \frac{\sum_{i = 1}^d e^{x_i} u_i}{\sum_{i = 1}^d e^{x_i}} = \frac {e^{\max(x)} \pi\left( \bigodot_{i = 1}^d (u_i, x_i) \right)} {e^{\max(x)} \pi\left( \bigodot_{i = 1}^d (1, x_i) \right) } = \frac {\pi\left( \bigodot_{i = 1}^d (u_i, x_i) \right)} {\pi\left( \bigodot_{i = 1}^d (1, x_i) \right) }

Attention

By accumulating in the value vector space and weighting with logits, this gives us a way to compute softmax attention as the quotient of two aggregations. In particular we can see FlashAttention emerge by reduction factorisation.

Let QRL×dkQ \in \mathbb{R}^{L \times d_k}, KRL×dkK \in \mathbb{R}^{L \times d_k} and VRL×dvV \in \mathbb{R}^{L \times d_v}. Then

attention(Q,K,V)q,:=π(i(Vi,:,Qq,:Ki,:dk))π(i(1,Qq,:Ki,:dk)) \text{attention}(Q, K, V)_{q, :} = \frac{\pi\left(\bigodot_i \left(V_{i, :}, \frac{Q_{q, :}^\top K_{i, :}}{\sqrt{d_k}}\right)\right)} {\pi\left(\bigodot_i \left(1, \frac{Q_{q, :}^\top K_{i, :}}{\sqrt{d_k}}\right)\right)}

We note that the weights for the numerator and denominator are the same. A sufficiently smart compiler would notice this fact and avoid computing them twice, effectively performing the aggregation in the pullback monoid SE(Rdv)×RSE(R)\text{SE}(\mathbb{R}^{d_v}) \times_{\mathbb{R}_{-\infty}} \text{SE}(\mathbb{R}).

Teaching the Compiler

We can apply the same trick of accumulating in SE(V)\text{SE}(V) for any exponentially weighted sum in any vector space VV, and those are rather common across machine learning. Tiling an accumulation, sharing subparts, cancelling out factors, eliminating common subexpressions are all general techniques. A compiler that is equipped with these optimisations could have discovered FlashAttention automatically. Looking into the future, we can distill a library of generalisable tricks from the literature and teach them to a compiler so that when new architectures come around we can save the months of research required to optimise them.

Footnotes

  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. (2022)