Geometry

warpconvnet.geometry

Geometry containers

Every geometry type in WarpConvNet wraps a Coords instance and a Features instance that share the same ragged-batch metadata.

  • warpconvnet.geometry.base.coords.Coords stores concatenated coordinates plus an offsets vector marking where each example begins.
  • warpconvnet.geometry.base.features.Features (and the CatFeatures and PadFeatures specializations) stores feature tensors that obey the same offsets so coordinates and features always stay aligned.
  • warpconvnet.geometry.base.geometry.Geometry wires the pair together, validates their shapes, and exposes device/dtype utilities with AMP-aware accessors.

This shared contract lets subclasses freely switch between point clouds, voxels, or grids without duplicating batching logic. See Batched coordinate layout for a deeper explanation of how concatenated tensors and offsets interact.

Types

WarpConvNet ships several geometry containers that unify coordinate systems with their associated features. Use these types as the canonical interfaces for points, voxels, dense grids, and FIGConvNet factor grids.

Points

Flexible point-cloud geometry supporting ragged batches, feature paddings, and neighbor search utilities for sparse convolution modules.

warpconvnet.geometry.types.points.Points

Bases: Geometry

Interface class for collections of points

A point collection is a set of points in a geometric space (dim=1 (line), 2 (plane), 3 (space), 4 (space-time)).

Source code in warpconvnet/geometry/types/points.py
class Points(Geometry):
    """
    Interface class for collections of points

    A point collection is a set of points in a geometric space
    (dim=1 (line), 2 (plane), 3 (space), 4 (space-time)).
    """

    def __init__(
        self,
        batched_coordinates: (
            List[Float[Tensor, "N 3"]] | Float[Tensor, "N 3"] | RealCoords
        ),  # noqa: F722,F821
        batched_features: (
            List[Float[Tensor, "N C"]]
            | Float[Tensor, "N C"]
            | Float[Tensor, "B M C"]
            | CatFeatures
            | PadFeatures
        ),  # noqa: F722,F821
        offsets: Optional[Int[Tensor, "B + 1"]] = None,  # noqa: F722,F821
        device: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize a point collection with coordinates and features.
        """
        if isinstance(batched_coordinates, list):
            assert isinstance(
                batched_features, list
            ), "If coords is a list, features must be a list too."
            assert len(batched_coordinates) == len(batched_features)
            # Assert all elements in coords and features have same length
            assert all(
                len(c) == len(f) for c, f in zip(batched_coordinates, batched_features)
            ), "All elements in coords and features must have same length"
            batched_coordinates = RealCoords(batched_coordinates, device=device)
        elif isinstance(batched_coordinates, Tensor):
            assert (
                isinstance(batched_features, Tensor) and offsets is not None
            ), "If coordinate is a tensor, features must be a tensor and offsets must be provided."
            batched_coordinates = RealCoords(batched_coordinates, offsets=offsets, device=device)

        if isinstance(batched_features, list):
            batched_features = CatFeatures(batched_features, device=device)
        elif isinstance(batched_features, Tensor):
            batched_features = to_batched_features(
                batched_features, batched_coordinates.offsets, device=device
            )

        Geometry.__init__(
            self,
            batched_coordinates,
            batched_features,
            **kwargs,
        )

    def sort(
        self,
        voxel_size: float,
        ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
    ):
        """
        Sort the points according to the ordering provided.
        The voxel size defines the smallest discretization and points in the same voxel will have random order.
        """
        # Warp uses int32 so only 10 bits per coordinate supported. Thus max 1024.
        assert self.device.type != "cpu", "Sorting is only supported on GPU"
        result = encode(
            torch.floor(self.coordinate_tensor / voxel_size).int(),
            batch_offsets=self.offsets,
            order=ordering,
            return_perm=True,
        )
        kwargs = self.extra_attributes.copy()
        return self.__class__(
            batched_coordinates=RealCoords(
                batched_tensor=self.coordinate_tensor[result.perm],
                offsets=self.offsets,
            ),
            batched_features=CatFeatures(
                batched_tensor=self.feature_tensor[result.perm],
                offsets=self.offsets,
            ),
            **kwargs,
        )

    def voxel_downsample(
        self,
        voxel_size: float,
        reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM,
    ) -> "Points":
        """
        Voxel downsample the coordinates
        """
        assert self.device.type != "cpu", "Voxel downsample is only supported on GPU"
        extra_args = self.extra_attributes
        extra_args["voxel_size"] = voxel_size
        assert isinstance(
            self.batched_features, CatFeatures
        ), "Voxel downsample is only supported for CatBatchedFeatures"
        if reduction == REDUCTIONS.RANDOM:
            to_unique_indicies, unique_offsets = voxel_downsample_random_indices(
                batched_points=self.coordinate_tensor,
                offsets=self.offsets,
                voxel_size=voxel_size,
            )
            return self.__class__(
                batched_coordinates=RealCoords(
                    batched_tensor=self.coordinate_tensor[to_unique_indicies],
                    offsets=unique_offsets,
                ),
                batched_features=CatFeatures(
                    batched_tensor=self.feature_tensor[to_unique_indicies],
                    offsets=unique_offsets,
                ),
                **extra_args,
            )

        # perm, down_offsets, vox_inices, vox_offsets = voxel_downsample_csr_mapping(
        #     batched_points=self.coordinate_tensor,
        #     offsets=self.offsets,
        #     voxel_size=voxel_size,
        # )
        (
            batch_indexed_down_coords,
            unique_offsets,
            to_csr_indices,
            to_csr_offsets,
            to_unique,
        ) = voxel_downsample_csr_mapping(
            batched_points=self.coordinate_tensor,
            offsets=self.offsets,
            voxel_size=voxel_size,
        )

        neighbors = RealSearchResult(to_csr_indices, to_csr_offsets)
        down_features = row_reduction(
            self.feature_tensor,
            neighbors.neighbor_row_splits,
            reduction,
        )

        return self.__class__(
            batched_coordinates=RealCoords(
                batched_tensor=self.coordinate_tensor[to_unique.to_unique_indices],
                offsets=unique_offsets,
            ),
            batched_features=CatFeatures(batched_tensor=down_features, offsets=unique_offsets),
            **extra_args,
        )

    def random_downsample(self, num_sample_points: int) -> "Points":
        """
        Downsample the coordinates to the specified number of points.

        If the batch size is N, the total number of output points is N * num_sample_points.
        """
        sampled_indices, sample_offsets = random_sample_per_batch(self.offsets, num_sample_points)
        return self.__class__(
            batched_coordinates=RealCoords(
                batched_tensor=self.coordinate_tensor[sampled_indices],
                offsets=sample_offsets,
            ),
            batched_features=CatFeatures(
                batched_tensor=self.feature_tensor[sampled_indices],
                offsets=sample_offsets,
            ),
            **self.extra_attributes,
        )

    def contiguous(self) -> "Points":
        """Ensure coordinates and features are contiguous in memory.

        This is important for memory access patterns and can improve
        performance for operations that require contiguous memory.

        Returns:
            Points: A new Points instance with contiguous tensors
        """
        if self.coordinate_tensor.is_contiguous() and self.feature_tensor.is_contiguous():
            return self

        return self.__class__(
            batched_coordinates=RealCoords(
                batched_tensor=self.coordinate_tensor.contiguous(),
                offsets=self.offsets,
            ),
            batched_features=CatFeatures(
                batched_tensor=self.feature_tensor.contiguous(),
                offsets=self.offsets,
            ),
            **self.extra_attributes,
        )

    def neighbors(
        self,
        search_args: RealSearchConfig,
        query_coords: Optional["Coords"] = None,
    ) -> RealSearchResult:
        """
        Returns CSR format neighbor indices
        """
        if query_coords is None:
            query_coords = self.batched_coordinates

        assert isinstance(query_coords, Coords), "query_coords must be BatchedCoordinates"

        # cache the neighbor search result
        if self.cache is not None:
            neighbor_search_result = self.cache.get(
                search_args, self.offsets, query_coords.offsets
            )
            if neighbor_search_result is not None:
                return neighbor_search_result

        neighbor_search_result = neighbor_search(
            self.coordinate_tensor,
            self.offsets,
            query_coords.batched_tensor,
            query_coords.offsets,
            search_args,
        )
        if self.cache is None:
            self._extra_attributes["_cache"] = RealSearchCache()
        self.cache.put(search_args, self.offsets, query_coords.offsets, neighbor_search_result)
        return neighbor_search_result

    @property
    def voxel_size(self):
        return self._extra_attributes.get("voxel_size", None)

    @property
    def ordering(self):
        return self._extra_attributes.get("ordering", None)

    @classmethod
    def from_list_of_coordinates(
        cls,
        coordinates: List[Float[Tensor, "N 3"]],
        features: Optional[List[Float[Tensor, "N C"]]] = None,
        encoding_channels: Optional[int] = None,
        encoding_range: Optional[Tuple[float, float]] = None,
        encoding_dim: Optional[int] = -1,
    ):
        """
        Create a point collection from a list of coordinates.
        """
        # if the input is a tensor, expand it to a list of tensors
        if isinstance(coordinates, Tensor):
            coordinates = list(coordinates)  # this expands the tensor to a list of tensors

        if features is None:
            assert (
                encoding_range is not None
            ), "Encoding range must be provided if encoding channels are provided"
            features = [
                sinusoidal_encoding(coordinates, encoding_channels, encoding_range, encoding_dim)
                for coordinates in coordinates
            ]

        # Create BatchedContinuousCoordinates
        batched_coordinates = RealCoords(coordinates)
        # Create CatBatchedFeatures
        batched_features = CatFeatures(features)

        return cls(batched_coordinates, batched_features)

    def to_voxels(
        self,
        voxel_size: float,
        reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
    ) -> "Voxels":
        """
        Convert the point collection to a spatially sparse tensor.
        """
        return points_to_voxels(self, voxel_size, reduction)

contiguous() -> Points

Ensure coordinates and features are contiguous in memory.

This is important for memory access patterns and can improve performance for operations that require contiguous memory.

Returns: Points: A new Points instance with contiguous tensors

Source code in warpconvnet/geometry/types/points.py
def contiguous(self) -> "Points":
    """Ensure coordinates and features are contiguous in memory.

    This is important for memory access patterns and can improve
    performance for operations that require contiguous memory.

    Returns:
        Points: A new Points instance with contiguous tensors
    """
    if self.coordinate_tensor.is_contiguous() and self.feature_tensor.is_contiguous():
        return self

    return self.__class__(
        batched_coordinates=RealCoords(
            batched_tensor=self.coordinate_tensor.contiguous(),
            offsets=self.offsets,
        ),
        batched_features=CatFeatures(
            batched_tensor=self.feature_tensor.contiguous(),
            offsets=self.offsets,
        ),
        **self.extra_attributes,
    )

from_list_of_coordinates(coordinates: List[Float[Tensor, 'N 3']], features: Optional[List[Float[Tensor, 'N C']]] = None, encoding_channels: Optional[int] = None, encoding_range: Optional[Tuple[float, float]] = None, encoding_dim: Optional[int] = -1) classmethod

Create a point collection from a list of coordinates.

Source code in warpconvnet/geometry/types/points.py
@classmethod
def from_list_of_coordinates(
    cls,
    coordinates: List[Float[Tensor, "N 3"]],
    features: Optional[List[Float[Tensor, "N C"]]] = None,
    encoding_channels: Optional[int] = None,
    encoding_range: Optional[Tuple[float, float]] = None,
    encoding_dim: Optional[int] = -1,
):
    """
    Create a point collection from a list of coordinates.
    """
    # if the input is a tensor, expand it to a list of tensors
    if isinstance(coordinates, Tensor):
        coordinates = list(coordinates)  # this expands the tensor to a list of tensors

    if features is None:
        assert (
            encoding_range is not None
        ), "Encoding range must be provided if encoding channels are provided"
        features = [
            sinusoidal_encoding(coordinates, encoding_channels, encoding_range, encoding_dim)
            for coordinates in coordinates
        ]

    # Create BatchedContinuousCoordinates
    batched_coordinates = RealCoords(coordinates)
    # Create CatBatchedFeatures
    batched_features = CatFeatures(features)

    return cls(batched_coordinates, batched_features)

neighbors(search_args: RealSearchConfig, query_coords: Optional[Coords] = None) -> RealSearchResult

Returns CSR format neighbor indices

Source code in warpconvnet/geometry/types/points.py
def neighbors(
    self,
    search_args: RealSearchConfig,
    query_coords: Optional["Coords"] = None,
) -> RealSearchResult:
    """
    Returns CSR format neighbor indices
    """
    if query_coords is None:
        query_coords = self.batched_coordinates

    assert isinstance(query_coords, Coords), "query_coords must be BatchedCoordinates"

    # cache the neighbor search result
    if self.cache is not None:
        neighbor_search_result = self.cache.get(
            search_args, self.offsets, query_coords.offsets
        )
        if neighbor_search_result is not None:
            return neighbor_search_result

    neighbor_search_result = neighbor_search(
        self.coordinate_tensor,
        self.offsets,
        query_coords.batched_tensor,
        query_coords.offsets,
        search_args,
    )
    if self.cache is None:
        self._extra_attributes["_cache"] = RealSearchCache()
    self.cache.put(search_args, self.offsets, query_coords.offsets, neighbor_search_result)
    return neighbor_search_result

random_downsample(num_sample_points: int) -> Points

Downsample the coordinates to the specified number of points.

If the batch size is N, the total number of output points is N * num_sample_points.

Source code in warpconvnet/geometry/types/points.py
def random_downsample(self, num_sample_points: int) -> "Points":
    """
    Downsample the coordinates to the specified number of points.

    If the batch size is N, the total number of output points is N * num_sample_points.
    """
    sampled_indices, sample_offsets = random_sample_per_batch(self.offsets, num_sample_points)
    return self.__class__(
        batched_coordinates=RealCoords(
            batched_tensor=self.coordinate_tensor[sampled_indices],
            offsets=sample_offsets,
        ),
        batched_features=CatFeatures(
            batched_tensor=self.feature_tensor[sampled_indices],
            offsets=sample_offsets,
        ),
        **self.extra_attributes,
    )

sort(voxel_size: float, ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ)

Sort the points according to the ordering provided. The voxel size defines the smallest discretization and points in the same voxel will have random order.

Source code in warpconvnet/geometry/types/points.py
def sort(
    self,
    voxel_size: float,
    ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
):
    """
    Sort the points according to the ordering provided.
    The voxel size defines the smallest discretization and points in the same voxel will have random order.
    """
    # Warp uses int32 so only 10 bits per coordinate supported. Thus max 1024.
    assert self.device.type != "cpu", "Sorting is only supported on GPU"
    result = encode(
        torch.floor(self.coordinate_tensor / voxel_size).int(),
        batch_offsets=self.offsets,
        order=ordering,
        return_perm=True,
    )
    kwargs = self.extra_attributes.copy()
    return self.__class__(
        batched_coordinates=RealCoords(
            batched_tensor=self.coordinate_tensor[result.perm],
            offsets=self.offsets,
        ),
        batched_features=CatFeatures(
            batched_tensor=self.feature_tensor[result.perm],
            offsets=self.offsets,
        ),
        **kwargs,
    )

to_voxels(voxel_size: float, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN) -> Voxels

Convert the point collection to a spatially sparse tensor.

Source code in warpconvnet/geometry/types/points.py
def to_voxels(
    self,
    voxel_size: float,
    reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
) -> "Voxels":
    """
    Convert the point collection to a spatially sparse tensor.
    """
    return points_to_voxels(self, voxel_size, reduction)

voxel_downsample(voxel_size: float, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM) -> Points

Voxel downsample the coordinates

Source code in warpconvnet/geometry/types/points.py
def voxel_downsample(
    self,
    voxel_size: float,
    reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM,
) -> "Points":
    """
    Voxel downsample the coordinates
    """
    assert self.device.type != "cpu", "Voxel downsample is only supported on GPU"
    extra_args = self.extra_attributes
    extra_args["voxel_size"] = voxel_size
    assert isinstance(
        self.batched_features, CatFeatures
    ), "Voxel downsample is only supported for CatBatchedFeatures"
    if reduction == REDUCTIONS.RANDOM:
        to_unique_indicies, unique_offsets = voxel_downsample_random_indices(
            batched_points=self.coordinate_tensor,
            offsets=self.offsets,
            voxel_size=voxel_size,
        )
        return self.__class__(
            batched_coordinates=RealCoords(
                batched_tensor=self.coordinate_tensor[to_unique_indicies],
                offsets=unique_offsets,
            ),
            batched_features=CatFeatures(
                batched_tensor=self.feature_tensor[to_unique_indicies],
                offsets=unique_offsets,
            ),
            **extra_args,
        )

    # perm, down_offsets, vox_inices, vox_offsets = voxel_downsample_csr_mapping(
    #     batched_points=self.coordinate_tensor,
    #     offsets=self.offsets,
    #     voxel_size=voxel_size,
    # )
    (
        batch_indexed_down_coords,
        unique_offsets,
        to_csr_indices,
        to_csr_offsets,
        to_unique,
    ) = voxel_downsample_csr_mapping(
        batched_points=self.coordinate_tensor,
        offsets=self.offsets,
        voxel_size=voxel_size,
    )

    neighbors = RealSearchResult(to_csr_indices, to_csr_offsets)
    down_features = row_reduction(
        self.feature_tensor,
        neighbors.neighbor_row_splits,
        reduction,
    )

    return self.__class__(
        batched_coordinates=RealCoords(
            batched_tensor=self.coordinate_tensor[to_unique.to_unique_indices],
            offsets=unique_offsets,
        ),
        batched_features=CatFeatures(batched_tensor=down_features, offsets=unique_offsets),
        **extra_args,
    )

Voxels

Sparse voxel geometry that accepts integer coordinates with tensor strides and offers helpers to move between dense tensors and CSR-style batched features.

warpconvnet.geometry.types.voxels.Voxels

Bases: Geometry

Source code in warpconvnet/geometry/types/voxels.py
class Voxels(Geometry):
    def __init__(
        self,
        batched_coordinates: List[Float[Tensor, "N 3"]] | Float[Tensor, "N 3"] | IntCoords,
        batched_features: (
            List[Float[Tensor, "N C"]]
            | Float[Tensor, "N C"]
            | Float[Tensor, "B M C"]
            | CatFeatures
            | PadFeatures
        ),
        offsets: Optional[Int[Tensor, "B + 1"]] = None,  # noqa: F722,F821
        device: Optional[str] = None,
        **kwargs,
    ):
        # extract tensor_stride/stride from kwargs
        tensor_stride = kwargs.pop("tensor_stride", None) or kwargs.pop("stride", None)

        if isinstance(batched_coordinates, list):
            assert isinstance(
                batched_features, list
            ), "If coords is a list, features must be a list too."
            assert len(batched_coordinates) == len(batched_features)
            # Assert all elements in coords and features have same length
            assert all(len(c) == len(f) for c, f in zip(batched_coordinates, batched_features))
            batched_coordinates = IntCoords(
                batched_coordinates, device=device, tensor_stride=tensor_stride
            )
        elif isinstance(batched_coordinates, Tensor):
            assert (
                isinstance(batched_features, Tensor) and offsets is not None
            ), "If coordinate is a tensor, features must be a tensor and offsets must be provided."
            batched_coordinates = IntCoords(
                batched_coordinates,
                offsets=offsets,
                device=device,
                tensor_stride=tensor_stride,
            )
        else:
            # Input is a BatchedDiscreteCoordinates
            if tensor_stride is not None:
                batched_coordinates.set_tensor_stride(tensor_stride)

        if isinstance(batched_features, list):
            batched_features = CatFeatures(batched_features, device=device)
        elif isinstance(batched_features, Tensor):
            batched_features = to_batched_features(
                batched_features, batched_coordinates.offsets, device=device
            )

        Geometry.__init__(self, batched_coordinates, batched_features, **kwargs)

    @classmethod
    def from_dense(
        cls,
        dense_tensor: Float[Tensor, "B C H W"] | Float[Tensor, "B C H W D"],
        dense_tensor_channel_dim: int = 1,
        target_spatial_sparse_tensor: Optional["Voxels"] = None,
        dense_max_coords: Optional[Tuple[int, ...]] = None,
        **kwargs,
    ):
        # Move channel dimension to the end
        if dense_tensor_channel_dim != -1 or dense_tensor.ndim != dense_tensor_channel_dim + 1:
            dense_tensor = dense_tensor.moveaxis(dense_tensor_channel_dim, -1)
        spatial_shape = dense_tensor.shape[1:-1]
        batched_spatial_shape = dense_tensor.shape[:-1]

        # Flatten the spatial dimensions
        flattened_tensor = dense_tensor.flatten(0, -2)

        if target_spatial_sparse_tensor is None:
            # abs sum all elements in the tensor
            abs_sum = torch.abs(dense_tensor).sum(dim=-1, keepdim=False)
            # Find all non-zero elements. Expected to be sorted.
            non_zero_inds = torch.nonzero(abs_sum).int()
            # Convert multi-dimensional indices to flattened indices
            flattened_indices = ravel_multi_index(non_zero_inds, batched_spatial_shape)

            # Use index_select to get the features
            non_zero_feats = torch.index_select(flattened_tensor, 0, flattened_indices)

            offsets = offsets_from_batch_index(non_zero_inds[:, 0])
            return cls(
                batched_coordinates=IntCoords(non_zero_inds[:, 1:], offsets=offsets),
                batched_features=CatFeatures(non_zero_feats, offsets=offsets),
                **kwargs,
            )
        else:
            assert target_spatial_sparse_tensor.num_spatial_dims == len(spatial_shape)
            assert target_spatial_sparse_tensor.batch_size == batched_spatial_shape[0]
            # Use the provided spatial sparse tensor's coordinate only
            batch_indexed_coords = target_spatial_sparse_tensor.batch_indexed_coordinates.clone()
            # subtract the min_coords
            min_coords = target_spatial_sparse_tensor.coordinate_tensor.min(dim=0).values.view(
                1, -1
            )
            batch_indexed_coords[:, 1:] = batch_indexed_coords[:, 1:] - min_coords
            if dense_max_coords is not None:
                invalid = (batch_indexed_coords[:, 1:] > dense_max_coords).any(dim=1)
                batch_indexed_coords[invalid] = 0
            else:
                sparse_max_coords = batch_indexed_coords[:, 1:].max(dim=0).values
                # This assumes the max_coords are already aligned with the spatial_shape.
                assert torch.all(
                    sparse_max_coords.cpu() < torch.tensor(spatial_shape)
                ), "Max coords must be aligned with the spatial shape."
            # Ravel the coordinates. This assumes the max_coords are already aligned with the spatial_shape.
            flattened_indices = ravel_multi_index(batch_indexed_coords, batched_spatial_shape)
            # use index_select to get the features
            non_zero_feats = torch.index_select(flattened_tensor, 0, flattened_indices)
            if dense_max_coords is not None:
                non_zero_feats[invalid] = 0
            return target_spatial_sparse_tensor.replace(batched_features=non_zero_feats)

    def to_dense(
        self,
        channel_dim: int = 1,
        spatial_shape: Optional[Tuple[int, ...]] = None,
        min_coords: Optional[Tuple[int, ...]] = None,
        max_coords: Optional[Tuple[int, ...]] = None,
    ) -> Float[Tensor, "B C H W D"] | Float[Tensor, "B C H W"]:
        device = self.batched_coordinates.device

        # Get the batch indexed coordinates and features
        batch_indexed_coords = self.batched_coordinates.batch_indexed_coordinates.clone()
        features = self.batched_features.batched_tensor

        # Get the spatial shape.
        # If min_coords and max_coords are provided, assert that spatial_shape matches
        if spatial_shape is None and min_coords is None:
            # Get the min max coordinates
            coords = batch_indexed_coords[:, 1:]
            if coords.shape[0] == 0:  # Handle empty tensor case
                min_coords_tensor = torch.zeros(
                    self.num_spatial_dims, dtype=torch.long, device=device
                )
                spatial_shape_tensor = torch.zeros(
                    self.num_spatial_dims, dtype=torch.long, device=device
                )
            else:
                min_coords_tensor = coords.min(dim=0).values
                max_coords_tensor = coords.max(dim=0).values
                spatial_shape_tensor = max_coords_tensor - min_coords_tensor + 1
            # Shift the coordinates to the min_coords
            batch_indexed_coords[:, 1:] = batch_indexed_coords[:, 1:] - min_coords_tensor.to(
                device
            )
            spatial_shape = tuple(s.item() for s in spatial_shape_tensor)
        elif min_coords is not None:
            # Assert either max_coords or spatial_shape is provided
            assert max_coords is not None or spatial_shape is not None
            # Convert min_coords to tensor
            min_coords_tensor = torch.tensor(min_coords, dtype=torch.int32, device=device)
            if max_coords is None:
                # convert spatial_shape to tensor
                spatial_shape_tensor = torch.tensor(
                    spatial_shape, dtype=torch.int32, device=device
                )
                # max_coords_tensor = min_coords_tensor + spatial_shape_tensor - 1
            else:  # both min_coords and max_coords are provided
                # Convert max_coords to tensor
                max_coords_tensor = torch.tensor(max_coords, dtype=torch.int32, device=device)
                assert len(min_coords_tensor) == len(max_coords_tensor) == self.num_spatial_dims
                spatial_shape_tensor = max_coords_tensor - min_coords_tensor + 1
            # Shift the coordinates to the min_coords and clip to the spatial_shape
            # Create a mask to identify coordinates within the spatial range
            mask = torch.ones(batch_indexed_coords.shape[0], dtype=torch.bool, device=device)
            for d in range(1, batch_indexed_coords.shape[1]):
                mask &= (batch_indexed_coords[:, d] >= min_coords_tensor[d - 1].item()) & (
                    batch_indexed_coords[:, d]
                    < min_coords_tensor[d - 1].item() + spatial_shape_tensor[d - 1].item()
                )
            batch_indexed_coords = batch_indexed_coords[mask]
            features = features[mask]
            spatial_shape = tuple(s.item() for s in spatial_shape_tensor)
        elif spatial_shape is not None and len(spatial_shape) == self.coordinate_tensor.shape[1]:
            # prepend a batch dimension
            pass
        else:
            raise ValueError(
                f"Provided spatial shape {spatial_shape} must be same length as the number of spatial dimensions {self.num_spatial_dims}."
            )

        if isinstance(spatial_shape, Tensor):  # Should be tuple by now
            spatial_shape = spatial_shape.tolist()

        # Create a dense tensor
        dense_tensor = torch.zeros(
            (self.batch_size, *spatial_shape, self.num_channels),
            dtype=self.batched_features.dtype,
            device=self.batched_features.device,
        )

        if batch_indexed_coords.shape[0] > 0:  # Only scatter if there are points
            # Flatten view and scatter add
            flattened_indices = ravel_multi_index(
                batch_indexed_coords, (self.batch_size, *spatial_shape)
            )
            dense_tensor.flatten(0, -2)[flattened_indices] = features

        if channel_dim != -1:
            # Put the channel dimension in the specified position and move the rest of the dimensions contiguous
            dense_tensor = dense_tensor.moveaxis(-1, channel_dim)
        return dense_tensor

    def to_point(self, voxel_size: Optional[float] = None) -> "Points":  # noqa: F821
        if voxel_size is None:
            assert (
                self.voxel_size is not None
            ), "Voxel size must be provided or the object must have been initialized with a voxel size to convert to point."
            voxel_size = self.voxel_size

        # tensor stride
        if self.tensor_stride is not None:
            tensor_stride = self.tensor_stride
            # multiply voxel_size by tensor_stride
            voxel_size = torch.Tensor([[voxel_size * s for s in tensor_stride]]).to(self.device)

        from warpconvnet.geometry.types.points import Points

        batched_points = RealCoords(self.coordinate_tensor * voxel_size, self.offsets)
        return Points(batched_points, self.batched_features)

    def sort(self, ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ) -> "Voxels":
        if ordering == self.ordering:
            return self

        assert isinstance(
            self.batched_features, CatFeatures
        ), "Features must be a CatBatchedFeatures to sort."

        code_result = encode(
            self.coordinate_tensor,
            batch_offsets=self.offsets,
            order=ordering,
            return_perm=True,
        )
        kwargs = self.extra_attributes.copy()
        kwargs["ordering"] = ordering
        kwargs["code"] = code_result.codes
        return self.__class__(
            batched_coordinates=IntCoords(self.coordinate_tensor[code_result.perm], self.offsets),
            batched_features=CatFeatures(self.feature_tensor[code_result.perm], self.offsets),
            **kwargs,
        )

    def unique(self) -> "Voxels":
        unique_indices, batch_offsets = voxel_downsample_random_indices(
            self.coordinate_tensor,
            self.offsets,
        )
        coords = IntCoords(self.coordinate_tensor[unique_indices], batch_offsets)
        feats = CatFeatures(self.feature_tensor[unique_indices], batch_offsets)
        return self.__class__(coords, feats, **self.extra_attributes)

    @property
    def coordinate_hashmap(self) -> TorchHashTable:
        return self.batched_coordinates.hashmap

    @property
    def voxel_size(self):
        return self.extra_attributes.get("voxel_size", None)

    @property
    def ordering(self):
        return self.extra_attributes.get("ordering", None)

    @property
    def stride(self):
        return self.tensor_stride

    @property
    def tensor_stride(self):
        return self.batched_coordinates.tensor_stride

    def set_tensor_stride(self, tensor_stride: Union[int, Tuple[int, ...]]):
        self.batched_coordinates.set_tensor_stride(tensor_stride)

    @property
    def batch_indexed_coordinates(self) -> Tensor:
        return self.batched_coordinates.batch_indexed_coordinates

Grid

Regular dense grid representation that keeps GridCoords and GridFeatures in sync, providing utilities for shape validation, format conversions, and batch-aware initialization.

warpconvnet.geometry.types.grid.Grid

Bases: Geometry

Grid geometry representation that combines coordinates and features.

This class provides a unified interface for grid-based geometries with any memory format, combining grid coordinates with grid features.

Args: batched_coordinates (GridCoords): Coordinate system for the grid batched_features (Union[GridFeatures, Tensor]): Grid features memory_format (GridMemoryFormat): Memory format for the features grid_shape (Tuple[int, int, int], optional): Grid resolution (H, W, D) num_channels (int, optional): Number of feature channels **kwargs: Additional parameters

Source code in warpconvnet/geometry/types/grid.py
class Grid(Geometry):
    """Grid geometry representation that combines coordinates and features.

    This class provides a unified interface for grid-based geometries with any
    memory format, combining grid coordinates with grid features.

    Args:
        batched_coordinates (GridCoords): Coordinate system for the grid
        batched_features (Union[GridFeatures, Tensor]): Grid features
        memory_format (GridMemoryFormat): Memory format for the features
        grid_shape (Tuple[int, int, int], optional): Grid resolution (H, W, D)
        num_channels (int, optional): Number of feature channels
        **kwargs: Additional parameters
    """

    def __init__(
        self,
        batched_coordinates: GridCoords,
        batched_features: Union[GridFeatures, Tensor],
        memory_format: Optional[GridMemoryFormat] = None,
        grid_shape: Optional[Tuple[int, int, int]] = None,
        num_channels: Optional[int] = None,
        **kwargs,
    ):
        if isinstance(batched_features, Tensor):
            assert (
                memory_format is not None
            ), "Memory format must be provided if features are a tensor"
            if grid_shape is None:
                grid_shape = batched_coordinates.grid_shape

            # If num_channels not provided, infer it from tensor shape and memory format
            if num_channels is None:
                if memory_format == GridMemoryFormat.b_x_y_z_c:
                    num_channels = batched_features.shape[-1]
                elif memory_format == GridMemoryFormat.b_c_x_y_z:
                    num_channels = batched_features.shape[1]
                elif memory_format == GridMemoryFormat.b_c_z_x_y:
                    num_channels = batched_features.shape[1]
                elif memory_format == GridMemoryFormat.b_zc_x_y:
                    zc = batched_features.shape[1]
                    num_channels = zc // grid_shape[2]
                elif memory_format == GridMemoryFormat.b_xc_y_z:
                    xc = batched_features.shape[1]
                    num_channels = xc // grid_shape[0]
                elif memory_format == GridMemoryFormat.b_yc_x_z:
                    yc = batched_features.shape[1]
                    num_channels = yc // grid_shape[1]
                else:
                    raise ValueError(f"Unsupported memory format: {memory_format}")

            # Create GridFeatures with same offsets as coordinates
            batched_features = GridFeatures(
                batched_features,
                batched_coordinates.offsets.clone(),
                memory_format,
                grid_shape,
                num_channels,
            )
        else:
            assert (
                memory_format is None or memory_format == batched_features.memory_format
            ), f"Memory format must be None or match the GridFeatures memory format: {batched_features.memory_format}. Provided: {memory_format}"

        # Check that the grid is valid
        self.check(batched_coordinates, batched_features)

        # Ensure offsets match if coordinates are not lazy
        assert (
            batched_coordinates.offsets == batched_features.offsets
        ).all(), "Coordinate and feature offsets must match"

        # Initialize base class
        super().__init__(batched_coordinates, batched_features, **kwargs)

    def check(self, coords: GridCoords, features: GridFeatures):
        """
        Check if the grid dimensions are consistent
        """
        assert coords.shape[-1] == 3

        num_coords = coords.numel() // 3
        num_features = features.numel() // features.num_channels
        assert (
            num_coords == num_features
        ), f"Number of coordinates ({num_coords}) must match number of features ({num_features})"
        assert (
            coords.grid_shape == features.grid_shape
        ), f"Grid shape ({coords.grid_shape}) must match feature grid shape ({features.grid_shape})"

    @classmethod
    def from_shape(
        cls,
        grid_shape: Tuple[int, int, int],
        num_channels: int,
        memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c,
        bounds: Optional[Tuple[Tensor, Tensor]] = None,
        batch_size: int = 1,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
        **kwargs,
    ) -> "Grid":
        """
        Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.

        Args:
            grid_shape: Grid resolution (H, W, D)
            num_channels: Number of feature channels
            memory_format: Memory format for features
            bounds: Min and max bounds for the grid
            batch_size: Number of batches
            device: Device to create tensors on
            dtype: Data type for feature tensors
            **kwargs: Additional parameters

        Returns:
            Initialized grid geometry
        """
        # Create coordinates. By default, data will be lazily initialized and coordinates will be flattened.
        coords = GridCoords.from_shape(
            grid_shape=grid_shape,
            bounds=bounds,
            batch_size=batch_size,
            device=device,
            flatten=True,
        )

        # Create empty features with same offsets
        features = GridFeatures.create_empty(
            grid_shape=grid_shape,
            num_channels=num_channels,
            batch_size=batch_size,
            memory_format=memory_format,
            device=device,
            dtype=dtype,
        )

        # Make sure offsets match
        assert (
            coords.offsets == features.offsets
        ).all(), "Coordinate and feature offsets must match"

        return cls(coords, features, memory_format, **kwargs)

    @property
    def grid_features(self) -> GridFeatures:
        """Get the grid features."""
        return self.batched_features

    @property
    def grid_coords(self) -> GridCoords:
        """Get the grid coordinates."""
        return self.batched_coordinates

    @property
    def grid_shape(self) -> Tuple[int, int, int]:
        """Get the grid shape (H, W, D)."""
        return self.grid_coords.grid_shape

    @property
    def bounds(self) -> Tuple[Tensor, Tensor]:
        """Get the bounds of the grid."""
        return self.grid_coords.bounds

    @property
    def num_channels(self) -> int:
        """Get the number of feature channels."""
        return self.grid_features.num_channels

    @property
    def memory_format(self) -> GridMemoryFormat:
        """Get the memory format."""
        return self.grid_features.memory_format

    def channel_size(self, memory_format: Optional[GridMemoryFormat] = None):
        if memory_format is None:
            memory_format = self.memory_format
        if memory_format == GridMemoryFormat.b_x_y_z_c:
            return self.num_channels
        elif memory_format == GridMemoryFormat.b_c_x_y_z:
            return self.num_channels
        elif memory_format == GridMemoryFormat.b_xc_y_z:
            return self.num_channels * self.grid_shape[0]
        elif memory_format == GridMemoryFormat.b_yc_x_z:
            return self.num_channels * self.grid_shape[1]
        elif memory_format == GridMemoryFormat.b_zc_x_y:
            return self.num_channels * self.grid_shape[2]
        else:
            raise ValueError(f"Unsupported memory format: {memory_format}")

    def to_memory_format(self, memory_format: GridMemoryFormat) -> "Grid":
        """Convert to a different memory format."""
        if memory_format != self.memory_format:
            return self.replace(
                batched_features=self.grid_features.to_memory_format(memory_format),
                memory_format=memory_format,
            )
        return self

    @property
    def shape(self) -> Dict[str, Union[int, Tuple[int, ...]]]:
        """Get the shape information."""
        H, W, D = self.grid_shape
        return {
            "grid_shape": self.grid_shape,
            "batch_size": self.batch_size,
            "num_channels": self.num_channels,
            "total_elements": H * W * D * self.batch_size,
        }

    def to(self, device: torch.device) -> "Grid":
        """Move the geometry to the specified device."""
        return Grid(
            self.grid_coords.to(device),
            self.grid_features.to(device),
            self.memory_format,
        )

    def replace(
        self,
        batched_coordinates: Optional[GridCoords] = None,
        batched_features: Optional[Union[GridFeatures, Tensor]] = None,
        **kwargs,
    ) -> "Grid":
        """Create a new instance with replaced coordinates and/or features."""
        # Convert the batched_features to a GridFeatures if it is a tensor
        if isinstance(batched_features, Tensor) and batched_features.ndim == 5:
            # Based on the memory format, we have to check the shape of the tensor
            if self.memory_format == GridMemoryFormat.b_x_y_z_c:
                in_H, in_W, in_D, in_C = batched_features.shape[1:5]
                assert in_H == self.grid_shape[0]
                assert in_W == self.grid_shape[1]
                assert in_D == self.grid_shape[2]
                assert in_C == self.num_channels
            elif self.memory_format == GridMemoryFormat.b_c_x_y_z:
                in_C, in_H, in_W, in_D = batched_features.shape[1:5]
                assert in_C == self.num_channels
                assert in_H == self.grid_shape[0]
                assert in_W == self.grid_shape[1]
                assert in_D == self.grid_shape[2]
            elif self.memory_format == GridMemoryFormat.b_c_z_x_y:
                in_C, in_D, in_H, in_W = batched_features.shape[1:5]
                assert in_C == self.num_channels
                assert in_D == self.grid_shape[2]
                assert in_H == self.grid_shape[0]
                assert in_W == self.grid_shape[1]
            else:
                raise ValueError(f"Unsupported memory format: {self.memory_format}")

            batched_features = GridFeatures(
                batched_tensor=batched_features,
                offsets=self.grid_features.offsets,
                memory_format=self.memory_format,
                grid_shape=self.grid_shape,
                num_channels=in_C,
            )
        elif isinstance(batched_features, Tensor) and batched_features.ndim == 4:
            # This is the compressed format
            assert self.memory_format in [
                GridMemoryFormat.b_zc_x_y,
                GridMemoryFormat.b_xc_y_z,
                GridMemoryFormat.b_yc_x_z,
            ], f"Unsupported memory format: {self.memory_format} for feature tensor of shape {batched_features.shape}"
            # Assert that the grid shape is consistent with the feature tensor shape
            # Only the channel dim can change when using .replace()
            # e.g. in_H, in_W, in_D == self.grid_shape[0], self.grid_shape[1], self.grid_shape[2]
            compressed_dim = batched_features.shape[1]  # this is the compressed_dim * channels
            new_channel_dim = None
            if self.memory_format == GridMemoryFormat.b_zc_x_y:
                assert batched_features.shape[2] == self.grid_shape[0]
                assert batched_features.shape[3] == self.grid_shape[1]
                new_channel_dim = compressed_dim // self.grid_shape[2]
            elif self.memory_format == GridMemoryFormat.b_xc_y_z:
                assert batched_features.shape[2] == self.grid_shape[1]
                assert batched_features.shape[3] == self.grid_shape[2]
                new_channel_dim = compressed_dim // self.grid_shape[0]
            elif self.memory_format == GridMemoryFormat.b_yc_x_z:
                assert batched_features.shape[2] == self.grid_shape[0]
                assert batched_features.shape[3] == self.grid_shape[2]
                new_channel_dim = compressed_dim // self.grid_shape[1]
            else:
                raise ValueError(f"Unsupported memory format: {self.memory_format}")

            batched_features = GridFeatures(
                batched_tensor=batched_features,
                offsets=self.grid_features.offsets,
                memory_format=self.memory_format,
                grid_shape=self.grid_shape,
                num_channels=new_channel_dim,
            )
        elif isinstance(batched_features, GridFeatures):
            # no action needed
            pass
        else:
            raise ValueError(f"Unsupported feature tensor shape: {batched_features.shape}")

        return super().replace(batched_coordinates, batched_features, **kwargs)

bounds: Tuple[Tensor, Tensor] property

Get the bounds of the grid.

grid_coords: GridCoords property

Get the grid coordinates.

grid_features: GridFeatures property

Get the grid features.

grid_shape: Tuple[int, int, int] property

Get the grid shape (H, W, D).

memory_format: GridMemoryFormat property

Get the memory format.

num_channels: int property

Get the number of feature channels.

shape: Dict[str, Union[int, Tuple[int, ...]]] property

Get the shape information.

check(coords: GridCoords, features: GridFeatures)

Check if the grid dimensions are consistent

Source code in warpconvnet/geometry/types/grid.py
def check(self, coords: GridCoords, features: GridFeatures):
    """
    Check if the grid dimensions are consistent
    """
    assert coords.shape[-1] == 3

    num_coords = coords.numel() // 3
    num_features = features.numel() // features.num_channels
    assert (
        num_coords == num_features
    ), f"Number of coordinates ({num_coords}) must match number of features ({num_features})"
    assert (
        coords.grid_shape == features.grid_shape
    ), f"Grid shape ({coords.grid_shape}) must match feature grid shape ({features.grid_shape})"

from_shape(grid_shape: Tuple[int, int, int], num_channels: int, memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c, bounds: Optional[Tuple[Tensor, Tensor]] = None, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> Grid classmethod

Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.

Args: grid_shape: Grid resolution (H, W, D) num_channels: Number of feature channels memory_format: Memory format for features bounds: Min and max bounds for the grid batch_size: Number of batches device: Device to create tensors on dtype: Data type for feature tensors **kwargs: Additional parameters

Returns: Initialized grid geometry

Source code in warpconvnet/geometry/types/grid.py
@classmethod
def from_shape(
    cls,
    grid_shape: Tuple[int, int, int],
    num_channels: int,
    memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c,
    bounds: Optional[Tuple[Tensor, Tensor]] = None,
    batch_size: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    **kwargs,
) -> "Grid":
    """
    Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.

    Args:
        grid_shape: Grid resolution (H, W, D)
        num_channels: Number of feature channels
        memory_format: Memory format for features
        bounds: Min and max bounds for the grid
        batch_size: Number of batches
        device: Device to create tensors on
        dtype: Data type for feature tensors
        **kwargs: Additional parameters

    Returns:
        Initialized grid geometry
    """
    # Create coordinates. By default, data will be lazily initialized and coordinates will be flattened.
    coords = GridCoords.from_shape(
        grid_shape=grid_shape,
        bounds=bounds,
        batch_size=batch_size,
        device=device,
        flatten=True,
    )

    # Create empty features with same offsets
    features = GridFeatures.create_empty(
        grid_shape=grid_shape,
        num_channels=num_channels,
        batch_size=batch_size,
        memory_format=memory_format,
        device=device,
        dtype=dtype,
    )

    # Make sure offsets match
    assert (
        coords.offsets == features.offsets
    ).all(), "Coordinate and feature offsets must match"

    return cls(coords, features, memory_format, **kwargs)

replace(batched_coordinates: Optional[GridCoords] = None, batched_features: Optional[Union[GridFeatures, Tensor]] = None, **kwargs) -> Grid

Create a new instance with replaced coordinates and/or features.

Source code in warpconvnet/geometry/types/grid.py
def replace(
    self,
    batched_coordinates: Optional[GridCoords] = None,
    batched_features: Optional[Union[GridFeatures, Tensor]] = None,
    **kwargs,
) -> "Grid":
    """Create a new instance with replaced coordinates and/or features."""
    # Convert the batched_features to a GridFeatures if it is a tensor
    if isinstance(batched_features, Tensor) and batched_features.ndim == 5:
        # Based on the memory format, we have to check the shape of the tensor
        if self.memory_format == GridMemoryFormat.b_x_y_z_c:
            in_H, in_W, in_D, in_C = batched_features.shape[1:5]
            assert in_H == self.grid_shape[0]
            assert in_W == self.grid_shape[1]
            assert in_D == self.grid_shape[2]
            assert in_C == self.num_channels
        elif self.memory_format == GridMemoryFormat.b_c_x_y_z:
            in_C, in_H, in_W, in_D = batched_features.shape[1:5]
            assert in_C == self.num_channels
            assert in_H == self.grid_shape[0]
            assert in_W == self.grid_shape[1]
            assert in_D == self.grid_shape[2]
        elif self.memory_format == GridMemoryFormat.b_c_z_x_y:
            in_C, in_D, in_H, in_W = batched_features.shape[1:5]
            assert in_C == self.num_channels
            assert in_D == self.grid_shape[2]
            assert in_H == self.grid_shape[0]
            assert in_W == self.grid_shape[1]
        else:
            raise ValueError(f"Unsupported memory format: {self.memory_format}")

        batched_features = GridFeatures(
            batched_tensor=batched_features,
            offsets=self.grid_features.offsets,
            memory_format=self.memory_format,
            grid_shape=self.grid_shape,
            num_channels=in_C,
        )
    elif isinstance(batched_features, Tensor) and batched_features.ndim == 4:
        # This is the compressed format
        assert self.memory_format in [
            GridMemoryFormat.b_zc_x_y,
            GridMemoryFormat.b_xc_y_z,
            GridMemoryFormat.b_yc_x_z,
        ], f"Unsupported memory format: {self.memory_format} for feature tensor of shape {batched_features.shape}"
        # Assert that the grid shape is consistent with the feature tensor shape
        # Only the channel dim can change when using .replace()
        # e.g. in_H, in_W, in_D == self.grid_shape[0], self.grid_shape[1], self.grid_shape[2]
        compressed_dim = batched_features.shape[1]  # this is the compressed_dim * channels
        new_channel_dim = None
        if self.memory_format == GridMemoryFormat.b_zc_x_y:
            assert batched_features.shape[2] == self.grid_shape[0]
            assert batched_features.shape[3] == self.grid_shape[1]
            new_channel_dim = compressed_dim // self.grid_shape[2]
        elif self.memory_format == GridMemoryFormat.b_xc_y_z:
            assert batched_features.shape[2] == self.grid_shape[1]
            assert batched_features.shape[3] == self.grid_shape[2]
            new_channel_dim = compressed_dim // self.grid_shape[0]
        elif self.memory_format == GridMemoryFormat.b_yc_x_z:
            assert batched_features.shape[2] == self.grid_shape[0]
            assert batched_features.shape[3] == self.grid_shape[2]
            new_channel_dim = compressed_dim // self.grid_shape[1]
        else:
            raise ValueError(f"Unsupported memory format: {self.memory_format}")

        batched_features = GridFeatures(
            batched_tensor=batched_features,
            offsets=self.grid_features.offsets,
            memory_format=self.memory_format,
            grid_shape=self.grid_shape,
            num_channels=new_channel_dim,
        )
    elif isinstance(batched_features, GridFeatures):
        # no action needed
        pass
    else:
        raise ValueError(f"Unsupported feature tensor shape: {batched_features.shape}")

    return super().replace(batched_coordinates, batched_features, **kwargs)

to(device: torch.device) -> Grid

Move the geometry to the specified device.

Source code in warpconvnet/geometry/types/grid.py
def to(self, device: torch.device) -> "Grid":
    """Move the geometry to the specified device."""
    return Grid(
        self.grid_coords.to(device),
        self.grid_features.to(device),
        self.memory_format,
    )

to_memory_format(memory_format: GridMemoryFormat) -> Grid

Convert to a different memory format.

Source code in warpconvnet/geometry/types/grid.py
def to_memory_format(self, memory_format: GridMemoryFormat) -> "Grid":
    """Convert to a different memory format."""
    if memory_format != self.memory_format:
        return self.replace(
            batched_features=self.grid_features.to_memory_format(memory_format),
            memory_format=memory_format,
        )
    return self

Factor grid

Container that bundles multiple Grid instances with distinct factorized memory formats so that FIGConvNet layers can operate on complementary spatial perspectives.

warpconvnet.geometry.types.factor_grid.FactorGrid dataclass

A group of grid geometries with different factorized memory formats.

This class implements the core concept of FIGConvNet where the 3D space is represented as multiple factorized 2D grids with different memory formats.

Args: geometries: List of GridGeometry objects with different factorized formats

Source code in warpconvnet/geometry/types/factor_grid.py
@dataclass
class FactorGrid:
    """A group of grid geometries with different factorized memory formats.

    This class implements the core concept of FIGConvNet where the 3D space
    is represented as multiple factorized 2D grids with different memory formats.

    Args:
        geometries: List of GridGeometry objects with different factorized formats
    """

    grids: List[Grid]
    _extra_attributes: Dict[str, Any] = field(default_factory=dict, init=True)  # Store extra args

    def __init__(self, grids: List[Grid], **kwargs):
        self.grids = grids

        # Validate we have at least one geometry
        assert len(grids) > 0, "At least one geometry must be provided"

        batch_size = grids[0].batch_size
        num_channels = grids[0].num_channels

        # Verify all geometries have the same batch size, channels, and grid shape
        for geo in grids:
            assert geo.batch_size == batch_size, "All geometries must have the same batch size"
            assert (
                geo.num_channels == num_channels
            ), "All geometries must have the same number of channels"

            # Ensure each geometry uses a factorized format
            assert geo.memory_format in [
                GridMemoryFormat.b_zc_x_y,
                GridMemoryFormat.b_xc_y_z,
                GridMemoryFormat.b_yc_x_z,
            ], f"Expected factorized format, got {geo.memory_format}"

        # Check for memory format duplicates
        memory_formats = [geo.memory_format for geo in grids]
        assert len(memory_formats) == len(
            set(memory_formats)
        ), "Each geometry must have a unique memory format"

        # Extra arguments for subclasses
        # First check _extra_attributes in kwargs. This happens when we use dataclasses.replace
        if "_extra_attributes" in kwargs:
            attr = kwargs.pop("_extra_attributes")
            assert isinstance(attr, dict), f"_extra_attributes must be a dictionary, got {attr}"
            # Update kwargs
            for k, v in attr.items():
                kwargs[k] = v
        self._extra_attributes = kwargs

    @classmethod
    def create_from_grid_shape(
        cls,
        grid_shapes: List[Tuple[int, int, int]],
        num_channels: int,
        memory_formats: List[Union[GridMemoryFormat, str]] = [
            "b_zc_x_y",
            "b_xc_y_z",
            "b_yc_x_z",
        ],
        bounds: Optional[Tuple[Tensor, Tensor]] = None,
        batch_size: int = 1,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> "FactorGrid":
        """Create a new factorized grid geometry with initialized geometries.

        Args:
            grid_shapes: List of grid resolutions (H, W, D)
            num_channels: Number of feature channels
            memory_formats: List of factorized formats to use
            bounds: Min and max bounds for the grid
            batch_size: Number of batches
            device: Device to create tensors on
            dtype: Data type for feature tensors

        Returns:
            Initialized factorized grid geometry
        """
        assert len(grid_shapes) == len(
            memory_formats
        ), "grid_shapes and memory_formats must have the same length"
        for grid_shape in grid_shapes:
            assert (
                isinstance(grid_shape, tuple) and len(grid_shape) == 3
            ), f"grid_shape: {grid_shape} must be a tuple of 3 integers."
        # First create a standard grid geometry
        geometries = []
        for grid_shape, memory_format in zip(grid_shapes, memory_formats):
            if isinstance(memory_format, str):
                memory_format = GridMemoryFormat(memory_format)
            geometry = Grid.from_shape(
                grid_shape,
                num_channels,
                memory_format=memory_format,
                bounds=bounds,
                batch_size=batch_size,
                device=device,
                dtype=dtype,
            )
            geometries.append(geometry)

        # Then convert to each factorized format
        return cls(geometries)

    @property
    def batch_size(self) -> int:
        """Return the batch size of the geometries."""
        return self.grids[0].batch_size

    @property
    def num_channels(self) -> int:
        """Return the number of channels in the geometries."""
        return self.grids[0].num_channels

    @property
    def device(self) -> torch.device:
        """Return the device of the geometries."""
        return self.grids[0].device

    def to(self, device: torch.device) -> "FactorGrid":
        """Move all geometries to the specified device."""
        return FactorGrid([geo.to(device) for geo in self.grids])

    def __getitem__(self, idx: int) -> Grid:
        """Get a specific geometry from the group."""
        return self.grids[idx]

    def __len__(self) -> int:
        """Get the number of geometries in the group."""
        return len(self.grids)

    def __iter__(self):
        """Iterate over the grids."""
        return iter(self.grids)

    def __repr__(self) -> str:
        """String representation of the FactorGrid."""
        out_str = "FactorGrid("
        for grid in self.grids:
            out_str += f"\n\t{grid}"
        out_str += "\n)"
        return out_str

    def __add__(self, other: "FactorGrid") -> "FactorGrid":
        """Add two FactorGrid objects together element-wise."""
        assert len(self) == len(
            other
        ), f"FactorGrid lengths must match: {len(self)} != {len(other)}"
        new_grids = []
        for grid_a, grid_b in zip(self.grids, other.grids):
            # Add features together using Grid.replace()
            new_features = (
                grid_a.grid_features.batched_tensor + grid_b.grid_features.batched_tensor
            )
            new_grid = grid_a.replace(batched_features=new_features)
            new_grids.append(new_grid)
        return FactorGrid(new_grids)

    def get_by_format(self, memory_format: GridMemoryFormat) -> Optional[Grid]:
        """Get a geometry with the specified memory format.

        Args:
            memory_format: The memory format to look for

        Returns:
            The geometry with the requested format, or None if not found
        """
        for geo in self.grids:
            if geo.memory_format == memory_format:
                return geo
        return None

    @property
    def shapes(self) -> List[Dict[str, Union[int, Tuple[int, ...]]]]:
        """Get shape information for all geometries."""
        return [geo.shape for geo in self.grids]

batch_size: int property

Return the batch size of the geometries.

device: torch.device property

Return the device of the geometries.

num_channels: int property

Return the number of channels in the geometries.

shapes: List[Dict[str, Union[int, Tuple[int, ...]]]] property

Get shape information for all geometries.

create_from_grid_shape(grid_shapes: List[Tuple[int, int, int]], num_channels: int, memory_formats: List[Union[GridMemoryFormat, str]] = ['b_zc_x_y', 'b_xc_y_z', 'b_yc_x_z'], bounds: Optional[Tuple[Tensor, Tensor]] = None, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> FactorGrid classmethod

Create a new factorized grid geometry with initialized geometries.

Args: grid_shapes: List of grid resolutions (H, W, D) num_channels: Number of feature channels memory_formats: List of factorized formats to use bounds: Min and max bounds for the grid batch_size: Number of batches device: Device to create tensors on dtype: Data type for feature tensors

Returns: Initialized factorized grid geometry

Source code in warpconvnet/geometry/types/factor_grid.py
@classmethod
def create_from_grid_shape(
    cls,
    grid_shapes: List[Tuple[int, int, int]],
    num_channels: int,
    memory_formats: List[Union[GridMemoryFormat, str]] = [
        "b_zc_x_y",
        "b_xc_y_z",
        "b_yc_x_z",
    ],
    bounds: Optional[Tuple[Tensor, Tensor]] = None,
    batch_size: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> "FactorGrid":
    """Create a new factorized grid geometry with initialized geometries.

    Args:
        grid_shapes: List of grid resolutions (H, W, D)
        num_channels: Number of feature channels
        memory_formats: List of factorized formats to use
        bounds: Min and max bounds for the grid
        batch_size: Number of batches
        device: Device to create tensors on
        dtype: Data type for feature tensors

    Returns:
        Initialized factorized grid geometry
    """
    assert len(grid_shapes) == len(
        memory_formats
    ), "grid_shapes and memory_formats must have the same length"
    for grid_shape in grid_shapes:
        assert (
            isinstance(grid_shape, tuple) and len(grid_shape) == 3
        ), f"grid_shape: {grid_shape} must be a tuple of 3 integers."
    # First create a standard grid geometry
    geometries = []
    for grid_shape, memory_format in zip(grid_shapes, memory_formats):
        if isinstance(memory_format, str):
            memory_format = GridMemoryFormat(memory_format)
        geometry = Grid.from_shape(
            grid_shape,
            num_channels,
            memory_format=memory_format,
            bounds=bounds,
            batch_size=batch_size,
            device=device,
            dtype=dtype,
        )
        geometries.append(geometry)

    # Then convert to each factorized format
    return cls(geometries)

get_by_format(memory_format: GridMemoryFormat) -> Optional[Grid]

Get a geometry with the specified memory format.

Args: memory_format: The memory format to look for

Returns: The geometry with the requested format, or None if not found

Source code in warpconvnet/geometry/types/factor_grid.py
def get_by_format(self, memory_format: GridMemoryFormat) -> Optional[Grid]:
    """Get a geometry with the specified memory format.

    Args:
        memory_format: The memory format to look for

    Returns:
        The geometry with the requested format, or None if not found
    """
    for geo in self.grids:
        if geo.memory_format == memory_format:
            return geo
    return None

to(device: torch.device) -> FactorGrid

Move all geometries to the specified device.

Source code in warpconvnet/geometry/types/factor_grid.py
def to(self, device: torch.device) -> "FactorGrid":
    """Move all geometries to the specified device."""
    return FactorGrid([geo.to(device) for geo in self.grids])