Dao 2022 - Flash Attention

Paper Link

This paper argues that the attention mechanism is slow because of reading / writing between GPU High Bandwidth Memory and GPU on-chip SRAM. The authors hence create a block-wise attention algorithm that minimizes such IO read / writes and speeds up attention significantly especially when the sequence length is long.

Brief Overview of Attention

Suppose we have an input sequence of embeddings where , such that . Naively, we can compute activations by , where , such that . However, this naive way of encoding our input sequence does not allow interaction between inputs at different positions (say with ). We can see this by observing that the first row of is only affected by the first column of (i.e. first encoding ), and likewise for all the other positions.

Attention addresses this problem by adding an interaction mechanism. Besides , we also create weight parameters . Given an input , we compute as follows:

We then create an interaction matrix , and apply row-wise softmax to get . can be thought of as a pairwise similarity matrix between the encoding at position and position that captures the degree of interaction. For example, in a sentence the economy has been in decline, the value of (assuming 0-index) measuring the interaction between economy and decline might be high.

Finally, we produce the output , which is an activation output from the input sequence that has captured the interactions between tokens at different positions of the input. This simple mechanism has led to significant improvements in language modelling.

GPU Memory Hierarchy

The memory hierarchy is such that read/write speed is super fast on the SRAM but memory is highly limited. Hence, the N x N attention matrix is written/read repeatedly to/from HBM, resulting in IO being a bottleneck. The numbers are as such on an A100 GPU:

  • SRAM: 19 TB/s (20 MB RAM)
  • HBM: 1.5 TB/s (40 GB RAM)

Naive Attention Algorithm

The naive attention algorithm has many reads and writes to HBM. (ps: Not sure why we cannot persist the intermediate matrices on SRAM and complete the computations, but in any case the naive algorithm requires materializing the matrices on SRAM which will quickly flood it. For example, a sequence length of 2,048 at float32 already takes up 33MB for the matrix).

  1. Load from HBM, compute , write to HBM
  2. Read from HBM, compute , write to HBM
  3. Load from HBM, compute , write to HBM

Flash Attention

The main idea is quite simple: instead of computing the full attention matrix, we use block-wise tiling to compute parts of it at a time. This reduces the memory required for each block and allows the whole computation to be done on SRAM while minimizing the amount of IO read from HBM, leading to faster compute time and lower memory usage on SRAM. The difficulty is in devising a block-wise softmax algorithm that yields the exact same result as computing it all at once.

Consider the naive softmax algorithm on an arbitrary vector .

Note that the maximum value is subtracted for numerical stability to avoid overflow (underflow is ok because ). is the numerator and is the sum of all elements in .

Now, the problem with the naive softmax algorithm in the context of attention is that we need an entire row of ( elements) to perform the row-wise softmax computation. This will not be available if we are performing block-wise computation, since we are splitting row-wise into blocks of . When we compute , blocks of will be materialized in each pass, but not the entire row at a time.

Hence, we need a modified algorithm that allows us to compute chunks of the final output at a time by iterating block-wise through , such that the combination of the new chunk of at each step with the already written intermediate gives the correct result at the end. The key to realizing this algorithm is in decomposing the softmax step, as shown below.

Consider two vectors . We can decompose the softmax of their concatenated vector as follows:

The first line of the above simply notes that the maximum of is the maximum over each of the subvector maximums . The second line notes that we previously multiplied each element of by a factor, say for those in . To get the correct multiplier for the full vector , we need to divide away the previous multiplier and apply the new multiplier, i.e. . The third line notes that the new denominator is the sum over each of the subvector sums, after we apply the correct multiplier from line 2.

The decomposition is simple but powerful. It implies that so long as we keep track of intermediate statistics and , we can compute the softmax of a long vector by splitting into subvectors and operate over each subvector at a time.

Now we are ready for Algorithm 1: Flash Attention of the paper.



Note that we use to denote a zero matrix of size and a zero array of size respectively. For simplicity, we divide into equal -sized blocks but the paper allows different block sizes for and . The equation numbers on the right in parentheses show which equations the lines correspond to above. Equation line 15 is a bit confusing because it combines multiple steps together. The next few paras try to unpack this.

Firstly, note that we are using the operator to denote an element-wise broadcasted multiplication. For a vector and a matrices , observe the associative property , since each element of only affects the corresponding row in the final matrix. This allows us to apply the scaling to either or and the result will be the same.

Next, see that the term is simply the corrected numerator of the softmax dotted with . Dividing this term by gives the output block for this particular pair.

Similarly, the other term is the existing output that has been accumulated from previous steps . Due to the associative property, we can also directly apply the scaling correction to . The are scaling factors according to equations (6), (8) to correct the scaling of previous steps.

Finally, we should understand why there is a + in equation 15. I find it easier to visualize if we set . If we trace the matrix multiplications, we will observe that is only affected by , i.e. it corresponds to only the query token in position . Now, represents the weighted average over all positions of the matrix where the weights are determined by the softmax over the interaction between (representing one token) and all positions on the matrix. This weighted average is why it is a symbol: we are accumulating the weighted sum over into . The only complication is that we are applying the scaling corrections at each step.

Hopefully these explanations provide some intuition to the FlashAttention algorithm, which is quite a simple idea but makes a ton of difference practically. It should be easy to implement this algorithm in numpy if the reader wishes to understand it better.