Neural Networks¶
warpconvnet.nn
¶
Modules¶
Activations¶
warpconvnet.nn.modules.activations
¶
DropPath
¶
Bases: BaseSpatialModule
Stochastic depth regularization.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/activations.py
class DropPath(BaseSpatialModule):
"""Stochastic depth regularization.
Parameters
----------
drop_prob : float, optional
Probability of dropping a sample. Defaults to ``0.0``.
scale_by_keep : bool, optional
If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
"""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x: Union[Geometry, Tensor]): # noqa: F821
if isinstance(x, Geometry):
return x.replace(
batched_features=drop_path(
x.feature_tensor, self.drop_prob, self.training, self.scale_by_keep
)
)
else:
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob, 3): 0.3f}"
ELU
¶
Bases: BaseSpatialModule
Applies the ELU activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class ELU(BaseSpatialModule):
"""Applies the ELU activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return elu(input)
GELU
¶
Bases: BaseSpatialModule
Applies the GELU activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class GELU(BaseSpatialModule):
"""Applies the GELU activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return gelu(input)
LeakyReLU
¶
Bases: Module
Applies the LeakyReLU activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class LeakyReLU(nn.Module):
"""Applies the LeakyReLU activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return leaky_relu(input)
LogSoftmax
¶
Bases: BaseSpatialModule
Applies the log_softmax activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class LogSoftmax(BaseSpatialModule):
"""Applies the ``log_softmax`` activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return log_softmax(input)
ReLU
¶
Bases: BaseSpatialModule
Apply the ReLU activation to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/activations.py
class ReLU(BaseSpatialModule):
"""Apply the ReLU activation to ``Geometry`` features.
Parameters
----------
inplace : bool, optional
Whether to perform the operation in-place. Defaults to ``False``.
"""
def __init__(self, inplace: bool = False):
super().__init__()
self.relu = nn.ReLU(inplace=inplace)
def __repr__(self):
return f"{self.__class__.__name__}(inplace={self.relu.inplace})"
def forward(self, input: Geometry): # noqa: F821
return apply_feature_transform(input, self.relu)
SiLU
¶
Bases: BaseSpatialModule
Applies the SiLU activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class SiLU(BaseSpatialModule):
"""Applies the SiLU activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return silu(input)
Sigmoid
¶
Bases: BaseSpatialModule
Applies the sigmoid activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class Sigmoid(BaseSpatialModule):
"""Applies the sigmoid activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return sigmoid(input)
Softmax
¶
Bases: BaseSpatialModule
Applies the softmax activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class Softmax(BaseSpatialModule):
"""Applies the ``softmax`` activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return softmax(input)
Tanh
¶
Bases: BaseSpatialModule
Applies the tanh activation to Geometry features.
Source code in warpconvnet/nn/modules/activations.py
class Tanh(BaseSpatialModule):
"""Applies the ``tanh`` activation to ``Geometry`` features."""
def forward(self, input: Geometry): # noqa: F821
return tanh(input)
drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True)
¶
Apply stochastic depth to the input tensor.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/activations.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Apply stochastic depth to the input tensor.
Parameters
----------
x : ``torch.Tensor``
Input tensor to apply stochastic depth to.
drop_prob : float, optional
Probability of dropping a sample. Defaults to ``0.0``.
training : bool, optional
Whether the module is in training mode. Defaults to ``False``.
scale_by_keep : bool, optional
If ``True`` the output is scaled by ``1 - drop_prob``. Defaults to ``True``.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
Attention¶
warpconvnet.nn.modules.attention
¶
Attention
¶
Bases: Module
Source code in warpconvnet/nn/modules/attention.py
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
enable_flash: bool = True,
use_batched_qkv: bool = True,
):
"""
Attention module with optional batched QKV for Muon optimization.
Args:
dim: Input feature dimension
num_heads: Number of attention heads
qkv_bias: Whether to use bias in QKV projection
qk_scale: Scale factor for attention scores
attn_drop: Attention dropout rate
proj_drop: Output projection dropout rate
enable_flash: Whether to use flash attention
use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
more effectively than the concatenated [dim, 3*dim] matrix.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.enable_flash = enable_flash
self.use_batched_qkv = use_batched_qkv
if enable_flash:
assert flash_attn is not None, "Make sure flash_attn is installed."
self.attn_drop_p = attn_drop
else:
self.attn_drop = nn.Dropout(attn_drop)
if use_batched_qkv:
# Use BatchedLinear for Muon-friendly QKV projection
self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
else:
# Original single linear layer approach
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: Float[Tensor, "B N C"], # noqa: F821
pos_enc: Optional[Float[Tensor, "B N C"]] = None, # noqa: F821
mask: Optional[Float[Tensor, "B N N"]] = None, # noqa: F821
num_points: Optional[Int[Tensor, "B"]] = None, # noqa: F821
) -> Float[Tensor, "B N C"]:
B, N, C = x.shape
# Compute QKV with unified approach
if pos_enc is not None and self.enable_flash:
# Add positional encoding to input before QKV projection for flash attention
qkv = self.qkv(x + pos_enc).reshape(B, N, 3, C)
else:
qkv = self.qkv(x).reshape(B, N, 3, C)
# Reshape to [B, N, 3, num_heads, head_dim]
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
if not self.enable_flash:
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
# Apply positional encoding to the query and key (non-flash path)
if pos_enc is not None:
q = q + pos_enc.unsqueeze(1)
k = k + pos_enc.unsqueeze(1)
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn + mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
else:
# Flash attention path
# Flash attention - preserve original dtype if possible
original_dtype = qkv.dtype
if qkv.dtype not in [torch.float16, torch.bfloat16]:
# Convert to half precision for flash attention
qkv_flash = qkv.half()
else:
qkv_flash = qkv
x = flash_attn.flash_attn_qkvpacked_func(
qkv_flash,
dropout_p=self.attn_drop_p if self.training else 0.0,
softmax_scale=self.scale,
).reshape(B, N, C)
# Convert back to original dtype if necessary
if x.dtype != original_dtype:
x = x.to(original_dtype)
x = self.proj(x)
x = self.proj_drop(x)
if num_points is not None:
x = zero_out_points(x, num_points)
return x
PatchAttention
¶
Bases: BaseSpatialModule
Source code in warpconvnet/nn/modules/attention.py
class PatchAttention(BaseSpatialModule):
def __init__(
self,
dim: int,
patch_size: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
order: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
use_batched_qkv: bool = True,
):
"""
Patch attention module with optional batched QKV for Muon optimization.
Args:
dim: Input feature dimension
patch_size: Size of patches for attention computation
num_heads: Number of attention heads
qkv_bias: Whether to use bias in QKV projection
qk_scale: Scale factor for attention scores
attn_drop: Attention dropout rate
proj_drop: Output projection dropout rate
order: Point ordering for patch generation
use_batched_qkv: If True, uses separate Q, K, V matrices stacked as [3, dim, dim]
for Muon optimization. Muon can orthogonalize the [dim, dim] matrices
more effectively than the concatenated [dim, 3*dim] matrix.
"""
super().__init__()
self.patch_size = patch_size
self.num_heads = num_heads
assert dim % num_heads == 0, "dim must be divisible by num_heads"
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.use_batched_qkv = use_batched_qkv
if use_batched_qkv:
# Use BatchedLinear for Muon-friendly QKV projection
self.qkv = BatchedLinear(dim, dim, num_matrices=3, bias=qkv_bias)
else:
# Original single linear layer approach
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.order = order
assert flash_attn is not None, "Make sure flash_attn is installed."
self.attn_drop_p = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def _offset_to_attn_offset(
self, offsets: Int[Tensor, "B+1"], patch_size: Optional[int] = None
) -> Int[Tensor, "B"]:
"""
Convert offsets to cumulative attention offsets required for flash attention.
If the patch size is 8 and the offsets are [0, 3, 11, 40] (3 batches),
the cumulative attention offsets are [0, 3, 3 + 8 = 11, 11 + 8, 11 + 8 + 8, 11 + 8 + 8 + 8, 40].
Args:
offsets: (B+1)
patch_size: Optional[int]
Returns:
cum_seqlens: M
"""
patch_size = patch_size or self.patch_size
counts = torch.diff(offsets)
num_patches_per_batch = counts // patch_size
# Fast path: if no patches, return original offsets
if num_patches_per_batch.sum() == 0:
return offsets
# Calculate how many elements each batch contributes (1 start + num_patches)
elements_per_batch = 1 + num_patches_per_batch
# Create indices for which batch each element belongs to
batch_indices = torch.repeat_interleave(
torch.arange(len(offsets) - 1, device=offsets.device), elements_per_batch
)
# Create indices for position within each batch's sequence (0, 1, 2, ...)
within_batch_indices = torch.cat(
[
torch.arange(n + 1, device=offsets.device, dtype=offsets.dtype)
for n in num_patches_per_batch
]
)
# Calculate the actual offsets: start_offset + patch_index * patch_size
start_offsets = offsets[:-1][batch_indices]
patch_contributions = within_batch_indices * patch_size
result_middle = start_offsets + patch_contributions
# Add the final offset
result = torch.cat([result_middle, offsets[-1].unsqueeze(0)])
return result.contiguous()
def forward(self, x: Geometry, order: Optional[POINT_ORDERING] = None) -> Geometry:
# Assert that x is serialized
K = self.patch_size
feats = x.features
M, C = feats.shape[:2]
inverse_perm = None
order = order or self.order
if not hasattr(x, "order") or (order != x.order):
# Generate new ordering and inverse permutation
code_result = encode(
x.coordinate_tensor,
batch_offsets=x.offsets,
order=order,
return_perm=True,
return_inverse=True,
)
feats = feats[code_result.perm]
inverse_perm = code_result.inverse_perm
# Compute QKV: (M, 3, num_heads, head_dim)
qkv = self.qkv(feats).reshape(M, 3, self.num_heads, C // self.num_heads)
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.float16)
attn_offsets = self._offset_to_attn_offset(x.offsets, K).to(qkv.device)
# Warning: When the loss is NaN, this module will fail during backward with
# index out of bounds error.
# e.g. /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [192,0,0], thread: [32,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "
# https://discuss.pytorch.org/t/scattergatherkernel-cu-assertion-idx-dim-0-idx-dim-index-size-index-out-of-bounds/195356
out_feat = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv,
attn_offsets,
max_seqlen=K,
dropout_p=self.attn_drop_p if self.training else 0.0,
softmax_scale=self.scale,
)
out_feat = out_feat.reshape(M, C).to(feats.dtype)
out_feat = self.proj(out_feat)
out_feat = self.proj_drop(out_feat)
if inverse_perm is not None:
out_feat = out_feat[inverse_perm]
return x.replace(batched_features=out_feat.to(feats.dtype))
offset_to_mask(x: Float[Tensor, 'B M C'], offsets: Float[Tensor, B + 1], max_num_points: int, dtype: torch.dtype = torch.bool) -> Float[Tensor, 'B 1 M M']
¶
Create a mask for the points in the batch.
Source code in warpconvnet/nn/modules/attention.py
def offset_to_mask(
x: Float[Tensor, "B M C"], # noqa: F821
offsets: Float[Tensor, "B+1"], # noqa: F821
max_num_points: int, # noqa: F821
dtype: torch.dtype = torch.bool,
) -> Float[Tensor, "B 1 M M"]: # noqa: F821
"""
Create a mask for the points in the batch.
"""
B = x.shape[0]
assert B == offsets.shape[0] - 1
mask = torch.zeros(
(B, 1, max_num_points, max_num_points),
dtype=dtype,
device=x.device,
)
num_points = offsets.diff()
if dtype == torch.bool:
for b in range(B):
# mask[b, :, : num_points[b], : num_points[b]] = True
mask[b, :, :, : num_points[b]] = True
else:
raise ValueError(f"Unsupported dtype: {dtype}")
return mask
zero_out_points(x: Float[Tensor, 'B N C'], num_points: Int[Tensor, B]) -> Float[Tensor, 'B N C']
¶
Zero out the points in the batch.
Source code in warpconvnet/nn/modules/attention.py
def zero_out_points(
x: Float[Tensor, "B N C"], num_points: Int[Tensor, "B"] # noqa: F821
) -> Float[Tensor, "B N C"]: # noqa: F821
"""
Zero out the points in the batch.
"""
for b in range(num_points.shape[0]):
x[b, num_points[b] :] = 0
return x
Base module¶
warpconvnet.nn.modules.base_module
¶
BaseSpatialModel
¶
Bases: BaseSpatialModule
Base model class.
Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModel(BaseSpatialModule):
"""Base model class."""
def data_dict_to_input(self, data_dict, **kwargs) -> Any:
"""Convert data dictionary to appropriate input for the model."""
raise NotImplementedError
def loss_dict(self, data_dict, **kwargs) -> Dict:
"""Compute the loss dictionary for the model."""
raise NotImplementedError
@torch.no_grad()
def eval_dict(self, data_dict, **kwargs) -> Dict:
"""Compute the evaluation dictionary for the model."""
raise NotImplementedError
def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
"""Compute the image dict and pointcloud dict for the model."""
raise NotImplementedError
data_dict_to_input(data_dict, **kwargs) -> Any
¶
Convert data dictionary to appropriate input for the model.
Source code in warpconvnet/nn/modules/base_module.py
def data_dict_to_input(self, data_dict, **kwargs) -> Any:
"""Convert data dictionary to appropriate input for the model."""
raise NotImplementedError
eval_dict(data_dict, **kwargs) -> Dict
¶
Compute the evaluation dictionary for the model.
Source code in warpconvnet/nn/modules/base_module.py
@torch.no_grad()
def eval_dict(self, data_dict, **kwargs) -> Dict:
"""Compute the evaluation dictionary for the model."""
raise NotImplementedError
image_pointcloud_dict(data_dict, datamodule) -> Tuple[Dict, Dict]
¶
Compute the image dict and pointcloud dict for the model.
Source code in warpconvnet/nn/modules/base_module.py
def image_pointcloud_dict(self, data_dict, datamodule) -> Tuple[Dict, Dict]:
"""Compute the image dict and pointcloud dict for the model."""
raise NotImplementedError
loss_dict(data_dict, **kwargs) -> Dict
¶
Compute the loss dictionary for the model.
Source code in warpconvnet/nn/modules/base_module.py
def loss_dict(self, data_dict, **kwargs) -> Dict:
"""Compute the loss dictionary for the model."""
raise NotImplementedError
BaseSpatialModule
¶
Bases: Module
Base module for spatial features. The input must be an instance of BatchedSpatialFeatures.
Source code in warpconvnet/nn/modules/base_module.py
class BaseSpatialModule(nn.Module):
"""Base module for spatial features. The input must be an instance of `BatchedSpatialFeatures`."""
@property
def device(self):
"""Returns the device that the model is on."""
return next(self.parameters()).device
def forward(self, x: Geometry):
"""Forward pass."""
raise NotImplementedError
Factor grid¶
warpconvnet.nn.modules.factor_grid
¶
Neural network modules for FactorGrid operations.
This module provides neural network layers and operations specifically designed for working with FactorGrid geometries in the FIGConvNet architecture.
FactorGridCat
¶
Bases: BaseSpatialModule
Concatenate features of two FactorGrid objects.
This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridCat(BaseSpatialModule):
"""Concatenate features of two FactorGrid objects.
This is equivalent to GridFeatureGroupCat but works with FactorGrid objects.
"""
def __init__(self):
super().__init__()
def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
"""Concatenate features from two FactorGrid objects."""
return factor_grid_cat(factor_grid1, factor_grid2)
forward(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid
¶
Concatenate features from two FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
"""Concatenate features from two FactorGrid objects."""
return factor_grid_cat(factor_grid1, factor_grid2)
FactorGridGlobalConv
¶
Bases: BaseSpatialModule
Global convolution with intra-communication for FactorGrid.
This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridGlobalConv(BaseSpatialModule):
"""Global convolution with intra-communication for FactorGrid.
This is equivalent to GridFeatureConv2DBlocksAndIntraCommunication but works with FactorGrid objects.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
compressed_spatial_dims: Tuple[int, ...],
compressed_memory_formats: Tuple[GridMemoryFormat, ...],
stride: int = 1,
up_stride: Optional[int] = None,
communication_types: List[Literal["sum", "mul"]] = ["sum"],
norm: Type[nn.Module] = nn.BatchNorm2d,
activation: Type[nn.Module] = nn.GELU,
):
super().__init__()
assert len(compressed_spatial_dims) == len(
compressed_memory_formats
), "Number of compressed spatial dimensions and compressed memory formats must match"
# Create convolution blocks for each compressed spatial dimension
self.conv_blocks = nn.ModuleList()
for compressed_spatial_dim, compressed_memory_format in zip(
compressed_spatial_dims, compressed_memory_formats
):
self.conv_blocks.append(
_FactorGridConvNormAct(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
compressed_spatial_dim=compressed_spatial_dim,
compressed_memory_format=compressed_memory_format,
stride=stride,
up_stride=up_stride,
norm=norm,
activation=activation,
)
)
# Intra-communication module
self.intra_communications = FactorGridIntraCommunication(
communication_types=communication_types
)
# Projection layer if multiple communication types
if len(communication_types) > 1:
self.proj = FactorGridProjection(
in_channels=out_channels * len(communication_types),
out_channels=out_channels,
kernel_size=1,
compressed_spatial_dims=compressed_spatial_dims,
compressed_memory_formats=compressed_memory_formats,
stride=1,
)
else:
self.proj = nn.Identity()
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Forward pass through the global convolution module."""
assert len(factor_grid) == len(
self.conv_blocks
), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"
# Apply convolution blocks to each grid
convolved_grids = []
for grid, conv_block in zip(factor_grid, self.conv_blocks):
convolved = conv_block(grid)
convolved_grids.append(convolved)
# Apply intra-communication
factor_grid = self.intra_communications(FactorGrid(convolved_grids))
# Apply projection if needed
factor_grid = self.proj(factor_grid)
return factor_grid
forward(factor_grid: FactorGrid) -> FactorGrid
¶
Forward pass through the global convolution module.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Forward pass through the global convolution module."""
assert len(factor_grid) == len(
self.conv_blocks
), f"Expected {len(self.conv_blocks)} grids, got {len(factor_grid)}"
# Apply convolution blocks to each grid
convolved_grids = []
for grid, conv_block in zip(factor_grid, self.conv_blocks):
convolved = conv_block(grid)
convolved_grids.append(convolved)
# Apply intra-communication
factor_grid = self.intra_communications(FactorGrid(convolved_grids))
# Apply projection if needed
factor_grid = self.proj(factor_grid)
return factor_grid
FactorGridIntraCommunication
¶
Bases: BaseSpatialModule
Intra-communication between grids in a FactorGrid.
This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridIntraCommunication(BaseSpatialModule):
"""Intra-communication between grids in a FactorGrid.
This is equivalent to GridFeaturesGroupIntraCommunication but works with FactorGrid objects.
"""
def __init__(self, communication_types: List[Literal["sum", "mul"]] = ["sum"]) -> None:
super().__init__()
assert len(communication_types) > 0, "At least one communication type must be provided"
assert len(communication_types) <= 2, "At most two communication types can be provided"
self.communication_types = communication_types
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Perform intra-communication between grids in the FactorGrid."""
return factor_grid_intra_communication(factor_grid, self.communication_types)
forward(factor_grid: FactorGrid) -> FactorGrid
¶
Perform intra-communication between grids in the FactorGrid.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Perform intra-communication between grids in the FactorGrid."""
return factor_grid_intra_communication(factor_grid, self.communication_types)
FactorGridPadToMatch
¶
Bases: BaseSpatialModule
Pad FactorGrid features to match spatial dimensions for UNet skip connections.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPadToMatch(BaseSpatialModule):
"""Pad FactorGrid features to match spatial dimensions for UNet skip connections."""
def __init__(self):
super().__init__()
def forward(self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid:
"""Pad up_factor_grid to match down_factor_grid spatial dimensions."""
assert len(up_factor_grid) == len(
down_factor_grid
), "FactorGrids must have same number of grids"
padded_grids = []
for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
# Get features and shapes
up_features = up_grid.grid_features.batched_tensor
down_features = down_grid.grid_features.batched_tensor
# Get spatial dimensions (excluding batch and channel)
up_shape = up_features.shape[2:] # Spatial dimensions
down_shape = down_features.shape[2:] # Spatial dimensions
if up_shape == down_shape:
# No padding needed
padded_grids.append(up_grid)
else:
# Calculate padding needed
pad_h = max(0, down_shape[0] - up_shape[0])
pad_w = max(0, down_shape[1] - up_shape[1])
pad_d = max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0
# Apply padding
if len(down_shape) == 2: # 2D case
padded_features = F.pad(up_features, (0, pad_w, 0, pad_h), mode="replicate")
else: # 3D case
padded_features = F.pad(
up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
)
# Create new grid with padded features
padded_grid = up_grid.replace(batched_features=padded_features)
padded_grids.append(padded_grid)
return FactorGrid(padded_grids)
forward(up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid
¶
Pad up_factor_grid to match down_factor_grid spatial dimensions.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, up_factor_grid: FactorGrid, down_factor_grid: FactorGrid) -> FactorGrid:
"""Pad up_factor_grid to match down_factor_grid spatial dimensions."""
assert len(up_factor_grid) == len(
down_factor_grid
), "FactorGrids must have same number of grids"
padded_grids = []
for up_grid, down_grid in zip(up_factor_grid, down_factor_grid):
# Get features and shapes
up_features = up_grid.grid_features.batched_tensor
down_features = down_grid.grid_features.batched_tensor
# Get spatial dimensions (excluding batch and channel)
up_shape = up_features.shape[2:] # Spatial dimensions
down_shape = down_features.shape[2:] # Spatial dimensions
if up_shape == down_shape:
# No padding needed
padded_grids.append(up_grid)
else:
# Calculate padding needed
pad_h = max(0, down_shape[0] - up_shape[0])
pad_w = max(0, down_shape[1] - up_shape[1])
pad_d = max(0, down_shape[2] - up_shape[2]) if len(down_shape) > 2 else 0
# Apply padding
if len(down_shape) == 2: # 2D case
padded_features = F.pad(up_features, (0, pad_w, 0, pad_h), mode="replicate")
else: # 3D case
padded_features = F.pad(
up_features, (0, pad_d, 0, pad_w, 0, pad_h), mode="replicate"
)
# Create new grid with padded features
padded_grid = up_grid.replace(batched_features=padded_features)
padded_grids.append(padded_grid)
return FactorGrid(padded_grids)
FactorGridPool
¶
Bases: BaseSpatialModule
Pooling operation for FactorGrid.
This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridPool(BaseSpatialModule):
"""Pooling operation for FactorGrid.
This is equivalent to GridFeatureGroupPool but works with FactorGrid objects.
"""
def __init__(
self,
pooling_type: Literal["max", "mean", "attention"] = "max",
):
super().__init__()
self.pooling_type = pooling_type
# Pooling operation
if pooling_type == "max":
self.pool_op = nn.AdaptiveMaxPool1d(1)
elif pooling_type == "mean":
self.pool_op = nn.AdaptiveAvgPool1d(1)
elif pooling_type == "attention":
# For now, use simple attention mechanism
# Note: attention layer dimensions will depend on actual feature dimensions
self.attention = None # Will be set based on input if needed
self.pool_op = None
else:
raise ValueError(f"Unsupported pooling type: {pooling_type}")
def forward(self, factor_grid: FactorGrid) -> Tensor:
"""Pool features from FactorGrid to a single tensor."""
return factor_grid_pool(
factor_grid,
self.pooling_type,
pool_op=self.pool_op,
attention_layer=getattr(self, "attention", None),
)
forward(factor_grid: FactorGrid) -> Tensor
¶
Pool features from FactorGrid to a single tensor.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> Tensor:
"""Pool features from FactorGrid to a single tensor."""
return factor_grid_pool(
factor_grid,
self.pooling_type,
pool_op=self.pool_op,
attention_layer=getattr(self, "attention", None),
)
FactorGridProjection
¶
Bases: BaseSpatialModule
Projection operation for FactorGrid.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridProjection(BaseSpatialModule):
"""Projection operation for FactorGrid."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
compressed_spatial_dims: Tuple[int, ...],
compressed_memory_formats: Tuple[GridMemoryFormat, ...],
stride: int = 1,
norm: Type[nn.Module] = nn.BatchNorm2d,
activation: Type[nn.Module] = nn.GELU,
):
super().__init__()
self.convs = nn.ModuleList()
for compressed_spatial_dim, compressed_memory_format in zip(
compressed_spatial_dims, compressed_memory_formats
):
block = nn.Sequential(
nn.Conv2d(
in_channels * compressed_spatial_dim,
out_channels * compressed_spatial_dim,
kernel_size,
stride,
(kernel_size - 1) // 2,
bias=True,
),
norm(out_channels * compressed_spatial_dim),
activation(),
)
self.convs.append(block)
def forward(self, grid: FactorGrid) -> FactorGrid:
projected_grids = []
for grid, conv in zip(grid, self.convs):
projected_grid = conv(grid)
projected_grids.append(projected_grid)
return FactorGrid(projected_grids)
FactorGridToPoint
¶
Bases: BaseSpatialModule
Convert FactorGrid features back to point features using trilinear interpolation.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridToPoint(BaseSpatialModule):
"""Convert FactorGrid features back to point features using trilinear interpolation."""
def __init__(
self,
grid_in_channels: int,
point_in_channels: int,
num_grids: int,
out_channels: int,
use_rel_pos: bool = True,
use_rel_pos_embed: bool = False,
pos_embed_dim: int = 32,
sample_method: Literal["graphconv", "interp"] = "interp",
neighbor_search_type: Literal["radius", "knn"] = "radius",
knn_k: int = 16,
reductions: List[str] = ["mean"],
):
super().__init__()
self.grid_in_channels = grid_in_channels
self.point_in_channels = point_in_channels
self.out_channels = out_channels
self.use_rel_pos = use_rel_pos
self.use_rel_pos_embed = use_rel_pos_embed
self.pos_embed_dim = pos_embed_dim
self.sample_method = sample_method
self.neighbor_search_type = neighbor_search_type
self.knn_k = knn_k
self.reductions = reductions
self.freqs = get_freqs(pos_embed_dim)
# Calculate combined channels for MLP
combined_channels = grid_in_channels * num_grids + point_in_channels
if use_rel_pos_embed:
combined_channels += pos_embed_dim * 3
elif use_rel_pos:
combined_channels += 3
self.combine_mlp = nn.Sequential(
nn.Linear(combined_channels, out_channels),
nn.LayerNorm(out_channels),
nn.GELU(),
nn.Linear(out_channels, out_channels),
)
def _normalize_coordinates(self, coords: Tensor, grid: Grid) -> Tensor:
"""Normalize coordinates to [-1, 1] range for grid_sample using grid bounds."""
# Get bounds from the grid
bounds_min, bounds_max = grid.bounds
# Normalize to [0, 1] first
normalized = (coords - bounds_min) / (bounds_max - bounds_min)
# Convert to [-1, 1] range for grid_sample
normalized = 2.0 * normalized - 1.0
# Ensure coordinates are within bounds
normalized = torch.clamp(normalized, -1.0, 1.0)
return normalized
def _sample_from_grid(self, grid: Grid, point_coords: Tensor) -> Tensor:
"""Sample features from grid using trilinear interpolation."""
grid_features_tensor = grid.grid_features.batched_tensor
batch_size = grid_features_tensor.shape[0]
# Normalize point coordinates using grid bounds
normalized_coords = self._normalize_coordinates(point_coords, grid)
# Reshape coordinates for grid_sample: (B, N, 3) -> (B, N, 1, 1, 3)
# grid_sample expects (B, H, W, D, 3) for 3D grids
normalized_coords = normalized_coords.view(batch_size, -1, 1, 1, 3)
# Convert grid to standard format for grid_sample
grid_reshaped = grid.grid_features.to_memory_format(GridMemoryFormat.b_c_x_y_z)
# Use grid_sample for trilinear interpolation
sampled_features = F.grid_sample(
grid_reshaped.batched_tensor,
normalized_coords,
mode="bilinear", # For 3D, this becomes trilinear
padding_mode="border",
align_corners=True,
) # Shape: B, C, N, 1, 1
# Reshape to (B*N, C)
sampled_features = sampled_features.squeeze(-1).squeeze(-1).transpose(1, 2) # B, N, C
sampled_features = sampled_features.flatten(start_dim=0, end_dim=1) # B*N, C
return sampled_features
def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
"""Convert FactorGrid features to point features using trilinear interpolation."""
# Get point coordinates and features
point_coords = point_features.coordinate_tensor
point_feats = point_features.feature_tensor
batch_size = point_features.batch_size
num_points = point_coords.shape[0]
# Sample features from all grids and concatenate
all_grid_features = []
for grid in factor_grid:
# Sample features from this grid
sampled_features = self._sample_from_grid(grid, point_coords)
all_grid_features.append(sampled_features)
# Concatenate features from all grids
if len(all_grid_features) > 1:
grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
else:
grid_feat_per_point = all_grid_features[0]
# Add relative position features if requested
if self.use_rel_pos_embed:
# Use bounds from the first grid for relative position calculation
bounds_min, bounds_max = factor_grid[0].bounds
rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
pos_encoding = sinusoidal_encoding(rel_pos, self.pos_embed_dim, data_range=2)
combined_features = torch.cat([point_feats, grid_feat_per_point, pos_encoding], dim=-1)
elif self.use_rel_pos:
# Use raw relative positions
# Use bounds from the first grid for relative position calculation
bounds_min, bounds_max = factor_grid[0].bounds
rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
combined_features = torch.cat([point_feats, grid_feat_per_point, rel_pos], dim=-1)
else:
# Just concatenate point and grid features
combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)
# Apply MLP
output_features = self.combine_mlp(combined_features)
# Create new Points object
return Points(
batched_coordinates=point_features.batched_coordinates,
batched_features=output_features,
)
forward(factor_grid: FactorGrid, point_features: Points) -> Points
¶
Convert FactorGrid features to point features using trilinear interpolation.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid, point_features: Points) -> Points:
"""Convert FactorGrid features to point features using trilinear interpolation."""
# Get point coordinates and features
point_coords = point_features.coordinate_tensor
point_feats = point_features.feature_tensor
batch_size = point_features.batch_size
num_points = point_coords.shape[0]
# Sample features from all grids and concatenate
all_grid_features = []
for grid in factor_grid:
# Sample features from this grid
sampled_features = self._sample_from_grid(grid, point_coords)
all_grid_features.append(sampled_features)
# Concatenate features from all grids
if len(all_grid_features) > 1:
grid_feat_per_point = torch.cat(all_grid_features, dim=-1)
else:
grid_feat_per_point = all_grid_features[0]
# Add relative position features if requested
if self.use_rel_pos_embed:
# Use bounds from the first grid for relative position calculation
bounds_min, bounds_max = factor_grid[0].bounds
rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
pos_encoding = sinusoidal_encoding(rel_pos, self.pos_embed_dim, data_range=2)
combined_features = torch.cat([point_feats, grid_feat_per_point, pos_encoding], dim=-1)
elif self.use_rel_pos:
# Use raw relative positions
# Use bounds from the first grid for relative position calculation
bounds_min, bounds_max = factor_grid[0].bounds
rel_pos = point_coords - ((bounds_max + bounds_min) / 2.0)
combined_features = torch.cat([point_feats, grid_feat_per_point, rel_pos], dim=-1)
else:
# Just concatenate point and grid features
combined_features = torch.cat([point_feats, grid_feat_per_point], dim=-1)
# Apply MLP
output_features = self.combine_mlp(combined_features)
# Create new Points object
return Points(
batched_coordinates=point_features.batched_coordinates,
batched_features=output_features,
)
FactorGridTransform
¶
Bases: BaseSpatialModule
Apply a transform operation to all grids in a FactorGrid.
This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.
Source code in warpconvnet/nn/modules/factor_grid.py
class FactorGridTransform(BaseSpatialModule):
"""Apply a transform operation to all grids in a FactorGrid.
This is equivalent to GridFeatureGroupTransform but works with FactorGrid objects.
"""
def __init__(self, transform: nn.Module, in_place: bool = True) -> None:
super().__init__()
self.transform = transform
self.in_place = in_place
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Apply transform to all grids in the FactorGrid."""
return factor_grid_transform(factor_grid, self.transform, self.in_place)
forward(factor_grid: FactorGrid) -> FactorGrid
¶
Apply transform to all grids in the FactorGrid.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, factor_grid: FactorGrid) -> FactorGrid:
"""Apply transform to all grids in the FactorGrid."""
return factor_grid_transform(factor_grid, self.transform, self.in_place)
PointToFactorGrid
¶
Bases: BaseSpatialModule
Convert point features to FactorGrid representation.
Source code in warpconvnet/nn/modules/factor_grid.py
class PointToFactorGrid(BaseSpatialModule):
"""Convert point features to FactorGrid representation."""
def __init__(
self,
in_channels: int,
out_channels: int,
grid_shapes: List[Tuple[int, int, int]],
memory_formats: List[GridMemoryFormat],
aabb_max: Tuple[float, float, float],
aabb_min: Tuple[float, float, float],
use_rel_pos: bool = True,
use_rel_pos_embed: bool = True,
pos_encode_dim: int = 32,
search_radius: Optional[float] = None,
k: int = 8,
search_type: Literal["radius", "knn", "voxel"] = "radius",
reduction: str = "mean",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grid_shapes = grid_shapes
self.memory_formats = memory_formats
self.aabb_max = aabb_max
self.aabb_min = aabb_min
self.use_rel_pos = use_rel_pos
self.use_rel_pos_embed = use_rel_pos_embed
self.pos_encode_dim = pos_encode_dim
self.search_radius = search_radius
self.k = k
self.search_type = search_type
self.reduction = reduction
# Calculate compressed spatial dimensions
self.compressed_spatial_dims = []
for grid_shape, memory_format in zip(grid_shapes, memory_formats):
# Determine compressed spatial dimension
if memory_format == GridMemoryFormat.b_zc_x_y:
compressed_dim = grid_shape[0] # Z dimension
elif memory_format == GridMemoryFormat.b_xc_y_z:
compressed_dim = grid_shape[0] # X dimension
elif memory_format == GridMemoryFormat.b_yc_x_z:
compressed_dim = grid_shape[1] # Y dimension
else:
raise ValueError(f"Unsupported memory format: {memory_format}")
self.compressed_spatial_dims.append(compressed_dim)
# Create projection layers for each grid
self.projections = nn.ModuleList()
for compressed_dim in self.compressed_spatial_dims:
# Create projection layer
proj = nn.Linear(in_channels, out_channels * compressed_dim)
self.projections.append(proj)
def forward(self, points: Points) -> FactorGrid:
"""Convert point features to FactorGrid."""
from warpconvnet.geometry.types.conversion.to_factor_grid import points_to_factor_grid
# Convert points to FactorGrid using the existing function
factor_grid = points_to_factor_grid(
points=points,
grid_shapes=self.grid_shapes,
memory_formats=self.memory_formats,
bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
search_radius=self.search_radius,
k=self.k,
search_type=self.search_type,
reduction=self.reduction,
)
# For now, return the factor grid as is without projections
# The projections can be applied later in the pipeline if needed
return factor_grid
forward(points: Points) -> FactorGrid
¶
Convert point features to FactorGrid.
Source code in warpconvnet/nn/modules/factor_grid.py
def forward(self, points: Points) -> FactorGrid:
"""Convert point features to FactorGrid."""
from warpconvnet.geometry.types.conversion.to_factor_grid import points_to_factor_grid
# Convert points to FactorGrid using the existing function
factor_grid = points_to_factor_grid(
points=points,
grid_shapes=self.grid_shapes,
memory_formats=self.memory_formats,
bounds=(torch.tensor(self.aabb_min), torch.tensor(self.aabb_max)),
search_radius=self.search_radius,
k=self.k,
search_type=self.search_type,
reduction=self.reduction,
)
# For now, return the factor grid as is without projections
# The projections can be applied later in the pipeline if needed
return factor_grid
Grid convolution¶
warpconvnet.nn.modules.grid_conv
¶
GridConv
¶
Bases: BaseSpatialModule
Convolutional layer for warpconvnet.geometry.types.grid.Grid data.
Parameters mirror those of torch.nn.Conv3d but operate on a
Grid object instead of plain tensors.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/grid_conv.py
class GridConv(BaseSpatialModule):
"""Convolutional layer for `warpconvnet.geometry.types.grid.Grid` data.
Parameters mirror those of `torch.nn.Conv3d` but operate on a
``Grid`` object instead of plain tensors.
Parameters
----------
in_channels : int
Number of input feature channels.
out_channels : int
Number of output feature channels.
kernel_size : int or tuple of int
Size of the convolution kernel.
stride : int or tuple of int, optional
Stride of the convolution. Defaults to ``1``.
padding : int or tuple of int, optional
Zero-padding added to all three sides of the input. Defaults to ``0``.
dilation : int or tuple of int, optional
Spacing between kernel elements. Defaults to ``1``.
bias : bool, optional
If ``True``, adds a learnable bias to the output. Defaults to ``True``.
num_spatial_dims : int, optional
Number of spatial dimensions. Defaults to ``3``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
bias: bool = True,
num_spatial_dims: Optional[int] = 3,
):
super().__init__()
kernel_size = ntuple(kernel_size, ndim=num_spatial_dims)
stride = ntuple(stride, ndim=num_spatial_dims)
padding = ntuple(padding, ndim=num_spatial_dims)
dilation = ntuple(dilation, ndim=num_spatial_dims)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.num_spatial_dims = num_spatial_dims
# For 3D convolution, shape is (out_channels, in_channels, depth, height, width)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))
if bias:
self.bias = nn.Parameter(torch.randn(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"in_channels={self.in_channels}, "
f"out_channels={self.out_channels}, "
f"kernel_size={self.kernel_size}, "
f"stride={self.stride}, "
f"padding={self.padding}, "
f"dilation={self.dilation}, "
f"bias={self.bias is not None}"
f")"
)
def reset_parameters(self):
# Standard initialization for convolutional layers
init.kaiming_uniform_(self.weight, a=1)
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
init.uniform_(self.bias, -bound, bound)
def forward(self, input_grid: Grid) -> Grid:
return grid_conv(
grid=input_grid,
weight=self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=self.bias,
)
MLP¶
warpconvnet.nn.modules.mlp
¶
BatchedLinear
¶
Bases: Module
A linear layer with batched weights for Muon-friendly optimization.
Instead of a single weight matrix [in_features, out_features * num_matrices], this uses separate weight matrices stacked as [num_matrices, in_features, out_features]. This structure is more suitable for Muon optimization as it can orthogonalize each [in_features, out_features] matrix independently.
Args: in_features: Input feature dimension out_features: Output feature dimension per matrix num_matrices: Number of separate matrices (e.g., 3 for Q, K, V) bias: Whether to use bias parameters
Source code in warpconvnet/nn/modules/mlp.py
class BatchedLinear(nn.Module):
"""
A linear layer with batched weights for Muon-friendly optimization.
Instead of a single weight matrix [in_features, out_features * num_matrices],
this uses separate weight matrices stacked as [num_matrices, in_features, out_features].
This structure is more suitable for Muon optimization as it can orthogonalize
each [in_features, out_features] matrix independently.
Args:
in_features: Input feature dimension
out_features: Output feature dimension per matrix
num_matrices: Number of separate matrices (e.g., 3 for Q, K, V)
bias: Whether to use bias parameters
"""
def __init__(
self, in_features: int, out_features: int, num_matrices: int = 3, bias: bool = True
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_matrices = num_matrices
# Create batched weight: [num_matrices, in_features, out_features]
self.weight = nn.Parameter(torch.empty(num_matrices, in_features, out_features))
nn.init.xavier_uniform_(self.weight)
if bias:
# Use flat bias for Muon - 1D parameter gets Adam optimization
self.bias = nn.Parameter(torch.zeros(num_matrices * out_features))
else:
self.register_parameter("bias", None)
def forward(self, input: Tensor) -> Tensor:
"""
Forward pass with batched matrix multiplication.
Args:
input: Input tensor of shape [..., in_features]
Returns:
Output tensor of shape [..., num_matrices, out_features]
"""
# input: [..., in_features], weight: [num_matrices, in_features, out_features]
# output: [..., num_matrices, out_features]
output = torch.einsum("...i,kio->...ko", input, self.weight)
if self.bias is not None:
output += self.bias.view(self.num_matrices, self.out_features)
output = output.to(input.dtype)
return output
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, num_matrices={self.num_matrices}, bias={self.bias is not None}"
forward(input: Tensor) -> Tensor
¶
Forward pass with batched matrix multiplication.
Args: input: Input tensor of shape [..., in_features]
Returns: Output tensor of shape [..., num_matrices, out_features]
Source code in warpconvnet/nn/modules/mlp.py
def forward(self, input: Tensor) -> Tensor:
"""
Forward pass with batched matrix multiplication.
Args:
input: Input tensor of shape [..., in_features]
Returns:
Output tensor of shape [..., num_matrices, out_features]
"""
# input: [..., in_features], weight: [num_matrices, in_features, out_features]
# output: [..., num_matrices, out_features]
output = torch.einsum("...i,kio->...ko", input, self.weight)
if self.bias is not None:
output += self.bias.view(self.num_matrices, self.out_features)
output = output.to(input.dtype)
return output
Linear
¶
Bases: BaseSpatialModule
Apply a linear layer to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/mlp.py
class Linear(BaseSpatialModule):
"""Apply a linear layer to ``Geometry`` features.
Parameters
----------
in_features : int
Number of input features.
out_features : int
Number of output features.
bias : bool, optional
If ``True`` adds a bias term to the layer. Defaults to ``True``.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.block = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: Geometry):
return x.replace(batched_features=self.block(x.feature_tensor))
LinearNormActivation
¶
Bases: BaseSpatialModule
Linear layer followed by LayerNorm and ReLU.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/mlp.py
class LinearNormActivation(BaseSpatialModule):
"""Linear layer followed by ``LayerNorm`` and ``ReLU``.
Parameters
----------
in_features : int
Number of input features.
out_features : int
Number of output features.
bias : bool, optional
Whether to include a bias term. Defaults to ``True``.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.block = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
nn.LayerNorm(out_features),
nn.ReLU(),
)
def forward(self, x: Geometry):
return x.replace(batched_features=self.block(x.feature_tensor))
MLPBlock
¶
Bases: BaseSpatialModule
MLP block with a residual connection.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/mlp.py
class MLPBlock(BaseSpatialModule):
"""MLP block with a residual connection.
Parameters
----------
in_channels : int
Number of input features.
out_channels : int, optional
Number of output features. Defaults to ``in_channels``.
hidden_channels : int, optional
Hidden layer size. Defaults to ``in_channels``.
activation : ``nn.Module``, optional
Activation module to apply. Defaults to `torch.nn.ReLU`.
bias : bool, optional
If ``True`` adds bias terms to the linear layers. Defaults to ``True``.
"""
def __init__(
self,
in_channels: int,
out_channels: int = None,
hidden_channels: int = None,
activation=nn.ReLU,
bias: bool = True,
):
super().__init__()
if hidden_channels is None:
hidden_channels = in_channels
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.block = nn.Sequential(
nn.Linear(in_channels, hidden_channels, bias=bias),
nn.LayerNorm(hidden_channels),
activation(),
nn.Linear(hidden_channels, out_channels, bias=bias),
nn.LayerNorm(out_channels),
)
self.shortcut = (
nn.Linear(in_channels, out_channels, bias=bias)
if in_channels != out_channels
else nn.Identity()
)
def _forward_feature(self, x: Tensor) -> Tensor:
out = self.block(x)
out = out + self.shortcut(x)
return out
def forward(self, x: Union[Tensor, Geometry]):
if isinstance(x, Geometry):
return x.replace(batched_features=self._forward_feature(x.feature_tensor))
else:
return self._forward_feature(x)
Normalizations¶
warpconvnet.nn.modules.normalizations
¶
BatchNorm
¶
Bases: NormalizationBase
Applies torch.nn.BatchNorm1d to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class BatchNorm(NormalizationBase):
"""Applies `torch.nn.BatchNorm1d` to ``Geometry`` features.
Parameters
----------
num_features : int
Number of feature channels in the input.
eps : float, optional
Value added to the denominator for numerical stability. Defaults to ``1e-5``.
momentum : float, optional
Momentum factor for the running statistics. Defaults to ``0.1``.
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
super().__init__(nn.BatchNorm1d(num_features, eps=eps, momentum=momentum))
GroupNorm
¶
Bases: NormalizationBase
Applies torch.nn.GroupNorm to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class GroupNorm(NormalizationBase):
"""Applies `torch.nn.GroupNorm` to ``Geometry`` features.
Parameters
----------
num_groups : int
Number of groups to separate the channels into.
num_channels : int
Number of channels expected in the input.
eps : float, optional
Value added to the denominator for numerical stability. Defaults to ``1e-5``.
"""
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
super().__init__(nn.GroupNorm(num_groups, num_channels, eps=eps))
InstanceNorm
¶
Bases: NormalizationBase
Applies torch.nn.InstanceNorm1d to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class InstanceNorm(NormalizationBase):
"""Applies `torch.nn.InstanceNorm1d` to ``Geometry`` features.
Parameters
----------
num_features : int
Number of feature channels in the input.
eps : float, optional
Value added to the denominator for numerical stability. Defaults to ``1e-5``.
"""
def __init__(self, num_features: int, eps: float = 1e-5):
super().__init__(nn.InstanceNorm1d(num_features, eps=eps))
LayerNorm
¶
Bases: NormalizationBase
Applies torch.nn.LayerNorm to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class LayerNorm(NormalizationBase):
"""Applies `torch.nn.LayerNorm` to ``Geometry`` features.
Parameters
----------
normalized_shape : list of int
Input shape from an expected input.
eps : float, optional
A value added to the denominator for numerical stability. Defaults to ``1e-5``.
elementwise_affine : bool, optional
Whether to learn elementwise affine parameters. Defaults to ``True``.
bias : bool, optional
If ``True`` adds bias parameters. Defaults to ``True``.
"""
def __init__(
self,
normalized_shape: List[int],
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
):
super().__init__(
nn.LayerNorm(
normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
bias=bias,
)
)
NormalizationBase
¶
Bases: BaseSpatialModule
Wrapper for applying a normalization module to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class NormalizationBase(BaseSpatialModule):
"""Wrapper for applying a normalization module to ``Geometry`` features.
Parameters
----------
norm : ``nn.Module``
Normalization module to apply to the feature tensor.
"""
def __init__(self, norm: nn.Module):
super().__init__()
self.norm = norm
def __repr__(self):
return f"{self.__class__.__name__}({self.norm})"
def forward(
self,
input: Union[Geometry, Tensor],
):
return apply_feature_transform(input, self.norm)
RMSNorm
¶
Bases: NormalizationBase
Applies torch.nn.RMSNorm to Geometry features.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class RMSNorm(NormalizationBase):
"""Applies `torch.nn.RMSNorm` to ``Geometry`` features.
Parameters
----------
dim : int
Number of input features.
eps : float, optional
Value added to the denominator for numerical stability. Defaults to ``1e-6``.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__(nn.RMSNorm(dim, eps=eps))
SegmentedLayerNorm
¶
Bases: LayerNorm
Layer normalization that respects variable-length segments.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/normalizations.py
class SegmentedLayerNorm(nn.LayerNorm):
"""Layer normalization that respects variable-length segments.
Parameters
----------
channels : int
Number of feature channels.
eps : float, optional
A value added to the denominator for numerical stability. Defaults to ``1e-5``.
elementwise_affine : bool, optional
If ``True`` learn per-channel affine parameters. Defaults to ``True``.
bias : bool, optional
Whether to include a bias term. Defaults to ``True``.
"""
def __init__(
self,
channels: int,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
):
super().__init__([channels], eps=eps, elementwise_affine=elementwise_affine, bias=bias)
def forward(self, x: Geometry):
# Only works for geometry with batched features
assert isinstance(
x, Geometry
), f"SegmentedLayerNorm only works for Geometry, got {type(x)}"
out_feature = segmented_layer_norm(
x.feature_tensor,
x.offsets,
gamma=self.weight if self.elementwise_affine else None,
beta=self.bias if self.elementwise_affine else None,
eps=self.eps,
)
return x.replace(batched_feature=out_feature)
def __repr__(self):
return f"{self.__class__.__name__}({self.weight.shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine})"
Point convolution¶
warpconvnet.nn.modules.point_conv
¶
PointConv
¶
Bases: BaseSpatialModule
Point convolution operating on warpconvnet.geometry.types.points.Points.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_conv.py
class PointConv(BaseSpatialModule):
"""Point convolution operating on `warpconvnet.geometry.types.points.Points`.
Parameters
----------
in_channels : int
Number of input feature channels.
out_channels : int
Number of output feature channels.
neighbor_search_args : `warpconvnet.geometry.coords.search.search_configs.RealSearchConfig`
Configuration for neighbor search.
pooling_reduction : `warpconvnet.ops.reductions.REDUCTIONS`, optional
Reduction used when downsampling points. Required when ``out_point_type`` is
``"downsample"``.
pooling_voxel_size : float, optional
Size of voxels used for downsampling when ``out_point_type`` is ``"downsample"``.
edge_transform_mlp : nn.Module, optional
MLP applied to constructed edge features.
out_transform_mlp : nn.Module, optional
MLP applied after neighborhood reduction.
mlp_block : nn.Module, optional
Module used to build ``edge_transform_mlp`` and ``out_transform_mlp`` when not
provided. Defaults to `MLPBlock`.
hidden_dim : int, optional
Hidden dimension of the automatically created MLPs.
channel_multiplier : int, optional
Multiplier used when ``hidden_dim`` is ``None``. Defaults to ``2``.
use_rel_pos : bool, optional
Include relative coordinates of neighbors as part of the edge features.
use_rel_pos_encode : bool, optional
Include sinusoidal encoding of relative coordinates in the edge features.
pos_encode_dim : int, optional
Dimension of the positional encoding. Defaults to ``32``.
pos_encode_range : float, optional
Range of the positional encoding. Defaults to ``4``.
reductions : list of str, optional
Reductions applied over the neighbor dimension. Defaults to ``("mean",)``.
out_point_type : {"provided", "downsample", "same"}, optional
Determines the coordinate set on which output features are computed.
provided_in_channels : int, optional
Number of channels of the provided query points when ``out_point_type`` is ``"provided"``.
bias : bool, optional
Whether linear layers contain biases. Defaults to ``True``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
neighbor_search_args: RealSearchConfig,
pooling_reduction: Optional[REDUCTIONS] = None,
pooling_voxel_size: Optional[float] = None,
edge_transform_mlp: Optional[nn.Module] = None,
out_transform_mlp: Optional[nn.Module] = None,
mlp_block: nn.Module = MLPBlock,
hidden_dim: Optional[int] = None,
channel_multiplier: int = 2,
use_rel_pos: bool = False,
use_rel_pos_encode: bool = False,
pos_encode_dim: int = 32,
pos_encode_range: float = 4,
reductions: List[REDUCTION_TYPES_STR] = ("mean",),
out_point_type: Literal["provided", "downsample", "same"] = "same",
provided_in_channels: Optional[int] = None,
bias: bool = True,
):
super().__init__()
assert (
isinstance(reductions, (tuple, list)) and len(reductions) > 0
), f"reductions must be a list or tuple of length > 0, got {reductions}"
if out_point_type == "provided":
assert pooling_reduction is None
assert pooling_voxel_size is None
assert (
provided_in_channels is not None
), "provided_in_channels must be provided for provided type"
elif out_point_type == "downsample":
assert (
pooling_reduction is not None and pooling_voxel_size is not None
), "pooling_reduction and pooling_voxel_size must be provided for downsample type"
assert (
provided_in_channels is None
), "provided_in_channels must be None for downsample type"
# print warning if search radius is not \sqrt(3) times the downsample voxel size
if (
pooling_voxel_size is not None
and neighbor_search_args.mode == RealSearchMode.RADIUS
and neighbor_search_args.radius < pooling_voxel_size * (3**0.5)
):
warnings.warn(
f"neighbor search radius {neighbor_search_args.radius} is less than sqrt(3) times the downsample voxel size {pooling_voxel_size}",
stacklevel=2,
)
elif out_point_type == "same":
assert (
pooling_reduction is None and pooling_voxel_size is None
), "pooling_reduction and pooling_voxel_size must be None for same type"
assert provided_in_channels is None, "provided_in_channels must be None for same type"
if (
pooling_reduction is not None
and pooling_voxel_size is not None
and neighbor_search_args.mode == RealSearchMode.RADIUS
and pooling_voxel_size > neighbor_search_args.radius
):
raise ValueError(
f"downsample_voxel_size {pooling_voxel_size} must be <= radius {neighbor_search_args.radius}"
)
assert isinstance(neighbor_search_args, RealSearchConfig)
self.reductions = reductions
self.in_channels = in_channels
self.out_channels = out_channels
self.use_rel_pos = use_rel_pos
self.use_rel_pos_encode = use_rel_pos_encode
self.out_point_feature_type = out_point_type
self.neighbor_search_args = neighbor_search_args
self.pooling_reduction = pooling_reduction
self.pooling_voxel_size = pooling_voxel_size
self.positional_encoding = SinusoidalEncoding(pos_encode_dim, data_range=pos_encode_range)
# When down voxel size is not None, there will be out_point_features will be provided as an additional input
if provided_in_channels is None:
provided_in_channels = in_channels
if hidden_dim is None:
hidden_dim = channel_multiplier * max(out_channels, in_channels)
if edge_transform_mlp is None:
edge_in_channels = in_channels + provided_in_channels
if use_rel_pos_encode:
edge_in_channels += pos_encode_dim * 3
elif use_rel_pos:
edge_in_channels += 3
edge_transform_mlp = mlp_block(
in_channels=edge_in_channels,
out_channels=out_channels,
hidden_channels=hidden_dim,
bias=bias,
)
self.edge_transform_mlp = edge_transform_mlp
self.edge_mlp_in_channels = _get_module_input_channel(edge_transform_mlp)
if out_transform_mlp is None:
out_transform_mlp = mlp_block(
in_channels=out_channels * len(reductions),
out_channels=out_channels,
hidden_channels=hidden_dim,
bias=bias,
)
self.out_transform_mlp = out_transform_mlp
def __repr__(self):
out_str = f"{self.__class__.__name__}(in_channels={self.in_channels} out_channels={self.out_channels}"
if self.use_rel_pos_encode:
out_str += f" rel_pos_encode={self.use_rel_pos_encode}"
if self.pooling_reduction is not None:
out_str += f" pooling={self.pooling_reduction}"
if self.neighbor_search_args is not None:
out_str += f" neighbor={self.neighbor_search_args}"
out_str += ")"
return out_str
def forward(
self,
in_pc: Points,
query_pc: Optional[Points] = None,
) -> Points:
"""
When out_point_features is None, the output will be generated on the
in_point_features.batched_coordinates.
"""
if self.out_point_feature_type == "provided":
assert (
query_pc is not None
), "query_point_features must be provided for the provided type"
elif self.out_point_feature_type == "downsample":
assert query_pc is None
query_pc = in_pc.voxel_downsample(
self.pooling_voxel_size,
reduction=self.pooling_reduction,
)
elif self.out_point_feature_type == "same":
assert query_pc is None
query_pc = in_pc
in_num_channels = in_pc.num_channels
query_num_channels = query_pc.num_channels
assert (
in_num_channels
+ query_num_channels
+ self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
+ (not self.use_rel_pos_encode) * self.use_rel_pos * 3
== self.edge_mlp_in_channels
), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"
# Get the neighbors
neighbors = in_pc.neighbors(
query_coords=query_pc.batched_coordinates,
search_args=self.neighbor_search_args,
)
neighbor_indices = neighbors.neighbor_indices.long().view(-1)
neighbor_row_splits = neighbors.neighbor_row_splits
num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]
# repeat the self features using num_reps
rep_in_features = in_pc.feature_tensor[neighbor_indices]
self_features = torch.repeat_interleave(
query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
num_reps,
dim=0,
)
edge_features = [rep_in_features, self_features]
if self.use_rel_pos or self.use_rel_pos_encode:
in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
self_vertices = torch.repeat_interleave(
query_pc.coordinate_tensor.view(-1, 3).contiguous(),
num_reps,
dim=0,
)
rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
if self.use_rel_pos_encode:
edge_features.append(self.positional_encoding(rel_coords))
elif self.use_rel_pos:
edge_features.append(rel_coords)
edge_features = torch.cat(edge_features, dim=1)
edge_features = self.edge_transform_mlp(edge_features)
# if in_weight is not None:
# assert in_weight.shape[0] == in_point_features.features.shape[0]
# rep_weights = in_weight[neighbor_indices]
# edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)
out_features = []
for reduction in self.reductions:
out_features.append(
row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
)
out_features = torch.cat(out_features, dim=-1)
out_features = self.out_transform_mlp(out_features)
return Points(
batched_coordinates=Coords(
batched_tensor=query_pc.coordinate_tensor,
offsets=query_pc.offsets,
),
batched_features=out_features,
**query_pc.extra_attributes,
)
forward(in_pc: Points, query_pc: Optional[Points] = None) -> Points
¶
When out_point_features is None, the output will be generated on the in_point_features.batched_coordinates.
Source code in warpconvnet/nn/modules/point_conv.py
def forward(
self,
in_pc: Points,
query_pc: Optional[Points] = None,
) -> Points:
"""
When out_point_features is None, the output will be generated on the
in_point_features.batched_coordinates.
"""
if self.out_point_feature_type == "provided":
assert (
query_pc is not None
), "query_point_features must be provided for the provided type"
elif self.out_point_feature_type == "downsample":
assert query_pc is None
query_pc = in_pc.voxel_downsample(
self.pooling_voxel_size,
reduction=self.pooling_reduction,
)
elif self.out_point_feature_type == "same":
assert query_pc is None
query_pc = in_pc
in_num_channels = in_pc.num_channels
query_num_channels = query_pc.num_channels
assert (
in_num_channels
+ query_num_channels
+ self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
+ (not self.use_rel_pos_encode) * self.use_rel_pos * 3
== self.edge_mlp_in_channels
), f"input features shape {in_pc.feature_tensor.shape} and query feature shape {query_pc.feature_tensor.shape} does not match the edge_transform_mlp input channels {self.edge_mlp_in_channels}"
# Get the neighbors
neighbors = in_pc.neighbors(
query_coords=query_pc.batched_coordinates,
search_args=self.neighbor_search_args,
)
neighbor_indices = neighbors.neighbor_indices.long().view(-1)
neighbor_row_splits = neighbors.neighbor_row_splits
num_reps = neighbor_row_splits[1:] - neighbor_row_splits[:-1]
# repeat the self features using num_reps
rep_in_features = in_pc.feature_tensor[neighbor_indices]
self_features = torch.repeat_interleave(
query_pc.feature_tensor.view(-1, query_num_channels).contiguous(),
num_reps,
dim=0,
)
edge_features = [rep_in_features, self_features]
if self.use_rel_pos or self.use_rel_pos_encode:
in_rep_vertices = in_pc.coordinate_tensor.view(-1, 3)[neighbor_indices]
self_vertices = torch.repeat_interleave(
query_pc.coordinate_tensor.view(-1, 3).contiguous(),
num_reps,
dim=0,
)
rel_coords = in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
if self.use_rel_pos_encode:
edge_features.append(self.positional_encoding(rel_coords))
elif self.use_rel_pos:
edge_features.append(rel_coords)
edge_features = torch.cat(edge_features, dim=1)
edge_features = self.edge_transform_mlp(edge_features)
# if in_weight is not None:
# assert in_weight.shape[0] == in_point_features.features.shape[0]
# rep_weights = in_weight[neighbor_indices]
# edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)
out_features = []
for reduction in self.reductions:
out_features.append(
row_reduction(edge_features, neighbor_row_splits, reduction=reduction)
)
out_features = torch.cat(out_features, dim=-1)
out_features = self.out_transform_mlp(out_features)
return Points(
batched_coordinates=Coords(
batched_tensor=query_pc.coordinate_tensor,
offsets=query_pc.offsets,
),
batched_features=out_features,
**query_pc.extra_attributes,
)
Point pooling¶
warpconvnet.nn.modules.point_pool
¶
PointAvgPool
¶
Bases: PointPoolBase
Point pooling using mean reduction.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_pool.py
class PointAvgPool(PointPoolBase):
"""Point pooling using ``mean`` reduction.
Parameters
----------
downsample_max_num_points : int, optional
Maximum number of points to keep when downsampling.
downsample_voxel_size : float, optional
Size of voxels used for downsampling.
return_type : {"point", "sparse"}, optional
Output geometry type. Defaults to ``"point"``.
return_neighbor_search_result : bool, optional
If ``True`` also return the neighbor search result. Defaults to ``False``.
"""
def __init__(
self,
downsample_max_num_points: Optional[int] = None,
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "sparse"] = "point",
return_neighbor_search_result: bool = False,
):
super().__init__(
reduction=REDUCTIONS.MEAN,
downsample_max_num_points=downsample_max_num_points,
downsample_voxel_size=downsample_voxel_size,
return_type=return_type,
return_neighbor_search_result=return_neighbor_search_result,
)
PointMaxPool
¶
Bases: PointPoolBase
Point pooling using max reduction.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_pool.py
class PointMaxPool(PointPoolBase):
"""Point pooling using ``max`` reduction.
Parameters
----------
downsample_max_num_points : int, optional
Maximum number of points to keep when downsampling.
downsample_voxel_size : float, optional
Size of voxels used for downsampling.
return_type : {"point", "sparse"}, optional
Output geometry type. Defaults to ``"point"``.
return_neighbor_search_result : bool, optional
If ``True`` also return the neighbor search result. Defaults to ``False``.
"""
def __init__(
self,
downsample_max_num_points: Optional[int] = None,
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "sparse"] = "point",
return_neighbor_search_result: bool = False,
):
super().__init__(
reduction=REDUCTIONS.MAX,
downsample_max_num_points=downsample_max_num_points,
downsample_voxel_size=downsample_voxel_size,
return_type=return_type,
return_neighbor_search_result=return_neighbor_search_result,
)
PointPoolBase
¶
Bases: BaseSpatialModule
Base module for pooling points or voxels.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_pool.py
class PointPoolBase(BaseSpatialModule):
"""Base module for pooling points or voxels.
Parameters
----------
reduction : str or `REDUCTIONS`, optional
Reduction method used when merging features. Defaults to ``REDUCTIONS.MAX``.
downsample_max_num_points : int, optional
Maximum number of points to keep when downsampling.
downsample_voxel_size : float, optional
Size of voxels used for downsampling.
return_type : {"point", "sparse"}, optional
Output geometry type. Defaults to ``"point"``.
unique_method : {"torch", "ravel", "morton"}, optional
Method used to find unique voxel indices. Defaults to ``"torch"``.
avereage_pooled_coordinates : bool, optional
If ``True`` average coordinates of points within each voxel. Defaults to ``False``.
return_neighbor_search_result : bool, optional
If ``True`` also return the neighbor search result. Defaults to ``False``.
"""
def __init__(
self,
reduction: Union[str, REDUCTIONS] = REDUCTIONS.MAX,
downsample_max_num_points: Optional[int] = None,
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "sparse"] = "point",
unique_method: Literal["torch", "ravel", "morton"] = "torch",
avereage_pooled_coordinates: bool = False,
return_neighbor_search_result: bool = False,
):
super().__init__()
if isinstance(reduction, str):
reduction = REDUCTIONS(reduction)
self.reduction = reduction
self.downsample_max_num_points = downsample_max_num_points
self.downsample_voxel_size = downsample_voxel_size
self.return_type = return_type
self.return_neighbor_search_result = return_neighbor_search_result
self.unique_method = unique_method
self.avereage_pooled_coordinates = avereage_pooled_coordinates
def forward(self, pc: Points) -> Union[Geometry, Tuple[Geometry, RealSearchResult]]:
return point_pool(
pc=pc,
reduction=self.reduction,
downsample_max_num_points=self.downsample_max_num_points,
downsample_voxel_size=self.downsample_voxel_size,
return_type=self.return_type,
return_neighbor_search_result=self.return_neighbor_search_result,
unique_method=self.unique_method,
avereage_pooled_coordinates=self.avereage_pooled_coordinates,
)
PointSumPool
¶
Bases: PointPoolBase
Point pooling using sum reduction.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_pool.py
class PointSumPool(PointPoolBase):
"""Point pooling using ``sum`` reduction.
Parameters
----------
downsample_max_num_points : int, optional
Maximum number of points to keep when downsampling.
downsample_voxel_size : float, optional
Size of voxels used for downsampling.
return_type : {"point", "sparse"}, optional
Output geometry type. Defaults to ``"point"``.
return_neighbor_search_result : bool, optional
If ``True`` also return the neighbor search result. Defaults to ``False``.
"""
def __init__(
self,
downsample_max_num_points: Optional[int] = None,
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "sparse"] = "point",
return_neighbor_search_result: bool = False,
):
super().__init__(
reduction=REDUCTIONS.SUM,
downsample_max_num_points=downsample_max_num_points,
downsample_voxel_size=downsample_voxel_size,
return_type=return_type,
return_neighbor_search_result=return_neighbor_search_result,
)
PointUnpool
¶
Bases: BaseSpatialModule
Undo point pooling by scattering pooled features back to the input cloud.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/point_pool.py
class PointUnpool(BaseSpatialModule):
"""Undo point pooling by scattering pooled features back to the input cloud.
Parameters
----------
unpooling_mode : {"repeat", "interpolate"} or `FEATURE_UNPOOLING_MODE`, optional
Strategy used when unpooling features. Defaults to ``FEATURE_UNPOOLING_MODE.REPEAT``.
concat_unpooled_pc : bool, optional
If ``True`` concatenate the unpooled point cloud with the input. Defaults to ``False``.
"""
def __init__(
self,
unpooling_mode: Union[str, FEATURE_UNPOOLING_MODE] = FEATURE_UNPOOLING_MODE.REPEAT,
concat_unpooled_pc: bool = False,
):
super().__init__()
if isinstance(unpooling_mode, str):
unpooling_mode = FEATURE_UNPOOLING_MODE(unpooling_mode)
self.unpooling_mode = unpooling_mode
self.concat_unpooled_pc = concat_unpooled_pc
def forward(self, pooled_pc: Points, unpooled_pc: Points):
return point_unpool(
pooled_pc=pooled_pc,
unpooled_pc=unpooled_pc,
unpooling_mode=self.unpooling_mode,
concat_unpooled_pc=self.concat_unpooled_pc,
)
Prune¶
warpconvnet.nn.modules.prune
¶
SparsePrune
¶
Bases: BaseSpatialModule
Module wrapper around prune_spatially_sparse_tensor so pruning can be composed in nn.Sequential.
Forward Args
spatial_tensor : Geometry
Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered.
mask : Bool[Tensor, "N"]
Boolean mask aligned with spatial_tensor.coordinate_tensor.
Source code in warpconvnet/nn/modules/prune.py
class SparsePrune(BaseSpatialModule):
"""
Module wrapper around ``prune_spatially_sparse_tensor`` so pruning can be composed in nn.Sequential.
Forward Args
-----------
spatial_tensor : Geometry
Sparse geometry (e.g., Voxels) whose coordinates/features will be filtered.
mask : Bool[Tensor, "N"]
Boolean mask aligned with ``spatial_tensor.coordinate_tensor``.
"""
def forward(
self,
spatial_tensor: Geometry,
mask: Bool[Tensor, "N"], # noqa: F821
) -> Geometry:
return prune_spatially_sparse_tensor(spatial_tensor, mask)
Sequential¶
warpconvnet.nn.modules.sequential
¶
GeometryWrapper
¶
Bases: BaseSpatialModule
Wrapper for a spatial module that returns a geometry object.
Source code in warpconvnet/nn/modules/sequential.py
class GeometryWrapper(BaseSpatialModule):
"""Wrapper for a spatial module that returns a geometry object."""
def __init__(self, module: Callable[[Tensor], Tensor]):
super().__init__()
self.module = module
def forward(self, x: Geometry) -> Geometry:
return x.replace(batched_features=self.module(x.feature_tensor))
Sequential
¶
Bases: Sequential, BaseSpatialModule
Sequential module that allows for spatial and non-spatial layers to be chained together.
If the module has multiple consecutive non-spatial layers, then it will not create an intermediate spatial features object and will become more efficient.
Source code in warpconvnet/nn/modules/sequential.py
class Sequential(nn.Sequential, BaseSpatialModule):
"""
Sequential module that allows for spatial and non-spatial layers to be chained together.
If the module has multiple consecutive non-spatial layers, then it will not create an intermediate
spatial features object and will become more efficient.
"""
def forward(self, x: Geometry):
assert isinstance(x, Geometry), f"Expected BatchedSpatialFeatures, got {type(x)}"
in_sf = x
for module in self:
x, in_sf = run_forward(module, x, in_sf)
if isinstance(x, torch.Tensor):
x = in_sf.replace(batched_features=x)
return x
TupleSequential
¶
Bases: Sequential, BaseSpatialModule
Sequential module that allows multiple inputs for a specified layer.
Source code in warpconvnet/nn/modules/sequential.py
class TupleSequential(Sequential, BaseSpatialModule):
"""
Sequential module that allows multiple inputs for a specified layer.
"""
def __init__(self, *args, tuple_layer: int):
if len(args) == 1 and isinstance(args[0], Sequence):
super().__init__(*args[0])
else:
super().__init__(*args)
self.tuple_layer = tuple_layer
def forward(self, *xs: Tuple[Geometry]):
x = xs[0]
in_sf = x
for i, module in enumerate(self):
if i == self.tuple_layer:
x, in_sf = tuple_run_forward(module, (x, *xs[1:]), in_sf)
else:
x, in_sf = run_forward(module, x, in_sf)
if isinstance(x, torch.Tensor):
x = in_sf.replace(batched_features=x)
return x
Sparse convolution¶
warpconvnet.nn.modules.sparse_conv
¶
SparseConv2d
¶
Bases: SpatiallySparseConv
2D sparse convolution.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv2d(SpatiallySparseConv):
"""2D sparse convolution.
Parameters
----------
in_channels : int
Number of input feature channels.
out_channels : int
Number of output feature channels.
kernel_size : int or tuple of int
Size of the convolution kernel.
stride : int or tuple of int, optional
Convolution stride. Defaults to ``1``.
dilation : int or tuple of int, optional
Spacing between kernel elements. Defaults to ``1``.
bias : bool, optional
If ``True`` adds a learnable bias to the output. Defaults to ``True``.
transposed : bool, optional
Perform a transposed convolution. Defaults to ``False``.
generative : bool, optional
Use generative convolution. Defaults to ``False``.
stride_mode : `STRIDED_CONV_MODE`, optional
How to interpret ``stride`` when ``transposed`` is ``True``.
fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
Forward algorithm to use.
bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
Backward algorithm to use.
kernel_matmul_batch_size : int, optional
Batch size used for implicit matrix multiplications. Defaults to ``2``.
order : `POINT_ORDERING`, optional
Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
compute_dtype : torch.dtype, optional
Data type used for intermediate computations.
implicit_matmul_fwd_block_size : int, optional
CUDA block size for implicit forward matmuls.
implicit_matmul_bwd_block_size : int, optional
CUDA block size for implicit backward matmuls.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
bias=True,
transposed=False,
generative: bool = False,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
kernel_matmul_batch_size: int = 2,
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
implicit_matmul_fwd_block_size: Optional[int] = None,
implicit_matmul_bwd_block_size: Optional[int] = None,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
transposed=transposed,
generative=generative,
num_spatial_dims=2,
stride_mode=stride_mode,
fwd_algo=fwd_algo,
bwd_algo=bwd_algo,
kernel_matmul_batch_size=kernel_matmul_batch_size,
order=order,
compute_dtype=compute_dtype,
implicit_matmul_fwd_block_size=implicit_matmul_fwd_block_size,
implicit_matmul_bwd_block_size=implicit_matmul_bwd_block_size,
)
SparseConv3d
¶
Bases: SpatiallySparseConv
3D sparse convolution.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_conv.py
class SparseConv3d(SpatiallySparseConv):
"""3D sparse convolution.
Parameters
----------
in_channels : int
Number of input feature channels.
out_channels : int
Number of output feature channels.
kernel_size : int or tuple of int
Size of the convolution kernel.
stride : int or tuple of int, optional
Convolution stride. Defaults to ``1``.
dilation : int or tuple of int, optional
Spacing between kernel elements. Defaults to ``1``.
bias : bool, optional
If ``True`` adds a learnable bias to the output. Defaults to ``True``.
transposed : bool, optional
Perform a transposed convolution. Defaults to ``False``.
generative : bool, optional
Use generative convolution. Defaults to ``False``.
stride_mode : `STRIDED_CONV_MODE`, optional
How to interpret ``stride`` when ``transposed`` is ``True``.
fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
Forward algorithm to use.
bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
Backward algorithm to use.
kernel_matmul_batch_size : int, optional
Batch size used for implicit matrix multiplications. Defaults to ``2``.
order : `POINT_ORDERING`, optional
Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
compute_dtype : torch.dtype, optional
Data type used for intermediate computations.
implicit_matmul_fwd_block_size : int, optional
CUDA block size for implicit forward matmuls.
implicit_matmul_bwd_block_size : int, optional
CUDA block size for implicit backward matmuls.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
bias=True,
transposed=False,
generative: bool = False,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
kernel_matmul_batch_size: int = 2,
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
implicit_matmul_fwd_block_size: Optional[int] = None,
implicit_matmul_bwd_block_size: Optional[int] = None,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
transposed=transposed,
generative=generative,
num_spatial_dims=3,
stride_mode=stride_mode,
fwd_algo=fwd_algo,
bwd_algo=bwd_algo,
kernel_matmul_batch_size=kernel_matmul_batch_size,
order=order,
compute_dtype=compute_dtype,
implicit_matmul_fwd_block_size=implicit_matmul_fwd_block_size,
implicit_matmul_bwd_block_size=implicit_matmul_bwd_block_size,
)
SpatiallySparseConv
¶
Bases: BaseSpatialModule
Sparse convolution layer for warpconvnet.geometry.types.voxels.Voxels.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_conv.py
class SpatiallySparseConv(BaseSpatialModule):
"""Sparse convolution layer for `warpconvnet.geometry.types.voxels.Voxels`.
Parameters
----------
in_channels : int
Number of input feature channels.
out_channels : int
Number of output feature channels.
kernel_size : int or tuple of int
Size of the convolution kernel.
stride : int or tuple of int, optional
Convolution stride. Defaults to ``1``.
dilation : int or tuple of int, optional
Spacing between kernel elements. Defaults to ``1``.
bias : bool, optional
If ``True`` adds a learnable bias to the output. Defaults to ``True``.
transposed : bool, optional
Perform a transposed convolution. Defaults to ``False``.
generative : bool, optional
Use generative convolution. Defaults to ``False``.
kernel_matmul_batch_size : int, optional
Batch size used for implicit matrix multiplications. Defaults to ``2``.
num_spatial_dims : int, optional
Number of spatial dimensions. Defaults to ``3``.
fwd_algo : `SPARSE_CONV_FWD_ALGO_MODE` or str, optional
Forward algorithm to use. Defaults to environment setting.
bwd_algo : `SPARSE_CONV_BWD_ALGO_MODE` or str, optional
Backward algorithm to use. Defaults to environment setting.
stride_mode : `STRIDED_CONV_MODE`, optional
How to interpret ``stride`` when ``transposed`` is ``True``.
order : `POINT_ORDERING`, optional
Ordering of points in the output. Defaults to ``POINT_ORDERING.RANDOM``.
compute_dtype : torch.dtype, optional
Data type used for intermediate computations.
implicit_matmul_fwd_block_size : int, optional
CUDA block size for implicit forward matmuls.
implicit_matmul_bwd_block_size : int, optional
CUDA block size for implicit backward matmuls.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
bias: bool = True,
transposed: bool = False,
generative: bool = False,
kernel_matmul_batch_size: int = 2,
num_spatial_dims: Optional[int] = 3,
fwd_algo: Optional[Union[SPARSE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_CONV_BWD_ALGO_MODE, str]] = None,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
implicit_matmul_fwd_block_size: Optional[int] = None,
implicit_matmul_bwd_block_size: Optional[int] = None,
):
super().__init__()
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, dilation are tuples for consistent use
_kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
_stride = ntuple(stride, ndim=self.num_spatial_dims)
_dilation = ntuple(dilation, ndim=self.num_spatial_dims)
self.kernel_size = _kernel_size
self.stride = _stride
self.dilation = _dilation
self.transposed = transposed
self.generative = generative
self.kernel_matmul_batch_size = kernel_matmul_batch_size
# Use environment variable values if not explicitly provided
if fwd_algo is None:
fwd_algo = WARPCONVNET_FWD_ALGO_MODE
if bwd_algo is None:
bwd_algo = WARPCONVNET_BWD_ALGO_MODE
# Convert string to enum, but preserve lists for direct passing to functional layer
if isinstance(fwd_algo, str):
self.fwd_algo = SPARSE_CONV_FWD_ALGO_MODE(fwd_algo)
else:
# Keep lists as-is (from env vars or direct user input)
self.fwd_algo = fwd_algo
if isinstance(bwd_algo, str):
self.bwd_algo = SPARSE_CONV_BWD_ALGO_MODE(bwd_algo)
else:
# Keep lists as-is (from env vars or direct user input)
self.bwd_algo = bwd_algo
self.stride_mode = stride_mode
self.order = order
self.compute_dtype = compute_dtype
self.implicit_matmul_fwd_block_size = implicit_matmul_fwd_block_size
self.implicit_matmul_bwd_block_size = implicit_matmul_bwd_block_size
self.bias: Optional[nn.Parameter] = None
self.weight = nn.Parameter(torch.randn(np.prod(_kernel_size), in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.randn(out_channels))
else:
self.bias = None # Explicitly set to None if bias is False
self.reset_parameters() # Call after parameters are defined for the chosen backend
def __repr__(self):
# return class name and parameters that are not default
out_str = f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}"
if self.stride != 1:
out_str += f", stride={self.stride}"
if self.dilation != 1:
out_str += f", dilation={self.dilation}"
if self.transposed:
out_str += f", transposed={self.transposed}"
if self.generative:
out_str += f", generative={self.generative}"
if self.order != POINT_ORDERING.RANDOM:
out_str += f", order={self.order}"
out_str += ")"
return out_str
def _calculate_fan_in_and_fan_out(self):
receptive_field_size = np.prod(self.kernel_size)
fan_in = self.in_channels * receptive_field_size
fan_out = self.out_channels * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(self, mode: Literal["fan_in", "fan_out"]):
mode = mode.lower()
assert mode in ["fan_in", "fan_out"]
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == "fan_in" else fan_out
def _custom_kaiming_uniform_(self, tensor, a=0, mode="fan_in", nonlinearity="leaky_relu"):
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(self.num_spatial_dims) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
@torch.no_grad()
def reset_parameters(self):
self._custom_kaiming_uniform_(
self.weight,
a=math.sqrt(5),
mode="fan_out" if self.transposed else "fan_in",
)
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(
self,
input_sparse_tensor: Voxels,
output_spatially_sparse_tensor: Optional[Voxels] = None,
):
return spatially_sparse_conv(
input_sparse_tensor=input_sparse_tensor,
weight=self.weight,
kernel_size=self.kernel_size,
stride=self.stride,
kernel_dilation=self.dilation,
bias=self.bias,
kernel_matmul_batch_size=self.kernel_matmul_batch_size,
output_spatially_sparse_tensor=output_spatially_sparse_tensor,
transposed=self.transposed,
generative=self.generative,
fwd_algo=self.fwd_algo,
bwd_algo=self.bwd_algo,
stride_mode=self.stride_mode,
order=self.order,
compute_dtype=self.compute_dtype,
implicit_matmul_fwd_block_size=self.implicit_matmul_fwd_block_size,
implicit_matmul_bwd_block_size=self.implicit_matmul_bwd_block_size,
)
Sparse depthwise convolution¶
warpconvnet.nn.modules.sparse_conv_depth
¶
SparseDepthwiseConv2d
¶
Bases: SpatiallySparseDepthwiseConv
2D spatially sparse depthwise convolution.
Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv2d(SpatiallySparseDepthwiseConv):
"""2D spatially sparse depthwise convolution."""
def __init__(
self,
channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
dilation: Union[int, Tuple[int, int]] = 1,
bias: bool = True,
transposed: bool = False,
generative: bool = False,
fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
stride_reduce: str = "max",
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
):
super().__init__(
channels=channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
transposed=transposed,
generative=generative,
num_spatial_dims=2,
fwd_algo=fwd_algo,
bwd_algo=bwd_algo,
stride_mode=stride_mode,
stride_reduce=stride_reduce,
order=order,
compute_dtype=compute_dtype,
)
SparseDepthwiseConv3d
¶
Bases: SpatiallySparseDepthwiseConv
3D spatially sparse depthwise convolution.
Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SparseDepthwiseConv3d(SpatiallySparseDepthwiseConv):
"""3D spatially sparse depthwise convolution."""
def __init__(
self,
channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
transposed: bool = False,
generative: bool = False,
fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
stride_reduce: str = "max",
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
):
super().__init__(
channels=channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=bias,
transposed=transposed,
generative=generative,
num_spatial_dims=3,
fwd_algo=fwd_algo,
bwd_algo=bwd_algo,
stride_mode=stride_mode,
stride_reduce=stride_reduce,
order=order,
compute_dtype=compute_dtype,
)
SpatiallySparseDepthwiseConv
¶
Bases: BaseSpatialModule
Spatially sparse depthwise convolution module.
In depthwise convolution, each input channel is convolved with its own kernel, so the number of input channels must equal the number of output channels. The weight shape is (K, C) where K is the kernel volume and C is the number of channels.
Source code in warpconvnet/nn/modules/sparse_conv_depth.py
class SpatiallySparseDepthwiseConv(BaseSpatialModule):
"""
Spatially sparse depthwise convolution module.
In depthwise convolution, each input channel is convolved with its own kernel,
so the number of input channels must equal the number of output channels.
The weight shape is (K, C) where K is the kernel volume and C is the number of channels.
"""
def __init__(
self,
channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
bias: bool = True,
transposed: bool = False,
generative: bool = False,
num_spatial_dims: int = 3,
fwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, str]] = None,
bwd_algo: Optional[Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, str]] = None,
stride_mode: STRIDED_CONV_MODE = STRIDED_CONV_MODE.STRIDE_ONLY,
stride_reduce: str = "max",
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
compute_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.num_spatial_dims = num_spatial_dims
self.channels = channels
self.in_channels = channels # For compatibility with PyTorch naming
self.out_channels = channels # For depthwise, in_channels == out_channels
# Ensure kernel_size, stride, dilation are tuples for consistent use
self.kernel_size = ntuple(kernel_size, ndim=self.num_spatial_dims)
self.stride = ntuple(stride, ndim=self.num_spatial_dims)
self.dilation = ntuple(dilation, ndim=self.num_spatial_dims)
self.transposed = transposed
self.generative = generative
self.stride_reduce = stride_reduce
# Use environment variable values if not explicitly provided
if fwd_algo is None:
fwd_algo = WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE
if bwd_algo is None:
bwd_algo = WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE
# Map string algo names to depthwise-specific enums if needed
if isinstance(fwd_algo, str):
# Map generic algorithm names to depthwise-specific ones
if fwd_algo.lower() in ["explicit", "explicit_gemm"]:
self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.EXPLICIT
elif fwd_algo.lower() in ["implicit", "implicit_gemm"]:
self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.IMPLICIT
elif fwd_algo.lower() == "auto":
self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE.AUTO
else:
self.fwd_algo = SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE(fwd_algo)
else:
self.fwd_algo = fwd_algo
if isinstance(bwd_algo, str):
# Map generic algorithm names to depthwise-specific ones
if bwd_algo.lower() in ["explicit", "explicit_gemm"]:
self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.EXPLICIT
elif bwd_algo.lower() in ["implicit", "implicit_gemm"]:
self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.IMPLICIT
elif bwd_algo.lower() == "auto":
self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE.AUTO
else:
self.bwd_algo = SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE(bwd_algo)
else:
self.bwd_algo = bwd_algo
self.stride_mode = stride_mode
self.order = order
self.compute_dtype = compute_dtype
# Depthwise convolution weight shape: (K, C) where K is kernel volume
kernel_volume = int(np.prod(self.kernel_size))
self.weight = nn.Parameter(torch.randn(kernel_volume, channels))
# Optional bias
if bias:
self.bias = nn.Parameter(torch.randn(channels))
else:
self.bias = None
self.reset_parameters()
def __repr__(self):
out_str = (
f"{self.__class__.__name__}(channels={self.channels}, "
f"kernel_size={self.kernel_size}"
)
if self.stride != (1,) * self.num_spatial_dims:
out_str += f", stride={self.stride}"
if self.dilation != (1,) * self.num_spatial_dims:
out_str += f", dilation={self.dilation}"
if self.transposed:
out_str += f", transposed={self.transposed}"
if self.generative:
out_str += f", generative={self.generative}"
if self.order != POINT_ORDERING.RANDOM:
out_str += f", order={self.order}"
if self.bias is None:
out_str += ", bias=False"
out_str += ")"
return out_str
def _calculate_fan_in_and_fan_out(self):
"""Calculate fan_in and fan_out for depthwise convolution."""
receptive_field_size = np.prod(self.kernel_size)
# For depthwise convolution, each channel has its own kernel
fan_in = receptive_field_size # One kernel per channel
fan_out = receptive_field_size # One output per channel
return fan_in, fan_out
def _calculate_correct_fan(self, mode: str):
"""Calculate correct fan for initialization."""
mode = mode.lower()
assert mode in ["fan_in", "fan_out"]
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == "fan_in" else fan_out
def _custom_kaiming_uniform_(self, tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
"""Custom Kaiming uniform initialization for depthwise convolution."""
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(self.num_spatial_dims) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
@torch.no_grad()
def reset_parameters(self):
"""Reset module parameters using appropriate initialization."""
self._custom_kaiming_uniform_(
self.weight,
a=math.sqrt(5),
mode="fan_out" if self.transposed else "fan_in",
)
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(
self,
input_sparse_tensor: Voxels,
output_spatially_sparse_tensor: Optional[Voxels] = None,
) -> Voxels:
"""
Forward pass for spatially sparse depthwise convolution.
Args:
input_sparse_tensor: Input sparse tensor
output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv
Returns:
Output sparse tensor
"""
# Generate output coordinates and kernel map
batch_indexed_out_coords, out_offsets, kernel_map = generate_output_coords_and_kernel_map(
input_sparse_tensor=input_sparse_tensor,
kernel_size=self.kernel_size,
kernel_dilation=self.dilation,
stride=self.stride,
generative=self.generative,
transposed=self.transposed,
output_spatially_sparse_tensor=output_spatially_sparse_tensor,
stride_mode=self.stride_mode,
order=self.order,
)
num_out_coords = batch_indexed_out_coords.shape[0]
# Apply depthwise convolution
output_features = spatially_sparse_depthwise_conv(
input_sparse_tensor.feature_tensor,
self.weight,
kernel_map,
num_out_coords,
fwd_algo=self.fwd_algo,
bwd_algo=self.bwd_algo,
compute_dtype=self.compute_dtype,
)
# Add bias if present
if self.bias is not None:
output_features = output_features + self.bias
# Determine output tensor stride
in_tensor_stride = input_sparse_tensor.tensor_stride
if in_tensor_stride is None:
in_tensor_stride = (1,) * self.num_spatial_dims
if not self.transposed:
out_tensor_stride = tuple(o * s for o, s in zip(self.stride, in_tensor_stride))
else:
if (
output_spatially_sparse_tensor is not None
and output_spatially_sparse_tensor.tensor_stride is not None
):
out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
else:
out_tensor_stride = (1,) * self.num_spatial_dims
# Create output voxels
out_offsets_cpu = out_offsets.cpu().int()
out_coords = IntCoords(
batch_indexed_out_coords[:, 1:],
offsets=out_offsets_cpu,
)
return input_sparse_tensor.replace(
batched_coordinates=out_coords,
batched_features=output_features,
tensor_stride=out_tensor_stride,
)
forward(input_sparse_tensor: Voxels, output_spatially_sparse_tensor: Optional[Voxels] = None) -> Voxels
¶
Forward pass for spatially sparse depthwise convolution.
Args: input_sparse_tensor: Input sparse tensor output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv
Returns: Output sparse tensor
Source code in warpconvnet/nn/modules/sparse_conv_depth.py
def forward(
self,
input_sparse_tensor: Voxels,
output_spatially_sparse_tensor: Optional[Voxels] = None,
) -> Voxels:
"""
Forward pass for spatially sparse depthwise convolution.
Args:
input_sparse_tensor: Input sparse tensor
output_spatially_sparse_tensor: Optional output sparse tensor for transposed conv
Returns:
Output sparse tensor
"""
# Generate output coordinates and kernel map
batch_indexed_out_coords, out_offsets, kernel_map = generate_output_coords_and_kernel_map(
input_sparse_tensor=input_sparse_tensor,
kernel_size=self.kernel_size,
kernel_dilation=self.dilation,
stride=self.stride,
generative=self.generative,
transposed=self.transposed,
output_spatially_sparse_tensor=output_spatially_sparse_tensor,
stride_mode=self.stride_mode,
order=self.order,
)
num_out_coords = batch_indexed_out_coords.shape[0]
# Apply depthwise convolution
output_features = spatially_sparse_depthwise_conv(
input_sparse_tensor.feature_tensor,
self.weight,
kernel_map,
num_out_coords,
fwd_algo=self.fwd_algo,
bwd_algo=self.bwd_algo,
compute_dtype=self.compute_dtype,
)
# Add bias if present
if self.bias is not None:
output_features = output_features + self.bias
# Determine output tensor stride
in_tensor_stride = input_sparse_tensor.tensor_stride
if in_tensor_stride is None:
in_tensor_stride = (1,) * self.num_spatial_dims
if not self.transposed:
out_tensor_stride = tuple(o * s for o, s in zip(self.stride, in_tensor_stride))
else:
if (
output_spatially_sparse_tensor is not None
and output_spatially_sparse_tensor.tensor_stride is not None
):
out_tensor_stride = output_spatially_sparse_tensor.tensor_stride
else:
out_tensor_stride = (1,) * self.num_spatial_dims
# Create output voxels
out_offsets_cpu = out_offsets.cpu().int()
out_coords = IntCoords(
batch_indexed_out_coords[:, 1:],
offsets=out_offsets_cpu,
)
return input_sparse_tensor.replace(
batched_coordinates=out_coords,
batched_features=output_features,
tensor_stride=out_tensor_stride,
)
reset_parameters()
¶
Reset module parameters using appropriate initialization.
Source code in warpconvnet/nn/modules/sparse_conv_depth.py
@torch.no_grad()
def reset_parameters(self):
"""Reset module parameters using appropriate initialization."""
self._custom_kaiming_uniform_(
self.weight,
a=math.sqrt(5),
mode="fan_out" if self.transposed else "fan_in",
)
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
Sparse pooling¶
warpconvnet.nn.modules.sparse_pool
¶
GlobalPool
¶
Bases: BaseSpatialModule
Pool features across the entire geometry.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class GlobalPool(BaseSpatialModule):
"""Pool features across the entire geometry.
Parameters
----------
reduce : {"min", "max", "mean", "sum"}, optional
Reduction to apply over all features. Defaults to ``"max"``.
"""
def __init__(self, reduce: Literal["min", "max", "mean", "sum"] = "max"):
super().__init__()
self.reduce = reduce
def forward(self, x: Geometry):
return global_pool(x, self.reduce)
PointToSparseWrapper
¶
Bases: BaseSpatialModule
Pool points into a sparse tensor, apply an inner module and unpool back to points.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class PointToSparseWrapper(BaseSpatialModule):
"""Pool points into a sparse tensor, apply an inner module and unpool back to points.
Parameters
----------
inner_module : `BaseSpatialModule`
Module applied on the pooled sparse tensor.
voxel_size : float
Voxel size used to pool the input points.
reduction : `REDUCTIONS` or str, optional
Reduction used when pooling points. Defaults to ``REDUCTIONS.MEAN``.
unique_method : {"morton", "ravel", "torch"}, optional
Method used for hashing voxel indices. Defaults to ``"morton"``.
concat_unpooled_pc : bool, optional
If ``True`` concatenate the unpooled result with the original input. Defaults to ``True``.
"""
def __init__(
self,
inner_module: BaseSpatialModule,
voxel_size: float,
reduction: Union[REDUCTIONS, REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
unique_method: Literal["morton", "ravel", "torch"] = "morton",
concat_unpooled_pc: bool = True,
):
super().__init__()
self.inner_module = inner_module
self.voxel_size = voxel_size
self.reduction = reduction
self.concat_unpooled_pc = concat_unpooled_pc
self.unique_method = unique_method
def forward(self, pc: Points) -> Points:
st, to_unique = point_pool(
pc,
reduction=self.reduction,
downsample_voxel_size=self.voxel_size,
return_type="voxel",
return_to_unique=True,
unique_method=self.unique_method,
)
out_st = self.inner_module(st)
assert isinstance(out_st, Voxels), "Output of inner module must be a Voxels"
unpooled_pc = point_unpool(
out_st.to_point(self.voxel_size),
pc,
concat_unpooled_pc=self.concat_unpooled_pc,
to_unique=to_unique,
)
return unpooled_pc
SparseMaxPool
¶
Bases: SparsePool
Max pooling for sparse tensors.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMaxPool(SparsePool):
"""Max pooling for sparse tensors.
Parameters
----------
kernel_size : int
Size of the pooling kernel.
stride : int
Stride between pooling windows.
"""
def __init__(
self,
kernel_size: int,
stride: int,
):
super().__init__(kernel_size, stride, "max")
SparseMinPool
¶
Bases: SparsePool
Min pooling for sparse tensors.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseMinPool(SparsePool):
"""Min pooling for sparse tensors.
Parameters
----------
kernel_size : int
Size of the pooling kernel.
stride : int
Stride between pooling windows.
"""
def __init__(
self,
kernel_size: int,
stride: int,
):
super().__init__(kernel_size, stride, "min")
SparsePool
¶
Bases: BaseSpatialModule
Reduce features of a Voxels object using a strided kernel.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class SparsePool(BaseSpatialModule):
"""Reduce features of a ``Voxels`` object using a strided kernel.
Parameters
----------
kernel_size : int
Size of the pooling kernel.
stride : int
Stride between pooling windows.
reduce : {"max", "min", "mean", "sum", "random"}, optional
Reduction to apply within each window. Defaults to ``"max"``.
"""
def __init__(
self,
kernel_size: int,
stride: int,
reduce: Literal["max", "min", "mean", "sum", "random"] = "max",
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.reduce = reduce
def __repr__(self):
return f"{self.__class__.__name__}(kernel_size={self.kernel_size}, stride={self.stride}, reduce={self.reduce})"
def forward(self, st: Voxels):
return sparse_reduce(
st,
self.kernel_size,
self.stride,
self.reduce,
)
SparseUnpool
¶
Bases: BaseSpatialModule
Unpool a sparse tensor back to a higher resolution.
| Parameters: |
|
|---|
Source code in warpconvnet/nn/modules/sparse_pool.py
class SparseUnpool(BaseSpatialModule):
"""Unpool a sparse tensor back to a higher resolution.
Parameters
----------
kernel_size : int
Size of the unpooling kernel.
stride : int
Stride between unpooling windows.
concat_unpooled_st : bool, optional
If ``True`` concatenate the unpooled tensor with the input. Defaults to ``True``.
"""
def __init__(self, kernel_size: int, stride: int, concat_unpooled_st: bool = True):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.concat_unpooled_st = concat_unpooled_st
def forward(self, st: Voxels, unpooled_st: Voxels):
return sparse_unpool(
st,
unpooled_st,
self.kernel_size,
self.stride,
self.concat_unpooled_st,
)
Transforms¶
warpconvnet.nn.modules.transforms
¶
Transform
¶
Bases: BaseSpatialModule
Point transform module that applies a feature transform to the input point collection. No spatial operations are performed.
Hydra config example usage:
.. code-block:: yaml
model:
feature_transform:
_target_: warpconvnet.nn.point_transform.Transform
feature_transform_fn: _target_: torch.nn.ReLU
Source code in warpconvnet/nn/modules/transforms.py
class Transform(BaseSpatialModule):
"""
Point transform module that applies a feature transform to the input point collection.
No spatial operations are performed.
Hydra config example usage:
.. code-block:: yaml
model:
feature_transform:
_target_: warpconvnet.nn.point_transform.Transform
feature_transform_fn: _target_: torch.nn.ReLU
"""
def __init__(self, feature_transform_fn: nn.Module):
super().__init__()
self.feature_transform_fn = feature_transform_fn
def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
"""
Apply the feature transform to the input point collection
Args:
pc: Input point collection
Returns:
Transformed point collection
"""
if isinstance(sfs, Geometry):
return sfs.replace(batched_features=self.feature_transform_fn(sfs.feature_tensor))
# When input is not a single BatchedSpatialFeatures, we assume the inputs are features
assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
# Assert that all spatial features have the same offsets
assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
sf = sfs[0]
features = [sf.feature_tensor for sf in sfs]
out_features = self.feature_transform_fn(*features)
return sf.replace(
batched_features=out_features,
)
forward(*sfs: Tuple[Geometry, ...]) -> Geometry
¶
Apply the feature transform to the input point collection
Args: pc: Input point collection
Returns: Transformed point collection
Source code in warpconvnet/nn/modules/transforms.py
def forward(self, *sfs: Tuple[Geometry, ...]) -> Geometry:
"""
Apply the feature transform to the input point collection
Args:
pc: Input point collection
Returns:
Transformed point collection
"""
if isinstance(sfs, Geometry):
return sfs.replace(batched_features=self.feature_transform_fn(sfs.feature_tensor))
# When input is not a single BatchedSpatialFeatures, we assume the inputs are features
assert [isinstance(sf, Geometry) for sf in sfs] == [True] * len(sfs)
# Assert that all spatial features have the same offsets
assert all(torch.allclose(sf.offsets, sfs[0].offsets) for sf in sfs)
sf = sfs[0]
features = [sf.feature_tensor for sf in sfs]
out_features = self.feature_transform_fn(*features)
return sf.replace(
batched_features=out_features,
)