Source code for sionna.ofdm.resource_grid

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition and functions related to the resource grid"""

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Layer

from .pilot_pattern import PilotPattern, EmptyPilotPattern, KroneckerPilotPattern # pylint: disable=line-too-long
from sionna.utils import flatten_last_dims, flatten_dims, split_dim
import matplotlib.pyplot as plt
from matplotlib import colors


[docs]class ResourceGrid(): # pylint: disable=line-too-long r"""Defines a `ResourceGrid` spanning multiple OFDM symbols and subcarriers. Parameters ---------- num_ofdm_symbols : int Number of OFDM symbols. fft_size : int FFT size (, i.e., the number of subcarriers). subcarrier_spacing : float The subcarrier spacing in Hz. num_tx : int Number of transmitters. num_streams_per_tx : int Number of streams per transmitter. cyclic_prefix_length : int Length of the cyclic prefix. num_guard_carriers : int List of two integers defining the number of guardcarriers at the left and right side of the resource grid. dc_null : bool Indicates if the DC carrier is nulled or not. pilot_pattern : One of [None, "kronecker", "empty", PilotPattern] An instance of :class:`~sionna.ofdm.PilotPattern`, a string shorthand for the :class:`~sionna.ofdm.KroneckerPilotPattern` or :class:`~sionna.ofdm.EmptyPilotPattern`, or `None`. Defaults to `None` which is equivalent to `"empty"`. pilot_ofdm_symbol_indices : List, int List of indices of OFDM symbols reserved for pilot transmissions. Only needed if ``pilot_pattern="kronecker"``. Defaults to `None`. dtype : tf.Dtype Defines the datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. """ def __init__(self, num_ofdm_symbols, fft_size, subcarrier_spacing, num_tx=1, num_streams_per_tx=1, cyclic_prefix_length=0, num_guard_carriers=(0,0), dc_null=False, pilot_pattern=None, pilot_ofdm_symbol_indices=None, dtype=tf.complex64): super().__init__() self._dtype = dtype self._num_ofdm_symbols = num_ofdm_symbols self._fft_size = fft_size self._subcarrier_spacing = subcarrier_spacing self._cyclic_prefix_length = int(cyclic_prefix_length) self._num_tx = num_tx self._num_streams_per_tx = num_streams_per_tx self._num_guard_carriers = np.array(num_guard_carriers) self._dc_null = dc_null self._pilot_ofdm_symbol_indices = pilot_ofdm_symbol_indices self.pilot_pattern = pilot_pattern self._check_settings() @property def cyclic_prefix_length(self): """Length of the cyclic prefix.""" return self._cyclic_prefix_length @property def num_tx(self): """Number of transmitters.""" return self._num_tx @property def num_streams_per_tx(self): """Number of streams per transmitter.""" return self._num_streams_per_tx @property def num_ofdm_symbols(self): """The number of OFDM symbols of the resource grid.""" return self._num_ofdm_symbols @property def num_resource_elements(self): """Number of resource elements.""" return self._fft_size*self._num_ofdm_symbols @property def num_effective_subcarriers(self): """Number of subcarriers used for data and pilot transmissions.""" n = self._fft_size - self._dc_null - np.sum(self._num_guard_carriers) return n @property def effective_subcarrier_ind(self): """Returns the indices of the effective subcarriers.""" num_gc = self._num_guard_carriers sc_ind = range(num_gc[0], self.fft_size-num_gc[1]) if self.dc_null: sc_ind = np.delete(sc_ind, self.dc_ind-num_gc[0]) return sc_ind @property def num_data_symbols(self): """Number of resource elements used for data transmissions.""" n = self.num_effective_subcarriers * self._num_ofdm_symbols - \ self.num_pilot_symbols return tf.cast(n, tf.int32) @property def num_pilot_symbols(self): """Number of resource elements used for pilot symbols.""" return self.pilot_pattern.num_pilot_symbols @property def num_zero_symbols(self): """Number of empty resource elements.""" n = (self._fft_size-self.num_effective_subcarriers) * \ self._num_ofdm_symbols return tf.cast(n, tf.int32) @property def num_guard_carriers(self): """Number of left and right guard carriers.""" return self._num_guard_carriers @property def dc_ind(self): """Index of the DC subcarrier. If ``fft_size`` is odd, the index is (``fft_size``-1)/2. If ``fft_size`` is even, the index is ``fft_size``/2. """ return int(self._fft_size/2 - (self._fft_size%2==1)/2) @property def fft_size(self): """The FFT size.""" return self._fft_size @property def subcarrier_spacing(self): """The subcarrier spacing [Hz].""" return self._subcarrier_spacing @property def ofdm_symbol_duration(self): """Duration of an OFDM symbol with cyclic prefix [s].""" return (1. + self.cyclic_prefix_length/self.fft_size) \ / self.subcarrier_spacing @property def bandwidth(self): """The occupied bandwidth [Hz]: ``fft_size*subcarrier_spacing``.""" return self.fft_size*self.subcarrier_spacing @property def num_time_samples(self): """The number of time-domain samples occupied by the resource grid.""" return (self.fft_size + self.cyclic_prefix_length) \ * self._num_ofdm_symbols @property def dc_null(self): """Indicates if the DC carriers is nulled or not.""" return self._dc_null @property def pilot_pattern(self): """The used PilotPattern.""" return self._pilot_pattern @pilot_pattern.setter def pilot_pattern(self, value): if value is None: value = EmptyPilotPattern(self._num_tx, self._num_streams_per_tx, self._num_ofdm_symbols, self.num_effective_subcarriers, dtype=self._dtype) elif isinstance(value, PilotPattern): pass elif isinstance(value, str): assert value in ["kronecker", "empty"],\ "Unknown pilot pattern" if value=="empty": value = EmptyPilotPattern(self._num_tx, self._num_streams_per_tx, self._num_ofdm_symbols, self.num_effective_subcarriers, dtype=self._dtype) elif value=="kronecker": assert self._pilot_ofdm_symbol_indices is not None,\ "You must provide pilot_ofdm_symbol_indices." value = KroneckerPilotPattern(self, self._pilot_ofdm_symbol_indices, dtype=self._dtype) else: raise ValueError("Unsupported pilot_pattern") self._pilot_pattern = value def _check_settings(self): """Validate that all properties define a valid resource grid""" assert self._num_ofdm_symbols > 0, \ "`num_ofdm_symbols` must be positive`." assert self._fft_size > 0, \ "`fft_size` must be positive`." assert self._cyclic_prefix_length>=0, \ "`cyclic_prefix_length must be nonnegative." assert self._cyclic_prefix_length<=self._fft_size, \ "`cyclic_prefix_length cannot be longer than `fft_size`." assert self._num_tx > 0, \ "`num_tx` must be positive`." assert self._num_streams_per_tx > 0, \ "`num_streams_per_tx` must be positive`." assert len(self._num_guard_carriers)==2, \ "`num_guard_carriers` must have two elements." assert np.all(np.greater_equal(self._num_guard_carriers, 0)), \ "`num_guard_carriers` must have nonnegative entries." assert np.sum(self._num_guard_carriers)<=self._fft_size-self._dc_null,\ "Total number of guardcarriers cannot be larger than `fft_size`." assert self._dtype in [tf.complex64, tf.complex128], \ "dtype must be tf.complex64 or tf.complex128" return True
[docs] def build_type_grid(self): """Returns a tensor indicating the type of each resource element. Resource elements can be one of - 0 : Data symbol - 1 : Pilot symbol - 2 : Guard carrier symbol - 3 : DC carrier symbol Output ------ : [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.int32 Tensor indicating for each transmitter and stream the type of the resource elements of the corresponding resource grid. The type can be one of [0,1,2,3] as explained above. """ shape = [self._num_tx, self._num_streams_per_tx, self._num_ofdm_symbols] gc_l = 2*tf.ones(shape+[self._num_guard_carriers[0]], tf.int32) gc_r = 2*tf.ones(shape+[self._num_guard_carriers[1]], tf.int32) dc = 3*tf.ones(shape + [tf.cast(self._dc_null, tf.int32)], tf.int32) mask = self.pilot_pattern.mask split_ind = self.dc_ind-self._num_guard_carriers[0] rg_type = tf.concat([gc_l, # Left Guards mask[...,:split_ind], # Data & pilots dc, # DC mask[...,split_ind:], # Data & pilots gc_r], -1) # Right guards return rg_type
[docs] def show(self, tx_ind=0, tx_stream_ind=0): """Visualizes the resource grid for a specific transmitter and stream. Input ----- tx_ind : int Indicates the transmitter index. tx_stream_ind : int Indicates the index of the stream. Output ------ : `matplotlib.figure` A handle to a matplot figure object. """ fig = plt.figure() data = self.build_type_grid()[tx_ind, tx_stream_ind] cmap = colors.ListedColormap([[60/256,8/256,72/256], [45/256,91/256,128/256], [45/256,172/256,111/256], [250/256,228/256,62/256]]) bounds=[0,1,2,3,4] norm = colors.BoundaryNorm(bounds, cmap.N) img = plt.imshow(np.transpose(data), interpolation="nearest", origin="lower", cmap=cmap, norm=norm, aspect="auto") cbar = plt.colorbar(img, ticks=[0.5, 1.5, 2.5,3.5], orientation="vertical", shrink=0.8) cbar.set_ticklabels(["Data", "Pilot", "Guard carrier", "DC carrier"]) plt.title("OFDM Resource Grid") plt.ylabel("Subcarrier Index") plt.xlabel("OFDM Symbol") plt.xticks(range(0, data.shape[0])) return fig
[docs]class ResourceGridMapper(Layer): # pylint: disable=line-too-long r"""ResourceGridMapper(resource_grid, dtype=tf.complex64, **kwargs) Maps a tensor of modulated data symbols to a ResourceGrid. This layer takes as input a tensor of modulated data symbols and maps them together with pilot symbols onto an OFDM :class:`~sionna.ofdm.ResourceGrid`. The output can be converted to a time-domain signal with the :class:`~sionna.ofdm.Modulator` or further processed in the frequency domain. Parameters ---------- resource_grid : ResourceGrid An instance of :class:`~sionna.ofdm.ResourceGrid`. dtype : tf.Dtype Datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. Input ----- : [batch_size, num_tx, num_streams_per_tx, num_data_symbols], tf.complex The modulated data symbols to be mapped onto the resource grid. Output ------ : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex The full OFDM resource grid in the frequency domain. """ def __init__(self, resource_grid, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) self._resource_grid = resource_grid def build(self, input_shape): # pylint: disable=unused-argument """Precompute a tensor of shape [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size] which is prefilled with pilots and stores indices to scatter data symbols. """ self._rg_type = self._resource_grid.build_type_grid() self._pilot_ind = tf.where(self._rg_type==1) self._data_ind = tf.where(self._rg_type==0) def call(self, inputs): # Map pilots on empty resource grid pilots = flatten_last_dims(self._resource_grid.pilot_pattern.pilots, 3) template = tf.scatter_nd(self._pilot_ind, pilots, self._rg_type.shape) template = tf.expand_dims(template, -1) # Broadcast the resource grid template to batch_size batch_size = tf.shape(inputs)[0] new_shape = tf.concat([tf.shape(template)[:-1], [batch_size]], 0) template = tf.broadcast_to(template, new_shape) # Flatten the inputs and put batch_dim last for scatter update inputs = tf.transpose(flatten_last_dims(inputs, 3)) rg = tf.tensor_scatter_nd_update(template, self._data_ind, inputs) rg = tf.transpose(rg, [4, 0, 1, 2, 3]) return rg
[docs]class ResourceGridDemapper(Layer): # pylint: disable=line-too-long r"""ResourceGridDemapper(resource_grid, stream_management, dtype=tf.complex64, **kwargs) Extracts data-carrying resource elements from a resource grid. This layer takes as input an OFDM :class:`~sionna.ofdm.ResourceGrid` and extracts the data-carrying resource elements. In other words, it implements the reverse operation of :class:`~sionna.ofdm.ResourceGridMapper`. Parameters ---------- resource_grid : ResourceGrid An instance of :class:`~sionna.ofdm.ResourceGrid`. stream_management : StreamManagement An instance of :class:`~sionna.mimo.StreamManagement`. dtype : tf.Dtype Datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. Input ----- : [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols, fft_size, data_dim] The full OFDM resource grid in the frequency domain. The last dimension `data_dim` is optional. If `data_dim` is used, it refers to the dimensionality of the data that should be demapped to individual streams. An example would be LLRs. Output ------ : [batch_size, num_rx, num_streams_per_rx, num_data_symbols, data_dim] The data that were mapped into the resource grid. The last dimension `data_dim` is only returned if it was used for the input. """ def __init__(self, resource_grid, stream_management, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) self._stream_management = stream_management self._resource_grid = resource_grid # Precompute indices to extract data symbols mask = resource_grid.pilot_pattern.mask num_data_symbols = resource_grid.pilot_pattern.num_data_symbols data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") self._data_ind = data_ind[...,:num_data_symbols] def call(self, y): # pylint: disable=arguments-renamed # y has shape # [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols,... # ..., fft_size, data_dim] # If data_dim is not provided, add a dummy dimension if len(y.shape)==5: y = tf.expand_dims(y, -1) # Remove nulled subcarriers from y (guards, dc). New shape: # [batch_size, num_rx, num_rx_ant, ... # ..., num_ofdm_symbols, num_effective_subcarriers, data dim] y = tf.gather(y, self._resource_grid.effective_subcarrier_ind, axis=-2) # Transpose tensor to shape # [num_rx, num_streams_per_rx, num_ofdm_symbols,... # ..., num_effective_subcarriers, data_dim, batch_size] y = tf.transpose(y, [1, 2, 3, 4, 5, 0]) # Merge num_rx amd num_streams_per_rx # [num_rx * num_streams_per_rx, num_ofdm_symbols,... # ...,num_effective_subcarriers, data_dim, batch_size] y = flatten_dims(y, 2, 0) # Put first dimension into the right ordering stream_ind = self._stream_management.stream_ind y = tf.gather(y, stream_ind, axis=0) # Reshape first dimensions to [num_tx, num_streams] so that # we can compared to the way the streams were created. # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,... # ..., data_dim, batch_size] num_streams = self._stream_management.num_streams_per_tx num_tx = self._stream_management.num_tx y = split_dim(y, [num_tx, num_streams], 0) # Flatten resource grid dimensions # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,... # ..., data_dim, batch_size] y = flatten_dims(y, 2, 2) # Gather data symbols # [num_tx, num_streams, num_data_symbols, data_dim, batch_size] y = tf.gather(y, self._data_ind, batch_dims=2, axis=2) # Put batch_dim first # [batch_size, num_tx, num_streams, num_data_symbols] y = tf.transpose(y, [4, 0, 1, 2, 3]) # Squeeze data_dim if y.shape[-1]==1: y = tf.squeeze(y, -1) return y
[docs]class RemoveNulledSubcarriers(Layer): # pylint: disable=line-too-long r"""RemoveNulledSubcarriers(resource_grid, **kwargs) Removes nulled guard and/or DC subcarriers from a resource grid. Parameters ---------- resource_grid : ResourceGrid An instance of :class:`~sionna.ofdm.ResourceGrid`. Input ----- : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex64 Full resource grid. Output ------ : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex64 Resource grid without nulled subcarriers. """ def __init__(self, resource_grid, **kwargs): self._sc_ind = resource_grid.effective_subcarrier_ind super().__init__(**kwargs) def call(self, inputs): return tf.gather(inputs, self._sc_ind, axis=-1)