flatten_dims#

sionna.phy.utils.flatten_dims(tensor: torch.Tensor, num_dims: int, axis: int) torch.Tensor[source]#

Flattens a specified set of dimensions of a tensor.

This operation flattens num_dims dimensions of a tensor starting at a given axis.

Parameters:
  • tensor (torch.Tensor) – Input tensor

  • num_dims (int) – Number of dimensions to combine. Must be larger than two and less or equal than the rank of tensor.

  • axis (int) – Index of the dimension from which to start

Outputs:

tensor – A tensor of the same type as tensor with num_dims - 1 lesser dimensions, but the same number of elements

Examples

import torch
from sionna.phy.utils import flatten_dims

x = torch.ones([2, 3, 4, 5])
print(x.shape)
# torch.Size([2, 3, 4, 5])

y = flatten_dims(x, num_dims=2, axis=1)
print(y.shape)
# torch.Size([2, 12, 5])