flatten_dims#
- sionna.phy.utils.flatten_dims(tensor: torch.Tensor, num_dims: int, axis: int) torch.Tensor[source]#
Flattens a specified set of dimensions of a tensor.
This operation flattens
num_dimsdimensions of atensorstarting at a givenaxis.- Parameters:
tensor (torch.Tensor) – Input tensor
num_dims (int) – Number of dimensions to combine. Must be larger than two and less or equal than the rank of
tensor.axis (int) – Index of the dimension from which to start
- Outputs:
tensor – A tensor of the same type as
tensorwithnum_dims- 1 lesser dimensions, but the same number of elements
Examples
import torch from sionna.phy.utils import flatten_dims x = torch.ones([2, 3, 4, 5]) print(x.shape) # torch.Size([2, 3, 4, 5]) y = flatten_dims(x, num_dims=2, axis=1) print(y.shape) # torch.Size([2, 12, 5])