Tutorial 5: Fused Softmax#

Softmax is a multi-step operation:

softmax(x)_i = exp(x_i - max(x)) / Σ exp(x_j - max(x))

A naive implementation would use separate kernels for each step. Kernel fusion combines all steps into one kernel to avoid redundant memory traffic. Each kernel launch has overhead (~10–20 microseconds), and each round-trip to global memory is slow.

Comparison of unfused (4 kernels) vs fused (1 kernel) softmax


Why Subtract the Max?#

Subtracting the row maximum before calling exp prevents overflow:

x = [100, 101, 102]
exp(x) = [2.7×10⁴³, 7.3×10⁴³, 2.0×10⁴⁴]  ← Overflow! These are inf in float32

With max subtraction:

x = [100, 101, 102]
max = 102
x - max = [-2, -1, 0]
exp(x - max) = [0.135, 0.368, 1.0]  ← Safe values

The result is mathematically identical:

exp(x_i - max) / Σ exp(x_j - max)
= [exp(x_i) / exp(max)] / [Σ exp(x_j) / exp(max)]
= exp(x_i) / Σ exp(x_j)

The Code#

use cuda_async::device_operation::DeviceOperation;
use cuda_core::CudaContext;
use cutile;
use cutile::api::arange;
use cutile::error::Error;
use cutile::tensor::{IntoPartition, Tensor, ToHostVec, Unpartition};
use std::sync::Arc;

#[cutile::module]
mod my_module {
    use cutile::core::*;

    #[cutile::entry()]
    fn softmax<const BM: i32, const BN: i32>(
        x: &Tensor<f32, { [-1, -1] }>,
        y: &mut Tensor<f32, { [BM, BN] }>,
    ) {
        let tile_x: Tile<f32, { [BM, BN] }> = load_tile_like_2d(x, y);

        // Find max per row (for numerical stability)
        let tile_x_max: Tile<f32, { [BM] }> = reduce_max(tile_x, 1i32);
        let tile_x_max: Tile<f32, { [BM, BN] }> =
            tile_x_max.reshape(const_shape![BM, 1]).broadcast(y.shape());

        // Subtract max and exponentiate
        let num: Tile<f32, { [BM, BN] }> = exp(tile_x - tile_x_max);

        // Sum per row
        let denom: Tile<f32, { [BM] }> = reduce_sum(num, 1);
        let denom = denom.reshape(const_shape![BM, 1]).broadcast(y.shape());

        // Divide
        y.store(num / denom);
    }
}

use my_module::softmax;

fn main() -> Result<(), Error> {
    let ctx = CudaContext::new(0)?;
    let stream = ctx.new_stream()?;

    let (m, n) = (4usize, 8usize);
    let (bm, bn) = (2i32, n as i32);

    let input: Arc<Tensor<f32>> = arange(m * n).sync_on(&stream)?.into();
    let x: Arc<Tensor<f32>> = input.copy_sync(&stream)?.reshape([m, n]).into();
    let y = input.copy_sync(&stream)?.reshape([m, n]).partition([bm, bn]);

    let (_x, y) = softmax(x, y).sync_on(&stream)?;
    let y_host: Vec<f32> = y.unpartition().to_host_vec().sync_on(&stream)?;

    // Each row should sum to 1.0
    for i in (0..y_host.len()).step_by(n) {
        let row_sum: f32 = y_host[i..i+n].iter().sum();
        println!("softmax row sum = {} (expected 1.0)", row_sum);
    }

    Ok(())
}

Output:

softmax row sum = 1 (expected 1.0)
softmax row sum = 1 (expected 1.0)
softmax row sum = 1 (expected 1.0)
softmax row sum = 1 (expected 1.0)

Each row sums to 1.0 — a probability distribution.


Reduction Operations#

Reductions collapse a dimension by applying an operation:

let tile_x_max: Tile<f32, { [BM] }> = reduce_max(tile_x, 1i32);

Reduction operation showing how reduce_max collapses a dimension

After reduction, reshape and broadcast to match the original tile:

// [BM] → [BM, 1] → [BM, BN]
let tile_x_max: Tile<f32, { [BM, BN] }> =
    tile_x_max
    .reshape(const_shape![BM, 1])   // [2] → [2, 1]
    .broadcast(y.shape());           // [2, 1] → [2, 8]

The Fusion Pattern#

Fused kernels load once, compute everything in registers, and store once:

// 1. LOAD once
let tile = load_tile_like_2d(input, output);

// 2. ALL COMPUTATION in registers
let step1 = reduce_max(tile, axis);
let step2 = tile - step1.broadcast(...);
let step3 = exp(step2);
let step4 = reduce_sum(step3, axis);
let result = step3 / step4.broadcast(...);

// 3. STORE once
output.store(result);

All intermediate values stay in registers — no global memory traffic between steps.


Key Takeaways#

Concept

What It Means

Kernel fusion

Combine multiple ops into one kernel to avoid memory traffic

Numerical stability

Subtract max before exp to prevent overflow

Reduction ops

Collapse a dimension: reduce_max, reduce_sum

Reshape + Broadcast

After reduction, reshape to use with the original data


Exercise 1: Column-wise Softmax#

Modify to compute softmax along axis 0 (columns) instead of axis 1 (rows):

let col_max: Tile<f32, { [BN] }> = reduce_max(tile_x, 0i32);

What changes need to be made to the reshape and broadcast calls?

Exercise 2: Temperature Scaling#

Add a temperature parameter for softer or sharper distributions:

fn softmax_with_temp<const BM: i32, const BN: i32>(
    x: &Tensor<f32, {[-1, -1]}>,
    y: &mut Tensor<f32, {[BM, BN]}>,
    temperature: f32,  // Higher = more uniform, Lower = more peaked
) {
    let tile_x = load_tile_like_2d(x, y);
    let scaled = tile_x / temperature.broadcast(y.shape());
    // ... rest of softmax ...
}

Exercise 3: Log-Softmax#

Many ML frameworks use log-softmax because it is more numerically stable. Can you implement it?

Hint
log_softmax(x)_i = x_i - max(x) - log(Σ exp(x_j - max(x)))

You can fuse this too.