Adaptive GEMM Grouping¶
Sparse convolution processes each kernel offset independently: for a 3x3x3 kernel, 26 non-identity offsets each produce a separate GEMM call (or CUDA kernel launch). When individual offsets have small pair counts, these launches become inefficient due to low GPU utilization and per-launch overhead.
Adaptive GEMM grouping addresses this by batching similar-sized offsets together, executing them as a single operation instead of many small ones.
How It Works¶
Offset Classification¶
Each kernel offset k has M_k coordinate pairs. Offsets are classified as:
- Large (M_k >=
saturation_m): Already GPU-saturating. Processed individually using the standard per-offset path. - Small (M_k <
saturation_m): Grouped into buckets for batched execution.
Bucketing Algorithm¶
Small offsets are sorted by pair count and greedily merged into buckets. A new offset is added to the current bucket only if the padding waste stays below a threshold:
where B is the bucket size and M_max is the largest pair count in the bucket. The default threshold is 10%, meaning at most 10% of the padded computation is wasted.
Example Distribution (500K points, 3x3x3 kernel)¶
| Offset Type | Count | Pair Count Range |
|---|---|---|
| Identity | 1 | ~500,000 |
| Face | 6 | 5,000 - 50,000 |
| Edge | 12 | 500 - 5,000 |
| Corner | 8 | 50 - 2,000 |
With saturation_m=5000, the 6 face offsets are processed individually while the 20 edge+corner offsets are grouped into a few buckets.
Grouped Backends¶
Four grouped variants are available, each applying the grouping strategy to its respective backend.
Explicit GEMM Grouped (EXPLICIT_GEMM_GROUPED)¶
Gathers input features into a padded [B, M_max, C_in] buffer using vectorized indexing, executes torch.bmm, then scatter-adds results back to the output. No Python for-loops; the gather and scatter use pre-computed flat indices.
Forward:
features[cat_in_map] → padded buffer → torch.bmm(buffer, stacked_weights) → scatter_add to output
Implicit GEMM Grouped (IMPLICIT_GEMM_GROUPED)¶
Uses a custom CUDA kernel (implicit_gemm_grouped) that processes all pairs from a bucket in a single launch. Each pair selects its weight matrix via a weight_idx array, achieving zero padding waste.
The kernel uses a dual-path optimization:
- Uniform blocks (~98% of blocks): All rows in the thread block share the same weight index. Both input features (A) and weights (B) are tiled in shared memory, matching the performance of the ungrouped kernel.
- Boundary blocks (~2%): Rows at weight-group transitions use different weights. A is tiled in shared memory; B is loaded per-thread from L2 cache.
CUDA kernel: A[in_map[i]] @ B[weight_idx[i]] → atomicAdd to C[out_map[i]]
CUTLASS Hybrid Grouped (CUTLASS_GROUPED_HYBRID)¶
Combines CUTLASS fused gather-GEMM-scatter for large offsets with torch.bmm for grouped small offsets. Amortizes CUTLASS per-offset setup cost.
CuTe Grouped (CUTE_GROUPED)¶
Uses the CuTe 3.x programming model for grouped GEMM with batched offset execution. Auto-tunes the MMA tile shape (mma_tile parameter). This is the most frequently optimal forward backend overall (44% win rate), dominating at channels >= 96. It processes all offsets in a single grouped GEMM call, avoiding per-offset launch overhead entirely.
CuTe 3.x grouped GEMM: A[in_map] @ B[weight_idx] → C[out_map]
Algorithm Selection¶
The grouped variants are included in the unified autotuning system. When using AUTO mode, the benchmarking system automatically evaluates grouped variants alongside ungrouped ones and selects the fastest for your data characteristics.
You can also select grouped algorithms explicitly:
from warpconvnet.nn.functional.sparse_conv import (
SPARSE_CONV_FWD_ALGO_MODE,
SPARSE_CONV_BWD_ALGO_MODE,
)
output = spatially_sparse_conv(
input_voxels,
weight,
kernel_size=3,
fwd_algo=SPARSE_CONV_FWD_ALGO_MODE.CUTLASS_GROUPED_HYBRID,
bwd_algo=SPARSE_CONV_BWD_ALGO_MODE.EXPLICIT_GEMM_GROUPED,
)
Or via environment variables:
export WARPCONVNET_FWD_ALGO_MODE=cutlass_grouped_hybrid
export WARPCONVNET_BWD_ALGO_MODE=explicit_gemm_grouped
Grouping Parameters¶
The grouping behavior is controlled by saturation_m (passed via the autotuning parameter grid):
saturation_m |
Effect |
|---|---|
| 1000 | More offsets grouped; beneficial when most offsets are small |
| 5000 (default) | Balanced; groups edge/corner offsets, leaves face offsets individual |
| 20000 | Fewer offsets grouped; only smallest offsets are batched |
The grouping_threshold (default 0.1) controls maximum padding waste per bucket.
Performance Characteristics¶
Grouping helps most when:
- Many small offsets with low pair counts (high per-launch overhead relative to compute)
- Large channel dimensions (C >= 64), where batched GEMM is more efficient
- CUTLASS backend, which has higher per-offset setup cost
Grouping helps least when:
- Dense kernel maps where most offsets are GPU-saturating
- Small channel dimensions (C \<= 16), where launch overhead is already low
- Very few offsets (e.g., 3x1x1 kernel with only 2 non-identity offsets)
Benchmark (RTX 6000 Ada, 3x3x3 kernel)¶
| Backend | 10K coords | 100K coords | 500K coords |
|---|---|---|---|
| Explicit Grouped (C=64) | 1.32x | 1.03x | 1.01x |
| Implicit Grouped (C=32) | 0.65x | 1.22x | 0.44x |
| CUTLASS Hybrid (C=64) | 0.91x | 0.67x | 6.04x |
The CUTLASS hybrid achieves 6x speedup at 500K coordinates with C=64, where per-offset CUTLASS launch overhead dominates the ungrouped path.
Source Files¶
| File | Contents |
|---|---|
warpconvnet/nn/functional/sparse_conv/detail/grouping.py |
Bucketing algorithm, GroupedKernelMap, prepare_grouped_kernel_map() |
warpconvnet/nn/functional/sparse_conv/detail/explicit.py |
_explicit_gemm_forward_grouped(), _explicit_gemm_backward_grouped() |
warpconvnet/nn/functional/sparse_conv/detail/implicit_direct.py |
_implicit_gemm_forward_grouped(), _implicit_gemm_backward_grouped() |
warpconvnet/nn/functional/sparse_conv/detail/cutlass.py |
_cutlass_implicit_gemm_forward_grouped(), _cutlass_implicit_gemm_backward_grouped() |
warpconvnet/nn/functional/sparse_conv/detail/cute_grouped.py |
_cute_grouped_forward(), _cute_grouped_backward() (CuTe 3.x) |
warpconvnet/csrc/implicit_gemm.cu |
implicit_gemm_grouped CUDA kernel |
warpconvnet/csrc/bindings/gemm_bindings.cpp |
pybind11 binding for _C.gemm.implicit_gemm_grouped |
warpconvnet/nn/functional/sparse_conv/detail/unified.py |
Autotuning integration |