flatten_multi_index#

sionna.phy.utils.flatten_multi_index(indices: torch.Tensor, shape: List[int] | torch.Size) torch.Tensor[source]#

Converts a tensor of index arrays into a tensor of flat indices.

Parameters:
  • indices (torch.Tensor) – Indices to flatten with shape […, N] and dtype torch.int32 or torch.int64

  • shape (List[int] | torch.Size) – Shape of each index dimension [N]. Note that it must hold that indices[..., n] < shape[n] for all n and batch dimension.

Outputs:

flat_indices – Flattened indices with shape […]

Examples

import torch
from sionna.phy.utils import flatten_multi_index

indices = torch.tensor([2, 3])
shape = [5, 6]
print(flatten_multi_index(indices, shape).numpy())
# 15 = 2*6 + 3