#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
from typing import Optional, Union
import numpy as np
import torch
from sionna.phy import Block
from sionna.phy.config import Precision
from sionna.phy.signal import ifft
from sionna.phy.utils import flatten_last_dims
__all__ = ["OFDMModulator"]
[docs]
class OFDMModulator(Block):
r"""Computes the time-domain representation of an OFDM resource grid
with (optional) cyclic prefix.
:param cyclic_prefix_length: Integer or vector of integers indicating the
length of the cyclic prefix that is prepended to each OFDM symbol.
None of its elements can be larger than the FFT size. Defaults to `0`.
:param precision: Precision used for internal calculations and outputs.
If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for tensor operations. If `None`,
:attr:`~sionna.phy.config.Config.device` is used.
:input inputs: [..., num_ofdm_symbols, fft_size], `torch.complex`.
Resource grid in the frequency domain.
:output x_time: [..., num_ofdm_symbols*(fft_size+cyclic_prefix_length)] or [..., num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)], `torch.complex`.
Time-domain OFDM signal.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.ofdm import OFDMModulator
modulator = OFDMModulator(cyclic_prefix_length=16)
# Resource grid: [batch, num_ofdm_symbols, fft_size]
x_freq = torch.randn(64, 14, 72, dtype=torch.complex64)
x_time = modulator(x_freq)
print(x_time.shape)
# torch.Size([64, 1232]) # 14 * (72 + 16) = 1232
"""
def __init__(
self,
cyclic_prefix_length: Union[int, np.ndarray, torch.Tensor] = 0,
precision: Optional[Precision] = None,
device: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(precision=precision, device=device, **kwargs)
self._cp_length_scalar: Optional[int] = None # Cached scalar for call()
# Register tensors as buffers for CUDA graph compatibility
self.register_buffer("_cyclic_prefix_length", None)
self.register_buffer("_ind", None)
self.cyclic_prefix_length = cyclic_prefix_length
@property
def cyclic_prefix_length(self) -> torch.Tensor:
"""Get/set the cyclic prefix length (scalar or per-symbol)"""
return self._cyclic_prefix_length
@cyclic_prefix_length.setter
def cyclic_prefix_length(self, value: Union[int, np.ndarray, torch.Tensor]) -> None:
if isinstance(value, (int, float)):
value = torch.tensor([value], dtype=torch.int32, device=self.device)
elif isinstance(value, np.ndarray):
value = torch.tensor(value, dtype=torch.int32, device=self.device)
else:
value = value.to(dtype=torch.int32, device=self.device)
if not torch.all(value >= 0):
raise ValueError("`cyclic_prefix_length` must be nonnegative.")
if not 0 <= value.dim() <= 1:
raise ValueError("`cyclic_prefix_length` must be of rank 0 or 1.")
# Store as 0D if scalar, 1D otherwise
if value.numel() == 1:
# Register as buffer for CUDA graph compatibility
self.register_buffer("_cyclic_prefix_length", value.squeeze())
# Cache scalar value to avoid .item() during tracing
self._cp_length_scalar = int(value.item())
else:
self.register_buffer("_cyclic_prefix_length", value)
self._cp_length_scalar = None
[docs]
def build(self, input_shape: tuple) -> None:
"""Build the modulator based on input shape.
:param input_shape: Shape of the input tensor
`[..., num_ofdm_symbols, fft_size]`
"""
num_ofdm_symbols, fft_size = input_shape[-2:]
cp_len = self._cyclic_prefix_length
if not torch.all(cp_len <= fft_size):
raise ValueError("`cyclic_prefix_length` cannot be larger than `fft_size`.")
if cp_len.dim() == 1:
if cp_len.shape[0] != num_ofdm_symbols:
raise ValueError(
"`cyclic_prefix_length` must be of size [num_ofdm_symbols]"
)
# Compute indices of CP symbols and data symbols
# Build the gather indices for variable CP lengths
# Convert to list once to avoid .item() calls during tracing
cp_lengths = cp_len.tolist()
indices_list = []
offset = 0
for i in range(num_ofdm_symbols):
cp_length_i = cp_lengths[i]
# CP indices (last cp_length_i samples of the OFDM symbol)
cp_start = (i + 1) * fft_size - cp_length_i
cp_indices = torch.arange(
cp_start, (i + 1) * fft_size, dtype=torch.int64, device=self.device
)
# Data indices
data_indices = torch.arange(
i * fft_size,
(i + 1) * fft_size,
dtype=torch.int64,
device=self.device,
)
indices_list.append(cp_indices)
indices_list.append(data_indices)
offset += cp_length_i + fft_size
# Concatenate all indices
self.register_buffer("_ind", torch.cat(indices_list))
def call(self, inputs: torch.Tensor) -> torch.Tensor:
"""Modulate OFDM resource grid to time-domain signal.
:param inputs: Resource grid in frequency domain with shape
`[..., num_ofdm_symbols, fft_size]`
:output x_time: Time-domain OFDM signal with shape
`[..., num_ofdm_symbols*(fft_size+cyclic_prefix_length)]` or
`[..., num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)]`
"""
# Shift DC subcarrier to first position
x_freq = torch.fft.ifftshift(inputs, dim=-1)
# Compute IFFT along the last dimension
x_time = ifft(x_freq, precision=self.precision)
cp_len = self._cyclic_prefix_length
if cp_len.dim() == 1:
# Individual CP length per OFDM symbol
# Flatten last two dimensions
x_time = flatten_last_dims(x_time, 2)
# Gather full time-domain signal
return x_time[..., self._ind]
else:
# Same CP length for all OFDM symbols
# Use cached scalar to avoid .item() during tracing
cp_length = self._cp_length_scalar
if cp_length > 0:
# Obtain cyclic prefix
cp = x_time[..., -cp_length:]
# Prepend cyclic prefix
x_time = torch.cat([cp, x_time], dim=-1)
# Serialize last two dimensions
return flatten_last_dims(x_time, 2)