Neural Network Functionals

warpconvnet.nn.functional

Functional operators used by WarpConvNet modules.

The package intentionally keeps implementations in separate submodules (sparse_conv, point_pool, etc.) to avoid large monolithic files. Defining this initializer makes the namespace explicit so documentation generators like mkdocstrings can traverse it.

Functionals

Sparse convolution

warpconvnet.nn.functional.sparse_conv

Sparse depthwise convolution

warpconvnet.nn.functional.sparse_conv_depth

spatially_sparse_depthwise_conv(in_features: Float[Tensor, 'N C'], weight: Float[Tensor, 'K C'], kernel_map: IntSearchResult, num_out_coords: int, fwd_algo: Union[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE], None] = None, bwd_algo: Union[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE], None] = None, compute_dtype: Optional[torch.dtype] = None) -> Float[Tensor, 'M C']

Perform spatially sparse depthwise convolution.

Args: in_features: Input features of shape (N, C) weight: Depthwise convolution weights of shape (K, C) kernel_map: Kernel mapping from IntSearchResult num_out_coords: Number of output coordinates fwd_algo: Forward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space. If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE (default: explicit) bwd_algo: Backward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space. If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE (default: explicit) compute_dtype: Computation dtype (defaults to input dtype)

Returns: Output features of shape (M, C)

Source code in warpconvnet/nn/functional/sparse_conv_depth.py
def spatially_sparse_depthwise_conv(
    in_features: Float[Tensor, "N C"],
    weight: Float[Tensor, "K C"],
    kernel_map: IntSearchResult,
    num_out_coords: int,
    fwd_algo: Union[
        SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_FWD_ALGO_MODE], None
    ] = None,
    bwd_algo: Union[
        SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE, List[SPARSE_DEPTHWISE_CONV_BWD_ALGO_MODE], None
    ] = None,
    compute_dtype: Optional[torch.dtype] = None,
) -> Float[Tensor, "M C"]:
    """
    Perform spatially sparse depthwise convolution.

    Args:
        in_features: Input features of shape (N, C)
        weight: Depthwise convolution weights of shape (K, C)
        kernel_map: Kernel mapping from IntSearchResult
        num_out_coords: Number of output coordinates
        fwd_algo: Forward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space.
                  If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE (default: explicit)
        bwd_algo: Backward algorithm(s) to use. Can be a single algorithm or a list to limit benchmark search space.
                  If None, uses environment variable WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE (default: explicit)
        compute_dtype: Computation dtype (defaults to input dtype)

    Returns:
        Output features of shape (M, C)
    """
    # Use environment variables if no explicit algorithms provided
    if fwd_algo is None:
        fwd_algo = WARPCONVNET_DEPTHWISE_CONV_FWD_ALGO_MODE
    if bwd_algo is None:
        bwd_algo = WARPCONVNET_DEPTHWISE_CONV_BWD_ALGO_MODE

    return UnifiedSpatiallySparseDepthwiseConvFunction.apply(
        in_features,
        weight,
        kernel_map,
        num_out_coords,
        fwd_algo,
        bwd_algo,
        compute_dtype,
    )

Sparse pooling

warpconvnet.nn.functional.sparse_pool

sparse_avg_pool(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None) -> Voxels

Average pooling for spatially sparse tensors.

Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_avg_pool(
    voxels: Voxels,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Voxels:
    """
    Average pooling for spatially sparse tensors.
    """
    return sparse_reduce(voxels, kernel_size, stride, reduction=REDUCTIONS.MEAN)

sparse_max_pool(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None) -> Voxels

Max pooling for spatially sparse tensors.

Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_max_pool(
    voxels: Voxels,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Voxels:
    """
    Max pooling for spatially sparse tensors.
    """
    return sparse_reduce(voxels, kernel_size, stride, reduction=REDUCTIONS.MAX)

sparse_reduce(voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Optional[Union[int, Tuple[int, ...]]] = None, reduction: Union[REDUCTIONS, str] = REDUCTIONS.MAX, order: POINT_ORDERING = POINT_ORDERING.RANDOM) -> Voxels

Max pooling for spatially sparse tensors.

Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_reduce(
    voxels: Voxels,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Optional[Union[int, Tuple[int, ...]]] = None,
    reduction: Union[REDUCTIONS, str] = REDUCTIONS.MAX,
    order: POINT_ORDERING = POINT_ORDERING.RANDOM,
) -> Voxels:
    """
    Max pooling for spatially sparse tensors.
    """
    if isinstance(reduction, str):
        reduction = REDUCTIONS(reduction)

    if stride is None:
        stride = kernel_size
    ndim = voxels.num_spatial_dims
    stride = ntuple(stride, ndim=ndim)
    kernel_size = ntuple(kernel_size, ndim=ndim)

    in_tensor_stride = voxels.stride
    if in_tensor_stride is None:
        in_tensor_stride = ntuple(1, ndim=ndim)
    out_tensor_stride = tuple(o * s for o, s in zip(stride, in_tensor_stride))

    batch_indexed_in_coords = voxels.batch_indexed_coordinates
    batch_indexed_out_coords, output_offsets = stride_coords(
        batch_indexed_in_coords, stride, order=order
    )
    from warpconvnet.nn.functional.sparse_conv import STRIDED_CONV_MODE

    kernel_map_cache_key = IntSearchCacheKey(
        kernel_size=kernel_size,
        kernel_dilation=ntuple(1, ndim=ndim),
        transposed=False,
        generative=False,
        stride_mode=str(STRIDED_CONV_MODE.STRIDE_ONLY),
        skip_symmetric_kernel_map=False,
        in_offsets=voxels.offsets,
        out_offsets=output_offsets,
    )
    kernel_map = None
    if voxels.cache is not None:
        kernel_map = voxels.cache.get(kernel_map_cache_key)

    if kernel_map is None:
        # Find mapping from in to out
        kernel_map: IntSearchResult = generate_kernel_map(
            batch_indexed_in_coords,
            batch_indexed_out_coords,
            in_to_out_stride_ratio=stride,
            kernel_size=kernel_size,
            kernel_dilation=ntuple(1, ndim=ndim),
            skip_symmetric_kernel_map=False,
        )

    if voxels.cache is None:
        voxels._extra_attributes["_cache"] = IntSearchCache()
    voxels.cache.put(kernel_map_cache_key, kernel_map)

    in_maps, unique_out_maps, map_offsets = kernel_map.to_csr()
    in_features = voxels.feature_tensor
    device = in_features.device

    out_features = row_reduction(in_features[in_maps], map_offsets.to(device), reduction)

    if len(unique_out_maps) != batch_indexed_out_coords.shape[0]:
        warnings.warn(
            f"Some output coordinates don't have any input maps. {batch_indexed_out_coords.shape[0] - len(unique_out_maps)} output coordinates are missing.",
            stacklevel=2,
        )

        # cchoy: This is a rare case where some output coordinates don't have any input maps.
        # We need to zero out the features for those coordinates.
        new_out_features = torch.zeros(
            batch_indexed_out_coords.shape[0],
            in_features.shape[1],
            device=voxels.device,
            dtype=voxels.dtype,
        )
        new_out_features[unique_out_maps] = out_features
        out_features = new_out_features

    output_offsets = output_offsets.cpu()
    return voxels.replace(
        batched_coordinates=IntCoords(
            batch_indexed_out_coords[:, 1:],
            output_offsets,
        ),
        batched_features=out_features,
        stride=out_tensor_stride,
    )

sparse_unpool(pooled_voxels: Voxels, unpooled_voxels: Voxels, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]], concat_unpooled_voxels: bool = False) -> Voxels

Unpooling for spatially sparse tensors.

Source code in warpconvnet/nn/functional/sparse_pool.py
def sparse_unpool(
    pooled_voxels: Voxels,
    unpooled_voxels: Voxels,
    kernel_size: Union[int, Tuple[int, ...]],
    stride: Union[int, Tuple[int, ...]],
    concat_unpooled_voxels: bool = False,
) -> Voxels:
    """
    Unpooling for spatially sparse tensors.
    """
    ndim = pooled_voxels.num_spatial_dims
    stride = ntuple(stride, ndim=ndim)
    kernel_size = ntuple(kernel_size, ndim=ndim)

    # use the cache for the transposed case to get the kernel map
    from warpconvnet.nn.functional.sparse_conv import STRIDED_CONV_MODE

    kernel_map_cache_key = IntSearchCacheKey(
        kernel_size=kernel_size,
        kernel_dilation=ntuple(1, ndim=ndim),
        transposed=False,
        generative=False,
        stride_mode=str(STRIDED_CONV_MODE.STRIDE_ONLY),
        skip_symmetric_kernel_map=False,
        in_offsets=unpooled_voxels.offsets,
        out_offsets=pooled_voxels.offsets,
    )
    assert pooled_voxels.cache is not None
    kernel_map = pooled_voxels.cache.get(kernel_map_cache_key)
    assert kernel_map is not None

    # Switch
    unpooled_maps = kernel_map.in_maps
    pooled_maps = kernel_map.out_maps

    perm = torch.argsort(unpooled_maps)
    rep_feats = pooled_voxels.feature_tensor[pooled_maps[perm]]
    if concat_unpooled_voxels:
        rep_feats = torch.cat([unpooled_voxels.feature_tensor, rep_feats], dim=-1)

    return unpooled_voxels.replace(batched_features=rep_feats)

Sparse ops helpers

warpconvnet.nn.functional.sparse_ops

cat_spatially_sparse_tensors(*sparse_tensors: Sequence[Voxels]) -> Voxels

Concatenate a list of spatially sparse tensors.

Source code in warpconvnet/nn/functional/sparse_ops.py
def cat_spatially_sparse_tensors(
    *sparse_tensors: Sequence[Voxels],
) -> Voxels:
    """
    Concatenate a list of spatially sparse tensors.
    """
    # Check that all sparse tensors have the same offsets
    offsets = sparse_tensors[0].offsets
    for sparse_tensor in sparse_tensors:
        if not torch.allclose(sparse_tensor.offsets.to(offsets), offsets):
            raise ValueError("All sparse tensors must have the same offsets")

    # Concatenate the features tensors
    features_tensor = torch.cat(
        [sparse_tensor.feature_tensor for sparse_tensor in sparse_tensors], dim=-1
    )
    return sparse_tensors[0].replace(batched_features=features_tensor)

prune_spatially_sparse_tensor(spatial_tensor: Geometry, mask: Bool[Tensor, N]) -> Geometry

Prune a spatially sparse tensor using a boolean mask.

Args: spatial_tensor: Geometry instance whose coordinates/features will be filtered. mask: Boolean mask of shape (N,) aligned with the flattened coordinates/features.

Returns: New Geometry instance containing only the entries where mask == True.

Source code in warpconvnet/nn/functional/sparse_ops.py
def prune_spatially_sparse_tensor(
    spatial_tensor: Geometry,
    mask: Bool[Tensor, "N"],  # noqa: F821
) -> Geometry:
    """
    Prune a spatially sparse tensor using a boolean mask.

    Args:
        spatial_tensor: Geometry instance whose coordinates/features will be filtered.
        mask: Boolean mask of shape (N,) aligned with the flattened coordinates/features.

    Returns:
        New Geometry instance containing only the entries where mask == True.
    """
    if mask.shape[0] != spatial_tensor.coordinate_tensor.shape[0]:
        raise ValueError(
            f"Mask length {mask.shape[0]} must match number of coordinates {spatial_tensor.coordinate_tensor.shape[0]}"
        )

    mask = mask.to(spatial_tensor.device)
    if mask.dtype != torch.bool:
        mask = mask.bool()

    coords = spatial_tensor.batched_coordinates
    if not hasattr(coords, "prune"):
        raise TypeError(f"{coords.__class__.__name__} does not implement prune()")

    pruned_coords = coords.prune(mask)
    pruned_features = spatial_tensor.feature_tensor[mask]
    return spatial_tensor.replace(
        batched_coordinates=pruned_coords,
        batched_features=pruned_features,
    )

Point pooling

warpconvnet.nn.functional.point_pool

point_pool(pc: Points, reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR], downsample_max_num_points: Optional[int] = None, downsample_voxel_size: Optional[float] = None, return_type: Literal['point', 'voxel'] = 'point', average_pooled_coordinates: bool = False, return_neighbor_search_result: bool = False, return_to_unique: bool = False, unique_method: Literal['torch', 'ravel', 'morton'] = 'torch') -> Geometry

Pool points in a point cloud. When downsample_max_num_points is provided, the point cloud will be downsampled to the number of points. When downsample_voxel_size is provided, the point cloud will be downsampled to the voxel size. When both are provided, the point cloud will be downsampled to the voxel size.

Args: pc: Points reduction: Reduction type downsample_max_num_points: Number of points to downsample to downsample_voxel_size: Voxel size to downsample to return_type: Return type return_neighbor_search_result: Return neighbor search result return_to_unique: Return to unique object Returns: Points or Voxels

Source code in warpconvnet/nn/functional/point_pool.py
def point_pool(
    pc: "Points",  # noqa: F821
    reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR],
    downsample_max_num_points: Optional[int] = None,
    downsample_voxel_size: Optional[float] = None,
    return_type: Literal["point", "voxel"] = "point",
    average_pooled_coordinates: bool = False,
    return_neighbor_search_result: bool = False,
    return_to_unique: bool = False,
    unique_method: Literal["torch", "ravel", "morton"] = "torch",
) -> Geometry:
    """
    Pool points in a point cloud.
    When downsample_max_num_points is provided, the point cloud will be downsampled to the number of points.
    When downsample_voxel_size is provided, the point cloud will be downsampled to the voxel size.
    When both are provided, the point cloud will be downsampled to the voxel size.

    Args:
        pc: Points
        reduction: Reduction type
        downsample_max_num_points: Number of points to downsample to
        downsample_voxel_size: Voxel size to downsample to
        return_type: Return type
        return_neighbor_search_result: Return neighbor search result
        return_to_unique: Return to unique object
    Returns:
        Points or Voxels
    """
    from warpconvnet.geometry.types.points import Points
    from warpconvnet.geometry.types.voxels import Voxels

    if isinstance(reduction, str):
        reduction = REDUCTIONS(reduction)
    # assert at least one of the two is provided
    assert (
        downsample_max_num_points is not None or downsample_voxel_size is not None
    ), "Either downsample_num_points or downsample_voxel_size must be provided."
    assert return_type in [
        "point",
        "voxel",
    ], "return_type must be either point or voxel."
    if return_type == "voxel":
        assert (
            not average_pooled_coordinates
        ), "averaging pooled coordinates is not supported for Voxels return type"
        RETURN_CLS = Voxels
    else:
        RETURN_CLS = Points

    if downsample_max_num_points is not None:
        assert (
            not return_to_unique
        ), "return_to_unique must be False when downsample_max_num_points is provided."
        return _pool_by_max_num_points(
            pc,
            reduction,
            downsample_max_num_points,
            return_type,
            return_neighbor_search_result,
        )

    if reduction == REDUCTIONS.RANDOM:
        assert (
            not return_to_unique
        ), "return_to_unique must be False when reduction is RANDOM."
        assert (
            not return_neighbor_search_result
        ), "return_neighbor_search_result must be False when reduction is RANDOM."
        return _pool_by_random_sample(
            pc,
            downsample_voxel_size,
            return_type,
        )

    # voxel downsample
    (
        unique_coords,
        unique_offsets,
        to_csr_indices,
        to_csr_offsets,
        to_unique,
    ) = voxel_downsample_csr_mapping(
        batched_points=pc.coordinate_tensor,
        offsets=pc.offsets,
        voxel_size=downsample_voxel_size,
        unique_method=unique_method,
    )

    down_features = row_reduction(
        pc.feature_tensor[to_csr_indices],
        to_csr_offsets,
        reduction=reduction,
    )

    batched_coords = _generate_batched_coords(
        pc.coordinate_tensor,
        return_type,
        to_csr_indices,
        to_csr_offsets,
        to_unique.to_unique_indices,
        unique_offsets,
        downsample_voxel_size,
        average_pooled_coordinates,
    )

    out_sf = RETURN_CLS(
        batched_coordinates=batched_coords,
        batched_features=down_features,
        voxel_size=downsample_voxel_size,
    )

    if return_to_unique:
        return out_sf, to_unique
    if return_neighbor_search_result:
        return out_sf, RealSearchResult(to_unique.to_unique_indices, unique_offsets)
    return out_sf

point_pool_by_code(pc: Points, code: Int[Tensor, N], reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR], average_pooled_coordinates: bool = False, return_to_unique: bool = False) -> Points

Pool points based on a user-provided clustering code.

Source code in warpconvnet/nn/functional/point_pool.py
def point_pool_by_code(
    pc: "Points",  # noqa: F821
    code: Int[Tensor, "N"],  # noqa: F821
    reduction: Union[REDUCTIONS | REDUCTION_TYPES_STR],
    average_pooled_coordinates: bool = False,
    return_to_unique: bool = False,
) -> "Points":  # noqa: F821
    """Pool points based on a user-provided clustering code."""
    from warpconvnet.geometry.types.points import Points

    if isinstance(reduction, str):
        reduction = REDUCTIONS(reduction)

    # get the unique indices
    to_unique = ToUnique(return_to_unique_indices=return_to_unique)
    unique_code = to_unique.to_unique(code)

    # get the coordinates
    if average_pooled_coordinates:
        coords = row_reduction(
            pc.coordinate_tensor[to_unique.to_csr_indices],
            to_unique.to_csr_offsets,
            reduction=REDUCTIONS.MEAN,
        )
    else:
        coords = pc.coordinate_tensor[to_unique.to_unique_indices]

    # get the features
    features = row_reduction(
        pc.feature_tensor[to_unique.to_csr_indices],
        to_unique.to_csr_offsets,
        reduction=reduction,
    )

    # get the offsets
    offsets = offsets_from_offsets(
        pc.offsets,
        to_unique.to_unique_indices,
        device="cpu"
    )

    out_pc = pc.replace(
        coordinate_tensor=coords,
        feature_tensor=features,
        offsets=offsets,
        code=unique_code,
    )
    if return_to_unique:
        return out_pc, to_unique
    return out_pc

Point unpooling

warpconvnet.nn.functional.point_unpool

point_unpool(pooled_pc: BatchedSpatialFeatures, unpooled_pc: Points, concat_unpooled_pc: bool, unpooling_mode: Optional[Union[str, FEATURE_UNPOOLING_MODE]] = FEATURE_UNPOOLING_MODE.REPEAT, to_unique: Optional[ToUnique] = None) -> Points

Unpool features to a denser set of points.

Source code in warpconvnet/nn/functional/point_unpool.py
def point_unpool(
    pooled_pc: "BatchedSpatialFeatures",  # noqa: F821
    unpooled_pc: "Points",  # noqa: F821
    concat_unpooled_pc: bool,
    unpooling_mode: Optional[
        Union[str, FEATURE_UNPOOLING_MODE]
    ] = FEATURE_UNPOOLING_MODE.REPEAT,
    to_unique: Optional[ToUnique] = None,
) -> "Points":  # noqa: F821
    """Unpool features to a denser set of points."""
    unpooled_features = _unpool_features(
        pooled_pc=pooled_pc,
        unpooled_pc=unpooled_pc,
        to_unique=to_unique,
        unpooling_mode=unpooling_mode,
    )
    if concat_unpooled_pc:
        unpooled_features = torch.cat(
            (unpooled_features, unpooled_pc.feature_tensor), dim=-1
        )
    return unpooled_pc.replace(batched_features=unpooled_features)

Grid convolution

warpconvnet.nn.functional.grid_conv

grid_conv(grid: Grid, weight: Float[Tensor, 'C_out C_in D H W'], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, bias: bool = True) -> Grid

3D Convolution on a Grid geometry type.

It is a simple wrapper on torch.nn.functional.conv3d. The output grid shape is computed as follows:

D_out = ((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0]) + 1 H_out = ((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1]) + 1 W_out = ((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[2]) + 1

Args: grid: Grid geometry type weight: Weight tensor stride: Stride padding: Padding dilation: Dilation bias: Bias

Returns: Grid: Output grid

Source code in warpconvnet/nn/functional/grid_conv.py
def grid_conv(
    grid: Grid,
    weight: Float[Tensor, "C_out C_in D H W"],  # noqa: F821
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    bias: bool = True,
) -> Grid:
    """
    3D Convolution on a Grid geometry type.

    It is a simple wrapper on torch.nn.functional.conv3d.
    The output grid shape is computed as follows:

    D_out = ((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0]) + 1
    H_out = ((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1]) + 1
    W_out = ((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[2]) + 1

    Args:
        grid: Grid geometry type
        weight: Weight tensor
        stride: Stride
        padding: Padding
        dilation: Dilation
        bias: Bias

    Returns:
        Grid: Output grid

    """
    # Use F.conv3d
    if grid.memory_format != GridMemoryFormat.b_c_z_x_y:
        warnings.warn(
            f"Input grid memory format is {grid.memory_format}, converting to {GridMemoryFormat.b_c_z_x_y}"
        )
        grid = grid.to_memory_format(GridMemoryFormat.b_c_z_x_y)

    # Apply convolution
    output_tensor = F.conv3d(grid.features, weight, bias, stride, padding, dilation)
    # For stride, padding, dilation, the grid shape may not match the output shape
    D, H, W = tuple(output_tensor.shape[2:])

    # Create a new Grid with the same coordinates but updated features
    return Grid(
        batched_coordinates=GridCoords.from_shape(
            grid_shape=(H, W, D),
            bounds=grid.bounds,
            batch_size=grid.batch_size,
            device=grid.device,
        ),
        batched_features=output_tensor,
        memory_format=GridMemoryFormat.b_c_z_x_y,
    )

Factor grid

warpconvnet.nn.functional.factor_grid

factor_grid_cat(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid

Concatenate features from two FactorGrid objects.

Args: factor_grid1: First FactorGrid factor_grid2: Second FactorGrid

Returns: FactorGrid with concatenated features

Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_cat(factor_grid1: FactorGrid, factor_grid2: FactorGrid) -> FactorGrid:
    """Concatenate features from two FactorGrid objects.

    Args:
        factor_grid1: First FactorGrid
        factor_grid2: Second FactorGrid

    Returns:
        FactorGrid with concatenated features
    """
    assert len(factor_grid1) == len(
        factor_grid2
    ), f"FactorGrid lengths must match: {len(factor_grid1)} != {len(factor_grid2)}"

    concatenated_grids = []
    for grid1, grid2 in zip(factor_grid1, factor_grid2):
        # Get features from both grids
        features1 = grid1.grid_features.batched_tensor
        features2 = grid2.grid_features.batched_tensor

        # Concatenate along channel dimension based on memory format
        if grid1.memory_format == GridMemoryFormat.b_x_y_z_c:
            # Channel is last dimension
            concatenated_features = torch.cat([features1, features2], dim=-1)
        elif grid1.memory_format == GridMemoryFormat.b_c_x_y_z:
            # Channel is second dimension
            concatenated_features = torch.cat([features1, features2], dim=1)
        elif grid1.memory_format in [
            GridMemoryFormat.b_zc_x_y,
            GridMemoryFormat.b_xc_y_z,
            GridMemoryFormat.b_yc_x_z,
        ]:
            # For factorized formats, channel is combined with spatial dimension
            concatenated_features = torch.cat([features1, features2], dim=1)
        else:
            raise ValueError(f"Unsupported memory format: {grid1.memory_format}")

        # Create new grid with concatenated features
        concatenated_grid = grid1.replace(batched_features=concatenated_features)
        concatenated_grids.append(concatenated_grid)

    return FactorGrid(concatenated_grids)

factor_grid_intra_communication(factor_grid: FactorGrid, communication_types: List[Literal['sum', 'mul']] = ['sum'], cat_fn: Optional[Callable] = None) -> FactorGrid

Perform intra-communication between grids in a FactorGrid with multiple communication types.

Args: factor_grid: Input FactorGrid communication_types: List of communication types to apply cat_fn: Function to concatenate results from multiple communication types

Returns: FactorGrid with inter-grid communication applied

Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_intra_communication(
    factor_grid: FactorGrid,
    communication_types: List[Literal["sum", "mul"]] = ["sum"],
    cat_fn: Optional[Callable] = None,
) -> FactorGrid:
    """Perform intra-communication between grids in a FactorGrid with multiple communication types.

    Args:
        factor_grid: Input FactorGrid
        communication_types: List of communication types to apply
        cat_fn: Function to concatenate results from multiple communication types

    Returns:
        FactorGrid with inter-grid communication applied
    """
    if len(communication_types) == 1:
        return _factor_grid_intra_communication(factor_grid, communication_types[0])
    elif len(communication_types) == 2:
        # Apply both communication types and concatenate
        result1 = _factor_grid_intra_communication(factor_grid, communication_types[0])
        result2 = _factor_grid_intra_communication(factor_grid, communication_types[1])

        if cat_fn is not None:
            return cat_fn(result1, result2)
        else:
            return factor_grid_cat(result1, result2)
    else:
        raise ValueError(f"Unsupported number of communication types: {len(communication_types)}")

factor_grid_pool(factor_grid: FactorGrid, pooling_type: Literal['max', 'mean', 'attention'] = 'max', pool_op: Optional[Callable] = None, attention_layer: Optional[Callable] = None) -> Tensor

Pool features from FactorGrid to a single tensor.

Args: factor_grid: Input FactorGrid pooling_type: Type of pooling ("max", "mean", "attention") pool_op: Pooling operation function attention_layer: Attention layer for attention pooling

Returns: Pooled tensor of shape [B, total_channels]

Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_pool(
    factor_grid: FactorGrid,
    pooling_type: Literal["max", "mean", "attention"] = "max",
    pool_op: Optional[Callable] = None,
    attention_layer: Optional[Callable] = None,
) -> Tensor:
    """Pool features from FactorGrid to a single tensor.

    Args:
        factor_grid: Input FactorGrid
        pooling_type: Type of pooling ("max", "mean", "attention")
        pool_op: Pooling operation function
        attention_layer: Attention layer for attention pooling

    Returns:
        Pooled tensor of shape [B, total_channels]
    """
    pooled_features = []

    for grid in factor_grid:
        # Get the features tensor
        features = grid.grid_features.batched_tensor  # Shape depends on memory format
        fmt = grid.memory_format

        # Convert to appropriate format for pooling
        if fmt == GridMemoryFormat.b_zc_x_y:
            # Shape: B, Z*C, X, Y -> flatten spatial -> B, Z*C, X*Y
            B, ZC, X, Y = features.shape
            features_flat = features.view(B, ZC, -1)
        elif fmt == GridMemoryFormat.b_xc_y_z:
            # Shape: B, X*C, Y, Z -> flatten spatial -> B, X*C, Y*Z
            B, XC, Y, Z = features.shape
            features_flat = features.view(B, XC, -1)
        elif fmt == GridMemoryFormat.b_yc_x_z:
            # Shape: B, Y*C, X, Z -> flatten spatial -> B, Y*C, X*Z
            B, YC, X, Z = features.shape
            features_flat = features.view(B, YC, -1)
        else:
            raise ValueError(f"Unsupported memory format for pooling: {fmt}")

        # Apply pooling directly to flattened features
        if pooling_type in ["max", "mean"]:
            if pool_op is not None:
                pooled = pool_op(features_flat).squeeze(-1)  # B, channels
            elif pooling_type == "max":
                pooled = F.adaptive_max_pool1d(features_flat, 1).squeeze(-1)
            else:  # mean
                pooled = F.adaptive_avg_pool1d(features_flat, 1).squeeze(-1)
        elif pooling_type == "attention":
            if attention_layer is not None:
                # Convert to B, N, C for attention
                features_t = features_flat.transpose(1, 2)  # B, N, C
                attended, _ = attention_layer(features_t, features_t, features_t)
                pooled = attended.mean(dim=1)  # B, C
            else:
                # Fallback: simple mean pooling
                pooled = F.adaptive_avg_pool1d(features_flat, 1).squeeze(-1)
        else:
            raise ValueError(f"Unsupported pooling type: {pooling_type}")

        pooled_features.append(pooled)

    # Concatenate pooled features from all grids
    return torch.cat(pooled_features, dim=-1)

factor_grid_transform(factor_grid: FactorGrid, transform_fn: Callable[[Tensor], Tensor]) -> FactorGrid

Apply a transform function to all grids in a FactorGrid.

Args: factor_grid: Input FactorGrid transform_fn: Function to apply to each grid's features

Returns: FactorGrid with transformed features

Source code in warpconvnet/nn/functional/factor_grid.py
def factor_grid_transform(
    factor_grid: FactorGrid,
    transform_fn: Callable[[Tensor], Tensor],
) -> FactorGrid:
    """Apply a transform function to all grids in a FactorGrid.

    Args:
        factor_grid: Input FactorGrid
        transform_fn: Function to apply to each grid's features

    Returns:
        FactorGrid with transformed features
    """
    # Apply transform to each grid's features
    transformed_grids = []
    for grid in factor_grid:
        # Apply transform to features and create new grid
        transformed_features = transform_fn(grid.grid_features.batched_tensor)
        transformed_grid = grid.replace(batched_features=transformed_features)
        transformed_grids.append(transformed_grid)

    return FactorGrid(transformed_grids)

Global pooling

warpconvnet.nn.functional.global_pool

global_pool(x: Geometry, reduce: Literal['max', 'mean', 'sum']) -> Geometry

Pool over all coordinates and return a single feature per batch.

Args: x: Input geometry instance to pool. reduce: Reduction type used to combine features.

Returns: Geometry object with a single coordinate and feature per batch.

Source code in warpconvnet/nn/functional/global_pool.py
def global_pool(x: Geometry, reduce: Literal["max", "mean", "sum"]) -> Geometry:
    """Pool over all coordinates and return a single feature per batch.

    Args:
        x: Input geometry instance to pool.
        reduce: Reduction type used to combine features.

    Returns:
        Geometry object with a single coordinate and feature per batch.
    """
    B = x.batch_size
    num_spatial_dims = x.num_spatial_dims
    # Generate output coordinates
    output_coords = torch.zeros(B, num_spatial_dims, dtype=torch.int32, device=x.device)
    output_offsets = torch.arange(B + 1, dtype=torch.int32)

    # Generate output features
    output_features = _global_pool(x, reduce)

    return x.replace(
        batched_coordinates=x.batched_coordinates.__class__(
            output_coords, output_offsets
        ),
        batched_features=x.batched_features.__class__(output_features, output_offsets),
        offsets=output_offsets,
    )

global_scale(x: Geometry, scale: Float[Tensor, 'B C']) -> Geometry

Global scaling that generates a single feature per batch. The coordinates of the output are the simply the 0 vector.

Source code in warpconvnet/nn/functional/global_pool.py
def global_scale(x: Geometry, scale: Float[Tensor, "B C"]) -> Geometry:
    """
    Global scaling that generates a single feature per batch.
    The coordinates of the output are the simply the 0 vector.
    """
    offsets = x.offsets
    diff = offsets.diff()
    B = diff.shape[0]
    # assert that the scale has the same batch size
    assert scale.shape[0] == B, "Scale must have the same batch size as the input"
    # repeat scale for each batch
    scaled_features = x.feature_tensor * torch.repeat_interleave(
        scale, diff.to(scale.device), dim=0
    )
    return x.replace(
        batched_features=scaled_features,
    )

Feature transforms

warpconvnet.nn.functional.transforms

apply_feature_transform(input: Union[Geometry, Tensor], transform: Callable)

Apply a function to the feature tensor of a Geometry or raw Tensor.

Source code in warpconvnet/nn/functional/transforms.py
def apply_feature_transform(
    input: Union[Geometry, Tensor],
    transform: Callable,
):
    """Apply a function to the feature tensor of a Geometry or raw Tensor."""
    if isinstance(input, Geometry):
        return input.replace(batched_features=transform(input.feature_tensor))
    else:
        assert isinstance(input, Tensor), f"Expected Tensor, got {type(input)}"
        return transform(input)

cat(*inputs: Geometry, dim: int = -1)

Concatenate feature tensors from multiple geometries along dim.

Source code in warpconvnet/nn/functional/transforms.py
def cat(*inputs: Geometry, dim: int = -1):
    """Concatenate feature tensors from multiple geometries along ``dim``."""
    # If called with a single sequence argument, unpack it
    if len(inputs) == 1 and isinstance(inputs[0], Sequence):
        inputs = inputs[0]
    assert all(
        isinstance(input, Geometry) for input in inputs
    ), f"Expected all inputs to be BatchedSpatialFeatures, got {type(inputs)}"
    # Ignore the log, int type difference
    assert all(
        torch.allclose(input.offsets.long(), inputs[0].offsets.long())
        for input in inputs
    ), "All inputs must have the same offsets"
    return inputs[0].replace(
        batched_features=torch.cat([input.feature_tensor for input in inputs], dim=dim),
    )

create_activation_function(torch_func)

Wrap a torch.nn.functional activation for use with Geometry objects.

Source code in warpconvnet/nn/functional/transforms.py
def create_activation_function(torch_func):
    """Wrap a ``torch.nn.functional`` activation for use with Geometry objects."""

    def wrapper(input: Geometry):
        return apply_feature_transform(input, torch_func)

    return wrapper

create_norm_function(torch_norm_func)

Create a geometry-aware wrapper around a normalization function.

Source code in warpconvnet/nn/functional/transforms.py
def create_norm_function(torch_norm_func):
    """Create a geometry-aware wrapper around a normalization function."""

    def wrapper(input: Geometry, *args, **kwargs):
        return apply_feature_transform(
            input, lambda x: torch_norm_func(x, *args, **kwargs)
        )

    return wrapper

Encodings

warpconvnet.nn.functional.encodings

get_freqs(num_freqs: int, data_range: float = 2.0, device: Optional[torch.device] = None) -> Float[Tensor, num_freqs]

Generate logarithmically spaced frequencies used in positional encoding.

Args: num_freqs: Number of frequency bands to generate. data_range: Range of the input data that the encoding will span. device: Device to create the frequency tensor on.

Returns: 1‑D tensor containing num_freqs frequencies.

Source code in warpconvnet/nn/functional/encodings.py
def get_freqs(
    num_freqs: int, data_range: float = 2.0, device: Optional[torch.device] = None
) -> Float[Tensor, "num_freqs"]:  # noqa: F821
    """Generate logarithmically spaced frequencies used in positional encoding.

    Args:
        num_freqs: Number of frequency bands to generate.
        data_range: Range of the input data that the encoding will span.
        device: Device to create the frequency tensor on.

    Returns:
        1‑D tensor containing ``num_freqs`` frequencies.
    """
    if device is None:
        device = torch.device("cpu")
    freqs = 2 ** torch.arange(start=0, end=num_freqs, device=device)
    freqs = (2 * np.pi / data_range) * freqs
    return freqs

sinusoidal_encoding(x: Float[Tensor, '... D'], num_channels: Optional[int] = None, data_range: Optional[float] = None, encoding_axis: int = -1, freqs: Optional[Float[Tensor, num_channels]] = None, concat_input: bool = False) -> Float[Tensor, '... D*num_channels']

Apply sinusoidal encoding to the input tensor.

Args: x: Input tensor of any shape. num_channels: Number of channels in the output per input channel. data_range: Range of the input data (max - min). encoding_axis: Axis to apply the encoding to. If None, the encoding is applied to the last axis. freqs: Frequencies to use for the sinusoidal encoding. If None, the frequencies are calculated from the data range and num_channels. concat_input: Whether to concatenate the input to the output.

Returns: Tensor with sinusoidal encoding applied. For input shape [..., C], the output shape is [..., C*num_channels]. If concat_input is True, the input is concatenated to the output.

Source code in warpconvnet/nn/functional/encodings.py
def sinusoidal_encoding(
    x: Float[Tensor, "... D"],
    num_channels: Optional[int] = None,
    data_range: Optional[float] = None,
    encoding_axis: int = -1,
    freqs: Optional[Float[Tensor, "num_channels"]] = None,  # noqa: F821
    concat_input: bool = False,
) -> Float[Tensor, "... D*num_channels"]:
    """
    Apply sinusoidal encoding to the input tensor.

    Args:
        x: Input tensor of any shape.
        num_channels: Number of channels in the output per input channel.
        data_range: Range of the input data (max - min).
        encoding_axis: Axis to apply the encoding to. If None, the encoding is applied to the last axis.
        freqs: Frequencies to use for the sinusoidal encoding. If None, the frequencies are calculated from the data range and num_channels.
        concat_input: Whether to concatenate the input to the output.

    Returns:
        Tensor with sinusoidal encoding applied.
        For input shape [..., C], the output shape is [..., C*num_channels].
        If concat_input is True, the input is concatenated to the output.
    """
    assert encoding_axis == -1, "Only encoding_axis=-1 is supported at the moment"

    x = x.unsqueeze(encoding_axis)
    if freqs is None:
        assert (
            num_channels is not None and data_range is not None
        ), "num_channels and data_range must be provided if freqs are not given"
        assert (
            num_channels % 2 == 0
        ), f"num_channels must be even for sin/cos, got {num_channels}"
        freqs = get_freqs(num_channels // 2, data_range, device=x.device)
    freqs = freqs.reshape((1,) * (len(x.shape) - 1) + freqs.shape)
    freqed_x = x * freqs
    if concat_input:
        return torch.cat(
            [freqed_x.cos(), freqed_x.sin(), x], dim=encoding_axis
        ).flatten(start_dim=-2)
    else:
        return torch.cat([freqed_x.cos(), freqed_x.sin()], dim=encoding_axis).flatten(
            start_dim=-2
        )

Normalizations

warpconvnet.nn.functional.normalizations

SegmentedLayerNormFunction

Bases: Function

Custom autograd function for segmented layer normalization (core normalization only).

This function implements both forward and backward passes for the core normalization: (x - mean) / std, without gamma and beta scaling/bias parameters.

Source code in warpconvnet/nn/functional/normalizations.py
class SegmentedLayerNormFunction(Function):
    """
    Custom autograd function for segmented layer normalization (core normalization only).

    This function implements both forward and backward passes for the core normalization:
    (x - mean) / std, without gamma and beta scaling/bias parameters.
    """

    @staticmethod
    def forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
        """
        Forward pass for segmented layer normalization (core normalization only).

        Args:
            ctx: Context for saving tensors needed in backward pass
            x: Input tensor of shape (N, D)
            offsets: Segment boundaries of shape (K+1,)
            eps: Epsilon for numerical stability

        Returns:
            Normalized tensor of shape (N, D) with mean=0, std=1 per segment
        """
        N, D = x.shape
        K = offsets.shape[0] - 1

        # Convert offsets to appropriate types
        offsets = offsets.to(dtype=torch.int64)  # For segment_csr
        d_offsets = offsets.to(x.device)

        # Compute mean and variance using segmented reduction
        mean = segment_csr(x, d_offsets, reduce="mean")  # Shape: (K, D)
        x_squared = x * x
        mean_squared = segment_csr(x_squared, d_offsets, reduce="mean")  # Shape: (K, D)
        variance = mean_squared - mean * mean  # Shape: (K, D)

        # Compute standard deviation
        std = torch.sqrt(variance + eps)  # Shape: (K, D)

        # Normalize: (x - mean) / std
        output = torch.zeros_like(x)

        # Subtract mean from each element
        _C.utils.segmented_arithmetic(x, mean, output, d_offsets, "subtract")

        # Divide by standard deviation
        _C.utils.segmented_arithmetic(output, std, output, d_offsets, "divide")

        # Save tensors for backward pass
        ctx.save_for_backward(x, mean, std, d_offsets)
        ctx.eps = eps
        ctx.N = N
        ctx.D = D
        ctx.K = K

        return output

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
        """
        Backward pass for segmented layer normalization (core normalization only).

        Treats mean and std as constants (detached from gradient computation).
        This simplifies the backward pass significantly.

        Args:
            ctx: Context containing saved tensors
            grad_output: Gradient of loss w.r.t. output

        Returns:
            Gradients w.r.t. (x, offsets, eps)
        """
        x, mean, std, d_offsets = ctx.saved_tensors

        # Detach mean and std to treat them as constants
        mean = mean.detach()
        std = std.detach()

        grad_x = None

        # Gradient w.r.t. x (simplified with mean and std as constants)
        if ctx.needs_input_grad[0]:
            # With mean and std as constants, the gradient simplifies to:
            # grad_x = grad_output / std (broadcast std to segments)
            grad_x = torch.zeros_like(x)
            _C.utils.segmented_arithmetic(grad_output, std, grad_x, d_offsets, "divide")

        # Return gradients in the same order as forward inputs
        # (x, offsets, eps)
        return grad_x, None, None

backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...] staticmethod

Backward pass for segmented layer normalization (core normalization only).

Treats mean and std as constants (detached from gradient computation). This simplifies the backward pass significantly.

Args: ctx: Context containing saved tensors grad_output: Gradient of loss w.r.t. output

Returns: Gradients w.r.t. (x, offsets, eps)

Source code in warpconvnet/nn/functional/normalizations.py
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
    """
    Backward pass for segmented layer normalization (core normalization only).

    Treats mean and std as constants (detached from gradient computation).
    This simplifies the backward pass significantly.

    Args:
        ctx: Context containing saved tensors
        grad_output: Gradient of loss w.r.t. output

    Returns:
        Gradients w.r.t. (x, offsets, eps)
    """
    x, mean, std, d_offsets = ctx.saved_tensors

    # Detach mean and std to treat them as constants
    mean = mean.detach()
    std = std.detach()

    grad_x = None

    # Gradient w.r.t. x (simplified with mean and std as constants)
    if ctx.needs_input_grad[0]:
        # With mean and std as constants, the gradient simplifies to:
        # grad_x = grad_output / std (broadcast std to segments)
        grad_x = torch.zeros_like(x)
        _C.utils.segmented_arithmetic(grad_output, std, grad_x, d_offsets, "divide")

    # Return gradients in the same order as forward inputs
    # (x, offsets, eps)
    return grad_x, None, None

forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor staticmethod

Forward pass for segmented layer normalization (core normalization only).

Args: ctx: Context for saving tensors needed in backward pass x: Input tensor of shape (N, D) offsets: Segment boundaries of shape (K+1,) eps: Epsilon for numerical stability

Returns: Normalized tensor of shape (N, D) with mean=0, std=1 per segment

Source code in warpconvnet/nn/functional/normalizations.py
@staticmethod
def forward(ctx: Any, x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
    """
    Forward pass for segmented layer normalization (core normalization only).

    Args:
        ctx: Context for saving tensors needed in backward pass
        x: Input tensor of shape (N, D)
        offsets: Segment boundaries of shape (K+1,)
        eps: Epsilon for numerical stability

    Returns:
        Normalized tensor of shape (N, D) with mean=0, std=1 per segment
    """
    N, D = x.shape
    K = offsets.shape[0] - 1

    # Convert offsets to appropriate types
    offsets = offsets.to(dtype=torch.int64)  # For segment_csr
    d_offsets = offsets.to(x.device)

    # Compute mean and variance using segmented reduction
    mean = segment_csr(x, d_offsets, reduce="mean")  # Shape: (K, D)
    x_squared = x * x
    mean_squared = segment_csr(x_squared, d_offsets, reduce="mean")  # Shape: (K, D)
    variance = mean_squared - mean * mean  # Shape: (K, D)

    # Compute standard deviation
    std = torch.sqrt(variance + eps)  # Shape: (K, D)

    # Normalize: (x - mean) / std
    output = torch.zeros_like(x)

    # Subtract mean from each element
    _C.utils.segmented_arithmetic(x, mean, output, d_offsets, "subtract")

    # Divide by standard deviation
    _C.utils.segmented_arithmetic(output, std, output, d_offsets, "divide")

    # Save tensors for backward pass
    ctx.save_for_backward(x, mean, std, d_offsets)
    ctx.eps = eps
    ctx.N = N
    ctx.D = D
    ctx.K = K

    return output

segmented_layer_norm(x: Float[Tensor, 'N D'], offsets: Int[Tensor, K + 1], gamma: Optional[Float[Tensor, 'K D']] = None, beta: Optional[Float[Tensor, 'K D']] = None, eps: float = 1e-05) -> Float[Tensor, 'N D']

Layer normalization on segmented data.

This is a segmented reduction of the form:

.. math:: \gamma_k \frac{x_i - \mu_k}{\sigma_k + \epsilon} + \beta_k

where :math:\mu_k and :math:\sigma_k are the mean and standard deviation of the :math:k-th segment, and :math:\gamma_k and :math:\beta_k are optional learnable parameters for the :math:k-th segment.

Args: x: Input tensor of shape (N, D) offsets: Segment boundaries of shape (K+1,) where K is the number of segments gamma: Optional learnable scale parameters of shape (D,) beta: Optional learnable bias parameters of shape (D,) eps: Epsilon value for numerical stability

Returns: Normalized tensor of shape (N, D)

Source code in warpconvnet/nn/functional/normalizations.py
def segmented_layer_norm(
    x: Float[Tensor, "N D"],
    offsets: Int[Tensor, "K+1"],
    gamma: Optional[Float[Tensor, "K D"]] = None,
    beta: Optional[Float[Tensor, "K D"]] = None,
    eps: float = 1e-5,
) -> Float[Tensor, "N D"]:
    r"""
    Layer normalization on segmented data.

    This is a segmented reduction of the form:

    .. math::
        \gamma_k \frac{x_i - \mu_k}{\sigma_k + \epsilon} + \beta_k

    where :math:`\mu_k` and :math:`\sigma_k` are the mean and standard deviation of the :math:`k`-th segment,
    and :math:`\gamma_k` and :math:`\beta_k` are optional learnable parameters for the :math:`k`-th segment.

    Args:
        x: Input tensor of shape (N, D)
        offsets: Segment boundaries of shape (K+1,) where K is the number of segments
        gamma: Optional learnable scale parameters of shape (D,)
        beta: Optional learnable bias parameters of shape (D,)
        eps: Epsilon value for numerical stability

    Returns:
        Normalized tensor of shape (N, D)
    """
    # Apply core normalization using the autograd function
    normalized: Tensor = segmented_norm(x, offsets, eps)  # type: ignore[assignment]

    if gamma is not None and beta is not None:
        normalized = torch.addcmul(beta, gamma, normalized)
    elif gamma is not None:
        normalized = torch.mul(gamma.unsqueeze(0), normalized)
    elif beta is not None:
        normalized = torch.add(beta.unsqueeze(0), normalized)

    return normalized

segmented_norm(x: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor

Segmented normalization.

Source code in warpconvnet/nn/functional/normalizations.py
def segmented_norm(x: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
    """
    Segmented normalization.
    """
    return SegmentedLayerNormFunction.apply(x, offsets, eps)  # type: ignore[assignment]

Segmented arithmetic

warpconvnet.nn.functional.segmented_arithmetics

SegmentedArithmeticFunction

Bases: Function

Custom autograd function for segmented arithmetic operations.

This ensures proper gradient flow through our segmented arithmetic operations.

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
class SegmentedArithmeticFunction(Function):
    """
    Custom autograd function for segmented arithmetic operations.

    This ensures proper gradient flow through our segmented arithmetic operations.
    """

    @staticmethod
    def forward(
        ctx: Any,
        x: Tensor,
        y: Tensor,
        offsets: Tensor,
        operation: str,
        eps: float = 1e-5,
    ) -> Tensor:
        """
        Forward pass for segmented arithmetic.

        Args:
            ctx: Context for backward pass
            x: Input tensor of shape (N, D)
            y: Segment-wise tensor of shape (K, D)
            offsets: Segment boundaries of shape (K+1,)
            operation: Operation type ("add", "subtract", "multiply", "divide")
            eps: Epsilon value for numerical stability
        Returns:
            Result tensor of shape (N, D)
        """
        # Perform the operation
        output = torch.zeros_like(x)
        _C.utils.segmented_arithmetic(x, y, output, offsets, operation)

        # Save for backward pass
        ctx.save_for_backward(x, y, offsets, eps)
        ctx.operation = operation

        return output

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
        """
        Backward pass for segmented arithmetic.

        Returns:
            Gradients w.r.t. (x, y, offsets, operation)
        """
        x, y, offsets, eps = ctx.saved_tensors
        operation = ctx.operation

        grad_x = None
        grad_y = None

        # Gradient w.r.t. x
        if ctx.needs_input_grad[0]:
            if operation == "add":
                grad_x = grad_output
            elif operation == "subtract":
                grad_x = grad_output
            elif operation == "multiply":
                # grad_x = grad_output * y (broadcast y to segments)
                grad_x = torch.zeros_like(x)
                _C.utils.segmented_arithmetic(
                    grad_output, y, grad_x, offsets, "multiply"
                )
            elif operation == "divide":
                # grad_x = grad_output / y (broadcast y to segments)
                grad_x = torch.zeros_like(x)
                _C.utils.segmented_arithmetic(grad_output, y, grad_x, offsets, "divide")

        # Gradient w.r.t. y
        if ctx.needs_input_grad[1]:
            if operation == "add":
                # grad_y = sum(grad_output) per segment
                grad_y = segment_csr(grad_output, offsets.to(torch.int64), reduce="sum")
            elif operation == "subtract":
                # grad_y = -sum(grad_output) per segment
                grad_y_sum = segment_csr(
                    grad_output, offsets.to(torch.int64), reduce="sum"
                )
                grad_y = -grad_y_sum
            elif operation == "multiply":
                # grad_y = sum(grad_output * x) per segment
                grad_y_input = grad_output * x
                grad_y = segment_csr(
                    grad_y_input, offsets.to(torch.int64), reduce="sum"
                )
            elif operation == "divide":
                # grad_y = -sum(grad_output * x / y^2) per segment
                # First compute x / y per segment, then multiply by grad_output
                x_div_y_sq = torch.zeros_like(x)
                y_squared = y * y + eps
                _C.utils.segmented_arithmetic(
                    x, y_squared, x_div_y_sq, offsets, "divide"
                )
                grad_y_input = -grad_output * x_div_y_sq
                grad_y = segment_csr(
                    grad_y_input, offsets.to(torch.int64), reduce="sum"
                )

        return grad_x, grad_y, None, None

backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...] staticmethod

Backward pass for segmented arithmetic.

Returns: Gradients w.r.t. (x, y, offsets, operation)

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tuple[Optional[Tensor], ...]:
    """
    Backward pass for segmented arithmetic.

    Returns:
        Gradients w.r.t. (x, y, offsets, operation)
    """
    x, y, offsets, eps = ctx.saved_tensors
    operation = ctx.operation

    grad_x = None
    grad_y = None

    # Gradient w.r.t. x
    if ctx.needs_input_grad[0]:
        if operation == "add":
            grad_x = grad_output
        elif operation == "subtract":
            grad_x = grad_output
        elif operation == "multiply":
            # grad_x = grad_output * y (broadcast y to segments)
            grad_x = torch.zeros_like(x)
            _C.utils.segmented_arithmetic(
                grad_output, y, grad_x, offsets, "multiply"
            )
        elif operation == "divide":
            # grad_x = grad_output / y (broadcast y to segments)
            grad_x = torch.zeros_like(x)
            _C.utils.segmented_arithmetic(grad_output, y, grad_x, offsets, "divide")

    # Gradient w.r.t. y
    if ctx.needs_input_grad[1]:
        if operation == "add":
            # grad_y = sum(grad_output) per segment
            grad_y = segment_csr(grad_output, offsets.to(torch.int64), reduce="sum")
        elif operation == "subtract":
            # grad_y = -sum(grad_output) per segment
            grad_y_sum = segment_csr(
                grad_output, offsets.to(torch.int64), reduce="sum"
            )
            grad_y = -grad_y_sum
        elif operation == "multiply":
            # grad_y = sum(grad_output * x) per segment
            grad_y_input = grad_output * x
            grad_y = segment_csr(
                grad_y_input, offsets.to(torch.int64), reduce="sum"
            )
        elif operation == "divide":
            # grad_y = -sum(grad_output * x / y^2) per segment
            # First compute x / y per segment, then multiply by grad_output
            x_div_y_sq = torch.zeros_like(x)
            y_squared = y * y + eps
            _C.utils.segmented_arithmetic(
                x, y_squared, x_div_y_sq, offsets, "divide"
            )
            grad_y_input = -grad_output * x_div_y_sq
            grad_y = segment_csr(
                grad_y_input, offsets.to(torch.int64), reduce="sum"
            )

    return grad_x, grad_y, None, None

forward(ctx: Any, x: Tensor, y: Tensor, offsets: Tensor, operation: str, eps: float = 1e-05) -> Tensor staticmethod

Forward pass for segmented arithmetic.

Args: ctx: Context for backward pass x: Input tensor of shape (N, D) y: Segment-wise tensor of shape (K, D) offsets: Segment boundaries of shape (K+1,) operation: Operation type ("add", "subtract", "multiply", "divide") eps: Epsilon value for numerical stability Returns: Result tensor of shape (N, D)

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
@staticmethod
def forward(
    ctx: Any,
    x: Tensor,
    y: Tensor,
    offsets: Tensor,
    operation: str,
    eps: float = 1e-5,
) -> Tensor:
    """
    Forward pass for segmented arithmetic.

    Args:
        ctx: Context for backward pass
        x: Input tensor of shape (N, D)
        y: Segment-wise tensor of shape (K, D)
        offsets: Segment boundaries of shape (K+1,)
        operation: Operation type ("add", "subtract", "multiply", "divide")
        eps: Epsilon value for numerical stability
    Returns:
        Result tensor of shape (N, D)
    """
    # Perform the operation
    output = torch.zeros_like(x)
    _C.utils.segmented_arithmetic(x, y, output, offsets, operation)

    # Save for backward pass
    ctx.save_for_backward(x, y, offsets, eps)
    ctx.operation = operation

    return output

segmented_add(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor

Segment-wise addition.

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_add(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5) -> Tensor:
    """Segment-wise addition."""
    return SegmentedArithmeticFunction.apply(x, y, offsets, "add", eps)  # type: ignore[return-value]

segmented_divide(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor

Segment-wise division.

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_divide(
    x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
    """Segment-wise division."""
    return SegmentedArithmeticFunction.apply(x, y, offsets, "divide", eps)  # type: ignore[return-value]

segmented_multiply(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor

Segment-wise multiplication.

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_multiply(
    x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
    """Segment-wise multiplication."""
    return SegmentedArithmeticFunction.apply(x, y, offsets, "multiply", eps)  # type: ignore[return-value]

segmented_subtract(x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-05) -> Tensor

Segment-wise subtraction.

Source code in warpconvnet/nn/functional/segmented_arithmetics.py
def segmented_subtract(
    x: Tensor, y: Tensor, offsets: Tensor, eps: float = 1e-5
) -> Tensor:
    """Segment-wise subtraction."""
    return SegmentedArithmeticFunction.apply(x, y, offsets, "subtract", eps)  # type: ignore[return-value]

Batched matrix multiplication

warpconvnet.nn.functional.bmm

bmm(sf: Geometry, weights: Float[Tensor, 'B C_in C_out']) -> Geometry

Batch matrix multiplication.

Source code in warpconvnet/nn/functional/bmm.py
def bmm(
    sf: Geometry,
    weights: Float[Tensor, "B C_in C_out"],
) -> Geometry:
    """
    Batch matrix multiplication.
    """
    assert sf.batch_size == weights.shape[0]
    if isinstance(sf.batched_features, CatFeatures):
        bat_features = cat_to_pad_tensor(sf.feature_tensor, sf.offsets)  # BxNxC_in
        out_bat_features = torch.bmm(bat_features, weights)
        out_features = pad_to_cat_tensor(out_bat_features, sf.offsets)
        out_features = CatFeatures(out_features, sf.offsets)
    elif isinstance(sf.batched_features, PadFeatures):
        bat_features = sf.feature_tensor  # BxMxC_in
        out_bat_features = torch.bmm(bat_features, weights)  # BxMxC_out
        out_features = PadFeatures(out_bat_features, sf.offsets)
    else:
        raise ValueError(f"Unsupported batched features type: {type(sf.batched_features)}")
    return sf.replace(
        batched_features=out_features,
    )