- Published on
Flash Attention Version 1
- Authors

- Name
- Vishal Padia
- @KyrieBlunders

- Name
- Sriram Govindan
- @s_gowindone
Introduction
For our LLM Inference Engine, the first optimization we tackled was Flash Attention. Originally introduced by Tri Dao in this paper, it delivered a 2-5x speedup over vanilla MHA and laid the groundwork for everything that followed. As of writing, there are now 4 versions of Flash Attention.
- Flash Attention 1 works on any gpu
- Flash Attention 2, optimized for Ampere
- Flash Attention 3, optimized for Hopper
- Flash Attention 4, optimized for Blackwell
In this article we will walk through a hybrid of FA1 and FA2, that uses cuda cores for computation. Our goal is to eventually build our way up to a SOTA competitive FA4 kernel.
- Flash Attention Theory
- CUDA Implementation
Flash Attention Theory
This is the attention formula:
Q, K and V are the Query, Key, and Value matrices respectively. Each of shape [Sequence_length, head_dim].
Now the softmax formula is:
So we can expand the Attention formula (elementwise form) too:
The Numerical Stability Problem
Exponential functions can produce extremely large values relative to the values passed in.
As we can see from this graph, when x is roughly 89 we can no longer represent floating point numbers. Thus to utilize the exponential function we must introduce max scaling.
We perform the scaling by subtracting the largest number seen in the vector, from all numbers in the vector. After the subtraction the largest number in the vector will be 0, and everything else will be negative.
Here is the numerically stable attention formula:
Mathematically this produces the exact same answer, as can be seen through the quotient rule of exponents:
A Tiny Example
Let's work through a trivial example to illustrate what's going on.
The Materialization Problem
On the programming side of things vanilla attention is typically implemented using 3 kernels. A matmul kernel for , a softmax kernel for , and a final matmul kernel for . While this approach seems logical it imposes one large problem, it's absolutely dominated by memory based operations.
| GPU (BF16) | Ridge Point | d=64 | d=128 | d=512 |
|---|---|---|---|---|
| A100 SXM (Ampere) | 156.0 | 32.1 | 62.7 | 228.1 |
| H100 SXM (Hopper) | 295.5 | 32.1 | 62.7 | 228.1 |
| H200 (Hopper) | 206.2 | 32.1 | 62.7 | 228.1 |
| B200 HGX (Blackwell) | 562.5 | 32.1 | 62.7 | 228.1 |
Flash Attention
Flash Attention reduced the number of memory ops, through the use of the online softmax algorithm. The online softmax allows us to compute a partial softmax using whatever max we have seen, and rescale later when we come across a new one. This is great for three reasons:
- we can do the entire MHA in one kernel
- we don't have to read and write the materialized matrices to gmem
- Data that lives in smem, can stay in smem for the entire duration of the kernel.
Now to utilize the online softmax we need to make one more adjustment to our equation:
The math is exactly the same as before, but it allows us to visualize that we can decouple the summation () from the rest of the equation.
Rescaling explained
Now you may be wondering how we adjust old values? Well the idea is quite simple, thanks to the quotient rule we know that this is true: , so we can apply the same property to rescaling and :
Consider computed with the old max:
We want to convert this to use . Applying the quotient rule to each term:
The exact same logic applies to , since the weights don't affect the exponential rescaling:
In the actual program this would be maintained using running statistics. These are values that track the per row running sum and maxes, which all get updated when a new max appears.
Another Tiny Example
Here is a visual guide to explain this better:
CUDA Implementation
Now that we understand the math, let's walk through the actual CUDA kernel. We use CuTe (a layout algebra library from CUTLASS) to describe our data layouts and tiling. Our implementation runs entirely on cuda cores, no tensor cores involved.
The FMHA Struct
The FMHA struct is the top-level interface. It's templated on the key parameters that define the shape and behavior of the kernel:
template <int head_count, int head_dim, int B_r, int B_c, typename DType,
int thread_count = 128, bool causal_mask = false,
bool qkv_contigous_buffer = false>
struct FMHA {
head_countandhead_dimdescribe the multi-head attention geometryB_ris how many query rows we process per thread blockB_cis how many key/value rows we process per iterationthread_countdefaults to 128 threads per blockcausal_maskenables the autoregressive mask (tokens can only attend to previous tokens)qkv_contigous_buffertoggles between interleaved and separate Q, K, V memory layouts
Launching the Kernel
The operator() at the bottom of the struct is how we actually launch the kernel from the host:
void operator()(DType *Q, DType *K, DType *V, DType *O, uint32_t batch_size,
uint32_t N_q, uint32_t N_kv, uint32_t start_pos) {
dim3 grid_dim{head_count, batch_size, ceil_div(N_q, B_r)};
dim3 block_dim{thread_count};
...
}
The grid is 3-dimensional: one block per head (x), one per batch element (y), and one per chunk of query rows (z). So each thread block is responsible for computing B_r rows of output for a single head in a single batch element.
We also tell the GPU to carve out as much shared memory as possible by setting the carveout to 100%:
cudaFuncSetAttribute(kernel_fptr,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
This is important because our kernel is shared-memory heavy, we're storing Q, K/V, and the scores matrix P all in SMEM.
Shared Memory Layout
Now this is where things get interesting. The SharedStorage struct defines what lives in shared memory:
struct SharedStorage {
ArrayEngine<DType, B_r * head_dim> Q;
ArrayEngine<DType, B_c * head_dim> KV;
ArrayEngine<DType, B_r * B_c> P;
...
};
We have three buffers: Q holds the query tile, KV holds either the current key or value tile (they share the same buffer since we never need both at the same time), and P holds the scores tile (the result of after softmax).
Notice that K and V share the KV buffer. This is why you'll see __syncthreads() calls between the K matmul and the V matmul in the kernel, we need to make sure every thread is done reading K before we overwrite it with V.
Swizzled Layouts
The shared memory layouts use swizzling to avoid bank conflicts:
using swizzle_atom = decltype(composition(
Swizzle<3, 2, 3>{},
Layout<Shape<_8, Shape<_4, _8>>, Stride<_32, Stride<_1, _4>>>{}));
Without swizzling, threads in the same warp that access the same column of a matrix would hit the same shared memory bank, causing serialization. The Swizzle<3, 2, 3> remaps addresses so that consecutive rows land in different banks. The layout is then tiled out to the full matrix size using tile_to_shape.
For V specifically, we need a transposed view because V gets multiplied on the right side (). So we define both a VLayoutType (for writing V into shared memory in its natural row-major order) and a VTransposedLayoutType (for reading V during the matmul as if it were transposed):
using VTransposedLayoutType = decltype(tile_to_shape(
swizzle_atom{}, make_shape(HeadDimType{}, KVColsType{})));
using VLayoutType = decltype(tile_to_shape(
swizzle_atom_T{}, make_shape(KVColsType{}, HeadDimType{}),
LayoutRight{}));
This way we write V normally but read it transposed, all without actually moving data around.
The Tiled Copy
To move data from global memory to shared memory, we create a TiledCopy:
template <typename LoadType> static constexpr auto get_tiled_copy() {
constexpr int elements_per_load{sizeof(LoadType) / sizeof(DType)};
constexpr int threads_per_row{head_dim / elements_per_load};
constexpr int rows{thread_count / threads_per_row};
...
return make_tiled_copy(
Copy_Atom<UniversalCopy<LoadType>, DType>{},
Layout<Shape<RowType, TPRType>, Stride<TPRType, _1>>{},
Layout<Shape<_1, EPLType>>{});
}
The idea is straightforward: we distribute threads across the matrix so that each thread loads elements_per_load consecutive elements from one row. The thread layout is (rows, threads_per_row) where each thread handles a contiguous chunk along the head dimension. This gives us coalesced global memory accesses since adjacent threads read adjacent memory addresses.
We create two different tiled copies: one for Q and K using uint128_t (128-bit vectorized loads, 4 floats at a time), and one for V using uint32_t (scalar loads). V uses scalar loads because its writes will be transposed (data wont be contiguous across rows).
The Tiled MMA
For the actual computation, we set up a TiledMMA:
static constexpr auto get_tiled_mma() {
using RowType = Int<thread_count / 32>;
auto t_mma{
make_tiled_mma(UniversalFMA<DType, DType, DType>{},
Layout<Shape<RowType, _32>, Stride<_32, _1>>{})};
return t_mma;
}
This tells CuTe how threads map to the output matrix. We use UniversalFMA (fused multiply-add on cuda cores). The thread layout assigns one warp (32 threads) per row of the output.
Inside the Kernel
Now let's look at the kernel itself. The first thing we do is figure out which slice of the input this thread block is responsible for:
Tensor q_head{MHAType::slice_head(Q, batch_size, N_q)};
const Tensor k_head{MHAType::slice_head(K, batch_size, N_kv)};
const Tensor v_head{MHAType::slice_head(V, batch_size, N_kv)};
Tensor o_head{MHAType::slice_head<true>(O, batch_size, N_q)};
slice_head uses blockIdx.x (head index) and blockIdx.y (batch index) to slice into the 4D tensor [batch, N, num_heads, head_dim], giving us a 2D view [N, head_dim] for this specific head and batch element.
Next we create iterators that tile across the sequence dimension:
Tensor q_iterator{local_tile(q_head, q_tiler, q_coord)}; // (B_r, d)
Tensor k_iterator{
local_tile(k_head, kv_tiler, kv_coord)}; // (B_c, d, ceil(N_kv / B_c))
Tensor v_iterator{
local_tile(v_head, kv_tiler, kv_coord)}; // (B_c, d, ceil(N_kv / B_c))
Q is tiled once (this block handles exactly B_r query rows, selected by blockIdx.z). K and V are tiled with an extra dimension for iteration, we'll loop over ceil(N_kv / B_c) chunks.
Initializing the Running Statistics
Before the main loop, we set up the accumulators for the online softmax:
auto m{make_tensor<DType>(mma_m)};
auto l{make_tensor<DType>(mma_m)};
fill(m, -INFINITY);
clear(l);
m tracks the running row-wise maximum (initialized to so any real value will be larger). l tracks the running row-wise sum (initialized to 0). These are per-thread tensors sized by mma_m, the number of output rows this thread owns.
The Main Loop
The main loop iterates over K/V chunks. The last chunk is handled first with predication (bounds checking), then the rest run without:
// Do the block that needs predication first
MHAType::matmul<true>(tK_global_part_iter(_, _, _, iters - 1), ...);
MHAType::update_statistics<true>(m, l, r_scores_mma, ...);
__syncthreads();
MHAType::matmul<true>(tV_global_part_iter(_, _, _, iters - 1), ...);
// do the rest of blocks that don't need predication
for (int iter{static_cast<int>(iters) - 2}; iter > -1; --iter) {
__syncthreads();
MHAType::matmul(tK_global_part_iter(_, _, _, iter), ...);
MHAType::update_statistics(m, l, r_scores_mma, ...);
__syncthreads();
MHAType::matmul(tV_global_part_iter(_, _, _, iter), ...);
}
For each iteration we: compute , update the online softmax statistics (m, l) and rescale, then accumulate . The __syncthreads() between the K and V matmuls ensures the shared KV buffer is safe to overwrite.
The last chunk is processed first because when N_kv isn't perfectly divisible by B_c, that final chunk may be partially filled. By handling it first with predication we avoid branching inside the hot inner loop.
The Matmul Function
The matmul function is where the actual matrix multiplication happens. We manually write it instead of utilizing the cute mma atom in order to take advantage of ILP and vectorized loads. Here's how it works:
float4 a_vecs[mma_m_size];
float4 b_vecs[mma_n_len];
#pragma unroll 8
for (size_t k{0}; k < mma_k_len; k += elements_per_load) {
// 1. Load A vectors (from shared Q or P)
// 2. Load B vectors (from shared K or V)
// 3. FMAs - all .x first, then .y, then .z, then .w
}
We load 4 floats at a time using float4 (128-bit reads from shared memory). The FMA (fused multiply-add) loop is structured to process all .x components first, then .y, .z, .w. This ordering matters because it keeps the same accumulator register live across consecutive FMA instructions.
The #pragma unroll 8 is a deliberate tradeoff: too little unrolling hurts ILP, but fully unrolling can cause register spills (the compiler runs out of registers and starts spilling to local memory, which is slow).
Updating the Online Softmax Statistics
The update_statistics function is the heart of the online softmax. After each chunk is computed, we need to:
- Scale scores by
- Find the new row-wise maximum
- Rescale old statistics
- Compute and accumulate the sum
- Write probabilities to shared memory for the matmul
- Rescale the running output
// Scale and find local max
for (size_t idx{0}; idx < slice_size; ++idx) {
r_score_slice(idx) = r_score_slice(idx) * scale;
current_max = cuda::std::max(r_score_slice(idx), current_max);
}
// Reduce max across the warp
current_max = warp_max(current_max);
Since each thread only holds a few elements of each row, we first find the local maximum per thread, then use warp_max to reduce across all 32 threads in the warp. This gives every thread the true row-wise maximum.
DType scale_old = expf(old_max - current_max);
current_sum = current_sum * scale_old;
Next we rescale using the formula from earlier.
for (size_t idx{0}; idx < slice_size; ++idx) {
auto p_score = expf(r_score_slice(idx) - current_max);
local_sum += p_score;
p_slice(idx) = p_score;
r_score_slice(idx) = 0; // reset for next iteration
}
current_sum += warp_sum(local_sum);
Then we compute the exponentials and write them to the shared P buffer.
We also reset the scores register to 0, since the matmul function accumulates into it and we need a clean slate for the next K chunk.
for (size_t i{0}; i < size(o_slice); i++) {
o_slice(i) *= scale_old;
}
Finally we rescale the running output. This is the rescaling we discussed in the math section.
Causal Masking
When causal_mask is enabled, we prevent tokens from attending to future positions:
if constexpr (causal_mask) {
adjusted_bound = get<0>(scores_idty_slice(0)) + start_pos + 1;
}
The identity tensor gives us the absolute row index of this query token. Adding start_pos handles the KV-cache case: during decoding, the query might be at position 500 in the full sequence even though N_q is just 1. Any key with an index >= the query's position + 1 gets its score set to , which zeroes it out after the exponential.
The Final Normalization
After all iterations complete, we divide by the accumulated sum l to get valid probabilities:
for (size_t m_row{0}; m_row < m_rows; ++m_row) {
auto out_slice{r_out_mma(_, m_row, _)};
for (size_t idx{0}; idx < size(out_slice); ++idx) {
out_slice(idx) = out_slice(idx) / l(m_row);
}
}
During the loop, we accumulated and . Dividing by gives us the correctly weighted attention output.
Finally we write the result back to global memory with bounds checking:
for (size_t i{0}; i < write_rows; ++i) {
auto seq_idx{get<1>(o_mma_idty(0, i, 0))};
if (seq_idx < N_q)
copy(r_out_mma(_, i, _), g_out_mma(_, i, _));
}
The predicate seq_idx < N_q ensures we don't write past the end of the sequence when N_q isn't perfectly divisible by B_r.
Benchmarks
Now let's see if all of this actually pays off. We benchmarked CobraML's FA1 kernel against vanilla PyTorch multi-head attention on an A10G GPU. All configs use 16 heads with a head dimension of 64.
Python Benchmarks โ CobraML FA1 vs Vanilla PyTorch MHA
| Config (B, N, causal) | CobraML | Vanilla | Speedup |
|---|---|---|---|
| 4, 512, non-causal | 0.527 ms | 1.198 ms | 2.27x |
| 4, 512, causal | 0.660 ms | 1.771 ms | 2.68x |
| 8, 59, non-causal | 0.044 ms | 0.170 ms | 3.90x |
| 8, 59, causal | 0.047 ms | 0.266 ms | 5.65x |
| 1, 2048, non-causal | 1.670 ms | 4.009 ms | 2.40x |
| 1, 2048, causal | 1.663 ms | 3.868 ms | 2.33x |
The short sequence causal case (B=8, N=59) is where we see the biggest wins โ 5.65x over vanilla PyTorch. This makes sense: shorter sequences mean fewer iterations in the main loop, and causal masking lets us skip a large chunk of the score matrix entirely. The longer sequences (N=2048) settle around 2.3โ2.4x, which is still a solid improvement.
C++ Benchmarks โ Raw Kernel Performance
These numbers measure the kernel in isolation, stripped of any Python overhead:
| Config (H, d, B, N, causal) | Kernel Time | GFLOPs |
|---|---|---|
| 16, 64, 4, 512 | 0.393 ms | 10,925 |
| 16, 64, 4, 512, causal | 0.034 ms | 3,973 |
| 16, 64, 8, 59 | 0.032 ms | 3,537 |
| 16, 64, 8, 59, causal | 0.032 ms | 3,514 |
| 16, 128, 4, 512 | 0.918 ms | 9,354 |
| 16, 128, 4, 512, causal | 0.857 ms | 10,018 |
Peak throughput lands around 11 TFLOPs on the A10G. The larger head dimension (d=128) configs push close to 10 TFLOPs as well.
Next Steps
This is only our first implementation of Flash Attention, expect a FA2 implementation very soon, along with articles covering more components of our inference engine including: matmul kernels, kv caching, ...
Please Reach out if you have any questions!