Neural Networks

warpconvnet.nn

BilateralFilter

Bases: Module

KNN/radius bilateral filter (Gaussian on xyz + feat).

Source code in warpconvnet/nn/modules/bilateral.py
class BilateralFilter(nn.Module):
    """KNN/radius bilateral filter (Gaussian on xyz + feat)."""

    def __init__(
        self,
        sigma_xyz: float = 0.05,
        sigma_feat: float = 20.0,
        k: int = 16,
        mode: str = "knn",
        radius_mult: float = 3.0,
        chunk_size: int = 32768,
    ):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat
        self.k = k
        self.mode = mode
        self.radius_mult = radius_mult
        self.chunk_size = chunk_size

    def forward(
        self,
        src_xyz: Tensor,
        src_feat: Tensor,
        src_value: Tensor,
        query_xyz: Tensor | None = None,
        query_feat: Tensor | None = None,
    ) -> Tensor:
        return bilateral_filter(
            src_xyz=src_xyz,
            src_feat=src_feat,
            src_value=src_value,
            query_xyz=query_xyz,
            query_feat=query_feat,
            sigma_xyz=self.sigma_xyz,
            sigma_feat=self.sigma_feat,
            k=self.k,
            mode=self.mode,
            radius_mult=self.radius_mult,
            chunk_size=self.chunk_size,
        )

BilateralFilterGrid

Bases: Module

Splat-blur-slice bilateral filter on a sparse d-cube grid (Barron-style).

Source code in warpconvnet/nn/modules/bilateral.py
class BilateralFilterGrid(nn.Module):
    """Splat-blur-slice bilateral filter on a sparse d-cube grid (Barron-style)."""

    def __init__(self, sigma_xyz: float = 0.05, sigma_feat: float = 20.0):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat

    def forward(self, src_xyz: Tensor, src_feat: Tensor, src_value: Tensor) -> Tensor:
        return bilateral_filter_grid(
            src_xyz,
            src_feat,
            src_value,
            sigma_xyz=self.sigma_xyz,
            sigma_feat=self.sigma_feat,
        )

BilateralFilterGridCached

Bases: Module

Build-once / filter-many sparse d-cube bilateral grid.

Positions fixed across calls (e.g., per-frame in video), only features differ. Call build_grid once, then forward repeatedly.

Source code in warpconvnet/nn/modules/bilateral.py
class BilateralFilterGridCached(nn.Module):
    """Build-once / filter-many sparse d-cube bilateral grid.

    Positions fixed across calls (e.g., per-frame in video), only features
    differ. Call ``build_grid`` once, then ``forward`` repeatedly.
    """

    def __init__(self, sigma_xyz: float = 0.05, sigma_feat: float = 20.0):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat
        self._grid: BilateralGrid | None = None

    def build_grid(self, src_xyz: Tensor, src_feat: Tensor) -> BilateralFilterGridCached:
        pos = torch.cat([src_xyz / self.sigma_xyz, src_feat / self.sigma_feat], dim=-1)
        self._grid = BilateralGrid.build(pos)
        return self

    def forward(self, src_value: Tensor) -> Tensor:
        if self._grid is None:
            raise RuntimeError("Call build_grid(src_xyz, src_feat) before forward().")
        return self._grid.filter(src_value, normalize=True)

    @property
    def num_vertices(self) -> int:
        if self._grid is None:
            return 0
        return self._grid.num_vertices

BilateralPermutohedralFilter

Bases: Module

Bilateral (xyz + color) permutohedral filter.

Lattice coords = concat(xyz / sigma_xyz, feat / sigma_feat). xyz alone is just a Gaussian blur; feat (e.g. RGB) is what makes it edge-preserving. Constraint: D_xyz + D_feat <= 6 (lattice axes capped at 7).

Source code in warpconvnet/nn/modules/permutohedral.py
class BilateralPermutohedralFilter(nn.Module):
    """Bilateral (xyz + color) permutohedral filter.

    Lattice coords = concat(xyz / sigma_xyz, feat / sigma_feat). xyz alone is
    just a Gaussian blur; feat (e.g. RGB) is what makes it edge-preserving.
    Constraint: D_xyz + D_feat <= 6 (lattice axes capped at 7).
    """

    def __init__(self, sigma_xyz: float = 0.05, sigma_feat: float = 20.0):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat

    def forward(
        self,
        src_xyz: Tensor,
        src_feat: Tensor,
        src_value: Tensor,
        query_xyz: Tensor | None = None,
        query_feat: Tensor | None = None,
        *,
        normalize: bool = True,
    ) -> Tensor:
        return bilateral_permutohedral_filter(
            src_xyz=src_xyz,
            src_feat=src_feat,
            src_value=src_value,
            sigma_xyz=self.sigma_xyz,
            sigma_feat=self.sigma_feat,
            query_xyz=query_xyz,
            query_feat=query_feat,
            normalize=normalize,
        )

BilateralPermutohedralFilterCached

Bases: Module

Build-once / filter-many bilateral permutohedral.

For iterative bilateral solving on fixed (xyz, feat). Call build_lattice with the source xyz + feat once, then call forward repeatedly with different value tensors.

Source code in warpconvnet/nn/modules/permutohedral.py
class BilateralPermutohedralFilterCached(nn.Module):
    """Build-once / filter-many bilateral permutohedral.

    For iterative bilateral solving on fixed (xyz, feat). Call ``build_lattice``
    with the source xyz + feat once, then call forward repeatedly with
    different value tensors.
    """

    def __init__(self, sigma_xyz: float = 0.05, sigma_feat: float = 20.0):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat
        self._lattice: PermutohedralLattice | None = None

    def build_lattice(
        self, src_xyz: Tensor, src_feat: Tensor
    ) -> BilateralPermutohedralFilterCached:
        d_xyz = src_xyz.shape[1]
        d_feat = src_feat.shape[1]
        if d_xyz + d_feat > 6:
            raise ValueError(f"D_xyz + D_feat = {d_xyz + d_feat} > 6; lattice axes capped at 7.")
        positions = torch.cat(
            [src_xyz / self.sigma_xyz, src_feat / self.sigma_feat],
            dim=-1,
        )
        self._lattice = PermutohedralLattice.build(positions)
        return self

    def forward(
        self,
        src_value: Tensor,
        query_xyz: Tensor | None = None,
        query_feat: Tensor | None = None,
        *,
        normalize: bool = True,
    ) -> Tensor:
        if self._lattice is None:
            raise RuntimeError("Call build_lattice(src_xyz, src_feat) before forward().")
        if query_xyz is None and query_feat is None:
            qp = None
        else:
            if query_xyz is None or query_feat is None:
                raise ValueError("Pass both query_xyz and query_feat, or neither.")
            qp = torch.cat(
                [query_xyz / self.sigma_xyz, query_feat / self.sigma_feat],
                dim=-1,
            )
        return self._lattice.filter(src_value, query_positions=qp, normalize=normalize)

    @property
    def num_vertices(self) -> int:
        if self._lattice is None:
            return 0
        return int(self._lattice.unique_keys.shape[0])

FastBilateralSolver

Bases: Module

Confidence-weighted bilateral smoothing via PCG (Barron & Poole 2015).

Source code in warpconvnet/nn/modules/bilateral.py
class FastBilateralSolver(nn.Module):
    """Confidence-weighted bilateral smoothing via PCG (Barron & Poole 2015)."""

    def __init__(
        self,
        sigma_xyz: float = 0.05,
        sigma_feat: float = 20.0,
        lam: float = 128.0,
        max_iters: int = 25,
        tol: float = 1e-5,
    ):
        super().__init__()
        self.sigma_xyz = sigma_xyz
        self.sigma_feat = sigma_feat
        self.lam = lam
        self.max_iters = max_iters
        self.tol = tol

    def forward(
        self,
        src_xyz: Tensor,
        src_feat: Tensor,
        target: Tensor,
        confidence: Tensor,
    ) -> Tensor:
        return fast_bilateral_solver(
            src_xyz=src_xyz,
            src_feat=src_feat,
            target=target,
            confidence=confidence,
            sigma_xyz=self.sigma_xyz,
            sigma_feat=self.sigma_feat,
            lam=self.lam,
            max_iters=self.max_iters,
            tol=self.tol,
        )

PermutohedralFilter

Bases: Module

Gaussian filter via permutohedral lattice (Adams, Baek, Davis 2010).

Pre-scales positions by sigmas (per-axis) or sigma (scalar) and runs splat -> blur -> slice. Lattice coords have d+1 axes so the input feature dim is bounded to d <= 6 by the underlying PackedHashTable128.

Source code in warpconvnet/nn/modules/permutohedral.py
class PermutohedralFilter(nn.Module):
    """Gaussian filter via permutohedral lattice (Adams, Baek, Davis 2010).

    Pre-scales positions by ``sigmas`` (per-axis) or ``sigma`` (scalar) and
    runs splat -> blur -> slice. Lattice coords have d+1 axes so the input
    feature dim is bounded to d <= 6 by the underlying PackedHashTable128.
    """

    def __init__(
        self,
        sigma: float | None = None,
        sigmas: Sequence[float] | None = None,
    ):
        super().__init__()
        if (sigma is None) == (sigmas is None):
            raise ValueError("Pass exactly one of sigma (scalar) or sigmas (per-axis).")
        self.sigma = sigma
        if sigmas is not None:
            self.register_buffer(
                "sigmas",
                torch.as_tensor(list(sigmas), dtype=torch.float32),
            )
        else:
            self.sigmas = None

    def forward(
        self,
        positions: Tensor,
        features: Tensor,
        query_positions: Tensor | None = None,
    ) -> Tensor:
        sigmas = self.sigmas
        if sigmas is not None:
            sigmas = sigmas.to(device=positions.device, dtype=positions.dtype)
        return permutohedral_filter(
            positions=positions,
            features=features,
            sigmas=sigmas,
            sigma=self.sigma,
            query_positions=query_positions,
        )

PermutohedralFilterCached

Bases: Module

Build-once / filter-many permutohedral lattice.

For pipelines where positions are fixed (video frame sequence, iterative bilateral solving) and only features change. Call build_lattice once, then forward repeatedly with different feature tensors.

Source code in warpconvnet/nn/modules/permutohedral.py
class PermutohedralFilterCached(nn.Module):
    """Build-once / filter-many permutohedral lattice.

    For pipelines where positions are fixed (video frame sequence, iterative
    bilateral solving) and only features change. Call ``build_lattice`` once,
    then ``forward`` repeatedly with different feature tensors.
    """

    def __init__(
        self,
        sigma: float | None = None,
        sigmas: Sequence[float] | None = None,
    ):
        super().__init__()
        if (sigma is None) == (sigmas is None):
            raise ValueError("Pass exactly one of sigma (scalar) or sigmas (per-axis).")
        self.sigma = sigma
        if sigmas is not None:
            self.register_buffer(
                "sigmas",
                torch.as_tensor(list(sigmas), dtype=torch.float32),
            )
        else:
            self.sigmas = None
        self._lattice: PermutohedralLattice | None = None

    def build_lattice(self, positions: Tensor) -> PermutohedralFilterCached:
        sigmas = self.sigmas
        if sigmas is not None:
            sigmas = sigmas.to(device=positions.device, dtype=positions.dtype)
            scaled = positions / sigmas
        else:
            scaled = positions / self.sigma
        self._lattice = PermutohedralLattice.build(scaled)
        return self

    def forward(
        self,
        features: Tensor,
        query_positions: Tensor | None = None,
    ) -> Tensor:
        if self._lattice is None:
            raise RuntimeError("Call build_lattice(positions) before forward().")
        if query_positions is not None:
            sigmas = self.sigmas
            if sigmas is not None:
                sigmas = sigmas.to(device=query_positions.device, dtype=query_positions.dtype)
                query_positions = query_positions / sigmas
            else:
                query_positions = query_positions / self.sigma
        return self._lattice.filter(features, query_positions=query_positions)

    @property
    def num_vertices(self) -> int:
        if self._lattice is None:
            return 0
        return int(self._lattice.unique_keys.shape[0])

Modules

Activations

warpconvnet.nn.modules.activations

DropPath

Bases: BaseSpatialModule

Stochastic depth regularization.

Parameters:
  • drop_prob (float, default: 0.0 ) –

    Probability of dropping a sample. Defaults to 0.0.

  • scale_by_keep (bool, default: True ) –

    If True the output is scaled by 1 - drop_prob. Defaults to True.

Source code in warpconvnet/nn/modules/activations.py
class DropPath(BaseSpatialModule):
    """Stochastic depth regularization.

    Parameters
    ----------
    drop_prob : float, optional
        Probability of dropping a sample. Defaults to ``0.0``.
    scale_by_keep : bool, optional
        If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
    """

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x: Union[Geometry, Tensor]):  # noqa: F821
        if isinstance(x, Geometry):
            return x.replace(
                batched_features=drop_path(
                    x.feature_tensor, self.drop_prob, self.training, self.scale_by_keep
                )
            )
        else:
            return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob, 3): 0.3f}"

ELU

Bases: BaseSpatialModule

Applies the ELU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class ELU(BaseSpatialModule):
    """Applies the ELU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return elu(input)

GELU

Bases: BaseSpatialModule

Applies the GELU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class GELU(BaseSpatialModule):
    """Applies the GELU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return gelu(input)

LeakyReLU

Bases: Module

Applies the LeakyReLU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class LeakyReLU(nn.Module):
    """Applies the LeakyReLU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return leaky_relu(input)

LogSoftmax

Bases: BaseSpatialModule

Applies the log_softmax activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class LogSoftmax(BaseSpatialModule):
    """Applies the ``log_softmax`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return log_softmax(input)

ReLU

Bases: BaseSpatialModule

Apply the ReLU activation to Geometry features.

Parameters:
  • inplace (bool, default: False ) –

    Whether to perform the operation in-place. Defaults to False.

Source code in warpconvnet/nn/modules/activations.py
class ReLU(BaseSpatialModule):
    """Apply the ReLU activation to ``Geometry`` features.

    Parameters
    ----------
    inplace : bool, optional
        Whether to perform the operation in-place. Defaults to ``False``.
    """

    def __init__(self, inplace: bool = False):
        super().__init__()
        self.relu = nn.ReLU(inplace=inplace)

    def __repr__(self):
        return f"{self.__class__.__name__}(inplace={self.relu.inplace})"

    def forward(self, input: Geometry):  # noqa: F821
        return apply_feature_transform(input, self.relu)

SiLU

Bases: BaseSpatialModule

Applies the SiLU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class SiLU(BaseSpatialModule):
    """Applies the SiLU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return silu(input)

Sigmoid

Bases: BaseSpatialModule

Applies the sigmoid activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Sigmoid(BaseSpatialModule):
    """Applies the sigmoid activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return sigmoid(input)

Softmax

Bases: BaseSpatialModule

Applies the softmax activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Softmax(BaseSpatialModule):
    """Applies the ``softmax`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return softmax(input)

Tanh

Bases: BaseSpatialModule

Applies the tanh activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Tanh(BaseSpatialModule):
    """Applies the ``tanh`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return tanh(input)

drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True)

Apply stochastic depth to the input tensor.

Parameters:
  • x (``torch.Tensor``) –

    Input tensor to apply stochastic depth to.

  • drop_prob (float, default: 0.0 ) –

    Probability of dropping a sample. Defaults to 0.0.

  • training (bool, default: False ) –

    Whether the module is in training mode. Defaults to False.

  • scale_by_keep (bool, default: True ) –

    If True the output is scaled by 1 - drop_prob. Defaults to True.

Source code in warpconvnet/nn/modules/activations.py
def drop_path(
    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
    """Apply stochastic depth to the input tensor.

    Parameters
    ----------
    x : ``torch.Tensor``
        Input tensor to apply stochastic depth to.
    drop_prob : float, optional
        Probability of dropping a sample. Defaults to ``0.0``.
    training : bool, optional
        Whether the module is in training mode. Defaults to ``False``.
    scale_by_keep : bool, optional
        If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

Attention

warpconvnet.nn.modules.attention

Attention

Bases: Module

Source code in warpconvnet/nn/modules/attention.py
class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        enable_flash: bool = True,
        use_batched_qkv: bool = True,
    ):
        """
        Attention module with optional batched QKV for Muon optimization.

        Args:
            dim: Input feature dimension
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in QKV projection
            qk_scale: Scale factor for attention scores
            attn_drop: Attention dropout rate
            proj_drop: Output projection dropout rate
            enable_flash: Whether to use flash attention
            use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
                           for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
                           more effectively than the concatenated [dim, 3*dim] matrix.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.enable_flash = enable_flash
        self.use_batched_qkv = use_batched_qkv

        if enable_flash:
            assert flash_attn is not None, "Make sure flash_attn is installed."
            self.attn_drop_p = attn_drop
        else:
            self.attn_drop = nn.Dropout(attn_drop)

        if use_batched_qkv:
            # Use BatchedLinear for Muon-friendly QKV projection
            self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
        else:
            # Original single linear layer approach
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(
        self,
        x: Float[Tensor, "B N C"],  # noqa: F821
        pos_enc: Optional[Float[Tensor, "B N C"]] = None,  # noqa: F821
        mask: Optional[Float[Tensor, "B N N"]] = None,  # noqa: F821
        num_points: Optional[Int[Tensor, "B"]] = None,  # noqa: F821
    ) -> Float[Tensor, "B N C"]:
        B, N, C = x.shape

        # Compute QKV with unified approach
        if pos_enc is not None and self.enable_flash:
            # Add positional encoding to input before QKV projection for flash attention
            qkv = self.qkv(x + pos_enc).reshape(B, N, 3, C)
        else:
            qkv = self.qkv(x).reshape(B, N, 3, C)

        # Reshape to [B, N, 3, num_heads, head_dim]
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)

        if not self.enable_flash:
            qkv = qkv.permute(2, 0, 3, 1, 4)
            q, k, v = (
                qkv[0],
                qkv[1],
                qkv[2],
            )  # make torchscript happy (cannot use tensor as tuple)

            # Apply positional encoding to the query and key (non-flash path)
            if pos_enc is not None:
                q = q + pos_enc.unsqueeze(1)
                k = k + pos_enc.unsqueeze(1)

            attn = (q @ k.transpose(-2, -1)) * self.scale
            if mask is not None:
                attn = attn + mask

            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

            x = attn @ v
            x = x.transpose(1, 2).reshape(B, N, C)
        else:
            # Flash attention path
            # Flash attention - preserve original dtype if possible
            original_dtype = qkv.dtype
            if qkv.dtype not in [torch.float16, torch.bfloat16]:
                # Convert to half precision for flash attention
                qkv_flash = qkv.half()
            else:
                qkv_flash = qkv

            x = flash_attn.flash_attn_qkvpacked_func(
                qkv_flash,
                dropout_p=self.attn_drop_p if self.training else 0.0,
                softmax_scale=self.scale,
            ).reshape(B, N, C)

            # Convert back to original dtype if necessary
            if x.dtype != original_dtype:
                x = x.to(original_dtype)

        x = self.proj(x)
        x = self.proj_drop(x)

        if num_points is not None:
            x = zero_out_points(x, num_points)
        return x

PatchAttention

Bases: BaseSpatialModule

Source code in warpconvnet/nn/modules/attention.py
class PatchAttention(BaseSpatialModule):
    def __init__(
        self,
        dim: int,
        patch_size: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        order: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
        use_batched_qkv: bool = True,
        use_rope: bool = False,
        rope_base: int = 10_000,
    ):
        """
        Patch attention module with optional batched QKV for Muon optimization.

        Args:
            dim: Input feature dimension
            patch_size: Size of patches for attention computation
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in QKV projection
            qk_scale: Scale factor for attention scores
            attn_drop: Attention dropout rate
            proj_drop: Output projection dropout rate
            order: Point ordering for patch generation
            use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
                           for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
                           more effectively than the concatenated [dim, 3*dim] matrix.
            use_rope: If True, apply 3D RoPE to Q and K via the fused CUDA
                kernel. Uses point-cloud coordinates for the rotation phase.
            rope_base: RoPE base. Use
                `warpconvnet.nn.modules.rope.suggest_voxel_rope_base` for a
                window-aware default.
        """
        super().__init__()
        self.patch_size = patch_size
        self.num_heads = num_heads
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.use_batched_qkv = use_batched_qkv

        if use_batched_qkv:
            # Use BatchedLinear for Muon-friendly QKV projection
            self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
        else:
            # Original single linear layer approach
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.order = order
        assert flash_attn is not None, "Make sure flash_attn is installed."
        self.attn_drop_p = attn_drop

        self.use_rope = use_rope
        if use_rope:
            self.rope = VoxelRotaryPositionalEmbeddings(
                dim=dim,
                num_heads=num_heads,
                base=rope_base,
            )

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _offset_to_attn_offset(
        self, offsets: Int[Tensor, "B+1"], patch_size: Optional[int] = None
    ) -> Int[Tensor, "B"]:
        """
        Convert offsets to cumulative attention offsets required for flash attention.
        If the patch size is 8 and the offsets are [0, 3, 11, 40] (3 batches),
        the cumulative attention offsets are [0, 3, 3 + 8 = 11, 11 + 8, 11 + 8 + 8, 11 + 8 + 8 + 8, 40].

        Args:
            offsets: (B+1)
            patch_size: Optional[int]
        Returns:
            cum_seqlens: M
        """
        patch_size = patch_size or self.patch_size
        counts = torch.diff(offsets)

        # Calculate number of patches per batch using ceil division
        num_patches_per_batch = (counts + patch_size - 1) // patch_size

        # Fast path: if no patches, return original offsets
        if num_patches_per_batch.sum() == 0:
            return offsets

        # Calculate how many elements each batch contributes (num_patches)
        # We generate the start indices for each patch. The final end point is added later.
        elements_per_batch = num_patches_per_batch

        # Create indices for which batch each element belongs to
        batch_indices = torch.repeat_interleave(
            torch.arange(len(offsets) - 1, device=offsets.device), elements_per_batch
        )

        # Create indices for position within each batch's sequence (0, 1, 2, ...)
        within_batch_indices = torch.cat(
            [
                torch.arange(n, device=offsets.device, dtype=offsets.dtype)
                for n in num_patches_per_batch
            ]
        )

        # Calculate the actual offsets: start_offset + patch_index * patch_size
        start_offsets = offsets[:-1][batch_indices]
        patch_contributions = within_batch_indices * patch_size
        result_middle = start_offsets + patch_contributions

        # Add the final offset
        result = torch.cat([result_middle, offsets[-1].unsqueeze(0)])

        return result.contiguous()

    def forward(self, x: Geometry, order: Optional[POINT_ORDERING] = None) -> Geometry:
        # Assert that x is serialized
        K = self.patch_size

        feats = x.features
        coords = x.coordinate_tensor
        M, C = feats.shape[:2]
        inverse_perm = None
        order = order or self.order
        if not hasattr(x, "order") or (order != x.order):
            # Generate new ordering and inverse permutation
            code_result = encode(
                x.coordinate_tensor,
                batch_offsets=x.offsets,
                order=order,
                return_perm=True,
                return_inverse=True,
            )
            feats = feats[code_result.perm]
            if self.use_rope:
                coords = coords[code_result.perm]
            inverse_perm = code_result.inverse_perm

        if self.use_rope:
            qkv = self.rope(self.qkv(feats), coords)
        else:
            qkv = self.qkv(feats).reshape(M, 3, self.num_heads, C // self.num_heads)
        if qkv.dtype not in [torch.float16, torch.bfloat16]:
            qkv = qkv.to(torch.float16)

        attn_offsets = self._offset_to_attn_offset(x.offsets, K).to(
            device=qkv.device, dtype=torch.int32
        )
        # Warning: When the loss is NaN, this module will fail during backward with
        # index out of bounds error.
        # e.g. /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [192,0,0], thread: [32,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "
        # https://discuss.pytorch.org/t/scattergatherkernel-cu-assertion-idx-dim-0-idx-dim-index-size-index-out-of-bounds/195356
        out_feat = flash_attn.flash_attn_varlen_qkvpacked_func(
            qkv,
            attn_offsets,
            max_seqlen=K,
            dropout_p=self.attn_drop_p if self.training else 0.0,
            softmax_scale=self.scale,
        )
        out_feat = out_feat.reshape(M, C).to(feats.dtype)

        out_feat = self.proj(out_feat)
        out_feat = self.proj_drop(out_feat)

        if inverse_perm is not None:
            out_feat = out_feat[inverse_perm]

        return x.replace(batched_features=out_feat.to(feats.dtype))

offset_to_mask(x: Float[Tensor, 'B M C'], offsets: Float[Tensor, B + 1], max_num_points: int, dtype: torch.dtype = torch.bool) -> Float[Tensor, 'B 1 M M']

Create a mask for the points in the batch.

Source code in warpconvnet/nn/modules/attention.py
def offset_to_mask(
    x: Float[Tensor, "B M C"],  # noqa: F821
    offsets: Float[Tensor, "B+1"],  # noqa: F821
    max_num_points: int,  # noqa: F821
    dtype: torch.dtype = torch.bool,
) -> Float[Tensor, "B 1 M M"]:  # noqa: F821
    """
    Create a mask for the points in the batch.
    """
    B = x.shape[0]
    assert B == offsets.shape[0] - 1
    mask = torch.zeros(
        (B, 1, max_num_points, max_num_points),
        dtype=dtype,
        device=x.device,
    )
    num_points = offsets.diff()
    if dtype == torch.bool:
        for b in range(B):
            # mask[b, :, : num_points[b], : num_points[b]] = True
            mask[b, :, :, : num_points[b]] = True
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")
    return mask

zero_out_points(x: Float[Tensor, 'B N C'], num_points: Int[Tensor, B]) -> Float[Tensor, 'B N C']

Zero out the points in the batch.

Source code in warpconvnet/nn/modules/attention.py
def zero_out_points(
    x: Float[Tensor, "B N C"], num_points: Int[Tensor, "B"]  # noqa: F821
) -> Float[Tensor, "B N C"]:  # noqa: F821
    """
    Zero out the points in the batch.
    """
    for b in range(num_points.shape[0]):
        x[b, num_points[b] :] = 0
    return x

Base module

warpconvnet.nn.modules.base_module

BaseSpatialModel

Bases: BaseSpatialModule

Base model class.

Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModel(BaseSpatialModule):
    """Base model class."""

    def data_dict_to_input(self, data_dict, **kwargs) -> Any:
        """Convert data dictionary to appropriate input for the model."""
        raise NotImplementedError

    def loss_dict(self, data_dict, **kwargs) -> Dict:
        """Compute the loss dictionary for the model."""
        raise NotImplementedError

    @torch.no_grad()
    def eval_dict(self, data_dict, **kwargs) -> Dict:
        """Compute the evaluation dictionary for the model."""
        raise NotImplementedError

    def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
        """Compute the image dict and pointcloud dict for the model."""
        raise NotImplementedError

data_dict_to_input(data_dict, **kwargs) -> Any

Convert data dictionary to appropriate input for the model.

Source code in warpconvnet/nn/modules/base_module.py
def data_dict_to_input(self, data_dict, **kwargs) -> Any:
    """Convert data dictionary to appropriate input for the model."""
    raise NotImplementedError

eval_dict(data_dict, **kwargs) -> Dict

Compute the evaluation dictionary for the model.

Source code in warpconvnet/nn/modules/base_module.py
@torch.no_grad()
def eval_dict(self, data_dict, **kwargs) -> Dict:
    """Compute the evaluation dictionary for the model."""
    raise NotImplementedError

image_pointcloud_dict(data_dict, datamodule) -> Tuple[Dict, Dict]

Compute the image dict and pointcloud dict for the model.

Source code in warpconvnet/nn/modules/base_module.py
def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
    """Compute the image dict and pointcloud dict for the model."""
    raise NotImplementedError

loss_dict(data_dict, **kwargs) -> Dict

Compute the loss dictionary for the model.

Source code in warpconvnet/nn/modules/base_module.py
def loss_dict(self, data_dict, **kwargs) -> Dict:
    """Compute the loss dictionary for the model."""
    raise NotImplementedError

BaseSpatialModule

Bases: Module

Base module for spatial features. The input must be an instance of BatchedSpatialFeatures.

Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModule(nn.Module):
    """Base module for spatial features. The input must be an instance of `BatchedSpatialFeatures`."""

    @property
    def device(self):
        """Returns the device that the model is on."""
        return next(self.parameters()).device

    def forward(self, x: Geometry):
        """Forward pass."""
        raise NotImplementedError

device property

Returns the device that the model is on.

forward(x: Geometry)

Forward pass.

Source code in warpconvnet/nn/modules/base_module.py
def forward(self, x: Geometry):
    """Forward pass."""
    raise NotImplementedError

Factor grid

warpconvnet.nn.modules.factor_grid

Neural network modules for FactorGrid operations.

This module provides neural network layers and operations specifically designed for working with FactorGrid geometries in the FIGConvNet architecture.

FactorGridCat

Bases: BaseSpatialModule

Concatenate features of two FactorGrid objects.

This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridCat(BaseSpatialModule):
    """Concatenate features of two FactorGrid objects.

    This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.
    """

    def __init__(self):
        super().__init__()

    def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
        """Concatenate features from two FactorGrid objects."""
        return factor_grid_cat(factor_grid1, factor_grid2)

forward(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid

Concatenate features from two FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
    """Concatenate features from two FactorGrid objects."""
    return factor_grid_cat(factor_grid1, factor_grid2)

FactorGridGlobalConv

Bases: BaseSpatialModule

Global convolution with intra-communication for FactorGrid.

This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridGlobalConv(BaseSpatialModule):
    """Global convolution with intra-communication for FactorGrid.

    This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        compressed_spatial_dims: Tuple[int, ...],
        compressed_memory_formats: Tuple[GridMemoryFormat, ...],
        stride: int = 1,
        up_stride: Optional[int] = None,
        communication_types: List[Literal["sum", "mul"]] = ["sum"],
        norm: Type[nn.Module] = nn.BatchNorm2d,
        activation: Type[nn.Module] = nn.GELU,
    ):
        super().__init__()
        assert len(compressed_spatial_dims) == len(
            compressed_memory_formats
        ), "Number of compressed spatial dimensions and compressed memory formats must match"

        # Create convolution blocks for each compressed spatial dimension
        self.conv_blocks = nn.ModuleList()
        for compressed_spatial_dim, compressed_memory_format in zip(
            compressed_spatial_dims, compressed_memory_formats
        ):
            self.conv_blocks.append(
                _FactorGridConvNormAct(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    compressed_spatial_dim=compressed_spatial_dim,
                    compressed_memory_format=compressed_memory_format,
                    stride=stride,
                    up_stride=up_stride,
                    norm=norm,
                    activation=activation,
                )
            )

        # Intra-communication module
        self.intra_communications = FactorGridIntraCommunication(
            communication_types=communication_types
        )

        # Projection layer if multiple communication types
        if len(communication_types) > 1:
            self.proj = FactorGridProjection(
                in_channels=out_channels * len(communication_types),
                out_channels=out_channels,
                kernel_size=1,
                compressed_spatial_dims=compressed_spatial_dims,
                compressed_memory_formats=compressed_memory_formats,
                stride=1,
            )
        else:
            self.proj = nn.Identity()

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Forward pass through the global convolution module."""
        assert len(factor_grid) == len(
            self.conv_blocks
        ), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"

        # Apply convolution blocks to each grid
        convolved_grids = []
        for grid, conv_block in zip(factor_grid, self.conv_blocks):
            convolved = conv_block(grid)
            convolved_grids.append(convolved)

        # Apply intra-communication
        factor_grid = self.intra_communications(FactorGrid(convolved_grids))

        # Apply projection if needed
        factor_grid = self.proj(factor_grid)

        return factor_grid

forward(factor_grid: FactorGrid) -> FactorGrid

Forward pass through the global convolution module.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Forward pass through the global convolution module."""
    assert len(factor_grid) == len(
        self.conv_blocks
    ), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"

    # Apply convolution blocks to each grid
    convolved_grids = []
    for grid, conv_block in zip(factor_grid, self.conv_blocks):
        convolved = conv_block(grid)
        convolved_grids.append(convolved)

    # Apply intra-communication
    factor_grid = self.intra_communications(FactorGrid(convolved_grids))

    # Apply projection if needed
    factor_grid = self.proj(factor_grid)

    return factor_grid

FactorGridIntraCommunication

Bases: BaseSpatialModule

Intra-communication between grids in a FactorGrid.

This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridIntraCommunication(BaseSpatialModule):
    """Intra-communication between grids in a FactorGrid.

    This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.
    """

    def __init__(
        self, communication_types: List[Literal["sum", "mul"]] = ["sum"]
    ) -> None:
        super().__init__()
        assert (
            len(communication_types) > 0
        ), "At least one communication type must be provided"
        assert (
            len(communication_types) <= 2
        ), "At most two communication types can be provided"
        self.communication_types = communication_types

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Perform intra-communication between grids in the FactorGrid."""
        return factor_grid_intra_communication(factor_grid, self.communication_types)

forward(factor_grid: FactorGrid) -> FactorGrid

Perform intra-communication between grids in the FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Perform intra-communication between grids in the FactorGrid."""
    return factor_grid_intra_communication(factor_grid, self.communication_types)

FactorGridPadToMatch

Bases: BaseSpatialModule

Pad FactorGrid features to match spatial dimensions for UNet skip connections.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPadToMatch(BaseSpatialModule):
    """Pad FactorGrid features to match spatial dimensions for UNet skip connections."""

    def __init__(self):
        super().__init__()

    def forward(
        self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid
    ) -> FactorGrid:
        """Pad up_factor_grid to match down_factor_grid spatial dimensions."""
        assert len(up_factor_grid) == len(
            down_factor_grid
        ), "FactorGrids must have same number of grids"

        padded_grids = []
        for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
            # Get features and shapes
            up_features = up_grid.grid_features.batched_tensor
            down_features = down_grid.grid_features.batched_tensor

            # Get spatial dimensions (excluding batch and channel)
            up_shape = up_features.shape[2:]  # Spatial dimensions
            down_shape = down_features.shape[2:]  # Spatial dimensions

            if up_shape == down_shape:
                # No padding needed
                padded_grids.append(up_grid)
            else:
                # Calculate padding needed
                pad_h = max(0, down_shape[0] - up_shape[0])
                pad_w = max(0, down_shape[1] - up_shape[1])
                pad_d = (
                    max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0
                )

                # Apply padding
                if len(down_shape) == 2:  # 2D case
                    padded_features = F.pad(
                        up_features, (0, pad_w, 0, pad_h), mode="replicate"
                    )
                else:  # 3D case
                    padded_features = F.pad(
                        up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
                    )

                # Create new grid with padded features
                padded_grid = up_grid.replace(batched_features=padded_features)
                padded_grids.append(padded_grid)

        return FactorGrid(padded_grids)

forward(up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid

Pad up_factor_grid to match down_factor_grid spatial dimensions.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(
    self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid
) -> FactorGrid:
    """Pad up_factor_grid to match down_factor_grid spatial dimensions."""
    assert len(up_factor_grid) == len(
        down_factor_grid
    ), "FactorGrids must have same number of grids"

    padded_grids = []
    for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
        # Get features and shapes
        up_features = up_grid.grid_features.batched_tensor
        down_features = down_grid.grid_features.batched_tensor

        # Get spatial dimensions (excluding batch and channel)
        up_shape = up_features.shape[2:]  # Spatial dimensions
        down_shape = down_features.shape[2:]  # Spatial dimensions

        if up_shape == down_shape:
            # No padding needed
            padded_grids.append(up_grid)
        else:
            # Calculate padding needed
            pad_h = max(0, down_shape[0] - up_shape[0])
            pad_w = max(0, down_shape[1] - up_shape[1])
            pad_d = (
                max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0
            )

            # Apply padding
            if len(down_shape) == 2:  # 2D case
                padded_features = F.pad(
                    up_features, (0, pad_w, 0, pad_h), mode="replicate"
                )
            else:  # 3D case
                padded_features = F.pad(
                    up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
                )

            # Create new grid with padded features
            padded_grid = up_grid.replace(batched_features=padded_features)
            padded_grids.append(padded_grid)

    return FactorGrid(padded_grids)

FactorGridPool

Bases: BaseSpatialModule

Pooling operation for FactorGrid.

This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPool(BaseSpatialModule):
    """Pooling operation for FactorGrid.

    This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.
    """

    def __init__(
        self,
        pooling_type: Literal["max", "mean", "attention"] = "max",
    ):
        super().__init__()
        self.pooling_type = pooling_type

        # Pooling operation
        if pooling_type == "max":
            self.pool_op = nn.AdaptiveMaxPool1d(1)
        elif pooling_type == "mean":
            self.pool_op = nn.AdaptiveAvgPool1d(1)
        elif pooling_type == "attention":
            # For now, use simple attention mechanism
            # Note: attention layer dimensions will depend on actual feature dimensions
            self.attention = None  # Will be set based on input if needed
            self.pool_op = None
        else:
            raise ValueError(f"Unsupported pooling type: {pooling_type}")

    def forward(self, factor_grid: FactorGrid) -> Tensor:
        """Pool features from FactorGrid to a single tensor."""
        return factor_grid_pool(
            factor_grid,
            self.pooling_type,
            pool_op=self.pool_op,
            attention_layer=getattr(self, "attention", None),
        )

forward(factor_grid: FactorGrid) -> Tensor

Pool features from FactorGrid to a single tensor.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> Tensor:
    """Pool features from FactorGrid to a single tensor."""
    return factor_grid_pool(
        factor_grid,
        self.pooling_type,
        pool_op=self.pool_op,
        attention_layer=getattr(self, "attention", None),
    )

FactorGridProjection

Bases: BaseSpatialModule

Projection operation for FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridProjection(BaseSpatialModule):
    """Projection operation for FactorGrid."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        compressed_spatial_dims: Tuple[int, ...],
        compressed_memory_formats: Tuple[GridMemoryFormat, ...],
        stride: int = 1,
        norm: Type[nn.Module] = nn.BatchNorm2d,
        activation: Type[nn.Module] = nn.GELU,
    ):
        super().__init__()
        self.convs = nn.ModuleList()
        for compressed_spatial_dim, compressed_memory_format in zip(
            compressed_spatial_dims, compressed_memory_formats
        ):
            block = nn.Sequential(
                nn.Conv2d(
                    in_channels * compressed_spatial_dim,
                    out_channels * compressed_spatial_dim,
                    kernel_size,
                    stride,
                    (kernel_size - 1) // 2,
                    bias=True,
                ),
                norm(out_channels * compressed_spatial_dim),
                activation(),
            )
            self.convs.append(block)

    def forward(self, grid: FactorGrid) -> FactorGrid:
        projected_grids = []
        for grid, conv in zip(grid, self.convs):
            projected_grid = conv(grid)
            projected_grids.append(projected_grid)
        return FactorGrid(projected_grids)

FactorGridToPoint

Bases: BaseSpatialModule

Convert FactorGrid features back to point features using trilinear interpolation.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridToPoint(BaseSpatialModule):
    """Convert FactorGrid features back to point features using trilinear interpolation."""

    def __init__(
        self,
        grid_in_channels: int,
        point_in_channels: int,
        num_grids: int,
        out_channels: int,
        use_rel_pos: bool = True,
        use_rel_pos_embed: bool = False,
        pos_embed_dim: int = 32,
        sample_method: Literal["graphconv", "interp"] = "interp",
        neighbor_search_type: Literal["radius", "knn"] = "radius",
        knn_k: int = 16,
        reductions: List[str] = ["mean"],
    ):
        super().__init__()
        self.grid_in_channels = grid_in_channels
        self.point_in_channels = point_in_channels
        self.out_channels = out_channels
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_embed = use_rel_pos_embed
        self.pos_embed_dim = pos_embed_dim
        self.sample_method = sample_method
        self.neighbor_search_type = neighbor_search_type
        self.knn_k = knn_k
        self.reductions = reductions
        self.freqs = get_freqs(pos_embed_dim)

        # Calculate combined channels for MLP
        combined_channels = grid_in_channels * num_grids + point_in_channels
        if use_rel_pos_embed:
            combined_channels += pos_embed_dim * 3
        elif use_rel_pos:
            combined_channels += 3

        self.combine_mlp = nn.Sequential(
            nn.Linear(combined_channels, out_channels),
            nn.LayerNorm(out_channels),
            nn.GELU(),
            nn.Linear(out_channels, out_channels),
        )

    def _normalize_coordinates(self, coords: Tensor, grid: Grid) -> Tensor:
        """Normalize coordinates to [-1, 1] range for grid_sample using grid bounds."""
        # Get bounds from the grid
        bounds_min, bounds_max = grid.bounds

        # Normalize to [0, 1] first
        normalized = (coords - bounds_min) / (bounds_max - bounds_min)

        # Convert to [-1, 1] range for grid_sample
        normalized = 2.0 * normalized - 1.0

        # Ensure coordinates are within bounds
        normalized = torch.clamp(normalized, -1.0, 1.0)

        return normalized

    def _sample_from_grid(self, grid: Grid, point_coords: Tensor) -> Tensor:
        """Sample features from grid using trilinear interpolation."""
        grid_features_tensor = grid.grid_features.batched_tensor
        batch_size = grid_features_tensor.shape[0]

        # Normalize point coordinates using grid bounds
        normalized_coords = self._normalize_coordinates(point_coords, grid)

        # Reshape coordinates for grid_sample: (B, N, 3) -> (B, N, 1, 1, 3)
        # grid_sample expects (B, H, W, D, 3) for 3D grids
        normalized_coords = normalized_coords.view(batch_size, -1, 1, 1, 3)

        # Convert grid to standard format for grid_sample
        grid_reshaped = grid.grid_features.to_memory_format(GridMemoryFormat.b_c_x_y_z)

        # Use grid_sample for trilinear interpolation
        sampled_features = F.grid_sample(
            grid_reshaped.batched_tensor,
            normalized_coords,
            mode="bilinear",  # For 3D, this becomes trilinear
            padding_mode="border",
            align_corners=True,
        )  # Shape: B, C, N, 1, 1

        # Reshape to (B*N, C)
        sampled_features = (
            sampled_features.squeeze(-1).squeeze(-1).transpose(1, 2)
        )  # B, N, C
        sampled_features = sampled_features.flatten(start_dim=0, end_dim=1)  # B*N, C
        return sampled_features

    def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
        """Convert FactorGrid features to point features using trilinear interpolation."""
        # Get point coordinates and features
        point_coords = point_features.coordinate_tensor
        point_feats = point_features.feature_tensor
        batch_size = point_features.batch_size
        num_points = point_coords.shape[0]

        # Sample features from all grids and concatenate
        all_grid_features = []

        for grid in factor_grid:
            # Sample features from this grid
            sampled_features = self._sample_from_grid(grid, point_coords)
            all_grid_features.append(sampled_features)

        # Concatenate features from all grids
        if len(all_grid_features) > 1:
            grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
        else:
            grid_feat_per_point = all_grid_features[0]

        # Add relative position features if requested
        if self.use_rel_pos_embed:
            # Use bounds from the first grid for relative position calculation
            bounds_min, bounds_max = factor_grid[0].bounds
            rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
            pos_encoding = sinusoidal_encoding(
                rel_pos, self.pos_embed_dim, data_range=2
            )
            combined_features = torch.cat(
                [point_feats, grid_feat_per_point, pos_encoding], dim=-1
            )
        elif self.use_rel_pos:
            # Use raw relative positions
            # Use bounds from the first grid for relative position calculation
            bounds_min, bounds_max = factor_grid[0].bounds
            rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
            combined_features = torch.cat(
                [point_feats, grid_feat_per_point, rel_pos], dim=-1
            )
        else:
            # Just concatenate point and grid features
            combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)

        # Apply MLP
        output_features = self.combine_mlp(combined_features)

        # Create new Points object
        return Points(
            batched_coordinates=point_features.batched_coordinates,
            batched_features=output_features,
        )

forward(factor_grid: FactorGrid, point_features: Points) -> Points

Convert FactorGrid features to point features using trilinear interpolation.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
    """Convert FactorGrid features to point features using trilinear interpolation."""
    # Get point coordinates and features
    point_coords = point_features.coordinate_tensor
    point_feats = point_features.feature_tensor
    batch_size = point_features.batch_size
    num_points = point_coords.shape[0]

    # Sample features from all grids and concatenate
    all_grid_features = []

    for grid in factor_grid:
        # Sample features from this grid
        sampled_features = self._sample_from_grid(grid, point_coords)
        all_grid_features.append(sampled_features)

    # Concatenate features from all grids
    if len(all_grid_features) > 1:
        grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
    else:
        grid_feat_per_point = all_grid_features[0]

    # Add relative position features if requested
    if self.use_rel_pos_embed:
        # Use bounds from the first grid for relative position calculation
        bounds_min, bounds_max = factor_grid[0].bounds
        rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
        pos_encoding = sinusoidal_encoding(
            rel_pos, self.pos_embed_dim, data_range=2
        )
        combined_features = torch.cat(
            [point_feats, grid_feat_per_point, pos_encoding], dim=-1
        )
    elif self.use_rel_pos:
        # Use raw relative positions
        # Use bounds from the first grid for relative position calculation
        bounds_min, bounds_max = factor_grid[0].bounds
        rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
        combined_features = torch.cat(
            [point_feats, grid_feat_per_point, rel_pos], dim=-1
        )
    else:
        # Just concatenate point and grid features
        combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)

    # Apply MLP
    output_features = self.combine_mlp(combined_features)

    # Create new Points object
    return Points(
        batched_coordinates=point_features.batched_coordinates,
        batched_features=output_features,
    )

FactorGridTransform

Bases: BaseSpatialModule

Apply a transform operation to all grids in a FactorGrid.

This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridTransform(BaseSpatialModule):
    """Apply a transform operation to all grids in a FactorGrid.

    This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.
    """

    def __init__(self, transform: nn.Module, in_place: bool = True) -> None:
        super().__init__()
        self.transform = transform
        self.in_place = in_place

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Apply transform to all grids in the FactorGrid."""
        return factor_grid_transform(factor_grid, self.transform, self.in_place)

forward(factor_grid: FactorGrid) -> FactorGrid

Apply transform to all grids in the FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Apply transform to all grids in the FactorGrid."""
    return factor_grid_transform(factor_grid, self.transform, self.in_place)

PointToFactorGrid

Bases: BaseSpatialModule

Convert point features to FactorGrid representation.

Source code in warpconvnet/nn/modules/factor_grid.py
class PointToFactorGrid(BaseSpatialModule):
    """Convert point features to FactorGrid representation."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        grid_shapes: List[Tuple[int, int, int]],
        memory_formats: List[GridMemoryFormat],
        aabb_max: Tuple[float, float, float],
        aabb_min: Tuple[float, float, float],
        use_rel_pos: bool = True,
        use_rel_pos_embed: bool = True,
        pos_encode_dim: int = 32,
        search_radius: Optional[float] = None,
        k: int = 8,
        search_type: Literal["radius", "knn", "voxel"] = "radius",
        reduction: str = "mean",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_shapes = grid_shapes
        self.memory_formats = memory_formats
        self.aabb_max = aabb_max
        self.aabb_min = aabb_min
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_embed = use_rel_pos_embed
        self.pos_encode_dim = pos_encode_dim
        self.search_radius = search_radius
        self.k = k
        self.search_type = search_type
        self.reduction = reduction

        # Calculate compressed spatial dimensions
        self.compressed_spatial_dims = []
        for grid_shape, memory_format in zip(grid_shapes, memory_formats):
            # Determine compressed spatial dimension
            if memory_format == GridMemoryFormat.b_zc_x_y:
                compressed_dim = grid_shape[0]  # Z dimension
            elif memory_format == GridMemoryFormat.b_xc_y_z:
                compressed_dim = grid_shape[0]  # X dimension
            elif memory_format == GridMemoryFormat.b_yc_x_z:
                compressed_dim = grid_shape[1]  # Y dimension
            else:
                raise ValueError(f"Unsupported memory format: {memory_format}")

            self.compressed_spatial_dims.append(compressed_dim)

        # Create projection layers for each grid
        self.projections = nn.ModuleList()
        for compressed_dim in self.compressed_spatial_dims:
            # Create projection layer
            proj = nn.Linear(in_channels, out_channels * compressed_dim)
            self.projections.append(proj)

    def forward(self, points: Points) -> FactorGrid:
        """Convert point features to FactorGrid."""
        from warpconvnet.geometry.types.conversion.to_factor_grid import (
            points_to_factor_grid,
        )

        # Convert points to FactorGrid using the existing function
        factor_grid = points_to_factor_grid(
            points=points,
            grid_shapes=self.grid_shapes,
            memory_formats=self.memory_formats,
            bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
            search_radius=self.search_radius,
            k=self.k,
            search_type=self.search_type,
            reduction=self.reduction,
        )

        # For now, return the factor grid as is without projections
        # The projections can be applied later in the pipeline if needed
        return factor_grid

forward(points: Points) -> FactorGrid

Convert point features to FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, points: Points) -> FactorGrid:
    """Convert point features to FactorGrid."""
    from warpconvnet.geometry.types.conversion.to_factor_grid import (
        points_to_factor_grid,
    )

    # Convert points to FactorGrid using the existing function
    factor_grid = points_to_factor_grid(
        points=points,
        grid_shapes=self.grid_shapes,
        memory_formats=self.memory_formats,
        bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
        search_radius=self.search_radius,
        k=self.k,
        search_type=self.search_type,
        reduction=self.reduction,
    )

    # For now, return the factor grid as is without projections
    # The projections can be applied later in the pipeline if needed
    return factor_grid

Grid convolution

warpconvnet.nn.modules.grid_conv

GridConv

Bases: BaseSpatialModule

Convolutional layer for warpconvnet.geometry.types.grid.Grid data.

Parameters mirror those of torch.nn.Conv3d but operate on a Grid object instead of plain tensors.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Stride of the convolution. Defaults to 1.

  • padding (int or tuple of int, default: 0 ) –

    Zero-padding added to all three sides of the input. Defaults to 0.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True, adds a learnable bias to the output. Defaults to True.

  • num_spatial_dims (int, default: 3 ) –

    Number of spatial dimensions. Defaults to 3.

Source code in warpconvnet/nn/modules/grid_conv.py
class GridConv(BaseSpatialModule):
    """Convolutional layer for `warpconvnet.geometry.types.grid.Grid` data.

    Parameters mirror those of `torch.nn.Conv3d` but operate on a
    ``Grid`` object instead of plain tensors.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Stride of the convolution. Defaults to ``1``.
    padding : int or tuple of int, optional
        Zero-padding added to all three sides of the input. Defaults to ``0``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True``, adds a learnable bias to the output. Defaults to ``True``.
    num_spatial_dims : int, optional
        Number of spatial dimensions. Defaults to ``3``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Union[int, Tuple[int, ...]] = 0,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        num_spatial_dims: Optional[int] = 3,
    ):
        super().__init__()
        kernel_size = ntuple(kernel_size, ndim=num_spatial_dims)
        stride = ntuple(stride, ndim=num_spatial_dims)
        padding = ntuple(padding, ndim=num_spatial_dims)
        dilation = ntuple(dilation, ndim=num_spatial_dims)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.num_spatial_dims = num_spatial_dims

        # For 3D convolution, shape is (out_channels, in_channels, depth, height, width)
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"in_channels={self.in_channels}, "
            f"out_channels={self.out_channels}, "
            f"kernel_size={self.kernel_size}, "
            f"stride={self.stride}, "
            f"padding={self.padding}, "
            f"dilation={self.dilation}, "
            f"bias={self.bias is not None}"
            f")"
        )

    def reset_parameters(self):
        # Standard initialization for convolutional layers
        init.kaiming_uniform_(self.weight, a=1)
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / (fan_in**0.5)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input_grid: Grid) -> Grid:
        return grid_conv(
            grid=input_grid,
            weight=self.weight,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            bias=self.bias,
        )

MLP

warpconvnet.nn.modules.mlp

BatchedLinear

Bases: Module

A linear layer with batched weights for Muon-friendly optimization.

Instead of a single weight matrix [in_features, out_features * num_matrices], this uses separate weight matrices stacked as [num_matrices, in_features, out_features]. This structure is more suitable for Muon optimization as it can orthogonalize each [in_features, out_features] matrix independently.

Args: in_features: Input feature dimension out_features: Output feature dimension per matrix num_matrices: Number of separate matrices (e.g., 3 for Q, K, V) bias: Whether to use bias parameters

Source code in warpconvnet/nn/modules/mlp.py
class BatchedLinear(nn.Module):
    """
    A linear layer with batched weights for Muon-friendly optimization.

    Instead of a single weight matrix [in_features, out_features * num_matrices],
    this uses separate weight matrices stacked as [num_matrices, in_features, out_features].
    This structure is more suitable for Muon optimization as it can orthogonalize
    each [in_features, out_features] matrix independently.

    Args:
        in_features: Input feature dimension
        out_features: Output feature dimension per matrix
        num_matrices: Number of separate matrices (e.g., 3 for Q, K, V)
        bias: Whether to use bias parameters
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        num_matrices: int = 3,
        bias: bool = True,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_matrices = num_matrices

        # Create batched weight: [num_matrices, in_features, out_features]
        self.weight = nn.Parameter(torch.empty(num_matrices, in_features, out_features))
        nn.init.xavier_uniform_(self.weight)

        if bias:
            # Use flat bias for Muon - 1D parameter gets Adam optimization
            self.bias = nn.Parameter(torch.zeros(num_matrices * out_features))
        else:
            self.register_parameter("bias", None)

    def forward(self, input: Tensor) -> Tensor:
        """
        Forward pass with batched matrix multiplication.

        Args:
            input: Input tensor of shape [..., in_features]

        Returns:
            Output tensor of shape [..., num_matrices, out_features]
        """
        # input: [..., in_features], weight: [num_matrices, in_features, out_features]
        # output: [..., num_matrices, out_features]
        output = torch.einsum("...i,kio->...ko", input, self.weight)

        if self.bias is not None:
            output += self.bias.view(self.num_matrices, self.out_features)

        output = output.to(input.dtype)
        return output

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, num_matrices={self.num_matrices}, bias={self.bias is not None}"

forward(input: Tensor) -> Tensor

Forward pass with batched matrix multiplication.

Args: input: Input tensor of shape [..., in_features]

Returns: Output tensor of shape [..., num_matrices, out_features]

Source code in warpconvnet/nn/modules/mlp.py
def forward(self, input: Tensor) -> Tensor:
    """
    Forward pass with batched matrix multiplication.

    Args:
        input: Input tensor of shape [..., in_features]

    Returns:
        Output tensor of shape [..., num_matrices, out_features]
    """
    # input: [..., in_features], weight: [num_matrices, in_features, out_features]
    # output: [..., num_matrices, out_features]
    output = torch.einsum("...i,kio->...ko", input, self.weight)

    if self.bias is not None:
        output += self.bias.view(self.num_matrices, self.out_features)

    output = output.to(input.dtype)
    return output

FeatureMLPBlock

Bases: Module

Plain MLP over a feature tensor with variable-depth hidden layers.

Stacks Linear → LayerNorm → activation per hidden width, then a final Linear to out_channels (no norm/activation on the last layer).

Source code in warpconvnet/nn/modules/mlp.py
class FeatureMLPBlock(nn.Module):
    """Plain MLP over a feature tensor with variable-depth hidden layers.

    Stacks Linear → LayerNorm → activation per hidden width, then a final
    Linear to ``out_channels`` (no norm/activation on the last layer).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: Sequence[int] = (),
        activation=nn.GELU,
        bias: bool = True,
    ):
        super().__init__()
        layers = []
        prev = in_channels
        for h in hidden_channels:
            layers.append(nn.Linear(prev, h, bias=bias))
            layers.append(nn.LayerNorm(h))
            layers.append(activation())
            prev = h
        layers.append(nn.Linear(prev, out_channels, bias=bias))
        self.block = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)

Linear

Bases: BaseSpatialModule

Apply a linear layer to Geometry features.

Parameters:
  • in_features (int) –

    Number of input features.

  • out_features (int) –

    Number of output features.

  • bias (bool, default: True ) –

    If True adds a bias term to the layer. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class Linear(BaseSpatialModule):
    """Apply a linear layer to ``Geometry`` features.

    Parameters
    ----------
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    bias : bool, optional
        If ``True`` adds a bias term to the layer. Defaults to ``True``.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.block = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, x: Geometry):
        return x.replace(batched_features=self.block(x.feature_tensor))

LinearNormActivation

Bases: BaseSpatialModule

Linear layer followed by LayerNorm and ReLU.

Parameters:
  • in_features (int) –

    Number of input features.

  • out_features (int) –

    Number of output features.

  • bias (bool, default: True ) –

    Whether to include a bias term. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class LinearNormActivation(BaseSpatialModule):
    """Linear layer followed by ``LayerNorm`` and ``ReLU``.

    Parameters
    ----------
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    bias : bool, optional
        Whether to include a bias term. Defaults to ``True``.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, out_features, bias=bias),
            nn.LayerNorm(out_features),
            nn.ReLU(),
        )

    def forward(self, x: Geometry):
        return x.replace(batched_features=self.block(x.feature_tensor))

MLPBlock

Bases: BaseSpatialModule

MLP block with a residual connection.

Parameters:
  • in_channels (int) –

    Number of input features.

  • out_channels (int, default: None ) –

    Number of output features. Defaults to in_channels.

  • hidden_channels (int, default: None ) –

    Hidden layer size. Defaults to in_channels.

  • activation (``nn.Module``, default: ReLU ) –

    Activation module to apply. Defaults to torch.nn.ReLU.

  • bias (bool, default: True ) –

    If True adds bias terms to the linear layers. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class MLPBlock(BaseSpatialModule):
    """MLP block with a residual connection.

    Parameters
    ----------
    in_channels : int
        Number of input features.
    out_channels : int, optional
        Number of output features. Defaults to ``in_channels``.
    hidden_channels : int, optional
        Hidden layer size. Defaults to ``in_channels``.
    activation : ``nn.Module``, optional
        Activation module to apply. Defaults to `torch.nn.ReLU`.
    bias : bool, optional
        If ``True`` adds bias terms to the linear layers. Defaults to ``True``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int = None,
        hidden_channels: int = None,
        activation=nn.ReLU,
        bias: bool = True,
    ):
        super().__init__()
        if hidden_channels is None:
            hidden_channels = in_channels
        if out_channels is None:
            out_channels = in_channels
        self.in_channels = in_channels
        self.block = nn.Sequential(
            nn.Linear(in_channels, hidden_channels, bias=bias),
            nn.LayerNorm(hidden_channels),
            activation(),
            nn.Linear(hidden_channels, out_channels, bias=bias),
            nn.LayerNorm(out_channels),
        )
        self.shortcut = (
            nn.Linear(in_channels, out_channels, bias=bias)
            if in_channels != out_channels
            else nn.Identity()
        )

    def _forward_feature(self, x: Tensor) -> Tensor:
        out = self.block(x)
        out = out + self.shortcut(x)
        return out

    def forward(self, x: Union[Tensor, Geometry]):
        if isinstance(x, Geometry):
            return x.replace(batched_features=self._forward_feature(x.feature_tensor))
        else:
            return self._forward_feature(x)

Dense 3D blocks

warpconvnet.nn.modules.conv3d_blocks

Reusable dense 3D convolutional blocks.

DownsampleBlock3d

Bases: Module

2x dense 3D downsampling by strided conv or average pooling.

Source code in warpconvnet/nn/modules/conv3d_blocks.py
class DownsampleBlock3d(nn.Module):
    """2x dense 3D downsampling by strided conv or average pooling."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mode: Literal["conv", "avgpool"] = "conv",
    ):
        super().__init__()
        assert mode in ("conv", "avgpool")
        self.in_channels = in_channels
        self.out_channels = out_channels
        if mode == "conv":
            self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
        else:
            assert in_channels == out_channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self, "conv"):
            return self.conv(x)
        return F.avg_pool3d(x, 2)

ResBlock3d

Bases: Module

Pre-norm 3D residual block: norm -> SiLU -> conv, twice, plus skip.

Source code in warpconvnet/nn/modules/conv3d_blocks.py
class ResBlock3d(nn.Module):
    """Pre-norm 3D residual block: norm -> SiLU -> conv, twice, plus skip."""

    def __init__(
        self,
        channels: int,
        out_channels: int | None = None,
        norm_type: Literal["group", "layer"] = "layer",
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.norm1 = norm_layer_3d(norm_type, channels)
        self.norm2 = norm_layer_3d(norm_type, self.out_channels)
        self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
        self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
        self.skip_connection = (
            nn.Conv3d(channels, self.out_channels, 1)
            if channels != self.out_channels
            else nn.Identity()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        h = F.silu(self.norm2(h))
        h = self.conv2(h)
        return h + self.skip_connection(x)

UpsampleBlock3d

Bases: Module

2x dense 3D upsampling by conv + pixel shuffle or nearest interpolation.

Source code in warpconvnet/nn/modules/conv3d_blocks.py
class UpsampleBlock3d(nn.Module):
    """2x dense 3D upsampling by conv + pixel shuffle or nearest interpolation."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mode: Literal["conv", "nearest"] = "conv",
    ):
        super().__init__()
        assert mode in ("conv", "nearest")
        self.in_channels = in_channels
        self.out_channels = out_channels
        if mode == "conv":
            self.conv = nn.Conv3d(in_channels, out_channels * 8, 3, padding=1)
        else:
            assert in_channels == out_channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self, "conv"):
            return pixel_shuffle_3d(self.conv(x), 2)
        return F.interpolate(x, scale_factor=2, mode="nearest")

norm_layer_3d(norm_type: Literal['group', 'layer'], channels: int) -> nn.Module

Build a fp32-internal normalization layer for (B, C, D, H, W) tensors.

Source code in warpconvnet/nn/modules/conv3d_blocks.py
def norm_layer_3d(norm_type: Literal["group", "layer"], channels: int) -> nn.Module:
    """Build a fp32-internal normalization layer for ``(B, C, D, H, W)`` tensors."""
    if norm_type == "group":
        return GroupNorm32(32, channels)
    if norm_type == "layer":
        return ChannelLayerNorm32(channels)
    raise ValueError(f"Invalid norm type {norm_type}")

Diffusion transformer blocks

warpconvnet.nn.modules.dit

DiT-style transformer building blocks for diffusion / flow models.

Operates on dense (B, L, C) token sequences. Provides:

  • MultiHeadAttention — self/cross attention with optional RoPE + qk-norm
  • FeedForwardNet — Linear-GELU(tanh)-Linear MLP
  • ModulatedTransformerBlock — DiT block (self-attn + FFN with adaLN modulation)
  • ModulatedTransformerCrossBlock — DiT cross block (self + cross + FFN, adaLN)

These mirror the building blocks introduced by DiT / SD3 / Flux / TRELLIS and are usable for any diffusion-transformer port. For sparse-voxel attention see warpconvnet.nn.modules.attention and warpconvnet.nn.modules.space_attention.

FeedForwardNet

Bases: Module

Linear-GELU(tanh)-Linear feed-forward block (mlp_ratio scaled hidden).

Source code in warpconvnet/nn/modules/dit.py
class FeedForwardNet(nn.Module):
    """Linear-GELU(tanh)-Linear feed-forward block (mlp_ratio scaled hidden)."""

    def __init__(self, channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(channels, int(channels * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(channels * mlp_ratio), channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

ModulatedTransformerBlock

Bases: Module

DiT block: norm → adaLN-modulated self-attn → norm → adaLN-modulated FFN.

Source code in warpconvnet/nn/modules/dit.py
class ModulatedTransformerBlock(nn.Module):
    """DiT block: norm → adaLN-modulated self-attn → norm → adaLN-modulated FFN."""

    def __init__(
        self,
        channels: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: Literal["full"] = "full",
        use_checkpoint: bool = False,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
        qkv_bias: bool = True,
        share_mod: bool = False,
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.share_mod = share_mod
        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.attn = MultiHeadAttention(
            channels,
            num_heads=num_heads,
            attn_mode=attn_mode,
            qkv_bias=qkv_bias,
            use_rope=use_rope,
            rope_freq=rope_freq,
            qk_rms_norm=qk_rms_norm,
        )
        self.mlp = FeedForwardNet(channels, mlp_ratio=mlp_ratio)
        if not share_mod:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
            )
        else:
            self.modulation = nn.Parameter(torch.randn(6 * channels) / channels**0.5)

    def _forward(
        self,
        x: torch.Tensor,
        mod: torch.Tensor,
        phases: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.share_mod:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
            )
        else:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
                mod
            ).chunk(6, dim=1)
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, phases=phases)
        h = h * gate_msa.unsqueeze(1)
        x = x + h
        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        h = h * gate_mlp.unsqueeze(1)
        return x + h

    def forward(
        self,
        x: torch.Tensor,
        mod: torch.Tensor,
        phases: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(
                self._forward, x, mod, phases, use_reentrant=False
            )
        return self._forward(x, mod, phases)

ModulatedTransformerCrossBlock

Bases: Module

DiT cross block: self-attn → cross-attn → FFN with adaLN modulation.

Source code in warpconvnet/nn/modules/dit.py
class ModulatedTransformerCrossBlock(nn.Module):
    """DiT cross block: self-attn → cross-attn → FFN with adaLN modulation."""

    def __init__(
        self,
        channels: int,
        ctx_channels: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: Literal["full"] = "full",
        use_checkpoint: bool = False,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
        qk_rms_norm_cross: bool = False,
        qkv_bias: bool = True,
        share_mod: bool = False,
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.share_mod = share_mod
        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
        self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.self_attn = MultiHeadAttention(
            channels,
            num_heads=num_heads,
            type="self",
            attn_mode=attn_mode,
            qkv_bias=qkv_bias,
            use_rope=use_rope,
            rope_freq=rope_freq,
            qk_rms_norm=qk_rms_norm,
        )
        self.cross_attn = MultiHeadAttention(
            channels,
            ctx_channels=ctx_channels,
            num_heads=num_heads,
            type="cross",
            attn_mode="full",
            qkv_bias=qkv_bias,
            qk_rms_norm=qk_rms_norm_cross,
        )
        self.mlp = FeedForwardNet(channels, mlp_ratio=mlp_ratio)
        if not share_mod:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
            )
        else:
            self.modulation = nn.Parameter(torch.randn(6 * channels) / channels**0.5)

    def _forward(
        self,
        x: torch.Tensor,
        mod: torch.Tensor,
        context: torch.Tensor,
        phases: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.share_mod:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
            )
        else:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
                mod
            ).chunk(6, dim=1)
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.self_attn(h, phases=phases)
        h = h * gate_msa.unsqueeze(1)
        x = x + h
        h = self.norm2(x)
        h = self.cross_attn(h, context)
        x = x + h
        h = self.norm3(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        h = h * gate_mlp.unsqueeze(1)
        return x + h

    def forward(
        self,
        x: torch.Tensor,
        mod: torch.Tensor,
        context: torch.Tensor,
        phases: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(
                self._forward, x, mod, context, phases, use_reentrant=False
            )
        return self._forward(x, mod, context, phases)

MultiHeadAttention

Bases: Module

Dense multi-head attention with optional RoPE + qk-RMSNorm.

Supports type='self' (Q=K=V from x) or type='cross' (Q from x, K/V from context). Cross attention is always full.

Source code in warpconvnet/nn/modules/dit.py
class MultiHeadAttention(nn.Module):
    """Dense multi-head attention with optional RoPE + qk-RMSNorm.

    Supports ``type='self'`` (Q=K=V from `x`) or ``type='cross'``
    (Q from `x`, K/V from `context`). Cross attention is always full.
    """

    def __init__(
        self,
        channels: int,
        num_heads: int,
        ctx_channels: int | None = None,
        type: Literal["self", "cross"] = "self",
        attn_mode: Literal["full"] = "full",
        qkv_bias: bool = True,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
    ):
        super().__init__()
        assert channels % num_heads == 0
        assert type in ("self", "cross")
        assert attn_mode == "full"
        assert type == "self" or attn_mode == "full"

        self.channels = channels
        self.head_dim = channels // num_heads
        self.ctx_channels = ctx_channels if ctx_channels is not None else channels
        self.num_heads = num_heads
        self._type = type
        self.attn_mode = attn_mode
        self.use_rope = use_rope
        self.qk_rms_norm = qk_rms_norm

        if self._type == "self":
            self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
        else:
            self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
            self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)

        if self.qk_rms_norm:
            self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
            self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)

        self.to_out = nn.Linear(channels, channels)

    def forward(
        self,
        x: torch.Tensor,
        context: torch.Tensor | None = None,
        phases: torch.Tensor | None = None,
    ) -> torch.Tensor:
        B, L, _ = x.shape
        if self._type == "self":
            qkv = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1)
            q, k, v = qkv.unbind(dim=2)
            if self.qk_rms_norm:
                q = self.q_rms_norm(q)
                k = self.k_rms_norm(k)
            if self.use_rope:
                assert phases is not None, "phases required when use_rope=True"
                q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
                k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
            h = _sdpa_4d(q, k, v)
        else:
            Lkv = context.shape[1]
            q = self.to_q(x).reshape(B, L, self.num_heads, -1)
            kv = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1)
            k, v = kv.unbind(dim=2)
            if self.qk_rms_norm:
                q = self.q_rms_norm(q)
                k = self.k_rms_norm(k)
            h = _sdpa_4d(q, k, v)
        return self.to_out(h.reshape(B, L, -1))

Embeddings

warpconvnet.nn.modules.embeddings

Position / time embeddings for diffusion and flow-matching models.

TimestepEmbedder — sinusoidal scalar→vector for diffusion timestep. SinusoidalPositionEmbedder — multi-axis sin/cos absolute position embedding. RotaryPositionEmbedder — RoPE phases for arbitrary-dim integer coords (dense tokens). For sparse-voxel RoPE see warpconvnet.nn.modules.rope.VoxelRotaryPositionalEmbeddings.

RotaryPositionEmbedder

Bases: Module

RoPE phases for any-D integer coordinates (dense token version).

head_dim must be even. dim is the number of coordinate axes (3 for voxels, 2 for image patches). Use the returned phases inside an attention module via apply_rotary_embedding.

Source code in warpconvnet/nn/modules/embeddings.py
class RotaryPositionEmbedder(nn.Module):
    """RoPE phases for any-D integer coordinates (dense token version).

    ``head_dim`` must be even. ``dim`` is the number of coordinate axes (3 for
    voxels, 2 for image patches). Use the returned phases inside an attention
    module via ``apply_rotary_embedding``.
    """

    def __init__(
        self,
        head_dim: int,
        dim: int = 3,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
    ):
        super().__init__()
        assert head_dim % 2 == 0, "head_dim must be even"
        self.head_dim = head_dim
        self.dim = dim
        self.rope_freq = rope_freq
        self.freq_dim = head_dim // 2 // dim
        freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
        self.freqs = rope_freq[0] / (rope_freq[1] ** freqs)

    def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
        self.freqs = self.freqs.to(indices.device)
        phases = torch.outer(indices, self.freqs)
        return torch.polar(torch.ones_like(phases), phases)

    @staticmethod
    def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        x_rot = x_complex * phases.unsqueeze(-2)
        return torch.view_as_real(x_rot).reshape(*x_rot.shape[:-1], -1).to(x.dtype)

    def forward(self, indices: torch.Tensor) -> torch.Tensor:
        assert indices.shape[-1] == self.dim
        phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
        if phases.shape[-1] < self.head_dim // 2:
            pad_n = self.head_dim // 2 - phases.shape[-1]
            ones = torch.ones(*phases.shape[:-1], pad_n, device=phases.device)
            zeros = torch.zeros(*phases.shape[:-1], pad_n, device=phases.device)
            phases = torch.cat([phases, torch.polar(ones, zeros)], dim=-1)
        return phases

SinusoidalPositionEmbedder

Bases: Module

Multi-axis sin/cos absolute position embedding for integer coordinates.

For in_channels=D input coords (N, D), produces (N, channels) by concatenating per-axis sin/cos embeddings (zero-padded if channels isn't an exact multiple of 2*D*freq_dim).

Source code in warpconvnet/nn/modules/embeddings.py
class SinusoidalPositionEmbedder(nn.Module):
    """Multi-axis sin/cos absolute position embedding for integer coordinates.

    For ``in_channels=D`` input coords ``(N, D)``, produces ``(N, channels)``
    by concatenating per-axis sin/cos embeddings (zero-padded if `channels`
    isn't an exact multiple of ``2*D*freq_dim``).
    """

    def __init__(self, channels: int, in_channels: int = 3):
        super().__init__()
        self.channels = channels
        self.in_channels = in_channels
        self.freq_dim = channels // in_channels // 2
        freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
        self.freqs = 1.0 / (10000**freqs)

    def _sin_cos(self, x: torch.Tensor) -> torch.Tensor:
        self.freqs = self.freqs.to(x.device)
        out = torch.outer(x, self.freqs)
        return torch.cat([torch.sin(out), torch.cos(out)], dim=-1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        N, D = x.shape
        assert D == self.in_channels
        embed = self._sin_cos(x.reshape(-1)).reshape(N, -1)
        if embed.shape[1] < self.channels:
            pad = torch.zeros(N, self.channels - embed.shape[1], device=embed.device)
            embed = torch.cat([embed, pad], dim=-1)
        return embed

TimestepEmbedder

Bases: Module

Sinusoidal timestep → 2-layer MLP (Linear-SiLU-Linear).

Source code in warpconvnet/nn/modules/embeddings.py
class TimestepEmbedder(nn.Module):
    """Sinusoidal timestep → 2-layer MLP (Linear-SiLU-Linear)."""

    def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))

Normalizations

warpconvnet.nn.modules.normalizations

BatchNorm

Bases: NormalizationBase

Applies torch.nn.BatchNorm1d to Geometry features.

Parameters:
  • num_features (int) –

    Number of feature channels in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

  • momentum (float, default: 0.1 ) –

    Momentum factor for the running statistics. Defaults to 0.1.

Source code in warpconvnet/nn/modules/normalizations.py
class BatchNorm(NormalizationBase):
    """Applies `torch.nn.BatchNorm1d` to ``Geometry`` features.

    Parameters
    ----------
    num_features : int
        Number of feature channels in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    momentum : float, optional
        Momentum factor for the running statistics. Defaults to ``0.1``.
    """

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__(nn.BatchNorm1d(num_features, eps=eps, momentum=momentum))

ChannelLayerNorm32

Bases: LayerNorm32

LayerNorm over the channel dim of a (B, C, *spatial) tensor (fp32).

Source code in warpconvnet/nn/modules/normalizations.py
class ChannelLayerNorm32(LayerNorm32):
    """LayerNorm over the channel dim of a `(B, C, *spatial)` tensor (fp32)."""

    def forward(self, x: Tensor) -> Tensor:
        DIM = x.dim()
        x = x.permute(0, *range(2, DIM), 1).contiguous()
        x = super().forward(x)
        return x.permute(0, DIM - 1, *range(1, DIM - 1)).contiguous()

GroupNorm

Bases: NormalizationBase

Applies torch.nn.GroupNorm to Geometry features.

Parameters:
  • num_groups (int) –

    Number of groups to separate the channels into.

  • num_channels (int) –

    Number of channels expected in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

Source code in warpconvnet/nn/modules/normalizations.py
class GroupNorm(NormalizationBase):
    """Applies `torch.nn.GroupNorm` to ``Geometry`` features.

    Parameters
    ----------
    num_groups : int
        Number of groups to separate the channels into.
    num_channels : int
        Number of channels expected in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    """

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        super().__init__(nn.GroupNorm(num_groups, num_channels, eps=eps))

GroupNorm32

Bases: GroupNorm

torch.nn.GroupNorm that always computes in fp32 then casts back.

Source code in warpconvnet/nn/modules/normalizations.py
class GroupNorm32(nn.GroupNorm):
    """`torch.nn.GroupNorm` that always computes in fp32 then casts back."""

    def forward(self, x: Tensor) -> Tensor:
        x_dtype = x.dtype
        return super().forward(x.float()).to(x_dtype)

InstanceNorm

Bases: NormalizationBase

Applies torch.nn.InstanceNorm1d to Geometry features.

Parameters:
  • num_features (int) –

    Number of feature channels in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

Source code in warpconvnet/nn/modules/normalizations.py
class InstanceNorm(NormalizationBase):
    """Applies `torch.nn.InstanceNorm1d` to ``Geometry`` features.

    Parameters
    ----------
    num_features : int
        Number of feature channels in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    """

    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__(nn.InstanceNorm1d(num_features, eps=eps))

LayerNorm

Bases: NormalizationBase

Applies torch.nn.LayerNorm to Geometry features.

Parameters:
  • normalized_shape (list of int) –

    Input shape from an expected input.

  • eps (float, default: 1e-05 ) –

    A value added to the denominator for numerical stability. Defaults to 1e-5.

  • elementwise_affine (bool, default: True ) –

    Whether to learn elementwise affine parameters. Defaults to True.

  • bias (bool, default: True ) –

    If True adds bias parameters. Defaults to True.

Source code in warpconvnet/nn/modules/normalizations.py
class LayerNorm(NormalizationBase):
    """Applies `torch.nn.LayerNorm` to ``Geometry`` features.

    Parameters
    ----------
    normalized_shape : list of int
        Input shape from an expected input.
    eps : float, optional
        A value added to the denominator for numerical stability. Defaults to ``1e-5``.
    elementwise_affine : bool, optional
        Whether to learn elementwise affine parameters. Defaults to ``True``.
    bias : bool, optional
        If ``True`` adds bias parameters. Defaults to ``True``.
    """

    def __init__(
        self,
        normalized_shape: List[int],
        eps: float = 1e-5,
        elementwise_affine: bool = True,
        bias: bool = True,
    ):
        super().__init__(
            nn.LayerNorm(
                normalized_shape,
                eps=eps,
                elementwise_affine=elementwise_affine,
                bias=bias,
            )
        )

LayerNorm32

Bases: LayerNorm

torch.nn.LayerNorm that always computes in fp32 then casts back.

Source code in warpconvnet/nn/modules/normalizations.py
class LayerNorm32(nn.LayerNorm):
    """`torch.nn.LayerNorm` that always computes in fp32 then casts back."""

    def forward(self, x: Tensor) -> Tensor:
        x_dtype = x.dtype
        return super().forward(x.float()).to(x_dtype)

MultiHeadRMSNorm

Bases: Module

RMSNorm applied independently per attention head.

Parameters:
  • dim (int) –

    Per-head feature dim (i.e. channels // num_heads).

  • heads (int) –

    Number of attention heads.

Source code in warpconvnet/nn/modules/normalizations.py
class MultiHeadRMSNorm(nn.Module):
    """RMSNorm applied independently per attention head.

    Parameters
    ----------
    dim : int
        Per-head feature dim (i.e. ``channels // num_heads``).
    heads : int
        Number of attention heads.
    """

    def __init__(self, dim: int, heads: int):
        super().__init__()
        self.scale = dim**0.5
        self.gamma = nn.Parameter(torch.ones(heads, dim))

    def forward(self, x: Tensor) -> Tensor:
        return (torch.nn.functional.normalize(x.float(), dim=-1) * self.gamma * self.scale).to(
            x.dtype
        )

NormalizationBase

Bases: BaseSpatialModule

Wrapper for applying a normalization module to Geometry features.

Parameters:
  • norm (``nn.Module``) –

    Normalization module to apply to the feature tensor.

Source code in warpconvnet/nn/modules/normalizations.py
class NormalizationBase(BaseSpatialModule):
    """Wrapper for applying a normalization module to ``Geometry`` features.

    Parameters
    ----------
    norm : ``nn.Module``
        Normalization module to apply to the feature tensor.
    """

    def __init__(self, norm: nn.Module):
        super().__init__()
        self.norm = norm

    def __repr__(self):
        return f"{self.__class__.__name__}({self.norm})"

    def forward(
        self,
        input: Union[Geometry, Tensor],
    ):
        return apply_feature_transform(input, self.norm)

RMSNorm

Bases: NormalizationBase

Applies torch.nn.RMSNorm to Geometry features.

Parameters:
  • dim (int) –

    Number of input features.

  • eps (float, default: 1e-06 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-6.

Source code in warpconvnet/nn/modules/normalizations.py
class RMSNorm(NormalizationBase):
    """Applies `torch.nn.RMSNorm` to ``Geometry`` features.

    Parameters
    ----------
    dim : int
        Number of input features.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-6``.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__(nn.RMSNorm(dim, eps=eps))

SegmentedLayerNorm

Bases: LayerNorm

Layer normalization that respects variable-length segments.

Parameters:
  • channels (int) –

    Number of feature channels.

  • eps (float, default: 1e-05 ) –

    A value added to the denominator for numerical stability. Defaults to 1e-5.

  • elementwise_affine (bool, default: True ) –

    If True learn per-channel affine parameters. Defaults to True.

  • bias (bool, default: True ) –

    Whether to include a bias term. Defaults to True.

Source code in warpconvnet/nn/modules/normalizations.py
class SegmentedLayerNorm(nn.LayerNorm):
    """Layer normalization that respects variable-length segments.

    Parameters
    ----------
    channels : int
        Number of feature channels.
    eps : float, optional
        A value added to the denominator for numerical stability. Defaults to ``1e-5``.
    elementwise_affine : bool, optional
        If ``True`` learn per-channel affine parameters. Defaults to ``True``.
    bias : bool, optional
        Whether to include a bias term. Defaults to ``True``.
    """

    def __init__(
        self,
        channels: int,
        eps: float = 1e-5,
        elementwise_affine: bool = True,
        bias: bool = True,
    ):
        super().__init__([channels], eps=eps, elementwise_affine=elementwise_affine, bias=bias)

    def forward(self, x: Geometry):
        # Only works for geometry with batched features
        assert isinstance(
            x, Geometry
        ), f"SegmentedLayerNorm only works for Geometry, got {type(x)}"
        out_feature = segmented_layer_norm(
            x.feature_tensor,
            x.offsets,
            gamma=self.weight if self.elementwise_affine else None,
            beta=self.bias if self.elementwise_affine else None,
            eps=self.eps,
        )
        return x.replace(batched_feature=out_feature)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.weight.shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine})"

Point convolution

warpconvnet.nn.modules.point_conv

PointConv

Bases: BaseSpatialModule

Point convolution operating on warpconvnet.geometry.types.points.Points.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • neighbor_search_args (`warpconvnet.geometry.coords.search.search_configs.RealSearchConfig`) –

    Configuration for neighbor search.

  • pooling_reduction (`warpconvnet.ops.reductions.REDUCTIONS`, default: None ) –

    Reduction used when downsampling points. Required when out_point_type is "downsample".

  • pooling_voxel_size (float, default: None ) –

    Size of voxels used for downsampling when out_point_type is "downsample".

  • edge_transform_mlp (Module, default: None ) –

    MLP applied to constructed edge features.

  • out_transform_mlp (Module, default: None ) –

    MLP applied after neighborhood reduction.

  • mlp_block (Module, default: MLPBlock ) –

    Module used to build edge_transform_mlp and out_transform_mlp when not provided. Defaults to MLPBlock.

  • hidden_dim (int, default: None ) –

    Hidden dimension of the automatically created MLPs.

  • channel_multiplier (int, default: 2 ) –

    Multiplier used when hidden_dim is None. Defaults to 2.

  • use_rel_pos (bool, default: False ) –

    Include relative coordinates of neighbors as part of the edge features.

  • use_rel_pos_encode (bool, default: False ) –

    Include sinusoidal encoding of relative coordinates in the edge features.

  • pos_encode_dim (int, default: 32 ) –

    Dimension of the positional encoding. Defaults to 32.

  • pos_encode_range (float, default: 4 ) –

    Range of the positional encoding. Defaults to 4.

  • reductions (list of str, default: ('mean',) ) –

    Reductions applied over the neighbor dimension. Defaults to ("mean",).

  • out_point_type ((provided, downsample, same), default: "provided" ) –

    Determines the coordinate set on which output features are computed.

  • provided_in_channels (int, default: None ) –

    Number of channels of the provided query points when out_point_type is "provided".

  • bias (bool, default: True ) –

    Whether linear layers contain biases. Defaults to True.

Source code in warpconvnet/nn/modules/point_conv.py
class PointConv(BaseSpatialModule):
    """Point convolution operating on `warpconvnet.geometry.types.points.Points`.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    neighbor_search_args : `warpconvnet.geometry.coords.search.search_configs.RealSearchConfig`
        Configuration for neighbor search.
    pooling_reduction : `warpconvnet.ops.reductions.REDUCTIONS`, optional
        Reduction used when downsampling points. Required when ``out_point_type`` is
        ``"downsample"``.
    pooling_voxel_size : float, optional
        Size of voxels used for downsampling when ``out_point_type`` is ``"downsample"``.
    edge_transform_mlp : nn.Module, optional
        MLP applied to constructed edge features.
    out_transform_mlp : nn.Module, optional
        MLP applied after neighborhood reduction.
    mlp_block : nn.Module, optional
        Module used to build ``edge_transform_mlp`` and ``out_transform_mlp`` when not
        provided. Defaults to `MLPBlock`.
    hidden_dim : int, optional
        Hidden dimension of the automatically created MLPs.
    channel_multiplier : int, optional
        Multiplier used when ``hidden_dim`` is ``None``. Defaults to ``2``.
    use_rel_pos : bool, optional
        Include relative coordinates of neighbors as part of the edge features.
    use_rel_pos_encode : bool, optional
        Include sinusoidal encoding of relative coordinates in the edge features.
    pos_encode_dim : int, optional
        Dimension of the positional encoding. Defaults to ``32``.
    pos_encode_range : float, optional
        Range of the positional encoding. Defaults to ``4``.
    reductions : list of str, optional
        Reductions applied over the neighbor dimension. Defaults to ``("mean",)``.
    out_point_type : {"provided", "downsample", "same"}, optional
        Determines the coordinate set on which output features are computed.
    provided_in_channels : int, optional
        Number of channels of the provided query points when ``out_point_type`` is ``"provided"``.
    bias : bool, optional
        Whether linear layers contain biases. Defaults to ``True``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neighbor_search_args: RealSearchConfig,
        pooling_reduction: Optional[REDUCTIONS] = None,
        pooling_voxel_size: Optional[float] = None,
        edge_transform_mlp: Optional[nn.Module] = None,
        out_transform_mlp: Optional[nn.Module] = None,
        mlp_block: nn.Module = MLPBlock,
        hidden_dim: Optional[int] = None,
        channel_multiplier: int = 2,
        use_rel_pos: bool = False,
        use_rel_pos_encode: bool = False,
        pos_encode_dim: int = 32,
        pos_encode_range: float = 4,
        reductions: List[REDUCTION_TYPES_STR] = ("mean",),
        out_point_type: Literal["provided", "downsample", "same"] = "same",
        provided_in_channels: Optional[int] = None,
        bias: bool = True,
    ):
        super().__init__()
        assert (
            isinstance(reductions, (tuple, list)) and len(reductions) > 0
        ), f"reductions must be a list or tuple of length > 0, got {reductions}"
        if out_point_type == "provided":
            assert pooling_reduction is None
            assert pooling_voxel_size is None
            assert (
                provided_in_channels is not None
            ), "provided_in_channels must be provided for provided type"
        elif out_point_type == "downsample":
            assert (
                pooling_reduction is not None and pooling_voxel_size is not None
            ), "pooling_reduction and pooling_voxel_size must be provided for downsample type"
            assert (
                provided_in_channels is None
            ), "provided_in_channels must be None for downsample type"
            # print warning if search radius is not \sqrt(3) times the downsample voxel size
            if (
                pooling_voxel_size is not None
                and neighbor_search_args.mode == RealSearchMode.RADIUS
                and neighbor_search_args.radius < pooling_voxel_size * (3**0.5)
            ):
                warnings.warn(
                    f"neighbor search radius {neighbor_search_args.radius} is less than sqrt(3) times the downsample voxel size {pooling_voxel_size}",
                    stacklevel=2,
                )
        elif out_point_type == "same":
            assert (
                pooling_reduction is None and pooling_voxel_size is None
            ), "pooling_reduction and pooling_voxel_size must be None for same type"
            assert (
                provided_in_channels is None
            ), "provided_in_channels must be None for same type"
        if (
            pooling_reduction is not None
            and pooling_voxel_size is not None
            and neighbor_search_args.mode == RealSearchMode.RADIUS
            and pooling_voxel_size > neighbor_search_args.radius
        ):
            raise ValueError(
                f"downsample_voxel_size {pooling_voxel_size} must be <= radius {neighbor_search_args.radius}"
            )

        assert isinstance(neighbor_search_args, RealSearchConfig)
        self.reductions = reductions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_encode = use_rel_pos_encode
        self.out_point_feature_type = out_point_type
        self.neighbor_search_args = neighbor_search_args
        self.pooling_reduction = pooling_reduction
        self.pooling_voxel_size = pooling_voxel_size
        self.positional_encoding = SinusoidalEncoding(
            pos_encode_dim, data_range=pos_encode_range
        )
        # When down voxel size is not None, there will be out_point_features will be provided as an additional input
        if provided_in_channels is None:
            provided_in_channels = in_channels
        if hidden_dim is None:
            hidden_dim = channel_multiplier * max(out_channels, in_channels)
        if edge_transform_mlp is None:
            edge_in_channels = in_channels + provided_in_channels
            if use_rel_pos_encode:
                edge_in_channels += pos_encode_dim * 3
            elif use_rel_pos:
                edge_in_channels += 3
            edge_transform_mlp = mlp_block(
                in_channels=edge_in_channels,
                out_channels=out_channels,
                hidden_channels=hidden_dim,
                bias=bias,
            )
        self.edge_transform_mlp = edge_transform_mlp
        self.edge_mlp_in_channels = _get_module_input_channel(edge_transform_mlp)
        if out_transform_mlp is None:
            out_transform_mlp = mlp_block(
                in_channels=out_channels * len(reductions),
                out_channels=out_channels,
                hidden_channels=hidden_dim,
                bias=bias,
            )
        self.out_transform_mlp = out_transform_mlp

    def __repr__(self):
        out_str = f"{self.__class__.__name__}(in_channels={self.in_channels} out_channels={self.out_channels}"
        if self.use_rel_pos_encode:
            out_str += f" rel_pos_encode={self.use_rel_pos_encode}"
        if self.pooling_reduction is not None:
            out_str += f" pooling={self.pooling_reduction}"
        if self.neighbor_search_args is not None:
            out_str += f" neighbor={self.neighbor_search_args}"
        out_str += ")"
        return out_str

    def forward(
        self,
        in_pc: Points,
        query_pc: Optional[Points] = None,
    ) -> Points:
        """
        When out_point_features is None, the output will be generated on the
        in_point_features.batched_coordinates.
        """
        if self.out_point_feature_type == "provided":
            assert (
                query_pc is not None
            ), "query_point_features must be provided for the provided type"
        elif self.out_point_feature_type == "downsample":
            assert query_pc is None
            query_pc = in_pc.voxel_downsample(
                self.pooling_voxel_size,
                reduction=self.pooling_reduction,
            )
        elif self.out_point_feature_type == "same":
            assert query_pc is None
            query_pc = in_pc

        in_num_channels = in_pc.num_channels
        query_num_channels = query_pc.num_channels
        assert (
            in_num_channels
            + query_num_channels
            + self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
            + (not self.use_rel_pos_encode) * self.use_rel_pos * 3
            == self.edge_mlp_in_channels
        ), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"

        # Get the neighbors
        neighbors = in_pc.neighbors(
            query_coords=query_pc.batched_coordinates,
            search_args=self.neighbor_search_args,
        )
        neighbor_indices = neighbors.neighbor_indices.long().view(-1)
        neighbor_row_splits = neighbors.neighbor_row_splits
        num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]

        # repeat the self features using num_reps
        rep_in_features = in_pc.feature_tensor[neighbor_indices]
        self_features = torch.repeat_interleave(
            query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
            num_reps,
            dim=0,
        )
        edge_features = [rep_in_features, self_features]
        if self.use_rel_pos or self.use_rel_pos_encode:
            in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
            self_vertices = torch.repeat_interleave(
                query_pc.coordinate_tensor.view(-1, 3).contiguous(),
                num_reps,
                dim=0,
            )
            rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
            if self.use_rel_pos_encode:
                edge_features.append(self.positional_encoding(rel_coords))
            elif self.use_rel_pos:
                edge_features.append(rel_coords)
        edge_features = torch.cat(edge_features, dim=1)
        edge_features = self.edge_transform_mlp(edge_features)
        # if in_weight is not None:
        #     assert in_weight.shape[0] == in_point_features.features.shape[0]
        #     rep_weights = in_weight[neighbor_indices]
        #     edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)

        out_features = []
        for reduction in self.reductions:
            out_features.append(
                row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
            )
        out_features = torch.cat(out_features, dim=-1)
        out_features = self.out_transform_mlp(out_features)

        return Points(
            batched_coordinates=Coords(
                batched_tensor=query_pc.coordinate_tensor,
                offsets=query_pc.offsets,
            ),
            batched_features=out_features,
            **query_pc.extra_attributes,
        )

forward(in_pc: Points, query_pc: Optional[Points] = None) -> Points

When out_point_features is None, the output will be generated on the in_point_features.batched_coordinates.

Source code in warpconvnet/nn/modules/point_conv.py
def forward(
    self,
    in_pc: Points,
    query_pc: Optional[Points] = None,
) -> Points:
    """
    When out_point_features is None, the output will be generated on the
    in_point_features.batched_coordinates.
    """
    if self.out_point_feature_type == "provided":
        assert (
            query_pc is not None
        ), "query_point_features must be provided for the provided type"
    elif self.out_point_feature_type == "downsample":
        assert query_pc is None
        query_pc = in_pc.voxel_downsample(
            self.pooling_voxel_size,
            reduction=self.pooling_reduction,
        )
    elif self.out_point_feature_type == "same":
        assert query_pc is None
        query_pc = in_pc

    in_num_channels = in_pc.num_channels
    query_num_channels = query_pc.num_channels
    assert (
        in_num_channels
        + query_num_channels
        + self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
        + (not self.use_rel_pos_encode) * self.use_rel_pos * 3
        == self.edge_mlp_in_channels
    ), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"

    # Get the neighbors
    neighbors = in_pc.neighbors(
        query_coords=query_pc.batched_coordinates,
        search_args=self.neighbor_search_args,
    )
    neighbor_indices = neighbors.neighbor_indices.long().view(-1)
    neighbor_row_splits = neighbors.neighbor_row_splits
    num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]

    # repeat the self features using num_reps
    rep_in_features = in_pc.feature_tensor[neighbor_indices]
    self_features = torch.repeat_interleave(
        query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
        num_reps,
        dim=0,
    )
    edge_features = [rep_in_features, self_features]
    if self.use_rel_pos or self.use_rel_pos_encode:
        in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
        self_vertices = torch.repeat_interleave(
            query_pc.coordinate_tensor.view(-1, 3).contiguous(),
            num_reps,
            dim=0,
        )
        rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
        if self.use_rel_pos_encode:
            edge_features.append(self.positional_encoding(rel_coords))
        elif self.use_rel_pos:
            edge_features.append(rel_coords)
    edge_features = torch.cat(edge_features, dim=1)
    edge_features = self.edge_transform_mlp(edge_features)
    # if in_weight is not None:
    #     assert in_weight.shape[0] == in_point_features.features.shape[0]
    #     rep_weights = in_weight[neighbor_indices]
    #     edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)

    out_features = []
    for reduction in self.reductions:
        out_features.append(
            row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
        )
    out_features = torch.cat(out_features, dim=-1)
    out_features = self.out_transform_mlp(out_features)

    return Points(
        batched_coordinates=Coords(
            batched_tensor=query_pc.coordinate_tensor,
            offsets=query_pc.offsets,
        ),
        batched_features=out_features,
        **query_pc.extra_attributes,
    )

Point pooling

warpconvnet.nn.modules.point_pool

PointAvgPool

Bases: PointPoolBase

Point pooling using mean reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointAvgPool(PointPoolBase):
    """Point pooling using ``mean`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.MEAN,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointMaxPool

Bases: PointPoolBase

Point pooling using max reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointMaxPool(PointPoolBase):
    """Point pooling using ``max`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.MAX,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointPoolBase

Bases: BaseSpatialModule

Base module for pooling points or voxels.

Parameters:
  • reduction (str or `REDUCTIONS`, default: MAX ) –

    Reduction method used when merging features. Defaults to REDUCTIONS.MAX.

  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • unique_method ((torch, ravel, morton), default: "torch" ) –

    Method used to find unique voxel indices. Defaults to "torch".

  • avereage_pooled_coordinates (bool, default: False ) –

    If True average coordinates of points within each voxel. Defaults to False.

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointPoolBase(BaseSpatialModule):
    """Base module for pooling points or voxels.

    Parameters
    ----------
    reduction : str or `REDUCTIONS`, optional
        Reduction method used when merging features. Defaults to ``REDUCTIONS.MAX``.
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    unique_method : {"torch", "ravel", "morton"}, optional
        Method used to find unique voxel indices. Defaults to ``"torch"``.
    avereage_pooled_coordinates : bool, optional
        If ``True`` average coordinates of points within each voxel. Defaults to ``False``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        reduction: Union[str, REDUCTIONS] = REDUCTIONS.MAX,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        unique_method: Literal["torch", "ravel", "morton"] = "torch",
        avereage_pooled_coordinates: bool = False,
        return_neighbor_search_result: bool = False,
    ):
        super().__init__()
        if isinstance(reduction, str):
            reduction = REDUCTIONS(reduction)
        self.reduction = reduction
        self.downsample_max_num_points = downsample_max_num_points
        self.downsample_voxel_size = downsample_voxel_size
        self.return_type = return_type
        self.return_neighbor_search_result = return_neighbor_search_result
        self.unique_method = unique_method
        self.avereage_pooled_coordinates = avereage_pooled_coordinates

    def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResult]]:
        return point_pool(
            pc=pc,
            reduction=self.reduction,
            downsample_max_num_points=self.downsample_max_num_points,
            downsample_voxel_size=self.downsample_voxel_size,
            return_type=self.return_type,
            return_neighbor_search_result=self.return_neighbor_search_result,
            unique_method=self.unique_method,
            average_pooled_coordinates=self.avereage_pooled_coordinates,
        )

PointSumPool

Bases: PointPoolBase

Point pooling using sum reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointSumPool(PointPoolBase):
    """Point pooling using ``sum`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.SUM,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointUnpool

Bases: BaseSpatialModule

Undo point pooling by scattering pooled features back to the input cloud.

Parameters:
  • unpooling_mode ((repeat, interpolate), default: "repeat" ) –

    Strategy used when unpooling features. Defaults to FEATURE_UNPOOLING_MODE.REPEAT.

  • concat_unpooled_pc (bool, default: False ) –

    If True concatenate the unpooled point cloud with the input. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointUnpool(BaseSpatialModule):
    """Undo point pooling by scattering pooled features back to the input cloud.

    Parameters
    ----------
    unpooling_mode : {"repeat", "interpolate"} or `FEATURE_UNPOOLING_MODE`, optional
        Strategy used when unpooling features. Defaults to ``FEATURE_UNPOOLING_MODE.REPEAT``.
    concat_unpooled_pc : bool, optional
        If ``True`` concatenate the unpooled point cloud with the input. Defaults to ``False``.
    """

    def __init__(
        self,
        unpooling_mode: Union[
            str, FEATURE_UNPOOLING_MODE
        ] = FEATURE_UNPOOLING_MODE.REPEAT,
        concat_unpooled_pc: bool = False,
    ):
        super().__init__()
        if isinstance(unpooling_mode, str):
            unpooling_mode = FEATURE_UNPOOLING_MODE(unpooling_mode)
        self.unpooling_mode = unpooling_mode
        self.concat_unpooled_pc = concat_unpooled_pc

    def forward(self, pooled_pc: Points, unpooled_pc: Points):
        return point_unpool(
            pooled_pc=pooled_pc,
            unpooled_pc=unpooled_pc,
            unpooling_mode=self.unpooling_mode,
            concat_unpooled_pc=self.concat_unpooled_pc,
        )

Prune

warpconvnet.nn.modules.prune

SparsePrune

Bases: BaseSpatialModule

Module wrapper around prune_spatially_sparse_tensor so pruning can be composed in nn.Sequential.

Forward Args

spatial_tensor : Geometry Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered. mask : Bool[Tensor, "N"] Boolean mask aligned with spatial_tensor.coordinate_tensor.

Source code in warpconvnet/nn/modules/prune.py
class SparsePrune(BaseSpatialModule):
    """
    Module wrapper around ``prune_spatially_sparse_tensor`` so pruning can be composed in nn.Sequential.

    Forward Args
    -----------
    spatial_tensor : Geometry
        Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered.
    mask : Bool[Tensor, "N"]
        Boolean mask aligned with ``spatial_tensor.coordinate_tensor``.
    """

    def forward(
        self,
        spatial_tensor: Geometry,
        mask: Bool[Tensor, "N"],  # noqa: F821
    ) -> Geometry:
        return prune_spatially_sparse_tensor(spatial_tensor, mask)

Sequential

warpconvnet.nn.modules.sequential

GeometryWrapper

Bases: BaseSpatialModule

Wrapper for a spatial module that returns a geometry object.

Source code in warpconvnet/nn/modules/sequential.py
class GeometryWrapper(BaseSpatialModule):
    """Wrapper for a spatial module that returns a geometry object."""

    def __init__(self, module: Callable[[Tensor], Tensor]):
        super().__init__()
        self.module = module

    def forward(self, x: Geometry) -> Geometry:
        return x.replace(batched_features=self.module(x.feature_tensor))

Sequential

Bases: Sequential, BaseSpatialModule

Sequential module that allows for spatial and non-spatial layers to be chained together.

If the module has multiple consecutive non-spatial layers, then it will not create an intermediate spatial features object and will become more efficient.

Source code in warpconvnet/nn/modules/sequential.py
class Sequential(nn.Sequential, BaseSpatialModule):
    """
    Sequential module that allows for spatial and non-spatial layers to be chained together.

    If the module has multiple consecutive non-spatial layers, then it will not create an intermediate
    spatial features object and will become more efficient.
    """

    def forward(self, x: Geometry):
        assert isinstance(
            x, Geometry
        ), f"Expected BatchedSpatialFeatures, got {type(x)}"

        in_sf = x
        for module in self:
            x, in_sf = run_forward(module, x, in_sf)

        if isinstance(x, torch.Tensor):
            x = in_sf.replace(batched_features=x)

        return x

TupleSequential

Bases: Sequential, BaseSpatialModule

Sequential module that allows multiple inputs for a specified layer.

Source code in warpconvnet/nn/modules/sequential.py
class TupleSequential(Sequential, BaseSpatialModule):
    """
    Sequential module that allows multiple inputs for a specified layer.
    """

    def __init__(self, *args, tuple_layer: int):
        if len(args) == 1 and isinstance(args[0], Sequence):
            super().__init__(*args[0])
        else:
            super().__init__(*args)
        self.tuple_layer = tuple_layer

    def forward(self, *xs: Tuple[Geometry]):
        x = xs[0]
        in_sf = x
        for i, module in enumerate(self):
            if i == self.tuple_layer:
                x, in_sf = tuple_run_forward(module, (x, *xs[1:]), in_sf)
            else:
                x, in_sf = run_forward(module, x, in_sf)

        if isinstance(x, torch.Tensor):
            x = in_sf.replace(batched_features=x)

        return x

Sparse convolution

warpconvnet.nn.modules.sparse_conv

SparseConv2d

Bases: SpatiallySparseConv

2D sparse convolution.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • fwd_algo (str, default: None ) –

    Forward (AB gather-scatter) algorithm.

  • dgrad_algo (str, default: None ) –

    Dgrad (AB gather-scatter) algorithm.

  • wgrad_algo (str, default: None ) –

    Wgrad (AtB gather-gather) algorithm.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv2d(SpatiallySparseConv):
    """2D sparse convolution.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    fwd_algo : str, optional
        Forward (AB gather-scatter) algorithm.
    dgrad_algo : str, optional
        Dgrad (AB gather-scatter) algorithm.
    wgrad_algo : str, optional
        Wgrad (AtB gather-gather) algorithm.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        bias=True,
        transposed=False,
        generative: bool = False,
        groups: int = 1,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        fwd_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        dgrad_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        wgrad_algo: Optional[Union[SPARSE_CONV_ATB_ALGO_MODE, str]] = None,
        kernel_matmul_batch_size: int = 2,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        use_fp16_accum: Optional[bool] = None,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            groups=groups,
            num_spatial_dims=2,
            stride_mode=stride_mode,
            fwd_algo=fwd_algo,
            dgrad_algo=dgrad_algo,
            wgrad_algo=wgrad_algo,
            kernel_matmul_batch_size=kernel_matmul_batch_size,
            order=order,
            compute_dtype=compute_dtype,
            use_fp16_accum=use_fp16_accum,
        )

SparseConv3d

Bases: SpatiallySparseConv

3D sparse convolution.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • fwd_algo (str, default: None ) –

    Forward (AB gather-scatter) algorithm.

  • dgrad_algo (str, default: None ) –

    Dgrad (AB gather-scatter) algorithm.

  • wgrad_algo (str, default: None ) –

    Wgrad (AtB gather-gather) algorithm.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv3d(SpatiallySparseConv):
    """3D sparse convolution.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    fwd_algo : str, optional
        Forward (AB gather-scatter) algorithm.
    dgrad_algo : str, optional
        Dgrad (AB gather-scatter) algorithm.
    wgrad_algo : str, optional
        Wgrad (AtB gather-gather) algorithm.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        bias=True,
        transposed=False,
        generative: bool = False,
        groups: int = 1,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        fwd_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        dgrad_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        wgrad_algo: Optional[Union[SPARSE_CONV_ATB_ALGO_MODE, str]] = None,
        kernel_matmul_batch_size: int = 2,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        use_fp16_accum: Optional[bool] = None,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            groups=groups,
            num_spatial_dims=3,
            stride_mode=stride_mode,
            fwd_algo=fwd_algo,
            dgrad_algo=dgrad_algo,
            wgrad_algo=wgrad_algo,
            kernel_matmul_batch_size=kernel_matmul_batch_size,
            order=order,
            compute_dtype=compute_dtype,
            use_fp16_accum=use_fp16_accum,
        )

SpatiallySparseConv

Bases: BaseSpatialModule

Sparse convolution layer for warpconvnet.geometry.types.voxels.Voxels.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • kernel_matmul_batch_size (int, default: 2 ) –

    Batch size used for implicit matrix multiplications. Defaults to 2.

  • num_spatial_dims (int, default: 3 ) –

    Number of spatial dimensions. Defaults to 3.

  • fwd_algo (`SPARSE_CONV_AB_ALGO_MODE` or str, default: None ) –

    Forward (AB gather-scatter) algorithm. Defaults to environment setting.

  • dgrad_algo (`SPARSE_CONV_AB_ALGO_MODE` or str, default: None ) –

    Dgrad (AB gather-scatter) algorithm. Defaults to environment setting.

  • wgrad_algo (`SPARSE_CONV_ATB_ALGO_MODE` or str, default: None ) –

    Wgrad (AtB gather-gather) algorithm. Defaults to environment setting.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • order (`POINT_ORDERING`, default: RANDOM ) –

    Ordering of points in the output. Defaults to POINT_ORDERING.RANDOM.

  • compute_dtype (dtype, default: None ) –

    Data type used for intermediate computations.

  • implicit_matmul_fwd_block_size (int, default: None ) –

    CUDA block size for implicit forward matmuls.

  • implicit_matmul_bwd_block_size (int, default: None ) –

    CUDA block size for implicit backward matmuls.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SpatiallySparseConv(BaseSpatialModule):
    """Sparse convolution layer for `warpconvnet.geometry.types.voxels.Voxels`.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    kernel_matmul_batch_size : int, optional
        Batch size used for implicit matrix multiplications. Defaults to ``2``.
    num_spatial_dims : int, optional
        Number of spatial dimensions. Defaults to ``3``.
    fwd_algo : `SPARSE_CONV_AB_ALGO_MODE` or str, optional
        Forward (AB gather-scatter) algorithm. Defaults to environment setting.
    dgrad_algo : `SPARSE_CONV_AB_ALGO_MODE` or str, optional
        Dgrad (AB gather-scatter) algorithm. Defaults to environment setting.
    wgrad_algo : `SPARSE_CONV_ATB_ALGO_MODE` or str, optional
        Wgrad (AtB gather-gather) algorithm. Defaults to environment setting.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    order : `POINT_ORDERING`, optional
        Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
    compute_dtype : torch.dtype, optional
        Data type used for intermediate computations.
    implicit_matmul_fwd_block_size : int, optional
        CUDA block size for implicit forward matmuls.
    implicit_matmul_bwd_block_size : int, optional
        CUDA block size for implicit backward matmuls.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        groups: int = 1,
        kernel_matmul_batch_size: int = 2,
        num_spatial_dims: Optional[int] = 3,
        fwd_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        dgrad_algo: Optional[Union[SPARSE_CONV_AB_ALGO_MODE, str]] = None,
        wgrad_algo: Optional[Union[SPARSE_CONV_ATB_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        use_fp16_accum: Optional[bool] = None,
        implicit_matmul_fwd_block_size: Optional[int] = None,
        implicit_matmul_bwd_block_size: Optional[int] = None,
    ):
        super().__init__()
        self.num_spatial_dims = num_spatial_dims
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.use_fp16_accum = use_fp16_accum
        if in_channels % groups != 0:
            raise ValueError(f"in_channels ({in_channels}) must be divisible by groups ({groups})")
        if out_channels % groups != 0:
            raise ValueError(
                f"out_channels ({out_channels}) must be divisible by groups ({groups})"
            )

        _kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
        _stride = ntuple(stride, ndim=self.num_spatial_dims)
        _dilation = ntuple(dilation, ndim=self.num_spatial_dims)

        self.kernel_size = _kernel_size
        self.stride = _stride
        self.dilation = _dilation

        self.transposed = transposed
        self.generative = generative
        self.kernel_matmul_batch_size = kernel_matmul_batch_size

        # Use environment variable values if not explicitly provided
        if fwd_algo is None:
            fwd_algo = WARPCONVNET_FWD_ALGO_MODE
        if dgrad_algo is None:
            dgrad_algo = WARPCONVNET_DGRAD_ALGO_MODE
        if wgrad_algo is None:
            wgrad_algo = WARPCONVNET_WGRAD_ALGO_MODE

        def _parse_algo(algo, enum_cls):
            if isinstance(algo, str):
                return enum_cls(algo)
            return algo

        self.fwd_algo = _parse_algo(fwd_algo, SPARSE_CONV_AB_ALGO_MODE)
        self.dgrad_algo = _parse_algo(dgrad_algo, SPARSE_CONV_AB_ALGO_MODE)
        self.wgrad_algo = _parse_algo(wgrad_algo, SPARSE_CONV_ATB_ALGO_MODE)

        self.stride_mode = stride_mode
        self.order = order
        self.compute_dtype = compute_dtype
        self.implicit_matmul_fwd_block_size = implicit_matmul_fwd_block_size
        self.implicit_matmul_bwd_block_size = implicit_matmul_bwd_block_size

        self.bias: Optional[nn.Parameter] = None

        if groups == 1:
            self.weight = nn.Parameter(
                torch.randn(np.prod(_kernel_size), in_channels, out_channels)
            )
        else:
            # Group conv weight: [K, G, C_in/G, C_out/G]
            self.weight = nn.Parameter(
                torch.randn(
                    np.prod(_kernel_size), groups, in_channels // groups, out_channels // groups
                )
            )
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.bias = None  # Explicitly set to None if bias is False
        self.reset_parameters()  # Call after parameters are defined for the chosen backend

    def __repr__(self):
        # return class name and parameters that are not default
        out_str = f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}"
        if self.stride != 1:
            out_str += f", stride={self.stride}"
        if self.dilation != 1:
            out_str += f", dilation={self.dilation}"
        if self.groups != 1:
            out_str += f", groups={self.groups}"
        if self.transposed:
            out_str += f", transposed={self.transposed}"
        if self.generative:
            out_str += f", generative={self.generative}"
        if self.order != POINT_ORDERING.RANDOM:
            out_str += f", order={self.order}"
        out_str += ")"
        return out_str

    def _calculate_fan_in_and_fan_out(self):
        receptive_field_size = np.prod(self.kernel_size)
        # For group conv, each output unit sees only in_channels/groups input
        # channels (and symmetrically out_channels/groups output channels).
        # Matches torch.nn._ConvNd convention.
        fan_in = (self.in_channels // self.groups) * receptive_field_size
        fan_out = (self.out_channels // self.groups) * receptive_field_size
        return fan_in, fan_out

    def _calculate_correct_fan(self, mode: Literal["fan_in", "fan_out"]):
        mode = mode.lower()
        assert mode in ["fan_in", "fan_out"]

        fan_in, fan_out = self._calculate_fan_in_and_fan_out()
        return fan_in if mode == "fan_in" else fan_out

    def _custom_kaiming_uniform_(self, tensor, a=0, mode="fan_in", nonlinearity="leaky_relu"):
        fan = self._calculate_correct_fan(mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
        bound = math.sqrt(self.num_spatial_dims) * std
        with torch.no_grad():
            return tensor.uniform_(-bound, bound)

    @torch.no_grad()
    def reset_parameters(self):
        self._custom_kaiming_uniform_(
            self.weight,
            a=math.sqrt(5),
            mode="fan_out" if self.transposed else "fan_in",
        )

        if self.bias is not None:
            fan_in, _ = self._calculate_fan_in_and_fan_out()
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(
        self,
        input_sparse_tensor: Voxels,
        output_spatially_sparse_tensor: Optional[Voxels] = None,
    ):
        return spatially_sparse_conv(
            input_sparse_tensor=input_sparse_tensor,
            weight=self.weight,
            kernel_size=self.kernel_size,
            stride=self.stride,
            kernel_dilation=self.dilation,
            bias=self.bias,
            groups=self.groups,
            kernel_matmul_batch_size=self.kernel_matmul_batch_size,
            output_spatially_sparse_tensor=output_spatially_sparse_tensor,
            transposed=self.transposed,
            generative=self.generative,
            fwd_algo=self.fwd_algo,
            dgrad_algo=self.dgrad_algo,
            wgrad_algo=self.wgrad_algo,
            stride_mode=self.stride_mode,
            order=self.order,
            compute_dtype=self.compute_dtype,
            use_fp16_accum=self.use_fp16_accum,
            implicit_matmul_fwd_block_size=self.implicit_matmul_fwd_block_size,
            implicit_matmul_bwd_block_size=self.implicit_matmul_bwd_block_size,
        )

Sparse ConvNeXt

warpconvnet.nn.modules.sparse_convnext

ConvNeXt block on Voxels (sparse 3D analogue).

Standard ConvNeXt layout: SparseConv3d (k=3, submanifold) → conv LayerNorm (affine, fp32) → norm Linear → SiLU → Linear (zero-init) on .feats → mlp + skip (residual)

Uses warpconvnet's native SparseConv3d (weight layout (K^3, in_channels, out_channels)). Models that need to load checkpoints stored in a different sparse-conv weight layout (TRELLIS uses (Cout, Kd, Kh, Kw, Cin) in some checkpoints, for instance) ship their own block with a layout-compatible conv wrapper.

Sparse DiT

warpconvnet.nn.modules.sparse_dit

Sparse DiT blocks (adaLN-modulated transformer over Voxels).

Generic sparse counterpart to warpconvnet.nn.modules.dit. x is a Voxels token sequence; mod is a per-batch (B, C) (or (B, 6*C) when share_mod=True) conditioning tensor; context is either a dense (B, L, ctx_channels) tensor (image features) or a Voxels (sparse cross-attention).

ModulatedSparseTransformerBlock

Bases: Module

Sparse DiT block: norm1 → adaLN(MSA) → norm2 → adaLN(FFN).

Source code in warpconvnet/nn/modules/sparse_dit.py
class ModulatedSparseTransformerBlock(nn.Module):
    """Sparse DiT block: ``norm1 → adaLN(MSA) → norm2 → adaLN(FFN)``."""

    def __init__(
        self,
        channels: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: Literal["full"] = "full",
        use_checkpoint: bool = False,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
        qkv_bias: bool = True,
        share_mod: bool = False,
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.share_mod = share_mod
        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.attn = SparseMultiHeadAttention(
            channels,
            num_heads=num_heads,
            attn_mode=attn_mode,
            qkv_bias=qkv_bias,
            use_rope=use_rope,
            rope_freq=rope_freq,
            qk_rms_norm=qk_rms_norm,
        )
        self.mlp = SparseFeedForwardNet(channels, mlp_ratio=mlp_ratio)
        if not share_mod:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
            )
        else:
            self.modulation = nn.Parameter(torch.randn(6 * channels) / channels**0.5)

    def _split_mod(self, mod: torch.Tensor):
        if self.share_mod:
            return (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
        return self.adaLN_modulation(mod).chunk(6, dim=1)

    def _forward(self, x: Voxels, mod: torch.Tensor) -> Voxels:
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._split_mod(mod)
        # MSA branch ---------------------------------------------------------
        h_feats = self.norm1(x.feats)
        h_feats = h_feats * (1 + _per_voxel(x, scale_msa)) + _per_voxel(x, shift_msa)
        h = x.replace_features(h_feats)
        h = self.attn(h)
        h = h.replace_features(h.feats * _per_voxel(x, gate_msa))
        x = x + h
        # FFN branch ---------------------------------------------------------
        h_feats = self.norm2(x.feats)
        h_feats = h_feats * (1 + _per_voxel(x, scale_mlp)) + _per_voxel(x, shift_mlp)
        h = x.replace_features(h_feats)
        h = self.mlp(h)
        h = h.replace_features(h.feats * _per_voxel(x, gate_mlp))
        return x + h

    def forward(self, x: Voxels, mod: torch.Tensor) -> Voxels:
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
        return self._forward(x, mod)

ModulatedSparseTransformerCrossBlock

Bases: Module

Sparse DiT cross block: MSA → MCA → FFN with adaLN modulation.

The cross-attn norm (norm2) is the only norm with affine params, matching upstream trellis2.modules.sparse.transformer.modulated.

Source code in warpconvnet/nn/modules/sparse_dit.py
class ModulatedSparseTransformerCrossBlock(nn.Module):
    """Sparse DiT cross block: ``MSA → MCA → FFN`` with adaLN modulation.

    The cross-attn norm (``norm2``) is the only norm with affine params,
    matching upstream `trellis2.modules.sparse.transformer.modulated`.
    """

    def __init__(
        self,
        channels: int,
        ctx_channels: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: Literal["full"] = "full",
        use_checkpoint: bool = False,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
        qk_rms_norm_cross: bool = False,
        qkv_bias: bool = True,
        share_mod: bool = False,
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.share_mod = share_mod
        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
        self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
        self.self_attn = SparseMultiHeadAttention(
            channels,
            num_heads=num_heads,
            type="self",
            attn_mode=attn_mode,
            qkv_bias=qkv_bias,
            use_rope=use_rope,
            rope_freq=rope_freq,
            qk_rms_norm=qk_rms_norm,
        )
        self.cross_attn = SparseMultiHeadAttention(
            channels,
            ctx_channels=ctx_channels,
            num_heads=num_heads,
            type="cross",
            attn_mode="full",
            qkv_bias=qkv_bias,
            qk_rms_norm=qk_rms_norm_cross,
        )
        self.mlp = SparseFeedForwardNet(channels, mlp_ratio=mlp_ratio)
        if not share_mod:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
            )
        else:
            self.modulation = nn.Parameter(torch.randn(6 * channels) / channels**0.5)

    def _split_mod(self, mod: torch.Tensor):
        if self.share_mod:
            return (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
        return self.adaLN_modulation(mod).chunk(6, dim=1)

    def _forward(
        self,
        x: Voxels,
        mod: torch.Tensor,
        context: torch.Tensor | Voxels,
    ) -> Voxels:
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._split_mod(mod)
        # Self-attn ----------------------------------------------------------
        h_feats = self.norm1(x.feats)
        h_feats = h_feats * (1 + _per_voxel(x, scale_msa)) + _per_voxel(x, shift_msa)
        h = x.replace_features(h_feats)
        h = self.self_attn(h)
        h = h.replace_features(h.feats * _per_voxel(x, gate_msa))
        x = x + h
        # Cross-attn (no scale/shift/gate per upstream) ----------------------
        h = x.replace_features(self.norm2(x.feats))
        h = self.cross_attn(h, context)
        x = x + h
        # FFN ----------------------------------------------------------------
        h_feats = self.norm3(x.feats)
        h_feats = h_feats * (1 + _per_voxel(x, scale_mlp)) + _per_voxel(x, shift_mlp)
        h = x.replace_features(h_feats)
        h = self.mlp(h)
        h = h.replace_features(h.feats * _per_voxel(x, gate_mlp))
        return x + h

    def forward(
        self,
        x: Voxels,
        mod: torch.Tensor,
        context: torch.Tensor | Voxels,
    ) -> Voxels:
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(
                self._forward, x, mod, context, use_reentrant=False
            )
        return self._forward(x, mod, context)

SparseFeedForwardNet

Bases: Module

Linear-GELU(tanh)-Linear MLP applied per-voxel.

Keeps the upstream attribute layout (self.mlp[0]/self.mlp[2]) so state_dict keys remain mlp.mlp.0.weight etc.

Source code in warpconvnet/nn/modules/sparse_dit.py
class SparseFeedForwardNet(nn.Module):
    """Linear-GELU(tanh)-Linear MLP applied per-voxel.

    Keeps the upstream attribute layout (``self.mlp[0]``/``self.mlp[2]``) so
    state_dict keys remain ``mlp.mlp.0.weight`` etc.
    """

    def __init__(self, channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(channels, int(channels * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(channels * mlp_ratio), channels),
        )

    def forward(self, x: Voxels) -> Voxels:
        return x.replace_features(self.mlp(x.feats))

Sparse DiT attention

warpconvnet.nn.modules.sparse_dit_attention

Sparse multi-head attention on Voxels (varlen flash-attn backend).

Generic sparse-voxel attention surface used by sparse DiT-style diffusion models (e.g. TRELLIS.2 SLAT flow):

  • Self-attention on voxel tokens (full attention via varlen-packed qkv).
  • Cross-attention with sparse Q (voxels) and dense KV (image features).
  • Optional RoPE phases derived from voxel coordinates, with a per-instance cache so multi-block stacks compute phases once.
  • Optional qk-RMSNorm for attention-logit stability under bf16/fp16.

For dense (B, L, C) token attention see warpconvnet.nn.modules.dit.

SparseMultiHeadAttention

Bases: Module

Multi-head attention on a Voxels token sequence.

Self-attention modes share Q/K/V via a single to_qkv Linear; cross modes split into to_q and to_kv (with separate ctx_channels). Only attn_mode='full' is implemented in this Phase-6 cut; windowed attention is a Phase-8/9 extension if the downstream models need it.

The context for cross attention may be either a Voxels (sparse KV) or a dense (B, L, ctx_channels) Tensor (image features).

Source code in warpconvnet/nn/modules/sparse_dit_attention.py
class SparseMultiHeadAttention(nn.Module):
    """Multi-head attention on a `Voxels` token sequence.

    Self-attention modes share Q/K/V via a single ``to_qkv`` Linear; cross
    modes split into ``to_q`` and ``to_kv`` (with separate ``ctx_channels``).
    Only ``attn_mode='full'`` is implemented in this Phase-6 cut; windowed
    attention is a Phase-8/9 extension if the downstream models need it.

    The ``context`` for cross attention may be either a ``Voxels`` (sparse KV)
    or a dense ``(B, L, ctx_channels)`` Tensor (image features).
    """

    def __init__(
        self,
        channels: int,
        num_heads: int,
        ctx_channels: int | None = None,
        type: Literal["self", "cross"] = "self",
        attn_mode: Literal["full"] = "full",
        qkv_bias: bool = True,
        use_rope: bool = False,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
        qk_rms_norm: bool = False,
    ):
        super().__init__()
        assert channels % num_heads == 0
        assert type in ("self", "cross")
        if attn_mode != "full":
            raise NotImplementedError(
                "SparseMultiHeadAttention currently supports only attn_mode='full'"
            )
        if type == "cross" and use_rope:
            raise ValueError("Rotary position embeddings only supported for self-attn")

        self.channels = channels
        self.head_dim = channels // num_heads
        self.ctx_channels = ctx_channels if ctx_channels is not None else channels
        self.num_heads = num_heads
        self._type = type
        self.attn_mode = attn_mode
        self.use_rope = use_rope
        self.qk_rms_norm = qk_rms_norm

        if self._type == "self":
            self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
        else:
            self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
            self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)

        if qk_rms_norm:
            self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
            self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)

        self.to_out = nn.Linear(channels, channels)

        if use_rope:
            self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)

    # -- self attention --------------------------------------------------------
    def _forward_self(self, x: Voxels) -> Voxels:
        T = x.feats.shape[0]
        qkv = self.to_qkv(x.feats).reshape(T, 3, self.num_heads, self.head_dim)
        if self.qk_rms_norm or self.use_rope:
            q, k, v = qkv.unbind(dim=1)
            if self.qk_rms_norm:
                q = self.q_rms_norm(q)
                k = self.k_rms_norm(k)
            if self.use_rope:
                q, k = self.rope(x, q, k)
            qkv = torch.stack([q, k, v], dim=1)
        h = sparse_scaled_dot_product_attention(qkv, x)  # (T, H, D)
        h = h.reshape(T, -1)
        return x.replace_features(self.to_out(h))

    # -- cross attention -------------------------------------------------------
    def _forward_cross(
        self,
        x: Voxels,
        context: torch.Tensor | Voxels,
    ) -> Voxels:
        T = x.feats.shape[0]
        q = self.to_q(x.feats).reshape(T, self.num_heads, self.head_dim)
        if isinstance(context, Voxels):
            T_kv = context.feats.shape[0]
            kv = self.to_kv(context.feats).reshape(T_kv, 2, self.num_heads, self.head_dim)
            cu_q, max_q = _voxels_cu_seqlens(x)
            cu_kv, max_kv = _voxels_cu_seqlens(context)
            if self.qk_rms_norm:
                q = self.q_rms_norm(q)
                k, v = kv.unbind(dim=1)
                k = self.k_rms_norm(k)
                import flash_attn

                h = flash_attn.flash_attn_varlen_func(
                    q,
                    k,
                    v,
                    cu_q,
                    cu_kv,
                    max_seqlen_q=max_q,
                    max_seqlen_k=max_kv,
                )
            else:
                import flash_attn

                h = flash_attn.flash_attn_varlen_kvpacked_func(
                    q,
                    kv,
                    cu_q,
                    cu_kv,
                    max_seqlen_q=max_q,
                    max_seqlen_k=max_kv,
                )
        else:
            assert context.ndim == 3, "dense context must be (B, L, ctx_channels)"
            B, L, _ = context.shape
            kv = self.to_kv(context).reshape(B, L, 2, self.num_heads, self.head_dim)
            if self.qk_rms_norm:
                q = self.q_rms_norm(q)
                k, v = kv.unbind(dim=2)
                k = self.k_rms_norm(k)
                h = sparse_scaled_dot_product_attention(q, x, k, v)
            else:
                h = sparse_scaled_dot_product_attention(q, x, kv)
        h = h.reshape(T, -1)
        return x.replace_features(self.to_out(h))

    def forward(
        self,
        x: Voxels,
        context: torch.Tensor | Voxels | None = None,
    ) -> Voxels:
        if self._type == "self":
            return self._forward_self(x)
        assert context is not None, "cross-attention requires `context`"
        return self._forward_cross(x, context)

SparseRotaryPositionEmbedder

Bases: Module

Per-voxel RoPE phases. Phases are cached on the input Voxels so multi-block stacks pay the polar-conversion cost once.

Source code in warpconvnet/nn/modules/sparse_dit_attention.py
class SparseRotaryPositionEmbedder(nn.Module):
    """Per-voxel RoPE phases. Phases are cached on the input ``Voxels`` so
    multi-block stacks pay the polar-conversion cost once."""

    def __init__(
        self,
        head_dim: int,
        dim: int = 3,
        rope_freq: tuple[float, float] = (1.0, 10000.0),
    ):
        super().__init__()
        assert head_dim % 2 == 0, "head_dim must be even"
        self.head_dim = head_dim
        self.dim = dim
        self.rope_freq = rope_freq
        self.freq_dim = head_dim // 2 // dim
        freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
        self.freqs = rope_freq[0] / (rope_freq[1] ** freqs)

    def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
        self.freqs = self.freqs.to(indices.device)
        phases = torch.outer(indices, self.freqs)
        return torch.polar(torch.ones_like(phases), phases)

    @staticmethod
    def _rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        x_rot = x_complex * phases.unsqueeze(-2)
        return torch.view_as_real(x_rot).reshape(*x_rot.shape[:-1], -1).to(x.dtype)

    def _phases_for(self, voxels: Voxels) -> torch.Tensor:
        cache_key = (
            f"rope_phase_{self.dim}d"
            f"_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}"
        )
        cache = voxels.spatial_cache
        phases = cache.get(cache_key)
        if phases is not None:
            return phases
        coords = voxels.coords[..., 1:]  # drop batch col
        phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
        if phases.shape[-1] < self.head_dim // 2:
            pad_n = self.head_dim // 2 - phases.shape[-1]
            ones = torch.ones(*phases.shape[:-1], pad_n, device=phases.device)
            zeros = torch.zeros(*phases.shape[:-1], pad_n, device=phases.device)
            phases = torch.cat([phases, torch.polar(ones, zeros)], dim=-1)
        cache[cache_key] = phases
        return phases

    def forward(
        self,
        voxels: Voxels,
        q_feats: torch.Tensor,
        k_feats: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Rotate per-token Q/K features inplace-style.

        ``q_feats`` / ``k_feats`` are the raw ``(T, H, D)`` tensors taken from
        the voxels' feature column post-Linear. Coords come from ``voxels``.
        """
        phases = self._phases_for(voxels)
        q_out = self._rotary_embedding(q_feats, phases)
        if k_feats is None:
            return q_out
        return q_out, self._rotary_embedding(k_feats, phases)

forward(voxels: Voxels, q_feats: torch.Tensor, k_feats: torch.Tensor | None = None) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]

Rotate per-token Q/K features inplace-style.

q_feats / k_feats are the raw (T, H, D) tensors taken from the voxels' feature column post-Linear. Coords come from voxels.

Source code in warpconvnet/nn/modules/sparse_dit_attention.py
def forward(
    self,
    voxels: Voxels,
    q_feats: torch.Tensor,
    k_feats: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Rotate per-token Q/K features inplace-style.

    ``q_feats`` / ``k_feats`` are the raw ``(T, H, D)`` tensors taken from
    the voxels' feature column post-Linear. Coords come from ``voxels``.
    """
    phases = self._phases_for(voxels)
    q_out = self._rotary_embedding(q_feats, phases)
    if k_feats is None:
        return q_out
    return q_out, self._rotary_embedding(k_feats, phases)

sparse_scaled_dot_product_attention(*args) -> torch.Tensor

Dispatch full sparse attention.

Three call shapes supported (matching the subset of upstream trellis2.modules.sparse.attention.full_attn that the SLAT flow model uses):

  • (qkv_feats, voxels): self-attn. qkv_feats is (T, 3, H, D), sequences delimited by voxels.offsets. Returns (T, H, D).
  • (q_feats, voxels, kv_dense): cross-attn with dense KV. q_feats is (T_q, H, D); kv_dense is (B, L, 2, H, D). Returns (T_q, H, D).
  • (q_feats, voxels, k_dense, v_dense): same as above but K and V passed separately as (B, L, H, D).
Source code in warpconvnet/nn/modules/sparse_dit_attention.py
def sparse_scaled_dot_product_attention(
    *args,
) -> torch.Tensor:
    """Dispatch full sparse attention.

    Three call shapes supported (matching the subset of upstream
    ``trellis2.modules.sparse.attention.full_attn`` that the SLAT flow model
    uses):

    - ``(qkv_feats, voxels)``: self-attn. ``qkv_feats`` is ``(T, 3, H, D)``,
      sequences delimited by ``voxels.offsets``. Returns ``(T, H, D)``.
    - ``(q_feats, voxels, kv_dense)``: cross-attn with dense KV. ``q_feats``
      is ``(T_q, H, D)``; ``kv_dense`` is ``(B, L, 2, H, D)``. Returns
      ``(T_q, H, D)``.
    - ``(q_feats, voxels, k_dense, v_dense)``: same as above but K and V
      passed separately as ``(B, L, H, D)``.
    """
    import flash_attn  # noqa: WPS433 — lazy: CUDA-only

    if len(args) == 2:
        qkv, voxels = args
        cu_q, max_q = _voxels_cu_seqlens(voxels)
        return flash_attn_varlen_qkvpacked(qkv, cu_q, max_q)

    if len(args) == 3:
        q, voxels, kv_dense = args
        assert kv_dense.ndim == 5, "kv_dense must be (B, L, 2, H, D)"
        B, L, _, H, D = kv_dense.shape
        cu_q, max_q = _voxels_cu_seqlens(voxels)
        cu_kv = _dense_cu_seqlens(B, L, q.device)
        return flash_attn.flash_attn_varlen_kvpacked_func(
            q,
            kv_dense.reshape(B * L, 2, H, D),
            cu_q,
            cu_kv,
            max_seqlen_q=max_q,
            max_seqlen_k=L,
        )

    if len(args) == 4:
        q, voxels, k_dense, v_dense = args
        assert k_dense.ndim == 4 and v_dense.ndim == 4
        B, L, H, _ = k_dense.shape
        cu_q, max_q = _voxels_cu_seqlens(voxels)
        cu_kv = _dense_cu_seqlens(B, L, q.device)
        return flash_attn.flash_attn_varlen_func(
            q,
            k_dense.reshape(B * L, *k_dense.shape[2:]),
            v_dense.reshape(B * L, *v_dense.shape[2:]),
            cu_q,
            cu_kv,
            max_seqlen_q=max_q,
            max_seqlen_k=L,
        )

    raise ValueError(f"sparse_scaled_dot_product_attention: bad arity {len(args)}")

Sparse depthwise convolution

warpconvnet.nn.modules.sparse_conv_depth

SparseDepthwiseConv2d

Bases: SpatiallySparseDepthwiseConv

2D spatially sparse depthwise convolution.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv2d(SpatiallySparseDepthwiseConv):
    """2D spatially sparse depthwise convolution."""

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]] = 1,
        dilation: Union[int, Tuple[int, int]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(
            channels=channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=2,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            stride_mode=stride_mode,
            stride_reduce=stride_reduce,
            order=order,
            compute_dtype=compute_dtype,
        )

SparseDepthwiseConv3d

Bases: SpatiallySparseDepthwiseConv

3D spatially sparse depthwise convolution.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv3d(SpatiallySparseDepthwiseConv):
    """3D spatially sparse depthwise convolution."""

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, int, int]],
        stride: Union[int, Tuple[int, int, int]] = 1,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(
            channels=channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=3,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            stride_mode=stride_mode,
            stride_reduce=stride_reduce,
            order=order,
            compute_dtype=compute_dtype,
        )

SpatiallySparseDepthwiseConv

Bases: BaseSpatialModule

Spatially sparse depthwise convolution module.

In depthwise convolution, each input channel is convolved with its own kernel, so the number of input channels must equal the number of output channels. The weight shape is (K, C) where K is the kernel volume and C is the number of channels.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SpatiallySparseDepthwiseConv(BaseSpatialModule):
    """
    Spatially sparse depthwise convolution module.

    In depthwise convolution, each input channel is convolved with its own kernel,
    so the number of input channels must equal the number of output channels.
    The weight shape is (K, C) where K is the kernel volume and C is the number of channels.
    """

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        num_spatial_dims: int = 3,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.num_spatial_dims = num_spatial_dims
        self.channels = channels
        self.in_channels = channels  # For compatibility with PyTorch naming
        self.out_channels = channels  # For depthwise, in_channels == out_channels

        # Ensure kernel_size, stride, dilation are tuples for consistent use
        self.kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
        self.stride = ntuple(stride, ndim=self.num_spatial_dims)
        self.dilation = ntuple(dilation, ndim=self.num_spatial_dims)

        self.transposed = transposed
        self.generative = generative
        self.stride_reduce = stride_reduce

        # Use environment variable values if not explicitly provided
        if fwd_algo is None:
            fwd_algo = WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE
        if bwd_algo is None:
            bwd_algo = WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE

        # Map string algo names to depthwise-specific enums if needed
        if isinstance(fwd_algo, str):
            # Map generic algorithm names to depthwise-specific ones
            if fwd_algo.lower() in ["explicit", "explicit_gemm"]:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.EXPLICIT
            elif fwd_algo.lower() in ["implicit", "implicit_gemm"]:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.IMPLICIT
            elif fwd_algo.lower() == "auto":
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.AUTO
            else:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE(fwd_algo)
        else:
            self.fwd_algo = fwd_algo

        if isinstance(bwd_algo, str):
            # Map generic algorithm names to depthwise-specific ones
            if bwd_algo.lower() in ["explicit", "explicit_gemm"]:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.EXPLICIT
            elif bwd_algo.lower() in ["implicit", "implicit_gemm"]:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.IMPLICIT
            elif bwd_algo.lower() == "auto":
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.AUTO
            else:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE(bwd_algo)
        else:
            self.bwd_algo = bwd_algo

        self.stride_mode = stride_mode
        self.order = order
        self.compute_dtype = compute_dtype

        # Depthwise convolution weight shape: (K, C) where K is kernel volume
        kernel_volume = int(np.prod(self.kernel_size))
        self.weight = nn.Parameter(torch.randn(kernel_volume, channels))

        # Optional bias
        if bias:
            self.bias = nn.Parameter(torch.randn(channels))
        else:
            self.bias = None

        self.reset_parameters()

    def __repr__(self):
        out_str = (
            f"{self.__class__.__name__}(channels={self.channels}, "
            f"kernel_size={self.kernel_size}"
        )
        if self.stride != (1,) * self.num_spatial_dims:
            out_str += f", stride={self.stride}"
        if self.dilation != (1,) * self.num_spatial_dims:
            out_str += f", dilation={self.dilation}"
        if self.transposed:
            out_str += f", transposed={self.transposed}"
        if self.generative:
            out_str += f", generative={self.generative}"
        if self.order != POINT_ORDERING.RANDOM:
            out_str += f", order={self.order}"
        if self.bias is None:
            out_str += ", bias=False"
        out_str += ")"
        return out_str

    def _calculate_fan_in_and_fan_out(self):
        """Calculate fan_in and fan_out for depthwise convolution."""
        receptive_field_size = np.prod(self.kernel_size)
        # For depthwise convolution, each channel has its own kernel
        fan_in = receptive_field_size  # One kernel per channel
        fan_out = receptive_field_size  # One output per channel
        return fan_in, fan_out

    def _calculate_correct_fan(self, mode: str):
        """Calculate correct fan for initialization."""
        mode = mode.lower()
        assert mode in ["fan_in", "fan_out"]

        fan_in, fan_out = self._calculate_fan_in_and_fan_out()
        return fan_in if mode == "fan_in" else fan_out

    def _custom_kaiming_uniform_(
        self, tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"
    ):
        """Custom Kaiming uniform initialization for depthwise convolution."""
        fan = self._calculate_correct_fan(mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
        bound = math.sqrt(self.num_spatial_dims) * std
        with torch.no_grad():
            return tensor.uniform_(-bound, bound)

    @torch.no_grad()
    def reset_parameters(self):
        """Reset module parameters using appropriate initialization."""
        self._custom_kaiming_uniform_(
            self.weight,
            a=math.sqrt(5),
            mode="fan_out" if self.transposed else "fan_in",
        )

        if self.bias is not None:
            fan_in, _ = self._calculate_fan_in_and_fan_out()
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(
        self,
        input_sparse_tensor: Voxels,
        output_spatially_sparse_tensor: Optional[Voxels] = None,
    ) -> Voxels:
        """
        Forward pass for spatially sparse depthwise convolution.

        Args:
            input_sparse_tensor: Input sparse tensor
            output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

        Returns:
            Output sparse tensor
        """
        # Generate output coordinates and kernel map
        batch_indexed_out_coords, out_offsets, kernel_map = (
            generate_output_coords_and_kernel_map(
                input_sparse_tensor=input_sparse_tensor,
                kernel_size=self.kernel_size,
                kernel_dilation=self.dilation,
                stride=self.stride,
                generative=self.generative,
                transposed=self.transposed,
                output_spatially_sparse_tensor=output_spatially_sparse_tensor,
                stride_mode=self.stride_mode,
                order=self.order,
            )
        )

        num_out_coords = batch_indexed_out_coords.shape[0]

        # Apply depthwise convolution
        output_features = spatially_sparse_depthwise_conv(
            input_sparse_tensor.feature_tensor,
            self.weight,
            kernel_map,
            num_out_coords,
            fwd_algo=self.fwd_algo,
            bwd_algo=self.bwd_algo,
            compute_dtype=self.compute_dtype,
        )

        # Add bias if present
        if self.bias is not None:
            output_features = output_features + self.bias

        # Determine output tensor stride
        in_tensor_stride = input_sparse_tensor.tensor_stride
        if in_tensor_stride is None:
            in_tensor_stride = (1,) * self.num_spatial_dims

        if not self.transposed:
            out_tensor_stride = tuple(
                o * s for o, s in zip(self.stride, in_tensor_stride)
            )
        else:
            if (
                output_spatially_sparse_tensor is not None
                and output_spatially_sparse_tensor.tensor_stride is not None
            ):
                out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
            else:
                out_tensor_stride = (1,) * self.num_spatial_dims

        # Create output voxels
        out_offsets_cpu = out_offsets.cpu().int()
        out_coords = IntCoords(
            batch_indexed_out_coords[:, 1:],
            offsets=out_offsets_cpu,
        )
        return input_sparse_tensor.replace(
            batched_coordinates=out_coords,
            batched_features=output_features,
            tensor_stride=out_tensor_stride,
        )

forward(input_sparse_tensor: Voxels, output_spatially_sparse_tensor: Optional[Voxels] = None) -> Voxels

Forward pass for spatially sparse depthwise convolution.

Args: input_sparse_tensor: Input sparse tensor output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

Returns: Output sparse tensor

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
def forward(
    self,
    input_sparse_tensor: Voxels,
    output_spatially_sparse_tensor: Optional[Voxels] = None,
) -> Voxels:
    """
    Forward pass for spatially sparse depthwise convolution.

    Args:
        input_sparse_tensor: Input sparse tensor
        output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

    Returns:
        Output sparse tensor
    """
    # Generate output coordinates and kernel map
    batch_indexed_out_coords, out_offsets, kernel_map = (
        generate_output_coords_and_kernel_map(
            input_sparse_tensor=input_sparse_tensor,
            kernel_size=self.kernel_size,
            kernel_dilation=self.dilation,
            stride=self.stride,
            generative=self.generative,
            transposed=self.transposed,
            output_spatially_sparse_tensor=output_spatially_sparse_tensor,
            stride_mode=self.stride_mode,
            order=self.order,
        )
    )

    num_out_coords = batch_indexed_out_coords.shape[0]

    # Apply depthwise convolution
    output_features = spatially_sparse_depthwise_conv(
        input_sparse_tensor.feature_tensor,
        self.weight,
        kernel_map,
        num_out_coords,
        fwd_algo=self.fwd_algo,
        bwd_algo=self.bwd_algo,
        compute_dtype=self.compute_dtype,
    )

    # Add bias if present
    if self.bias is not None:
        output_features = output_features + self.bias

    # Determine output tensor stride
    in_tensor_stride = input_sparse_tensor.tensor_stride
    if in_tensor_stride is None:
        in_tensor_stride = (1,) * self.num_spatial_dims

    if not self.transposed:
        out_tensor_stride = tuple(
            o * s for o, s in zip(self.stride, in_tensor_stride)
        )
    else:
        if (
            output_spatially_sparse_tensor is not None
            and output_spatially_sparse_tensor.tensor_stride is not None
        ):
            out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
        else:
            out_tensor_stride = (1,) * self.num_spatial_dims

    # Create output voxels
    out_offsets_cpu = out_offsets.cpu().int()
    out_coords = IntCoords(
        batch_indexed_out_coords[:, 1:],
        offsets=out_offsets_cpu,
    )
    return input_sparse_tensor.replace(
        batched_coordinates=out_coords,
        batched_features=output_features,
        tensor_stride=out_tensor_stride,
    )

reset_parameters()

Reset module parameters using appropriate initialization.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
@torch.no_grad()
def reset_parameters(self):
    """Reset module parameters using appropriate initialization."""
    self._custom_kaiming_uniform_(
        self.weight,
        a=math.sqrt(5),
        mode="fan_out" if self.transposed else "fan_in",
    )

    if self.bias is not None:
        fan_in, _ = self._calculate_fan_in_and_fan_out()
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

Sparse pooling

warpconvnet.nn.modules.sparse_pool

GlobalPool

Bases: BaseSpatialModule

Pool features across the entire geometry.

Parameters:
  • reduce ((min, max, mean, sum), default: "min" ) –

    Reduction to apply over all features. Defaults to "max".

Source code in warpconvnet/nn/modules/sparse_pool.py
class GlobalPool(BaseSpatialModule):
    """Pool features across the entire geometry.

    Parameters
    ----------
    reduce : {"min", "max", "mean", "sum"}, optional
        Reduction to apply over all features. Defaults to ``"max"``.
    """

    def __init__(self, reduce: Literal["min", "max", "mean", "sum"] = "max"):
        super().__init__()
        self.reduce = reduce

    def forward(self, x: Geometry):
        return global_pool(x, self.reduce)

PointToSparseWrapper

Bases: PointToVoxel

Deprecated alias for PointToVoxel. Will be removed in a future release.

Source code in warpconvnet/nn/modules/sparse_pool.py
class PointToSparseWrapper(PointToVoxel):
    """Deprecated alias for ``PointToVoxel``. Will be removed in a future release."""

    def __init__(self, *args, **kwargs):
        import warnings

        warnings.warn(
            "PointToSparseWrapper is deprecated; use PointToVoxel instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        super().__init__(*args, **kwargs)

PointToVoxel

Bases: BaseSpatialModule

Pool points into a sparse tensor, apply an inner module and unpool back to points.

Parameters:
  • inner_module (`BaseSpatialModule`) –

    Module applied on the pooled sparse tensor.

  • voxel_size (float) –

    Voxel size used to pool the input points.

  • reduction (`REDUCTIONS` or str, default: MEAN ) –

    Reduction used when pooling points. Defaults to REDUCTIONS.MEAN.

  • unique_method ((morton, ravel, torch), default: "morton" ) –

    Method used for hashing voxel indices. Defaults to "morton".

  • concat_unpooled_pc (bool, default: True ) –

    If True concatenate the unpooled result with the original input. Defaults to True.

Source code in warpconvnet/nn/modules/sparse_pool.py
class PointToVoxel(BaseSpatialModule):
    """Pool points into a sparse tensor, apply an inner module and unpool back to points.

    Parameters
    ----------
    inner_module : `BaseSpatialModule`
        Module applied on the pooled sparse tensor.
    voxel_size : float
        Voxel size used to pool the input points.
    reduction : `REDUCTIONS` or str, optional
        Reduction used when pooling points. Defaults to ``REDUCTIONS.MEAN``.
    unique_method : {"morton", "ravel", "torch"}, optional
        Method used for hashing voxel indices. Defaults to ``"morton"``.
    concat_unpooled_pc : bool, optional
        If ``True`` concatenate the unpooled result with the original input. Defaults to ``True``.
    """

    def __init__(
        self,
        inner_module: BaseSpatialModule,
        voxel_size: float,
        reduction: Union[REDUCTIONS, REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
        unique_method: Literal["morton", "ravel", "torch"] = "morton",
        concat_unpooled_pc: bool = True,
    ):
        super().__init__()
        self.inner_module = inner_module
        self.voxel_size = voxel_size
        self.reduction = reduction
        self.concat_unpooled_pc = concat_unpooled_pc
        self.unique_method = unique_method

    def forward(self, pc: Points) -> Points:
        st, to_unique = point_pool(
            pc,
            reduction=self.reduction,
            downsample_voxel_size=self.voxel_size,
            return_type="voxel",
            return_to_unique=True,
            unique_method=self.unique_method,
        )
        out_st = self.inner_module(st)
        assert isinstance(out_st, Voxels), "Output of inner module must be a Voxels"
        unpooled_pc = point_unpool(
            out_st.to_point(self.voxel_size),
            pc,
            concat_unpooled_pc=self.concat_unpooled_pc,
            to_unique=to_unique,
        )
        return unpooled_pc

SparseMaxPool

Bases: SparsePool

Max pooling for sparse tensors.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMaxPool(SparsePool):
    """Max pooling for sparse tensors.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
    ):
        super().__init__(kernel_size, stride, "max")

SparseMinPool

Bases: SparsePool

Min pooling for sparse tensors.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMinPool(SparsePool):
    """Min pooling for sparse tensors.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
    ):
        super().__init__(kernel_size, stride, "min")

SparsePool

Bases: BaseSpatialModule

Reduce features of a Voxels object using a strided kernel.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

  • reduce ((max, min, mean, sum, random), default: "max" ) –

    Reduction to apply within each window. Defaults to "max".

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparsePool(BaseSpatialModule):
    """Reduce features of a ``Voxels`` object using a strided kernel.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    reduce : {"max", "min", "mean", "sum", "random"}, optional
        Reduction to apply within each window. Defaults to ``"max"``.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
        reduce: Literal["max", "min", "mean", "sum", "random"] = "max",
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.reduce = reduce

    def __repr__(self):
        return f"{self.__class__.__name__}(kernel_size={self.kernel_size}, stride={self.stride}, reduce={self.reduce})"

    def forward(self, st: Voxels):
        return sparse_reduce(
            st,
            self.kernel_size,
            self.stride,
            self.reduce,
        )

SparseUnpool

Bases: BaseSpatialModule

Unpool a sparse tensor back to a higher resolution.

Parameters:
  • kernel_size (int) –

    Size of the unpooling kernel.

  • stride (int) –

    Stride between unpooling windows.

  • concat_unpooled_st (bool, default: True ) –

    If True concatenate the unpooled tensor with the input. Defaults to True.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseUnpool(BaseSpatialModule):
    """Unpool a sparse tensor back to a higher resolution.

    Parameters
    ----------
    kernel_size : int
        Size of the unpooling kernel.
    stride : int
        Stride between unpooling windows.
    concat_unpooled_st : bool, optional
        If ``True`` concatenate the unpooled tensor with the input. Defaults to ``True``.
    """

    def __init__(self, kernel_size: int, stride: int, concat_unpooled_st: bool = True):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.concat_unpooled_st = concat_unpooled_st

    def forward(self, st: Voxels, unpooled_st: Voxels):
        return sparse_unpool(
            st,
            unpooled_st,
            self.kernel_size,
            self.stride,
            self.concat_unpooled_st,
        )

Sparse resampling

warpconvnet.nn.modules.sparse_resample

Sparse spatial resampling: Down/Up/Spatial2Channel/Channel2Spatial/Subdivide.

Pure-tensor index arithmetic, operates directly on warpconvnet.geometry.types.voxels.Voxels. Each block writes a coord/index cache onto voxels.spatial_cache so the paired inverse op can avoid recomputing the gather table.

Scale tracking is intentionally not baked into these classes — Voxels already carries a tensor_stride, and TRELLIS-style fractional _scale tracking lives in the model adapter that uses these (see warpconvnet.models.trellis2.sparse_spatial).

SparseChannel2Spatial

Bases: Module

Inverse of SparseSpatial2Channel. Reads each child slot from the channel block and either uses an explicit subdivision mask or the cache written by a paired SparseSpatial2Channel.

Source code in warpconvnet/nn/modules/sparse_resample.py
class SparseChannel2Spatial(nn.Module):
    """Inverse of `SparseSpatial2Channel`. Reads each child slot from the
    channel block and either uses an explicit `subdivision` mask or the cache
    written by a paired `SparseSpatial2Channel`."""

    def __init__(self, factor: int = 2):
        super().__init__()
        self.factor = factor

    def forward(
        self,
        x: Voxels,
        subdivision: Voxels | None = None,
    ) -> Voxels:
        f = self.factor
        DIM = x.coords.shape[-1] - 1
        cache = x.spatial_cache
        entry = cache.get(f"channel2spatial_{f}")
        if entry is None:
            if subdivision is None:
                raise ValueError(
                    "SparseChannel2Spatial needs either a cached spatial2channel "
                    "or an explicit subdivision tensor"
                )
            sub = subdivision.feats
            n_leaf = sub.sum(dim=-1)
            subidx = sub.nonzero()[:, -1]
            new_coords = x.coords.clone().detach()
            new_coords[:, 1:] *= f
            new_coords = torch.repeat_interleave(
                new_coords, n_leaf, dim=0, output_size=subidx.shape[0]
            )
            for i in range(DIM):
                new_coords[:, i + 1] += subidx // f**i % f
            idx = torch.repeat_interleave(
                torch.arange(x.coords.shape[0], device=x.device),
                n_leaf,
                dim=0,
                output_size=subidx.shape[0],
            )
        else:
            new_coords, idx, subidx = entry
        n_per = f**DIM
        x_feats = x.feats.reshape(x.feats.shape[0] * n_per, -1)
        new_feats = x_feats[idx * n_per + subidx]
        return from_feats_coords(new_feats, new_coords.int())

SparseDownsample

Bases: Module

Stride-factor average / max pool over coordinates.

Source code in warpconvnet/nn/modules/sparse_resample.py
class SparseDownsample(nn.Module):
    """Stride-`factor` average / max pool over coordinates."""

    def __init__(self, factor: int, mode: Literal["mean", "max"] = "mean"):
        super().__init__()
        assert mode in ("mean", "max")
        self.factor = factor
        self.mode = mode

    def forward(self, x: Voxels) -> Voxels:
        f = self.factor
        cache = x.spatial_cache
        ck = f"downsample_{f}"
        entry = cache.get(ck)
        if entry is None:
            DIM = x.coords.shape[-1] - 1
            coord = list(x.coords.unbind(dim=-1))
            for i in range(DIM):
                coord[i + 1] = coord[i + 1] // f
            MAX = [(s + f - 1) // f for s in _spatial_shape_of(x, cache)]
            OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
            code = sum(c * o for c, o in zip(coord, OFFSET))
            code, idx = code.unique(return_inverse=True)
            new_coords = torch.stack(
                [code // OFFSET[0]] + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)],
                dim=-1,
            )
        else:
            new_coords, idx = entry
        new_feats = torch.scatter_reduce(
            torch.zeros(
                new_coords.shape[0],
                x.feats.shape[1],
                device=x.feats.device,
                dtype=x.feats.dtype,
            ),
            dim=0,
            index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]),
            src=x.feats,
            reduce=self.mode,
            include_self=False,
        )
        out = from_feats_coords(new_feats, new_coords.int())
        # Share cache dict so paired Up can reuse the inverse map.
        out._extra_attributes["_spatial_cache"] = cache
        if entry is None:
            cache[ck] = (new_coords, idx)
            cache[f"upsample_{f}"] = (x.coords, idx)
        return out

SparseSpatial2Channel

Bases: Module

Pack factor**DIM neighbouring voxels into the channel dim. Output has factor-coarser spatial coords and factor**DIM-times more channels. Missing children are zero-padded.

Source code in warpconvnet/nn/modules/sparse_resample.py
class SparseSpatial2Channel(nn.Module):
    """Pack `factor**DIM` neighbouring voxels into the channel dim. Output has
    `factor`-coarser spatial coords and `factor**DIM`-times more channels.
    Missing children are zero-padded."""

    def __init__(self, factor: int = 2):
        super().__init__()
        self.factor = factor

    def forward(self, x: Voxels) -> Voxels:
        f = self.factor
        DIM = x.coords.shape[-1] - 1
        cache = x.spatial_cache
        ck = f"spatial2channel_{f}"
        entry = cache.get(ck)
        if entry is None:
            coord = list(x.coords.unbind(dim=-1))
            for i in range(DIM):
                coord[i + 1] = coord[i + 1] // f
            subidx = x.coords[:, 1:] % f
            subidx = sum(subidx[..., i] * f**i for i in range(DIM))
            MAX = [(s + f - 1) // f for s in _spatial_shape_of(x, cache)]
            OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
            code = sum(c * o for c, o in zip(coord, OFFSET))
            code, idx = code.unique(return_inverse=True)
            new_coords = torch.stack(
                [code // OFFSET[0]] + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)],
                dim=-1,
            )
        else:
            new_coords, idx, subidx = entry
        n_per = f**DIM
        new_feats = torch.zeros(
            new_coords.shape[0] * n_per,
            x.feats.shape[1],
            device=x.feats.device,
            dtype=x.feats.dtype,
        )
        new_feats[idx * n_per + subidx] = x.feats
        out = from_feats_coords(new_feats.reshape(new_coords.shape[0], -1), new_coords.int())
        out._extra_attributes["_spatial_cache"] = cache
        if entry is None:
            cache[ck] = (new_coords, idx, subidx)
            cache[f"channel2spatial_{f}"] = (x.coords, idx, subidx)
        return out

SparseSubdivide

Bases: Module

Repeat each voxel factor**DIM times along the spatial dims (no pooling, no scatter).

Source code in warpconvnet/nn/modules/sparse_resample.py
class SparseSubdivide(nn.Module):
    """Repeat each voxel `factor**DIM` times along the spatial dims (no
    pooling, no scatter)."""

    def __init__(self, factor: int):
        super().__init__()
        self.factor = factor

    def forward(self, x: Voxels) -> Voxels:
        f = self.factor
        DIM = x.coords.shape[-1] - 1
        n_per = f**DIM
        ranges = [torch.arange(f, device=x.device) for _ in range(DIM)]
        grid = torch.stack(torch.meshgrid(*ranges, indexing="ij"), dim=-1).reshape(n_per, DIM)
        new_coords = x.coords.repeat_interleave(n_per, dim=0)
        new_coords[:, 1:] = new_coords[:, 1:] * f + grid.repeat(x.coords.shape[0], 1)
        new_feats = x.feats.repeat_interleave(n_per, dim=0)
        return from_feats_coords(new_feats, new_coords.int())

SparseUpsample

Bases: Module

Inverse of SparseDownsample. Requires a paired downsample cache or an explicit subdivision mask of shape (N, factor**DIM).

Source code in warpconvnet/nn/modules/sparse_resample.py
class SparseUpsample(nn.Module):
    """Inverse of `SparseDownsample`. Requires a paired downsample cache or an
    explicit `subdivision` mask of shape ``(N, factor**DIM)``."""

    def __init__(self, factor: int):
        super().__init__()
        self.factor = factor

    def forward(self, x: Voxels, subdivision: Voxels | None = None) -> Voxels:
        f = self.factor
        DIM = x.coords.shape[-1] - 1
        cache = x.spatial_cache
        entry = cache.get(f"upsample_{f}")
        if entry is None:
            if subdivision is None:
                raise ValueError(
                    "SparseUpsample needs either a cached downsample or a subdivision tensor"
                )
            sub = subdivision.feats
            n_leaf = sub.sum(dim=-1)
            subidx = sub.nonzero()[:, -1]
            new_coords = x.coords.clone().detach()
            new_coords[:, 1:] *= f
            new_coords = torch.repeat_interleave(
                new_coords, n_leaf, dim=0, output_size=subidx.shape[0]
            )
            for i in range(DIM):
                new_coords[:, i + 1] += subidx // f**i % f
            idx = torch.repeat_interleave(
                torch.arange(x.coords.shape[0], device=x.device),
                n_leaf,
                dim=0,
                output_size=subidx.shape[0],
            )
        else:
            new_coords, idx = entry
        new_feats = x.feats[idx]
        return from_feats_coords(new_feats, new_coords.int())

Sparse U-Net blocks

warpconvnet.nn.modules.sparse_unet

Reusable sparse U-Net blocks and stage assembly on Voxels.

SparseChannelToSpatialResBlock3d

Bases: Module

Residual block that upsamples sparse voxels via channel-to-spatial unpacking.

The block projects channels to out_channels * factor**3 channels, unpacks neighbouring child voxels with SparseChannel2Spatial, then applies a zero-initialized sparse conv and residual skip. An optional subdivision head can predict which child slots to materialize.

Source code in warpconvnet/nn/modules/sparse_unet.py
class SparseChannelToSpatialResBlock3d(nn.Module):
    """Residual block that upsamples sparse voxels via channel-to-spatial unpacking.

    The block projects ``channels`` to ``out_channels * factor**3`` channels,
    unpacks neighbouring child voxels with ``SparseChannel2Spatial``, then
    applies a zero-initialized sparse conv and residual skip. An optional
    subdivision head can predict which child slots to materialize.
    """

    def __init__(
        self,
        channels: int,
        out_channels: int | None = None,
        factor: int = 2,
        use_checkpoint: bool = False,
        pred_subdiv: bool = True,
        conv_cls: type[nn.Module] = SparseConv3d,
        norm_cls: type[nn.Module] = LayerNorm32,
        kernel_size: int | tuple[int, int, int] = 3,
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.factor = factor
        self.use_checkpoint = use_checkpoint
        self.pred_subdiv = pred_subdiv
        self.num_children = factor**3

        if channels % self.num_children != 0:
            raise ValueError(
                f"channels ({channels}) must be divisible by factor**3 ({self.num_children})"
            )
        if self.out_channels % (channels // self.num_children) != 0:
            raise ValueError(
                "out_channels must be divisible by channels // factor**3 "
                f"({channels // self.num_children})"
            )

        self.norm1 = norm_cls(channels, elementwise_affine=True, eps=1e-6)
        self.norm2 = norm_cls(self.out_channels, elementwise_affine=False, eps=1e-6)
        self.conv1 = conv_cls(channels, self.out_channels * self.num_children, kernel_size)
        self.conv2 = zero_module(conv_cls(self.out_channels, self.out_channels, kernel_size))
        self._repeat = self.out_channels // (channels // self.num_children)
        if pred_subdiv:
            self.to_subdiv = nn.Linear(channels, self.num_children)
        self.updown = SparseChannel2Spatial(factor)

    def _skip(self, x: Voxels) -> Voxels:
        return x.replace_features(x.feats.repeat_interleave(self._repeat, dim=1))

    def _forward(
        self,
        x: Voxels,
        subdiv: Voxels | None = None,
    ):
        if self.pred_subdiv:
            subdiv = x.replace_features(self.to_subdiv(x.feats))
        h = x.replace_features(self.norm1(x.feats))
        h = h.replace_features(F.silu(h.feats))
        h = self.conv1(h)
        sub_bin = subdiv.replace_features(subdiv.feats > 0) if subdiv is not None else None
        h = self.updown(h, sub_bin)
        x = self.updown(x, sub_bin)
        h = h.replace_features(self.norm2(h.feats))
        h = h.replace_features(F.silu(h.feats))
        h = self.conv2(h)
        h = h + self._skip(x)
        if self.pred_subdiv:
            return h, subdiv
        return h

    def forward(
        self,
        x: Voxels,
        subdiv: Voxels | None = None,
    ):
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False)
        return self._forward(x, subdiv)

SparseUNetDecoderStages

Bases: ModuleList

Resolution-stage assembly for sparse U-Net decoders.

The class subclasses torch.nn.ModuleList so assigning it as self.blocks preserves familiar state-dict names such as blocks.0.0.weight. Each stage contains num_blocks[i] residual blocks followed by an optional upsample block between resolutions.

Source code in warpconvnet/nn/modules/sparse_unet.py
class SparseUNetDecoderStages(nn.ModuleList):
    """Resolution-stage assembly for sparse U-Net decoders.

    The class subclasses ``torch.nn.ModuleList`` so assigning it as
    ``self.blocks`` preserves familiar state-dict names such as
    ``blocks.0.0.weight``. Each stage contains ``num_blocks[i]`` residual
    blocks followed by an optional upsample block between resolutions.
    """

    def __init__(
        self,
        model_channels: list[int],
        num_blocks: list[int],
        block_type: list[str],
        up_block_type: list[str],
        block_args: list[dict[str, Any]],
        block_registry: Mapping[str, type[nn.Module]],
        up_block_kwargs: dict[str, Any] | None = None,
    ):
        if not (len(model_channels) == len(num_blocks) == len(block_type) == len(block_args)):
            raise ValueError("model_channels, num_blocks, block_type, and block_args must align")
        if len(up_block_type) != max(0, len(num_blocks) - 1):
            raise ValueError("up_block_type must have one entry between each resolution stage")

        stages: list[nn.ModuleList] = []
        up_block_kwargs = up_block_kwargs or {}
        for i, n_blocks in enumerate(num_blocks):
            stage = nn.ModuleList([])
            for _ in range(n_blocks):
                stage.append(block_registry[block_type[i]](model_channels[i], **block_args[i]))
            if i < len(num_blocks) - 1:
                kwargs = dict(block_args[i])
                kwargs.update(up_block_kwargs)
                stage.append(
                    block_registry[up_block_type[i]](
                        model_channels[i],
                        model_channels[i + 1],
                        **kwargs,
                    )
                )
            stages.append(stage)
        super().__init__(stages)

        self.model_channels = model_channels
        self.num_blocks = num_blocks
        self.block_type = block_type
        self.up_block_type = up_block_type

    def run(
        self,
        x: Voxels,
        guide_subs: list[Voxels] | None = None,
        return_subs: bool = False,
        stop_before_stage: int | None = None,
    ):
        """Run decoder stages.

        ``guide_subs`` supplies explicit subdivision masks for upsample blocks.
        ``return_subs`` collects subdivision predictions from blocks that return
        ``(x, subdiv)``. ``stop_before_stage`` returns early before a stage is
        executed, useful for cascade coordinate upsampling.
        """
        if guide_subs is not None and return_subs:
            raise ValueError("guide_subs and return_subs are mutually exclusive")

        subs: list[Voxels] = []
        for i, stage in enumerate(self):
            if stop_before_stage is not None and i == stop_before_stage:
                return (x, subs) if return_subs else x
            for j, block in enumerate(stage):
                is_last_in_stage = j == len(stage) - 1
                is_upsample = i < len(self) - 1 and is_last_in_stage
                if is_upsample and guide_subs is not None:
                    x = block(x, subdiv=guide_subs[i])
                    continue

                out = block(x)
                if isinstance(out, tuple):
                    x, sub = out
                    if return_subs:
                        subs.append(sub)
                else:
                    x = out

        return (x, subs) if return_subs else x

run(x: Voxels, guide_subs: list[Voxels] | None = None, return_subs: bool = False, stop_before_stage: int | None = None)

Run decoder stages.

guide_subs supplies explicit subdivision masks for upsample blocks. return_subs collects subdivision predictions from blocks that return (x, subdiv). stop_before_stage returns early before a stage is executed, useful for cascade coordinate upsampling.

Source code in warpconvnet/nn/modules/sparse_unet.py
def run(
    self,
    x: Voxels,
    guide_subs: list[Voxels] | None = None,
    return_subs: bool = False,
    stop_before_stage: int | None = None,
):
    """Run decoder stages.

    ``guide_subs`` supplies explicit subdivision masks for upsample blocks.
    ``return_subs`` collects subdivision predictions from blocks that return
    ``(x, subdiv)``. ``stop_before_stage`` returns early before a stage is
    executed, useful for cascade coordinate upsampling.
    """
    if guide_subs is not None and return_subs:
        raise ValueError("guide_subs and return_subs are mutually exclusive")

    subs: list[Voxels] = []
    for i, stage in enumerate(self):
        if stop_before_stage is not None and i == stop_before_stage:
            return (x, subs) if return_subs else x
        for j, block in enumerate(stage):
            is_last_in_stage = j == len(stage) - 1
            is_upsample = i < len(self) - 1 and is_last_in_stage
            if is_upsample and guide_subs is not None:
                x = block(x, subdiv=guide_subs[i])
                continue

            out = block(x)
            if isinstance(out, tuple):
                x, sub = out
                if return_subs:
                    subs.append(sub)
            else:
                x = out

    return (x, subs) if return_subs else x

Transforms

warpconvnet.nn.modules.transforms

Transform

Bases: BaseSpatialModule

Point transform module that applies a feature transform to the input point collection. No spatial operations are performed.

Hydra config example usage:

.. code-block:: yaml

    model:
    feature_transform:
        _target_: warpconvnet.nn.point_transform.Transform
        feature_transform_fn: _target_: torch.nn.ReLU
Source code in warpconvnet/nn/modules/transforms.py
class Transform(BaseSpatialModule):
    """
    Point transform module that applies a feature transform to the input point collection.
    No spatial operations are performed.

    Hydra config example usage:

    .. code-block:: yaml

            model:
            feature_transform:
                _target_: warpconvnet.nn.point_transform.Transform
                feature_transform_fn: _target_: torch.nn.ReLU
    """

    def __init__(self, feature_transform_fn: nn.Module):
        super().__init__()
        self.feature_transform_fn = feature_transform_fn

    def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
        """
        Apply the feature transform to the input point collection

        Args:
            pc: Input point collection

        Returns:
            Transformed point collection
        """
        if isinstance(sfs, Geometry):
            return sfs.replace(
                batched_features=self.feature_transform_fn(sfs.feature_tensor)
            )

        # When input is not a single BatchedSpatialFeatures, we assume the inputs are features
        assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
        # Assert that all spatial features have the same offsets
        assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
        sf = sfs[0]
        features = [sf.feature_tensor for sf in sfs]

        out_features = self.feature_transform_fn(*features)
        return sf.replace(
            batched_features=out_features,
        )

forward(*sfs: Tuple[Geometry, ...]) -> Geometry

Apply the feature transform to the input point collection

Args: pc: Input point collection

Returns: Transformed point collection

Source code in warpconvnet/nn/modules/transforms.py
def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
    """
    Apply the feature transform to the input point collection

    Args:
        pc: Input point collection

    Returns:
        Transformed point collection
    """
    if isinstance(sfs, Geometry):
        return sfs.replace(
            batched_features=self.feature_transform_fn(sfs.feature_tensor)
        )

    # When input is not a single BatchedSpatialFeatures, we assume the inputs are features
    assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
    # Assert that all spatial features have the same offsets
    assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
    sf = sfs[0]
    features = [sf.feature_tensor for sf in sfs]

    out_features = self.feature_transform_fn(*features)
    return sf.replace(
        batched_features=out_features,
    )

Samplers

warpconvnet.nn.samplers

ODE / SDE samplers for diffusion and flow-matching models.

FlowEulerCfgSampler

Bases: _ClassifierFreeGuidanceMixin, FlowEulerSampler

Flow Euler with classifier-free guidance.

Source code in warpconvnet/nn/samplers/flow_euler.py
class FlowEulerCfgSampler(_ClassifierFreeGuidanceMixin, FlowEulerSampler):
    """Flow Euler with classifier-free guidance."""

    @torch.no_grad()
    def sample(
        self,
        model,
        noise: torch.Tensor,
        cond: Any,
        neg_cond: Any,
        steps: int = 50,
        rescale_t: float = 1.0,
        guidance_strength: float = 3.0,
        verbose: bool = True,
        **kwargs,
    ):
        return super().sample(
            model,
            noise,
            cond,
            steps=steps,
            rescale_t=rescale_t,
            verbose=verbose,
            neg_cond=neg_cond,
            guidance_strength=guidance_strength,
            **kwargs,
        )

FlowEulerGuidanceIntervalSampler

Bases: _GuidanceIntervalMixin, _ClassifierFreeGuidanceMixin, FlowEulerSampler

Flow Euler with CFG + guidance interval.

Source code in warpconvnet/nn/samplers/flow_euler.py
class FlowEulerGuidanceIntervalSampler(
    _GuidanceIntervalMixin, _ClassifierFreeGuidanceMixin, FlowEulerSampler
):
    """Flow Euler with CFG + guidance interval."""

    @torch.no_grad()
    def sample(
        self,
        model,
        noise: torch.Tensor,
        cond: Any,
        neg_cond: Any,
        steps: int = 50,
        rescale_t: float = 1.0,
        guidance_strength: float = 3.0,
        guidance_interval: tuple[float, float] = (0.0, 1.0),
        verbose: bool = True,
        **kwargs,
    ):
        return super().sample(
            model,
            noise,
            cond,
            steps=steps,
            rescale_t=rescale_t,
            verbose=verbose,
            neg_cond=neg_cond,
            guidance_strength=guidance_strength,
            guidance_interval=guidance_interval,
            **kwargs,
        )

FlowEulerSampler

Bases: Sampler

Euler ODE solver for rectified-flow / flow-matching models.

The model is assumed to predict velocity v = (x_1 - x_0) / 1 given the noisy sample x_t = (1 - t) x_0 + (sigma_min + (1 - sigma_min) t) eps. Stepping is plain Euler: x_{t-Δ} = x_t - (t - t_prev) * v.

Source code in warpconvnet/nn/samplers/flow_euler.py
class FlowEulerSampler(Sampler):
    """Euler ODE solver for rectified-flow / flow-matching models.

    The model is assumed to predict velocity ``v = (x_1 - x_0) / 1`` given the
    noisy sample ``x_t = (1 - t) x_0 + (sigma_min + (1 - sigma_min) t) eps``.
    Stepping is plain Euler: ``x_{t-Δ} = x_t - (t - t_prev) * v``.
    """

    def __init__(self, sigma_min: float):
        self.sigma_min = sigma_min

    # -- conversion utilities (linear-interp flow parameterisation) ------------
    def _v_to_xstart_eps(self, x_t: torch.Tensor, t: float, v: torch.Tensor):
        # Use left-multiply on Voxels (no ``__rmul__`` on Geometry).
        eps = v * (1 - t) + x_t
        x_0 = x_t * (1 - self.sigma_min) - v * (self.sigma_min + (1 - self.sigma_min) * t)
        return x_0, eps

    def _pred_to_xstart(self, x_t: torch.Tensor, t: float, pred: torch.Tensor) -> torch.Tensor:
        # Use left-multiply on x_t / pred so sparse Voxels (which have
        # ``__mul__`` but no ``__rmul__``) flow through unchanged.
        return x_t * (1 - self.sigma_min) - pred * (self.sigma_min + (1 - self.sigma_min) * t)

    def _xstart_to_pred(self, x_t: torch.Tensor, t: float, x_0: torch.Tensor) -> torch.Tensor:
        return (x_t * (1 - self.sigma_min) - x_0) / (self.sigma_min + (1 - self.sigma_min) * t)

    # -- model wrapper ---------------------------------------------------------
    def _inference_model(
        self,
        model,
        x_t: torch.Tensor,
        t: float,
        cond: Any | None = None,
        **kwargs,
    ) -> torch.Tensor:
        # Voxels (Geometry) does not implement .shape; fall back to batch_size.
        n = getattr(x_t, "batch_size", None)
        if n is None:
            n = x_t.shape[0]
        t_in = torch.tensor([1000 * t] * n, device=x_t.device, dtype=torch.float32)
        return model(x_t, t_in, cond, **kwargs)

    def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
        pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
        pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
        return pred_x_0, pred_eps, pred_v

    # -- single Euler step -----------------------------------------------------
    @torch.no_grad()
    def sample_once(
        self,
        model,
        x_t: torch.Tensor,
        t: float,
        t_prev: float,
        cond: Any | None = None,
        **kwargs,
    ):
        pred_x_0, _pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
        pred_x_prev = x_t - pred_v * (t - t_prev)
        return {"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}

    # -- full sampling loop ----------------------------------------------------
    @torch.no_grad()
    def sample(
        self,
        model,
        noise: torch.Tensor,
        cond: Any | None = None,
        steps: int = 50,
        rescale_t: float = 1.0,
        verbose: bool = True,
        tqdm_desc: str = "Sampling",
        **kwargs,
    ):
        sample = noise
        t_seq = np.linspace(1, 0, steps + 1)
        t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
        t_pairs = list(zip(t_seq[:-1].tolist(), t_seq[1:].tolist()))
        pred_x_t_hist, pred_x_0_hist = [], []
        for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose):
            out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
            sample = out["pred_x_prev"]
            pred_x_t_hist.append(out["pred_x_prev"])
            pred_x_0_hist.append(out["pred_x_0"])
        return {"samples": sample, "pred_x_t": pred_x_t_hist, "pred_x_0": pred_x_0_hist}

Sampler

Bases: ABC

Base class for ODE / SDE samplers.

Source code in warpconvnet/nn/samplers/flow_euler.py
class Sampler(ABC):
    """Base class for ODE / SDE samplers."""

    @abstractmethod
    def sample(self, model, **kwargs):  # pragma: no cover
        pass

Utilities

warpconvnet.nn.utils

Small neural-network utilities shared by model ports and reusable modules.

convert_module_parameters_to(module: nn.Module, dtype: torch.dtype, module_types: tuple[type[nn.Module], ...] = DEFAULT_MIXED_PRECISION_MODULES) -> None

Cast parameters for selected leaf module families in-place.

Intended for diffusion/flow model ports that keep normalization modules in fp32 while casting Linear/Conv layers to fp16 or bf16.

Source code in warpconvnet/nn/utils.py
def convert_module_parameters_to(
    module: nn.Module,
    dtype: torch.dtype,
    module_types: tuple[type[nn.Module], ...] = DEFAULT_MIXED_PRECISION_MODULES,
) -> None:
    """Cast parameters for selected leaf module families in-place.

    Intended for diffusion/flow model ports that keep normalization modules in
    fp32 while casting Linear/Conv layers to fp16 or bf16.
    """
    if isinstance(module, module_types):
        for p in module.parameters():
            p.data = p.data.to(dtype)

manual_cast(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor

Cast tensor to dtype only when AMP autocast is not active.

Source code in warpconvnet/nn/utils.py
def manual_cast(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
    """Cast ``tensor`` to ``dtype`` only when AMP autocast is not active."""
    if not torch.is_autocast_enabled():
        return tensor.type(dtype)
    return tensor

str_to_dtype(s: str) -> torch.dtype

Parse common dtype spellings used in model config files.

Source code in warpconvnet/nn/utils.py
def str_to_dtype(s: str) -> torch.dtype:
    """Parse common dtype spellings used in model config files."""
    return {
        "f16": torch.float16,
        "fp16": torch.float16,
        "float16": torch.float16,
        "bf16": torch.bfloat16,
        "bfloat16": torch.bfloat16,
        "f32": torch.float32,
        "fp32": torch.float32,
        "float32": torch.float32,
    }[s]

zero_module(module: nn.Module) -> nn.Module

Zero all parameters of module in-place and return it.

Source code in warpconvnet/nn/utils.py
def zero_module(module: nn.Module) -> nn.Module:
    """Zero all parameters of ``module`` in-place and return it."""
    for p in module.parameters():
        p.detach().zero_()
    return module