#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Functions extending TensorFlow tensor operations for Sionna PHY and SYS"""
import tensorflow as tf
from sionna.phy import config
[docs]
def expand_to_rank(tensor, target_rank, axis=-1):
"""Inserts as many axes to a tensor as needed to achieve a desired rank
This operation inserts additional dimensions to a ``tensor`` starting at
``axis``, so that so that the rank of the resulting tensor has rank
``target_rank``. The dimension index follows Python indexing rules, i.e.,
zero-based, where a negative index is counted backward from the end.
Input
-----
tensor : `tf.Tensor`
Input tensor
target_rank : `int`
Rank of the output tensor.
If ``target_rank`` is smaller than the rank of ``tensor``,
the function does nothing.
axis : `int`
Dimension index at which to expand the
shape of ``tensor``. Given a ``tensor`` of `D` dimensions,
``axis`` must be within the range `[-(D+1), D]` (inclusive).
Output
------
: `tf.Tensor`
A tensor with the same data as ``tensor``, with
``target_rank``- rank(``tensor``) additional dimensions inserted at the
index specified by ``axis``.
If ``target_rank`` <= rank(``tensor``), ``tensor`` is returned.
"""
num_dims = tf.maximum(target_rank - tf.rank(tensor), 0)
output = insert_dims(tensor, num_dims, axis)
return output
[docs]
def flatten_dims(tensor, num_dims, axis):
"""
Flattens a specified set of dimensions of a tensor
This operation flattens ``num_dims`` dimensions of a ``tensor``
starting at a given ``axis``.
Input
-----
tensor : `tf.Tensor`
Input tensor
num_dims : `int`
Number of dimensions to combine. Must be larger than
two and less or equal than the rank of ``tensor``.
axis : `int`
Index of the dimension from which to start
Output
------
: `tf.Tensor`
A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser
dimensions, but the same number of elements
"""
msg = "`num_dims` must be >= 2"
tf.debugging.assert_greater_equal(num_dims, 2, msg)
msg = "`num_dims` must <= rank(`tensor`)"
tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg)
msg = "0<= `axis` <= rank(tensor)-1"
tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg)
tf.debugging.assert_greater_equal(axis, 0, msg)
msg ="`num_dims`+`axis` <= rank(`tensor`)"
tf.debugging.assert_less_equal(num_dims + axis, tf.rank(tensor), msg)
if num_dims==len(tensor.shape):
new_shape = [-1]
elif axis==0:
shape = tf.shape(tensor)
new_shape = tf.concat([[-1], shape[axis+num_dims:]], 0)
else:
shape = tf.shape(tensor)
flat_dim = tf.reduce_prod(tensor.shape[axis:axis+num_dims])
new_shape = tf.concat([shape[:axis],
[flat_dim],
shape[axis+num_dims:]], 0)
return tf.reshape(tensor, new_shape)
[docs]
def flatten_last_dims(tensor, num_dims=2):
"""
Flattens the last `n` dimensions of a tensor
This operation flattens the last ``num_dims`` dimensions of a ``tensor``.
It is a simplified version of the function ``flatten_dims``.
Input
-----
tensor : `tf.Tensor`
Input tensor
num_dims : `int`
Number of dimensions to combine.
Must be greater than or equal to two and less or equal
than the rank of ``tensor``.
Output
------
: `tf.Tensor`
A tensor of the same type as ``tensor`` with ``num_dims``-1 lesser
dimensions, but the same number of elements
"""
msg = "`num_dims` must be >= 2"
tf.debugging.assert_greater_equal(num_dims, 2, msg)
msg = "`num_dims` must <= rank(`tensor`)"
tf.debugging.assert_less_equal(num_dims, tf.rank(tensor), msg)
if num_dims==len(tensor.shape):
new_shape = [-1]
else:
shape = tf.shape(tensor)
last_dim = tf.reduce_prod(tensor.shape[-num_dims:])
new_shape = tf.concat([shape[:-num_dims], [last_dim]], 0)
return tf.reshape(tensor, new_shape)
[docs]
def insert_dims(tensor, num_dims, axis=-1):
"""Adds multiple length-one dimensions to a tensor
This operation is an extension to TensorFlow`s ``expand_dims`` function.
It inserts ``num_dims`` dimensions of length one starting from the
dimension ``axis`` of a ``tensor``. The dimension
index follows Python indexing rules, i.e., zero-based, where a negative
index is counted backward from the end.
Input
-----
tensor : `tf.Tensor`
Input tensor
num_dims : `int`
Number of dimensions to add
axis : `int`
Dimension index at which to expand the
shape of ``tensor``. Given a ``tensor`` of `D` dimensions,
``axis`` must be within the range `[-(D+1), D]` (inclusive).
Output
------
: `tf.tensor`
A tensor with the same data as ``tensor``, with ``num_dims`` additional
dimensions inserted at the index specified by ``axis``
"""
msg = "`num_dims` must be nonnegative."
tf.debugging.assert_greater_equal(num_dims, 0, msg)
rank = tf.rank(tensor)
msg = "`axis` is out of range `[-(D+1), D]`)"
tf.debugging.assert_less_equal(axis, rank, msg)
tf.debugging.assert_greater_equal(axis, -(rank+1), msg)
axis = axis if axis>=0 else rank+axis+1
shape = tf.shape(tensor)
new_shape = tf.concat([shape[:axis],
tf.ones([num_dims], tf.int32),
shape[axis:]], 0)
output = tf.reshape(tensor, new_shape)
return output
[docs]
def split_dim(tensor, shape, axis):
"""Reshapes a dimension of a tensor into multiple dimensions
This operation splits the dimension ``axis`` of a ``tensor`` into
multiple dimensions according to ``shape``.
Input
-----
tensor : `tf.Tensor`
Input tensor
shape : (list or TensorShape)
Shape to which the dimension should
be reshaped
axis : `int`
Index of the axis to be reshaped
Output
------
: `tf.Tensor`
A tensor of the same type as ``tensor`` with len(``shape``)-1
additional dimensions, but the same number of elements
"""
msg = "0<= `axis` <= rank(tensor)-1"
tf.debugging.assert_less_equal(axis, tf.rank(tensor)-1, msg)
tf.debugging.assert_greater_equal(axis, 0, msg)
s = tf.shape(tensor)
new_shape = tf.concat([s[:axis], shape, s[axis+1:]], 0)
output = tf.reshape(tensor, new_shape)
return output
[docs]
def diag_part_axis(tensor,
axis,
**kwargs):
# pylint: disable=line-too-long
r"""
Extracts the batched diagonal part of a batched tensor over the specified axis
This is an extension of TensorFlow`s ``tf.linalg.diag_part`` function, which
extracts the diagonal over the last two dimensions. This behavior can be
reproduced by setting ``axis`` =-2.
Input
-----
tensor : [s(1), ..., s(N)], `any`
A tensor of rank greater than or equal to two (:math:`N\ge 2`)
axis : `int`
Axis index starting from which the diagonal part is
extracted
kwargs : `dict`
Optional inputs for TensorFlow's
`linalg.diag_part`, such as the diagonal offset ``k`` or the
padding value ``padding_value``. See TensorFlow's `linalg.diag_part` for
more details.
Output
------
: [s(1), ..., min[s(``axis``),s(``axis`` +1)], s(``axis`` +2), ..., s(N))], `any`
Tensor containing the diagonal part of input ``tensor`` over axis
(``axis``, ``axis`` +1)
Example
-------
.. code-block:: Python
import tensorflow as tf
from sionna.phy.utils import diag_part_axis
a = tf.reshape(tf.range(27), [3,3,3])
print(a.numpy())
# [[[ 0 1 2]
# [ 3 4 5]
# [ 6 7 8]]
#
# [[ 9 10 11]
# [12 13 14]
# [15 16 17]]
#
# [[18 19 20]
# [21 22 23]
# [24 25 26]]]
dp_0 = diag_part_axis(a, axis=0)
print(dp_0.numpy())
# [[ 0 1 2]
# [12 13 14]
# [24 25 26]]
dp_1 = diag_part_axis(a, axis=1)
print(dp_1.numpy())
# [[ 0 4 8]
# [ 9 13 17]
# [18 22 26]]
"""
tf.debugging.assert_rank_at_least(tensor, 2,
message='The input tensor must have rank >= 2.')
if axis < 0:
axis = tf.rank(tensor) + axis
shape_in = tf.shape(tensor)
tf.debugging.assert_equal((axis < 0) | (axis > len(shape_in)-2),
False,
message="Input value of 'axis' out of boundaries.")
if 'k' not in kwargs:
len_k = 0
else:
if hasattr(kwargs['k'], '__len__'):
len_k = 1
else:
len_k = 0
# Push the axis (axis,axis+1) to the last dimensions
index1 = tf.concat([range(axis),
range(axis+2, len(shape_in)),
range(axis, axis+2)], 0)
tensor_out = tf.transpose(tensor, index1)
# Extract the diagonal part out of the 2 innermost dimensions
tensor_out = tf.linalg.diag_part(tensor_out,
**kwargs)
# Push the two last dimensions back to the original location
index2 = tf.concat([range(axis),
range(len(tensor_out.shape)-1-len_k,
len(tensor_out.shape)),
range(axis, len(tensor_out.shape)-1-len_k)], 0)
tensor_out = tf.transpose(tensor_out, index2)
return tensor_out
[docs]
def flatten_multi_index(indices, shape):
# pylint: disable=line-too-long
r"""
Converts a tensor of index arrays into an tensor of flat indices
Input
-----
indices : [..., N], `tf.int32`
Indices to flatten
shape : [N], `tf.int32`
Shape of each index dimension.
Note that it must hold that ``indices[..., n]<shape[n]`` for all n and
batch dimension
Output
------
flat_indices : [...], `tf.int32`
Flattened indices
Example
-------
.. code-block:: Python
import tensorflow as tf
from sionna.phy.utils import flatten_multi_index
indices = tf.constant([2, 3])
shape = [5, 6]
print(flatten_multi_index(indices, shape).numpy())
# 15 = 2*6 + 3
"""
indices = tf.cast(indices, tf.int32)
batch_rank = tf.rank(indices) - 1
# Assert that indices are within valid bounds
tf.debugging.assert_less(indices, insert_dims(shape, batch_rank, axis=0))
tf.debugging.assert_non_negative(indices)
# strides = [prod(shape[1:]), prod(shape[2,:]),...,shape[-1], 1]
strides = tf.math.cumprod([1] + shape[::-1][:-1])[::-1]
strides = insert_dims(strides, batch_rank, axis=0)
flat_indices = tf.reduce_sum(strides * indices, axis=-1)
return flat_indices
[docs]
def gather_from_batched_indices(params, indices):
# pylint: disable=line-too-long
r"""
Gathers the values of a tensor ``params`` according to batch-specific
``indices``
Input
-----
params : [s(1), ..., s(N)], `any`
Tensor containing the values to gather
indices : [..., N], `tf.int32`
Tensor containing, for each batch `[...]`, the indices at which
``params`` is gathered. Note that 0 :math:`\le`
``indices[...,n]`` :math:`<` `s(n)` must hold for all `n=1,...,N`
Output
------
: [...], `any`
Tensor containing the gathered values
Example
-------
.. code-block:: Python
import tensorflow as tf
from sionna.phy.utils import gather_from_batched_indices
params = tf.constant([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
print(params.shape)
# TensorShape([3, 3])
indices = tf.constant([[[0, 1], [1, 2], [2, 0], [0, 0]],
[[0, 0], [2, 2], [2, 1], [0, 1]]])
print(indices.shape)
# TensorShape([2, 4, 2])
# Note that the batch shape is [2, 4]. Each batch contains a list of 2 indices
print(gather_from_batched_indices(params, indices).numpy())
# [[20, 60, 70, 10],
# [10, 90, 80, 20]]
# Note that the output shape coincides with the batch shape.
# Element [i,j] coincides with params[indices[i,j,:]]
"""
# flatten indices
flat_indices = flatten_multi_index(indices, shape=params.shape)
# gather according to the flattened indices
return tf.gather(tf.reshape(params, [-1]), flat_indices)
[docs]
def tensor_values_are_in_set(tensor,
admissible_set):
r"""
Checks if the input ``tensor`` values are contained in the
specified ``admissible_set``
Input
-----
tensor : `tf.Tensor` | `list`
Tensor to validate
admissible_set : `tf.Tensor` | `list`
Set of valid values that the input ``tensor`` must be composed of
Output
------
: `bool`
Returns `True` if and only if ``tensor`` values are contained in
``admissible_set``
Example
-------
.. code-block:: Python
import tensorflow as tf
from sionna.phy.utils import tensor_values_are_in_set
tensor = tf.Variable([[1, 0], [0, 1]])
print(tensor_values_are_in_set(tensor, [0, 1, 2]).numpy())
# True
print(tensor_values_are_in_set(tensor, [0, 2]).numpy())
# False
"""
# Flatten tensors
tensor_flat = tf.reshape(tensor, [-1]) # Shape: [num_values]
admissible_set_flat = tf.reshape(admissible_set, [-1]) # Shape: [set_size]
# element [i] = 1 if tensor_flat[i] is found in admissible_set, else 0
# [len(tensor_unique)]
value_is_admissible = tf.reduce_any(
tf.equal(tf.expand_dims(tensor_flat, axis=0), # [1, -1]
tf.expand_dims(admissible_set_flat, axis=1)), # [-1, 1]
axis=0)
# Whether all tensor values are contained in admissible set
return tf.reduce_all(value_is_admissible)
def random_tensor_from_values(values, shape, dtype=None):
r"""
Generates a tensor of the specified ``shape``, with elements randomly
sampled from the provided set of ``values``
Input
-----
values : `tf.Tensor` | `list`
The set of values to sample from
shape : `tf.Tensor` | `list`
The desired shape of the output tensor
dtype : `tf.dtype`
Desired dtype of the output
Returns
-------
: `tf.Tensor`
A tensor with the specified shape, where each element is randomly
selected from ``values``
Example
-------
.. code-block:: Python
from sionna.phy.utils import random_tensor_from_values
values = [0, 10, 20]
shape = [2, 3]
print(random_tensor_from_values(values, shape).numpy())
# array([[ 0, 20, 0],
# [10, 0, 20]], dtype=int32)
"""
num_elements = tf.reduce_prod(shape)
indices = config.tf_rng.uniform(shape=(num_elements,),
minval=0,
maxval=len(values),
dtype=tf.int32)
tensor = tf.reshape(tf.gather(values, indices), shape)
if dtype is not None:
tensor = tf.cast(tensor, dtype)
return tensor
[docs]
def enumerate_indices(bounds):
r"""
Enumerates all indices between 0 (included) and ``bounds`` (excluded) in
lexicographic order
Input
-----
bounds : `list` | `tf.Tensor` | `np.array`, `int`
Collection of index bounds
Output
------
: [prod(bounds), len(bounds)]
Collection of all indices, in lexicographic order
Example
-------
.. code-block:: Python
from sionna.phy.utils import enumerate_indices
print(enumerate_indices([2, 3]).numpy())
# [[0 0]
# [0 1]
# [0 2]
# [1 0]
# [1 1]
# [1 2]]
"""
# Flattened indices: range from 0 to total number of elements
idx_flat = tf.range(tf.reduce_prod(bounds))
# Convert flattened indices to multi-dimensional indices
idx = tf.unravel_index(idx_flat, dims=bounds)
# Transpose
return tf.transpose(idx, [1, 0])
def find_true_position(bool_tensor,
side='last',
axis=-1):
"""
Finds the index of the first or last (according to the value of ``side``)
`True` value along the specified axis.
When no `True` value is present, it returns -1
Input
-----
bool_tensor : `tf.bool`
Boolean tensor of any shape
side : "first" | "last"
If "first", the first `True` position is found, else the last
axis : `int` (default: -1)
Axis along which to find the last `True` value
Output
------
index : `tf.int32`
Tensor of indices, containing the index of the first or last `True`
value.
Its shape is ``bool_tensor.shape`` with specified ``axis`` removed
"""
tf.debugging.assert_equal(side in ['first', 'last'],
True,
message="input side must be 'first' or 'last'")
rank = tf.rank(bool_tensor)
shape = tf.shape(bool_tensor)
# Convert to positive axis
axis = rank + axis if axis < 0 else axis
# Create sequence of indices
# [1, ..., shape[axis], 1, ..., 1]
indices = tf.range(shape[axis], dtype=tf.int32)
indices = insert_dims(indices, axis, axis=0)
indices = insert_dims(indices, rank - axis - 1, axis=-1)
# Broadcast to shape
# multiples = tf.tensor_scatter_nd_update(shape, [axis], [1])
# indices = tf.tile(indices, multiples)
indices = tf.broadcast_to(indices, shape)
if side == 'last':
# Where tensor is True, use computed indices, else set to -1
masked_indices = tf.where(bool_tensor, indices, -1)
# Get maximum index (last True position)
index = tf.reduce_max(masked_indices, axis=axis)
else:
# Where tensor is True, use computed indices, else set to shape[axis]
masked_indices = tf.where(bool_tensor, indices, shape[axis])
# Get minimum index (first True position)
index = tf.reduce_min(masked_indices, axis=axis)
# If not found, return -1
index = tf.where(index != shape[axis], index, -1)
return index