Sparse Convolutions¶
WarpConvNet implements spatially sparse convolutions on voxel grids using multiple CUDA backends with automatic algorithm selection.
Overview¶
WarpConvNet provides two types of sparse convolutions:
- Regular Sparse Convolution: General-purpose convolution for feature learning
- Depthwise Sparse Convolution: Channel-wise convolution for efficient feature processing
Both include a unified auto-tuning system that benchmarks algorithm candidates at runtime and caches the best configuration per problem shape.
Two GEMM Operations¶
A sparse convolution backward pass decomposes into two mathematically distinct GEMM operations:
| Operation | Math | Used By | Cache Namespace |
|---|---|---|---|
| AB gather-scatter | D[scatter] = A[gather] @ B |
Forward, dgrad | AB_gather_scatter |
| AtB gather-gather | D = A[gather]^T @ B[gather] |
Wgrad | AtB_gather_gather |
Forward and dgrad share the same kernel (gather input, dense weight, scatter to output). Wgrad uses a reduction kernel (gather both operands, dense output per offset). Each operation is auto-tuned independently.
Convolution Kernel Backends¶
Per-Offset Backends¶
These backends process each kernel offset as a separate GEMM call:
| Backend | Implementation | Strengths |
|---|---|---|
explicit_gemm |
Gather features into a dense buffer, call torch.mm, scatter-add results |
Simple, reliable fallback. No CUDA alignment requirements. |
implicit_gemm |
Custom CUDA kernel that fuses gather, GEMM, and scatter-add into one launch | Best at small channels (C \<= 64) where launch overhead matters less. |
cutlass_implicit_gemm |
CUTLASS fused gather-GEMM-scatter kernel | High throughput at large channels. Auto-pads unaligned channels internally. |
cute_implicit_gemm |
CuTe 3.x fused gather-GEMM-scatter kernel | Vectorized A-operand loads (cp.async). Competitive at small-medium channels. |
Fused Multi-Offset Backends¶
These backends process multiple (or all) kernel offsets in a single launch:
| Backend | Implementation | Strengths |
|---|---|---|
cute_grouped |
CuTe 3.x grouped GEMM — all offsets in one launch via binary-search dispatch | Dominant wgrad winner (64%). Amortizes launch overhead at medium-large channels. |
cutlass_grouped_hybrid |
CUTLASS for large offsets + torch.bmm for grouped small offsets |
Strong at large N with medium-large channels. |
mask_implicit_gemm |
Mask-based fused kernel — iterates all K offsets per output row using bitmask skipping | Dominant AB winner (56% fwd, 74% dgrad). No atomicAdd. CuTe tensor core MMA. |
How mask_implicit_gemm Works¶
Unlike per-offset and grouped backends that launch separate work per offset, the mask kernel processes all K offsets in a single launch. For each output row:
- Look up which offsets are active via a bitmask (
pair_mask) - For each active offset, gather from input and accumulate with the offset's weight
- Write output directly — no atomicAdd needed since each output row is exclusive
For dgrad, a reverse pair_table is constructed so the same forward kernel can be reused with swapped dimensions, avoiding atomicAdd entirely (~2x speedup over the old atomicAdd dgrad).
Auto-Tuning System¶
How It Works¶
On the first forward (or backward) pass for a new problem shape, WarpConvNet:
- Selects a set of candidate algorithms based on the convolution dimensions (N, C_in, C_out, K)
- Runs each candidate with warmup + timed iterations
- Picks the fastest and caches the result keyed by
(log10(N_in), log10(N_out), C_in, C_out, K, dtype, SM) - Subsequent calls with the same shape hit the cache instantly
Results are persisted to ~/.cache/warpconvnet/benchmark_cache_generic.msgpack and survive across Python sessions.
Adaptive Candidate Selection¶
The candidate set adapts to the problem dimensions. Based on benchmark analysis of 148 configs (SM 8.9, cuBLAS 12.9.1.4):
AB gather-scatter (forward + dgrad) — 7-11 candidates:
| N range | ch \<= 256 | ch > 256 |
|---|---|---|
| Small (N \<= 10K) | mask (92-100%) | cute_grouped (58%), mask (25%) |
| Medium (10K-100K) | mask (69%), cutlass (27%) | cutlass_grouped (67%) |
| Large (N > 100K) | mask/cutlass_grouped/cutlass | cutlass (100%) |
AtB gather-gather (wgrad) — 5-8 candidates:
| N range | ch \<= 64 | ch > 64 |
|---|---|---|
| Small (N \<= 10K) | cute_grouped (57%), implicit_gemm (36%) | cute_grouped (100%) |
| Medium (10K-100K) | cute_grouped (57%), explicit_grouped (43%) | cute_grouped (77%) |
| Large (N > 100K) | cutlass_grouped (57%), explicit_grouped (36%) | cute_grouped (100%) |
Algorithm Modes¶
| Mode | AB Candidates | AtB Candidates | Use Case |
|---|---|---|---|
auto (default) |
7-11 (adaptive) | 5-8 (adaptive) | Normal usage. Covers all winning algorithms. |
trimmed |
11 | 27 | Broader search, excludes known dead-weight. |
all |
23 | 35 | Exhaustive. For benchmarking or new hardware. |
# Default: adaptive reduced set (recommended)
export WARPCONVNET_FWD_ALGO_MODE=auto
# Exhaustive: benchmark every algorithm variant
export WARPCONVNET_FWD_ALGO_MODE=all
# Specific algorithm (no benchmarking, just use it)
export WARPCONVNET_FWD_ALGO_MODE=mask_implicit_gemm
# Algorithm list (benchmark only these)
export WARPCONVNET_FWD_ALGO_MODE="[mask_implicit_gemm,cutlass_implicit_gemm]"
The same options apply to WARPCONVNET_BWD_ALGO_MODE (controls wgrad AtB algorithm).
Usage¶
Basic Usage¶
from warpconvnet.nn.modules.sparse_conv import SpatiallySparseConv
# Auto mode (default) -- auto-tunes on first call, cached thereafter
conv = SpatiallySparseConv(
in_channels=64,
out_channels=128,
kernel_size=3,
)
output = conv(input_voxels)
Functional API¶
from warpconvnet.nn.functional import spatially_sparse_conv
output = spatially_sparse_conv(
input_voxels,
weight,
kernel_size=3,
)
Specifying Algorithms¶
Forward, dgrad, and wgrad can each be controlled independently:
# Different algorithms for each operation
output = spatially_sparse_conv(
input_voxels, weight, kernel_size=3,
fwd_algo="mask_implicit_gemm", # AB gather-scatter for forward
dgrad_algo="mask_implicit_gemm", # AB gather-scatter for dgrad
wgrad_algo="cute_grouped", # AtB gather-gather for wgrad
)
# Algorithm list -- benchmarks only these
output = spatially_sparse_conv(
input_voxels, weight, kernel_size=3,
fwd_algo=["mask_implicit_gemm", "cutlass_implicit_gemm"],
dgrad_algo=["mask_implicit_gemm", "cute_grouped"],
wgrad_algo=["cute_grouped", "cutlass_grouped_hybrid"],
)
Depthwise Convolution¶
from warpconvnet.nn.functional import spatially_sparse_depthwise_conv
output = spatially_sparse_depthwise_conv(
input_features,
depthwise_weight,
kernel_map,
num_out_coords,
)
Depthwise convolution has its own algorithm modes (explicit_gemm, implicit_gemm, auto) controlled by:
export WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE=auto
export WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE=auto
Environment Variables¶
Algorithm Selection¶
# AB gather-scatter algorithm for forward and dgrad (default: auto)
export WARPCONVNET_FWD_ALGO_MODE=auto
# AtB gather-gather algorithm for wgrad (default: auto)
export WARPCONVNET_BWD_ALGO_MODE=auto
Accepted values: auto, all, trimmed, any single algorithm name, or a bracket list like [algo1,algo2].
Valid algorithm names: explicit_gemm, implicit_gemm, cutlass_implicit_gemm, cute_implicit_gemm, explicit_gemm_grouped, implicit_gemm_grouped, cutlass_grouped_hybrid, cute_grouped, mask_implicit_gemm.
Cache and Logging¶
# Cache directory (default: ~/.cache/warpconvnet)
export WARPCONVNET_BENCHMARK_CACHE_DIR=~/.cache/warpconvnet
# Suppress auto-tuning logs (default: true)
export WARPCONVNET_AUTOTUNE_LOG=false
Inspecting the Cache¶
Use scripts/inspect_benchmark_cache.py to view cached results:
python scripts/inspect_benchmark_cache.py
python scripts/inspect_benchmark_cache.py namespace=AB_gather_scatter --best-only
Use scripts/analyze_autotune_cache.py to generate statistical analysis of algorithm win rates:
python scripts/analyze_autotune_cache.py --markdown --output analysis.md
See Inspecting the Benchmark Cache for details.
Performance Characteristics¶
When Each Backend Wins¶
Based on empirical analysis on RTX 6000 Ada with cuBLAS 12.9.1.4:
| Condition | Best AB Backend | Best AtB Backend |
|---|---|---|
| ch \<= 256, any N | mask_implicit_gemm |
cute_grouped |
| ch > 256, small N | cute_grouped |
cute_grouped |
| ch > 256, large N | cutlass_implicit_gemm |
cute_grouped |
| ch \<= 64, small N (wgrad) | — | implicit_gemm or explicit_grouped |
Troubleshooting¶
Slow first run: Normal — auto-tuning benchmarks candidates. Subsequent runs use the cache. Use auto mode (not all) to minimize tuning time. To skip auto-tuning entirely, pre-populate the cache before your first run.
Clear cache when switching GPUs:
rm -rf ~/.cache/warpconvnet/
CUTLASS not available: Some backends require specific GPU compute capability. Fall back to:
export WARPCONVNET_FWD_ALGO_MODE="[explicit_gemm,implicit_gemm,mask_implicit_gemm]"
Source Files¶
| File | Contents |
|---|---|
warpconvnet/nn/functional/sparse_conv/detail/unified.py |
Auto-tuning dispatch, config construction |
warpconvnet/nn/functional/sparse_conv/detail/algo_params.py |
Adaptive candidate selection, algorithm enums |
warpconvnet/nn/functional/sparse_conv/detail/autotune.py |
Benchmark runners, cache init/merge |
warpconvnet/nn/functional/sparse_conv/detail/dispatch.py |
Algorithm execution dispatch |
warpconvnet/nn/functional/sparse_conv/detail/mask_gemm.py |
Mask-based fused kernel dispatch, reverse pair_table |
warpconvnet/nn/functional/sparse_conv/detail/cute_grouped.py |
CuTe grouped GEMM (AB + TrAB) |
warpconvnet/nn/functional/sparse_conv/detail/cutlass.py |
CUTLASS per-offset gather-scatter |
warpconvnet/nn/functional/sparse_conv/detail/explicit.py |
Explicit GEMM via cuBLAS |
warpconvnet/nn/functional/sparse_conv/detail/implicit_direct.py |
SIMT implicit GEMM |
warpconvnet/utils/benchmark_cache.py |
Generic benchmark cache with persistence |
warpconvnet/constants.py |
Environment variable parsing |