Source code for sionna.phy.object
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Definition of Sionna Object."""
import random
from typing import Any, Optional
import torch
import numpy as np
from .config import config, dtypes, Precision
__all__ = ["Object"]
[docs]
class Object(torch.nn.Module):
"""Base class for Sionna PHY objects.
:param precision: Floating-point precision ('single' or 'double') to be used within the block.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
Defaults to `None`.
:param device: Device for computation (e.g., 'cpu', 'cuda:0') to be used within the block.
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
Defaults to `None`.
"""
def __init__(
self,
*args: Any,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize the object."""
# Initialize nn.Module first
super().__init__()
if precision is not None and precision not in dtypes:
raise ValueError(f"Invalid precision: {precision}")
if device is not None and device not in config.available_devices:
raise ValueError(f"Invalid device: {device}")
self._precision: Precision = (
config.precision if precision is None else precision
)
# Use _device_str to avoid conflict with nn.Module internals
self._device_str: str = config.device if device is None else device
@property
def dtype(self) -> torch.dtype:
"""Get the PyTorch real-valued dtype based on the current precision."""
return dtypes[self.precision]["torch"]["dtype"]
@property
def cdtype(self) -> torch.dtype:
"""Get the PyTorch complex-valued dtype based on the current precision."""
return dtypes[self.precision]["torch"]["cdtype"]
@property
def np_dtype(self) -> type:
"""Get the NumPy real-valued dtype based on the current precision."""
return dtypes[self.precision]["np"]["dtype"]
@property
def np_cdtype(self) -> type:
"""Get the NumPy complex-valued dtype based on the current precision."""
return dtypes[self.precision]["np"]["cdtype"]
@property
def precision(self) -> Precision:
"""Get the floating-point precision ('single' or 'double')."""
return self._precision
@property
def device(self) -> str:
"""Get the device for computation (e.g., 'cpu', 'cuda:0')."""
return self._device_str
@property
def torch_rng(self) -> torch.Generator:
"""Get the PyTorch random number generator for the object's device."""
return config.torch_rng(self.device)
@property
def np_rng(self) -> np.random.Generator:
"""Get the NumPy random number generator."""
return config.np_rng
@property
def py_rng(self) -> random.Random:
"""Get the Python random number generator."""
return config.py_rng
def _convert(self, v: Any) -> Any:
# None stays None
if v is None:
return None
# Handle recursion for lists/tuples/dicts
if isinstance(v, (list, tuple)):
return type(v)(self._convert(x) for x in v)
if isinstance(v, dict):
return {k: self._convert(val) for k, val in v.items()}
# Strings and ints stay as-is (ints often used for shapes/indices)
if isinstance(v, (str, int)):
return v
# Convert floats and complex to tensors (data values)
# Also convert numpy arrays and other array-like objects
if not isinstance(v, torch.Tensor):
v = torch.as_tensor(v, device=self._device_str)
# Determine target dtype
if v.is_complex():
target_dtype = self.cdtype
elif v.is_floating_point():
target_dtype = self.dtype
else:
# Keep integer/boolean dtypes unchanged
target_dtype = v.dtype
# Only call .to() if conversion is needed
if v.device != torch.device(self.device) or v.dtype != target_dtype:
v = v.to(device=self.device, dtype=target_dtype)
return v
def _get_shape(self, v):
"""Extracts shape tuple if available."""
# Handle recursion
if isinstance(v, (list, tuple)):
return type(v)(self._get_shape(x) for x in v)
if isinstance(v, dict):
return {k: self._get_shape(val) for k, val in v.items()}
if hasattr(v, "shape"):
return tuple(v.shape)
return ()
def __setattr__(self, name: str, value: Any) -> None:
"""Override to ensure property setters are called even for nn.Module values.
PyTorch's nn.Module.__setattr__ intercepts nn.Module assignments and
registers them in _modules, bypassing property setters. This override
checks if there's a property descriptor with a setter on the class and
uses it instead.
"""
cls = type(self)
descriptor = getattr(cls, name, None)
if (
descriptor is not None
and isinstance(descriptor, property)
and descriptor.fset is not None
):
# Use the property setter directly
descriptor.fset(self, value)
else:
# Fall back to nn.Module's default behavior
super().__setattr__(name, value)