Geometry¶
warpconvnet.geometry
¶
Geometry containers¶
Every geometry type in WarpConvNet wraps a Coords instance and a Features
instance that share the same ragged-batch metadata.
warpconvnet.geometry.base.coords.Coordsstores concatenated coordinates plus anoffsetsvector marking where each example begins.warpconvnet.geometry.base.features.Features(and theCatFeaturesandPadFeaturesspecializations) stores feature tensors that obey the same offsets so coordinates and features always stay aligned.warpconvnet.geometry.base.geometry.Geometrywires the pair together, validates their shapes, and exposes device/dtype utilities with AMP-aware accessors.
This shared contract lets subclasses freely switch between point clouds, voxels, or grids without duplicating batching logic. See Batched coordinate layout for a deeper explanation of how concatenated tensors and offsets interact.
Types¶
WarpConvNet ships several geometry containers that unify coordinate systems with their associated features. Use these types as the canonical interfaces for points, voxels, dense grids, and FIGConvNet factor grids.
Points¶
Flexible point-cloud geometry supporting ragged batches, feature paddings, and neighbor search utilities for sparse convolution modules.
warpconvnet.geometry.types.points.Points
¶
Bases: Geometry
Interface class for collections of points
A point collection is a set of points in a geometric space (dim=1 (line), 2 (plane), 3 (space), 4 (space-time)).
Source code in warpconvnet/geometry/types/points.py
class Points(Geometry):
"""
Interface class for collections of points
A point collection is a set of points in a geometric space
(dim=1 (line), 2 (plane), 3 (space), 4 (space-time)).
"""
def __init__(
self,
batched_coordinates: (
List[Float[Tensor, "N 3"]] | Float[Tensor, "N 3"] | RealCoords
), # noqa: F722,F821
batched_features: (
List[Float[Tensor, "N C"]]
| Float[Tensor, "N C"]
| Float[Tensor, "B M C"]
| CatFeatures
| PadFeatures
), # noqa: F722,F821
offsets: Optional[Int[Tensor, "B + 1"]] = None, # noqa: F722,F821
device: Optional[str] = None,
**kwargs,
):
"""
Initialize a point collection with coordinates and features.
"""
if isinstance(batched_coordinates, list):
assert isinstance(
batched_features, list
), "If coords is a list, features must be a list too."
assert len(batched_coordinates) == len(batched_features)
# Assert all elements in coords and features have same length
assert all(
len(c) == len(f) for c, f in zip(batched_coordinates, batched_features)
), "All elements in coords and features must have same length"
batched_coordinates = RealCoords(batched_coordinates, device=device)
elif isinstance(batched_coordinates, Tensor):
assert (
isinstance(batched_features, Tensor) and offsets is not None
), "If coordinate is a tensor, features must be a tensor and offsets must be provided."
batched_coordinates = RealCoords(batched_coordinates, offsets=offsets, device=device)
if isinstance(batched_features, list):
batched_features = CatFeatures(batched_features, device=device)
elif isinstance(batched_features, Tensor):
batched_features = to_batched_features(
batched_features, batched_coordinates.offsets, device=device
)
Geometry.__init__(
self,
batched_coordinates,
batched_features,
**kwargs,
)
def sort(
self,
voxel_size: float,
ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
):
"""
Sort the points according to the ordering provided.
The voxel size defines the smallest discretization and points in the same voxel will have random order.
"""
# Warp uses int32 so only 10 bits per coordinate supported. Thus max 1024.
assert self.device.type != "cpu", "Sorting is only supported on GPU"
result = encode(
torch.floor(self.coordinate_tensor / voxel_size).int(),
batch_offsets=self.offsets,
order=ordering,
return_perm=True,
)
kwargs = self.extra_attributes.copy()
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[result.perm],
offsets=self.offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[result.perm],
offsets=self.offsets,
),
**kwargs,
)
def voxel_downsample(
self,
voxel_size: float,
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM,
) -> "Points":
"""
Voxel downsample the coordinates
"""
assert self.device.type != "cpu", "Voxel downsample is only supported on GPU"
extra_args = self.extra_attributes
extra_args["voxel_size"] = voxel_size
assert isinstance(
self.batched_features, CatFeatures
), "Voxel downsample is only supported for CatBatchedFeatures"
if reduction == REDUCTIONS.RANDOM:
to_unique_indicies, unique_offsets = voxel_downsample_random_indices(
batched_points=self.coordinate_tensor,
offsets=self.offsets,
voxel_size=voxel_size,
)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[to_unique_indicies],
offsets=unique_offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[to_unique_indicies],
offsets=unique_offsets,
),
**extra_args,
)
# perm, down_offsets, vox_inices, vox_offsets = voxel_downsample_csr_mapping(
# batched_points=self.coordinate_tensor,
# offsets=self.offsets,
# voxel_size=voxel_size,
# )
(
batch_indexed_down_coords,
unique_offsets,
to_csr_indices,
to_csr_offsets,
to_unique,
) = voxel_downsample_csr_mapping(
batched_points=self.coordinate_tensor,
offsets=self.offsets,
voxel_size=voxel_size,
)
neighbors = RealSearchResult(to_csr_indices, to_csr_offsets)
down_features = row_reduction(
self.feature_tensor,
neighbors.neighbor_row_splits,
reduction,
)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[to_unique.to_unique_indices],
offsets=unique_offsets,
),
batched_features=CatFeatures(batched_tensor=down_features, offsets=unique_offsets),
**extra_args,
)
def random_downsample(self, num_sample_points: int) -> "Points":
"""
Downsample the coordinates to the specified number of points.
If the batch size is N, the total number of output points is N * num_sample_points.
"""
sampled_indices, sample_offsets = random_sample_per_batch(self.offsets, num_sample_points)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[sampled_indices],
offsets=sample_offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[sampled_indices],
offsets=sample_offsets,
),
**self.extra_attributes,
)
def contiguous(self) -> "Points":
"""Ensure coordinates and features are contiguous in memory.
This is important for memory access patterns and can improve
performance for operations that require contiguous memory.
Returns:
Points: A new Points instance with contiguous tensors
"""
if self.coordinate_tensor.is_contiguous() and self.feature_tensor.is_contiguous():
return self
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor.contiguous(),
offsets=self.offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor.contiguous(),
offsets=self.offsets,
),
**self.extra_attributes,
)
def neighbors(
self,
search_args: RealSearchConfig,
query_coords: Optional["Coords"] = None,
) -> RealSearchResult:
"""
Returns CSR format neighbor indices
"""
if query_coords is None:
query_coords = self.batched_coordinates
assert isinstance(query_coords, Coords), "query_coords must be BatchedCoordinates"
# cache the neighbor search result
if self.cache is not None:
neighbor_search_result = self.cache.get(
search_args, self.offsets, query_coords.offsets
)
if neighbor_search_result is not None:
return neighbor_search_result
neighbor_search_result = neighbor_search(
self.coordinate_tensor,
self.offsets,
query_coords.batched_tensor,
query_coords.offsets,
search_args,
)
if self.cache is None:
self._extra_attributes["_cache"] = RealSearchCache()
self.cache.put(search_args, self.offsets, query_coords.offsets, neighbor_search_result)
return neighbor_search_result
@property
def voxel_size(self):
return self._extra_attributes.get("voxel_size", None)
@property
def ordering(self):
return self._extra_attributes.get("ordering", None)
@classmethod
def from_list_of_coordinates(
cls,
coordinates: List[Float[Tensor, "N 3"]],
features: Optional[List[Float[Tensor, "N C"]]] = None,
encoding_channels: Optional[int] = None,
encoding_range: Optional[Tuple[float, float]] = None,
encoding_dim: Optional[int] = -1,
):
"""
Create a point collection from a list of coordinates.
"""
# if the input is a tensor, expand it to a list of tensors
if isinstance(coordinates, Tensor):
coordinates = list(coordinates) # this expands the tensor to a list of tensors
if features is None:
assert (
encoding_range is not None
), "Encoding range must be provided if encoding channels are provided"
features = [
sinusoidal_encoding(coordinates, encoding_channels, encoding_range, encoding_dim)
for coordinates in coordinates
]
# Create BatchedContinuousCoordinates
batched_coordinates = RealCoords(coordinates)
# Create CatBatchedFeatures
batched_features = CatFeatures(features)
return cls(batched_coordinates, batched_features)
def to_voxels(
self,
voxel_size: float,
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
) -> "Voxels":
"""
Convert the point collection to a spatially sparse tensor.
"""
return points_to_voxels(self, voxel_size, reduction)
contiguous() -> Points
¶
Ensure coordinates and features are contiguous in memory.
This is important for memory access patterns and can improve performance for operations that require contiguous memory.
Returns: Points: A new Points instance with contiguous tensors
Source code in warpconvnet/geometry/types/points.py
def contiguous(self) -> "Points":
"""Ensure coordinates and features are contiguous in memory.
This is important for memory access patterns and can improve
performance for operations that require contiguous memory.
Returns:
Points: A new Points instance with contiguous tensors
"""
if self.coordinate_tensor.is_contiguous() and self.feature_tensor.is_contiguous():
return self
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor.contiguous(),
offsets=self.offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor.contiguous(),
offsets=self.offsets,
),
**self.extra_attributes,
)
from_list_of_coordinates(coordinates: List[Float[Tensor, 'N 3']], features: Optional[List[Float[Tensor, 'N C']]] = None, encoding_channels: Optional[int] = None, encoding_range: Optional[Tuple[float, float]] = None, encoding_dim: Optional[int] = -1)
classmethod
¶
Create a point collection from a list of coordinates.
Source code in warpconvnet/geometry/types/points.py
@classmethod
def from_list_of_coordinates(
cls,
coordinates: List[Float[Tensor, "N 3"]],
features: Optional[List[Float[Tensor, "N C"]]] = None,
encoding_channels: Optional[int] = None,
encoding_range: Optional[Tuple[float, float]] = None,
encoding_dim: Optional[int] = -1,
):
"""
Create a point collection from a list of coordinates.
"""
# if the input is a tensor, expand it to a list of tensors
if isinstance(coordinates, Tensor):
coordinates = list(coordinates) # this expands the tensor to a list of tensors
if features is None:
assert (
encoding_range is not None
), "Encoding range must be provided if encoding channels are provided"
features = [
sinusoidal_encoding(coordinates, encoding_channels, encoding_range, encoding_dim)
for coordinates in coordinates
]
# Create BatchedContinuousCoordinates
batched_coordinates = RealCoords(coordinates)
# Create CatBatchedFeatures
batched_features = CatFeatures(features)
return cls(batched_coordinates, batched_features)
neighbors(search_args: RealSearchConfig, query_coords: Optional[Coords] = None) -> RealSearchResult
¶
Returns CSR format neighbor indices
Source code in warpconvnet/geometry/types/points.py
def neighbors(
self,
search_args: RealSearchConfig,
query_coords: Optional["Coords"] = None,
) -> RealSearchResult:
"""
Returns CSR format neighbor indices
"""
if query_coords is None:
query_coords = self.batched_coordinates
assert isinstance(query_coords, Coords), "query_coords must be BatchedCoordinates"
# cache the neighbor search result
if self.cache is not None:
neighbor_search_result = self.cache.get(
search_args, self.offsets, query_coords.offsets
)
if neighbor_search_result is not None:
return neighbor_search_result
neighbor_search_result = neighbor_search(
self.coordinate_tensor,
self.offsets,
query_coords.batched_tensor,
query_coords.offsets,
search_args,
)
if self.cache is None:
self._extra_attributes["_cache"] = RealSearchCache()
self.cache.put(search_args, self.offsets, query_coords.offsets, neighbor_search_result)
return neighbor_search_result
random_downsample(num_sample_points: int) -> Points
¶
Downsample the coordinates to the specified number of points.
If the batch size is N, the total number of output points is N * num_sample_points.
Source code in warpconvnet/geometry/types/points.py
def random_downsample(self, num_sample_points: int) -> "Points":
"""
Downsample the coordinates to the specified number of points.
If the batch size is N, the total number of output points is N * num_sample_points.
"""
sampled_indices, sample_offsets = random_sample_per_batch(self.offsets, num_sample_points)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[sampled_indices],
offsets=sample_offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[sampled_indices],
offsets=sample_offsets,
),
**self.extra_attributes,
)
sort(voxel_size: float, ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ)
¶
Sort the points according to the ordering provided. The voxel size defines the smallest discretization and points in the same voxel will have random order.
Source code in warpconvnet/geometry/types/points.py
def sort(
self,
voxel_size: float,
ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
):
"""
Sort the points according to the ordering provided.
The voxel size defines the smallest discretization and points in the same voxel will have random order.
"""
# Warp uses int32 so only 10 bits per coordinate supported. Thus max 1024.
assert self.device.type != "cpu", "Sorting is only supported on GPU"
result = encode(
torch.floor(self.coordinate_tensor / voxel_size).int(),
batch_offsets=self.offsets,
order=ordering,
return_perm=True,
)
kwargs = self.extra_attributes.copy()
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[result.perm],
offsets=self.offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[result.perm],
offsets=self.offsets,
),
**kwargs,
)
to_voxels(voxel_size: float, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN) -> Voxels
¶
Convert the point collection to a spatially sparse tensor.
Source code in warpconvnet/geometry/types/points.py
def to_voxels(
self,
voxel_size: float,
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.MEAN,
) -> "Voxels":
"""
Convert the point collection to a spatially sparse tensor.
"""
return points_to_voxels(self, voxel_size, reduction)
voxel_downsample(voxel_size: float, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM) -> Points
¶
Voxel downsample the coordinates
Source code in warpconvnet/geometry/types/points.py
def voxel_downsample(
self,
voxel_size: float,
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR] = REDUCTIONS.RANDOM,
) -> "Points":
"""
Voxel downsample the coordinates
"""
assert self.device.type != "cpu", "Voxel downsample is only supported on GPU"
extra_args = self.extra_attributes
extra_args["voxel_size"] = voxel_size
assert isinstance(
self.batched_features, CatFeatures
), "Voxel downsample is only supported for CatBatchedFeatures"
if reduction == REDUCTIONS.RANDOM:
to_unique_indicies, unique_offsets = voxel_downsample_random_indices(
batched_points=self.coordinate_tensor,
offsets=self.offsets,
voxel_size=voxel_size,
)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[to_unique_indicies],
offsets=unique_offsets,
),
batched_features=CatFeatures(
batched_tensor=self.feature_tensor[to_unique_indicies],
offsets=unique_offsets,
),
**extra_args,
)
# perm, down_offsets, vox_inices, vox_offsets = voxel_downsample_csr_mapping(
# batched_points=self.coordinate_tensor,
# offsets=self.offsets,
# voxel_size=voxel_size,
# )
(
batch_indexed_down_coords,
unique_offsets,
to_csr_indices,
to_csr_offsets,
to_unique,
) = voxel_downsample_csr_mapping(
batched_points=self.coordinate_tensor,
offsets=self.offsets,
voxel_size=voxel_size,
)
neighbors = RealSearchResult(to_csr_indices, to_csr_offsets)
down_features = row_reduction(
self.feature_tensor,
neighbors.neighbor_row_splits,
reduction,
)
return self.__class__(
batched_coordinates=RealCoords(
batched_tensor=self.coordinate_tensor[to_unique.to_unique_indices],
offsets=unique_offsets,
),
batched_features=CatFeatures(batched_tensor=down_features, offsets=unique_offsets),
**extra_args,
)
Voxels¶
Sparse voxel geometry that accepts integer coordinates with tensor strides and offers helpers to move between dense tensors and CSR-style batched features.
warpconvnet.geometry.types.voxels.Voxels
¶
Bases: Geometry
Source code in warpconvnet/geometry/types/voxels.py
class Voxels(Geometry):
def __init__(
self,
batched_coordinates: List[Float[Tensor, "N 3"]] | Float[Tensor, "N 3"] | IntCoords,
batched_features: (
List[Float[Tensor, "N C"]]
| Float[Tensor, "N C"]
| Float[Tensor, "B M C"]
| CatFeatures
| PadFeatures
),
offsets: Optional[Int[Tensor, "B + 1"]] = None, # noqa: F722,F821
device: Optional[str] = None,
**kwargs,
):
# extract tensor_stride/stride from kwargs
tensor_stride = kwargs.pop("tensor_stride", None) or kwargs.pop("stride", None)
if isinstance(batched_coordinates, list):
assert isinstance(
batched_features, list
), "If coords is a list, features must be a list too."
assert len(batched_coordinates) == len(batched_features)
# Assert all elements in coords and features have same length
assert all(len(c) == len(f) for c, f in zip(batched_coordinates, batched_features))
batched_coordinates = IntCoords(
batched_coordinates, device=device, tensor_stride=tensor_stride
)
elif isinstance(batched_coordinates, Tensor):
assert (
isinstance(batched_features, Tensor) and offsets is not None
), "If coordinate is a tensor, features must be a tensor and offsets must be provided."
batched_coordinates = IntCoords(
batched_coordinates,
offsets=offsets,
device=device,
tensor_stride=tensor_stride,
)
else:
# Input is a BatchedDiscreteCoordinates
if tensor_stride is not None:
batched_coordinates.set_tensor_stride(tensor_stride)
if isinstance(batched_features, list):
batched_features = CatFeatures(batched_features, device=device)
elif isinstance(batched_features, Tensor):
batched_features = to_batched_features(
batched_features, batched_coordinates.offsets, device=device
)
Geometry.__init__(self, batched_coordinates, batched_features, **kwargs)
@classmethod
def from_dense(
cls,
dense_tensor: Float[Tensor, "B C H W"] | Float[Tensor, "B C H W D"],
dense_tensor_channel_dim: int = 1,
target_spatial_sparse_tensor: Optional["Voxels"] = None,
dense_max_coords: Optional[Tuple[int, ...]] = None,
**kwargs,
):
# Move channel dimension to the end
if dense_tensor_channel_dim != -1 or dense_tensor.ndim != dense_tensor_channel_dim + 1:
dense_tensor = dense_tensor.moveaxis(dense_tensor_channel_dim, -1)
spatial_shape = dense_tensor.shape[1:-1]
batched_spatial_shape = dense_tensor.shape[:-1]
# Flatten the spatial dimensions
flattened_tensor = dense_tensor.flatten(0, -2)
if target_spatial_sparse_tensor is None:
# abs sum all elements in the tensor
abs_sum = torch.abs(dense_tensor).sum(dim=-1, keepdim=False)
# Find all non-zero elements. Expected to be sorted.
non_zero_inds = torch.nonzero(abs_sum).int()
# Convert multi-dimensional indices to flattened indices
flattened_indices = ravel_multi_index(non_zero_inds, batched_spatial_shape)
# Use index_select to get the features
non_zero_feats = torch.index_select(flattened_tensor, 0, flattened_indices)
offsets = offsets_from_batch_index(non_zero_inds[:, 0])
return cls(
batched_coordinates=IntCoords(non_zero_inds[:, 1:], offsets=offsets),
batched_features=CatFeatures(non_zero_feats, offsets=offsets),
**kwargs,
)
else:
assert target_spatial_sparse_tensor.num_spatial_dims == len(spatial_shape)
assert target_spatial_sparse_tensor.batch_size == batched_spatial_shape[0]
# Use the provided spatial sparse tensor's coordinate only
batch_indexed_coords = target_spatial_sparse_tensor.batch_indexed_coordinates.clone()
# subtract the min_coords
min_coords = target_spatial_sparse_tensor.coordinate_tensor.min(dim=0).values.view(
1, -1
)
batch_indexed_coords[:, 1:] = batch_indexed_coords[:, 1:] - min_coords
if dense_max_coords is not None:
invalid = (batch_indexed_coords[:, 1:] > dense_max_coords).any(dim=1)
batch_indexed_coords[invalid] = 0
else:
sparse_max_coords = batch_indexed_coords[:, 1:].max(dim=0).values
# This assumes the max_coords are already aligned with the spatial_shape.
assert torch.all(
sparse_max_coords.cpu() < torch.tensor(spatial_shape)
), "Max coords must be aligned with the spatial shape."
# Ravel the coordinates. This assumes the max_coords are already aligned with the spatial_shape.
flattened_indices = ravel_multi_index(batch_indexed_coords, batched_spatial_shape)
# use index_select to get the features
non_zero_feats = torch.index_select(flattened_tensor, 0, flattened_indices)
if dense_max_coords is not None:
non_zero_feats[invalid] = 0
return target_spatial_sparse_tensor.replace(batched_features=non_zero_feats)
def to_dense(
self,
channel_dim: int = 1,
spatial_shape: Optional[Tuple[int, ...]] = None,
min_coords: Optional[Tuple[int, ...]] = None,
max_coords: Optional[Tuple[int, ...]] = None,
) -> Float[Tensor, "B C H W D"] | Float[Tensor, "B C H W"]:
device = self.batched_coordinates.device
# Get the batch indexed coordinates and features
batch_indexed_coords = self.batched_coordinates.batch_indexed_coordinates.clone()
features = self.batched_features.batched_tensor
# Get the spatial shape.
# If min_coords and max_coords are provided, assert that spatial_shape matches
if spatial_shape is None and min_coords is None:
# Get the min max coordinates
coords = batch_indexed_coords[:, 1:]
if coords.shape[0] == 0: # Handle empty tensor case
min_coords_tensor = torch.zeros(
self.num_spatial_dims, dtype=torch.long, device=device
)
spatial_shape_tensor = torch.zeros(
self.num_spatial_dims, dtype=torch.long, device=device
)
else:
min_coords_tensor = coords.min(dim=0).values
max_coords_tensor = coords.max(dim=0).values
spatial_shape_tensor = max_coords_tensor - min_coords_tensor + 1
# Shift the coordinates to the min_coords
batch_indexed_coords[:, 1:] = batch_indexed_coords[:, 1:] - min_coords_tensor.to(
device
)
spatial_shape = tuple(s.item() for s in spatial_shape_tensor)
elif min_coords is not None:
# Assert either max_coords or spatial_shape is provided
assert max_coords is not None or spatial_shape is not None
# Convert min_coords to tensor
min_coords_tensor = torch.tensor(min_coords, dtype=torch.int32, device=device)
if max_coords is None:
# convert spatial_shape to tensor
spatial_shape_tensor = torch.tensor(
spatial_shape, dtype=torch.int32, device=device
)
# max_coords_tensor = min_coords_tensor + spatial_shape_tensor - 1
else: # both min_coords and max_coords are provided
# Convert max_coords to tensor
max_coords_tensor = torch.tensor(max_coords, dtype=torch.int32, device=device)
assert len(min_coords_tensor) == len(max_coords_tensor) == self.num_spatial_dims
spatial_shape_tensor = max_coords_tensor - min_coords_tensor + 1
# Shift the coordinates to the min_coords and clip to the spatial_shape
# Create a mask to identify coordinates within the spatial range
mask = torch.ones(batch_indexed_coords.shape[0], dtype=torch.bool, device=device)
for d in range(1, batch_indexed_coords.shape[1]):
mask &= (batch_indexed_coords[:, d] >= min_coords_tensor[d - 1].item()) & (
batch_indexed_coords[:, d]
< min_coords_tensor[d - 1].item() + spatial_shape_tensor[d - 1].item()
)
batch_indexed_coords = batch_indexed_coords[mask]
features = features[mask]
spatial_shape = tuple(s.item() for s in spatial_shape_tensor)
elif spatial_shape is not None and len(spatial_shape) == self.coordinate_tensor.shape[1]:
# prepend a batch dimension
pass
else:
raise ValueError(
f"Provided spatial shape {spatial_shape} must be same length as the number of spatial dimensions {self.num_spatial_dims}."
)
if isinstance(spatial_shape, Tensor): # Should be tuple by now
spatial_shape = spatial_shape.tolist()
# Create a dense tensor
dense_tensor = torch.zeros(
(self.batch_size, *spatial_shape, self.num_channels),
dtype=self.batched_features.dtype,
device=self.batched_features.device,
)
if batch_indexed_coords.shape[0] > 0: # Only scatter if there are points
# Flatten view and scatter add
flattened_indices = ravel_multi_index(
batch_indexed_coords, (self.batch_size, *spatial_shape)
)
dense_tensor.flatten(0, -2)[flattened_indices] = features
if channel_dim != -1:
# Put the channel dimension in the specified position and move the rest of the dimensions contiguous
dense_tensor = dense_tensor.moveaxis(-1, channel_dim)
return dense_tensor
def to_point(self, voxel_size: Optional[float] = None) -> "Points": # noqa: F821
if voxel_size is None:
assert (
self.voxel_size is not None
), "Voxel size must be provided or the object must have been initialized with a voxel size to convert to point."
voxel_size = self.voxel_size
# tensor stride
if self.tensor_stride is not None:
tensor_stride = self.tensor_stride
# multiply voxel_size by tensor_stride
voxel_size = torch.Tensor([[voxel_size * s for s in tensor_stride]]).to(self.device)
from warpconvnet.geometry.types.points import Points
batched_points = RealCoords(self.coordinate_tensor * voxel_size, self.offsets)
return Points(batched_points, self.batched_features)
def sort(self, ordering: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ) -> "Voxels":
if ordering == self.ordering:
return self
assert isinstance(
self.batched_features, CatFeatures
), "Features must be a CatBatchedFeatures to sort."
code_result = encode(
self.coordinate_tensor,
batch_offsets=self.offsets,
order=ordering,
return_perm=True,
)
kwargs = self.extra_attributes.copy()
kwargs["ordering"] = ordering
kwargs["code"] = code_result.codes
return self.__class__(
batched_coordinates=IntCoords(self.coordinate_tensor[code_result.perm], self.offsets),
batched_features=CatFeatures(self.feature_tensor[code_result.perm], self.offsets),
**kwargs,
)
def unique(self) -> "Voxels":
unique_indices, batch_offsets = voxel_downsample_random_indices(
self.coordinate_tensor,
self.offsets,
)
coords = IntCoords(self.coordinate_tensor[unique_indices], batch_offsets)
feats = CatFeatures(self.feature_tensor[unique_indices], batch_offsets)
return self.__class__(coords, feats, **self.extra_attributes)
@property
def coordinate_hashmap(self) -> TorchHashTable:
return self.batched_coordinates.hashmap
@property
def voxel_size(self):
return self.extra_attributes.get("voxel_size", None)
@property
def ordering(self):
return self.extra_attributes.get("ordering", None)
@property
def stride(self):
return self.tensor_stride
@property
def tensor_stride(self):
return self.batched_coordinates.tensor_stride
def set_tensor_stride(self, tensor_stride: Union[int, Tuple[int, ...]]):
self.batched_coordinates.set_tensor_stride(tensor_stride)
@property
def batch_indexed_coordinates(self) -> Tensor:
return self.batched_coordinates.batch_indexed_coordinates
Grid¶
Regular dense grid representation that keeps GridCoords and GridFeatures
in sync, providing utilities for shape validation, format conversions, and
batch-aware initialization.
warpconvnet.geometry.types.grid.Grid
¶
Bases: Geometry
Grid geometry representation that combines coordinates and features.
This class provides a unified interface for grid-based geometries with any memory format, combining grid coordinates with grid features.
Args: batched_coordinates (GridCoords): Coordinate system for the grid batched_features (Union[GridFeatures, Tensor]): Grid features memory_format (GridMemoryFormat): Memory format for the features grid_shape (Tuple[int, int, int], optional): Grid resolution (H, W, D) num_channels (int, optional): Number of feature channels **kwargs: Additional parameters
Source code in warpconvnet/geometry/types/grid.py
class Grid(Geometry):
"""Grid geometry representation that combines coordinates and features.
This class provides a unified interface for grid-based geometries with any
memory format, combining grid coordinates with grid features.
Args:
batched_coordinates (GridCoords): Coordinate system for the grid
batched_features (Union[GridFeatures, Tensor]): Grid features
memory_format (GridMemoryFormat): Memory format for the features
grid_shape (Tuple[int, int, int], optional): Grid resolution (H, W, D)
num_channels (int, optional): Number of feature channels
**kwargs: Additional parameters
"""
def __init__(
self,
batched_coordinates: GridCoords,
batched_features: Union[GridFeatures, Tensor],
memory_format: Optional[GridMemoryFormat] = None,
grid_shape: Optional[Tuple[int, int, int]] = None,
num_channels: Optional[int] = None,
**kwargs,
):
if isinstance(batched_features, Tensor):
assert (
memory_format is not None
), "Memory format must be provided if features are a tensor"
if grid_shape is None:
grid_shape = batched_coordinates.grid_shape
# If num_channels not provided, infer it from tensor shape and memory format
if num_channels is None:
if memory_format == GridMemoryFormat.b_x_y_z_c:
num_channels = batched_features.shape[-1]
elif memory_format == GridMemoryFormat.b_c_x_y_z:
num_channels = batched_features.shape[1]
elif memory_format == GridMemoryFormat.b_c_z_x_y:
num_channels = batched_features.shape[1]
elif memory_format == GridMemoryFormat.b_zc_x_y:
zc = batched_features.shape[1]
num_channels = zc // grid_shape[2]
elif memory_format == GridMemoryFormat.b_xc_y_z:
xc = batched_features.shape[1]
num_channels = xc // grid_shape[0]
elif memory_format == GridMemoryFormat.b_yc_x_z:
yc = batched_features.shape[1]
num_channels = yc // grid_shape[1]
else:
raise ValueError(f"Unsupported memory format: {memory_format}")
# Create GridFeatures with same offsets as coordinates
batched_features = GridFeatures(
batched_features,
batched_coordinates.offsets.clone(),
memory_format,
grid_shape,
num_channels,
)
else:
assert (
memory_format is None or memory_format == batched_features.memory_format
), f"Memory format must be None or match the GridFeatures memory format: {batched_features.memory_format}. Provided: {memory_format}"
# Check that the grid is valid
self.check(batched_coordinates, batched_features)
# Ensure offsets match if coordinates are not lazy
assert (
batched_coordinates.offsets == batched_features.offsets
).all(), "Coordinate and feature offsets must match"
# Initialize base class
super().__init__(batched_coordinates, batched_features, **kwargs)
def check(self, coords: GridCoords, features: GridFeatures):
"""
Check if the grid dimensions are consistent
"""
assert coords.shape[-1] == 3
num_coords = coords.numel() // 3
num_features = features.numel() // features.num_channels
assert (
num_coords == num_features
), f"Number of coordinates ({num_coords}) must match number of features ({num_features})"
assert (
coords.grid_shape == features.grid_shape
), f"Grid shape ({coords.grid_shape}) must match feature grid shape ({features.grid_shape})"
@classmethod
def from_shape(
cls,
grid_shape: Tuple[int, int, int],
num_channels: int,
memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c,
bounds: Optional[Tuple[Tensor, Tensor]] = None,
batch_size: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs,
) -> "Grid":
"""
Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.
Args:
grid_shape: Grid resolution (H, W, D)
num_channels: Number of feature channels
memory_format: Memory format for features
bounds: Min and max bounds for the grid
batch_size: Number of batches
device: Device to create tensors on
dtype: Data type for feature tensors
**kwargs: Additional parameters
Returns:
Initialized grid geometry
"""
# Create coordinates. By default, data will be lazily initialized and coordinates will be flattened.
coords = GridCoords.from_shape(
grid_shape=grid_shape,
bounds=bounds,
batch_size=batch_size,
device=device,
flatten=True,
)
# Create empty features with same offsets
features = GridFeatures.create_empty(
grid_shape=grid_shape,
num_channels=num_channels,
batch_size=batch_size,
memory_format=memory_format,
device=device,
dtype=dtype,
)
# Make sure offsets match
assert (
coords.offsets == features.offsets
).all(), "Coordinate and feature offsets must match"
return cls(coords, features, memory_format, **kwargs)
@property
def grid_features(self) -> GridFeatures:
"""Get the grid features."""
return self.batched_features
@property
def grid_coords(self) -> GridCoords:
"""Get the grid coordinates."""
return self.batched_coordinates
@property
def grid_shape(self) -> Tuple[int, int, int]:
"""Get the grid shape (H, W, D)."""
return self.grid_coords.grid_shape
@property
def bounds(self) -> Tuple[Tensor, Tensor]:
"""Get the bounds of the grid."""
return self.grid_coords.bounds
@property
def num_channels(self) -> int:
"""Get the number of feature channels."""
return self.grid_features.num_channels
@property
def memory_format(self) -> GridMemoryFormat:
"""Get the memory format."""
return self.grid_features.memory_format
def channel_size(self, memory_format: Optional[GridMemoryFormat] = None):
if memory_format is None:
memory_format = self.memory_format
if memory_format == GridMemoryFormat.b_x_y_z_c:
return self.num_channels
elif memory_format == GridMemoryFormat.b_c_x_y_z:
return self.num_channels
elif memory_format == GridMemoryFormat.b_xc_y_z:
return self.num_channels * self.grid_shape[0]
elif memory_format == GridMemoryFormat.b_yc_x_z:
return self.num_channels * self.grid_shape[1]
elif memory_format == GridMemoryFormat.b_zc_x_y:
return self.num_channels * self.grid_shape[2]
else:
raise ValueError(f"Unsupported memory format: {memory_format}")
def to_memory_format(self, memory_format: GridMemoryFormat) -> "Grid":
"""Convert to a different memory format."""
if memory_format != self.memory_format:
return self.replace(
batched_features=self.grid_features.to_memory_format(memory_format),
memory_format=memory_format,
)
return self
@property
def shape(self) -> Dict[str, Union[int, Tuple[int, ...]]]:
"""Get the shape information."""
H, W, D = self.grid_shape
return {
"grid_shape": self.grid_shape,
"batch_size": self.batch_size,
"num_channels": self.num_channels,
"total_elements": H * W * D * self.batch_size,
}
def to(self, device: torch.device) -> "Grid":
"""Move the geometry to the specified device."""
return Grid(
self.grid_coords.to(device),
self.grid_features.to(device),
self.memory_format,
)
def replace(
self,
batched_coordinates: Optional[GridCoords] = None,
batched_features: Optional[Union[GridFeatures, Tensor]] = None,
**kwargs,
) -> "Grid":
"""Create a new instance with replaced coordinates and/or features."""
# Convert the batched_features to a GridFeatures if it is a tensor
if isinstance(batched_features, Tensor) and batched_features.ndim == 5:
# Based on the memory format, we have to check the shape of the tensor
if self.memory_format == GridMemoryFormat.b_x_y_z_c:
in_H, in_W, in_D, in_C = batched_features.shape[1:5]
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
assert in_D == self.grid_shape[2]
assert in_C == self.num_channels
elif self.memory_format == GridMemoryFormat.b_c_x_y_z:
in_C, in_H, in_W, in_D = batched_features.shape[1:5]
assert in_C == self.num_channels
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
assert in_D == self.grid_shape[2]
elif self.memory_format == GridMemoryFormat.b_c_z_x_y:
in_C, in_D, in_H, in_W = batched_features.shape[1:5]
assert in_C == self.num_channels
assert in_D == self.grid_shape[2]
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
else:
raise ValueError(f"Unsupported memory format: {self.memory_format}")
batched_features = GridFeatures(
batched_tensor=batched_features,
offsets=self.grid_features.offsets,
memory_format=self.memory_format,
grid_shape=self.grid_shape,
num_channels=in_C,
)
elif isinstance(batched_features, Tensor) and batched_features.ndim == 4:
# This is the compressed format
assert self.memory_format in [
GridMemoryFormat.b_zc_x_y,
GridMemoryFormat.b_xc_y_z,
GridMemoryFormat.b_yc_x_z,
], f"Unsupported memory format: {self.memory_format} for feature tensor of shape {batched_features.shape}"
# Assert that the grid shape is consistent with the feature tensor shape
# Only the channel dim can change when using .replace()
# e.g. in_H, in_W, in_D == self.grid_shape[0], self.grid_shape[1], self.grid_shape[2]
compressed_dim = batched_features.shape[1] # this is the compressed_dim * channels
new_channel_dim = None
if self.memory_format == GridMemoryFormat.b_zc_x_y:
assert batched_features.shape[2] == self.grid_shape[0]
assert batched_features.shape[3] == self.grid_shape[1]
new_channel_dim = compressed_dim // self.grid_shape[2]
elif self.memory_format == GridMemoryFormat.b_xc_y_z:
assert batched_features.shape[2] == self.grid_shape[1]
assert batched_features.shape[3] == self.grid_shape[2]
new_channel_dim = compressed_dim // self.grid_shape[0]
elif self.memory_format == GridMemoryFormat.b_yc_x_z:
assert batched_features.shape[2] == self.grid_shape[0]
assert batched_features.shape[3] == self.grid_shape[2]
new_channel_dim = compressed_dim // self.grid_shape[1]
else:
raise ValueError(f"Unsupported memory format: {self.memory_format}")
batched_features = GridFeatures(
batched_tensor=batched_features,
offsets=self.grid_features.offsets,
memory_format=self.memory_format,
grid_shape=self.grid_shape,
num_channels=new_channel_dim,
)
elif isinstance(batched_features, GridFeatures):
# no action needed
pass
else:
raise ValueError(f"Unsupported feature tensor shape: {batched_features.shape}")
return super().replace(batched_coordinates, batched_features, **kwargs)
bounds: Tuple[Tensor, Tensor]
property
¶
Get the bounds of the grid.
grid_coords: GridCoords
property
¶
Get the grid coordinates.
grid_features: GridFeatures
property
¶
Get the grid features.
grid_shape: Tuple[int, int, int]
property
¶
Get the grid shape (H, W, D).
memory_format: GridMemoryFormat
property
¶
Get the memory format.
num_channels: int
property
¶
Get the number of feature channels.
shape: Dict[str, Union[int, Tuple[int, ...]]]
property
¶
Get the shape information.
check(coords: GridCoords, features: GridFeatures)
¶
Check if the grid dimensions are consistent
Source code in warpconvnet/geometry/types/grid.py
def check(self, coords: GridCoords, features: GridFeatures):
"""
Check if the grid dimensions are consistent
"""
assert coords.shape[-1] == 3
num_coords = coords.numel() // 3
num_features = features.numel() // features.num_channels
assert (
num_coords == num_features
), f"Number of coordinates ({num_coords}) must match number of features ({num_features})"
assert (
coords.grid_shape == features.grid_shape
), f"Grid shape ({coords.grid_shape}) must match feature grid shape ({features.grid_shape})"
from_shape(grid_shape: Tuple[int, int, int], num_channels: int, memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c, bounds: Optional[Tuple[Tensor, Tensor]] = None, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs) -> Grid
classmethod
¶
Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.
Args: grid_shape: Grid resolution (H, W, D) num_channels: Number of feature channels memory_format: Memory format for features bounds: Min and max bounds for the grid batch_size: Number of batches device: Device to create tensors on dtype: Data type for feature tensors **kwargs: Additional parameters
Returns: Initialized grid geometry
Source code in warpconvnet/geometry/types/grid.py
@classmethod
def from_shape(
cls,
grid_shape: Tuple[int, int, int],
num_channels: int,
memory_format: GridMemoryFormat = GridMemoryFormat.b_x_y_z_c,
bounds: Optional[Tuple[Tensor, Tensor]] = None,
batch_size: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs,
) -> "Grid":
"""
Create a new Grid geometry from a grid shape. The coordinates will be lazily initialized and the features will be created as an empty tensor.
Args:
grid_shape: Grid resolution (H, W, D)
num_channels: Number of feature channels
memory_format: Memory format for features
bounds: Min and max bounds for the grid
batch_size: Number of batches
device: Device to create tensors on
dtype: Data type for feature tensors
**kwargs: Additional parameters
Returns:
Initialized grid geometry
"""
# Create coordinates. By default, data will be lazily initialized and coordinates will be flattened.
coords = GridCoords.from_shape(
grid_shape=grid_shape,
bounds=bounds,
batch_size=batch_size,
device=device,
flatten=True,
)
# Create empty features with same offsets
features = GridFeatures.create_empty(
grid_shape=grid_shape,
num_channels=num_channels,
batch_size=batch_size,
memory_format=memory_format,
device=device,
dtype=dtype,
)
# Make sure offsets match
assert (
coords.offsets == features.offsets
).all(), "Coordinate and feature offsets must match"
return cls(coords, features, memory_format, **kwargs)
replace(batched_coordinates: Optional[GridCoords] = None, batched_features: Optional[Union[GridFeatures, Tensor]] = None, **kwargs) -> Grid
¶
Create a new instance with replaced coordinates and/or features.
Source code in warpconvnet/geometry/types/grid.py
def replace(
self,
batched_coordinates: Optional[GridCoords] = None,
batched_features: Optional[Union[GridFeatures, Tensor]] = None,
**kwargs,
) -> "Grid":
"""Create a new instance with replaced coordinates and/or features."""
# Convert the batched_features to a GridFeatures if it is a tensor
if isinstance(batched_features, Tensor) and batched_features.ndim == 5:
# Based on the memory format, we have to check the shape of the tensor
if self.memory_format == GridMemoryFormat.b_x_y_z_c:
in_H, in_W, in_D, in_C = batched_features.shape[1:5]
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
assert in_D == self.grid_shape[2]
assert in_C == self.num_channels
elif self.memory_format == GridMemoryFormat.b_c_x_y_z:
in_C, in_H, in_W, in_D = batched_features.shape[1:5]
assert in_C == self.num_channels
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
assert in_D == self.grid_shape[2]
elif self.memory_format == GridMemoryFormat.b_c_z_x_y:
in_C, in_D, in_H, in_W = batched_features.shape[1:5]
assert in_C == self.num_channels
assert in_D == self.grid_shape[2]
assert in_H == self.grid_shape[0]
assert in_W == self.grid_shape[1]
else:
raise ValueError(f"Unsupported memory format: {self.memory_format}")
batched_features = GridFeatures(
batched_tensor=batched_features,
offsets=self.grid_features.offsets,
memory_format=self.memory_format,
grid_shape=self.grid_shape,
num_channels=in_C,
)
elif isinstance(batched_features, Tensor) and batched_features.ndim == 4:
# This is the compressed format
assert self.memory_format in [
GridMemoryFormat.b_zc_x_y,
GridMemoryFormat.b_xc_y_z,
GridMemoryFormat.b_yc_x_z,
], f"Unsupported memory format: {self.memory_format} for feature tensor of shape {batched_features.shape}"
# Assert that the grid shape is consistent with the feature tensor shape
# Only the channel dim can change when using .replace()
# e.g. in_H, in_W, in_D == self.grid_shape[0], self.grid_shape[1], self.grid_shape[2]
compressed_dim = batched_features.shape[1] # this is the compressed_dim * channels
new_channel_dim = None
if self.memory_format == GridMemoryFormat.b_zc_x_y:
assert batched_features.shape[2] == self.grid_shape[0]
assert batched_features.shape[3] == self.grid_shape[1]
new_channel_dim = compressed_dim // self.grid_shape[2]
elif self.memory_format == GridMemoryFormat.b_xc_y_z:
assert batched_features.shape[2] == self.grid_shape[1]
assert batched_features.shape[3] == self.grid_shape[2]
new_channel_dim = compressed_dim // self.grid_shape[0]
elif self.memory_format == GridMemoryFormat.b_yc_x_z:
assert batched_features.shape[2] == self.grid_shape[0]
assert batched_features.shape[3] == self.grid_shape[2]
new_channel_dim = compressed_dim // self.grid_shape[1]
else:
raise ValueError(f"Unsupported memory format: {self.memory_format}")
batched_features = GridFeatures(
batched_tensor=batched_features,
offsets=self.grid_features.offsets,
memory_format=self.memory_format,
grid_shape=self.grid_shape,
num_channels=new_channel_dim,
)
elif isinstance(batched_features, GridFeatures):
# no action needed
pass
else:
raise ValueError(f"Unsupported feature tensor shape: {batched_features.shape}")
return super().replace(batched_coordinates, batched_features, **kwargs)
to(device: torch.device) -> Grid
¶
Move the geometry to the specified device.
Source code in warpconvnet/geometry/types/grid.py
def to(self, device: torch.device) -> "Grid":
"""Move the geometry to the specified device."""
return Grid(
self.grid_coords.to(device),
self.grid_features.to(device),
self.memory_format,
)
to_memory_format(memory_format: GridMemoryFormat) -> Grid
¶
Convert to a different memory format.
Source code in warpconvnet/geometry/types/grid.py
def to_memory_format(self, memory_format: GridMemoryFormat) -> "Grid":
"""Convert to a different memory format."""
if memory_format != self.memory_format:
return self.replace(
batched_features=self.grid_features.to_memory_format(memory_format),
memory_format=memory_format,
)
return self
Factor grid¶
Container that bundles multiple Grid instances with distinct factorized
memory formats so that FIGConvNet layers can operate on complementary spatial
perspectives.
warpconvnet.geometry.types.factor_grid.FactorGrid
dataclass
¶
A group of grid geometries with different factorized memory formats.
This class implements the core concept of FIGConvNet where the 3D space is represented as multiple factorized 2D grids with different memory formats.
Args: geometries: List of GridGeometry objects with different factorized formats
Source code in warpconvnet/geometry/types/factor_grid.py
@dataclass
class FactorGrid:
"""A group of grid geometries with different factorized memory formats.
This class implements the core concept of FIGConvNet where the 3D space
is represented as multiple factorized 2D grids with different memory formats.
Args:
geometries: List of GridGeometry objects with different factorized formats
"""
grids: List[Grid]
_extra_attributes: Dict[str, Any] = field(default_factory=dict, init=True) # Store extra args
def __init__(self, grids: List[Grid], **kwargs):
self.grids = grids
# Validate we have at least one geometry
assert len(grids) > 0, "At least one geometry must be provided"
batch_size = grids[0].batch_size
num_channels = grids[0].num_channels
# Verify all geometries have the same batch size, channels, and grid shape
for geo in grids:
assert geo.batch_size == batch_size, "All geometries must have the same batch size"
assert (
geo.num_channels == num_channels
), "All geometries must have the same number of channels"
# Ensure each geometry uses a factorized format
assert geo.memory_format in [
GridMemoryFormat.b_zc_x_y,
GridMemoryFormat.b_xc_y_z,
GridMemoryFormat.b_yc_x_z,
], f"Expected factorized format, got {geo.memory_format}"
# Check for memory format duplicates
memory_formats = [geo.memory_format for geo in grids]
assert len(memory_formats) == len(
set(memory_formats)
), "Each geometry must have a unique memory format"
# Extra arguments for subclasses
# First check _extra_attributes in kwargs. This happens when we use dataclasses.replace
if "_extra_attributes" in kwargs:
attr = kwargs.pop("_extra_attributes")
assert isinstance(attr, dict), f"_extra_attributes must be a dictionary, got {attr}"
# Update kwargs
for k, v in attr.items():
kwargs[k] = v
self._extra_attributes = kwargs
@classmethod
def create_from_grid_shape(
cls,
grid_shapes: List[Tuple[int, int, int]],
num_channels: int,
memory_formats: List[Union[GridMemoryFormat, str]] = [
"b_zc_x_y",
"b_xc_y_z",
"b_yc_x_z",
],
bounds: Optional[Tuple[Tensor, Tensor]] = None,
batch_size: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "FactorGrid":
"""Create a new factorized grid geometry with initialized geometries.
Args:
grid_shapes: List of grid resolutions (H, W, D)
num_channels: Number of feature channels
memory_formats: List of factorized formats to use
bounds: Min and max bounds for the grid
batch_size: Number of batches
device: Device to create tensors on
dtype: Data type for feature tensors
Returns:
Initialized factorized grid geometry
"""
assert len(grid_shapes) == len(
memory_formats
), "grid_shapes and memory_formats must have the same length"
for grid_shape in grid_shapes:
assert (
isinstance(grid_shape, tuple) and len(grid_shape) == 3
), f"grid_shape: {grid_shape} must be a tuple of 3 integers."
# First create a standard grid geometry
geometries = []
for grid_shape, memory_format in zip(grid_shapes, memory_formats):
if isinstance(memory_format, str):
memory_format = GridMemoryFormat(memory_format)
geometry = Grid.from_shape(
grid_shape,
num_channels,
memory_format=memory_format,
bounds=bounds,
batch_size=batch_size,
device=device,
dtype=dtype,
)
geometries.append(geometry)
# Then convert to each factorized format
return cls(geometries)
@property
def batch_size(self) -> int:
"""Return the batch size of the geometries."""
return self.grids[0].batch_size
@property
def num_channels(self) -> int:
"""Return the number of channels in the geometries."""
return self.grids[0].num_channels
@property
def device(self) -> torch.device:
"""Return the device of the geometries."""
return self.grids[0].device
def to(self, device: torch.device) -> "FactorGrid":
"""Move all geometries to the specified device."""
return FactorGrid([geo.to(device) for geo in self.grids])
def __getitem__(self, idx: int) -> Grid:
"""Get a specific geometry from the group."""
return self.grids[idx]
def __len__(self) -> int:
"""Get the number of geometries in the group."""
return len(self.grids)
def __iter__(self):
"""Iterate over the grids."""
return iter(self.grids)
def __repr__(self) -> str:
"""String representation of the FactorGrid."""
out_str = "FactorGrid("
for grid in self.grids:
out_str += f"\n\t{grid}"
out_str += "\n)"
return out_str
def __add__(self, other: "FactorGrid") -> "FactorGrid":
"""Add two FactorGrid objects together element-wise."""
assert len(self) == len(
other
), f"FactorGrid lengths must match: {len(self)} != {len(other)}"
new_grids = []
for grid_a, grid_b in zip(self.grids, other.grids):
# Add features together using Grid.replace()
new_features = (
grid_a.grid_features.batched_tensor + grid_b.grid_features.batched_tensor
)
new_grid = grid_a.replace(batched_features=new_features)
new_grids.append(new_grid)
return FactorGrid(new_grids)
def get_by_format(self, memory_format: GridMemoryFormat) -> Optional[Grid]:
"""Get a geometry with the specified memory format.
Args:
memory_format: The memory format to look for
Returns:
The geometry with the requested format, or None if not found
"""
for geo in self.grids:
if geo.memory_format == memory_format:
return geo
return None
@property
def shapes(self) -> List[Dict[str, Union[int, Tuple[int, ...]]]]:
"""Get shape information for all geometries."""
return [geo.shape for geo in self.grids]
batch_size: int
property
¶
Return the batch size of the geometries.
device: torch.device
property
¶
Return the device of the geometries.
num_channels: int
property
¶
Return the number of channels in the geometries.
shapes: List[Dict[str, Union[int, Tuple[int, ...]]]]
property
¶
Get shape information for all geometries.
create_from_grid_shape(grid_shapes: List[Tuple[int, int, int]], num_channels: int, memory_formats: List[Union[GridMemoryFormat, str]] = ['b_zc_x_y', 'b_xc_y_z', 'b_yc_x_z'], bounds: Optional[Tuple[Tensor, Tensor]] = None, batch_size: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> FactorGrid
classmethod
¶
Create a new factorized grid geometry with initialized geometries.
Args: grid_shapes: List of grid resolutions (H, W, D) num_channels: Number of feature channels memory_formats: List of factorized formats to use bounds: Min and max bounds for the grid batch_size: Number of batches device: Device to create tensors on dtype: Data type for feature tensors
Returns: Initialized factorized grid geometry
Source code in warpconvnet/geometry/types/factor_grid.py
@classmethod
def create_from_grid_shape(
cls,
grid_shapes: List[Tuple[int, int, int]],
num_channels: int,
memory_formats: List[Union[GridMemoryFormat, str]] = [
"b_zc_x_y",
"b_xc_y_z",
"b_yc_x_z",
],
bounds: Optional[Tuple[Tensor, Tensor]] = None,
batch_size: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "FactorGrid":
"""Create a new factorized grid geometry with initialized geometries.
Args:
grid_shapes: List of grid resolutions (H, W, D)
num_channels: Number of feature channels
memory_formats: List of factorized formats to use
bounds: Min and max bounds for the grid
batch_size: Number of batches
device: Device to create tensors on
dtype: Data type for feature tensors
Returns:
Initialized factorized grid geometry
"""
assert len(grid_shapes) == len(
memory_formats
), "grid_shapes and memory_formats must have the same length"
for grid_shape in grid_shapes:
assert (
isinstance(grid_shape, tuple) and len(grid_shape) == 3
), f"grid_shape: {grid_shape} must be a tuple of 3 integers."
# First create a standard grid geometry
geometries = []
for grid_shape, memory_format in zip(grid_shapes, memory_formats):
if isinstance(memory_format, str):
memory_format = GridMemoryFormat(memory_format)
geometry = Grid.from_shape(
grid_shape,
num_channels,
memory_format=memory_format,
bounds=bounds,
batch_size=batch_size,
device=device,
dtype=dtype,
)
geometries.append(geometry)
# Then convert to each factorized format
return cls(geometries)
get_by_format(memory_format: GridMemoryFormat) -> Optional[Grid]
¶
Get a geometry with the specified memory format.
Args: memory_format: The memory format to look for
Returns: The geometry with the requested format, or None if not found
Source code in warpconvnet/geometry/types/factor_grid.py
def get_by_format(self, memory_format: GridMemoryFormat) -> Optional[Grid]:
"""Get a geometry with the specified memory format.
Args:
memory_format: The memory format to look for
Returns:
The geometry with the requested format, or None if not found
"""
for geo in self.grids:
if geo.memory_format == memory_format:
return geo
return None
to(device: torch.device) -> FactorGrid
¶
Move all geometries to the specified device.
Source code in warpconvnet/geometry/types/factor_grid.py
def to(self, device: torch.device) -> "FactorGrid":
"""Move all geometries to the specified device."""
return FactorGrid([geo.to(device) for geo in self.grids])