Published on

Flash Attention Version 1

Authors

Introduction

For our LLM Inference Engine, the first optimization we tackled was Flash Attention. Originally introduced by Tri Dao in this paper, it delivered a 2-5x speedup over vanilla MHA and laid the groundwork for everything that followed. As of writing, there are now 4 versions of Flash Attention.

  1. Flash Attention 1 works on any gpu
  2. Flash Attention 2, optimized for Ampere
  3. Flash Attention 3, optimized for Hopper
  4. Flash Attention 4, optimized for Blackwell

In this article we will walk through a hybrid of FA1 and FA2, that uses cuda cores for computation. Our goal is to eventually build our way up to a SOTA competitive FA4 kernel.

Flash Attention Theory

This is the attention formula:

Attention(Q,K,V)=softmaxโ€‰โฃ(QKโŠคdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V

Q, K and V are the Query, Key, and Value matrices respectively. Each of shape [Sequence_length, head_dim].

Now the softmax formula is:

softmax(zi)=expโก(zi)โˆ‘lexpโก(zl)\text{softmax}(z_i) = \frac{\exp(z_i)}{\sum_l \exp(z_l)}

So we can expand the Attention formula (elementwise form) too:

Attention(Q,K,V)i=โˆ‘jexpโกโ€‰โฃ(qiโ‹…kjdk)โˆ‘lexpโกโ€‰โฃ(qiโ‹…kldk)vj\text{Attention}(Q, K, V)_i = \sum_j \frac{\exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}}\right)}{\sum_l \exp\!\left(\frac{q_i \cdot k_l}{\sqrt{d_k}}\right)} v_j

The Numerical Stability Problem

Exponential functions can produce extremely large values relative to the values passed in.

As we can see from this graph, when x is roughly 89 we can no longer represent floating point numbers. Thus to utilize the exponential function we must introduce max scaling.

We perform the scaling by subtracting the largest number seen in the vector, from all numbers in the vector. After the subtraction the largest number in the vector will be 0, and everything else will be negative.

Here is the numerically stable attention formula:

Attention(Q,K,V)i=โˆ‘jexpโกโ€‰โฃ(qiโ‹…kjdkโˆ’mi)โˆ‘lexpโกโ€‰โฃ(qiโ‹…kldkโˆ’mi)vj,mi=maxโกjqiโ‹…kjdk\text{Attention}(Q, K, V)_i = \sum_j \frac{\exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}} - m_i\right)}{\sum_l \exp\!\left(\frac{q_i \cdot k_l}{\sqrt{d_k}} - m_i\right)} v_j, \quad m_i = \max_j \frac{q_i \cdot k_j}{\sqrt{d_k}}

Mathematically this produces the exact same answer, as can be seen through the quotient rule of exponents:

expโก(zjโˆ’m)โˆ‘lexpโก(zlโˆ’m)=expโก(zj)/expโก(m)โˆ‘lexpโก(zl)/expโก(m)=expโก(zj)โˆ‘lexpโก(zl)\frac{\exp(z_j - m)}{\sum_l \exp(z_l - m)} = \frac{\exp(z_j) / \exp(m)}{\sum_l \exp(z_l) / \exp(m)} = \frac{\exp(z_j)}{\sum_l \exp(z_l)}

A Tiny Example

Let's work through a trivial example to illustrate what's going on.

Inputs: Q, K, V
Our three input matrices. Q, K, and V. With head dim of 3, sequence length of 4 and dk of 1 for simplicity.
Q
5.20
4.80
5.10
4.90
5.30
5
5.10
4.70
5.20
5
5.10
4.80
4ร—3
K
5
5.20
4.90
5.10
4.80
5.30
4.80
5.10
5
5.20
5
5.10
4ร—3
V
1
3
2
4
1
5
2
6
1
1
1
3
4ร—3

The Materialization Problem

On the programming side of things vanilla attention is typically implemented using 3 kernels. A matmul kernel for S=QKโŠคS=QK^\top, a softmax kernel for P=softmax(S)P=softmax(S), and a final matmul kernel for O=PVO=PV. While this approach seems logical it imposes one large problem, it's absolutely dominated by memory based operations.

HBM (Global Memory)QKSNร—NVPNร—NOKernel 1: MatmulS = Q ร— K^TreadwriteKernel 2: SoftmaxP = softmax(S)readwriteKernel 3: MatmulO = P ร— VreadreadwriteTotal HBM access: 4Nd + 4NยฒS and P are Nร—N, as sequence length grows, these dominate!Each kernel launch is a full round-trip to slow global memory
FLOPs1=2N2d(computeย S=QKโŠค)FLOPs2โ‰ˆ5N2(softmax)FLOPs3=2N2d(computeย O=PV)FLOPstotal=4N2d+5N2\begin{aligned} \text{FLOPs}_1 &= 2N^2 d && \text{(compute } S = QK^\top\text{)} \\ \text{FLOPs}_2 &\approx 5N^2 && \text{(softmax)} \\ \text{FLOPs}_3 &= 2N^2 d && \text{(compute } O = PV\text{)} \\[6pt] \text{FLOPs}_{\text{total}} &= 4N^2 d + 5N^2 \end{aligned}
Stepย 1:NdโŸreadย Q+NdโŸreadย K+N2โŸwriteย S=2Nd+N2Stepย 2:N2โŸreadย S+N2โŸwriteย P=2N2Stepย 3:N2โŸreadย P+NdโŸreadย V+NdโŸwriteย O=N2+2NdTotal:=4N2+4Ndโ€…โ€Šย elements\begin{aligned} \text{Step 1:} \quad & \underbrace{Nd}_{\text{read } Q} + \underbrace{Nd}_{\text{read } K} + \underbrace{N^2}_{\text{write } S} &&= 2Nd + N^2 \\ \text{Step 2:} \quad & \underbrace{N^2}_{\text{read } S} + \underbrace{N^2}_{\text{write } P} &&= 2N^2 \\ \text{Step 3:} \quad & \underbrace{N^2}_{\text{read } P} + \underbrace{Nd}_{\text{read } V} + \underbrace{Nd}_{\text{write } O} &&= N^2 + 2Nd \\[6pt] \text{Total:} \quad & &&= 4N^2 + 4Nd \;\text{ elements} \end{aligned}
AIstandard=4N2d+5N2(4N2+4Nd)โ‹…Belem\begin{aligned} \text{AI}_{\text{standard}} &= \frac{4N^2 d + 5N^2}{(4N^2 + 4Nd) \cdot B_{\text{elem}}} \\[8pt] \end{aligned}
GPU (BF16)Ridge Pointd=64d=128d=512
A100 SXM (Ampere)156.032.162.7228.1
H100 SXM (Hopper)295.532.162.7228.1
H200 (Hopper)206.232.162.7228.1
B200 HGX (Blackwell)562.532.162.7228.1

Flash Attention

Flash Attention reduced the number of memory ops, through the use of the online softmax algorithm. The online softmax allows us to compute a partial softmax using whatever max we have seen, and rescale later when we come across a new one. This is great for three reasons:

  1. we can do the entire MHA in one kernel
  2. we don't have to read and write the materialized matrices to gmem
  3. Data that lives in smem, can stay in smem for the entire duration of the kernel.

Now to utilize the online softmax we need to make one more adjustment to our equation:

o~i=โˆ‘jexpโกโ€‰โฃ(qiโ‹…kjdkโˆ’mi)vj\tilde{o}_i = \sum_j \exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}} - m_i\right) v_j โ„“i=โˆ‘lexpโกโ€‰โฃ(qiโ‹…kldkโˆ’mi)\ell_i = \sum_l \exp\!\left(\frac{q_i \cdot k_l}{\sqrt{d_k}} - m_i\right) oi=o~iโ„“io_i = \frac{\tilde{o}_i}{\ell_i}

The math is exactly the same as before, but it allows us to visualize that we can decouple the summation (โ„“i\ell_i) from the rest of the equation.

Rescaling explained

Now you may be wondering how we adjust old values? Well the idea is quite simple, thanks to the quotient rule we know that this is true: exโˆ’y=exeye^{x - y} = \frac{e^x}{e^y}, so we can apply the same property to rescaling โ„“i\ell_i and o^i\hat{o}_i:

Consider โ„“i\ell_i computed with the old max:

โ„“i=โˆ‘jexpโกโ€‰โฃ(qiโ‹…kjdkโˆ’miold)\ell_i = \sum_j \exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}} - m_i^{\text{old}}\right)

We want to convert this to use minewm_i^{\text{new}}. Applying the quotient rule to each term:

expโกโ€‰โฃ(qiโ‹…kjdkโˆ’minew)=expโกโ€‰โฃ(qiโ‹…kjdkโˆ’miold)expโกโ€‰โฃ(minewโˆ’miold)\exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}} - m_i^{\text{new}}\right) = \frac{\exp\!\left(\frac{q_i \cdot k_j}{\sqrt{d_k}} - m_i^{\text{old}}\right)}{\exp\!\left(m_i^{\text{new}} - m_i^{\text{old}}\right)}

The exact same logic applies to o^i\hat{o}_i, since the vjv_j weights don't affect the exponential rescaling:

o^inew=o^iexpโกโ€‰โฃ(minewโˆ’miold)\hat{o}_i^{\text{new}} = \frac{\hat{o}_i}{\exp\!\left(m_i^{\text{new}} - m_i^{\text{old}}\right)}

In the actual program this would be maintained using running statistics. These are values that track the per row running sum and maxes, which all get updated when a new max appears.

Another Tiny Example

Here is a visual guide to explain this better:

Setup: track m, l, and out
We track three running values per row: m (running max), l (running sum), and out (output accumulator). We chunk K and V so iterations = N / B_N. In this case B_N = 2
Q
5.2
4.8
5.1
4.9
5.3
5.0
5.1
4.7
5.2
5.0
5.1
4.8
K^T
5.0
5.1
4.8
5.2
5.2
4.8
5.1
5.0
4.9
5.3
5.0
5.1
V
1.0
3.0
2.0
4.0
1.0
5.0
2.0
6.0
1.0
1.0
1.0
3.0
Chunk 1: tokens 0-1
Initial state
rowmlout0-โˆž0.000[0.000, 0.000, 0.000]1-โˆž0.000[0.000, 0.000, 0.000]2-โˆž0.000[0.000, 0.000, 0.000]3-โˆž0.000[0.000, 0.000, 0.000]

CUDA Implementation

Now that we understand the math, let's walk through the actual CUDA kernel. We use CuTe (a layout algebra library from CUTLASS) to describe our data layouts and tiling. Our implementation runs entirely on cuda cores, no tensor cores involved.

The FMHA Struct

The FMHA struct is the top-level interface. It's templated on the key parameters that define the shape and behavior of the kernel:

template <int head_count, int head_dim, int B_r, int B_c, typename DType,
          int thread_count = 128, bool causal_mask = false,
          bool qkv_contigous_buffer = false>
struct FMHA {
  • head_count and head_dim describe the multi-head attention geometry
  • B_r is how many query rows we process per thread block
  • B_c is how many key/value rows we process per iteration
  • thread_count defaults to 128 threads per block
  • causal_mask enables the autoregressive mask (tokens can only attend to previous tokens)
  • qkv_contigous_buffer toggles between interleaved and separate Q, K, V memory layouts

Launching the Kernel

The operator() at the bottom of the struct is how we actually launch the kernel from the host:

void operator()(DType *Q, DType *K, DType *V, DType *O, uint32_t batch_size,
                uint32_t N_q, uint32_t N_kv, uint32_t start_pos) {
  dim3 grid_dim{head_count, batch_size, ceil_div(N_q, B_r)};
  dim3 block_dim{thread_count};
  ...
}

The grid is 3-dimensional: one block per head (x), one per batch element (y), and one per chunk of query rows (z). So each thread block is responsible for computing B_r rows of output for a single head in a single batch element.

We also tell the GPU to carve out as much shared memory as possible by setting the carveout to 100%:

cudaFuncSetAttribute(kernel_fptr,
                     cudaFuncAttributePreferredSharedMemoryCarveout, 100);

This is important because our kernel is shared-memory heavy, we're storing Q, K/V, and the scores matrix P all in SMEM.

Shared Memory Layout

Now this is where things get interesting. The SharedStorage struct defines what lives in shared memory:

struct SharedStorage {
  ArrayEngine<DType, B_r * head_dim> Q;
  ArrayEngine<DType, B_c * head_dim> KV;
  ArrayEngine<DType, B_r * B_c> P;
  ...
};

We have three buffers: Q holds the query tile, KV holds either the current key or value tile (they share the same buffer since we never need both at the same time), and P holds the scores tile (the result of QKโŠคQK^\top after softmax).

Notice that K and V share the KV buffer. This is why you'll see __syncthreads() calls between the K matmul and the V matmul in the kernel, we need to make sure every thread is done reading K before we overwrite it with V.

Swizzled Layouts

The shared memory layouts use swizzling to avoid bank conflicts:

using swizzle_atom = decltype(composition(
    Swizzle<3, 2, 3>{},
    Layout<Shape<_8, Shape<_4, _8>>, Stride<_32, Stride<_1, _4>>>{}));

Without swizzling, threads in the same warp that access the same column of a matrix would hit the same shared memory bank, causing serialization. The Swizzle<3, 2, 3> remaps addresses so that consecutive rows land in different banks. The layout is then tiled out to the full matrix size using tile_to_shape.

For V specifically, we need a transposed view because V gets multiplied on the right side (PVPV). So we define both a VLayoutType (for writing V into shared memory in its natural row-major order) and a VTransposedLayoutType (for reading V during the matmul as if it were transposed):

using VTransposedLayoutType = decltype(tile_to_shape(
    swizzle_atom{}, make_shape(HeadDimType{}, KVColsType{})));
using VLayoutType = decltype(tile_to_shape(
    swizzle_atom_T{}, make_shape(KVColsType{}, HeadDimType{}),
    LayoutRight{}));

This way we write V normally but read it transposed, all without actually moving data around.

The Tiled Copy

To move data from global memory to shared memory, we create a TiledCopy:

template <typename LoadType> static constexpr auto get_tiled_copy() {
  constexpr int elements_per_load{sizeof(LoadType) / sizeof(DType)};
  constexpr int threads_per_row{head_dim / elements_per_load};
  constexpr int rows{thread_count / threads_per_row};
  ...
  return make_tiled_copy(
      Copy_Atom<UniversalCopy<LoadType>, DType>{},
      Layout<Shape<RowType, TPRType>, Stride<TPRType, _1>>{},
      Layout<Shape<_1, EPLType>>{});
}

The idea is straightforward: we distribute threads across the matrix so that each thread loads elements_per_load consecutive elements from one row. The thread layout is (rows, threads_per_row) where each thread handles a contiguous chunk along the head dimension. This gives us coalesced global memory accesses since adjacent threads read adjacent memory addresses.

We create two different tiled copies: one for Q and K using uint128_t (128-bit vectorized loads, 4 floats at a time), and one for V using uint32_t (scalar loads). V uses scalar loads because its writes will be transposed (data wont be contiguous across rows).

The Tiled MMA

For the actual computation, we set up a TiledMMA:

static constexpr auto get_tiled_mma() {
  using RowType = Int<thread_count / 32>;
  auto t_mma{
      make_tiled_mma(UniversalFMA<DType, DType, DType>{},
                     Layout<Shape<RowType, _32>, Stride<_32, _1>>{})};
  return t_mma;
}

This tells CuTe how threads map to the output matrix. We use UniversalFMA (fused multiply-add on cuda cores). The thread layout assigns one warp (32 threads) per row of the output.

Inside the Kernel

Now let's look at the kernel itself. The first thing we do is figure out which slice of the input this thread block is responsible for:

Tensor q_head{MHAType::slice_head(Q, batch_size, N_q)};
const Tensor k_head{MHAType::slice_head(K, batch_size, N_kv)};
const Tensor v_head{MHAType::slice_head(V, batch_size, N_kv)};
Tensor o_head{MHAType::slice_head<true>(O, batch_size, N_q)};

slice_head uses blockIdx.x (head index) and blockIdx.y (batch index) to slice into the 4D tensor [batch, N, num_heads, head_dim], giving us a 2D view [N, head_dim] for this specific head and batch element.

Next we create iterators that tile across the sequence dimension:

Tensor q_iterator{local_tile(q_head, q_tiler, q_coord)}; // (B_r, d)
Tensor k_iterator{
    local_tile(k_head, kv_tiler, kv_coord)}; // (B_c, d, ceil(N_kv / B_c))
Tensor v_iterator{
    local_tile(v_head, kv_tiler, kv_coord)}; // (B_c, d, ceil(N_kv / B_c))

Q is tiled once (this block handles exactly B_r query rows, selected by blockIdx.z). K and V are tiled with an extra dimension for iteration, we'll loop over ceil(N_kv / B_c) chunks.

Initializing the Running Statistics

Before the main loop, we set up the accumulators for the online softmax:

auto m{make_tensor<DType>(mma_m)};
auto l{make_tensor<DType>(mma_m)};
fill(m, -INFINITY);
clear(l);

m tracks the running row-wise maximum (initialized to โˆ’โˆž-\infty so any real value will be larger). l tracks the running row-wise sum (initialized to 0). These are per-thread tensors sized by mma_m, the number of output rows this thread owns.

The Main Loop

The main loop iterates over K/V chunks. The last chunk is handled first with predication (bounds checking), then the rest run without:

// Do the block that needs predication first
MHAType::matmul<true>(tK_global_part_iter(_, _, _, iters - 1), ...);
MHAType::update_statistics<true>(m, l, r_scores_mma, ...);
__syncthreads();
MHAType::matmul<true>(tV_global_part_iter(_, _, _, iters - 1), ...);

// do the rest of blocks that don't need predication
for (int iter{static_cast<int>(iters) - 2}; iter > -1; --iter) {
  __syncthreads();
  MHAType::matmul(tK_global_part_iter(_, _, _, iter), ...);
  MHAType::update_statistics(m, l, r_scores_mma, ...);
  __syncthreads();
  MHAType::matmul(tV_global_part_iter(_, _, _, iter), ...);
}

For each iteration we: compute Schunk=Qโ‹…KchunkโŠคS_{chunk} = Q \cdot K_{chunk}^\top, update the online softmax statistics (m, l) and rescale, then accumulate O+=Pchunkโ‹…VchunkO += P_{chunk} \cdot V_{chunk}. The __syncthreads() between the K and V matmuls ensures the shared KV buffer is safe to overwrite.

The last chunk is processed first because when N_kv isn't perfectly divisible by B_c, that final chunk may be partially filled. By handling it first with predication we avoid branching inside the hot inner loop.

The Matmul Function

The matmul function is where the actual matrix multiplication happens. We manually write it instead of utilizing the cute mma atom in order to take advantage of ILP and vectorized loads. Here's how it works:

float4 a_vecs[mma_m_size];
float4 b_vecs[mma_n_len];

#pragma unroll 8
for (size_t k{0}; k < mma_k_len; k += elements_per_load) {
  // 1. Load A vectors (from shared Q or P)
  // 2. Load B vectors (from shared K or V)
  // 3. FMAs - all .x first, then .y, then .z, then .w
}

We load 4 floats at a time using float4 (128-bit reads from shared memory). The FMA (fused multiply-add) loop is structured to process all .x components first, then .y, .z, .w. This ordering matters because it keeps the same accumulator register live across consecutive FMA instructions.

The #pragma unroll 8 is a deliberate tradeoff: too little unrolling hurts ILP, but fully unrolling can cause register spills (the compiler runs out of registers and starts spilling to local memory, which is slow).

Updating the Online Softmax Statistics

The update_statistics function is the heart of the online softmax. After each QKโŠคQK^\top chunk is computed, we need to:

  1. Scale scores by 1dk\frac{1}{\sqrt{d_k}}
  2. Find the new row-wise maximum
  3. Rescale old statistics
  4. Compute expโก(scoreโˆ’mnew)\exp(\text{score} - m^{\text{new}}) and accumulate the sum
  5. Write probabilities to shared memory for the PVPV matmul
  6. Rescale the running output
// Scale and find local max
for (size_t idx{0}; idx < slice_size; ++idx) {
  r_score_slice(idx) = r_score_slice(idx) * scale;
  current_max = cuda::std::max(r_score_slice(idx), current_max);
}

// Reduce max across the warp
current_max = warp_max(current_max);

Since each thread only holds a few elements of each row, we first find the local maximum per thread, then use warp_max to reduce across all 32 threads in the warp. This gives every thread the true row-wise maximum.

DType scale_old = expf(old_max - current_max);
current_sum = current_sum * scale_old;

Next we rescale using the formula from earlier.

for (size_t idx{0}; idx < slice_size; ++idx) {
  auto p_score = expf(r_score_slice(idx) - current_max);
  local_sum += p_score;
  p_slice(idx) = p_score;
  r_score_slice(idx) = 0; // reset for next iteration
}
current_sum += warp_sum(local_sum);

Then we compute the exponentials and write them to the shared P buffer.

We also reset the scores register to 0, since the matmul function accumulates into it and we need a clean slate for the next K chunk.

for (size_t i{0}; i < size(o_slice); i++) {
  o_slice(i) *= scale_old;
}

Finally we rescale the running output. This is the o^inew=o^iexpโกโ€‰โฃ(minewโˆ’miold)\hat{o}_i^{\text{new}} = \frac{\hat{o}_i}{\exp\!\left(m_i^{\text{new}} - m_i^{\text{old}}\right)} rescaling we discussed in the math section.

Causal Masking

When causal_mask is enabled, we prevent tokens from attending to future positions:

if constexpr (causal_mask) {
  adjusted_bound = get<0>(scores_idty_slice(0)) + start_pos + 1;
}

The identity tensor gives us the absolute row index of this query token. Adding start_pos handles the KV-cache case: during decoding, the query might be at position 500 in the full sequence even though N_q is just 1. Any key with an index >= the query's position + 1 gets its score set to โˆ’โˆž-\infty, which zeroes it out after the exponential.

The Final Normalization

After all iterations complete, we divide by the accumulated sum l to get valid probabilities:

for (size_t m_row{0}; m_row < m_rows; ++m_row) {
  auto out_slice{r_out_mma(_, m_row, _)};
  for (size_t idx{0}; idx < size(out_slice); ++idx) {
    out_slice(idx) = out_slice(idx) / l(m_row);
  }
}

During the loop, we accumulated O=โˆ‘jexpโก(sjโˆ’m)โ‹…vjO = \sum_j \exp(s_j - m) \cdot v_j and l=โˆ‘jexpโก(sjโˆ’m)l = \sum_j \exp(s_j - m). Dividing OO by ll gives us the correctly weighted attention output.

Finally we write the result back to global memory with bounds checking:

for (size_t i{0}; i < write_rows; ++i) {
  auto seq_idx{get<1>(o_mma_idty(0, i, 0))};
  if (seq_idx < N_q)
    copy(r_out_mma(_, i, _), g_out_mma(_, i, _));
}

The predicate seq_idx < N_q ensures we don't write past the end of the sequence when N_q isn't perfectly divisible by B_r.

Benchmarks

Now let's see if all of this actually pays off. We benchmarked CobraML's FA1 kernel against vanilla PyTorch multi-head attention on an A10G GPU. All configs use 16 heads with a head dimension of 64.

Python Benchmarks โ€” CobraML FA1 vs Vanilla PyTorch MHA

Config (B, N, causal)CobraMLVanillaSpeedup
4, 512, non-causal0.527 ms1.198 ms2.27x
4, 512, causal0.660 ms1.771 ms2.68x
8, 59, non-causal0.044 ms0.170 ms3.90x
8, 59, causal0.047 ms0.266 ms5.65x
1, 2048, non-causal1.670 ms4.009 ms2.40x
1, 2048, causal1.663 ms3.868 ms2.33x

The short sequence causal case (B=8, N=59) is where we see the biggest wins โ€” 5.65x over vanilla PyTorch. This makes sense: shorter sequences mean fewer iterations in the main loop, and causal masking lets us skip a large chunk of the score matrix entirely. The longer sequences (N=2048) settle around 2.3โ€“2.4x, which is still a solid improvement.

C++ Benchmarks โ€” Raw Kernel Performance

These numbers measure the kernel in isolation, stripped of any Python overhead:

Config (H, d, B, N, causal)Kernel TimeGFLOPs
16, 64, 4, 5120.393 ms10,925
16, 64, 4, 512, causal0.034 ms3,973
16, 64, 8, 590.032 ms3,537
16, 64, 8, 59, causal0.032 ms3,514
16, 128, 4, 5120.918 ms9,354
16, 128, 4, 512, causal0.857 ms10,018

Peak throughput lands around 11 TFLOPs on the A10G. The larger head dimension (d=128) configs push close to 10 TFLOPs as well.

Next Steps

This is only our first implementation of Flash Attention, expect a FA2 implementation very soon, along with articles covering more components of our inference engine including: matmul kernels, kv caching, ...

Please Reach out if you have any questions!