Source code for sionna.phy.utils.tensors

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Functions extending PyTorch tensor operations for Sionna PHY and SYS"""

from typing import Any, List, Optional, Union

import torch

from sionna.phy import config

__all__ = [
    "expand_to_rank",
    "flatten_dims",
    "flatten_last_dims",
    "insert_dims",
    "split_dim",
    "diag_part_axis",
    "flatten_multi_index",
    "gather_from_batched_indices",
    "tensor_values_are_in_set",
    "random_tensor_from_values",
    "enumerate_indices",
    "find_true_position",
]


[docs] def expand_to_rank( tensor: torch.Tensor, target_rank: int, axis: int = -1 ) -> torch.Tensor: """Inserts as many axes to a tensor as needed to achieve a desired rank. This operation inserts additional dimensions to a ``tensor`` starting at ``axis``, so that the rank of the resulting tensor has rank ``target_rank``. The dimension index follows Python indexing rules, i.e., zero-based, where a negative index is counted backward from the end. :param tensor: Input tensor :param target_rank: Rank of the output tensor. If ``target_rank`` is smaller than the rank of ``tensor``, the function does nothing. :param axis: Dimension index at which to expand the shape of ``tensor``. Given a ``tensor`` of `D` dimensions, ``axis`` must be within the range `[-(D+1), D]` (inclusive). :output tensor: A tensor with the same data as ``tensor``, with ``target_rank`` - rank(``tensor``) additional dimensions inserted at the index specified by ``axis``. If ``target_rank`` <= rank(``tensor``), ``tensor`` is returned. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import expand_to_rank x = torch.ones([3, 4]) print(x.shape) # torch.Size([3, 4]) y = expand_to_rank(x, 4, axis=-1) print(y.shape) # torch.Size([3, 4, 1, 1]) """ num_dims = max(target_rank - tensor.dim(), 0) return insert_dims(tensor, num_dims, axis)
[docs] def flatten_dims(tensor: torch.Tensor, num_dims: int, axis: int) -> torch.Tensor: """Flattens a specified set of dimensions of a tensor. This operation flattens ``num_dims`` dimensions of a ``tensor`` starting at a given ``axis``. :param tensor: Input tensor :param num_dims: Number of dimensions to combine. Must be larger than two and less or equal than the rank of ``tensor``. :param axis: Index of the dimension from which to start :output tensor: A tensor of the same type as ``tensor`` with ``num_dims`` - 1 lesser dimensions, but the same number of elements .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import flatten_dims x = torch.ones([2, 3, 4, 5]) print(x.shape) # torch.Size([2, 3, 4, 5]) y = flatten_dims(x, num_dims=2, axis=1) print(y.shape) # torch.Size([2, 12, 5]) """ assert num_dims >= 2, "`num_dims` must be >= 2" assert num_dims <= tensor.dim(), "`num_dims` must <= rank(`tensor`)" assert 0 <= axis <= tensor.dim() - 1, "0 <= `axis` <= rank(tensor) - 1" assert num_dims + axis <= tensor.dim(), "`num_dims` + `axis` <= rank(`tensor`)" return torch.flatten(tensor, start_dim=axis, end_dim=axis + num_dims - 1)
[docs] def flatten_last_dims(tensor: torch.Tensor, num_dims: int = 2) -> torch.Tensor: """Flattens the last `n` dimensions of a tensor. This operation flattens the last ``num_dims`` dimensions of a ``tensor``. It is a simplified version of the function ``flatten_dims``. :param tensor: Input tensor :param num_dims: Number of dimensions to combine. Must be greater than or equal to two and less or equal than the rank of ``tensor``. :output tensor: A tensor of the same type as ``tensor`` with ``num_dims`` - 1 lesser dimensions, but the same number of elements .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import flatten_last_dims x = torch.ones([2, 3, 4]) print(x.shape) # torch.Size([2, 3, 4]) y = flatten_last_dims(x, num_dims=2) print(y.shape) # torch.Size([2, 12]) """ assert num_dims >= 2, "`num_dims` must be >= 2" assert num_dims <= tensor.dim(), "`num_dims` must <= rank(`tensor`)" return torch.flatten(tensor, start_dim=-num_dims)
[docs] def insert_dims(tensor: torch.Tensor, num_dims: int, axis: int = -1) -> torch.Tensor: """Adds multiple length-one dimensions to a tensor. This operation is an extension to PyTorch's ``unsqueeze`` function. It inserts ``num_dims`` dimensions of length one starting from the dimension ``axis`` of a ``tensor``. The dimension index follows Python indexing rules, i.e., zero-based, where a negative index is counted backward from the end. :param tensor: Input tensor :param num_dims: Number of dimensions to add :param axis: Dimension index at which to expand the shape of ``tensor``. Given a ``tensor`` of `D` dimensions, ``axis`` must be within the range `[-(D+1), D]` (inclusive). :output tensor: A tensor with the same data as ``tensor``, with ``num_dims`` additional dimensions inserted at the index specified by ``axis`` .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import insert_dims x = torch.ones([3, 4]) print(x.shape) # torch.Size([3, 4]) y = insert_dims(x, num_dims=2, axis=-1) print(y.shape) # torch.Size([3, 4, 1, 1]) """ assert num_dims >= 0, "`num_dims` must be nonnegative." rank = tensor.dim() assert -(rank + 1) <= axis <= rank, "`axis` is out of range `[-(D+1), D]`)" axis = axis if axis >= 0 else rank + axis + 1 shape = list(tensor.shape) new_shape = shape[:axis] + [1] * num_dims + shape[axis:] return tensor.reshape(new_shape)
[docs] def split_dim( tensor: torch.Tensor, shape: Union[List[int], torch.Size], axis: int ) -> torch.Tensor: """Reshapes a dimension of a tensor into multiple dimensions. This operation splits the dimension ``axis`` of a ``tensor`` into multiple dimensions according to ``shape``. :param tensor: Input tensor :param shape: Shape to which the dimension should be reshaped :param axis: Index of the axis to be reshaped :output tensor: A tensor of the same type as ``tensor`` with len(``shape``) - 1 additional dimensions, but the same number of elements .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import split_dim x = torch.ones([2, 12]) print(x.shape) # torch.Size([2, 12]) y = split_dim(x, shape=[3, 4], axis=1) print(y.shape) # torch.Size([2, 3, 4]) """ assert 0 <= axis <= tensor.dim() - 1, "0 <= `axis` <= rank(tensor) - 1" s = tensor.shape new_shape = list(s[:axis]) + list(shape) + list(s[axis + 1 :]) return tensor.reshape(new_shape)
[docs] def diag_part_axis(tensor: torch.Tensor, axis: int, offset: int = 0) -> torch.Tensor: r"""Extracts the batched diagonal part of a batched tensor over the specified axis. This is an extension of PyTorch's ``torch.diagonal`` function, which extracts the diagonal over the last two dimensions. This behavior can be reproduced by setting ``axis`` = -2. :param tensor: A tensor of rank greater than or equal to two (:math:`N\ge 2`) with shape [s(1), ..., s(N)] :param axis: Axis index starting from which the diagonal part is extracted :param offset: Offset of the diagonal from the main diagonal. Positive values select superdiagonals, negative values select subdiagonals. :output tensor: Tensor containing the diagonal part of input ``tensor`` over axis (``axis``, ``axis`` + 1), with shape [s(1), ..., min[s(``axis``), s(``axis`` + 1)], s(``axis`` + 2), ..., s(N)] .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import diag_part_axis a = torch.arange(27).reshape(3, 3, 3) print(a.numpy()) # [[[ 0 1 2] # [ 3 4 5] # [ 6 7 8]] # # [[ 9 10 11] # [12 13 14] # [15 16 17]] # # [[18 19 20] # [21 22 23] # [24 25 26]]] dp_0 = diag_part_axis(a, axis=0) print(dp_0.numpy()) # [[ 0 1 2] # [12 13 14] # [24 25 26]] dp_1 = diag_part_axis(a, axis=1) print(dp_1.numpy()) # [[ 0 4 8] # [ 9 13 17] # [18 22 26]] """ assert tensor.dim() >= 2, "The input tensor must have rank >= 2." rank = tensor.dim() if axis < 0: axis = rank + axis assert 0 <= axis <= rank - 2, "Input value of 'axis' out of boundaries." # torch.diagonal extracts diagonal from dim1 and dim2, placing it at the end diag = torch.diagonal(tensor, offset=offset, dim1=axis, dim2=axis + 1) # Move the last dimension (diagonal) back to position axis return torch.movedim(diag, -1, axis)
[docs] def flatten_multi_index( indices: torch.Tensor, shape: Union[List[int], torch.Size] ) -> torch.Tensor: r"""Converts a tensor of index arrays into a tensor of flat indices. :param indices: Indices to flatten with shape [..., N] and dtype `torch.int32` or `torch.int64` :param shape: Shape of each index dimension [N]. Note that it must hold that ``indices[..., n]`` < ``shape[n]`` for all n and batch dimension. :output flat_indices: Flattened indices with shape [...] .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import flatten_multi_index indices = torch.tensor([2, 3]) shape = [5, 6] print(flatten_multi_index(indices, shape).numpy()) # 15 = 2*6 + 3 """ indices = indices.to(torch.int64) shape_tensor = torch.tensor(shape, dtype=torch.int64, device=indices.device) # Assert that indices are within valid bounds assert torch.all(indices >= 0), "indices must be non-negative" assert torch.all(indices < shape_tensor), "indices out of bounds" # Compute strides: [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] ones = torch.ones(1, dtype=torch.int64, device=indices.device) strides = torch.cat([shape_tensor[1:], ones]).flip(0).cumprod(0).flip(0) return (indices * strides).sum(dim=-1)
[docs] def gather_from_batched_indices( params: torch.Tensor, indices: torch.Tensor ) -> torch.Tensor: r"""Gathers the values of a tensor ``params`` according to batch-specific ``indices``. :param params: Tensor containing the values to gather with shape [s(1), ..., s(N)] :param indices: Tensor containing, for each batch [...], the indices at which ``params`` is gathered with shape [..., N] and dtype `torch.int32` or `torch.int64`. Note that 0 :math:`\le` ``indices[...,n]`` :math:`<` s(n) must hold for all n=1,...,N. :output values: Tensor containing the gathered values with shape [...] .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import gather_from_batched_indices params = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) print(params.shape) # torch.Size([3, 3]) indices = torch.tensor([[[0, 1], [1, 2], [2, 0], [0, 0]], [[0, 0], [2, 2], [2, 1], [0, 1]]]) print(indices.shape) # torch.Size([2, 4, 2]) # Note that the batch shape is [2, 4]. Each batch contains a list # of 2 indices print(gather_from_batched_indices(params, indices).numpy()) # [[20, 60, 70, 10], # [10, 90, 80, 20]] # Note that the output shape coincides with the batch shape. # Element [i,j] coincides with params[indices[i,j,:]] """ flat_indices = flatten_multi_index(indices, shape=list(params.shape)) return params.reshape(-1)[flat_indices]
[docs] def tensor_values_are_in_set( tensor: torch.Tensor, admissible_set: Union[torch.Tensor, List[Any]] ) -> torch.Tensor: r"""Checks if the input ``tensor`` values are contained in the specified ``admissible_set``. :param tensor: Tensor to validate :param admissible_set: Set of valid values that the input ``tensor`` must be composed of :output result: Returns `True` if and only if ``tensor`` values are contained in ``admissible_set`` .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import tensor_values_are_in_set tensor = torch.tensor([[1, 0], [0, 1]]) print(tensor_values_are_in_set(tensor, [0, 1, 2]).item()) # True print(tensor_values_are_in_set(tensor, [0, 2]).item()) # False """ if not isinstance(admissible_set, torch.Tensor): admissible_set = torch.tensor(admissible_set, device=tensor.device) return torch.all(torch.isin(tensor, admissible_set))
[docs] def random_tensor_from_values( values: Union[torch.Tensor, List[Any]], shape: Union[List[int], torch.Size], dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: r"""Generates a tensor of the specified ``shape``, with elements randomly sampled from the provided set of ``values``. :param values: The set of values to sample from :param shape: The desired shape of the output tensor :param dtype: Desired dtype of the output :output tensor: A tensor with the specified shape, where each element is randomly selected from ``values`` .. rubric:: Examples .. code-block:: python from sionna.phy.utils import random_tensor_from_values values = [0, 10, 20] shape = [2, 3] print(random_tensor_from_values(values, shape).numpy()) # array([[ 0, 20, 0], # [10, 0, 20]], dtype=int32) """ if not isinstance(values, torch.Tensor): values = torch.tensor(values, device=config.device) indices = torch.randint( low=0, high=len(values), size=tuple(shape), dtype=torch.int64, device=config.device, generator=config.torch_rng(), ) tensor = values[indices] if dtype is not None: tensor = tensor.to(dtype=dtype) return tensor
[docs] def enumerate_indices(bounds: Union[List[int], torch.Tensor]) -> torch.Tensor: r"""Enumerates all indices between 0 (included) and ``bounds`` (excluded) in lexicographic order. :param bounds: Collection of index bounds :output indices: Collection of all indices, in lexicographic order, with shape [prod(bounds), len(bounds)] .. rubric:: Examples .. code-block:: python from sionna.phy.utils import enumerate_indices print(enumerate_indices([2, 3]).numpy()) # [[0 0] # [0 1] # [0 2] # [1 0] # [1 1] # [1 2]] """ if isinstance(bounds, torch.Tensor): bounds_list = bounds.tolist() else: bounds_list = list(bounds) ranges = [torch.arange(b, device=config.device) for b in bounds_list] return torch.cartesian_prod(*ranges)
[docs] def find_true_position( bool_tensor: torch.Tensor, side: str = "last", axis: int = -1 ) -> torch.Tensor: """Finds the index of the first or last `True` value along the specified axis. When no `True` value is present, it returns -1. :param bool_tensor: Boolean tensor of any shape :param side: ``'first'`` | ``'last'``. If ``'first'``, the first `True` position is found, else the last. :param axis: Axis along which to find the last `True` value :output position: Tensor of indices, containing the index of the first or last `True` value. Its shape is ``bool_tensor.shape`` with specified ``axis`` removed. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import find_true_position x = torch.tensor([True, False, True, False, True]) print(find_true_position(x, side='first').item()) # 0 print(find_true_position(x, side='last').item()) # 4 """ assert side in ["first", "last"], "input side must be 'first' or 'last'" # Check if any True exists along the axis any_true = torch.any(bool_tensor, dim=axis) size_along_axis = bool_tensor.shape[axis] if side == "first": # argmax returns the first True position (treats True as 1, False as 0) idx = torch.argmax(bool_tensor.to(torch.int32), dim=axis) else: # Flip, find first, then convert to last position flipped = torch.flip(bool_tensor, dims=[axis]) idx_from_end = torch.argmax(flipped.to(torch.int32), dim=axis) idx = size_along_axis - 1 - idx_from_end # Return -1 where no True was found return torch.where(any_true, idx, torch.tensor(-1, device=bool_tensor.device))