#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Numerical methods for Sionna PHY and SYS"""
from typing import Callable, Optional, Tuple, Union
import torch
from sionna.phy import config, dtypes
__all__ = ["expand_bound", "bisection_method"]
[docs]
def expand_bound(
f: Callable,
bound: torch.Tensor,
side: str,
step_expand: float = 2.0,
max_n_iter: int = 100,
precision: Optional[str] = None,
**kwargs,
) -> torch.Tensor:
r"""
Expands the left (right, respectively) search interval end point until the
function ``f`` becomes positive (negative, resp.)
:param f: Generic function handle that takes batched inputs and returns batched outputs.
Applies a different decreasing univariate function to each of its inputs.
Must accept input batches of the same shape as ``left`` and ``right``.
:param bound: [...], `torch.float`.
Left (if ``side`` is 'left') or right (if ``side`` is 'right') end point
of the initial search interval, for each batch.
:param side: 'left' | 'right'.
See ``bound``.
:param step_expand: Geometric progression factor at which the bound is expanded. Must be
higher than 1.
:param max_n_iter: Maximum number of iterations.
:param precision: Precision used for internal calculations and outputs.
If set to `None`,
:attr:`~sionna.phy.config.Config.precision` is used.
:param kwargs: Additional arguments for function ``f``.
:output bound: [...], `torch.float`.
Final value of expanded bound.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.utils.numerics import expand_bound
# Define a decreasing univariate function of x
def f(x):
return -x + 1.0
# Expand right bound until f becomes negative
bound = torch.tensor(0.5)
result = expand_bound(f, bound, side='right', step_expand=2.0)
print(result)
"""
if precision is None:
dtype = config.dtype
else:
dtype = dtypes[precision]["torch"]["dtype"]
# Determine device from input tensor, fall back to config.device
if isinstance(bound, torch.Tensor):
device = bound.device
else:
device = config.device
# Cast inputs - preserve the device of input tensor
if not isinstance(bound, torch.Tensor):
bound = torch.tensor(bound, dtype=dtype, device=device)
else:
bound = bound.to(dtype=dtype) # Keep on original device
if not isinstance(step_expand, torch.Tensor):
step_expand = torch.tensor(step_expand, dtype=dtype, device=device)
else:
step_expand = step_expand.to(dtype=dtype, device=device)
# Validate inputs
assert side in ["left", "right"], "side must be 'left' or 'right'"
assert step_expand > 1, "step_expand must be > 1"
# Initialize left and right bounds for search intervals
if side == "left":
for i in range(max_n_iter):
f_val = f(bound, **kwargs)
# Check if any bound needs expansion (f < 0)
condition = f_val < 0
if not torch.any(condition):
break
# Update bounds where f < 0
# bound = bound - step_expand^i
step = torch.pow(torch.abs(step_expand), i)
bound = torch.where(condition, bound - step, bound)
assert torch.all(
f(bound, **kwargs) >= 0
), "Root cannot be found. Please either increase 'step_expand' or 'max_n_iter'"
else:
for i in range(max_n_iter):
f_val = f(bound, **kwargs)
# Check if any bound needs expansion (f > 0)
condition = f_val > 0
if not torch.any(condition):
break
# Update bounds where f > 0
# bound = bound + step_expand^i
step = torch.pow(torch.abs(step_expand), i)
bound = torch.where(condition, bound + step, bound)
assert torch.all(
f(bound, **kwargs) <= 0
), "Root cannot be found. Please either increase 'step_expand' or 'max_n_iter'"
return bound
[docs]
def bisection_method(
f: Callable,
left: torch.Tensor,
right: torch.Tensor,
regula_falsi: bool = False,
expand_to_left: bool = True,
expand_to_right: bool = True,
step_expand: float = 2.0,
eps_x: float = 1e-5,
eps_y: float = 1e-4,
max_n_iter: int = 100,
return_brackets: bool = False,
precision: Optional[str] = None,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
r"""
Implements the classic bisection method for estimating the root of batches of decreasing
univariate functions
:param f: Generic function handle that takes batched inputs and returns batched outputs.
Applies a different decreasing univariate function to each of its inputs.
Must accept input batches of the same shape as ``left`` and ``right``.
:param left: [...], `torch.float`.
Left end point of the initial search interval, for each batch.
The root is guessed to be contained within [``left``, ``right``].
:param right: [...], `torch.float`.
Right end point of the initial search interval, for each batch.
:param regula_falsi: If `True`, then the `regula falsi` method is employed to determine the
next root guess. This guess is computed as the x-intercept of the line
passing through the two points formed by the function evaluated at the
current search interval endpoints.
Else, the next root guess is computed as the middle point of the current
search interval. Defaults to `False`.
:param expand_to_left: If `True` and ``f(left)`` is negative, then ``left`` is decreased by a
geometric progression of ``step_expand`` until ``f`` becomes positive,
for each batch.
If `False`, then ``left`` is not decreased. Defaults to `True`.
:param expand_to_right: If `True` and ``f(left)`` is positive, then ``right`` is increased by a
geometric progression of ``step_expand`` until ``f`` becomes negative,
for each batch.
If `False`, then ``right`` is not increased. Defaults to `True`.
:param step_expand: See ``expand_to_left`` and ``expand_to_right``. Defaults to 2.0.
:param eps_x: Convergence criterion. Search terminates after ``max_n_iter`` iterations
or if, for each batch, either the search interval length is smaller than
``eps_x`` or the function absolute value is smaller than ``eps_y``. Defaults to 1e-5.
:param eps_y: Convergence criterion. See ``eps_x``. Defaults to 1e-4.
:param max_n_iter: Maximum number of iterations. Defaults to 100.
:param return_brackets: If `True`, the final values of search interval ``left`` and ``right``
end point are returned. Defaults to `False`.
:param precision: Precision used for internal calculations and outputs.
If set to `None`,
:attr:`~sionna.phy.config.Config.precision` is used.
:param kwargs: Additional arguments for function ``f``.
:output x_opt: [...], `torch.float`.
Estimated roots of the input batch of functions ``f``.
:output f_opt: [...], `torch.float`.
Value of function ``f`` evaluated at ``x_opt``.
:output left: [...], `torch.float`.
Final value of left end points of the search intervals.
Only returned if ``return_brackets`` is `True`.
:output right: [...], `torch.float`.
Final value of right end points of the search intervals.
Only returned if ``return_brackets`` is `True`.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.utils import bisection_method
# Define a decreasing univariate function of x
def f(x, a):
return - torch.pow(x - a, 3)
# Initial search interval
left, right = 0., 2.
# Input parameter for function a
a = 3
# Perform bisection method
x_opt, _ = bisection_method(f, left, right, eps_x=1e-4, eps_y=0, a=a)
print(x_opt.numpy())
# 2.9999084
"""
if precision is None:
dtype = config.dtype
else:
dtype = dtypes[precision]["torch"]["dtype"]
# Determine device from input tensors, fall back to config.device
if isinstance(left, torch.Tensor):
device = left.device
elif isinstance(right, torch.Tensor):
device = right.device
else:
device = config.device
# Cast inputs - preserve the device of input tensors
if not isinstance(left, torch.Tensor):
left = torch.tensor(left, dtype=dtype, device=device)
else:
left = left.to(dtype=dtype) # Keep on original device
if not isinstance(right, torch.Tensor):
right = torch.tensor(right, dtype=dtype, device=device)
else:
right = right.to(dtype=dtype) # Keep on original device
eps_x = torch.tensor(eps_x, dtype=dtype, device=device)
# Validate inputs
assert torch.all(left <= right), "bound_left must be <= bound_right"
# -------------------------- #
# Expand (or not) end points #
# -------------------------- #
if expand_to_left:
# Decrease left bracket until f gets positive
left = expand_bound(
f,
left,
"left",
step_expand=step_expand,
max_n_iter=max_n_iter,
precision=precision,
**kwargs,
)
else:
left = torch.where(f(right, **kwargs) > 0, right, left)
if expand_to_right:
# Increase left bracket until f gets negative
right = expand_bound(
f,
right,
"right",
step_expand=step_expand,
max_n_iter=max_n_iter,
precision=precision,
**kwargs,
)
else:
right = torch.where(f(left, **kwargs) < 0, left, right)
# -------------- #
# Initialization #
# -------------- #
def get_x_next(left, right):
"""Computes the next guess of the function root"""
if regula_falsi:
# Regula falsi:
# Compute x-intercept of function evaluated at current end points
f_left = f(left, **kwargs)
f_right = f(right, **kwargs)
x_next = torch.where(
right > left,
(left * f_right - right * f_left) / (f_right - f_left),
left,
)
else:
# Compute middle point of search interval
x_next = (left + right) / 2
return x_next
x_next = get_x_next(left, right)
f_next = f(x_next, **kwargs)
# -------------- #
# Bisection loop #
# -------------- #
for _ in range(max_n_iter):
# Convergence criterion
# Condition 1: Interval length is small enough
stop_cond1 = torch.abs(right - left) < eps_x
# Condition 2: Function value is small enough
stop_cond2 = torch.abs(f_next) < eps_y
if torch.all(stop_cond1 | stop_cond2):
break
# Update left and right bounds
# Next guess
x_next = get_x_next(left, right)
f_next = f(x_next, **kwargs)
# If f_next >= 0, then shrink interval to the right
left = torch.where(f_next >= 0, x_next, left)
# If f_next <= 0, then shrink interval to the left
right = torch.where(f_next <= 0, x_next, right)
if return_brackets:
return x_next, f_next, left, right
else:
return x_next, f_next