Neural Networks

warpconvnet.nn

Modules

Activations

warpconvnet.nn.modules.activations

DropPath

Bases: BaseSpatialModule

Stochastic depth regularization.

Parameters:
  • drop_prob (float, default: 0.0 ) –

    Probability of dropping a sample. Defaults to 0.0.

  • scale_by_keep (bool, default: True ) –

    If True the output is scaled by 1 - drop_prob. Defaults to True.

Source code in warpconvnet/nn/modules/activations.py
class DropPath(BaseSpatialModule):
    """Stochastic depth regularization.

    Parameters
    ----------
    drop_prob : float, optional
        Probability of dropping a sample. Defaults to ``0.0``.
    scale_by_keep : bool, optional
        If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
    """

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x: Union[Geometry, Tensor]):  # noqa: F821
        if isinstance(x, Geometry):
            return x.replace(
                batched_features=drop_path(
                    x.feature_tensor, self.drop_prob, self.training, self.scale_by_keep
                )
            )
        else:
            return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob, 3): 0.3f}"

ELU

Bases: BaseSpatialModule

Applies the ELU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class ELU(BaseSpatialModule):
    """Applies the ELU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return elu(input)

GELU

Bases: BaseSpatialModule

Applies the GELU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class GELU(BaseSpatialModule):
    """Applies the GELU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return gelu(input)

LeakyReLU

Bases: Module

Applies the LeakyReLU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class LeakyReLU(nn.Module):
    """Applies the LeakyReLU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return leaky_relu(input)

LogSoftmax

Bases: BaseSpatialModule

Applies the log_softmax activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class LogSoftmax(BaseSpatialModule):
    """Applies the ``log_softmax`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return log_softmax(input)

ReLU

Bases: BaseSpatialModule

Apply the ReLU activation to Geometry features.

Parameters:
  • inplace (bool, default: False ) –

    Whether to perform the operation in-place. Defaults to False.

Source code in warpconvnet/nn/modules/activations.py
class ReLU(BaseSpatialModule):
    """Apply the ReLU activation to ``Geometry`` features.

    Parameters
    ----------
    inplace : bool, optional
        Whether to perform the operation in-place. Defaults to ``False``.
    """

    def __init__(self, inplace: bool = False):
        super().__init__()
        self.relu = nn.ReLU(inplace=inplace)

    def __repr__(self):
        return f"{self.__class__.__name__}(inplace={self.relu.inplace})"

    def forward(self, input: Geometry):  # noqa: F821
        return apply_feature_transform(input, self.relu)

SiLU

Bases: BaseSpatialModule

Applies the SiLU activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class SiLU(BaseSpatialModule):
    """Applies the SiLU activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return silu(input)

Sigmoid

Bases: BaseSpatialModule

Applies the sigmoid activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Sigmoid(BaseSpatialModule):
    """Applies the sigmoid activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return sigmoid(input)

Softmax

Bases: BaseSpatialModule

Applies the softmax activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Softmax(BaseSpatialModule):
    """Applies the ``softmax`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return softmax(input)

Tanh

Bases: BaseSpatialModule

Applies the tanh activation to Geometry features.

Source code in warpconvnet/nn/modules/activations.py
class Tanh(BaseSpatialModule):
    """Applies the ``tanh`` activation to ``Geometry`` features."""

    def forward(self, input: Geometry):  # noqa: F821
        return tanh(input)

drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True)

Apply stochastic depth to the input tensor.

Parameters:
  • x (``torch.Tensor``) –

    Input tensor to apply stochastic depth to.

  • drop_prob (float, default: 0.0 ) –

    Probability of dropping a sample. Defaults to 0.0.

  • training (bool, default: False ) –

    Whether the module is in training mode. Defaults to False.

  • scale_by_keep (bool, default: True ) –

    If True the output is scaled by 1 - drop_prob. Defaults to True.

Source code in warpconvnet/nn/modules/activations.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
    """Apply stochastic depth to the input tensor.

    Parameters
    ----------
    x : ``torch.Tensor``
        Input tensor to apply stochastic depth to.
    drop_prob : float, optional
        Probability of dropping a sample. Defaults to ``0.0``.
    training : bool, optional
        Whether the module is in training mode. Defaults to ``False``.
    scale_by_keep : bool, optional
        If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

Attention

warpconvnet.nn.modules.attention

Attention

Bases: Module

Source code in warpconvnet/nn/modules/attention.py
class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        enable_flash: bool = True,
        use_batched_qkv: bool = True,
    ):
        """
        Attention module with optional batched QKV for Muon optimization.

        Args:
            dim: Input feature dimension
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in QKV projection
            qk_scale: Scale factor for attention scores
            attn_drop: Attention dropout rate
            proj_drop: Output projection dropout rate
            enable_flash: Whether to use flash attention
            use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
                           for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
                           more effectively than the concatenated [dim, 3*dim] matrix.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.enable_flash = enable_flash
        self.use_batched_qkv = use_batched_qkv

        if enable_flash:
            assert flash_attn is not None, "Make sure flash_attn is installed."
            self.attn_drop_p = attn_drop
        else:
            self.attn_drop = nn.Dropout(attn_drop)

        if use_batched_qkv:
            # Use BatchedLinear for Muon-friendly QKV projection
            self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
        else:
            # Original single linear layer approach
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(
        self,
        x: Float[Tensor, "B N C"],  # noqa: F821
        pos_enc: Optional[Float[Tensor, "B N C"]] = None,  # noqa: F821
        mask: Optional[Float[Tensor, "B N N"]] = None,  # noqa: F821
        num_points: Optional[Int[Tensor, "B"]] = None,  # noqa: F821
    ) -> Float[Tensor, "B N C"]:
        B, N, C = x.shape

        # Compute QKV with unified approach
        if pos_enc is not None and self.enable_flash:
            # Add positional encoding to input before QKV projection for flash attention
            qkv = self.qkv(x + pos_enc).reshape(B, N, 3, C)
        else:
            qkv = self.qkv(x).reshape(B, N, 3, C)

        # Reshape to [B, N, 3, num_heads, head_dim]
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)

        if not self.enable_flash:
            qkv = qkv.permute(2, 0, 3, 1, 4)
            q, k, v = (
                qkv[0],
                qkv[1],
                qkv[2],
            )  # make torchscript happy (cannot use tensor as tuple)

            # Apply positional encoding to the query and key (non-flash path)
            if pos_enc is not None:
                q = q + pos_enc.unsqueeze(1)
                k = k + pos_enc.unsqueeze(1)

            attn = (q @ k.transpose(-2, -1)) * self.scale
            if mask is not None:
                attn = attn + mask

            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

            x = attn @ v
            x = x.transpose(1, 2).reshape(B, N, C)
        else:
            # Flash attention path
            # Flash attention - preserve original dtype if possible
            original_dtype = qkv.dtype
            if qkv.dtype not in [torch.float16, torch.bfloat16]:
                # Convert to half precision for flash attention
                qkv_flash = qkv.half()
            else:
                qkv_flash = qkv

            x = flash_attn.flash_attn_qkvpacked_func(
                qkv_flash,
                dropout_p=self.attn_drop_p if self.training else 0.0,
                softmax_scale=self.scale,
            ).reshape(B, N, C)

            # Convert back to original dtype if necessary
            if x.dtype != original_dtype:
                x = x.to(original_dtype)

        x = self.proj(x)
        x = self.proj_drop(x)

        if num_points is not None:
            x = zero_out_points(x, num_points)
        return x

PatchAttention

Bases: BaseSpatialModule

Source code in warpconvnet/nn/modules/attention.py
class PatchAttention(BaseSpatialModule):
    def __init__(
        self,
        dim: int,
        patch_size: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        order: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
        use_batched_qkv: bool = True,
    ):
        """
        Patch attention module with optional batched QKV for Muon optimization.

        Args:
            dim: Input feature dimension
            patch_size: Size of patches for attention computation
            num_heads: Number of attention heads
            qkv_bias: Whether to use bias in QKV projection
            qk_scale: Scale factor for attention scores
            attn_drop: Attention dropout rate
            proj_drop: Output projection dropout rate
            order: Point ordering for patch generation
            use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
                           for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
                           more effectively than the concatenated [dim, 3*dim] matrix.
        """
        super().__init__()
        self.patch_size = patch_size
        self.num_heads = num_heads
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5
        self.use_batched_qkv = use_batched_qkv

        if use_batched_qkv:
            # Use BatchedLinear for Muon-friendly QKV projection
            self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
        else:
            # Original single linear layer approach
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.order = order
        assert flash_attn is not None, "Make sure flash_attn is installed."
        self.attn_drop_p = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def _offset_to_attn_offset(
        self, offsets: Int[Tensor, "B+1"], patch_size: Optional[int] = None
    ) -> Int[Tensor, "B"]:
        """
        Convert offsets to cumulative attention offsets required for flash attention.
        If the patch size is 8 and the offsets are [0, 3, 11, 40] (3 batches),
        the cumulative attention offsets are [0, 3, 3 + 8 = 11, 11 + 8, 11 + 8 + 8, 11 + 8 + 8 + 8, 40].

        Args:
            offsets: (B+1)
            patch_size: Optional[int]
        Returns:
            cum_seqlens: M
        """
        patch_size = patch_size or self.patch_size
        counts = torch.diff(offsets)
        num_patches_per_batch = counts // patch_size

        # Fast path: if no patches, return original offsets
        if num_patches_per_batch.sum() == 0:
            return offsets

        # Calculate how many elements each batch contributes (1 start + num_patches)
        elements_per_batch = 1 + num_patches_per_batch

        # Create indices for which batch each element belongs to
        batch_indices = torch.repeat_interleave(
            torch.arange(len(offsets) - 1, device=offsets.device), elements_per_batch
        )

        # Create indices for position within each batch's sequence (0, 1, 2, ...)
        within_batch_indices = torch.cat(
            [
                torch.arange(n + 1, device=offsets.device, dtype=offsets.dtype)
                for n in num_patches_per_batch
            ]
        )

        # Calculate the actual offsets: start_offset + patch_index * patch_size
        start_offsets = offsets[:-1][batch_indices]
        patch_contributions = within_batch_indices * patch_size
        result_middle = start_offsets + patch_contributions

        # Add the final offset
        result = torch.cat([result_middle, offsets[-1].unsqueeze(0)])

        return result.contiguous()

    def forward(self, x: Geometry, order: Optional[POINT_ORDERING] = None) -> Geometry:
        # Assert that x is serialized
        K = self.patch_size

        feats = x.features
        M, C = feats.shape[:2]
        inverse_perm = None
        order = order or self.order
        if not hasattr(x, "order") or (order != x.order):
            # Generate new ordering and inverse permutation
            code_result = encode(
                x.coordinate_tensor,
                batch_offsets=x.offsets,
                order=order,
                return_perm=True,
                return_inverse=True,
            )
            feats = feats[code_result.perm]
            inverse_perm = code_result.inverse_perm

        # Compute QKV: (M, 3, num_heads, head_dim)
        qkv = self.qkv(feats).reshape(M, 3, self.num_heads, C // self.num_heads)
        if qkv.dtype not in [torch.float16, torch.bfloat16]:
            qkv = qkv.to(torch.float16)

        attn_offsets = self._offset_to_attn_offset(x.offsets, K).to(qkv.device)
        # Warning: When the loss is NaN, this module will fail during backward with
        # index out of bounds error.
        # e.g. /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [192,0,0], thread: [32,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "
        # https://discuss.pytorch.org/t/scattergatherkernel-cu-assertion-idx-dim-0-idx-dim-index-size-index-out-of-bounds/195356
        out_feat = flash_attn.flash_attn_varlen_qkvpacked_func(
            qkv,
            attn_offsets,
            max_seqlen=K,
            dropout_p=self.attn_drop_p if self.training else 0.0,
            softmax_scale=self.scale,
        )
        out_feat = out_feat.reshape(M, C).to(feats.dtype)

        out_feat = self.proj(out_feat)
        out_feat = self.proj_drop(out_feat)

        if inverse_perm is not None:
            out_feat = out_feat[inverse_perm]

        return x.replace(batched_features=out_feat.to(feats.dtype))

offset_to_mask(x: Float[Tensor, 'B M C'], offsets: Float[Tensor, B + 1], max_num_points: int, dtype: torch.dtype = torch.bool) -> Float[Tensor, 'B 1 M M']

Create a mask for the points in the batch.

Source code in warpconvnet/nn/modules/attention.py
def offset_to_mask(
    x: Float[Tensor, "B M C"],  # noqa: F821
    offsets: Float[Tensor, "B+1"],  # noqa: F821
    max_num_points: int,  # noqa: F821
    dtype: torch.dtype = torch.bool,
) -> Float[Tensor, "B 1 M M"]:  # noqa: F821
    """
    Create a mask for the points in the batch.
    """
    B = x.shape[0]
    assert B == offsets.shape[0] - 1
    mask = torch.zeros(
        (B, 1, max_num_points, max_num_points),
        dtype=dtype,
        device=x.device,
    )
    num_points = offsets.diff()
    if dtype == torch.bool:
        for b in range(B):
            # mask[b, :, : num_points[b], : num_points[b]] = True
            mask[b, :, :, : num_points[b]] = True
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")
    return mask

zero_out_points(x: Float[Tensor, 'B N C'], num_points: Int[Tensor, B]) -> Float[Tensor, 'B N C']

Zero out the points in the batch.

Source code in warpconvnet/nn/modules/attention.py
def zero_out_points(
    x: Float[Tensor, "B N C"], num_points: Int[Tensor, "B"]  # noqa: F821
) -> Float[Tensor, "B N C"]:  # noqa: F821
    """
    Zero out the points in the batch.
    """
    for b in range(num_points.shape[0]):
        x[b, num_points[b] :] = 0
    return x

Base module

warpconvnet.nn.modules.base_module

BaseSpatialModel

Bases: BaseSpatialModule

Base model class.

Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModel(BaseSpatialModule):
    """Base model class."""

    def data_dict_to_input(self, data_dict, **kwargs) -> Any:
        """Convert data dictionary to appropriate input for the model."""
        raise NotImplementedError

    def loss_dict(self, data_dict, **kwargs) -> Dict:
        """Compute the loss dictionary for the model."""
        raise NotImplementedError

    @torch.no_grad()
    def eval_dict(self, data_dict, **kwargs) -> Dict:
        """Compute the evaluation dictionary for the model."""
        raise NotImplementedError

    def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
        """Compute the image dict and pointcloud dict for the model."""
        raise NotImplementedError

data_dict_to_input(data_dict, **kwargs) -> Any

Convert data dictionary to appropriate input for the model.

Source code in warpconvnet/nn/modules/base_module.py
def data_dict_to_input(self, data_dict, **kwargs) -> Any:
    """Convert data dictionary to appropriate input for the model."""
    raise NotImplementedError

eval_dict(data_dict, **kwargs) -> Dict

Compute the evaluation dictionary for the model.

Source code in warpconvnet/nn/modules/base_module.py
@torch.no_grad()
def eval_dict(self, data_dict, **kwargs) -> Dict:
    """Compute the evaluation dictionary for the model."""
    raise NotImplementedError

image_pointcloud_dict(data_dict, datamodule) -> Tuple[Dict, Dict]

Compute the image dict and pointcloud dict for the model.

Source code in warpconvnet/nn/modules/base_module.py
def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
    """Compute the image dict and pointcloud dict for the model."""
    raise NotImplementedError

loss_dict(data_dict, **kwargs) -> Dict

Compute the loss dictionary for the model.

Source code in warpconvnet/nn/modules/base_module.py
def loss_dict(self, data_dict, **kwargs) -> Dict:
    """Compute the loss dictionary for the model."""
    raise NotImplementedError

BaseSpatialModule

Bases: Module

Base module for spatial features. The input must be an instance of BatchedSpatialFeatures.

Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModule(nn.Module):
    """Base module for spatial features. The input must be an instance of `BatchedSpatialFeatures`."""

    @property
    def device(self):
        """Returns the device that the model is on."""
        return next(self.parameters()).device

    def forward(self, x: Geometry):
        """Forward pass."""
        raise NotImplementedError

device property

Returns the device that the model is on.

forward(x: Geometry)

Forward pass.

Source code in warpconvnet/nn/modules/base_module.py
def forward(self, x: Geometry):
    """Forward pass."""
    raise NotImplementedError

Factor grid

warpconvnet.nn.modules.factor_grid

Neural network modules for FactorGrid operations.

This module provides neural network layers and operations specifically designed for working with FactorGrid geometries in the FIGConvNet architecture.

FactorGridCat

Bases: BaseSpatialModule

Concatenate features of two FactorGrid objects.

This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridCat(BaseSpatialModule):
    """Concatenate features of two FactorGrid objects.

    This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.
    """

    def __init__(self):
        super().__init__()

    def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
        """Concatenate features from two FactorGrid objects."""
        return factor_grid_cat(factor_grid1, factor_grid2)

forward(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid

Concatenate features from two FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
    """Concatenate features from two FactorGrid objects."""
    return factor_grid_cat(factor_grid1, factor_grid2)

FactorGridGlobalConv

Bases: BaseSpatialModule

Global convolution with intra-communication for FactorGrid.

This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridGlobalConv(BaseSpatialModule):
    """Global convolution with intra-communication for FactorGrid.

    This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        compressed_spatial_dims: Tuple[int, ...],
        compressed_memory_formats: Tuple[GridMemoryFormat, ...],
        stride: int = 1,
        up_stride: Optional[int] = None,
        communication_types: List[Literal["sum", "mul"]] = ["sum"],
        norm: Type[nn.Module] = nn.BatchNorm2d,
        activation: Type[nn.Module] = nn.GELU,
    ):
        super().__init__()
        assert len(compressed_spatial_dims) == len(
            compressed_memory_formats
        ), "Number of compressed spatial dimensions and compressed memory formats must match"

        # Create convolution blocks for each compressed spatial dimension
        self.conv_blocks = nn.ModuleList()
        for compressed_spatial_dim, compressed_memory_format in zip(
            compressed_spatial_dims, compressed_memory_formats
        ):
            self.conv_blocks.append(
                _FactorGridConvNormAct(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    compressed_spatial_dim=compressed_spatial_dim,
                    compressed_memory_format=compressed_memory_format,
                    stride=stride,
                    up_stride=up_stride,
                    norm=norm,
                    activation=activation,
                )
            )

        # Intra-communication module
        self.intra_communications = FactorGridIntraCommunication(
            communication_types=communication_types
        )

        # Projection layer if multiple communication types
        if len(communication_types) > 1:
            self.proj = FactorGridProjection(
                in_channels=out_channels * len(communication_types),
                out_channels=out_channels,
                kernel_size=1,
                compressed_spatial_dims=compressed_spatial_dims,
                compressed_memory_formats=compressed_memory_formats,
                stride=1,
            )
        else:
            self.proj = nn.Identity()

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Forward pass through the global convolution module."""
        assert len(factor_grid) == len(
            self.conv_blocks
        ), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"

        # Apply convolution blocks to each grid
        convolved_grids = []
        for grid, conv_block in zip(factor_grid, self.conv_blocks):
            convolved = conv_block(grid)
            convolved_grids.append(convolved)

        # Apply intra-communication
        factor_grid = self.intra_communications(FactorGrid(convolved_grids))

        # Apply projection if needed
        factor_grid = self.proj(factor_grid)

        return factor_grid

forward(factor_grid: FactorGrid) -> FactorGrid

Forward pass through the global convolution module.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Forward pass through the global convolution module."""
    assert len(factor_grid) == len(
        self.conv_blocks
    ), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"

    # Apply convolution blocks to each grid
    convolved_grids = []
    for grid, conv_block in zip(factor_grid, self.conv_blocks):
        convolved = conv_block(grid)
        convolved_grids.append(convolved)

    # Apply intra-communication
    factor_grid = self.intra_communications(FactorGrid(convolved_grids))

    # Apply projection if needed
    factor_grid = self.proj(factor_grid)

    return factor_grid

FactorGridIntraCommunication

Bases: BaseSpatialModule

Intra-communication between grids in a FactorGrid.

This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridIntraCommunication(BaseSpatialModule):
    """Intra-communication between grids in a FactorGrid.

    This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.
    """

    def __init__(self, communication_types: List[Literal["sum", "mul"]] = ["sum"]) -> None:
        super().__init__()
        assert len(communication_types) > 0, "At least one communication type must be provided"
        assert len(communication_types) <= 2, "At most two communication types can be provided"
        self.communication_types = communication_types

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Perform intra-communication between grids in the FactorGrid."""
        return factor_grid_intra_communication(factor_grid, self.communication_types)

forward(factor_grid: FactorGrid) -> FactorGrid

Perform intra-communication between grids in the FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Perform intra-communication between grids in the FactorGrid."""
    return factor_grid_intra_communication(factor_grid, self.communication_types)

FactorGridPadToMatch

Bases: BaseSpatialModule

Pad FactorGrid features to match spatial dimensions for UNet skip connections.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPadToMatch(BaseSpatialModule):
    """Pad FactorGrid features to match spatial dimensions for UNet skip connections."""

    def __init__(self):
        super().__init__()

    def forward(self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid:
        """Pad up_factor_grid to match down_factor_grid spatial dimensions."""
        assert len(up_factor_grid) == len(
            down_factor_grid
        ), "FactorGrids must have same number of grids"

        padded_grids = []
        for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
            # Get features and shapes
            up_features = up_grid.grid_features.batched_tensor
            down_features = down_grid.grid_features.batched_tensor

            # Get spatial dimensions (excluding batch and channel)
            up_shape = up_features.shape[2:]  # Spatial dimensions
            down_shape = down_features.shape[2:]  # Spatial dimensions

            if up_shape == down_shape:
                # No padding needed
                padded_grids.append(up_grid)
            else:
                # Calculate padding needed
                pad_h = max(0, down_shape[0] - up_shape[0])
                pad_w = max(0, down_shape[1] - up_shape[1])
                pad_d = max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0

                # Apply padding
                if len(down_shape) == 2:  # 2D case
                    padded_features = F.pad(up_features, (0, pad_w, 0, pad_h), mode="replicate")
                else:  # 3D case
                    padded_features = F.pad(
                        up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
                    )

                # Create new grid with padded features
                padded_grid = up_grid.replace(batched_features=padded_features)
                padded_grids.append(padded_grid)

        return FactorGrid(padded_grids)

forward(up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid

Pad up_factor_grid to match down_factor_grid spatial dimensions.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid:
    """Pad up_factor_grid to match down_factor_grid spatial dimensions."""
    assert len(up_factor_grid) == len(
        down_factor_grid
    ), "FactorGrids must have same number of grids"

    padded_grids = []
    for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
        # Get features and shapes
        up_features = up_grid.grid_features.batched_tensor
        down_features = down_grid.grid_features.batched_tensor

        # Get spatial dimensions (excluding batch and channel)
        up_shape = up_features.shape[2:]  # Spatial dimensions
        down_shape = down_features.shape[2:]  # Spatial dimensions

        if up_shape == down_shape:
            # No padding needed
            padded_grids.append(up_grid)
        else:
            # Calculate padding needed
            pad_h = max(0, down_shape[0] - up_shape[0])
            pad_w = max(0, down_shape[1] - up_shape[1])
            pad_d = max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0

            # Apply padding
            if len(down_shape) == 2:  # 2D case
                padded_features = F.pad(up_features, (0, pad_w, 0, pad_h), mode="replicate")
            else:  # 3D case
                padded_features = F.pad(
                    up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
                )

            # Create new grid with padded features
            padded_grid = up_grid.replace(batched_features=padded_features)
            padded_grids.append(padded_grid)

    return FactorGrid(padded_grids)

FactorGridPool

Bases: BaseSpatialModule

Pooling operation for FactorGrid.

This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPool(BaseSpatialModule):
    """Pooling operation for FactorGrid.

    This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.
    """

    def __init__(
        self,
        pooling_type: Literal["max", "mean", "attention"] = "max",
    ):
        super().__init__()
        self.pooling_type = pooling_type

        # Pooling operation
        if pooling_type == "max":
            self.pool_op = nn.AdaptiveMaxPool1d(1)
        elif pooling_type == "mean":
            self.pool_op = nn.AdaptiveAvgPool1d(1)
        elif pooling_type == "attention":
            # For now, use simple attention mechanism
            # Note: attention layer dimensions will depend on actual feature dimensions
            self.attention = None  # Will be set based on input if needed
            self.pool_op = None
        else:
            raise ValueError(f"Unsupported pooling type: {pooling_type}")

    def forward(self, factor_grid: FactorGrid) -> Tensor:
        """Pool features from FactorGrid to a single tensor."""
        return factor_grid_pool(
            factor_grid,
            self.pooling_type,
            pool_op=self.pool_op,
            attention_layer=getattr(self, "attention", None),
        )

forward(factor_grid: FactorGrid) -> Tensor

Pool features from FactorGrid to a single tensor.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> Tensor:
    """Pool features from FactorGrid to a single tensor."""
    return factor_grid_pool(
        factor_grid,
        self.pooling_type,
        pool_op=self.pool_op,
        attention_layer=getattr(self, "attention", None),
    )

FactorGridProjection

Bases: BaseSpatialModule

Projection operation for FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridProjection(BaseSpatialModule):
    """Projection operation for FactorGrid."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        compressed_spatial_dims: Tuple[int, ...],
        compressed_memory_formats: Tuple[GridMemoryFormat, ...],
        stride: int = 1,
        norm: Type[nn.Module] = nn.BatchNorm2d,
        activation: Type[nn.Module] = nn.GELU,
    ):
        super().__init__()
        self.convs = nn.ModuleList()
        for compressed_spatial_dim, compressed_memory_format in zip(
            compressed_spatial_dims, compressed_memory_formats
        ):
            block = nn.Sequential(
                nn.Conv2d(
                    in_channels * compressed_spatial_dim,
                    out_channels * compressed_spatial_dim,
                    kernel_size,
                    stride,
                    (kernel_size - 1) // 2,
                    bias=True,
                ),
                norm(out_channels * compressed_spatial_dim),
                activation(),
            )
            self.convs.append(block)

    def forward(self, grid: FactorGrid) -> FactorGrid:
        projected_grids = []
        for grid, conv in zip(grid, self.convs):
            projected_grid = conv(grid)
            projected_grids.append(projected_grid)
        return FactorGrid(projected_grids)

FactorGridToPoint

Bases: BaseSpatialModule

Convert FactorGrid features back to point features using trilinear interpolation.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridToPoint(BaseSpatialModule):
    """Convert FactorGrid features back to point features using trilinear interpolation."""

    def __init__(
        self,
        grid_in_channels: int,
        point_in_channels: int,
        num_grids: int,
        out_channels: int,
        use_rel_pos: bool = True,
        use_rel_pos_embed: bool = False,
        pos_embed_dim: int = 32,
        sample_method: Literal["graphconv", "interp"] = "interp",
        neighbor_search_type: Literal["radius", "knn"] = "radius",
        knn_k: int = 16,
        reductions: List[str] = ["mean"],
    ):
        super().__init__()
        self.grid_in_channels = grid_in_channels
        self.point_in_channels = point_in_channels
        self.out_channels = out_channels
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_embed = use_rel_pos_embed
        self.pos_embed_dim = pos_embed_dim
        self.sample_method = sample_method
        self.neighbor_search_type = neighbor_search_type
        self.knn_k = knn_k
        self.reductions = reductions
        self.freqs = get_freqs(pos_embed_dim)

        # Calculate combined channels for MLP
        combined_channels = grid_in_channels * num_grids + point_in_channels
        if use_rel_pos_embed:
            combined_channels += pos_embed_dim * 3
        elif use_rel_pos:
            combined_channels += 3

        self.combine_mlp = nn.Sequential(
            nn.Linear(combined_channels, out_channels),
            nn.LayerNorm(out_channels),
            nn.GELU(),
            nn.Linear(out_channels, out_channels),
        )

    def _normalize_coordinates(self, coords: Tensor, grid: Grid) -> Tensor:
        """Normalize coordinates to [-1, 1] range for grid_sample using grid bounds."""
        # Get bounds from the grid
        bounds_min, bounds_max = grid.bounds

        # Normalize to [0, 1] first
        normalized = (coords - bounds_min) / (bounds_max - bounds_min)

        # Convert to [-1, 1] range for grid_sample
        normalized = 2.0 * normalized - 1.0

        # Ensure coordinates are within bounds
        normalized = torch.clamp(normalized, -1.0, 1.0)

        return normalized

    def _sample_from_grid(self, grid: Grid, point_coords: Tensor) -> Tensor:
        """Sample features from grid using trilinear interpolation."""
        grid_features_tensor = grid.grid_features.batched_tensor
        batch_size = grid_features_tensor.shape[0]

        # Normalize point coordinates using grid bounds
        normalized_coords = self._normalize_coordinates(point_coords, grid)

        # Reshape coordinates for grid_sample: (B, N, 3) -> (B, N, 1, 1, 3)
        # grid_sample expects (B, H, W, D, 3) for 3D grids
        normalized_coords = normalized_coords.view(batch_size, -1, 1, 1, 3)

        # Convert grid to standard format for grid_sample
        grid_reshaped = grid.grid_features.to_memory_format(GridMemoryFormat.b_c_x_y_z)

        # Use grid_sample for trilinear interpolation
        sampled_features = F.grid_sample(
            grid_reshaped.batched_tensor,
            normalized_coords,
            mode="bilinear",  # For 3D, this becomes trilinear
            padding_mode="border",
            align_corners=True,
        )  # Shape: B, C, N, 1, 1

        # Reshape to (B*N, C)
        sampled_features = sampled_features.squeeze(-1).squeeze(-1).transpose(1, 2)  # B, N, C
        sampled_features = sampled_features.flatten(start_dim=0, end_dim=1)  # B*N, C
        return sampled_features

    def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
        """Convert FactorGrid features to point features using trilinear interpolation."""
        # Get point coordinates and features
        point_coords = point_features.coordinate_tensor
        point_feats = point_features.feature_tensor
        batch_size = point_features.batch_size
        num_points = point_coords.shape[0]

        # Sample features from all grids and concatenate
        all_grid_features = []

        for grid in factor_grid:
            # Sample features from this grid
            sampled_features = self._sample_from_grid(grid, point_coords)
            all_grid_features.append(sampled_features)

        # Concatenate features from all grids
        if len(all_grid_features) > 1:
            grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
        else:
            grid_feat_per_point = all_grid_features[0]

        # Add relative position features if requested
        if self.use_rel_pos_embed:
            # Use bounds from the first grid for relative position calculation
            bounds_min, bounds_max = factor_grid[0].bounds
            rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
            pos_encoding = sinusoidal_encoding(rel_pos, self.pos_embed_dim, data_range=2)
            combined_features = torch.cat([point_feats, grid_feat_per_point, pos_encoding], dim=-1)
        elif self.use_rel_pos:
            # Use raw relative positions
            # Use bounds from the first grid for relative position calculation
            bounds_min, bounds_max = factor_grid[0].bounds
            rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
            combined_features = torch.cat([point_feats, grid_feat_per_point, rel_pos], dim=-1)
        else:
            # Just concatenate point and grid features
            combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)

        # Apply MLP
        output_features = self.combine_mlp(combined_features)

        # Create new Points object
        return Points(
            batched_coordinates=point_features.batched_coordinates,
            batched_features=output_features,
        )

forward(factor_grid: FactorGrid, point_features: Points) -> Points

Convert FactorGrid features to point features using trilinear interpolation.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
    """Convert FactorGrid features to point features using trilinear interpolation."""
    # Get point coordinates and features
    point_coords = point_features.coordinate_tensor
    point_feats = point_features.feature_tensor
    batch_size = point_features.batch_size
    num_points = point_coords.shape[0]

    # Sample features from all grids and concatenate
    all_grid_features = []

    for grid in factor_grid:
        # Sample features from this grid
        sampled_features = self._sample_from_grid(grid, point_coords)
        all_grid_features.append(sampled_features)

    # Concatenate features from all grids
    if len(all_grid_features) > 1:
        grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
    else:
        grid_feat_per_point = all_grid_features[0]

    # Add relative position features if requested
    if self.use_rel_pos_embed:
        # Use bounds from the first grid for relative position calculation
        bounds_min, bounds_max = factor_grid[0].bounds
        rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
        pos_encoding = sinusoidal_encoding(rel_pos, self.pos_embed_dim, data_range=2)
        combined_features = torch.cat([point_feats, grid_feat_per_point, pos_encoding], dim=-1)
    elif self.use_rel_pos:
        # Use raw relative positions
        # Use bounds from the first grid for relative position calculation
        bounds_min, bounds_max = factor_grid[0].bounds
        rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
        combined_features = torch.cat([point_feats, grid_feat_per_point, rel_pos], dim=-1)
    else:
        # Just concatenate point and grid features
        combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)

    # Apply MLP
    output_features = self.combine_mlp(combined_features)

    # Create new Points object
    return Points(
        batched_coordinates=point_features.batched_coordinates,
        batched_features=output_features,
    )

FactorGridTransform

Bases: BaseSpatialModule

Apply a transform operation to all grids in a FactorGrid.

This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.

Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridTransform(BaseSpatialModule):
    """Apply a transform operation to all grids in a FactorGrid.

    This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.
    """

    def __init__(self, transform: nn.Module, in_place: bool = True) -> None:
        super().__init__()
        self.transform = transform
        self.in_place = in_place

    def forward(self, factor_grid: FactorGrid) -> FactorGrid:
        """Apply transform to all grids in the FactorGrid."""
        return factor_grid_transform(factor_grid, self.transform, self.in_place)

forward(factor_grid: FactorGrid) -> FactorGrid

Apply transform to all grids in the FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
    """Apply transform to all grids in the FactorGrid."""
    return factor_grid_transform(factor_grid, self.transform, self.in_place)

PointToFactorGrid

Bases: BaseSpatialModule

Convert point features to FactorGrid representation.

Source code in warpconvnet/nn/modules/factor_grid.py
class PointToFactorGrid(BaseSpatialModule):
    """Convert point features to FactorGrid representation."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        grid_shapes: List[Tuple[int, int, int]],
        memory_formats: List[GridMemoryFormat],
        aabb_max: Tuple[float, float, float],
        aabb_min: Tuple[float, float, float],
        use_rel_pos: bool = True,
        use_rel_pos_embed: bool = True,
        pos_encode_dim: int = 32,
        search_radius: Optional[float] = None,
        k: int = 8,
        search_type: Literal["radius", "knn", "voxel"] = "radius",
        reduction: str = "mean",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_shapes = grid_shapes
        self.memory_formats = memory_formats
        self.aabb_max = aabb_max
        self.aabb_min = aabb_min
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_embed = use_rel_pos_embed
        self.pos_encode_dim = pos_encode_dim
        self.search_radius = search_radius
        self.k = k
        self.search_type = search_type
        self.reduction = reduction

        # Calculate compressed spatial dimensions
        self.compressed_spatial_dims = []
        for grid_shape, memory_format in zip(grid_shapes, memory_formats):
            # Determine compressed spatial dimension
            if memory_format == GridMemoryFormat.b_zc_x_y:
                compressed_dim = grid_shape[0]  # Z dimension
            elif memory_format == GridMemoryFormat.b_xc_y_z:
                compressed_dim = grid_shape[0]  # X dimension
            elif memory_format == GridMemoryFormat.b_yc_x_z:
                compressed_dim = grid_shape[1]  # Y dimension
            else:
                raise ValueError(f"Unsupported memory format: {memory_format}")

            self.compressed_spatial_dims.append(compressed_dim)

        # Create projection layers for each grid
        self.projections = nn.ModuleList()
        for compressed_dim in self.compressed_spatial_dims:
            # Create projection layer
            proj = nn.Linear(in_channels, out_channels * compressed_dim)
            self.projections.append(proj)

    def forward(self, points: Points) -> FactorGrid:
        """Convert point features to FactorGrid."""
        from warpconvnet.geometry.types.conversion.to_factor_grid import points_to_factor_grid

        # Convert points to FactorGrid using the existing function
        factor_grid = points_to_factor_grid(
            points=points,
            grid_shapes=self.grid_shapes,
            memory_formats=self.memory_formats,
            bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
            search_radius=self.search_radius,
            k=self.k,
            search_type=self.search_type,
            reduction=self.reduction,
        )

        # For now, return the factor grid as is without projections
        # The projections can be applied later in the pipeline if needed
        return factor_grid

forward(points: Points) -> FactorGrid

Convert point features to FactorGrid.

Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, points: Points) -> FactorGrid:
    """Convert point features to FactorGrid."""
    from warpconvnet.geometry.types.conversion.to_factor_grid import points_to_factor_grid

    # Convert points to FactorGrid using the existing function
    factor_grid = points_to_factor_grid(
        points=points,
        grid_shapes=self.grid_shapes,
        memory_formats=self.memory_formats,
        bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
        search_radius=self.search_radius,
        k=self.k,
        search_type=self.search_type,
        reduction=self.reduction,
    )

    # For now, return the factor grid as is without projections
    # The projections can be applied later in the pipeline if needed
    return factor_grid

Grid convolution

warpconvnet.nn.modules.grid_conv

GridConv

Bases: BaseSpatialModule

Convolutional layer for warpconvnet.geometry.types.grid.Grid data.

Parameters mirror those of torch.nn.Conv3d but operate on a Grid object instead of plain tensors.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Stride of the convolution. Defaults to 1.

  • padding (int or tuple of int, default: 0 ) –

    Zero-padding added to all three sides of the input. Defaults to 0.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True, adds a learnable bias to the output. Defaults to True.

  • num_spatial_dims (int, default: 3 ) –

    Number of spatial dimensions. Defaults to 3.

Source code in warpconvnet/nn/modules/grid_conv.py
class GridConv(BaseSpatialModule):
    """Convolutional layer for `warpconvnet.geometry.types.grid.Grid` data.

    Parameters mirror those of `torch.nn.Conv3d` but operate on a
    ``Grid`` object instead of plain tensors.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Stride of the convolution. Defaults to ``1``.
    padding : int or tuple of int, optional
        Zero-padding added to all three sides of the input. Defaults to ``0``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True``, adds a learnable bias to the output. Defaults to ``True``.
    num_spatial_dims : int, optional
        Number of spatial dimensions. Defaults to ``3``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Union[int, Tuple[int, ...]] = 0,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        num_spatial_dims: Optional[int] = 3,
    ):
        super().__init__()
        kernel_size = ntuple(kernel_size, ndim=num_spatial_dims)
        stride = ntuple(stride, ndim=num_spatial_dims)
        padding = ntuple(padding, ndim=num_spatial_dims)
        dilation = ntuple(dilation, ndim=num_spatial_dims)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.num_spatial_dims = num_spatial_dims

        # For 3D convolution, shape is (out_channels, in_channels, depth, height, width)
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"in_channels={self.in_channels}, "
            f"out_channels={self.out_channels}, "
            f"kernel_size={self.kernel_size}, "
            f"stride={self.stride}, "
            f"padding={self.padding}, "
            f"dilation={self.dilation}, "
            f"bias={self.bias is not None}"
            f")"
        )

    def reset_parameters(self):
        # Standard initialization for convolutional layers
        init.kaiming_uniform_(self.weight, a=1)
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / (fan_in**0.5)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input_grid: Grid) -> Grid:
        return grid_conv(
            grid=input_grid,
            weight=self.weight,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            bias=self.bias,
        )

MLP

warpconvnet.nn.modules.mlp

BatchedLinear

Bases: Module

A linear layer with batched weights for Muon-friendly optimization.

Instead of a single weight matrix [in_features, out_features * num_matrices], this uses separate weight matrices stacked as [num_matrices, in_features, out_features]. This structure is more suitable for Muon optimization as it can orthogonalize each [in_features, out_features] matrix independently.

Args: in_features: Input feature dimension out_features: Output feature dimension per matrix num_matrices: Number of separate matrices (e.g., 3 for Q, K, V) bias: Whether to use bias parameters

Source code in warpconvnet/nn/modules/mlp.py
class BatchedLinear(nn.Module):
    """
    A linear layer with batched weights for Muon-friendly optimization.

    Instead of a single weight matrix [in_features, out_features * num_matrices],
    this uses separate weight matrices stacked as [num_matrices, in_features, out_features].
    This structure is more suitable for Muon optimization as it can orthogonalize
    each [in_features, out_features] matrix independently.

    Args:
        in_features: Input feature dimension
        out_features: Output feature dimension per matrix
        num_matrices: Number of separate matrices (e.g., 3 for Q, K, V)
        bias: Whether to use bias parameters
    """

    def __init__(
        self, in_features: int, out_features: int, num_matrices: int = 3, bias: bool = True
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_matrices = num_matrices

        # Create batched weight: [num_matrices, in_features, out_features]
        self.weight = nn.Parameter(torch.empty(num_matrices, in_features, out_features))
        nn.init.xavier_uniform_(self.weight)

        if bias:
            # Use flat bias for Muon - 1D parameter gets Adam optimization
            self.bias = nn.Parameter(torch.zeros(num_matrices * out_features))
        else:
            self.register_parameter("bias", None)

    def forward(self, input: Tensor) -> Tensor:
        """
        Forward pass with batched matrix multiplication.

        Args:
            input: Input tensor of shape [..., in_features]

        Returns:
            Output tensor of shape [..., num_matrices, out_features]
        """
        # input: [..., in_features], weight: [num_matrices, in_features, out_features]
        # output: [..., num_matrices, out_features]
        output = torch.einsum("...i,kio->...ko", input, self.weight)

        if self.bias is not None:
            output += self.bias.view(self.num_matrices, self.out_features)

        output = output.to(input.dtype)
        return output

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, num_matrices={self.num_matrices}, bias={self.bias is not None}"

forward(input: Tensor) -> Tensor

Forward pass with batched matrix multiplication.

Args: input: Input tensor of shape [..., in_features]

Returns: Output tensor of shape [..., num_matrices, out_features]

Source code in warpconvnet/nn/modules/mlp.py
def forward(self, input: Tensor) -> Tensor:
    """
    Forward pass with batched matrix multiplication.

    Args:
        input: Input tensor of shape [..., in_features]

    Returns:
        Output tensor of shape [..., num_matrices, out_features]
    """
    # input: [..., in_features], weight: [num_matrices, in_features, out_features]
    # output: [..., num_matrices, out_features]
    output = torch.einsum("...i,kio->...ko", input, self.weight)

    if self.bias is not None:
        output += self.bias.view(self.num_matrices, self.out_features)

    output = output.to(input.dtype)
    return output

Linear

Bases: BaseSpatialModule

Apply a linear layer to Geometry features.

Parameters:
  • in_features (int) –

    Number of input features.

  • out_features (int) –

    Number of output features.

  • bias (bool, default: True ) –

    If True adds a bias term to the layer. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class Linear(BaseSpatialModule):
    """Apply a linear layer to ``Geometry`` features.

    Parameters
    ----------
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    bias : bool, optional
        If ``True`` adds a bias term to the layer. Defaults to ``True``.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.block = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, x: Geometry):
        return x.replace(batched_features=self.block(x.feature_tensor))

LinearNormActivation

Bases: BaseSpatialModule

Linear layer followed by LayerNorm and ReLU.

Parameters:
  • in_features (int) –

    Number of input features.

  • out_features (int) –

    Number of output features.

  • bias (bool, default: True ) –

    Whether to include a bias term. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class LinearNormActivation(BaseSpatialModule):
    """Linear layer followed by ``LayerNorm`` and ``ReLU``.

    Parameters
    ----------
    in_features : int
        Number of input features.
    out_features : int
        Number of output features.
    bias : bool, optional
        Whether to include a bias term. Defaults to ``True``.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, out_features, bias=bias),
            nn.LayerNorm(out_features),
            nn.ReLU(),
        )

    def forward(self, x: Geometry):
        return x.replace(batched_features=self.block(x.feature_tensor))

MLPBlock

Bases: BaseSpatialModule

MLP block with a residual connection.

Parameters:
  • in_channels (int) –

    Number of input features.

  • out_channels (int, default: None ) –

    Number of output features. Defaults to in_channels.

  • hidden_channels (int, default: None ) –

    Hidden layer size. Defaults to in_channels.

  • activation (``nn.Module``, default: ReLU ) –

    Activation module to apply. Defaults to torch.nn.ReLU.

  • bias (bool, default: True ) –

    If True adds bias terms to the linear layers. Defaults to True.

Source code in warpconvnet/nn/modules/mlp.py
class MLPBlock(BaseSpatialModule):
    """MLP block with a residual connection.

    Parameters
    ----------
    in_channels : int
        Number of input features.
    out_channels : int, optional
        Number of output features. Defaults to ``in_channels``.
    hidden_channels : int, optional
        Hidden layer size. Defaults to ``in_channels``.
    activation : ``nn.Module``, optional
        Activation module to apply. Defaults to `torch.nn.ReLU`.
    bias : bool, optional
        If ``True`` adds bias terms to the linear layers. Defaults to ``True``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int = None,
        hidden_channels: int = None,
        activation=nn.ReLU,
        bias: bool = True,
    ):
        super().__init__()
        if hidden_channels is None:
            hidden_channels = in_channels
        if out_channels is None:
            out_channels = in_channels
        self.in_channels = in_channels
        self.block = nn.Sequential(
            nn.Linear(in_channels, hidden_channels, bias=bias),
            nn.LayerNorm(hidden_channels),
            activation(),
            nn.Linear(hidden_channels, out_channels, bias=bias),
            nn.LayerNorm(out_channels),
        )
        self.shortcut = (
            nn.Linear(in_channels, out_channels, bias=bias)
            if in_channels != out_channels
            else nn.Identity()
        )

    def _forward_feature(self, x: Tensor) -> Tensor:
        out = self.block(x)
        out = out + self.shortcut(x)
        return out

    def forward(self, x: Union[Tensor, Geometry]):
        if isinstance(x, Geometry):
            return x.replace(batched_features=self._forward_feature(x.feature_tensor))
        else:
            return self._forward_feature(x)

Normalizations

warpconvnet.nn.modules.normalizations

BatchNorm

Bases: NormalizationBase

Applies torch.nn.BatchNorm1d to Geometry features.

Parameters:
  • num_features (int) –

    Number of feature channels in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

  • momentum (float, default: 0.1 ) –

    Momentum factor for the running statistics. Defaults to 0.1.

Source code in warpconvnet/nn/modules/normalizations.py
class BatchNorm(NormalizationBase):
    """Applies `torch.nn.BatchNorm1d` to ``Geometry`` features.

    Parameters
    ----------
    num_features : int
        Number of feature channels in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    momentum : float, optional
        Momentum factor for the running statistics. Defaults to ``0.1``.
    """

    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__(nn.BatchNorm1d(num_features, eps=eps, momentum=momentum))

GroupNorm

Bases: NormalizationBase

Applies torch.nn.GroupNorm to Geometry features.

Parameters:
  • num_groups (int) –

    Number of groups to separate the channels into.

  • num_channels (int) –

    Number of channels expected in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

Source code in warpconvnet/nn/modules/normalizations.py
class GroupNorm(NormalizationBase):
    """Applies `torch.nn.GroupNorm` to ``Geometry`` features.

    Parameters
    ----------
    num_groups : int
        Number of groups to separate the channels into.
    num_channels : int
        Number of channels expected in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    """

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        super().__init__(nn.GroupNorm(num_groups, num_channels, eps=eps))

InstanceNorm

Bases: NormalizationBase

Applies torch.nn.InstanceNorm1d to Geometry features.

Parameters:
  • num_features (int) –

    Number of feature channels in the input.

  • eps (float, default: 1e-05 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-5.

Source code in warpconvnet/nn/modules/normalizations.py
class InstanceNorm(NormalizationBase):
    """Applies `torch.nn.InstanceNorm1d` to ``Geometry`` features.

    Parameters
    ----------
    num_features : int
        Number of feature channels in the input.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-5``.
    """

    def __init__(self, num_features: int, eps: float = 1e-5):
        super().__init__(nn.InstanceNorm1d(num_features, eps=eps))

LayerNorm

Bases: NormalizationBase

Applies torch.nn.LayerNorm to Geometry features.

Parameters:
  • normalized_shape (list of int) –

    Input shape from an expected input.

  • eps (float, default: 1e-05 ) –

    A value added to the denominator for numerical stability. Defaults to 1e-5.

  • elementwise_affine (bool, default: True ) –

    Whether to learn elementwise affine parameters. Defaults to True.

  • bias (bool, default: True ) –

    If True adds bias parameters. Defaults to True.

Source code in warpconvnet/nn/modules/normalizations.py
class LayerNorm(NormalizationBase):
    """Applies `torch.nn.LayerNorm` to ``Geometry`` features.

    Parameters
    ----------
    normalized_shape : list of int
        Input shape from an expected input.
    eps : float, optional
        A value added to the denominator for numerical stability. Defaults to ``1e-5``.
    elementwise_affine : bool, optional
        Whether to learn elementwise affine parameters. Defaults to ``True``.
    bias : bool, optional
        If ``True`` adds bias parameters. Defaults to ``True``.
    """

    def __init__(
        self,
        normalized_shape: List[int],
        eps: float = 1e-5,
        elementwise_affine: bool = True,
        bias: bool = True,
    ):
        super().__init__(
            nn.LayerNorm(
                normalized_shape,
                eps=eps,
                elementwise_affine=elementwise_affine,
                bias=bias,
            )
        )

NormalizationBase

Bases: BaseSpatialModule

Wrapper for applying a normalization module to Geometry features.

Parameters:
  • norm (``nn.Module``) –

    Normalization module to apply to the feature tensor.

Source code in warpconvnet/nn/modules/normalizations.py
class NormalizationBase(BaseSpatialModule):
    """Wrapper for applying a normalization module to ``Geometry`` features.

    Parameters
    ----------
    norm : ``nn.Module``
        Normalization module to apply to the feature tensor.
    """

    def __init__(self, norm: nn.Module):
        super().__init__()
        self.norm = norm

    def __repr__(self):
        return f"{self.__class__.__name__}({self.norm})"

    def forward(
        self,
        input: Union[Geometry, Tensor],
    ):
        return apply_feature_transform(input, self.norm)

RMSNorm

Bases: NormalizationBase

Applies torch.nn.RMSNorm to Geometry features.

Parameters:
  • dim (int) –

    Number of input features.

  • eps (float, default: 1e-06 ) –

    Value added to the denominator for numerical stability. Defaults to 1e-6.

Source code in warpconvnet/nn/modules/normalizations.py
class RMSNorm(NormalizationBase):
    """Applies `torch.nn.RMSNorm` to ``Geometry`` features.

    Parameters
    ----------
    dim : int
        Number of input features.
    eps : float, optional
        Value added to the denominator for numerical stability. Defaults to ``1e-6``.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__(nn.RMSNorm(dim, eps=eps))

SegmentedLayerNorm

Bases: LayerNorm

Layer normalization that respects variable-length segments.

Parameters:
  • channels (int) –

    Number of feature channels.

  • eps (float, default: 1e-05 ) –

    A value added to the denominator for numerical stability. Defaults to 1e-5.

  • elementwise_affine (bool, default: True ) –

    If True learn per-channel affine parameters. Defaults to True.

  • bias (bool, default: True ) –

    Whether to include a bias term. Defaults to True.

Source code in warpconvnet/nn/modules/normalizations.py
class SegmentedLayerNorm(nn.LayerNorm):
    """Layer normalization that respects variable-length segments.

    Parameters
    ----------
    channels : int
        Number of feature channels.
    eps : float, optional
        A value added to the denominator for numerical stability. Defaults to ``1e-5``.
    elementwise_affine : bool, optional
        If ``True`` learn per-channel affine parameters. Defaults to ``True``.
    bias : bool, optional
        Whether to include a bias term. Defaults to ``True``.
    """

    def __init__(
        self,
        channels: int,
        eps: float = 1e-5,
        elementwise_affine: bool = True,
        bias: bool = True,
    ):
        super().__init__([channels], eps=eps, elementwise_affine=elementwise_affine, bias=bias)

    def forward(self, x: Geometry):
        # Only works for geometry with batched features
        assert isinstance(
            x, Geometry
        ), f"SegmentedLayerNorm only works for Geometry, got {type(x)}"
        out_feature = segmented_layer_norm(
            x.feature_tensor,
            x.offsets,
            gamma=self.weight if self.elementwise_affine else None,
            beta=self.bias if self.elementwise_affine else None,
            eps=self.eps,
        )
        return x.replace(batched_feature=out_feature)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.weight.shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine})"

Point convolution

warpconvnet.nn.modules.point_conv

PointConv

Bases: BaseSpatialModule

Point convolution operating on warpconvnet.geometry.types.points.Points.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • neighbor_search_args (`warpconvnet.geometry.coords.search.search_configs.RealSearchConfig`) –

    Configuration for neighbor search.

  • pooling_reduction (`warpconvnet.ops.reductions.REDUCTIONS`, default: None ) –

    Reduction used when downsampling points. Required when out_point_type is "downsample".

  • pooling_voxel_size (float, default: None ) –

    Size of voxels used for downsampling when out_point_type is "downsample".

  • edge_transform_mlp (Module, default: None ) –

    MLP applied to constructed edge features.

  • out_transform_mlp (Module, default: None ) –

    MLP applied after neighborhood reduction.

  • mlp_block (Module, default: MLPBlock ) –

    Module used to build edge_transform_mlp and out_transform_mlp when not provided. Defaults to MLPBlock.

  • hidden_dim (int, default: None ) –

    Hidden dimension of the automatically created MLPs.

  • channel_multiplier (int, default: 2 ) –

    Multiplier used when hidden_dim is None. Defaults to 2.

  • use_rel_pos (bool, default: False ) –

    Include relative coordinates of neighbors as part of the edge features.

  • use_rel_pos_encode (bool, default: False ) –

    Include sinusoidal encoding of relative coordinates in the edge features.

  • pos_encode_dim (int, default: 32 ) –

    Dimension of the positional encoding. Defaults to 32.

  • pos_encode_range (float, default: 4 ) –

    Range of the positional encoding. Defaults to 4.

  • reductions (list of str, default: ('mean',) ) –

    Reductions applied over the neighbor dimension. Defaults to ("mean",).

  • out_point_type ((provided, downsample, same), default: "provided" ) –

    Determines the coordinate set on which output features are computed.

  • provided_in_channels (int, default: None ) –

    Number of channels of the provided query points when out_point_type is "provided".

  • bias (bool, default: True ) –

    Whether linear layers contain biases. Defaults to True.

Source code in warpconvnet/nn/modules/point_conv.py
class PointConv(BaseSpatialModule):
    """Point convolution operating on `warpconvnet.geometry.types.points.Points`.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    neighbor_search_args : `warpconvnet.geometry.coords.search.search_configs.RealSearchConfig`
        Configuration for neighbor search.
    pooling_reduction : `warpconvnet.ops.reductions.REDUCTIONS`, optional
        Reduction used when downsampling points. Required when ``out_point_type`` is
        ``"downsample"``.
    pooling_voxel_size : float, optional
        Size of voxels used for downsampling when ``out_point_type`` is ``"downsample"``.
    edge_transform_mlp : nn.Module, optional
        MLP applied to constructed edge features.
    out_transform_mlp : nn.Module, optional
        MLP applied after neighborhood reduction.
    mlp_block : nn.Module, optional
        Module used to build ``edge_transform_mlp`` and ``out_transform_mlp`` when not
        provided. Defaults to `MLPBlock`.
    hidden_dim : int, optional
        Hidden dimension of the automatically created MLPs.
    channel_multiplier : int, optional
        Multiplier used when ``hidden_dim`` is ``None``. Defaults to ``2``.
    use_rel_pos : bool, optional
        Include relative coordinates of neighbors as part of the edge features.
    use_rel_pos_encode : bool, optional
        Include sinusoidal encoding of relative coordinates in the edge features.
    pos_encode_dim : int, optional
        Dimension of the positional encoding. Defaults to ``32``.
    pos_encode_range : float, optional
        Range of the positional encoding. Defaults to ``4``.
    reductions : list of str, optional
        Reductions applied over the neighbor dimension. Defaults to ``("mean",)``.
    out_point_type : {"provided", "downsample", "same"}, optional
        Determines the coordinate set on which output features are computed.
    provided_in_channels : int, optional
        Number of channels of the provided query points when ``out_point_type`` is ``"provided"``.
    bias : bool, optional
        Whether linear layers contain biases. Defaults to ``True``.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        neighbor_search_args: RealSearchConfig,
        pooling_reduction: Optional[REDUCTIONS] = None,
        pooling_voxel_size: Optional[float] = None,
        edge_transform_mlp: Optional[nn.Module] = None,
        out_transform_mlp: Optional[nn.Module] = None,
        mlp_block: nn.Module = MLPBlock,
        hidden_dim: Optional[int] = None,
        channel_multiplier: int = 2,
        use_rel_pos: bool = False,
        use_rel_pos_encode: bool = False,
        pos_encode_dim: int = 32,
        pos_encode_range: float = 4,
        reductions: List[REDUCTION_TYPES_STR] = ("mean",),
        out_point_type: Literal["provided", "downsample", "same"] = "same",
        provided_in_channels: Optional[int] = None,
        bias: bool = True,
    ):
        super().__init__()
        assert (
            isinstance(reductions, (tuple, list)) and len(reductions) > 0
        ), f"reductions must be a list or tuple of length > 0, got {reductions}"
        if out_point_type == "provided":
            assert pooling_reduction is None
            assert pooling_voxel_size is None
            assert (
                provided_in_channels is not None
            ), "provided_in_channels must be provided for provided type"
        elif out_point_type == "downsample":
            assert (
                pooling_reduction is not None and pooling_voxel_size is not None
            ), "pooling_reduction and pooling_voxel_size must be provided for downsample type"
            assert (
                provided_in_channels is None
            ), "provided_in_channels must be None for downsample type"
            # print warning if search radius is not \sqrt(3) times the downsample voxel size
            if (
                pooling_voxel_size is not None
                and neighbor_search_args.mode == RealSearchMode.RADIUS
                and neighbor_search_args.radius < pooling_voxel_size * (3**0.5)
            ):
                warnings.warn(
                    f"neighbor search radius {neighbor_search_args.radius} is less than sqrt(3) times the downsample voxel size {pooling_voxel_size}",
                    stacklevel=2,
                )
        elif out_point_type == "same":
            assert (
                pooling_reduction is None and pooling_voxel_size is None
            ), "pooling_reduction and pooling_voxel_size must be None for same type"
            assert provided_in_channels is None, "provided_in_channels must be None for same type"
        if (
            pooling_reduction is not None
            and pooling_voxel_size is not None
            and neighbor_search_args.mode == RealSearchMode.RADIUS
            and pooling_voxel_size > neighbor_search_args.radius
        ):
            raise ValueError(
                f"downsample_voxel_size {pooling_voxel_size} must be <= radius {neighbor_search_args.radius}"
            )

        assert isinstance(neighbor_search_args, RealSearchConfig)
        self.reductions = reductions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_rel_pos = use_rel_pos
        self.use_rel_pos_encode = use_rel_pos_encode
        self.out_point_feature_type = out_point_type
        self.neighbor_search_args = neighbor_search_args
        self.pooling_reduction = pooling_reduction
        self.pooling_voxel_size = pooling_voxel_size
        self.positional_encoding = SinusoidalEncoding(pos_encode_dim, data_range=pos_encode_range)
        # When down voxel size is not None, there will be out_point_features will be provided as an additional input
        if provided_in_channels is None:
            provided_in_channels = in_channels
        if hidden_dim is None:
            hidden_dim = channel_multiplier * max(out_channels, in_channels)
        if edge_transform_mlp is None:
            edge_in_channels = in_channels + provided_in_channels
            if use_rel_pos_encode:
                edge_in_channels += pos_encode_dim * 3
            elif use_rel_pos:
                edge_in_channels += 3
            edge_transform_mlp = mlp_block(
                in_channels=edge_in_channels,
                out_channels=out_channels,
                hidden_channels=hidden_dim,
                bias=bias,
            )
        self.edge_transform_mlp = edge_transform_mlp
        self.edge_mlp_in_channels = _get_module_input_channel(edge_transform_mlp)
        if out_transform_mlp is None:
            out_transform_mlp = mlp_block(
                in_channels=out_channels * len(reductions),
                out_channels=out_channels,
                hidden_channels=hidden_dim,
                bias=bias,
            )
        self.out_transform_mlp = out_transform_mlp

    def __repr__(self):
        out_str = f"{self.__class__.__name__}(in_channels={self.in_channels} out_channels={self.out_channels}"
        if self.use_rel_pos_encode:
            out_str += f" rel_pos_encode={self.use_rel_pos_encode}"
        if self.pooling_reduction is not None:
            out_str += f" pooling={self.pooling_reduction}"
        if self.neighbor_search_args is not None:
            out_str += f" neighbor={self.neighbor_search_args}"
        out_str += ")"
        return out_str

    def forward(
        self,
        in_pc: Points,
        query_pc: Optional[Points] = None,
    ) -> Points:
        """
        When out_point_features is None, the output will be generated on the
        in_point_features.batched_coordinates.
        """
        if self.out_point_feature_type == "provided":
            assert (
                query_pc is not None
            ), "query_point_features must be provided for the provided type"
        elif self.out_point_feature_type == "downsample":
            assert query_pc is None
            query_pc = in_pc.voxel_downsample(
                self.pooling_voxel_size,
                reduction=self.pooling_reduction,
            )
        elif self.out_point_feature_type == "same":
            assert query_pc is None
            query_pc = in_pc

        in_num_channels = in_pc.num_channels
        query_num_channels = query_pc.num_channels
        assert (
            in_num_channels
            + query_num_channels
            + self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
            + (not self.use_rel_pos_encode) * self.use_rel_pos * 3
            == self.edge_mlp_in_channels
        ), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"

        # Get the neighbors
        neighbors = in_pc.neighbors(
            query_coords=query_pc.batched_coordinates,
            search_args=self.neighbor_search_args,
        )
        neighbor_indices = neighbors.neighbor_indices.long().view(-1)
        neighbor_row_splits = neighbors.neighbor_row_splits
        num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]

        # repeat the self features using num_reps
        rep_in_features = in_pc.feature_tensor[neighbor_indices]
        self_features = torch.repeat_interleave(
            query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
            num_reps,
            dim=0,
        )
        edge_features = [rep_in_features, self_features]
        if self.use_rel_pos or self.use_rel_pos_encode:
            in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
            self_vertices = torch.repeat_interleave(
                query_pc.coordinate_tensor.view(-1, 3).contiguous(),
                num_reps,
                dim=0,
            )
            rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
            if self.use_rel_pos_encode:
                edge_features.append(self.positional_encoding(rel_coords))
            elif self.use_rel_pos:
                edge_features.append(rel_coords)
        edge_features = torch.cat(edge_features, dim=1)
        edge_features = self.edge_transform_mlp(edge_features)
        # if in_weight is not None:
        #     assert in_weight.shape[0] == in_point_features.features.shape[0]
        #     rep_weights = in_weight[neighbor_indices]
        #     edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)

        out_features = []
        for reduction in self.reductions:
            out_features.append(
                row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
            )
        out_features = torch.cat(out_features, dim=-1)
        out_features = self.out_transform_mlp(out_features)

        return Points(
            batched_coordinates=Coords(
                batched_tensor=query_pc.coordinate_tensor,
                offsets=query_pc.offsets,
            ),
            batched_features=out_features,
            **query_pc.extra_attributes,
        )

forward(in_pc: Points, query_pc: Optional[Points] = None) -> Points

When out_point_features is None, the output will be generated on the in_point_features.batched_coordinates.

Source code in warpconvnet/nn/modules/point_conv.py
def forward(
    self,
    in_pc: Points,
    query_pc: Optional[Points] = None,
) -> Points:
    """
    When out_point_features is None, the output will be generated on the
    in_point_features.batched_coordinates.
    """
    if self.out_point_feature_type == "provided":
        assert (
            query_pc is not None
        ), "query_point_features must be provided for the provided type"
    elif self.out_point_feature_type == "downsample":
        assert query_pc is None
        query_pc = in_pc.voxel_downsample(
            self.pooling_voxel_size,
            reduction=self.pooling_reduction,
        )
    elif self.out_point_feature_type == "same":
        assert query_pc is None
        query_pc = in_pc

    in_num_channels = in_pc.num_channels
    query_num_channels = query_pc.num_channels
    assert (
        in_num_channels
        + query_num_channels
        + self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
        + (not self.use_rel_pos_encode) * self.use_rel_pos * 3
        == self.edge_mlp_in_channels
    ), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"

    # Get the neighbors
    neighbors = in_pc.neighbors(
        query_coords=query_pc.batched_coordinates,
        search_args=self.neighbor_search_args,
    )
    neighbor_indices = neighbors.neighbor_indices.long().view(-1)
    neighbor_row_splits = neighbors.neighbor_row_splits
    num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]

    # repeat the self features using num_reps
    rep_in_features = in_pc.feature_tensor[neighbor_indices]
    self_features = torch.repeat_interleave(
        query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
        num_reps,
        dim=0,
    )
    edge_features = [rep_in_features, self_features]
    if self.use_rel_pos or self.use_rel_pos_encode:
        in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
        self_vertices = torch.repeat_interleave(
            query_pc.coordinate_tensor.view(-1, 3).contiguous(),
            num_reps,
            dim=0,
        )
        rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
        if self.use_rel_pos_encode:
            edge_features.append(self.positional_encoding(rel_coords))
        elif self.use_rel_pos:
            edge_features.append(rel_coords)
    edge_features = torch.cat(edge_features, dim=1)
    edge_features = self.edge_transform_mlp(edge_features)
    # if in_weight is not None:
    #     assert in_weight.shape[0] == in_point_features.features.shape[0]
    #     rep_weights = in_weight[neighbor_indices]
    #     edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)

    out_features = []
    for reduction in self.reductions:
        out_features.append(
            row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
        )
    out_features = torch.cat(out_features, dim=-1)
    out_features = self.out_transform_mlp(out_features)

    return Points(
        batched_coordinates=Coords(
            batched_tensor=query_pc.coordinate_tensor,
            offsets=query_pc.offsets,
        ),
        batched_features=out_features,
        **query_pc.extra_attributes,
    )

Point pooling

warpconvnet.nn.modules.point_pool

PointAvgPool

Bases: PointPoolBase

Point pooling using mean reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointAvgPool(PointPoolBase):
    """Point pooling using ``mean`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.MEAN,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointMaxPool

Bases: PointPoolBase

Point pooling using max reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointMaxPool(PointPoolBase):
    """Point pooling using ``max`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.MAX,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointPoolBase

Bases: BaseSpatialModule

Base module for pooling points or voxels.

Parameters:
  • reduction (str or `REDUCTIONS`, default: MAX ) –

    Reduction method used when merging features. Defaults to REDUCTIONS.MAX.

  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • unique_method ((torch, ravel, morton), default: "torch" ) –

    Method used to find unique voxel indices. Defaults to "torch".

  • avereage_pooled_coordinates (bool, default: False ) –

    If True average coordinates of points within each voxel. Defaults to False.

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointPoolBase(BaseSpatialModule):
    """Base module for pooling points or voxels.

    Parameters
    ----------
    reduction : str or `REDUCTIONS`, optional
        Reduction method used when merging features. Defaults to ``REDUCTIONS.MAX``.
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    unique_method : {"torch", "ravel", "morton"}, optional
        Method used to find unique voxel indices. Defaults to ``"torch"``.
    avereage_pooled_coordinates : bool, optional
        If ``True`` average coordinates of points within each voxel. Defaults to ``False``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        reduction: Union[str, REDUCTIONS] = REDUCTIONS.MAX,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        unique_method: Literal["torch", "ravel", "morton"] = "torch",
        avereage_pooled_coordinates: bool = False,
        return_neighbor_search_result: bool = False,
    ):
        super().__init__()
        if isinstance(reduction, str):
            reduction = REDUCTIONS(reduction)
        self.reduction = reduction
        self.downsample_max_num_points = downsample_max_num_points
        self.downsample_voxel_size = downsample_voxel_size
        self.return_type = return_type
        self.return_neighbor_search_result = return_neighbor_search_result
        self.unique_method = unique_method
        self.avereage_pooled_coordinates = avereage_pooled_coordinates

    def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResult]]:
        return point_pool(
            pc=pc,
            reduction=self.reduction,
            downsample_max_num_points=self.downsample_max_num_points,
            downsample_voxel_size=self.downsample_voxel_size,
            return_type=self.return_type,
            return_neighbor_search_result=self.return_neighbor_search_result,
            unique_method=self.unique_method,
            avereage_pooled_coordinates=self.avereage_pooled_coordinates,
        )

PointSumPool

Bases: PointPoolBase

Point pooling using sum reduction.

Parameters:
  • downsample_max_num_points (int, default: None ) –

    Maximum number of points to keep when downsampling.

  • downsample_voxel_size (float, default: None ) –

    Size of voxels used for downsampling.

  • return_type ((point, sparse), default: "point" ) –

    Output geometry type. Defaults to "point".

  • return_neighbor_search_result (bool, default: False ) –

    If True also return the neighbor search result. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointSumPool(PointPoolBase):
    """Point pooling using ``sum`` reduction.

    Parameters
    ----------
    downsample_max_num_points : int, optional
        Maximum number of points to keep when downsampling.
    downsample_voxel_size : float, optional
        Size of voxels used for downsampling.
    return_type : {"point", "sparse"}, optional
        Output geometry type. Defaults to ``"point"``.
    return_neighbor_search_result : bool, optional
        If ``True`` also return the neighbor search result. Defaults to ``False``.
    """

    def __init__(
        self,
        downsample_max_num_points: Optional[int] = None,
        downsample_voxel_size: Optional[float] = None,
        return_type: Literal["point", "sparse"] = "point",
        return_neighbor_search_result: bool = False,
    ):
        super().__init__(
            reduction=REDUCTIONS.SUM,
            downsample_max_num_points=downsample_max_num_points,
            downsample_voxel_size=downsample_voxel_size,
            return_type=return_type,
            return_neighbor_search_result=return_neighbor_search_result,
        )

PointUnpool

Bases: BaseSpatialModule

Undo point pooling by scattering pooled features back to the input cloud.

Parameters:
  • unpooling_mode ((repeat, interpolate), default: "repeat" ) –

    Strategy used when unpooling features. Defaults to FEATURE_UNPOOLING_MODE.REPEAT.

  • concat_unpooled_pc (bool, default: False ) –

    If True concatenate the unpooled point cloud with the input. Defaults to False.

Source code in warpconvnet/nn/modules/point_pool.py
class PointUnpool(BaseSpatialModule):
    """Undo point pooling by scattering pooled features back to the input cloud.

    Parameters
    ----------
    unpooling_mode : {"repeat", "interpolate"} or `FEATURE_UNPOOLING_MODE`, optional
        Strategy used when unpooling features. Defaults to ``FEATURE_UNPOOLING_MODE.REPEAT``.
    concat_unpooled_pc : bool, optional
        If ``True`` concatenate the unpooled point cloud with the input. Defaults to ``False``.
    """

    def __init__(
        self,
        unpooling_mode: Union[str, FEATURE_UNPOOLING_MODE] = FEATURE_UNPOOLING_MODE.REPEAT,
        concat_unpooled_pc: bool = False,
    ):
        super().__init__()
        if isinstance(unpooling_mode, str):
            unpooling_mode = FEATURE_UNPOOLING_MODE(unpooling_mode)
        self.unpooling_mode = unpooling_mode
        self.concat_unpooled_pc = concat_unpooled_pc

    def forward(self, pooled_pc: Points, unpooled_pc: Points):
        return point_unpool(
            pooled_pc=pooled_pc,
            unpooled_pc=unpooled_pc,
            unpooling_mode=self.unpooling_mode,
            concat_unpooled_pc=self.concat_unpooled_pc,
        )

Prune

warpconvnet.nn.modules.prune

SparsePrune

Bases: BaseSpatialModule

Module wrapper around prune_spatially_sparse_tensor so pruning can be composed in nn.Sequential.

Forward Args

spatial_tensor : Geometry Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered. mask : Bool[Tensor, "N"] Boolean mask aligned with spatial_tensor.coordinate_tensor.

Source code in warpconvnet/nn/modules/prune.py
class SparsePrune(BaseSpatialModule):
    """
    Module wrapper around ``prune_spatially_sparse_tensor`` so pruning can be composed in nn.Sequential.

    Forward Args
    -----------
    spatial_tensor : Geometry
        Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered.
    mask : Bool[Tensor, "N"]
        Boolean mask aligned with ``spatial_tensor.coordinate_tensor``.
    """

    def forward(
        self,
        spatial_tensor: Geometry,
        mask: Bool[Tensor, "N"],  # noqa: F821
    ) -> Geometry:
        return prune_spatially_sparse_tensor(spatial_tensor, mask)

Sequential

warpconvnet.nn.modules.sequential

GeometryWrapper

Bases: BaseSpatialModule

Wrapper for a spatial module that returns a geometry object.

Source code in warpconvnet/nn/modules/sequential.py
class GeometryWrapper(BaseSpatialModule):
    """Wrapper for a spatial module that returns a geometry object."""

    def __init__(self, module: Callable[[Tensor], Tensor]):
        super().__init__()
        self.module = module

    def forward(self, x: Geometry) -> Geometry:
        return x.replace(batched_features=self.module(x.feature_tensor))

Sequential

Bases: Sequential, BaseSpatialModule

Sequential module that allows for spatial and non-spatial layers to be chained together.

If the module has multiple consecutive non-spatial layers, then it will not create an intermediate spatial features object and will become more efficient.

Source code in warpconvnet/nn/modules/sequential.py
class Sequential(nn.Sequential, BaseSpatialModule):
    """
    Sequential module that allows for spatial and non-spatial layers to be chained together.

    If the module has multiple consecutive non-spatial layers, then it will not create an intermediate
    spatial features object and will become more efficient.
    """

    def forward(self, x: Geometry):
        assert isinstance(x, Geometry), f"Expected BatchedSpatialFeatures, got {type(x)}"

        in_sf = x
        for module in self:
            x, in_sf = run_forward(module, x, in_sf)

        if isinstance(x, torch.Tensor):
            x = in_sf.replace(batched_features=x)

        return x

TupleSequential

Bases: Sequential, BaseSpatialModule

Sequential module that allows multiple inputs for a specified layer.

Source code in warpconvnet/nn/modules/sequential.py
class TupleSequential(Sequential, BaseSpatialModule):
    """
    Sequential module that allows multiple inputs for a specified layer.
    """

    def __init__(self, *args, tuple_layer: int):
        if len(args) == 1 and isinstance(args[0], Sequence):
            super().__init__(*args[0])
        else:
            super().__init__(*args)
        self.tuple_layer = tuple_layer

    def forward(self, *xs: Tuple[Geometry]):
        x = xs[0]
        in_sf = x
        for i, module in enumerate(self):
            if i == self.tuple_layer:
                x, in_sf = tuple_run_forward(module, (x, *xs[1:]), in_sf)
            else:
                x, in_sf = run_forward(module, x, in_sf)

        if isinstance(x, torch.Tensor):
            x = in_sf.replace(batched_features=x)

        return x

Sparse convolution

warpconvnet.nn.modules.sparse_conv

SparseConv2d

Bases: SpatiallySparseConv

2D sparse convolution.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • fwd_algo (`SPARSE_CONV_FWD_ALGO_MODE` or str, default: None ) –

    Forward algorithm to use.

  • bwd_algo (`SPARSE_CONV_BWD_ALGO_MODE` or str, default: None ) –

    Backward algorithm to use.

  • kernel_matmul_batch_size (int, default: 2 ) –

    Batch size used for implicit matrix multiplications. Defaults to 2.

  • order (`POINT_ORDERING`, default: RANDOM ) –

    Ordering of points in the output. Defaults to POINT_ORDERING.RANDOM.

  • compute_dtype (dtype, default: None ) –

    Data type used for intermediate computations.

  • implicit_matmul_fwd_block_size (int, default: None ) –

    CUDA block size for implicit forward matmuls.

  • implicit_matmul_bwd_block_size (int, default: None ) –

    CUDA block size for implicit backward matmuls.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv2d(SpatiallySparseConv):
    """2D sparse convolution.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
        Forward algorithm to use.
    bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
        Backward algorithm to use.
    kernel_matmul_batch_size : int, optional
        Batch size used for implicit matrix multiplications. Defaults to ``2``.
    order : `POINT_ORDERING`, optional
        Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
    compute_dtype : torch.dtype, optional
        Data type used for intermediate computations.
    implicit_matmul_fwd_block_size : int, optional
        CUDA block size for implicit forward matmuls.
    implicit_matmul_bwd_block_size : int, optional
        CUDA block size for implicit backward matmuls.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        bias=True,
        transposed=False,
        generative: bool = False,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
        kernel_matmul_batch_size: int = 2,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        implicit_matmul_fwd_block_size: Optional[int] = None,
        implicit_matmul_bwd_block_size: Optional[int] = None,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=2,
            stride_mode=stride_mode,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            kernel_matmul_batch_size=kernel_matmul_batch_size,
            order=order,
            compute_dtype=compute_dtype,
            implicit_matmul_fwd_block_size=implicit_matmul_fwd_block_size,
            implicit_matmul_bwd_block_size=implicit_matmul_bwd_block_size,
        )

SparseConv3d

Bases: SpatiallySparseConv

3D sparse convolution.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • fwd_algo (`SPARSE_CONV_FWD_ALGO_MODE` or str, default: None ) –

    Forward algorithm to use.

  • bwd_algo (`SPARSE_CONV_BWD_ALGO_MODE` or str, default: None ) –

    Backward algorithm to use.

  • kernel_matmul_batch_size (int, default: 2 ) –

    Batch size used for implicit matrix multiplications. Defaults to 2.

  • order (`POINT_ORDERING`, default: RANDOM ) –

    Ordering of points in the output. Defaults to POINT_ORDERING.RANDOM.

  • compute_dtype (dtype, default: None ) –

    Data type used for intermediate computations.

  • implicit_matmul_fwd_block_size (int, default: None ) –

    CUDA block size for implicit forward matmuls.

  • implicit_matmul_bwd_block_size (int, default: None ) –

    CUDA block size for implicit backward matmuls.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv3d(SpatiallySparseConv):
    """3D sparse convolution.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
        Forward algorithm to use.
    bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
        Backward algorithm to use.
    kernel_matmul_batch_size : int, optional
        Batch size used for implicit matrix multiplications. Defaults to ``2``.
    order : `POINT_ORDERING`, optional
        Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
    compute_dtype : torch.dtype, optional
        Data type used for intermediate computations.
    implicit_matmul_fwd_block_size : int, optional
        CUDA block size for implicit forward matmuls.
    implicit_matmul_bwd_block_size : int, optional
        CUDA block size for implicit backward matmuls.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        bias=True,
        transposed=False,
        generative: bool = False,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
        kernel_matmul_batch_size: int = 2,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        implicit_matmul_fwd_block_size: Optional[int] = None,
        implicit_matmul_bwd_block_size: Optional[int] = None,
    ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=3,
            stride_mode=stride_mode,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            kernel_matmul_batch_size=kernel_matmul_batch_size,
            order=order,
            compute_dtype=compute_dtype,
            implicit_matmul_fwd_block_size=implicit_matmul_fwd_block_size,
            implicit_matmul_bwd_block_size=implicit_matmul_bwd_block_size,
        )

SpatiallySparseConv

Bases: BaseSpatialModule

Sparse convolution layer for warpconvnet.geometry.types.voxels.Voxels.

Parameters:
  • in_channels (int) –

    Number of input feature channels.

  • out_channels (int) –

    Number of output feature channels.

  • kernel_size (int or tuple of int) –

    Size of the convolution kernel.

  • stride (int or tuple of int, default: 1 ) –

    Convolution stride. Defaults to 1.

  • dilation (int or tuple of int, default: 1 ) –

    Spacing between kernel elements. Defaults to 1.

  • bias (bool, default: True ) –

    If True adds a learnable bias to the output. Defaults to True.

  • transposed (bool, default: False ) –

    Perform a transposed convolution. Defaults to False.

  • generative (bool, default: False ) –

    Use generative convolution. Defaults to False.

  • kernel_matmul_batch_size (int, default: 2 ) –

    Batch size used for implicit matrix multiplications. Defaults to 2.

  • num_spatial_dims (int, default: 3 ) –

    Number of spatial dimensions. Defaults to 3.

  • fwd_algo (`SPARSE_CONV_FWD_ALGO_MODE` or str, default: None ) –

    Forward algorithm to use. Defaults to environment setting.

  • bwd_algo (`SPARSE_CONV_BWD_ALGO_MODE` or str, default: None ) –

    Backward algorithm to use. Defaults to environment setting.

  • stride_mode (`STRIDED_CONV_MODE`, default: STRIDE_ONLY ) –

    How to interpret stride when transposed is True.

  • order (`POINT_ORDERING`, default: RANDOM ) –

    Ordering of points in the output. Defaults to POINT_ORDERING.RANDOM.

  • compute_dtype (dtype, default: None ) –

    Data type used for intermediate computations.

  • implicit_matmul_fwd_block_size (int, default: None ) –

    CUDA block size for implicit forward matmuls.

  • implicit_matmul_bwd_block_size (int, default: None ) –

    CUDA block size for implicit backward matmuls.

Source code in warpconvnet/nn/modules/sparse_conv.py
class SpatiallySparseConv(BaseSpatialModule):
    """Sparse convolution layer for `warpconvnet.geometry.types.voxels.Voxels`.

    Parameters
    ----------
    in_channels : int
        Number of input feature channels.
    out_channels : int
        Number of output feature channels.
    kernel_size : int or tuple of int
        Size of the convolution kernel.
    stride : int or tuple of int, optional
        Convolution stride. Defaults to ``1``.
    dilation : int or tuple of int, optional
        Spacing between kernel elements. Defaults to ``1``.
    bias : bool, optional
        If ``True`` adds a learnable bias to the output. Defaults to ``True``.
    transposed : bool, optional
        Perform a transposed convolution. Defaults to ``False``.
    generative : bool, optional
        Use generative convolution. Defaults to ``False``.
    kernel_matmul_batch_size : int, optional
        Batch size used for implicit matrix multiplications. Defaults to ``2``.
    num_spatial_dims : int, optional
        Number of spatial dimensions. Defaults to ``3``.
    fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
        Forward algorithm to use. Defaults to environment setting.
    bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
        Backward algorithm to use. Defaults to environment setting.
    stride_mode : `STRIDED_CONV_MODE`, optional
        How to interpret ``stride`` when ``transposed`` is ``True``.
    order : `POINT_ORDERING`, optional
        Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
    compute_dtype : torch.dtype, optional
        Data type used for intermediate computations.
    implicit_matmul_fwd_block_size : int, optional
        CUDA block size for implicit forward matmuls.
    implicit_matmul_bwd_block_size : int, optional
        CUDA block size for implicit backward matmuls.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        kernel_matmul_batch_size: int = 2,
        num_spatial_dims: Optional[int] = 3,
        fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
        implicit_matmul_fwd_block_size: Optional[int] = None,
        implicit_matmul_bwd_block_size: Optional[int] = None,
    ):
        super().__init__()
        self.num_spatial_dims = num_spatial_dims
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Ensure kernel_size, stride, dilation are tuples for consistent use
        _kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
        _stride = ntuple(stride, ndim=self.num_spatial_dims)
        _dilation = ntuple(dilation, ndim=self.num_spatial_dims)

        self.kernel_size = _kernel_size
        self.stride = _stride
        self.dilation = _dilation

        self.transposed = transposed
        self.generative = generative
        self.kernel_matmul_batch_size = kernel_matmul_batch_size

        # Use environment variable values if not explicitly provided
        if fwd_algo is None:
            fwd_algo = WARPCONVNET_FWD_ALGO_MODE
        if bwd_algo is None:
            bwd_algo = WARPCONVNET_BWD_ALGO_MODE

        # Convert string to enum, but preserve lists for direct passing to functional layer
        if isinstance(fwd_algo, str):
            self.fwd_algo = SPARSE_CONV_FWD_ALGO_MODE(fwd_algo)
        else:
            # Keep lists as-is (from env vars or direct user input)
            self.fwd_algo = fwd_algo

        if isinstance(bwd_algo, str):
            self.bwd_algo = SPARSE_CONV_BWD_ALGO_MODE(bwd_algo)
        else:
            # Keep lists as-is (from env vars or direct user input)
            self.bwd_algo = bwd_algo

        self.stride_mode = stride_mode
        self.order = order
        self.compute_dtype = compute_dtype
        self.implicit_matmul_fwd_block_size = implicit_matmul_fwd_block_size
        self.implicit_matmul_bwd_block_size = implicit_matmul_bwd_block_size

        self.bias: Optional[nn.Parameter] = None

        self.weight = nn.Parameter(torch.randn(np.prod(_kernel_size), in_channels, out_channels))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_channels))
        else:
            self.bias = None  # Explicitly set to None if bias is False
        self.reset_parameters()  # Call after parameters are defined for the chosen backend

    def __repr__(self):
        # return class name and parameters that are not default
        out_str = f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}"
        if self.stride != 1:
            out_str += f", stride={self.stride}"
        if self.dilation != 1:
            out_str += f", dilation={self.dilation}"
        if self.transposed:
            out_str += f", transposed={self.transposed}"
        if self.generative:
            out_str += f", generative={self.generative}"
        if self.order != POINT_ORDERING.RANDOM:
            out_str += f", order={self.order}"
        out_str += ")"
        return out_str

    def _calculate_fan_in_and_fan_out(self):
        receptive_field_size = np.prod(self.kernel_size)
        fan_in = self.in_channels * receptive_field_size
        fan_out = self.out_channels * receptive_field_size
        return fan_in, fan_out

    def _calculate_correct_fan(self, mode: Literal["fan_in", "fan_out"]):
        mode = mode.lower()
        assert mode in ["fan_in", "fan_out"]

        fan_in, fan_out = self._calculate_fan_in_and_fan_out()
        return fan_in if mode == "fan_in" else fan_out

    def _custom_kaiming_uniform_(self, tensor, a=0, mode="fan_in", nonlinearity="leaky_relu"):
        fan = self._calculate_correct_fan(mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
        bound = math.sqrt(self.num_spatial_dims) * std
        with torch.no_grad():
            return tensor.uniform_(-bound, bound)

    @torch.no_grad()
    def reset_parameters(self):
        self._custom_kaiming_uniform_(
            self.weight,
            a=math.sqrt(5),
            mode="fan_out" if self.transposed else "fan_in",
        )

        if self.bias is not None:
            fan_in, _ = self._calculate_fan_in_and_fan_out()
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(
        self,
        input_sparse_tensor: Voxels,
        output_spatially_sparse_tensor: Optional[Voxels] = None,
    ):
        return spatially_sparse_conv(
            input_sparse_tensor=input_sparse_tensor,
            weight=self.weight,
            kernel_size=self.kernel_size,
            stride=self.stride,
            kernel_dilation=self.dilation,
            bias=self.bias,
            kernel_matmul_batch_size=self.kernel_matmul_batch_size,
            output_spatially_sparse_tensor=output_spatially_sparse_tensor,
            transposed=self.transposed,
            generative=self.generative,
            fwd_algo=self.fwd_algo,
            bwd_algo=self.bwd_algo,
            stride_mode=self.stride_mode,
            order=self.order,
            compute_dtype=self.compute_dtype,
            implicit_matmul_fwd_block_size=self.implicit_matmul_fwd_block_size,
            implicit_matmul_bwd_block_size=self.implicit_matmul_bwd_block_size,
        )

Sparse depthwise convolution

warpconvnet.nn.modules.sparse_conv_depth

SparseDepthwiseConv2d

Bases: SpatiallySparseDepthwiseConv

2D spatially sparse depthwise convolution.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv2d(SpatiallySparseDepthwiseConv):
    """2D spatially sparse depthwise convolution."""

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]] = 1,
        dilation: Union[int, Tuple[int, int]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(
            channels=channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=2,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            stride_mode=stride_mode,
            stride_reduce=stride_reduce,
            order=order,
            compute_dtype=compute_dtype,
        )

SparseDepthwiseConv3d

Bases: SpatiallySparseDepthwiseConv

3D spatially sparse depthwise convolution.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv3d(SpatiallySparseDepthwiseConv):
    """3D spatially sparse depthwise convolution."""

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, int, int]],
        stride: Union[int, Tuple[int, int, int]] = 1,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__(
            channels=channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias,
            transposed=transposed,
            generative=generative,
            num_spatial_dims=3,
            fwd_algo=fwd_algo,
            bwd_algo=bwd_algo,
            stride_mode=stride_mode,
            stride_reduce=stride_reduce,
            order=order,
            compute_dtype=compute_dtype,
        )

SpatiallySparseDepthwiseConv

Bases: BaseSpatialModule

Spatially sparse depthwise convolution module.

In depthwise convolution, each input channel is convolved with its own kernel, so the number of input channels must equal the number of output channels. The weight shape is (K, C) where K is the kernel volume and C is the number of channels.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SpatiallySparseDepthwiseConv(BaseSpatialModule):
    """
    Spatially sparse depthwise convolution module.

    In depthwise convolution, each input channel is convolved with its own kernel,
    so the number of input channels must equal the number of output channels.
    The weight shape is (K, C) where K is the kernel volume and C is the number of channels.
    """

    def __init__(
        self,
        channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        bias: bool = True,
        transposed: bool = False,
        generative: bool = False,
        num_spatial_dims: int = 3,
        fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
        bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
        stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
        stride_reduce: str = "max",
        order: POINT_ORDERING = POINT_ORDERING.RANDOM,
        compute_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.num_spatial_dims = num_spatial_dims
        self.channels = channels
        self.in_channels = channels  # For compatibility with PyTorch naming
        self.out_channels = channels  # For depthwise, in_channels == out_channels

        # Ensure kernel_size, stride, dilation are tuples for consistent use
        self.kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
        self.stride = ntuple(stride, ndim=self.num_spatial_dims)
        self.dilation = ntuple(dilation, ndim=self.num_spatial_dims)

        self.transposed = transposed
        self.generative = generative
        self.stride_reduce = stride_reduce

        # Use environment variable values if not explicitly provided
        if fwd_algo is None:
            fwd_algo = WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE
        if bwd_algo is None:
            bwd_algo = WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE

        # Map string algo names to depthwise-specific enums if needed
        if isinstance(fwd_algo, str):
            # Map generic algorithm names to depthwise-specific ones
            if fwd_algo.lower() in ["explicit", "explicit_gemm"]:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.EXPLICIT
            elif fwd_algo.lower() in ["implicit", "implicit_gemm"]:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.IMPLICIT
            elif fwd_algo.lower() == "auto":
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.AUTO
            else:
                self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE(fwd_algo)
        else:
            self.fwd_algo = fwd_algo

        if isinstance(bwd_algo, str):
            # Map generic algorithm names to depthwise-specific ones
            if bwd_algo.lower() in ["explicit", "explicit_gemm"]:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.EXPLICIT
            elif bwd_algo.lower() in ["implicit", "implicit_gemm"]:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.IMPLICIT
            elif bwd_algo.lower() == "auto":
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.AUTO
            else:
                self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE(bwd_algo)
        else:
            self.bwd_algo = bwd_algo

        self.stride_mode = stride_mode
        self.order = order
        self.compute_dtype = compute_dtype

        # Depthwise convolution weight shape: (K, C) where K is kernel volume
        kernel_volume = int(np.prod(self.kernel_size))
        self.weight = nn.Parameter(torch.randn(kernel_volume, channels))

        # Optional bias
        if bias:
            self.bias = nn.Parameter(torch.randn(channels))
        else:
            self.bias = None

        self.reset_parameters()

    def __repr__(self):
        out_str = (
            f"{self.__class__.__name__}(channels={self.channels}, "
            f"kernel_size={self.kernel_size}"
        )
        if self.stride != (1,) * self.num_spatial_dims:
            out_str += f", stride={self.stride}"
        if self.dilation != (1,) * self.num_spatial_dims:
            out_str += f", dilation={self.dilation}"
        if self.transposed:
            out_str += f", transposed={self.transposed}"
        if self.generative:
            out_str += f", generative={self.generative}"
        if self.order != POINT_ORDERING.RANDOM:
            out_str += f", order={self.order}"
        if self.bias is None:
            out_str += ", bias=False"
        out_str += ")"
        return out_str

    def _calculate_fan_in_and_fan_out(self):
        """Calculate fan_in and fan_out for depthwise convolution."""
        receptive_field_size = np.prod(self.kernel_size)
        # For depthwise convolution, each channel has its own kernel
        fan_in = receptive_field_size  # One kernel per channel
        fan_out = receptive_field_size  # One output per channel
        return fan_in, fan_out

    def _calculate_correct_fan(self, mode: str):
        """Calculate correct fan for initialization."""
        mode = mode.lower()
        assert mode in ["fan_in", "fan_out"]

        fan_in, fan_out = self._calculate_fan_in_and_fan_out()
        return fan_in if mode == "fan_in" else fan_out

    def _custom_kaiming_uniform_(self, tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
        """Custom Kaiming uniform initialization for depthwise convolution."""
        fan = self._calculate_correct_fan(mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
        bound = math.sqrt(self.num_spatial_dims) * std
        with torch.no_grad():
            return tensor.uniform_(-bound, bound)

    @torch.no_grad()
    def reset_parameters(self):
        """Reset module parameters using appropriate initialization."""
        self._custom_kaiming_uniform_(
            self.weight,
            a=math.sqrt(5),
            mode="fan_out" if self.transposed else "fan_in",
        )

        if self.bias is not None:
            fan_in, _ = self._calculate_fan_in_and_fan_out()
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(
        self,
        input_sparse_tensor: Voxels,
        output_spatially_sparse_tensor: Optional[Voxels] = None,
    ) -> Voxels:
        """
        Forward pass for spatially sparse depthwise convolution.

        Args:
            input_sparse_tensor: Input sparse tensor
            output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

        Returns:
            Output sparse tensor
        """
        # Generate output coordinates and kernel map
        batch_indexed_out_coords, out_offsets, kernel_map = generate_output_coords_and_kernel_map(
            input_sparse_tensor=input_sparse_tensor,
            kernel_size=self.kernel_size,
            kernel_dilation=self.dilation,
            stride=self.stride,
            generative=self.generative,
            transposed=self.transposed,
            output_spatially_sparse_tensor=output_spatially_sparse_tensor,
            stride_mode=self.stride_mode,
            order=self.order,
        )

        num_out_coords = batch_indexed_out_coords.shape[0]

        # Apply depthwise convolution
        output_features = spatially_sparse_depthwise_conv(
            input_sparse_tensor.feature_tensor,
            self.weight,
            kernel_map,
            num_out_coords,
            fwd_algo=self.fwd_algo,
            bwd_algo=self.bwd_algo,
            compute_dtype=self.compute_dtype,
        )

        # Add bias if present
        if self.bias is not None:
            output_features = output_features + self.bias

        # Determine output tensor stride
        in_tensor_stride = input_sparse_tensor.tensor_stride
        if in_tensor_stride is None:
            in_tensor_stride = (1,) * self.num_spatial_dims

        if not self.transposed:
            out_tensor_stride = tuple(o * s for o, s in zip(self.stride, in_tensor_stride))
        else:
            if (
                output_spatially_sparse_tensor is not None
                and output_spatially_sparse_tensor.tensor_stride is not None
            ):
                out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
            else:
                out_tensor_stride = (1,) * self.num_spatial_dims

        # Create output voxels
        out_offsets_cpu = out_offsets.cpu().int()
        out_coords = IntCoords(
            batch_indexed_out_coords[:, 1:],
            offsets=out_offsets_cpu,
        )
        return input_sparse_tensor.replace(
            batched_coordinates=out_coords,
            batched_features=output_features,
            tensor_stride=out_tensor_stride,
        )

forward(input_sparse_tensor: Voxels, output_spatially_sparse_tensor: Optional[Voxels] = None) -> Voxels

Forward pass for spatially sparse depthwise convolution.

Args: input_sparse_tensor: Input sparse tensor output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

Returns: Output sparse tensor

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
def forward(
    self,
    input_sparse_tensor: Voxels,
    output_spatially_sparse_tensor: Optional[Voxels] = None,
) -> Voxels:
    """
    Forward pass for spatially sparse depthwise convolution.

    Args:
        input_sparse_tensor: Input sparse tensor
        output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv

    Returns:
        Output sparse tensor
    """
    # Generate output coordinates and kernel map
    batch_indexed_out_coords, out_offsets, kernel_map = generate_output_coords_and_kernel_map(
        input_sparse_tensor=input_sparse_tensor,
        kernel_size=self.kernel_size,
        kernel_dilation=self.dilation,
        stride=self.stride,
        generative=self.generative,
        transposed=self.transposed,
        output_spatially_sparse_tensor=output_spatially_sparse_tensor,
        stride_mode=self.stride_mode,
        order=self.order,
    )

    num_out_coords = batch_indexed_out_coords.shape[0]

    # Apply depthwise convolution
    output_features = spatially_sparse_depthwise_conv(
        input_sparse_tensor.feature_tensor,
        self.weight,
        kernel_map,
        num_out_coords,
        fwd_algo=self.fwd_algo,
        bwd_algo=self.bwd_algo,
        compute_dtype=self.compute_dtype,
    )

    # Add bias if present
    if self.bias is not None:
        output_features = output_features + self.bias

    # Determine output tensor stride
    in_tensor_stride = input_sparse_tensor.tensor_stride
    if in_tensor_stride is None:
        in_tensor_stride = (1,) * self.num_spatial_dims

    if not self.transposed:
        out_tensor_stride = tuple(o * s for o, s in zip(self.stride, in_tensor_stride))
    else:
        if (
            output_spatially_sparse_tensor is not None
            and output_spatially_sparse_tensor.tensor_stride is not None
        ):
            out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
        else:
            out_tensor_stride = (1,) * self.num_spatial_dims

    # Create output voxels
    out_offsets_cpu = out_offsets.cpu().int()
    out_coords = IntCoords(
        batch_indexed_out_coords[:, 1:],
        offsets=out_offsets_cpu,
    )
    return input_sparse_tensor.replace(
        batched_coordinates=out_coords,
        batched_features=output_features,
        tensor_stride=out_tensor_stride,
    )

reset_parameters()

Reset module parameters using appropriate initialization.

Source code in warpconvnet/nn/modules/sparse_conv_depth.py
@torch.no_grad()
def reset_parameters(self):
    """Reset module parameters using appropriate initialization."""
    self._custom_kaiming_uniform_(
        self.weight,
        a=math.sqrt(5),
        mode="fan_out" if self.transposed else "fan_in",
    )

    if self.bias is not None:
        fan_in, _ = self._calculate_fan_in_and_fan_out()
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

Sparse pooling

warpconvnet.nn.modules.sparse_pool

GlobalPool

Bases: BaseSpatialModule

Pool features across the entire geometry.

Parameters:
  • reduce ((min, max, mean, sum), default: "min" ) –

    Reduction to apply over all features. Defaults to "max".

Source code in warpconvnet/nn/modules/sparse_pool.py
class GlobalPool(BaseSpatialModule):
    """Pool features across the entire geometry.

    Parameters
    ----------
    reduce : {"min", "max", "mean", "sum"}, optional
        Reduction to apply over all features. Defaults to ``"max"``.
    """

    def __init__(self, reduce: Literal["min", "max", "mean", "sum"] = "max"):
        super().__init__()
        self.reduce = reduce

    def forward(self, x: Geometry):
        return global_pool(x, self.reduce)

PointToSparseWrapper

Bases: BaseSpatialModule

Pool points into a sparse tensor, apply an inner module and unpool back to points.

Parameters:
  • inner_module (`BaseSpatialModule`) –

    Module applied on the pooled sparse tensor.

  • voxel_size (float) –

    Voxel size used to pool the input points.

  • reduction (`REDUCTIONS` or str, default: MEAN ) –

    Reduction used when pooling points. Defaults to REDUCTIONS.MEAN.

  • unique_method ((morton, ravel, torch), default: "morton" ) –

    Method used for hashing voxel indices. Defaults to "morton".

  • concat_unpooled_pc (bool, default: True ) –

    If True concatenate the unpooled result with the original input. Defaults to True.

Source code in warpconvnet/nn/modules/sparse_pool.py
class PointToSparseWrapper(BaseSpatialModule):
    """Pool points into a sparse tensor, apply an inner module and unpool back to points.

    Parameters
    ----------
    inner_module : `BaseSpatialModule`
        Module applied on the pooled sparse tensor.
    voxel_size : float
        Voxel size used to pool the input points.
    reduction : `REDUCTIONS` or str, optional
        Reduction used when pooling points. Defaults to ``REDUCTIONS.MEAN``.
    unique_method : {"morton", "ravel", "torch"}, optional
        Method used for hashing voxel indices. Defaults to ``"morton"``.
    concat_unpooled_pc : bool, optional
        If ``True`` concatenate the unpooled result with the original input. Defaults to ``True``.
    """

    def __init__(
        self,
        inner_module: BaseSpatialModule,
        voxel_size: float,
        reduction: Union[REDUCTIONS, REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
        unique_method: Literal["morton", "ravel", "torch"] = "morton",
        concat_unpooled_pc: bool = True,
    ):
        super().__init__()
        self.inner_module = inner_module
        self.voxel_size = voxel_size
        self.reduction = reduction
        self.concat_unpooled_pc = concat_unpooled_pc
        self.unique_method = unique_method

    def forward(self, pc: Points) -> Points:
        st, to_unique = point_pool(
            pc,
            reduction=self.reduction,
            downsample_voxel_size=self.voxel_size,
            return_type="voxel",
            return_to_unique=True,
            unique_method=self.unique_method,
        )
        out_st = self.inner_module(st)
        assert isinstance(out_st, Voxels), "Output of inner module must be a Voxels"
        unpooled_pc = point_unpool(
            out_st.to_point(self.voxel_size),
            pc,
            concat_unpooled_pc=self.concat_unpooled_pc,
            to_unique=to_unique,
        )
        return unpooled_pc

SparseMaxPool

Bases: SparsePool

Max pooling for sparse tensors.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMaxPool(SparsePool):
    """Max pooling for sparse tensors.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
    ):
        super().__init__(kernel_size, stride, "max")

SparseMinPool

Bases: SparsePool

Min pooling for sparse tensors.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMinPool(SparsePool):
    """Min pooling for sparse tensors.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
    ):
        super().__init__(kernel_size, stride, "min")

SparsePool

Bases: BaseSpatialModule

Reduce features of a Voxels object using a strided kernel.

Parameters:
  • kernel_size (int) –

    Size of the pooling kernel.

  • stride (int) –

    Stride between pooling windows.

  • reduce ((max, min, mean, sum, random), default: "max" ) –

    Reduction to apply within each window. Defaults to "max".

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparsePool(BaseSpatialModule):
    """Reduce features of a ``Voxels`` object using a strided kernel.

    Parameters
    ----------
    kernel_size : int
        Size of the pooling kernel.
    stride : int
        Stride between pooling windows.
    reduce : {"max", "min", "mean", "sum", "random"}, optional
        Reduction to apply within each window. Defaults to ``"max"``.
    """

    def __init__(
        self,
        kernel_size: int,
        stride: int,
        reduce: Literal["max", "min", "mean", "sum", "random"] = "max",
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.reduce = reduce

    def __repr__(self):
        return f"{self.__class__.__name__}(kernel_size={self.kernel_size}, stride={self.stride}, reduce={self.reduce})"

    def forward(self, st: Voxels):
        return sparse_reduce(
            st,
            self.kernel_size,
            self.stride,
            self.reduce,
        )

SparseUnpool

Bases: BaseSpatialModule

Unpool a sparse tensor back to a higher resolution.

Parameters:
  • kernel_size (int) –

    Size of the unpooling kernel.

  • stride (int) –

    Stride between unpooling windows.

  • concat_unpooled_st (bool, default: True ) –

    If True concatenate the unpooled tensor with the input. Defaults to True.

Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseUnpool(BaseSpatialModule):
    """Unpool a sparse tensor back to a higher resolution.

    Parameters
    ----------
    kernel_size : int
        Size of the unpooling kernel.
    stride : int
        Stride between unpooling windows.
    concat_unpooled_st : bool, optional
        If ``True`` concatenate the unpooled tensor with the input. Defaults to ``True``.
    """

    def __init__(self, kernel_size: int, stride: int, concat_unpooled_st: bool = True):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.concat_unpooled_st = concat_unpooled_st

    def forward(self, st: Voxels, unpooled_st: Voxels):
        return sparse_unpool(
            st,
            unpooled_st,
            self.kernel_size,
            self.stride,
            self.concat_unpooled_st,
        )

Transforms

warpconvnet.nn.modules.transforms

Transform

Bases: BaseSpatialModule

Point transform module that applies a feature transform to the input point collection. No spatial operations are performed.

Hydra config example usage:

.. code-block:: yaml

    model:
    feature_transform:
        _target_: warpconvnet.nn.point_transform.Transform
        feature_transform_fn: _target_: torch.nn.ReLU
Source code in warpconvnet/nn/modules/transforms.py
class Transform(BaseSpatialModule):
    """
    Point transform module that applies a feature transform to the input point collection.
    No spatial operations are performed.

    Hydra config example usage:

    .. code-block:: yaml

            model:
            feature_transform:
                _target_: warpconvnet.nn.point_transform.Transform
                feature_transform_fn: _target_: torch.nn.ReLU
    """

    def __init__(self, feature_transform_fn: nn.Module):
        super().__init__()
        self.feature_transform_fn = feature_transform_fn

    def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
        """
        Apply the feature transform to the input point collection

        Args:
            pc: Input point collection

        Returns:
            Transformed point collection
        """
        if isinstance(sfs, Geometry):
            return sfs.replace(batched_features=self.feature_transform_fn(sfs.feature_tensor))

        # When input is not a single BatchedSpatialFeatures, we assume the inputs are features
        assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
        # Assert that all spatial features have the same offsets
        assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
        sf = sfs[0]
        features = [sf.feature_tensor for sf in sfs]

        out_features = self.feature_transform_fn(*features)
        return sf.replace(
            batched_features=out_features,
        )

forward(*sfs: Tuple[Geometry, ...]) -> Geometry

Apply the feature transform to the input point collection

Args: pc: Input point collection

Returns: Transformed point collection

Source code in warpconvnet/nn/modules/transforms.py
def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
    """
    Apply the feature transform to the input point collection

    Args:
        pc: Input point collection

    Returns:
        Transformed point collection
    """
    if isinstance(sfs, Geometry):
        return sfs.replace(batched_features=self.feature_transform_fn(sfs.feature_tensor))

    # When input is not a single BatchedSpatialFeatures, we assume the inputs are features
    assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
    # Assert that all spatial features have the same offsets
    assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
    sf = sfs[0]
    features = [sf.feature_tensor for sf in sfs]

    out_features = self.feature_transform_fn(*features)
    return sf.replace(
        batched_features=out_features,
    )