gather_from_batched_indices#
- sionna.phy.utils.gather_from_batched_indices(params: torch.Tensor, indices: torch.Tensor) torch.Tensor[source]#
Gathers the values of a tensor
paramsaccording to batch-specificindices.- Parameters:
params (torch.Tensor) – Tensor containing the values to gather with shape [s(1), …, s(N)]
indices (torch.Tensor) – Tensor containing, for each batch […], the indices at which
paramsis gathered with shape […, N] and dtype torch.int32 or torch.int64. Note that 0 \(\le\)indices[...,n]\(<\) s(n) must hold for all n=1,…,N.
- Outputs:
values – Tensor containing the gathered values with shape […]
Examples
import torch from sionna.phy.utils import gather_from_batched_indices params = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) print(params.shape) # torch.Size([3, 3]) indices = torch.tensor([[[0, 1], [1, 2], [2, 0], [0, 0]], [[0, 0], [2, 2], [2, 1], [0, 1]]]) print(indices.shape) # torch.Size([2, 4, 2]) # Note that the batch shape is [2, 4]. Each batch contains a list # of 2 indices print(gather_from_batched_indices(params, indices).numpy()) # [[20, 60, 70, 10], # [10, 90, 80, 20]] # Note that the output shape coincides with the batch shape. # Element [i,j] coincides with params[indices[i,j,:]]