Source code for sionna.phy.block


#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Definition of Sionna PHY Object and Block classes"""

from abc import ABC
from abc import abstractmethod
import tensorflow as tf
import numpy as np
from .config import config, dtypes

[docs] class Object(ABC): """Abstract class for Sionna PHY objects Parameters ---------- precision : `None` (default) | "single" | "double" Precision used for internal calculations. If set to `None`, the default :attr:`~sionna.phy.config.Config.precision` is used. """ # pylint: disable=unused-argument def __init__(self, *args, precision=None, **kwargs): if precision is None: self._precision = config.precision elif precision in ['single', 'double']: self._precision = precision else: raise ValueError("'precision' must be 'single' or 'double'") @property def precision(self): """ `str`, "single" | "double" : Precision used for all compuations """ return self._precision @property def cdtype(self): """ `tf.complex` : Type for complex floating point numbers """ return dtypes[self.precision]['tf']['cdtype'] @property def rdtype(self): """ `tf.float` : Type for real floating point numbers """ return dtypes[self.precision]['tf']['rdtype'] def _cast_or_check_precision(self, v): """Cast tensor to internal precision or check if a variable has the right precision """ # Check correct dtype for Variables if isinstance(v, tf.Variable): if v.dtype.is_complex: if v.dtype != self.cdtype: msg = f"Wrong dtype. Expected {self.cdtype}" + \ f", got {v.dtype}" raise ValueError(msg) elif v.dtype.is_floating: if v.dtype != self.rdtype: msg = f"Wrong dtype. Expected {self.cdtype}" + \ f", got {v.dtype}" raise ValueError("Wrong dtype") # Cast tensors to the correct dtype else: if not isinstance(v, tf.Tensor): v = tf.convert_to_tensor(v) if v.dtype.is_complex: v = tf.cast(v, self.cdtype) else: v = tf.cast(v, self.rdtype) return v
[docs] class Block(Object): """Abstract class for Sionna PHY processing blocks Parameters ---------- precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, the default :attr:`~sionna.phy.config.Config.precision` is used. """ # pylint: disable=unused-argument def __init__(self, *args, precision=None, **kwargs): super().__init__(precision=precision, **kwargs) # Boolean flag indicating if the block's build function has been called # This will prevent rebuilding the block in Eager mode each time it is # called. self._built = False @property def built(self): """ `bool` : Indicates if the blocks' build function was called """ return self._built
[docs] def build(self, *arg_shapes, **kwarg_shapes): """ Method to (optionally) initialize the block based on the inputs' shapes """ pass
[docs] @abstractmethod def call(self, *args, **kwargs): """ Abstract call method with arbitrary arguments and keyword arguments """ raise NotImplementedError("Subclasses must implement this method.")
def _convert_to_tensor(self, v): """Casts floating or complex tensors to the block's precision""" if isinstance(v, np.ndarray): v = tf.convert_to_tensor(v) if isinstance(v, tf.Tensor): if v.dtype.is_floating: v = tf.cast(v, self.rdtype) elif v.dtype.is_complex: v = tf.cast(v, self.cdtype) return v def _get_shape(self, v): """Converts an input to the corresponding TensorShape""" try : v = tf.convert_to_tensor(v) except (TypeError, ValueError): pass if hasattr(v, "shape"): return tf.TensorShape(v.shape) else: return tf.TensorShape([]) def __call__(self, *args, **kwargs): args, kwargs = tf.nest.map_structure(self._convert_to_tensor, [args, kwargs]) with tf.init_scope(): # pylint: disable=not-context-manager if not self._built: shapes = tf.nest.map_structure(self._get_shape, [args, kwargs]) self.build(*shapes[0], **shapes[1]) self._built = True return self.call(*args, **kwargs)