flatten_last_dims#

sionna.phy.utils.flatten_last_dims(tensor: torch.Tensor, num_dims: int = 2) torch.Tensor[source]#

Flattens the last n dimensions of a tensor.

This operation flattens the last num_dims dimensions of a tensor. It is a simplified version of the function flatten_dims.

Parameters:
  • tensor (torch.Tensor) – Input tensor

  • num_dims (int) – Number of dimensions to combine. Must be greater than or equal to two and less or equal than the rank of tensor.

Outputs:

tensor – A tensor of the same type as tensor with num_dims - 1 lesser dimensions, but the same number of elements

Examples

import torch
from sionna.phy.utils import flatten_last_dims

x = torch.ones([2, 3, 4])
print(x.shape)
# torch.Size([2, 3, 4])

y = flatten_last_dims(x, num_dims=2)
print(y.shape)
# torch.Size([2, 12])