flatten_multi_index#
- sionna.phy.utils.flatten_multi_index(indices: torch.Tensor, shape: List[int] | torch.Size) torch.Tensor[source]#
Converts a tensor of index arrays into a tensor of flat indices.
- Parameters:
indices (torch.Tensor) – Indices to flatten with shape […, N] and dtype torch.int32 or torch.int64
shape (List[int] | torch.Size) – Shape of each index dimension [N]. Note that it must hold that
indices[..., n]<shape[n]for all n and batch dimension.
- Outputs:
flat_indices – Flattened indices with shape […]
Examples
import torch from sionna.phy.utils import flatten_multi_index indices = torch.tensor([2, 3]) shape = [5, 6] print(flatten_multi_index(indices, shape).numpy()) # 15 = 2*6 + 3