Source code for sionna.ofdm.demodulator

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition for the OFDM Demodulator"""

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.signal import fftshift
from sionna.constants import PI
from sionna.utils import expand_to_rank
from sionna.signal import fft
import numpy as np

[docs]class OFDMDemodulator(Layer): # pylint: disable=line-too-long r""" OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs) Computes the frequency-domain representation of an OFDM waveform with cyclic prefix removal. The demodulator assumes that the input sequence is generated by the :class:`~sionna.channel.TimeChannel`. For a single pair of antennas, the received signal sequence is given as: .. math:: y_b = \sum_{\ell =L_\text{min}}^{L_\text{max}} \bar{h}_\ell x_{b-\ell} + w_b, \quad b \in[L_\text{min}, N_B+L_\text{max}-1] where :math:`\bar{h}_\ell` are the discrete-time channel taps, :math:`x_{b}` is the the transmitted signal, and :math:`w_\ell` Gaussian noise. Starting from the first symbol, the demodulator cuts the input sequence into pieces of size ``cyclic_prefix_length + fft_size``, and throws away any trailing symbols. For each piece, the cyclic prefix is removed and the ``fft_size``-point discrete Fourier transform is computed. Since the input sequence starts at time :math:`L_\text{min}`, the FFT-window has a timing offset of :math:`L_\text{min}` symbols, which leads to a subcarrier-dependent phase shift of :math:`e^{\frac{j2\pi k L_\text{min}}{N}}`, where :math:`k` is the subcarrier index, :math:`N` is the FFT size, and :math:`L_\text{min} \le 0` is the largest negative time lag of the discrete-time channel impulse response. This phase shift is removed in this layer, by explicitly multiplying each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`. This is a very important step to enable channel estimation with sparse pilot patterns that needs to interpolate the channel frequency response accross subcarriers. It also ensures that the channel frequency response `seen` by the time-domain channel is close to the :class:`~sionna.channel.OFDMChannel`. Parameters ---------- fft_size : int FFT size (, i.e., the number of subcarriers). l_min : int The largest negative time lag of the discrete-time channel impulse response. It should be the same value as that used by the `cir_to_time_channel` function. cyclic_prefix_length : int Integer indicating the length of the cyclic prefix that is prepended to each OFDM symbol. Input ----- :[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)+n], tf.complex Tensor containing the time-domain signal along the last dimension. `n` is a nonnegative integer. Output ------ :[...,num_ofdm_symbols,fft_size], tf.complex Tensor containing the OFDM resource grid along the last two dimension. """ def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs): super().__init__(**kwargs) self.fft_size = fft_size self.l_min = l_min self.cyclic_prefix_length = cyclic_prefix_length @property def fft_size(self): return self._fft_size @fft_size.setter def fft_size(self, value): assert value>0, "`fft_size` must be positive." self._fft_size = int(value) @property def l_min(self): return self._l_min @l_min.setter def l_min(self, value): assert value<=0, "l_min must be nonpositive." self._l_min = int(value) @property def cyclic_prefix_length(self): return self._cyclic_prefix_length @cyclic_prefix_length.setter def cyclic_prefix_length(self, value): assert value >=0, "`cyclic_prefix_length` must be nonnegative." self._cyclic_prefix_length = int(value) def build(self, input_shape): # pylint: disable=unused-argument tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \ / tf.cast(self.fft_size, tf.float32) \ * tf.range(self.fft_size, dtype=tf.float32) self._phase_compensation = tf.exp(tf.complex(0., tmp)) # Compute number of elements that will be truncated self._rest = np.mod(input_shape[-1], self.fft_size + self.cyclic_prefix_length) # Compute number of full OFDM symbols to be demodulated self._num_ofdm_symbols = np.floor_divide( input_shape[-1]-self._rest, self.fft_size + self.cyclic_prefix_length) def call(self, inputs): """Demodulate OFDM waveform onto a resource grid. Args: inputs (tf.complex64): `[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)]`. Returns: `tf.complex64` : The demodulated inputs of shape `[...,num_ofdm_symbols, fft_size]`. """ # Cut last samples that do not fit into an OFDM symbol inputs = inputs if self._rest==0 else inputs[...,:-self._rest] # Reshape input to separate OFDM symbols new_shape = tf.concat([tf.shape(inputs)[:-1], [self._num_ofdm_symbols], [self.fft_size + self.cyclic_prefix_length]], 0) x = tf.reshape(inputs, new_shape) # Remove cyclic prefix x = x[...,self.cyclic_prefix_length:] # Compute FFT x = fft(x) # Apply phase shift compensation to all subcarriers rot = tf.cast(self._phase_compensation, x.dtype) rot = expand_to_rank(rot, tf.rank(x), 0) x = x * rot # Shift DC subcarrier to the middle x = fftshift(x, axes=-1) return x