Source code for sionna.phy.utils.numerics

# pylint: disable=line-too-long
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Numerical methods for Sionna PHY and SYS"""

import tensorflow as tf
from sionna.phy import config, dtypes


def expand_bound(f,
                 bound,
                 side,
                 step_expand=2.,
                 max_n_iter=100,
                 precision=None,
                 **kwargs):
    r"""
    Expands the left (right, respectively) search interval end point until the
    function ``f`` becomes positive (negative, resp.)

    Input
    -----

    f : `callable`
        Generic function handle that takes batched inputs and returns batched outputs.
        Applies a different decreasing univariate function to each of its inputs.
        Must accept input batches of the same shape as ``left`` and ``right``.

    bound : [...], `tf.float`
        Left (if ``side`` is 'left') or right (if ``side`` is 'right') end point
        of the initial search interval, for each batch

    side : 'left' | 'right'
        See ``bound``

    step_expand : `float`
        Geometric progression factor at which the bound is expanded. Must be
        higher than 1.

    max_n_iter : `int`
        Maximum number of iterations

    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    kwargs : `dict`
        Additional arguments for function ``f``

    Output
    ------

    bound : [...], `tf.float`
        Final value of expanded bound
    """
    if precision is None:
        rdtype = config.tf_rdtype
    else:
        rdtype = dtypes[precision]["tf"]["rdtype"]

    # Cast inputs
    bound = tf.cast(bound, rdtype)
    step_expand = tf.cast(step_expand, rdtype)

    # Validate inputs
    tf.debugging.assert_equal(
        side in ['left', 'right'], True,
        message="side must be 'left' or 'right'")

    tf.debugging.assert_greater(step_expand, tf.cast(1., rdtype),
                                message="step_expand must be > 1")

    # Initialize left and right bounds for search intervals
    if side == 'left':
        bound = tf.while_loop(
            lambda bound, _: tf.reduce_any(f(bound, **kwargs) < 0),
            lambda bound, ii: [tf.where(f(bound, **kwargs) < 0,
                                        bound -
                                        tf.pow(tf.abs(step_expand),
                                               tf.cast(ii, rdtype)),
                                        bound),
                               ii + 1],
            [bound, 0],
            maximum_iterations=max_n_iter)[0]

        tf.debugging.assert_equal(
            tf.reduce_all(f(bound, **kwargs) >= 0),
            True,
            message="Root cannot be found. Please either increase "
            "'step_expand' or 'max_n_iter'")
    else:
        bound = tf.while_loop(
            lambda bound, _: tf.reduce_any(f(bound, **kwargs) > 0),
            lambda bound, ii: [tf.where(f(bound, **kwargs) > 0,
                                        bound +
                                        tf.pow(tf.abs(step_expand),
                                               tf.cast(ii, rdtype)),
                                        bound),
                               ii + 1],
            [bound, 0],
            maximum_iterations=max_n_iter)[0]

        tf.debugging.assert_equal(
            tf.reduce_all(f(bound, **kwargs) <= 0),
            True,
            message="Root cannot be found. Please either increase "
            "'step_expand' or 'max_n_iter'")
    return bound


[docs] def bisection_method(f, left, right, regula_falsi=False, expand_to_left=True, expand_to_right=True, step_expand=2., eps_x=1e-5, eps_y=1e-4, max_n_iter=100, return_brackets=False, precision=None, **kwargs): r""" Implements the classic bisection method for estimating the root of batches of decreasing univariate functions Input ----- f : `callable` Generic function handle that takes batched inputs and returns batched outputs. Applies a different decreasing univariate function to each of its inputs. Must accept input batches of the same shape as ``left`` and ``right``. left : [...], `tf.float` Left end point of the initial search interval, for each batch. The root is guessed to be contained within [``left``, ``right``]. right : [...], `tf.float` Right end point of the initial search interval, for each batch regula_falsi : `bool` (default: `False`) If `True`, then the `regula falsi` method is employed to determine the next root guess. This guess is computed as the x-intercept of the line passing through the two points formed by the function evaluated at the current search interval endpoints. Else, the next root guess is computed as the middle point of the current search interval. expand_to_left : `bool` (default: `True`) If `True` and ``f(left)`` is negative, then ``left`` is decreased by a geometric progression of ``step_expand`` until ``f`` becomes positive, for each batch. If `False`, then ``left`` is not decreased. expand_to_right : `bool` (default: `True`) If `True` and ``f(left)`` is positive, then ``right`` is increased by a geometric progression of ``step_expand`` until ``f`` becomes negative, for each batch. If `False`, then ``right`` is not increased. step_expand : `float` (default: 2.) See ``expand_to_left`` and ``expand_to_right`` eps_x : `float` (default: 1e-4) Convergence criterion. Search terminates after ``max_n_iter`` iterations or if, for each batch, either the search interval length is smaller than ``eps_x`` or the function absolute value is smaller than ``eps_y``. eps_y : `float` (default: 1e-4) Convergence criterion. See ``eps_x``. max_n_iter : `int` (default: 1000) Maximum number of iterations return_brackets : `bool` (default: `False`) If `True`, the final values of search interval ``left`` and ``right`` end point are returned precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. kwargs : `dict` Additional arguments for function ``f`` Output ------ x_opt : [...], `tf.float` Estimated roots of the input batch of functions ``f`` f_opt : [...], `tf.float` Value of function ``f`` evaluated at ``x_opt`` left : [...], `tf.float` Final value of left end points of the search intervals. Only returned if ``return_brackets`` is `True`. right : [...], `tf.float` Final value of right end points of the search intervals. Only returned if ``return_brackets`` is `True`. Example ------- .. code-block:: Python import tensorflow as tf from sionna.phy.utils import bisection_method # Define a decreasing univariate function of x def f(x, a): return - tf.math.pow(x - a, 3) # Initial search interval left, right = 0., 2. # Input parameter for function a a = 3 # Perform bisection method x_opt, _ = bisection_method(f, left, right, eps_x=1e-4, eps_y=0, a=a) print(x_opt.numpy()) # 2.9999084 """ if precision is None: rdtype = config.tf_rdtype else: rdtype = dtypes[precision]["tf"]["rdtype"] # Validate inputs tf.debugging.assert_less_equal( left, right, message="bound_left must be <= bound_right") # Cast inputs left = tf.cast(left, rdtype) right = tf.cast(right, rdtype) eps_x = tf.cast(eps_x, rdtype) # -------------------------- # # Expand (or not) end points # # -------------------------- # if expand_to_left: # Decrease left bracket until f gets positive left = expand_bound(f, left, 'left', step_expand=step_expand, max_n_iter=max_n_iter, precision=precision, **kwargs) else: left = tf.where(f(right, **kwargs) > 0, right, left) if expand_to_right: # Increase left bracket until f gets negative right = expand_bound(f, right, 'right', step_expand=step_expand, max_n_iter=max_n_iter, precision=precision, **kwargs) else: right = tf.where(f(left, **kwargs) < 0, left, right) # -------------- # # Initialization # # -------------- # def get_x_next(left, right): """Computes the next guess of the function root""" if regula_falsi: # Regula falsi: # Compute x-intercept of function evaluated at current end points f_left = f(left, **kwargs) f_right = f(right, **kwargs) x_next = tf.where( right > left, (left * f_right - right * f_left) / (f_right - f_left), left) else: # Compute middle point of search interval x_next = (left + right) / tf.cast(2, rdtype) return x_next x_next = get_x_next(left, right) f_next = f(x_next, **kwargs) # -------------- # # Bisection loop # # -------------- # # pylint: disable=unused-argument def cond_bisection(left, right, x_next, f_next): """Convergence criterion (If True, search continues)""" # Condition 1: Interval length is small enough stop_cond1 = tf.abs(right - left) < eps_x # Condition 2: Function value is small enough stop_cond2 = tf.abs(f_next) < eps_y return not tf.reduce_all(stop_cond1 | stop_cond2) def body_bisection(left, right, x_next, f_next): """Bisection body: Update left and right bounds""" # Next guess x_next = get_x_next(left, right) f_next = f(x_next, **kwargs) # If f_next >= 0, then shrink interval to the right left = tf.where(f_next >= 0, x_next, left) # If f_next <= 0, then shrink interval to the left right = tf.where(f_next <= 0, x_next, right) return [left, right, x_next, f_next] # Perform bisection method left, right, x_opt, f_opt = tf.while_loop( cond_bisection, body_bisection, [left, right, x_next, f_next], maximum_iterations=max_n_iter) if return_brackets: return x_opt, f_opt, left, right else: return x_opt, f_opt