gather_from_batched_indices#

sionna.phy.utils.gather_from_batched_indices(params: torch.Tensor, indices: torch.Tensor) torch.Tensor[source]#

Gathers the values of a tensor params according to batch-specific indices.

Parameters:
  • params (torch.Tensor) – Tensor containing the values to gather with shape [s(1), …, s(N)]

  • indices (torch.Tensor) – Tensor containing, for each batch […], the indices at which params is gathered with shape […, N] and dtype torch.int32 or torch.int64. Note that 0 \(\le\) indices[...,n] \(<\) s(n) must hold for all n=1,…,N.

Outputs:

values – Tensor containing the gathered values with shape […]

Examples

import torch
from sionna.phy.utils import gather_from_batched_indices

params = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
print(params.shape)
# torch.Size([3, 3])

indices = torch.tensor([[[0, 1], [1, 2], [2, 0], [0, 0]],
                       [[0, 0], [2, 2], [2, 1], [0, 1]]])
print(indices.shape)
# torch.Size([2, 4, 2])
# Note that the batch shape is [2, 4]. Each batch contains a list
# of 2 indices

print(gather_from_batched_indices(params, indices).numpy())
# [[20, 60, 70, 10],
#  [10, 90, 80, 20]]
# Note that the output shape coincides with the batch shape.
# Element [i,j] coincides with params[indices[i,j,:]]