Tutorial 6: Fused Multihead Attention#
Attention is a performance-critical operation at the heart of transformer models (BERT, GPT, etc.). It computes a weighted combination of values, where the weights reflect the relevance of each position in the sequence. Given parameters Q, K, and V constructed from an input sequence, attention is computed as:
Attention(Q, K, V) = softmax(Q @ K^T / √d) @ V
Where:
Q (Query): “What am I looking for?”
K (Key): “What do I contain?”
V (Value): “What information do I have?”
d: The embedding dimension of Q and K. This may differ from V.
The softmax produces attention weights — a probability distribution over positions in the sequence.
Note: In this tutorial, we write a fused multihead attention (FMHA) kernel using cuTile Rust’s tile abstractions. The programmer expresses the algorithm — tiled Q/K/V access, online softmax, streaming accumulation — while the compiler handles the heavy lifting that makes this a “Flash Attention”-caliber implementation: staging data through shared memory, mapping operations onto Tensor Cores, managing the register file, and coalescing memory accesses. In a traditional CUDA C++ Flash Attention kernel, these low-level details dominate the code; here, the tile programming model abstracts them away. In parts of this project outside of this tutorial, we may refer to Flash Attention and Fused MHA interchangeably.
The Memory Problem#
The intermediate attention scores matrix is quadratic in the sequence length:
Q shape: (batch, heads, seq_len, head_dim) = (1, 1, N, 64)
K shape: (batch, heads, seq_len, head_dim) = (1, 1, N, 64)
Q @ K^T → shape: (1, 1, N, N)
For N=65k, the scores matrix alone requires 4 billion elements. Long sequences simply do not fit in GPU memory with naive attention.
The Softmax Challenge#
Softmax requires access to all values in a row to compute reduce_max and reduce_sum:
let max_x = reduce_max(x, axis);
let exp_x = exp(x - max_x);
let sum_exp = reduce_sum(exp_x, axis);
let result = exp_x / sum_exp;
In standard attention, each row has N elements, so the entire N×N matrix must be stored before computing softmax.
Online Softmax#
Softmax can be computed incrementally by processing one chunk at a time and maintaining running statistics.
When a new chunk introduces a larger maximum, previous results must be corrected:
Before: We computed exp(x₁ - m₁) where m₁ was the old max.
After: We need exp(x₁ - m₂) where m₂ is the new (larger) max.
Correction: exp(x₁ - m₂) = exp(x₁ - m₁) × exp(m₁ - m₂)
^^^^^^^^^^^^ ^^^^^^^^^^^^^
what we had correction factor (α)
The algorithm maintains three running values:
m: Running maximum.l: Running sum of exponentials.acc: Running output accumulator.
When a new maximum appears, all previous results are rescaled by exp(old_max - new_max).
Memory Savings#
At any point the kernel stores:
One Q tile:
BM × Delements (e.g., 64 × 64 = 4,096).One K tile:
BN × Delements (e.g., 32 × 64 = 2,048).One V tile:
BN × Delements (e.g., 32 × 64 = 2,048).Running stats:
BM × 1for max and sum (~64 each).Output accumulator:
BM × Delements (4,096).
Total: ~12,000 elements per tile, regardless of sequence length.
The Algorithm#
For each Q tile (row block of the output):
Initialize: max = -∞, sum = 0, output = 0
For each K,V tile (streaming through the sequence):
1. Compute attention scores: scores = Q_tile @ K_tile^T
2. Scale: scores = scores / √d
3. Update running max: new_max = max(current_max, max(scores))
4. Compute correction factor: α = exp(old_max - new_max)
5. Rescale previous results: output *= α, sum *= α
6. Compute new exponentials: P = exp(scores - new_max)
7. Update sum: sum += sum(P)
8. Accumulate: output += P @ V_tile
9. Update max: current_max = new_max
Normalize: output = output / sum
Implementation note: The code below uses
exp2instead ofexpas a performance optimization —exp2is faster on GPU hardware. To compensate, the scale factor is divided byln(2)so thatexp2(x / ln(2)) = exp(x). The correction factorαand softmax numeratorPare both computed withexp2using this adjusted scale.
The Code#
use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use std::sync::Arc;
use cutile;
use cutile::api::{randn_f32, zeros};
use cutile::error::Error;
use cutile::tensor::{CopyToHost, IntoPartition, Partition, Tensor, Unpartition};
use cutile::tile_kernel::TileKernel;
#[cutile::module]
mod fmha_module {
use cutile::core::*;
#[cutile::entry(print_ir=false)]
fn fmha<
const BM: i32, // Q tile size (rows of output we compute)
const BN: i32, // K,V tile size (how many K,V we process at once)
const D: i32, // Head dimension
>(
q: &Tensor<f32, { [-1, -1, -1, -1] }>, // (B, H, M, D)
k: &Tensor<f32, { [-1, -1, -1, -1] }>, // (B, H, N, D)
v: &Tensor<f32, { [-1, -1, -1, -1] }>, // (B, H, N, D)
out: &mut Tensor<f32, { [1, BM, D] }>,
qk_scale: f32,
) {
let pid: (i32, i32, i32) = get_tile_block_id();
let h = get_shape_dim(q.shape(), 1i32);
let batch_idx = pid.0 / h;
let head_idx = pid.0 % h;
let q_m_idx = pid.1;
// Convert to exp2-friendly scale (exp2 is faster than exp on GPU)
let two: Tile<f32, { [] }> = constant(2.0f32, const_shape![]);
let log2: f32 = tile_to_scalar(log(two));
let qk_scale: f32 = qk_scale / log2;
let qk_scale: Tile<f32, { [BM, BN] }> = qk_scale.broadcast(const_shape![BM, BN]);
// Online softmax state
let mut m_i: Tile<f32, { [BM, 1] }> = constant(f32::NEG_INFINITY, const_shape![BM, 1]);
let mut l_i: Tile<f32, { [BM, 1] }> = constant(0.0f32, const_shape![BM, 1]);
let mut acc: Tile<f32, { [BM, D] }> = constant(0.0f32, const_shape![BM, D]);
// Load Q tile once and reuse for all K,V tiles
let q_part: Partition<f32, { [1, 1, BM, D] }> = q.partition(const_shape![1, 1, BM, D]);
let tq: Tile<f32, { [1, 1, BM, D] }> = q_part.load([batch_idx, head_idx, q_m_idx, 0i32]);
let tq: Tile<f32, { [BM, D] }> = tq.reshape(const_shape![BM, D]);
let n: i32 = get_shape_dim(k.shape(), 2i32);
let num_tiles: i32 = ceil_div(n, BN);
let k_part = k.partition(const_shape![1, 1, BN, D]);
let v_part = v.partition(const_shape![1, 1, BN, D]);
let transpose: Array<{ [1, 0] }> = Array::<{ [1, 0] }> { dims: &[1i32, 0i32] };
// Stream through K,V tiles
for j in 0i32..num_tiles {
// Q @ K^T
let k_tile: Tile<f32, { [BN, D] }> = k_part
.load([batch_idx, head_idx, j, 0i32])
.reshape(const_shape![BN, D]);
let k_tile_trans: Tile<f32, { [D, BN] }> = permute(k_tile, transpose);
let qk: Tile<f32, { [BM, BN] }> = constant(0.0f32, const_shape![BM, BN]);
let qk: Tile<f32, { [BM, BN] }> = mma(tq, k_tile_trans, qk);
let qk: Tile<f32, { [BM, BN] }> = qk * qk_scale;
// Update running max
let qk_max: Tile<f32, { [BM] }> = reduce_max(qk, 1);
let qk_max: Tile<f32, { [BM, 1] }> = qk_max.reshape(const_shape![BM, 1]);
let m_ij: Tile<f32, { [BM, 1] }> = max_tile(m_i, qk_max);
let qk = qk - m_ij.broadcast(const_shape![BM, BN]);
// Softmax numerator and correction factor
let p: Tile<f32, { [BM, BN] }> = exp2(qk);
let l_ij: Tile<f32, { [BM] }> = reduce_sum(p, 1);
let l_ij: Tile<f32, { [BM, 1] }> = l_ij.reshape(const_shape![BM, 1]);
let alpha: Tile<f32, { [BM, 1] }> = exp2(m_i - m_ij);
// Update running sum and rescale accumulator
l_i = l_i * alpha + l_ij;
let alpha: Tile<f32, { [BM, D] }> = alpha.broadcast(const_shape![BM, D]);
acc = acc * alpha;
// Accumulate P @ V
let v_tile: Tile<f32, { [1, 1, BN, D] }> = v_part.load([batch_idx, head_idx, j, 0i32]);
let v_tile: Tile<f32, { [BN, D] }> = v_tile.reshape(const_shape![BN, D]);
acc = mma(p, v_tile, acc);
m_i = m_ij;
}
// Final normalization
acc = true_div(acc, l_i.broadcast(const_shape![BM, D]));
let acc = acc.reshape(const_shape![1, BM, D]);
out.store(acc);
}
}
use fmha_module::fmha;
fn main() -> Result<(), Error> {
let ctx = CudaContext::new(0)?;
let stream = ctx.new_stream()?;
let (batch, heads, seq_len, head_dim) = (2, 4, 128, 64);
let (bm, bn) = (64, 32);
let seed = 42u64;
let q: Arc<Tensor<f32>> = randn_f32(0., 1., [batch, heads, seq_len, head_dim], Some(seed))
.sync_on(&stream)?.into();
let k: Arc<Tensor<f32>> = randn_f32(0., 1., [batch, heads, seq_len, head_dim], Some(seed + 1))
.sync_on(&stream)?.into();
let v: Arc<Tensor<f32>> = randn_f32(0., 1., [batch, heads, seq_len, head_dim], Some(seed + 2))
.sync_on(&stream)?.into();
let out: Partition<Tensor<f32>> = zeros([batch * heads, seq_len, head_dim])
.sync_on(&stream)?
.partition([1, bm, head_dim as i32]);
let qk_scale = 1.0 / f32::sqrt(head_dim as f32);
let generics = vec![bm.to_string(), bn.to_string(), head_dim.to_string()];
let (_, _, _, out, _) = fmha(q, k, v, out, qk_scale)
.generics(generics)
.sync_on(&stream)?;
let out_host = out.unpartition().copy_to_host().sync_on(&stream)?;
println!("Output shape: {:?}", out_host.shape());
Ok(())
}
Output:
Output shape: [8, 128, 64]
Key Takeaways#
Concept |
Standard Attention |
Fused Multihead Attention |
|---|---|---|
Memory for scores |
O(N²) — store full matrix |
O(BM × BN) — one tile at a time |
Softmax approach |
Compute all, then normalize |
Online: update as we go |
When max changes |
N/A (have all values) |
Rescale previous results |
K,V access pattern |
Load all at once |
Stream tile by tile |
Low-level optimization |
Manual (shared memory, warps, coalescing) |
Handled by the compiler |
This fused kernel trades extra compute (rescaling) for dramatically less memory, achieving Flash Attention-level performance. The programmer writes the algorithm at the tile level, while the compiler generates the shared memory staging, Tensor Core mappings, and memory coalescing that would otherwise require hundreds of lines of CUDA C++. For long sequences, this means running workloads that would otherwise not fit in GPU memory.
Full Production Example#
A complete implementation with Multi-Query Attention (MQA) support and reference validation:
cargo run --example flash_attention
out_host.shape() = [128, 1024, 64]
diff near zero? true: 5.96e-8
diff near zero? true: 2.98e-8
... (validates against reference for all batch×head combinations)
Exercise 1: Trace the Memory#
Calculate the memory usage for:
Standard attention with N=1024.
Fused multihead attention with N=1024, BM=64, BN=32.
How many times less memory does the fused kernel use?
Exercise 2: Add Causal Masking#
For autoregressive models (like GPT), we only attend to previous positions. Modify the kernel to skip computing attention scores where key_position > query_position.