Neural Network Functionals¶
warpconvnet.nn.functional
¶
Functional operators used by WarpConvNet modules.
The package intentionally keeps implementations in separate submodules
(sparse_conv, point_pool, etc.) to avoid large monolithic files.
Defining this initializer makes the namespace explicit so documentation
generators like mkdocstrings can traverse it.
Functionals¶
Sparse convolution¶
warpconvnet.nn.functional.sparse_conv
¶
Sparse depthwise convolution¶
warpconvnet.nn.functional.sparse_conv_depth
¶
spatially_sparse_depthwise_conv(in_features: Float[Tensor, 'N C'], weight: Float[Tensor, 'K C'], kernel_map: IntSearchResult, num_out_coords: int, fwd_algo: Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE], None] = None, bwd_algo: Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE], None] = None, compute_dtype: Optional[torch.dtype] = None) -> Float[Tensor, 'M C']
¶
Perform spatially sparse depthwise convolution.
Args: in_features: Input features of shape (N, C) weight: Depthwise convolution weights of shape (K, C) kernel_map: Kernel mapping from IntSearchResult num_out_coords: Number of output coordinates fwd_algo: Forward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space. If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE (default: explicit) bwd_algo: Backward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space. If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE (default: explicit) compute_dtype: Computation dtype (defaults to input dtype)
Returns: Output features of shape (M, C)
Source code in warpconvnet/nn/functional/sparse_conv_depth.py
def spatially_sparse_depthwise_conv(
in_features: Float[Tensor, "N C"],
weight: Float[Tensor, "K C"],
kernel_map: IntSearchResult,
num_out_coords: int,
fwd_algo: Union[
SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE], None
] = None,
bwd_algo: Union[
SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE], None
] = None,
compute_dtype: Optional[torch.dtype] = None,
) -> Float[Tensor, "M C"]:
"""
Perform spatially sparse depthwise convolution.
Args:
in_features: Input features of shape (N, C)
weight: Depthwise convolution weights of shape (K, C)
kernel_map: Kernel mapping from IntSearchResult
num_out_coords: Number of output coordinates
fwd_algo: Forward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space.
If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE (default: explicit)
bwd_algo: Backward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space.
If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE (default: explicit)
compute_dtype: Computation dtype (defaults to input dtype)
Returns:
Output features of shape (M, C)
"""
# Use environment variables if no explicit algorithms provided
if fwd_algo is None:
fwd_algo = WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE
if bwd_algo is None:
bwd_algo = WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE
return UnifiedSpatiallySparseDepthwiseConvFunction.apply(
in_features,
weight,
kernel_map,
num_out_coords,
fwd_algo,
bwd_algo,
compute_dtype,
)
Sparse pooling¶
warpconvnet.nn.functional.sparse_pool
¶
sparse_avg_pool(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None) -> Voxels
¶
Average pooling for spatially sparse tensors.
Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_avg_pool(
voxels: Voxels,
kernel_size: Union[int, Tuple[int, ...]],
stride: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Voxels:
"""
Average pooling for spatially sparse tensors.
"""
return sparse_reduce(voxels, kernel_size, stride, reduction=REDUCTIONS.MEAN)
sparse_max_pool(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None) -> Voxels
¶
Max pooling for spatially sparse tensors.
Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_max_pool(
voxels: Voxels,
kernel_size: Union[int, Tuple[int, ...]],
stride: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Voxels:
"""
Max pooling for spatially sparse tensors.
"""
return sparse_reduce(voxels, kernel_size, stride, reduction=REDUCTIONS.MAX)
sparse_reduce(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None, reduction: Union[REDUCTIONS, str] = REDUCTIONS.MAX, order: POINT_ORDERING = POINT_ORDERING.RANDOM) -> Voxels
¶
Max pooling for spatially sparse tensors.
Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_reduce(
voxels: Voxels,
kernel_size: Union[int, Tuple[int, ...]],
stride: Optional[Union[int, Tuple[int, ...]]] = None,
reduction: Union[REDUCTIONS, str] = REDUCTIONS.MAX,
order: POINT_ORDERING = POINT_ORDERING.RANDOM,
) -> Voxels:
"""
Max pooling for spatially sparse tensors.
"""
if isinstance(reduction, str):
reduction = REDUCTIONS(reduction)
if stride is None:
stride = kernel_size
ndim = voxels.num_spatial_dims
stride = ntuple(stride, ndim=ndim)
kernel_size = ntuple(kernel_size, ndim=ndim)
in_tensor_stride = voxels.stride
if in_tensor_stride is None:
in_tensor_stride = ntuple(1, ndim=ndim)
out_tensor_stride = tuple(o * s for o, s in zip(stride, in_tensor_stride))
batch_indexed_in_coords = voxels.batch_indexed_coordinates
batch_indexed_out_coords, output_offsets = stride_coords(
batch_indexed_in_coords, stride, order=order
)
from warpconvnet.nn.functional.sparse_conv import STRIDED_CONV_MODE
kernel_map_cache_key = IntSearchCacheKey(
kernel_size=kernel_size,
kernel_dilation=ntuple(1, ndim=ndim),
transposed=False,
generative=False,
stride_mode=str(STRIDED_CONV_MODE.STRIDE_ONLY),
skip_symmetric_kernel_map=False,
in_offsets=voxels.offsets,
out_offsets=output_offsets,
)
kernel_map = None
if voxels.cache is not None:
kernel_map = voxels.cache.get(kernel_map_cache_key)
if kernel_map is None:
# Find mapping from in to out
kernel_map: IntSearchResult = generate_kernel_map(
batch_indexed_in_coords,
batch_indexed_out_coords,
in_to_out_stride_ratio=stride,
kernel_size=kernel_size,
kernel_dilation=ntuple(1, ndim=ndim),
skip_symmetric_kernel_map=False,
)
if voxels.cache is None:
voxels._extra_attributes["_cache"] = IntSearchCache()
voxels.cache.put(kernel_map_cache_key, kernel_map)
in_maps, unique_out_maps, map_offsets = kernel_map.to_csr()
in_features = voxels.feature_tensor
device = in_features.device
out_features = row_reduction(in_features[in_maps], map_offsets.to(device), reduction)
if len(unique_out_maps) != batch_indexed_out_coords.shape[0]:
warnings.warn(
f"Some output coordinates don't have any input maps. {batch_indexed_out_coords.shape[0] - len(unique_out_maps)} output coordinates are missing.",
stacklevel=2,
)
# cchoy: This is a rare case where some output coordinates don't have any input maps.
# We need to zero out the features for those coordinates.
new_out_features = torch.zeros(
batch_indexed_out_coords.shape[0],
in_features.shape[1],
device=voxels.device,
dtype=voxels.dtype,
)
new_out_features[unique_out_maps] = out_features
out_features = new_out_features
output_offsets = output_offsets.cpu()
return voxels.replace(
batched_coordinates=IntCoords(
batch_indexed_out_coords[:, 1:],
output_offsets,
),
batched_features=out_features,
stride=out_tensor_stride,
)
sparse_unpool(pooled_voxels: Voxels, unpooled_voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]], concat_unpooled_voxels: bool = False) -> Voxels
¶
Unpooling for spatially sparse tensors.
Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_unpool(
pooled_voxels: Voxels,
unpooled_voxels: Voxels,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]],
concat_unpooled_voxels: bool = False,
) -> Voxels:
"""
Unpooling for spatially sparse tensors.
"""
ndim = pooled_voxels.num_spatial_dims
stride = ntuple(stride, ndim=ndim)
kernel_size = ntuple(kernel_size, ndim=ndim)
# use the cache for the transposed case to get the kernel map
from warpconvnet.nn.functional.sparse_conv import STRIDED_CONV_MODE
kernel_map_cache_key = IntSearchCacheKey(
kernel_size=kernel_size,
kernel_dilation=ntuple(1, ndim=ndim),
transposed=False,
generative=False,
stride_mode=str(STRIDED_CONV_MODE.STRIDE_ONLY),
skip_symmetric_kernel_map=False,
in_offsets=unpooled_voxels.offsets,
out_offsets=pooled_voxels.offsets,
)
assert pooled_voxels.cache is not None
kernel_map = pooled_voxels.cache.get(kernel_map_cache_key)
assert kernel_map is not None
# Switch
unpooled_maps = kernel_map.in_maps
pooled_maps = kernel_map.out_maps
perm = torch.argsort(unpooled_maps)
rep_feats = pooled_voxels.feature_tensor[pooled_maps[perm]]
if concat_unpooled_voxels:
rep_feats = torch.cat([unpooled_voxels.feature_tensor, rep_feats], dim=-1)
return unpooled_voxels.replace(batched_features=rep_feats)
Sparse ops helpers¶
warpconvnet.nn.functional.sparse_ops
¶
cat_spatially_sparse_tensors(*sparse_tensors: Sequence[Voxels]) -> Voxels
¶
Concatenate a list of spatially sparse tensors.
Source code in warpconvnet/nn/functional/sparse_ops.py
def cat_spatially_sparse_tensors(
*sparse_tensors: Sequence[Voxels],
) -> Voxels:
"""
Concatenate a list of spatially sparse tensors.
"""
# Check that all sparse tensors have the same offsets
offsets = sparse_tensors[0].offsets
for sparse_tensor in sparse_tensors:
if not torch.allclose(sparse_tensor.offsets.to(offsets), offsets):
raise ValueError("All sparse tensors must have the same offsets")
# Concatenate the features tensors
features_tensor = torch.cat(
[sparse_tensor.feature_tensor for sparse_tensor in sparse_tensors], dim=-1
)
return sparse_tensors[0].replace(batched_features=features_tensor)
prune_spatially_sparse_tensor(spatial_tensor: Geometry, mask: Bool[Tensor, N]) -> Geometry
¶
Prune a spatially sparse tensor using a boolean mask.
Args: spatial_tensor: Geometry instance whose coordinates/features will be filtered. mask: Boolean mask of shape (N,) aligned with the flattened coordinates/features.
Returns: New Geometry instance containing only the entries where mask == True.
Source code in warpconvnet/nn/functional/sparse_ops.py
def prune_spatially_sparse_tensor(
spatial_tensor: Geometry,
mask: Bool[Tensor, "N"], # noqa: F821
) -> Geometry:
"""
Prune a spatially sparse tensor using a boolean mask.
Args:
spatial_tensor: Geometry instance whose coordinates/features will be filtered.
mask: Boolean mask of shape (N,) aligned with the flattened coordinates/features.
Returns:
New Geometry instance containing only the entries where mask == True.
"""
if mask.shape[0] != spatial_tensor.coordinate_tensor.shape[0]:
raise ValueError(
f"Mask length {mask.shape[0]} must match number of coordinates {spatial_tensor.coordinate_tensor.shape[0]}"
)
mask = mask.to(spatial_tensor.device)
if mask.dtype != torch.bool:
mask = mask.bool()
coords = spatial_tensor.batched_coordinates
if not hasattr(coords, "prune"):
raise TypeError(f"{coords.__class__.__name__} does not implement prune()")
pruned_coords = coords.prune(mask)
pruned_features = spatial_tensor.feature_tensor[mask]
return spatial_tensor.replace(
batched_coordinates=pruned_coords,
batched_features=pruned_features,
)
Point pooling¶
warpconvnet.nn.functional.point_pool
¶
point_pool(pc: Points, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR], downsample_max_num_points: Optional[int] = None, downsample_voxel_size: Optional[float] = None, return_type: Literal['point', 'voxel'] = 'point', average_pooled_coordinates: bool = False, return_neighbor_search_result: bool = False, return_to_unique: bool = False, unique_method: Literal['torch', 'ravel', 'morton'] = 'torch') -> Geometry
¶
Pool points in a point cloud. When downsample_max_num_points is provided, the point cloud will be downsampled to the number of points. When downsample_voxel_size is provided, the point cloud will be downsampled to the voxel size. When both are provided, the point cloud will be downsampled to the voxel size.
Args: pc: Points reduction: Reduction type downsample_max_num_points: Number of points to downsample to downsample_voxel_size: Voxel size to downsample to return_type: Return type return_neighbor_search_result: Return neighbor search result return_to_unique: Return to unique object Returns: Points or Voxels
Source code in warpconvnet/nn/functional/point_pool.py
def point_pool(
pc: "Points", # noqa: F821
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR],
downsample_max_num_points: Optional[int] = None,
downsample_voxel_size: Optional[float] = None,
return_type: Literal["point", "voxel"] = "point",
average_pooled_coordinates: bool = False,
return_neighbor_search_result: bool = False,
return_to_unique: bool = False,
unique_method: Literal["torch", "ravel", "morton"] = "torch",
) -> Geometry:
"""
Pool points in a point cloud.
When downsample_max_num_points is provided, the point cloud will be downsampled to the number of points.
When downsample_voxel_size is provided, the point cloud will be downsampled to the voxel size.
When both are provided, the point cloud will be downsampled to the voxel size.
Args:
pc: Points
reduction: Reduction type
downsample_max_num_points: Number of points to downsample to
downsample_voxel_size: Voxel size to downsample to
return_type: Return type
return_neighbor_search_result: Return neighbor search result
return_to_unique: Return to unique object
Returns:
Points or Voxels
"""
from warpconvnet.geometry.types.points import Points
from warpconvnet.geometry.types.voxels import Voxels
if isinstance(reduction, str):
reduction = REDUCTIONS(reduction)
# assert at least one of the two is provided
assert (
downsample_max_num_points is not None or downsample_voxel_size is not None
), "Either downsample_num_points or downsample_voxel_size must be provided."
assert return_type in [
"point",
"voxel",
], "return_type must be either point or voxel."
if return_type == "voxel":
assert (
not average_pooled_coordinates
), "averaging pooled coordinates is not supported for Voxels return type"
RETURN_CLS = Voxels
else:
RETURN_CLS = Points
if downsample_max_num_points is not None:
assert (
not return_to_unique
), "return_to_unique must be False when downsample_max_num_points is provided."
return _pool_by_max_num_points(
pc,
reduction,
downsample_max_num_points,
return_type,
return_neighbor_search_result,
)
if reduction == REDUCTIONS.RANDOM:
assert (
not return_to_unique
), "return_to_unique must be False when reduction is RANDOM."
assert (
not return_neighbor_search_result
), "return_neighbor_search_result must be False when reduction is RANDOM."
return _pool_by_random_sample(
pc,
downsample_voxel_size,
return_type,
)
# voxel downsample
(
unique_coords,
unique_offsets,
to_csr_indices,
to_csr_offsets,
to_unique,
) = voxel_downsample_csr_mapping(
batched_points=pc.coordinate_tensor,
offsets=pc.offsets,
voxel_size=downsample_voxel_size,
unique_method=unique_method,
)
down_features = row_reduction(
pc.feature_tensor[to_csr_indices],
to_csr_offsets,
reduction=reduction,
)
batched_coords = _generate_batched_coords(
pc.coordinate_tensor,
return_type,
to_csr_indices,
to_csr_offsets,
to_unique.to_unique_indices,
unique_offsets,
downsample_voxel_size,
average_pooled_coordinates,
)
out_sf = RETURN_CLS(
batched_coordinates=batched_coords,
batched_features=down_features,
voxel_size=downsample_voxel_size,
)
if return_to_unique:
return out_sf, to_unique
if return_neighbor_search_result:
return out_sf, RealSearchResult(to_unique.to_unique_indices, unique_offsets)
return out_sf
point_pool_by_code(pc: Points, code: Int[Tensor, N], reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR], average_pooled_coordinates: bool = False, return_to_unique: bool = False) -> Points
¶
Pool points based on a user-provided clustering code.
Source code in warpconvnet/nn/functional/point_pool.py
def point_pool_by_code(
pc: "Points", # noqa: F821
code: Int[Tensor, "N"], # noqa: F821
reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR],
average_pooled_coordinates: bool = False,
return_to_unique: bool = False,
) -> "Points": # noqa: F821
"""Pool points based on a user-provided clustering code."""
from warpconvnet.geometry.types.points import Points
if isinstance(reduction, str):
reduction = REDUCTIONS(reduction)
# get the unique indices
to_unique = ToUnique(return_to_unique_indices=return_to_unique)
unique_code = to_unique.to_unique(code)
# get the coordinates
if average_pooled_coordinates:
coords = row_reduction(
pc.coordinate_tensor[to_unique.to_csr_indices],
to_unique.to_csr_offsets,
reduction=REDUCTIONS.MEAN,
)
else:
coords = pc.coordinate_tensor[to_unique.to_unique_indices]
# get the features
features = row_reduction(
pc.feature_tensor[to_unique.to_csr_indices],
to_unique.to_csr_offsets,
reduction=reduction,
)
# get the offsets
offsets = offsets_from_offsets(
pc.offsets,
to_unique.to_unique_indices,
device="cpu"
)
out_pc = pc.replace(
coordinate_tensor=coords,
feature_tensor=features,
offsets=offsets,
code=unique_code,
)
if return_to_unique:
return out_pc, to_unique
return out_pc
Point unpooling¶
warpconvnet.nn.functional.point_unpool
¶
point_unpool(pooled_pc: BatchedSpatialFeatures, unpooled_pc: Points, concat_unpooled_pc: bool, unpooling_mode: Optional[Union[str, FEATURE_UNPOOLING_MODE]] = FEATURE_UNPOOLING_MODE.REPEAT, to_unique: Optional[ToUnique] = None) -> Points
¶
Unpool features to a denser set of points.
Source code in warpconvnet/nn/functional/point_unpool.py
def point_unpool(
pooled_pc: "BatchedSpatialFeatures", # noqa: F821
unpooled_pc: "Points", # noqa: F821
concat_unpooled_pc: bool,
unpooling_mode: Optional[
Union[str, FEATURE_UNPOOLING_MODE]
] = FEATURE_UNPOOLING_MODE.REPEAT,
to_unique: Optional[ToUnique] = None,
) -> "Points": # noqa: F821
"""Unpool features to a denser set of points."""
unpooled_features = _unpool_features(
pooled_pc=pooled_pc,
unpooled_pc=unpooled_pc,
to_unique=to_unique,
unpooling_mode=unpooling_mode,
)
if concat_unpooled_pc:
unpooled_features = torch.cat(
(unpooled_features, unpooled_pc.feature_tensor), dim=-1
)
return unpooled_pc.replace(batched_features=unpooled_features)
Grid convolution¶
warpconvnet.nn.functional.grid_conv
¶
grid_conv(grid: Grid, weight: Float[Tensor, 'C_out C_in D H W'], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, bias: bool = True) -> Grid
¶
3D Convolution on a Grid geometry type.
It is a simple wrapper on torch.nn.functional.conv3d. The output grid shape is computed as follows:
D_out = ((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0]) + 1 H_out = ((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1]) + 1 W_out = ((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[2]) + 1
Args: grid: Grid geometry type weight: Weight tensor stride: Stride padding: Padding dilation: Dilation bias: Bias
Returns: Grid: Output grid
Source code in warpconvnet/nn/functional/grid_conv.py
def grid_conv(
grid: Grid,
weight: Float[Tensor, "C_out C_in D H W"], # noqa: F821
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
bias: bool = True,
) -> Grid:
"""
3D Convolution on a Grid geometry type.
It is a simple wrapper on torch.nn.functional.conv3d.
The output grid shape is computed as follows:
D_out = ((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0]) + 1
H_out = ((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1]) + 1
W_out = ((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[2]) + 1
Args:
grid: Grid geometry type
weight: Weight tensor
stride: Stride
padding: Padding
dilation: Dilation
bias: Bias
Returns:
Grid: Output grid
"""
# Use F.conv3d
if grid.memory_format != GridMemoryFormat.b_c_z_x_y:
warnings.warn(
f"Input grid memory format is {grid.memory_format}, converting to {GridMemoryFormat.b_c_z_x_y}"
)
grid = grid.to_memory_format(GridMemoryFormat.b_c_z_x_y)
# Apply convolution
output_tensor = F.conv3d(grid.features, weight, bias, stride, padding, dilation)
# For stride, padding, dilation, the grid shape may not match the output shape
D, H, W = tuple(output_tensor.shape[2:])
# Create a new Grid with the same coordinates but updated features
return Grid(
batched_coordinates=GridCoords.from_shape(
grid_shape=(H, W, D),
bounds=grid.bounds,
batch_size=grid.batch_size,
device=grid.device,
),
batched_features=output_tensor,
memory_format=GridMemoryFormat.b_c_z_x_y,
)
Factor grid¶
warpconvnet.nn.functional.factor_grid
¶
factor_grid_cat(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid
¶
Concatenate features from two FactorGrid objects.
Args: factor_grid1: First FactorGrid factor_grid2: Second FactorGrid
Returns: FactorGrid with concatenated features
Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_cat(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
"""Concatenate features from two FactorGrid objects.
Args:
factor_grid1: First FactorGrid
factor_grid2: Second FactorGrid
Returns:
FactorGrid with concatenated features
"""
assert len(factor_grid1) == len(
factor_grid2
), f"FactorGrid lengths must match: {len(factor_grid1)} != {len(factor_grid2)}"
concatenated_grids = []
for grid1, grid2 in zip(factor_grid1, factor_grid2):
# Get features from both grids
features1 = grid1.grid_features.batched_tensor
features2 = grid2.grid_features.batched_tensor
# Concatenate along channel dimension based on memory format
if grid1.memory_format == GridMemoryFormat.b_x_y_z_c:
# Channel is last dimension
concatenated_features = torch.cat([features1, features2], dim=-1)
elif grid1.memory_format == GridMemoryFormat.b_c_x_y_z:
# Channel is second dimension
concatenated_features = torch.cat([features1, features2], dim=1)
elif grid1.memory_format in [
GridMemoryFormat.b_zc_x_y,
GridMemoryFormat.b_xc_y_z,
GridMemoryFormat.b_yc_x_z,
]:
# For factorized formats, channel is combined with spatial dimension
concatenated_features = torch.cat([features1, features2], dim=1)
else:
raise ValueError(f"Unsupported memory format: {grid1.memory_format}")
# Create new grid with concatenated features
concatenated_grid = grid1.replace(batched_features=concatenated_features)
concatenated_grids.append(concatenated_grid)
return FactorGrid(concatenated_grids)
factor_grid_intra_communication(factor_grid: FactorGrid, communication_types: List[Literal['sum', 'mul']] = ['sum'], cat_fn: Optional[Callable] = None) -> FactorGrid
¶
Perform intra-communication between grids in a FactorGrid with multiple communication types.
Args: factor_grid: Input FactorGrid communication_types: List of communication types to apply cat_fn: Function to concatenate results from multiple communication types
Returns: FactorGrid with inter-grid communication applied
Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_intra_communication(
factor_grid: FactorGrid,
communication_types: List[Literal["sum", "mul"]] = ["sum"],
cat_fn: Optional[Callable] = None,
) -> FactorGrid:
"""Perform intra-communication between grids in a FactorGrid with multiple communication types.
Args:
factor_grid: Input FactorGrid
communication_types: List of communication types to apply
cat_fn: Function to concatenate results from multiple communication types
Returns:
FactorGrid with inter-grid communication applied
"""
if len(communication_types) == 1:
return _factor_grid_intra_communication(factor_grid, communication_types[0])
elif len(communication_types) == 2:
# Apply both communication types and concatenate
result1 = _factor_grid_intra_communication(factor_grid, communication_types[0])
result2 = _factor_grid_intra_communication(factor_grid, communication_types[1])
if cat_fn is not None:
return cat_fn(result1, result2)
else:
return factor_grid_cat(result1, result2)
else:
raise ValueError(f"Unsupported number of communication types: {len(communication_types)}")
factor_grid_pool(factor_grid: FactorGrid, pooling_type: Literal['max', 'mean', 'attention'] = 'max', pool_op: Optional[Callable] = None, attention_layer: Optional[Callable] = None) -> Tensor
¶
Pool features from FactorGrid to a single tensor.
Args: factor_grid: Input FactorGrid pooling_type: Type of pooling ("max", "mean", "attention") pool_op: Pooling operation function attention_layer: Attention layer for attention pooling
Returns: Pooled tensor of shape [B, total_channels]
Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_pool(
factor_grid: FactorGrid,
pooling_type: Literal["max", "mean", "attention"] = "max",
pool_op: Optional[Callable] = None,
attention_layer: Optional[Callable] = None,
) -> Tensor:
"""Pool features from FactorGrid to a single tensor.
Args:
factor_grid: Input FactorGrid
pooling_type: Type of pooling ("max", "mean", "attention")
pool_op: Pooling operation function
attention_layer: Attention layer for attention pooling
Returns:
Pooled tensor of shape [B, total_channels]
"""
pooled_features = []
for grid in factor_grid:
# Get the features tensor
features = grid.grid_features.batched_tensor # Shape depends on memory format
fmt = grid.memory_format
# Convert to appropriate format for pooling
if fmt == GridMemoryFormat.b_zc_x_y:
# Shape: B, Z*C, X, Y -> flatten spatial -> B, Z*C, X*Y
B, ZC, X, Y = features.shape
features_flat = features.view(B, ZC, -1)
elif fmt == GridMemoryFormat.b_xc_y_z:
# Shape: B, X*C, Y, Z -> flatten spatial -> B, X*C, Y*Z
B, XC, Y, Z = features.shape
features_flat = features.view(B, XC, -1)
elif fmt == GridMemoryFormat.b_yc_x_z:
# Shape: B, Y*C, X, Z -> flatten spatial -> B, Y*C, X*Z
B, YC, X, Z = features.shape
features_flat = features.view(B, YC, -1)
else:
raise ValueError(f"Unsupported memory format for pooling: {fmt}")
# Apply pooling directly to flattened features
if pooling_type in ["max", "mean"]:
if pool_op is not None:
pooled = pool_op(features_flat).squeeze(-1) # B, channels
elif pooling_type == "max":
pooled = F.adaptive_max_pool1d(features_flat, 1).squeeze(-1)
else: # mean
pooled = F.adaptive_avg_pool1d(features_flat, 1).squeeze(-1)
elif pooling_type == "attention":
if attention_layer is not None:
# Convert to B, N, C for attention
features_t = features_flat.transpose(1, 2) # B, N, C
attended, _ = attention_layer(features_t, features_t, features_t)
pooled = attended.mean(dim=1) # B, C
else:
# Fallback: simple mean pooling
pooled = F.adaptive_avg_pool1d(features_flat, 1).squeeze(-1)
else:
raise ValueError(f"Unsupported pooling type: {pooling_type}")
pooled_features.append(pooled)
# Concatenate pooled features from all grids
return torch.cat(pooled_features, dim=-1)
factor_grid_transform(factor_grid: FactorGrid, transform_fn: Callable[[Tensor], Tensor]) -> FactorGrid
¶
Apply a transform function to all grids in a FactorGrid.
Args: factor_grid: Input FactorGrid transform_fn: Function to apply to each grid's features
Returns: FactorGrid with transformed features
Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_transform(
factor_grid: FactorGrid,
transform_fn: Callable[[Tensor], Tensor],
) -> FactorGrid:
"""Apply a transform function to all grids in a FactorGrid.
Args:
factor_grid: Input FactorGrid
transform_fn: Function to apply to each grid's features
Returns:
FactorGrid with transformed features
"""
# Apply transform to each grid's features
transformed_grids = []
for grid in factor_grid:
# Apply transform to features and create new grid
transformed_features = transform_fn(grid.grid_features.batched_tensor)
transformed_grid = grid.replace(batched_features=transformed_features)
transformed_grids.append(transformed_grid)
return FactorGrid(transformed_grids)
Global pooling¶
warpconvnet.nn.functional.global_pool
¶
global_pool(x: Geometry, reduce: Literal['max', 'mean', 'sum']) -> Geometry
¶
Pool over all coordinates and return a single feature per batch.
Args: x: Input geometry instance to pool. reduce: Reduction type used to combine features.
Returns: Geometry object with a single coordinate and feature per batch.
Source code in warpconvnet/nn/functional/global_pool.py
def global_pool(x: Geometry, reduce: Literal["max", "mean", "sum"]) -> Geometry:
"""Pool over all coordinates and return a single feature per batch.
Args:
x: Input geometry instance to pool.
reduce: Reduction type used to combine features.
Returns:
Geometry object with a single coordinate and feature per batch.
"""
B = x.batch_size
num_spatial_dims = x.num_spatial_dims
# Generate output coordinates
output_coords = torch.zeros(B, num_spatial_dims, dtype=torch.int32, device=x.device)
output_offsets = torch.arange(B + 1, dtype=torch.int32)
# Generate output features
output_features = _global_pool(x, reduce)
return x.replace(
batched_coordinates=x.batched_coordinates.__class__(
output_coords, output_offsets
),
batched_features=x.batched_features.__class__(output_features, output_offsets),
offsets=output_offsets,
)
global_scale(x: Geometry, scale: Float[Tensor, 'B C']) -> Geometry
¶
Global scaling that generates a single feature per batch. The coordinates of the output are the simply the 0 vector.
Source code in warpconvnet/nn/functional/global_pool.py
def global_scale(x: Geometry, scale: Float[Tensor, "B C"]) -> Geometry:
"""
Global scaling that generates a single feature per batch.
The coordinates of the output are the simply the 0 vector.
"""
offsets = x.offsets
diff = offsets.diff()
B = diff.shape[0]
# assert that the scale has the same batch size
assert scale.shape[0] == B, "Scale must have the same batch size as the input"
# repeat scale for each batch
scaled_features = x.feature_tensor * torch.repeat_interleave(
scale, diff.to(scale.device), dim=0
)
return x.replace(
batched_features=scaled_features,
)
Feature transforms¶
warpconvnet.nn.functional.transforms
¶
apply_feature_transform(input: Union[Geometry, Tensor], transform: Callable)
¶
Apply a function to the feature tensor of a Geometry or raw Tensor.
Source code in warpconvnet/nn/functional/transforms.py
def apply_feature_transform(
input: Union[Geometry, Tensor],
transform: Callable,
):
"""Apply a function to the feature tensor of a Geometry or raw Tensor."""
if isinstance(input, Geometry):
return input.replace(batched_features=transform(input.feature_tensor))
else:
assert isinstance(input, Tensor), f"Expected Tensor, got {type(input)}"
return transform(input)
cat(*inputs: Geometry, dim: int = -1)
¶
Concatenate feature tensors from multiple geometries along dim.
Source code in warpconvnet/nn/functional/transforms.py
def cat(*inputs: Geometry, dim: int = -1):
"""Concatenate feature tensors from multiple geometries along ``dim``."""
# If called with a single sequence argument, unpack it
if len(inputs) == 1 and isinstance(inputs[0], Sequence):
inputs = inputs[0]
assert all(
isinstance(input, Geometry) for input in inputs
), f"Expected all inputs to be BatchedSpatialFeatures, got {type(inputs)}"
# Ignore the log, int type difference
assert all(
torch.allclose(input.offsets.long(), inputs[0].offsets.long())
for input in inputs
), "All inputs must have the same offsets"
return inputs[0].replace(
batched_features=torch.cat([input.feature_tensor for input in inputs], dim=dim),
)
create_activation_function(torch_func)
¶
Wrap a torch.nn.functional activation for use with Geometry objects.
Source code in warpconvnet/nn/functional/transforms.py
def create_activation_function(torch_func):
"""Wrap a ``torch.nn.functional`` activation for use with Geometry objects."""
def wrapper(input: Geometry):
return apply_feature_transform(input, torch_func)
return wrapper
create_norm_function(torch_norm_func)
¶
Create a geometry-aware wrapper around a normalization function.
Source code in warpconvnet/nn/functional/transforms.py
def create_norm_function(torch_norm_func):
"""Create a geometry-aware wrapper around a normalization function."""
def wrapper(input: Geometry, *args, **kwargs):
return apply_feature_transform(
input, lambda x: torch_norm_func(x, *args, **kwargs)
)
return wrapper
Encodings¶
warpconvnet.nn.functional.encodings
¶
get_freqs(num_freqs: int, data_range: float = 2.0, device: Optional[torch.device] = None) -> Float[Tensor, num_freqs]
¶
Generate logarithmically spaced frequencies used in positional encoding.
Args: num_freqs: Number of frequency bands to generate. data_range: Range of the input data that the encoding will span. device: Device to create the frequency tensor on.
Returns:
1‑D tensor containing num_freqs frequencies.
Source code in warpconvnet/nn/functional/encodings.py
def get_freqs(
num_freqs: int, data_range: float = 2.0, device: Optional[torch.device] = None
) -> Float[Tensor, "num_freqs"]: # noqa: F821
"""Generate logarithmically spaced frequencies used in positional encoding.
Args:
num_freqs: Number of frequency bands to generate.
data_range: Range of the input data that the encoding will span.
device: Device to create the frequency tensor on.
Returns:
1‑D tensor containing ``num_freqs`` frequencies.
"""
if device is None:
device = torch.device("cpu")
freqs = 2 ** torch.arange(start=0, end=num_freqs, device=device)
freqs = (2 * np.pi / data_range) * freqs
return freqs
sinusoidal_encoding(x: Float[Tensor, '... D'], num_channels: Optional[int] = None, data_range: Optional[float] = None, encoding_axis: int = -1, freqs: Optional[Float[Tensor, num_channels]] = None, concat_input: bool = False) -> Float[Tensor, '... D*num_channels']
¶
Apply sinusoidal encoding to the input tensor.
Args: x: Input tensor of any shape. num_channels: Number of channels in the output per input channel. data_range: Range of the input data (max - min). encoding_axis: Axis to apply the encoding to. If None, the encoding is applied to the last axis. freqs: Frequencies to use for the sinusoidal encoding. If None, the frequencies are calculated from the data range and num_channels. concat_input: Whether to concatenate the input to the output.
Returns: Tensor with sinusoidal encoding applied. For input shape [..., C], the output shape is [..., C*num_channels]. If concat_input is True, the input is concatenated to the output.
Source code in warpconvnet/nn/functional/encodings.py
def sinusoidal_encoding(
x: Float[Tensor, "... D"],
num_channels: Optional[int] = None,
data_range: Optional[float] = None,
encoding_axis: int = -1,
freqs: Optional[Float[Tensor, "num_channels"]] = None, # noqa: F821
concat_input: bool = False,
) -> Float[Tensor, "... D*num_channels"]:
"""
Apply sinusoidal encoding to the input tensor.
Args:
x: Input tensor of any shape.
num_channels: Number of channels in the output per input channel.
data_range: Range of the input data (max - min).
encoding_axis: Axis to apply the encoding to. If None, the encoding is applied to the last axis.
freqs: Frequencies to use for the sinusoidal encoding. If None, the frequencies are calculated from the data range and num_channels.
concat_input: Whether to concatenate the input to the output.
Returns:
Tensor with sinusoidal encoding applied.
For input shape [..., C], the output shape is [..., C*num_channels].
If concat_input is True, the input is concatenated to the output.
"""
assert encoding_axis == -1, "Only encoding_axis=-1 is supported at the moment"
x = x.unsqueeze(encoding_axis)
if freqs is None:
assert (
num_channels is not None and data_range is not None
), "num_channels and data_range must be provided if freqs are not given"
assert (
num_channels % 2 == 0
), f"num_channels must be even for sin/cos, got {num_channels}"
freqs = get_freqs(num_channels // 2, data_range, device=x.device)
freqs = freqs.reshape((1,) * (len(x.shape) - 1) + freqs.shape)
freqed_x = x * freqs
if concat_input:
return torch.cat(
[freqed_x.cos(), freqed_x.sin(), x], dim=encoding_axis
).flatten(start_dim=-2)
else:
return torch.cat([freqed_x.cos(), freqed_x.sin()], dim=encoding_axis).flatten(
start_dim=-2
)
Normalizations¶
warpconvnet.nn.functional.normalizations
¶
SegmentedLayerNormFunction
¶
Bases: Function
Custom autograd function for segmented layer normalization (core normalization only).
This function implements both forward and backward passes for the core normalization: (x - mean) / std, without gamma and beta scaling/bias parameters.
Source code in warpconvnet/nn/functional/normalizations.py
class SegmentedLayerNormFunction(Function):
"""
Custom autograd function for segmented layer normalization (core normalization only).
This function implements both forward and backward passes for the core normalization:
(x - mean) / std, without gamma and beta scaling/bias parameters.
"""
@staticmethod
def forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
"""
Forward pass for segmented layer normalization (core normalization only).
Args:
ctx: Context for saving tensors needed in backward pass
x: Input tensor of shape (N, D)
offsets: Segment boundaries of shape (K+1,)
eps: Epsilon for numerical stability
Returns:
Normalized tensor of shape (N, D) with mean=0, std=1 per segment
"""
N, D = x.shape
K = offsets.shape[0] - 1
# Convert offsets to appropriate types
offsets = offsets.to(dtype=torch.int64) # For segment_csr
d_offsets = offsets.to(x.device)
# Compute mean and variance using segmented reduction
mean = segment_csr(x, d_offsets, reduce="mean") # Shape: (K, D)
x_squared = x * x
mean_squared = segment_csr(x_squared, d_offsets, reduce="mean") # Shape: (K, D)
variance = mean_squared - mean * mean # Shape: (K, D)
# Compute standard deviation
std = torch.sqrt(variance + eps) # Shape: (K, D)
# Normalize: (x - mean) / std
output = torch.zeros_like(x)
# Subtract mean from each element
_C.utils.segmented_arithmetic(x, mean, output, d_offsets, "subtract")
# Divide by standard deviation
_C.utils.segmented_arithmetic(output, std, output, d_offsets, "divide")
# Save tensors for backward pass
ctx.save_for_backward(x, mean, std, d_offsets)
ctx.eps = eps
ctx.N = N
ctx.D = D
ctx.K = K
return output
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
"""
Backward pass for segmented layer normalization (core normalization only).
Treats mean and std as constants (detached from gradient computation).
This simplifies the backward pass significantly.
Args:
ctx: Context containing saved tensors
grad_output: Gradient of loss w.r.t. output
Returns:
Gradients w.r.t. (x, offsets, eps)
"""
x, mean, std, d_offsets = ctx.saved_tensors
# Detach mean and std to treat them as constants
mean = mean.detach()
std = std.detach()
grad_x = None
# Gradient w.r.t. x (simplified with mean and std as constants)
if ctx.needs_input_grad[0]:
# With mean and std as constants, the gradient simplifies to:
# grad_x = grad_output / std (broadcast std to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(grad_output, std, grad_x, d_offsets, "divide")
# Return gradients in the same order as forward inputs
# (x, offsets, eps)
return grad_x, None, None
backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]
staticmethod
¶
Backward pass for segmented layer normalization (core normalization only).
Treats mean and std as constants (detached from gradient computation). This simplifies the backward pass significantly.
Args: ctx: Context containing saved tensors grad_output: Gradient of loss w.r.t. output
Returns: Gradients w.r.t. (x, offsets, eps)
Source code in warpconvnet/nn/functional/normalizations.py
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
"""
Backward pass for segmented layer normalization (core normalization only).
Treats mean and std as constants (detached from gradient computation).
This simplifies the backward pass significantly.
Args:
ctx: Context containing saved tensors
grad_output: Gradient of loss w.r.t. output
Returns:
Gradients w.r.t. (x, offsets, eps)
"""
x, mean, std, d_offsets = ctx.saved_tensors
# Detach mean and std to treat them as constants
mean = mean.detach()
std = std.detach()
grad_x = None
# Gradient w.r.t. x (simplified with mean and std as constants)
if ctx.needs_input_grad[0]:
# With mean and std as constants, the gradient simplifies to:
# grad_x = grad_output / std (broadcast std to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(grad_output, std, grad_x, d_offsets, "divide")
# Return gradients in the same order as forward inputs
# (x, offsets, eps)
return grad_x, None, None
forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
staticmethod
¶
Forward pass for segmented layer normalization (core normalization only).
Args: ctx: Context for saving tensors needed in backward pass x: Input tensor of shape (N, D) offsets: Segment boundaries of shape (K+1,) eps: Epsilon for numerical stability
Returns: Normalized tensor of shape (N, D) with mean=0, std=1 per segment
Source code in warpconvnet/nn/functional/normalizations.py
@staticmethod
def forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
"""
Forward pass for segmented layer normalization (core normalization only).
Args:
ctx: Context for saving tensors needed in backward pass
x: Input tensor of shape (N, D)
offsets: Segment boundaries of shape (K+1,)
eps: Epsilon for numerical stability
Returns:
Normalized tensor of shape (N, D) with mean=0, std=1 per segment
"""
N, D = x.shape
K = offsets.shape[0] - 1
# Convert offsets to appropriate types
offsets = offsets.to(dtype=torch.int64) # For segment_csr
d_offsets = offsets.to(x.device)
# Compute mean and variance using segmented reduction
mean = segment_csr(x, d_offsets, reduce="mean") # Shape: (K, D)
x_squared = x * x
mean_squared = segment_csr(x_squared, d_offsets, reduce="mean") # Shape: (K, D)
variance = mean_squared - mean * mean # Shape: (K, D)
# Compute standard deviation
std = torch.sqrt(variance + eps) # Shape: (K, D)
# Normalize: (x - mean) / std
output = torch.zeros_like(x)
# Subtract mean from each element
_C.utils.segmented_arithmetic(x, mean, output, d_offsets, "subtract")
# Divide by standard deviation
_C.utils.segmented_arithmetic(output, std, output, d_offsets, "divide")
# Save tensors for backward pass
ctx.save_for_backward(x, mean, std, d_offsets)
ctx.eps = eps
ctx.N = N
ctx.D = D
ctx.K = K
return output
segmented_layer_norm(x: Float[Tensor, 'N D'], offsets: Int[Tensor, K + 1], gamma: Optional[Float[Tensor, 'K D']] = None, beta: Optional[Float[Tensor, 'K D']] = None, eps: float = 1e-05) -> Float[Tensor, 'N D']
¶
Layer normalization on segmented data.
This is a segmented reduction of the form:
.. math:: \gamma_k \frac{x_i - \mu_k}{\sigma_k + \epsilon} + \beta_k
where :math:\mu_k and :math:\sigma_k are the mean and standard deviation of the :math:k-th segment,
and :math:\gamma_k and :math:\beta_k are optional learnable parameters for the :math:k-th segment.
Args: x: Input tensor of shape (N, D) offsets: Segment boundaries of shape (K+1,) where K is the number of segments gamma: Optional learnable scale parameters of shape (D,) beta: Optional learnable bias parameters of shape (D,) eps: Epsilon value for numerical stability
Returns: Normalized tensor of shape (N, D)
Source code in warpconvnet/nn/functional/normalizations.py
def segmented_layer_norm(
x: Float[Tensor, "N D"],
offsets: Int[Tensor, "K+1"],
gamma: Optional[Float[Tensor, "K D"]] = None,
beta: Optional[Float[Tensor, "K D"]] = None,
eps: float = 1e-5,
) -> Float[Tensor, "N D"]:
r"""
Layer normalization on segmented data.
This is a segmented reduction of the form:
.. math::
\gamma_k \frac{x_i - \mu_k}{\sigma_k + \epsilon} + \beta_k
where :math:`\mu_k` and :math:`\sigma_k` are the mean and standard deviation of the :math:`k`-th segment,
and :math:`\gamma_k` and :math:`\beta_k` are optional learnable parameters for the :math:`k`-th segment.
Args:
x: Input tensor of shape (N, D)
offsets: Segment boundaries of shape (K+1,) where K is the number of segments
gamma: Optional learnable scale parameters of shape (D,)
beta: Optional learnable bias parameters of shape (D,)
eps: Epsilon value for numerical stability
Returns:
Normalized tensor of shape (N, D)
"""
# Apply core normalization using the autograd function
normalized: Tensor = segmented_norm(x, offsets, eps) # type: ignore[assignment]
if gamma is not None and beta is not None:
normalized = torch.addcmul(beta, gamma, normalized)
elif gamma is not None:
normalized = torch.mul(gamma.unsqueeze(0), normalized)
elif beta is not None:
normalized = torch.add(beta.unsqueeze(0), normalized)
return normalized
segmented_norm(x: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
¶
Segmented normalization.
Source code in warpconvnet/nn/functional/normalizations.py
def segmented_norm(x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
"""
Segmented normalization.
"""
return SegmentedLayerNormFunction.apply(x, offsets, eps) # type: ignore[assignment]
Segmented arithmetic¶
warpconvnet.nn.functional.segmented_arithmetics
¶
SegmentedArithmeticFunction
¶
Bases: Function
Custom autograd function for segmented arithmetic operations.
This ensures proper gradient flow through our segmented arithmetic operations.
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
class SegmentedArithmeticFunction(Function):
"""
Custom autograd function for segmented arithmetic operations.
This ensures proper gradient flow through our segmented arithmetic operations.
"""
@staticmethod
def forward(
ctx: Any,
x: Tensor,
y: Tensor,
offsets: Tensor,
operation: str,
eps: float = 1e-5,
) -> Tensor:
"""
Forward pass for segmented arithmetic.
Args:
ctx: Context for backward pass
x: Input tensor of shape (N, D)
y: Segment-wise tensor of shape (K, D)
offsets: Segment boundaries of shape (K+1,)
operation: Operation type ("add", "subtract", "multiply", "divide")
eps: Epsilon value for numerical stability
Returns:
Result tensor of shape (N, D)
"""
# Perform the operation
output = torch.zeros_like(x)
_C.utils.segmented_arithmetic(x, y, output, offsets, operation)
# Save for backward pass
ctx.save_for_backward(x, y, offsets, eps)
ctx.operation = operation
return output
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
"""
Backward pass for segmented arithmetic.
Returns:
Gradients w.r.t. (x, y, offsets, operation)
"""
x, y, offsets, eps = ctx.saved_tensors
operation = ctx.operation
grad_x = None
grad_y = None
# Gradient w.r.t. x
if ctx.needs_input_grad[0]:
if operation == "add":
grad_x = grad_output
elif operation == "subtract":
grad_x = grad_output
elif operation == "multiply":
# grad_x = grad_output * y (broadcast y to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(
grad_output, y, grad_x, offsets, "multiply"
)
elif operation == "divide":
# grad_x = grad_output / y (broadcast y to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(grad_output, y, grad_x, offsets, "divide")
# Gradient w.r.t. y
if ctx.needs_input_grad[1]:
if operation == "add":
# grad_y = sum(grad_output) per segment
grad_y = segment_csr(grad_output, offsets.to(torch.int64), reduce="sum")
elif operation == "subtract":
# grad_y = -sum(grad_output) per segment
grad_y_sum = segment_csr(
grad_output, offsets.to(torch.int64), reduce="sum"
)
grad_y = -grad_y_sum
elif operation == "multiply":
# grad_y = sum(grad_output * x) per segment
grad_y_input = grad_output * x
grad_y = segment_csr(
grad_y_input, offsets.to(torch.int64), reduce="sum"
)
elif operation == "divide":
# grad_y = -sum(grad_output * x / y^2) per segment
# First compute x / y per segment, then multiply by grad_output
x_div_y_sq = torch.zeros_like(x)
y_squared = y * y + eps
_C.utils.segmented_arithmetic(
x, y_squared, x_div_y_sq, offsets, "divide"
)
grad_y_input = -grad_output * x_div_y_sq
grad_y = segment_csr(
grad_y_input, offsets.to(torch.int64), reduce="sum"
)
return grad_x, grad_y, None, None
backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]
staticmethod
¶
Backward pass for segmented arithmetic.
Returns: Gradients w.r.t. (x, y, offsets, operation)
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
"""
Backward pass for segmented arithmetic.
Returns:
Gradients w.r.t. (x, y, offsets, operation)
"""
x, y, offsets, eps = ctx.saved_tensors
operation = ctx.operation
grad_x = None
grad_y = None
# Gradient w.r.t. x
if ctx.needs_input_grad[0]:
if operation == "add":
grad_x = grad_output
elif operation == "subtract":
grad_x = grad_output
elif operation == "multiply":
# grad_x = grad_output * y (broadcast y to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(
grad_output, y, grad_x, offsets, "multiply"
)
elif operation == "divide":
# grad_x = grad_output / y (broadcast y to segments)
grad_x = torch.zeros_like(x)
_C.utils.segmented_arithmetic(grad_output, y, grad_x, offsets, "divide")
# Gradient w.r.t. y
if ctx.needs_input_grad[1]:
if operation == "add":
# grad_y = sum(grad_output) per segment
grad_y = segment_csr(grad_output, offsets.to(torch.int64), reduce="sum")
elif operation == "subtract":
# grad_y = -sum(grad_output) per segment
grad_y_sum = segment_csr(
grad_output, offsets.to(torch.int64), reduce="sum"
)
grad_y = -grad_y_sum
elif operation == "multiply":
# grad_y = sum(grad_output * x) per segment
grad_y_input = grad_output * x
grad_y = segment_csr(
grad_y_input, offsets.to(torch.int64), reduce="sum"
)
elif operation == "divide":
# grad_y = -sum(grad_output * x / y^2) per segment
# First compute x / y per segment, then multiply by grad_output
x_div_y_sq = torch.zeros_like(x)
y_squared = y * y + eps
_C.utils.segmented_arithmetic(
x, y_squared, x_div_y_sq, offsets, "divide"
)
grad_y_input = -grad_output * x_div_y_sq
grad_y = segment_csr(
grad_y_input, offsets.to(torch.int64), reduce="sum"
)
return grad_x, grad_y, None, None
forward(ctx: Any, x: Tensor, y: Tensor, offsets: Tensor, operation: str, eps: float = 1e-05) -> Tensor
staticmethod
¶
Forward pass for segmented arithmetic.
Args: ctx: Context for backward pass x: Input tensor of shape (N, D) y: Segment-wise tensor of shape (K, D) offsets: Segment boundaries of shape (K+1,) operation: Operation type ("add", "subtract", "multiply", "divide") eps: Epsilon value for numerical stability Returns: Result tensor of shape (N, D)
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
@staticmethod
def forward(
ctx: Any,
x: Tensor,
y: Tensor,
offsets: Tensor,
operation: str,
eps: float = 1e-5,
) -> Tensor:
"""
Forward pass for segmented arithmetic.
Args:
ctx: Context for backward pass
x: Input tensor of shape (N, D)
y: Segment-wise tensor of shape (K, D)
offsets: Segment boundaries of shape (K+1,)
operation: Operation type ("add", "subtract", "multiply", "divide")
eps: Epsilon value for numerical stability
Returns:
Result tensor of shape (N, D)
"""
# Perform the operation
output = torch.zeros_like(x)
_C.utils.segmented_arithmetic(x, y, output, offsets, operation)
# Save for backward pass
ctx.save_for_backward(x, y, offsets, eps)
ctx.operation = operation
return output
segmented_add(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
¶
Segment-wise addition.
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_add(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
"""Segment-wise addition."""
return SegmentedArithmeticFunction.apply(x, y, offsets, "add", eps) # type: ignore[return-value]
segmented_divide(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
¶
Segment-wise division.
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_divide(
x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
"""Segment-wise division."""
return SegmentedArithmeticFunction.apply(x, y, offsets, "divide", eps) # type: ignore[return-value]
segmented_multiply(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
¶
Segment-wise multiplication.
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_multiply(
x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
"""Segment-wise multiplication."""
return SegmentedArithmeticFunction.apply(x, y, offsets, "multiply", eps) # type: ignore[return-value]
segmented_subtract(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor
¶
Segment-wise subtraction.
Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_subtract(
x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
"""Segment-wise subtraction."""
return SegmentedArithmeticFunction.apply(x, y, offsets, "subtract", eps) # type: ignore[return-value]
Batched matrix multiplication¶
warpconvnet.nn.functional.bmm
¶
bmm(sf: Geometry, weights: Float[Tensor, 'B C_in C_out']) -> Geometry
¶
Batch matrix multiplication.
Source code in warpconvnet/nn/functional/bmm.py
def bmm(
sf: Geometry,
weights: Float[Tensor, "B C_in C_out"],
) -> Geometry:
"""
Batch matrix multiplication.
"""
assert sf.batch_size == weights.shape[0]
if isinstance(sf.batched_features, CatFeatures):
bat_features = cat_to_pad_tensor(sf.feature_tensor, sf.offsets) # BxNxC_in
out_bat_features = torch.bmm(bat_features, weights)
out_features = pad_to_cat_tensor(out_bat_features, sf.offsets)
out_features = CatFeatures(out_features, sf.offsets)
elif isinstance(sf.batched_features, PadFeatures):
bat_features = sf.feature_tensor # BxMxC_in
out_bat_features = torch.bmm(bat_features, weights) # BxMxC_out
out_features = PadFeatures(out_bat_features, sf.offsets)
else:
raise ValueError(f"Unsupported batched features type: {type(sf.batched_features)}")
return sf.replace(
batched_features=out_features,
)