Source code for sionna.utils.tensors

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Functions extending TensorFlow tensor operations"""

import tensorflow as tf
import sionna as sn

[docs]def expand_to_rank(tensor, target_rank, axis=-1): """Inserts as many axes to a tensor as needed to achieve a desired rank. This operation inserts additional dimensions to a ``tensor`` starting at ``axis``, so that so that the rank of the resulting tensor has rank ``target_rank``. The dimension index follows Python indexing rules, i.e., zero-based, where a negative index is counted backward from the end. Args: tensor : A tensor. target_rank (int) : The rank of the output tensor. If ``target_rank`` is smaller than the rank of ``tensor``, the function does nothing. axis (int) : The dimension index at which to expand the shape of ``tensor``. Given a ``tensor`` of `D` dimensions, ``axis`` must be within the range `[-(D+1), D]` (inclusive). Returns: A tensor with the same data as ``tensor``, with ``target_rank``- rank(``tensor``) additional dimensions inserted at the index specified by ``axis``. If ``target_rank`` <= rank(``tensor``), ``tensor`` is returned. """ num_dims = tf.maximum(target_rank - tf.rank(tensor), 0) output = insert_dims(tensor, num_dims, axis) return output
[docs]def flatten_dims(tensor, num_dims, axis): """ Flattens a specified set of dimensions of a tensor. This operation flattens ``num_dims`` dimensions of a ``tensor`` starting at a given ``axis``. Args: tensor : A tensor. num_dims (int): The number of dimensions to combine. Must be larger than two and less or equal than the rank of ``tensor``. axis (int): The index of the dimension from which to start. Returns: A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser dimensions, but the same number of elements. """ msg = "`num_dims` must be >= 2" tf.debugging.assert_greater_equal(num_dims, 2, msg) msg = "`num_dims` must <= rank(`tensor`)" tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg) msg = "0<= `axis` <= rank(tensor)-1" tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg) tf.debugging.assert_greater_equal(axis, 0, msg) msg ="`num_dims`+`axis` <= rank(`tensor`)" tf.debugging.assert_less_equal(num_dims + axis, tf.rank(tensor), msg) if num_dims==len(tensor.shape): new_shape = [-1] elif axis==0: shape = tf.shape(tensor) new_shape = tf.concat([[-1], shape[axis+num_dims:]], 0) else: shape = tf.shape(tensor) flat_dim = tf.reduce_prod(tensor.shape[axis:axis+num_dims]) new_shape = tf.concat([shape[:axis], [flat_dim], shape[axis+num_dims:]], 0) return tf.reshape(tensor, new_shape)
[docs]def flatten_last_dims(tensor, num_dims=2): """ Flattens the last `n` dimensions of a tensor. This operation flattens the last ``num_dims`` dimensions of a ``tensor``. It is a simplified version of the function ``flatten_dims``. Args: tensor : A tensor. num_dims (int): The number of dimensions to combine. Must be greater than or equal to two and less or equal than the rank of ``tensor``. Returns: A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser dimensions, but the same number of elements. """ msg = "`num_dims` must be >= 2" tf.debugging.assert_greater_equal(num_dims, 2, msg) msg = "`num_dims` must <= rank(`tensor`)" tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg) if num_dims==len(tensor.shape): new_shape = [-1] else: shape = tf.shape(tensor) last_dim = tf.reduce_prod(tensor.shape[-num_dims:]) new_shape = tf.concat([shape[:-num_dims], [last_dim]], 0) return tf.reshape(tensor, new_shape)
[docs]def insert_dims(tensor, num_dims, axis=-1): """Adds multiple length-one dimensions to a tensor. This operation is an extension to TensorFlow`s ``expand_dims`` function. It inserts ``num_dims`` dimensions of length one starting from the dimension ``axis`` of a ``tensor``. The dimension index follows Python indexing rules, i.e., zero-based, where a negative index is counted backward from the end. Args: tensor : A tensor. num_dims (int) : The number of dimensions to add. axis : The dimension index at which to expand the shape of ``tensor``. Given a ``tensor`` of `D` dimensions, ``axis`` must be within the range `[-(D+1), D]` (inclusive). Returns: A tensor with the same data as ``tensor``, with ``num_dims`` additional dimensions inserted at the index specified by ``axis``. """ msg = "`num_dims` must be nonnegative." tf.debugging.assert_greater_equal(num_dims, 0, msg) rank = tf.rank(tensor) msg = "`axis` is out of range `[-(D+1), D]`)" tf.debugging.assert_less_equal(axis, rank, msg) tf.debugging.assert_greater_equal(axis, -(rank+1), msg) axis = axis if axis>=0 else rank+axis+1 shape = tf.shape(tensor) new_shape = tf.concat([shape[:axis], tf.ones([num_dims], tf.int32), shape[axis:]], 0) output = tf.reshape(tensor, new_shape) return output
[docs]def split_dim(tensor, shape, axis): """Reshapes a dimension of a tensor into multiple dimensions. This operation splits the dimension ``axis`` of a ``tensor`` into multiple dimensions according to ``shape``. Args: tensor : A tensor. shape (list or TensorShape): The shape to which the dimension should be reshaped. axis (int): The index of the axis to be reshaped. Returns: A tensor of the same type as ``tensor`` with len(``shape``)-1 additional dimensions, but the same number of elements. """ msg = "0<= `axis` <= rank(tensor)-1" tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg) tf.debugging.assert_greater_equal(axis, 0, msg) s = tf.shape(tensor) new_shape = tf.concat([s[:axis], shape, s[axis+1:]], 0) output = tf.reshape(tensor, new_shape) return output
[docs]def matrix_sqrt(tensor): r""" Computes the square root of a matrix. Given a batch of Hermitian positive semi-definite matrices :math:`\mathbf{A}`, returns matrices :math:`\mathbf{B}`, such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`. The two inner dimensions are assumed to correspond to the matrix rows and columns, respectively. Args: tensor ([..., M, M]) : A tensor of rank greater than or equal to two. Returns: A tensor of the same shape and type as ``tensor`` containing the matrix square root of its last two dimensions. Note: If you want to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.config.xla_compat=true``. See :py:attr:`~sionna.config.xla_compat`. """ if sn.config.xla_compat and not tf.executing_eagerly(): s, u = tf.linalg.eigh(tensor) # Compute sqrt of eigenvalues s = tf.abs(s) s = tf.sqrt(s) s = tf.cast(s, u.dtype) # Matrix multiplication s = tf.expand_dims(s, -2) return tf.matmul(u*s, u, adjoint_b=True) else: return tf.linalg.sqrtm(tensor)
[docs]def matrix_sqrt_inv(tensor): r""" Computes the inverse square root of a Hermitian matrix. Given a batch of Hermitian positive definite matrices :math:`\mathbf{A}`, with square root matrices :math:`\mathbf{B}`, such that :math:`\mathbf{B}\mathbf{B}^H = \mathbf{A}`, the function returns :math:`\mathbf{B}^{-1}`, such that :math:`\mathbf{B}^{-1}\mathbf{B}=\mathbf{I}`. The two inner dimensions are assumed to correspond to the matrix rows and columns, respectively. Args: tensor ([..., M, M]) : A tensor of rank greater than or equal to two. Returns: A tensor of the same shape and type as ``tensor`` containing the inverse matrix square root of its last two dimensions. Note: If you want to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ if sn.config.xla_compat and not tf.executing_eagerly(): s, u = tf.linalg.eigh(tensor) # Compute 1/sqrt of eigenvalues s = tf.abs(s) tf.debugging.assert_positive(s, "Input must be positive definite.") s = 1/tf.sqrt(s) s = tf.cast(s, u.dtype) # Matrix multiplication s = tf.expand_dims(s, -2) return tf.matmul(u*s, u, adjoint_b=True) else: return tf.linalg.inv(tf.linalg.sqrtm(tensor))
[docs]def matrix_inv(tensor): r""" Computes the inverse of a Hermitian matrix. Given a batch of Hermitian positive definite matrices :math:`\mathbf{A}`, the function returns :math:`\mathbf{A}^{-1}`, such that :math:`\mathbf{A}^{-1}\mathbf{A}=\mathbf{I}`. The two inner dimensions are assumed to correspond to the matrix rows and columns, respectively. Args: tensor ([..., M, M]) : A tensor of rank greater than or equal to two. Returns: A tensor of the same shape and type as ``tensor``, containing the inverse of its last two dimensions. Note: If you want to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ if tensor.dtype in [tf.complex64, tf.complex128] \ and sn.config.xla_compat \ and not tf.executing_eagerly(): s, u = tf.linalg.eigh(tensor) # Compute inverse of eigenvalues s = tf.abs(s) tf.debugging.assert_positive(s, "Input must be positive definite.") s = 1/s s = tf.cast(s, u.dtype) # Matrix multiplication s = tf.expand_dims(s, -2) return tf.matmul(u*s, u, adjoint_b=True) else: return tf.linalg.inv(tensor)
[docs]def matrix_pinv(tensor): r""" Computes the Moore–Penrose (or pseudo) inverse of a matrix. Given a batch of :math:`M \times K` matrices :math:`\mathbf{A}` with rank :math:`K` (i.e., linearly independent columns), the function returns :math:`\mathbf{A}^+`, such that :math:`\mathbf{A}^{+}\mathbf{A}=\mathbf{I}_K`. The two inner dimensions are assumed to correspond to the matrix rows and columns, respectively. Args: tensor ([..., M, K]) : A tensor of rank greater than or equal to two. Returns: A tensor of shape ([..., K,K]) of the same type as ``tensor``, containing the pseudo inverse of its last two dimensions. Note: If you want to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.config.xla_compat=true``. See :py:attr:`~sionna.config.xla_compat`. """ inv = matrix_inv(tf.matmul(tensor, tensor, adjoint_a=True)) return tf.matmul(inv, tensor, adjoint_b=True)