diag_part_axis#

sionna.phy.utils.diag_part_axis(tensor: torch.Tensor, axis: int, offset: int = 0) torch.Tensor[source]#

Extracts the batched diagonal part of a batched tensor over the specified axis.

This is an extension of PyTorch’s torch.diagonal function, which extracts the diagonal over the last two dimensions. This behavior can be reproduced by setting axis = -2.

Parameters:
  • tensor (torch.Tensor) – A tensor of rank greater than or equal to two (\(N\ge 2\)) with shape [s(1), …, s(N)]

  • axis (int) – Axis index starting from which the diagonal part is extracted

  • offset (int) – Offset of the diagonal from the main diagonal. Positive values select superdiagonals, negative values select subdiagonals.

Outputs:

tensor – Tensor containing the diagonal part of input tensor over axis (axis, axis + 1), with shape [s(1), …, min[s(axis), s(axis + 1)], s(axis + 2), …, s(N)]

Examples

import torch
from sionna.phy.utils import diag_part_axis

a = torch.arange(27).reshape(3, 3, 3)
print(a.numpy())
#  [[[ 0  1  2]
#    [ 3  4  5]
#    [ 6  7  8]]
#
#    [[ 9 10 11]
#    [12 13 14]
#    [15 16 17]]
#
#    [[18 19 20]
#    [21 22 23]
#    [24 25 26]]]

dp_0 = diag_part_axis(a, axis=0)
print(dp_0.numpy())
# [[ 0  1  2]
#  [12 13 14]
#  [24 25 26]]

dp_1 = diag_part_axis(a, axis=1)
print(dp_1.numpy())
# [[ 0  4  8]
#  [ 9 13 17]
#  [18 22 26]]