diag_part_axis#
- sionna.phy.utils.diag_part_axis(tensor: torch.Tensor, axis: int, offset: int = 0) torch.Tensor[source]#
Extracts the batched diagonal part of a batched tensor over the specified axis.
This is an extension of PyTorch’s
torch.diagonalfunction, which extracts the diagonal over the last two dimensions. This behavior can be reproduced by settingaxis= -2.- Parameters:
tensor (torch.Tensor) – A tensor of rank greater than or equal to two (\(N\ge 2\)) with shape [s(1), …, s(N)]
axis (int) – Axis index starting from which the diagonal part is extracted
offset (int) – Offset of the diagonal from the main diagonal. Positive values select superdiagonals, negative values select subdiagonals.
- Outputs:
tensor – Tensor containing the diagonal part of input
tensorover axis (axis,axis+ 1), with shape [s(1), …, min[s(axis), s(axis+ 1)], s(axis+ 2), …, s(N)]
Examples
import torch from sionna.phy.utils import diag_part_axis a = torch.arange(27).reshape(3, 3, 3) print(a.numpy()) # [[[ 0 1 2] # [ 3 4 5] # [ 6 7 8]] # # [[ 9 10 11] # [12 13 14] # [15 16 17]] # # [[18 19 20] # [21 22 23] # [24 25 26]]] dp_0 = diag_part_axis(a, axis=0) print(dp_0.numpy()) # [[ 0 1 2] # [12 13 14] # [24 25 26]] dp_1 = diag_part_axis(a, axis=1) print(dp_1.numpy()) # [[ 0 4 8] # [ 9 13 17] # [18 22 26]]