flatten_last_dims#
- sionna.phy.utils.flatten_last_dims(tensor: torch.Tensor, num_dims: int = 2) torch.Tensor[source]#
Flattens the last n dimensions of a tensor.
This operation flattens the last
num_dimsdimensions of atensor. It is a simplified version of the functionflatten_dims.- Parameters:
tensor (torch.Tensor) – Input tensor
num_dims (int) – Number of dimensions to combine. Must be greater than or equal to two and less or equal than the rank of
tensor.
- Outputs:
tensor – A tensor of the same type as
tensorwithnum_dims- 1 lesser dimensions, but the same number of elements
Examples
import torch from sionna.phy.utils import flatten_last_dims x = torch.ones([2, 3, 4]) print(x.shape) # torch.Size([2, 3, 4]) y = flatten_last_dims(x, num_dims=2) print(y.shape) # torch.Size([2, 12])