#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Utilities for computation with complex numbers"""
import drjit as dr
import mitsuba as mi
from typing import Tuple, Literal
[docs]
def cpx_add(
a: Tuple[mi.TensorXf, mi.TensorXf],
b: Tuple[mi.TensorXf, mi.TensorXf]
) -> Tuple[mi.TensorXf, mi.TensorXf]:
r"""Element-wise addition of two complex-valued tensors
Each tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param a: First tensor
:param b: Second tensor
"""
return (a[0] + b[0], a[1] + b[1])
[docs]
def cpx_sub(
a: Tuple[mi.TensorXf, mi.TensorXf],
b: Tuple[mi.TensorXf, mi.TensorXf]
) -> Tuple[mi.TensorXf, mi.TensorXf]:
r"""Element-wise substraction of a complex-valued tensor from another
Each tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param a: First tensor
:param b: Second tensor which is substracted from the first
"""
return (a[0] - b[0], a[1] - b[1])
[docs]
def cpx_mul(
a: Tuple[mi.TensorXf, mi.TensorXf],
b: Tuple[mi.TensorXf, mi.TensorXf]
) -> Tuple[mi.TensorXf, mi.TensorXf]:
r"""Element-wise multiplication of two complex-valued tensors
Each tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param a: First tensor
:param b: Second tensor
"""
re = a[0]*b[0] - a[1]*b[1]
im = a[0]*b[1] + a[1]*b[0]
return (re, im)
[docs]
def cpx_div(
a: Tuple[mi.TensorXf, mi.TensorXf],
b: Tuple[mi.TensorXf, mi.TensorXf]
) -> Tuple[mi.TensorXf, mi.TensorXf]:
r"""Element-wise division of a complex-valued tensor by another
Each tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param a: First tensor
:param b: Second tensor by which the first is divided
"""
d = dr.rcp(dr.square(b[0]) + dr.square(b[1]))
re = (a[0]*b[0] + a[1]*b[1]) * d
im = (a[1]*b[0] - a[0]*b[1]) * d
return (re, im)
[docs]
def cpx_exp(
x: Tuple[mi.TensorXf, mi.TensorXf]
) -> Tuple[mi.TensorXf, mi.TensorXf]:
r"""Element-wise exponential of a complex-valued tensor
The tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param x: A tensor
"""
exp_re = dr.exp(x[0])
sin_im, cos_im = dr.sincos(x[1])
return (exp_re*cos_im, exp_re*sin_im)
[docs]
def cpx_abs(
x: Tuple[mi.TensorXf, mi.TensorXf]
) -> mi.TensorXf:
r"""Element-wise absolute value of a complex-valued tensor
The tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param x: A tensor
"""
return dr.safe_sqrt(cpx_abs_square(x))
[docs]
def cpx_abs_square(
x: Tuple[mi.TensorXf, mi.TensorXf]
) -> mi.TensorXf:
r"""Element-wise absolute squared value of a complex-valued tensor
The tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
:param x: A tensor
"""
return dr.square(x[0]) + dr.square(x[1])
[docs]
def cpx_sqrt(
x: Tuple[mi.TensorXf, mi.TensorXf]
) -> mi.TensorXf:
r"""Element-wise square root of a complex-valued tensor
The tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
The following formula is implemented to compute the square roots of complex
numbers:
https://en.wikipedia.org/wiki/Square_root#Algebraic_formula
:param x: A tensor
"""
x_r = x[0]
x_i = x[1]
r = dr.safe_sqrt(dr.square(x_r) + dr.square(x_i))
y_r = dr.safe_sqrt(0.5*(r + x_r))
y_i = dr.sign(x_i)*dr.safe_sqrt(0.5*(r - x_r))
return (y_r, y_i)
[docs]
def cpx_convert(
x: Tuple[mi.TensorXf, mi.TensorXf],
out_type: Literal["numpy", "jax", "tf", "torch"]
):
r"""
Converts a complex-valued tensor to any of the supported frameworks
The tensor is represented as a tuple of two real-valued tensors,
corresponding to the real and imaginary part, respectively.
Note that the chosen framework must be installed for this function
to work.
:param x: A tensor
:param out_type: Name of the target framework. Currently supported
are `Numpy <https://numpy.org>`_ ("numpy"),
`Jax <https://jax.readthedocs.io/en/latest/index.html>`_ ("jax"),
`TensorFlow <https://www.tensorflow.org>`_ ("tf"),
and `PyTorch <https://pytorch.org>`_ ("torch").
:return type: :py:class:`np.array` | :py:class:`jax.array` |
:py:class:`tf.Tensor` | :py:class:`torch.tensor`
"""
if out_type == "numpy":
return x[0].numpy() + 1j*x[1].numpy()
elif out_type == "tf":
try:
import tensorflow as tf # pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError("Please install TensorFlow to use this feature.")\
from e
return tf.complex(x[0].tf(), x[1].tf())
elif out_type == "torch":
try:
import torch # pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError("Please install PyTorch to use this feature.") \
from e
return torch.complex(x[0].torch(), x[1].torch())
elif out_type == "jax":
try:
from jax import lax # pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError("Please install Jax to use this feature.") from e
return lax.complex(x[0].jax(), x[1].jax())
else:
raise ValueError(f"Unsupported target: {out_type}")