split_dim#

sionna.phy.utils.split_dim(tensor: torch.Tensor, shape: List[int] | torch.Size, axis: int) torch.Tensor[source]#

Reshapes a dimension of a tensor into multiple dimensions.

This operation splits the dimension axis of a tensor into multiple dimensions according to shape.

Parameters:
  • tensor (torch.Tensor) – Input tensor

  • shape (List[int] | torch.Size) – Shape to which the dimension should be reshaped

  • axis (int) – Index of the axis to be reshaped

Outputs:

tensor – A tensor of the same type as tensor with len(shape) - 1 additional dimensions, but the same number of elements

Examples

import torch
from sionna.phy.utils import split_dim

x = torch.ones([2, 12])
print(x.shape)
# torch.Size([2, 12])

y = split_dim(x, shape=[3, 4], axis=1)
print(y.shape)
# torch.Size([2, 3, 4])