split_dim#
- sionna.phy.utils.split_dim(tensor: torch.Tensor, shape: List[int] | torch.Size, axis: int) torch.Tensor[source]#
Reshapes a dimension of a tensor into multiple dimensions.
This operation splits the dimension
axisof atensorinto multiple dimensions according toshape.- Parameters:
tensor (torch.Tensor) – Input tensor
shape (List[int] | torch.Size) – Shape to which the dimension should be reshaped
axis (int) – Index of the axis to be reshaped
- Outputs:
tensor – A tensor of the same type as
tensorwith len(shape) - 1 additional dimensions, but the same number of elements
Examples
import torch from sionna.phy.utils import split_dim x = torch.ones([2, 12]) print(x.shape) # torch.Size([2, 12]) y = split_dim(x, shape=[3, 4], axis=1) print(y.shape) # torch.Size([2, 3, 4])