Source code for sionna.ofdm.channel_estimation

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Functions related to OFDM channel estimation"""

import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np
from sionna.channel.tr38901 import models
from sionna.utils import flatten_last_dims, expand_to_rank, matrix_inv
from sionna.ofdm import ResourceGrid, RemoveNulledSubcarriers
from sionna import PI, SPEED_OF_LIGHT
from scipy.special import jv
import itertools
from abc import ABC, abstractmethod
import json
from importlib_resources import files

[docs]class BaseChannelEstimator(ABC, Layer): # pylint: disable=line-too-long r"""BaseChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs) Abstract layer for implementing an OFDM channel estimator. Any layer that implements an OFDM channel estimator must implement this class and its :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations` abstract method. This class extracts the pilots from the received resource grid ``y``, calls the :meth:`~sionna.ofdm.BaseChannelEstimator.estimate_at_pilot_locations` method to estimate the channel for the pilot-carrying resource elements, and then interpolates the channel to compute channel estimates for the data-carrying resouce elements using the interpolation method specified by ``interpolation_type`` or the ``interpolator`` object. Parameters ---------- resource_grid : ResourceGrid An instance of :class:`~sionna.ofdm.ResourceGrid`. interpolation_type : One of ["nn", "lin", "lin_time_avg"], string The interpolation method to be used. It is ignored if ``interpolator`` is not `None`. Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`") or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with averaging across OFDM symbols (`"lin_time_avg"`). Defaults to "nn". interpolator : BaseChannelInterpolator An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`, such as :class:`~sionna.ofdm.LMMSEInterpolator`, or `None`. In the latter case, the interpolator specfied by ``interpolation_type`` is used. Otherwise, the ``interpolator`` is used and ``interpolation_type`` is ignored. Defaults to `None`. dtype : tf.Dtype Datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. Input ----- (y, no) : Tuple: y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex Observed resource grid no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float Variance of the AWGN Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_hat``, tf.float Channel estimation error variance accross the entire resource grid for all transmitters and streams """ def __init__(self, resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) assert isinstance(resource_grid, ResourceGrid),\ "You must provide a valid instance of ResourceGrid." self._pilot_pattern = resource_grid.pilot_pattern self._removed_nulled_scs = RemoveNulledSubcarriers(resource_grid) assert interpolation_type in ["nn","lin","lin_time_avg",None], \ "Unsupported `interpolation_type`" self._interpolation_type = interpolation_type if interpolator is not None: assert isinstance(interpolator, BaseChannelInterpolator), \ "`interpolator` must implement the BaseChannelInterpolator interface" self._interpol = interpolator elif self._interpolation_type == "nn": self._interpol = NearestNeighborInterpolator(self._pilot_pattern) elif self._interpolation_type == "lin": self._interpol = LinearInterpolator(self._pilot_pattern) elif self._interpolation_type == "lin_time_avg": self._interpol = LinearInterpolator(self._pilot_pattern, time_avg=True) # Precompute indices to gather received pilot signals num_pilot_symbols = self._pilot_pattern.num_pilot_symbols mask = flatten_last_dims(self._pilot_pattern.mask) pilot_ind = tf.argsort(mask, axis=-1, direction="DESCENDING") self._pilot_ind = pilot_ind[...,:num_pilot_symbols]
[docs] @abstractmethod def estimate_at_pilot_locations(self, y_pilots, no): """ Estimates the channel for the pilot-carrying resource elements. This is an abstract method that must be implemented by a concrete OFDM channel estimator that implement this class. Input ----- y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex Observed signals for the pilot-carrying resource elements no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float Variance of the AWGN Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, num_pilot_symbols], tf.complex Channel estimates for the pilot-carrying resource elements err_var : Same shape as ``h_hat``, tf.float Channel estimation error variance for the pilot-carrying resource elements """ pass
def call(self, inputs): y, no = inputs # y has shape: # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,.. # ... fft_size] # # no can have shapes [], [batch_size], [batch_size, num_rx] # or [batch_size, num_rx, num_rx_ant] # Removed nulled subcarriers (guards, dc) y_eff = self._removed_nulled_scs(y) # Flatten the resource grid for pilot extraction # New shape: [...,num_ofdm_symbols*num_effective_subcarriers] y_eff_flat = flatten_last_dims(y_eff) # Gather pilots along the last dimensions # Resulting shape: y_eff_flat.shape[:-1] + pilot_ind.shape, i.e.: # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... # ..., num_pilot_symbols] y_pilots = tf.gather(y_eff_flat, self._pilot_ind, axis=-1) # Compute LS channel estimates # Note: Some might be Inf because pilots=0, but we do not care # as only the valid estimates will be considered during interpolation. # We do a save division to replace Inf by 0. # Broadcasting from pilots here is automatic since pilots have shape # [num_tx, num_streams, num_pilot_symbols] h_hat, err_var = self.estimate_at_pilot_locations(y_pilots, no) # Interpolate channel estimates over the resource grid if self._interpolation_type is not None: h_hat, err_var = self._interpol(h_hat, err_var) err_var = tf.maximum(err_var, tf.cast(0, err_var.dtype)) return h_hat, err_var
[docs]class LSChannelEstimator(BaseChannelEstimator, Layer): # pylint: disable=line-too-long r"""LSChannelEstimator(resource_grid, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs) Layer implementing least-squares (LS) channel estimation for OFDM MIMO systems. After LS channel estimation at the pilot positions, the channel estimates and error variances are interpolated accross the entire resource grid using a specified interpolation function. For simplicity, the underlying algorithm is described for a vectorized observation, where we have a nonzero pilot for all elements to be estimated. The actual implementation works on a full OFDM resource grid with sparse pilot patterns. The following model is assumed: .. math:: \mathbf{y} = \mathbf{h}\odot\mathbf{p} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector, :math:`\mathbf{p}\in\mathbb{C}^M` is the vector of pilot symbols, :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated, and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose elements have variance :math:`N_0`. The operator :math:`\odot` denotes element-wise multiplication. The channel estimate :math:`\hat{\mathbf{h}}` and error variances :math:`\sigma^2_i`, :math:`i=0,\dots,M-1`, are computed as .. math:: \hat{\mathbf{h}} &= \mathbf{y} \odot \frac{\mathbf{p}^\star}{\left|\mathbf{p}\right|^2} = \mathbf{h} + \tilde{\mathbf{h}}\\ \sigma^2_i &= \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right] = \frac{N_0}{\left|p_i\right|^2}. The channel estimates and error variances are then interpolated accross the entire resource grid. Parameters ---------- resource_grid : ResourceGrid An instance of :class:`~sionna.ofdm.ResourceGrid`. interpolation_type : One of ["nn", "lin", "lin_time_avg"], string The interpolation method to be used. It is ignored if ``interpolator`` is not `None`. Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`") or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with averaging across OFDM symbols (`"lin_time_avg"`). Defaults to "nn". interpolator : BaseChannelInterpolator An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`, such as :class:`~sionna.ofdm.LMMSEInterpolator`, or `None`. In the latter case, the interpolator specfied by ``interpolation_type`` is used. Otherwise, the ``interpolator`` is used and ``interpolation_type`` is ignored. Defaults to `None`. dtype : tf.Dtype Datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. Input ----- (y, no) : Tuple: y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex Observed resource grid no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float Variance of the AWGN Output ------ h_ls : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_ls``, tf.float Channel estimation error variance accross the entire resource grid for all transmitters and streams """ def estimate_at_pilot_locations(self, y_pilots, no): # y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, # num_pilot_symbols], tf.complex # The observed signals for the pilot-carrying resource elements. # no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, # tf.float # The variance of the AWGN. # Compute LS channel estimates # Note: Some might be Inf because pilots=0, but we do not care # as only the valid estimates will be considered during interpolation. # We do a save division to replace Inf by 0. # Broadcasting from pilots here is automatic since pilots have shape # [num_tx, num_streams, num_pilot_symbols] h_ls = tf.math.divide_no_nan(y_pilots, self._pilot_pattern.pilots) # Compute error variance and broadcast to the same shape as h_ls # Expand rank of no for broadcasting no = expand_to_rank(no, tf.rank(h_ls), -1) # Expand rank of pilots for broadcasting pilots = expand_to_rank(self._pilot_pattern.pilots, tf.rank(h_ls), 0) # Compute error variance, broadcastable to the shape of h_ls err_var = tf.math.divide_no_nan(no, tf.abs(pilots)**2) return h_ls, err_var
[docs]class BaseChannelInterpolator(ABC): # pylint: disable=line-too-long r"""BaseChannelInterpolator() Abstract layer for implementing an OFDM channel interpolator. Any layer that implements an OFDM channel interpolator must implement this callable class. A channel interpolator is used by an OFDM channel estimator (:class:`~sionna.ofdm.BaseChannelEstimator`) to compute channel estimates for the data-carrying resource elements from the channel estimates for the pilot-carrying resource elements. Input ----- h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimates for the pilot-carrying resource elements err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimation error variances for the pilot-carrying resource elements Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_hat``, tf.float Channel estimation error variance accross the entire resource grid for all transmitters and streams """ @abstractmethod def __call__(self, h_hat, err_var): pass
[docs]class NearestNeighborInterpolator(BaseChannelInterpolator): # pylint: disable=line-too-long r"""NearestNeighborInterpolator(pilot_pattern) Nearest-neighbor channel estimate interpolation on a resource grid. This class assigns to each element of an OFDM resource grid one of ``num_pilots`` provided channel estimates and error variances according to the nearest neighbor method. It is assumed that the measurements were taken at the nonzero positions of a :class:`~sionna.ofdm.PilotPattern`. The figure below shows how four channel estimates are interpolated accross a resource grid. Grey fields indicate measurement positions while the colored regions show which resource elements are assigned to the same measurement value. .. image:: ../figures/nearest_neighbor_interpolation.png Parameters ---------- pilot_pattern : PilotPattern An instance of :class:`~sionna.ofdm.PilotPattern` Input ----- h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimates for the pilot-carrying resource elements err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimation error variances for the pilot-carrying resource elements Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_hat``, tf.float Channel estimation error variances accross the entire resource grid for all transmitters and streams """ def __init__(self, pilot_pattern): super().__init__() assert(pilot_pattern.num_pilot_symbols>0),\ """The pilot pattern cannot be empty""" # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers] mask = np.array(pilot_pattern.mask) mask_shape = mask.shape # Store to reconstruct the original shape mask = np.reshape(mask, [-1] + list(mask_shape[-2:])) # Reshape the pilots to shape [-1, num_pilot_symbols] pilots = pilot_pattern.pilots pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]]) max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1)) assert max_num_zero_pilots<pilots.shape[-1],\ """Each pilot sequence must have at least one nonzero entry""" # Compute gather indices for nearest neighbor interpolation gather_ind = np.zeros_like(mask, dtype=np.int32) for a in range(gather_ind.shape[0]): # For each pilot pattern... i_p, j_p = np.where(mask[a]) # ...determine the pilot indices for i in range(mask_shape[-2]): # Iterate over... for j in range(mask_shape[-1]): # ... all resource elements # Compute Manhattan distance to all pilot positions d = np.abs(i-i_p) + np.abs(j-j_p) # Set the distance at all pilot positions with zero energy # equal to the maximum possible distance d[np.abs(pilots[a])==0] = np.sum(mask_shape[-2:]) # Find the pilot index with the shortest distance... ind = np.argmin(d) # ... and store it in the index tensor gather_ind[a, i, j] = ind # Reshape to the original shape of the mask, i.e.: # [num_tx, num_streams_per_tx, num_ofdm_symbols,... # ..., num_effective_subcarriers] self._gather_ind = tf.reshape(gather_ind, mask_shape) def _interpolate(self, inputs): # inputs has shape: # [k, l, m, num_tx, num_streams_per_tx, num_pilots] # Transpose inputs to bring batch_dims for gather last. New shape: # [num_tx, num_streams_per_tx, num_pilots, k, l, m] perm = tf.roll(tf.range(tf.rank(inputs)), -3, 0) inputs = tf.transpose(inputs, perm) # Interpolate through gather. Shape: # [num_tx, num_streams_per_tx, num_ofdm_symbols, # ..., num_effective_subcarriers, k, l, m] outputs = tf.gather(inputs, self._gather_ind, 2, batch_dims=2) # Transpose outputs to bring batch_dims first again. New shape: # [k, l, m, num_tx, num_streams_per_tx,... # ..., num_ofdm_symbols, num_effective_subcarriers] perm = tf.roll(tf.range(tf.rank(outputs)), 3, 0) outputs = tf.transpose(outputs, perm) return outputs def __call__(self, h_hat, err_var): h_hat = self._interpolate(h_hat) err_var = self._interpolate(err_var) return h_hat, err_var
[docs]class LinearInterpolator(BaseChannelInterpolator): # pylint: disable=line-too-long r"""LinearInterpolator(pilot_pattern, time_avg=False) Linear channel estimate interpolation on a resource grid. This class computes for each element of an OFDM resource grid a channel estimate based on ``num_pilots`` provided channel estimates and error variances through linear interpolation. It is assumed that the measurements were taken at the nonzero positions of a :class:`~sionna.ofdm.PilotPattern`. The interpolation is done first across sub-carriers and then across OFDM symbols. Parameters ---------- pilot_pattern : PilotPattern An instance of :class:`~sionna.ofdm.PilotPattern` time_avg : bool If enabled, measurements will be averaged across OFDM symbols (i.e., time). This is useful for channels that do not vary substantially over the duration of an OFDM frame. Defaults to `False`. Input ----- h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimates for the pilot-carrying resource elements err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimation error variances for the pilot-carrying resource elements Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_hat``, tf.float Channel estimation error variances accross the entire resource grid for all transmitters and streams """ def __init__(self, pilot_pattern, time_avg=False): super().__init__() assert(pilot_pattern.num_pilot_symbols>0),\ """The pilot pattern cannot be empty""" self._time_avg = time_avg # Reshape mask to shape [-1,num_ofdm_symbols,num_effective_subcarriers] mask = np.array(pilot_pattern.mask) mask_shape = mask.shape # Store to reconstruct the original shape mask = np.reshape(mask, [-1] + list(mask_shape[-2:])) # Reshape the pilots to shape [-1, num_pilot_symbols] pilots = pilot_pattern.pilots pilots = np.reshape(pilots, [-1] + [pilots.shape[-1]]) max_num_zero_pilots = np.max(np.sum(np.abs(pilots)==0, -1)) assert max_num_zero_pilots<pilots.shape[-1],\ """Each pilot sequence must have at least one nonzero entry""" # Create actual pilot patterns for each stream over the resource grid z = np.zeros_like(mask, dtype=pilots.dtype) for a in range(z.shape[0]): z[a][np.where(mask[a])] = pilots[a] # Linear interpolation works as follows: # We compute for each resource element (RE) # x_0 : The x-value (i.e., sub-carrier index or OFDM symbol) at which # the first channel measurement was taken # x_1 : The x-value (i.e., sub-carrier index or OFDM symbol) at which # the second channel measurement was taken # y_0 : The first channel estimate # y_1 : The second channel estimate # x : The x-value (i.e., sub-carrier index or OFDM symbol) # # The linearly interpolated value y is then given as: # y = (x-x_0) * (y_1-y_0) / (x_1-x_0) + y_0 # # The following code pre-computes various quantities and indices # that are needed to compute x_0, x_1, y_0, y_1, x for frequency- and # time-domain interpolation. ## ## Frequency-domain interpolation ## self._x_freq = tf.cast(expand_to_rank(tf.range(0, mask.shape[-1]), 7, axis=0), pilots.dtype) # Permutation indices to shift batch_dims last during gather self._perm_fwd_freq = tf.roll(tf.range(6), -3, 0) x_0_freq = np.zeros_like(mask, np.int32) x_1_freq = np.zeros_like(mask, np.int32) # Set REs of OFDM symbols without any pilot equal to -1 (dummy value) x_0_freq[np.sum(np.abs(z), axis=-1)==0] = -1 x_1_freq[np.sum(np.abs(z), axis=-1)==0] = -1 y_0_freq_ind = np.copy(x_0_freq) # Indices used to gather estimates y_1_freq_ind = np.copy(x_1_freq) # Indices used to gather estimates # For each stream for a in range(z.shape[0]): pilot_count = 0 # Counts the number of non-zero pilots # Indices of non-zero pilots within the pilots vector pilot_ind = np.where(np.abs(pilots[a]))[0] # Go through all OFDM symbols for i in range(x_0_freq.shape[1]): # Indices of non-zero pilots within the OFDM symbol pilot_ind_ofdm = np.where(np.abs(z[a][i]))[0] # If OFDM symbol contains only one non-zero pilot if len(pilot_ind_ofdm)==1: # Set the indices of the first and second pilot to the same # value for all REs of the OFDM symbol x_0_freq[a][i] = pilot_ind_ofdm[0] x_1_freq[a][i] = pilot_ind_ofdm[0] y_0_freq_ind[a,i] = pilot_ind[pilot_count] y_1_freq_ind[a,i] = pilot_ind[pilot_count] # If OFDM symbol contains two or more pilots elif len(pilot_ind_ofdm)>=2: x0 = 0 x1 = 1 # Go through all resource elements of this OFDM symbol for j in range(x_0_freq.shape[2]): x_0_freq[a,i,j] = pilot_ind_ofdm[x0] x_1_freq[a,i,j] = pilot_ind_ofdm[x1] y_0_freq_ind[a,i,j] = pilot_ind[pilot_count + x0] y_1_freq_ind[a,i,j] = pilot_ind[pilot_count + x1] if j==pilot_ind_ofdm[x1] and x1<len(pilot_ind_ofdm)-1: x0 = x1 x1 += 1 pilot_count += len(pilot_ind_ofdm) x_0_freq = np.reshape(x_0_freq, mask_shape) x_1_freq = np.reshape(x_1_freq, mask_shape) x_0_freq = expand_to_rank(x_0_freq, 7, axis=0) x_1_freq = expand_to_rank(x_1_freq, 7, axis=0) self._x_0_freq = tf.cast(x_0_freq, pilots.dtype) self._x_1_freq = tf.cast(x_1_freq, pilots.dtype) # We add +1 here to shift all indices as the input will be padded # at the beginning with 0, (i.e., the dummy index -1 will become 0). self._y_0_freq_ind = np.reshape(y_0_freq_ind, mask_shape)+1 self._y_1_freq_ind = np.reshape(y_1_freq_ind, mask_shape)+1 ## ## Time-domain interpolation ## self._x_time = tf.expand_dims(tf.range(0, mask.shape[-2]), -1) self._x_time = tf.cast(expand_to_rank(self._x_time, 7, axis=0), dtype=pilots.dtype) # Indices used to gather estimates self._perm_fwd_time = tf.roll(tf.range(7), -3, 0) y_0_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices y_1_time_ind = np.zeros(z.shape[:2], np.int32) # Gather indices # For each stream for a in range(z.shape[0]): # Indices of OFDM symbols for which channel estimates were computed ofdm_ind = np.where(np.sum(np.abs(z[a]), axis=-1))[0] # Only one OFDM symbol with pilots if len(ofdm_ind)==1: y_0_time_ind[a] = ofdm_ind[0] y_1_time_ind[a] = ofdm_ind[0] # Two or more OFDM symbols with pilots elif len(ofdm_ind)>=2: x0 = 0 x1 = 1 for i in range(z.shape[1]): y_0_time_ind[a,i] = ofdm_ind[x0] y_1_time_ind[a,i] = ofdm_ind[x1] if i==ofdm_ind[x1] and x1<len(ofdm_ind)-1: x0 = x1 x1 += 1 self._y_0_time_ind = np.reshape(y_0_time_ind, mask_shape[:-1]) self._y_1_time_ind = np.reshape(y_1_time_ind, mask_shape[:-1]) self._x_0_time = expand_to_rank(tf.expand_dims(self._y_0_time_ind, -1), 7, axis=0) self._x_0_time = tf.cast(self._x_0_time, dtype=pilots.dtype) self._x_1_time = expand_to_rank(tf.expand_dims(self._y_1_time_ind, -1), 7, axis=0) self._x_1_time = tf.cast(self._x_1_time, dtype=pilots.dtype) # # Other precomputed values # # Undo permutation of batch_dims for gather self._perm_bwd = tf.roll(tf.range(7), 3, 0) # Padding for the inputs pad = np.zeros([6, 2], np.int32) pad[-1, 0] = 1 self._pad = pad # Number of ofdm symbols carrying at least one pilot. # Used for time-averaging (optional) n = np.sum(np.abs(np.reshape(z, mask_shape)), axis=-1, keepdims=True) n = np.sum(n>0, axis=-2, keepdims=True) self._num_pilot_ofdm_symbols = expand_to_rank(n, 7, axis=0) def _interpolate_1d(self, inputs, x, x0, x1, y0_ind, y1_ind): # Gather the right values for y0 and y1 y0 = tf.gather(inputs, y0_ind, axis=2, batch_dims=2) y1 = tf.gather(inputs, y1_ind, axis=2, batch_dims=2) # Undo the permutation of the inputs y0 = tf.transpose(y0, self._perm_bwd) y1 = tf.transpose(y1, self._perm_bwd) # Compute linear interpolation slope = tf.math.divide_no_nan(y1-y0, tf.cast(x1-x0, dtype=y0.dtype)) return tf.cast(x-x0, dtype=y0.dtype)*slope + y0 def _interpolate(self, inputs): # # Prepare inputs # # inputs has shape: # [k, l, m, num_tx, num_streams_per_tx, num_pilots] # Pad the inputs with a leading 0. # All undefined channel estimates will get this value. inputs = tf.pad(inputs, self._pad, constant_values=0) # Transpose inputs to bring batch_dims for gather last. New shape: # [num_tx, num_streams_per_tx, 1+num_pilots, k, l, m] inputs = tf.transpose(inputs, self._perm_fwd_freq) # # Frequency-domain interpolation # # h_hat_freq has shape: # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,... # ...num_effective_subcarriers] h_hat_freq = self._interpolate_1d(inputs, self._x_freq, self._x_0_freq, self._x_1_freq, self._y_0_freq_ind, self._y_1_freq_ind) # # Time-domain interpolation # # Time-domain averaging (optional) if self._time_avg: num_ofdm_symbols = h_hat_freq.shape[-2] h_hat_freq = tf.reduce_sum(h_hat_freq, axis=-2, keepdims=True) h_hat_freq /= tf.cast(self._num_pilot_ofdm_symbols,h_hat_freq.dtype) h_hat_freq = tf.repeat(h_hat_freq, [num_ofdm_symbols], axis=-2) # Transpose h_hat_freq to bring batch_dims for gather last. New shape: # [num_tx, num_streams_per_tx, num_ofdm_symbols,... # ...num_effective_subcarriers, k, l, m] h_hat_time = tf.transpose(h_hat_freq, self._perm_fwd_time) # h_hat_time has shape: # [k, l, m, num_tx, num_streams_per_tx, num_ofdm_symbols,... # ...num_effective_subcarriers] h_hat_time = self._interpolate_1d(h_hat_time, self._x_time, self._x_0_time, self._x_1_time, self._y_0_time_ind, self._y_1_time_ind) return h_hat_time def __call__(self, h_hat, err_var): h_hat = self._interpolate(h_hat) # the interpolator requires complex-valued inputs err_var = tf.cast(err_var, tf.complex64) err_var = self._interpolate(err_var) err_var = tf.math.real(err_var) return h_hat, err_var
class LMMSEInterpolator1D: # pylint: disable=line-too-long r"""LMMSEInterpolator1D(pilot_mask, cov_mat) This class performs the linear interpolation across the inner dimension of the input ``h_hat``. The two inner dimensions of the input ``h_hat`` form a matrix :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`. LMMSE interpolation is performed across the inner dimension as follows: .. math:: \tilde{\mathbf{h}}_n = \mathbf{A}_n \hat{\mathbf{h}}_n where :math:`1 \leq n \leq N` and :math:`\hat{\mathbf{h}}_n` is the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`. :math:`\mathbf{A}_n` is the :math:`M \times M` interpolation LMMSE matrix: .. math:: \mathbf{A}_n = \mathbf{R} \mathbf{\Pi}_n \left( \mathbf{\Pi}_n^\intercal \mathbf{R} \mathbf{\Pi}_n + \tilde{\mathbf{\Sigma}}_n \right)^{-1} \mathbf{\Pi}_n^\intercal. where :math:`\mathbf{R}` is the :math:`M \times M` covariance matrix across the inner dimension of the quantity which is estimated, :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n` values to a vector of size :math:`M` according to the ``pilot_mask`` for the :math:`n^{\text{th}}` row, and :math:`\tilde{\mathbf{\Sigma}}_n \in \mathbb{R}^{K_n \times K_n}` is the regularized channel estimation error covariance. The :math:`i^{\text{th}}`` diagonal element of :math:`\tilde{\mathbf{\Sigma}}_n` is such that: .. math:: \left[ \tilde{\mathbf{\Sigma}}_n \right]_{i,i} = \text{max} \left\{ \right\} built from ``err_var`` and assumed to be diagonal. The returned channel estimates are .. math:: \begin{bmatrix} {\tilde{\mathbf{h}}_1}^\intercal\\ \vdots\\ {\tilde{\mathbf{h}}_N}^\intercal \end{bmatrix}. The returned channel estimation error variances are the diaginal coefficients of .. math:: \text{diag} \left( \mathbf{R} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R} \right), 1 \leq n \leq N where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the columns corresponding to rows not carrying any pilots. Note that interpolation is not performed for rows not carrying any pilots. **Remark**: The interpolation matrix differs across rows as different rows may carry pilots on different elements and/or have different estimation error variances. Parameters ---------- pilot_mask : [:math:`N`, :math:`M`] : int Mask indicating the allocation of resource elements. 0 : Data, 1 : Pilot, 2 : Not used, cov_mat : [:math:`M`, :math:`M`], tf.complex Covariance matrix of the channel across the inner dimension. last_step : bool Set to `True` if this is the last interpolation step. Otherwise, set to `False`. If `True`, the the output is scaled to ensure its variance is as expected by the following interpolation step. Input ----- h_hat : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex Channel estimates. err_var : [batch_size, num_rx, num_rx_ant, num_tx, :math:`N`, :math:`M`], tf.complex Channel estimation error variances. Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, :math:`N`, :math:`M`], tf.complex Channel estimates interpolated across the inner dimension. err_var : Same shape as ``h_hat``, tf.float The channel estimation error variances of the interpolated channel estimates. """ def __init__(self, pilot_mask, cov_mat, last_step): self._cdtype = cov_mat.dtype assert self._cdtype in (tf.complex64, tf.complex128),\ "`cov_mat` dtype must be one of tf.complex64 or tf.complex128" self._rdtype = self._cdtype.real_dtype self._rzero = tf.constant(0.0, self._rdtype) # Interpolation is performed along the inner dimension of # the resource grid, which may be either the subcarriers # or the OFDM symbols dimension. # This dimension is referred to as the inner dimension. # The other dimension of the resource grid is referred to # as the outer dimension. # Size of the inner dimension. inner_dim_size = tf.shape(pilot_mask)[-1] self._inner_dim_size = inner_dim_size # Size of the outer dimension. outer_dim_size = tf.shape(pilot_mask)[-2] self._outer_dim_size = outer_dim_size self._cov_mat = cov_mat self._last_step = last_step # Computation of the interpolation matrix is done solving the # least-square problem: # # X = min_Z |AZ - B|_F^2 # # where A = (\Pi_T R \Pi + S) and # B = R \Pi # where R is the channel covariance matrix, S the error # diagonal covariance matrix, and \Pi the matrix that spreads the pilots # according to the pilot pattern along the inner axis. # Extracting the locations of pilots from the pilot mask num_tx = tf.shape(pilot_mask)[0] num_streams_per_tx = tf.shape(pilot_mask)[1] # List of indices of pilots in the inner dimension for every # transmit antenna, stream, and outer dimension element. pilot_indices = [] # Maximum number of pilots carried by an inner dimension. max_num_pil = 0 # Indices used to add the error variance to the diagonal # elements of the covariance matrix restricted # to the elements carrying pilots. # These matrices are computed below. add_err_var_indices = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, 5], int) for tx in range(num_tx): pilot_indices.append([]) for st in range(num_streams_per_tx): pilot_indices[-1].append([]) for oi in range(outer_dim_size): pilot_indices[-1][-1].append([]) num_pil = 0 # Number of pilots on this outer dim for ii in range(inner_dim_size): # Check if this RE is carrying a pilot # for this stream if pilot_mask[tx,st,oi,ii] == 0: continue if pilot_mask[tx,st,oi,ii] == 1: pilot_indices[tx][st][oi].append(ii) indices = [tx, st, oi, num_pil, num_pil] add_err_var_indices[tx, st, oi, ii] = indices num_pil += 1 max_num_pil = max(max_num_pil, num_pil) # [num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, 5] self._add_err_var_indices = tf.cast(add_err_var_indices, tf.int32) # Different subcarriers/symbols may carry a different number of pilots. # To handle such cases, we create a tensor of square matrices of # size the maximum number of pilots carried by an inner dimension # and zero-padding is used to handle axes with less pilots than the # maximum value. The obtained structure is: # # |B 0| # |0 0| # pil_cov_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil], complex) for tx,st,oi in itertools.product(range(num_tx), range(num_streams_per_tx), range(outer_dim_size)): pil_ind = pilot_indices[tx][st][oi] num_pil = len(pil_ind) tmp = np.take(cov_mat, pil_ind, axis=0) pil_cov_mat_ = np.take(tmp, pil_ind, axis=1) pil_cov_mat[tx,st,oi,:num_pil,:num_pil] = pil_cov_mat_ # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil] self._pil_cov_mat = tf.constant(pil_cov_mat, self._cdtype) # Pre-compute the covariance matrix with only the columns corresponding # to pilots. b_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, inner_dim_size], complex) for tx,st,oi in itertools.product(range(num_tx), range(num_streams_per_tx), range(outer_dim_size)): pil_ind = pilot_indices[tx][st][oi] num_pil = len(pil_ind) b_mat_ = np.take(cov_mat, pil_ind, axis=0) b_mat[tx,st,oi,:num_pil,:] = b_mat_ self._b_mat = tf.constant(b_mat, self._cdtype) # Indices used to fill with zeros the columns of the interpolation # matrix not corresponding to zeros. # The results is a matrix of size inner_dim_size x inner_dim_size # where rows and columns not correspondong to pilots are set to zero. pil_loc = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, max_num_pil, 5], dtype=int) for tx,st,oi,p,ii in itertools.product(range(num_tx), range(num_streams_per_tx), range(outer_dim_size), range(max_num_pil), range(inner_dim_size)): if p >= len(pilot_indices[tx][st][oi]): # An extra dummy subcarrier is added to push there padding # identity matrix pil_loc[tx, st, oi, ii, p] = [tx, st, oi, inner_dim_size, inner_dim_size] else: pil_loc[tx, st, oi, ii, p] = [tx, st, oi, ii, pilot_indices[tx][st][oi][p]] self._pil_loc = tf.cast(pil_loc, tf.int32) # Covariance matrix for each stream with only the row corresponding # to a pilot carrying RE not set to 0. # This is required to compute the estimation error variances. err_var_mat = np.zeros([num_tx, num_streams_per_tx, outer_dim_size, inner_dim_size, inner_dim_size], complex) for tx,st,oi in itertools.product(range(num_tx), range(num_streams_per_tx), range(outer_dim_size)): pil_ind = pilot_indices[tx][st][oi] mask = np.zeros([inner_dim_size], complex) mask[pil_ind] = 1.0 mask = np.expand_dims(mask, axis=1) err_var_mat[tx,st,oi] = cov_mat*mask self._err_var_mat = tf.constant(err_var_mat, self._cdtype) def __call__(self, h_hat, err_var): # h_hat : [batch_size, num_rx, num_rx_ant, num_tx, # num_streams_per_tx, outer_dim_size, inner_dim_size] # err_var : [batch_size, num_rx, num_rx_ant, num_tx, # num_streams_per_tx, outer_dim_size, inner_dim_size] batch_size = tf.shape(h_hat)[0] num_rx = tf.shape(h_hat)[1] num_rx_ant = tf.shape(h_hat)[2] num_tx = tf.shape(h_hat)[3] num_tx_stream = tf.shape(h_hat)[4] outer_dim_size = self._outer_dim_size inner_dim_size = self._inner_dim_size ##################################### # Compute the interpolation matrix ##################################### # Computation of the interpolation matrix is done solving the # least-square problem: # # X = min_Z |AZ - B|_F^2 # # where A = (\Pi_T R \Pi + S) and # B = R \Pi # where R is the channel covariance matrix, S the error # diagonal covariance matrix, and \Pi the matrix that spreads the pilots # according to the pilot pattern along the inner axis. # # Computing A # # Covariance matrices restricted to pilot locations # [num_tx, num_streams_per_tx, outer_dim_size, max_num_pil, max_num_pil] pil_cov_mat = self._pil_cov_mat # Adding batch, receive, and receive antennas dimensions to the # covariance matrices restricted to pilot locations and to the # regularization values # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, max_num_pil, max_num_pil] pil_cov_mat = expand_to_rank(pil_cov_mat, 8, 0) pil_cov_mat = tf.tile(pil_cov_mat, [batch_size, num_rx, num_rx_ant, 1, 1, 1, 1, 1]) # Adding the noise variance to the covariance matrices restricted to # pilots # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, max_num_pil, max_num_pil] pil_cov_mat_ = tf.transpose(pil_cov_mat, [3, 4, 5, 6, 7, 0, 1, 2]) err_var_ = tf.complex(err_var, self._rzero) err_var_ = tf.transpose(err_var_, [3, 4, 5, 6, 0, 1, 2]) a_mat = tf.tensor_scatter_nd_add(pil_cov_mat_, self._add_err_var_indices, err_var_) a_mat = tf.transpose(a_mat, [5, 6, 7, 0, 1, 2, 3, 4]) # # Computing B # # B is pre-computed as it only depend on the channel covariance and # pilot pattern. # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, max_num_pil, inner_dim_size] b_mat = self._b_mat b_mat = expand_to_rank(b_mat, 8, 0) b_mat = tf.tile(b_mat, [batch_size, num_rx, num_rx_ant, 1, 1, 1, 1, 1]) # # Computing the interpolation matrix # # Using lstsq to compute the columns of the interpolation matrix # corresponding to pilots. # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size, max_num_pil] ext_mat = tf.linalg.lstsq(a_mat, b_mat, fast=False) ext_mat = tf.transpose(ext_mat, [0,1,2,3,4,5,7,6], conjugate=True) # Filling with zeros the columns not corresponding to pilots. # An extra dummy outer dim is added to scatter there the coefficients # of the identity matrix used for padding. # This dummy dim is then removed. # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size, inner_dim_size] ext_mat = tf.transpose(ext_mat, [3, 4, 5, 6, 7, 0, 1, 2]) ext_mat = tf.scatter_nd(self._pil_loc, ext_mat, [num_tx, num_tx_stream, outer_dim_size, inner_dim_size+1, inner_dim_size+1, batch_size, num_rx, num_rx_ant]) ext_mat = tf.transpose(ext_mat, [5, 6, 7, 0, 1, 2, 3, 4]) ext_mat = ext_mat[...,:inner_dim_size,:inner_dim_size] ################################################ # Apply interpolation over the inner dimension ################################################ # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] h_hat = tf.expand_dims(h_hat, axis=-1) h_hat = tf.matmul(ext_mat, h_hat) h_hat = tf.squeeze(h_hat, axis=-1) ############################## # Compute the error variances ############################## # Keep track of the previous estimation error variances for later use # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] err_var_old = err_var # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] cov_mat = expand_to_rank(self._cov_mat, 8, 0) err_var = tf.linalg.diag_part(cov_mat) err_var_mat = expand_to_rank(self._err_var_mat, 8, 0) err_var_mat = tf.transpose(err_var_mat, [0, 1, 2, 3, 4, 5, 7, 6]) err_var = err_var - tf.reduce_sum(ext_mat*err_var_mat, axis=-1) err_var = tf.math.real(err_var) err_var = tf.maximum(err_var, self._rzero) ##################################### # If this is *not* the last # interpolation step, scales the # input `h_hat` to ensure # it has the variance expected by the # next interpolation step. # # The error variance also `err_var` # is updated accordingly. ##################################### if not self._last_step: # # Variance of h_hat # # Conjugate transpose of LMMSE matrix # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size, inner_dim_size] ext_mat_h = tf.transpose(ext_mat, [0, 1, 2, 3, 4, 5, 7, 6], conjugate=True) # First part of the estimate covariance # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size, inner_dim_size] h_hat_var_1 = tf.matmul(cov_mat, ext_mat_h) h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6]) # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] h_hat_var_1 = tf.reduce_sum(ext_mat*h_hat_var_1, axis=-1) # Second part of the estimate covariance # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] err_var_old_c = tf.complex(err_var_old, self._rzero) err_var_old_c = tf.expand_dims(err_var_old_c, axis=-1) h_hat_var_2 = err_var_old_c*ext_mat_h h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6]) h_hat_var_2 = tf.reduce_sum(ext_mat*h_hat_var_2, axis=-1) # Variance of h_hat # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] h_hat_var = h_hat_var_1 + h_hat_var_2 # Scaling factor # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] err_var_c = tf.complex(err_var, self._rzero) h_var = tf.linalg.diag_part(cov_mat) s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c) # Apply scaling to estimate # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] h_hat = s*h_hat # Updated variance # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # outer_dim_size, inner_dim_size] err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c err_var = tf.math.real(err_var) err_var = tf.maximum(err_var, self._rzero) return h_hat, err_var class SpatialChannelFilter: # pylint: disable=line-too-long r"""SpatialChannelFilter(cov_mat, last_step) Implements linear minimum mean square error (LMMSE) smoothing. We consider the following model: .. math:: \mathbf{y} = \mathbf{h} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^{M}` is the received signal vector, :math:`\mathbf{h}\in\mathbb{C}^{M}` is the channel vector to be estimated with covariance matrix :math:`\mathbb{E}\left[ \mathbf{h} \mathbf{h}^{\mathsf{H}} \right] = \mathbf{R}`, and :math:`\mathbf{n}\in\mathbb{C}^M` is a zero-mean noise vector whose elements have variance :math:`N_0`. The channel estimate :math:`\hat{\mathbf{h}}` is computed as .. math:: \hat{\mathbf{h}} &= \mathbf{A} \mathbf{y} where .. math:: \mathbf{A} = \mathbf{R} \left( \mathbf{R} + N_0 \mathbf{I}_M \right)^{-1} where :math:`\mathbf{I}_M` is the :math:`M \times M` identity matrix. The estimation error is: .. math:: \tilde{h} = \mathbf{h} - \hat{\mathbf{h}} The error variances .. math:: \sigma^2_i = \mathbb{E}\left[\tilde{h}_i \tilde{h}_i^\star \right], 0 \leq i \leq M-1 are the diagonal elements of .. math:: \mathbb{E}\left[\mathbf{\tilde{h}} \mathbf{\tilde{h}}^{\mathsf{H}} \right] = \mathbf{R} - \mathbf{A}\mathbf{R}. Note ---- If you want to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. Parameters ---------- cov_mat : [num_rx_ant, num_rx_ant], tf.complex Spatial covariance matrix of the channel last_step : bool Set to `True` if this is the last interpolation step. Otherwise, set to `False`. If `True`, the the output is scaled to ensure its variance is as expected by the following interpolation step. Input ----- h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex Channel estimates. err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float Channel estimation error variances. Output ------ h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.complex Channel estimates smoothed accross the spatial dimension err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers, num_rx_ant], tf.float The channel estimation error variances of the smoothed channel estimates. """ def __init__(self, cov_mat, last_step): self._rzero = tf.zeros((), cov_mat.dtype.real_dtype) self._cov_mat = cov_mat self._last_step = last_step # Indices for adding a tensor of vectors [..., num_rx_ant] to the # diagonal of a tensor of matrices [..., num_rx_ant, num_rx_ant] num_rx_ant = cov_mat.shape[0] add_diag_indices = [[rxa, rxa] for rxa in range(num_rx_ant)] self._add_diag_indices = tf.cast(add_diag_indices, tf.int32) def __call__(self, h_hat, err_var): # h_hat : [batch_size, num_rx, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_subcarriers, num_rx_ant] # err_var : [batch_size, num_rx, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_subcarriers, num_rx_ant] # [..., num_rx_ant] err_var = tf.complex(err_var, self._rzero) # Keep track of the previous estimation error variances for later use err_var_old = err_var # [num_rx_ant, num_rx_ant] cov_mat = self._cov_mat cov_mat_t = tf.transpose(cov_mat) num_rx_ant = tf.shape(cov_mat)[0] ########################################## # Compute LMMSE matrix ########################################## # [..., num_rx_ant, num_rx_ant] cov_mat = expand_to_rank(cov_mat, tf.rank(err_var)+1, axis=0) # Adding the error variances to the diagonal # [..., num_rx_ant, num_rx_ant] lmmse_mat = tf.broadcast_to(cov_mat, tf.concat([tf.shape(err_var), [num_rx_ant]], axis=0)) # [num_rx_ant, ...] err_var_ = tf.transpose(err_var, [6, 0, 1, 2, 3, 4, 5]) # [num_rx_ant, num_rx_ant, ...] lmmse_mat = tf.transpose(lmmse_mat, [6, 7, 0, 1, 2, 3, 4, 5]) lmmse_mat = tf.tensor_scatter_nd_add(lmmse_mat, self._add_diag_indices, err_var_) # [..., num_rx_ant, num_rx_ant] lmmse_mat = tf.transpose(lmmse_mat, [2, 3, 4, 5, 6, 7, 0, 1]) # [..., num_rx_ant, num_rx_ant] lmmse_mat = matrix_inv(lmmse_mat) lmmse_mat = tf.matmul(cov_mat, lmmse_mat) ########################################## # Apply smoothing ########################################## # [..., num_rx_ant, 1] h_hat = tf.expand_dims(h_hat, axis=-1) # [..., num_rx_ant] h_hat = tf.squeeze(tf.matmul(lmmse_mat, h_hat), axis=-1) ########################################## # Compute the estimation error variances ########################################## # [..., num_rx_ant, num_rx_ant] cov_mat_t = expand_to_rank(cov_mat_t, tf.rank(lmmse_mat), axis=0) # [..., num_rx_ant] err_var = tf.reduce_sum(cov_mat_t*lmmse_mat, axis=-1) # [..., num_rx_ant] err_var = tf.linalg.diag_part(cov_mat) - err_var err_var = tf.math.real(err_var) err_var = tf.maximum(err_var, self._rzero) ########################################## # If this is *not* the last # interpolation step, scales the # input `h_hat` to ensure # it has the variance expected by the # next interpolation step. # # The error variance also `err_var` # is updated accordingly. ########################################## if not self._last_step: # # Variance of h_hat # # Conjugate transpose of the LMMSE matrix # [..., num_rx_ant, num_rx_ant] lmmse_mat_h = tf.transpose(lmmse_mat, [0, 1, 2, 3, 4, 5, 7, 6], conjugate=True) # First part of the estimate covariance # [..., num_rx_ant, num_rx_ant] h_hat_var_1 = tf.matmul(cov_mat, lmmse_mat_h) h_hat_var_1 = tf.transpose(h_hat_var_1, [0, 1, 2, 3, 4, 5, 7, 6]) # [..., num_rx_ant] h_hat_var_1 = tf.reduce_sum(lmmse_mat*h_hat_var_1, axis=-1) # Second part of the estimate covariance # [..., num_rx_ant, 1] err_var_old = tf.expand_dims(err_var_old, axis=-1) # [..., num_rx_ant, num_rx_ant] h_hat_var_2 = err_var_old*lmmse_mat_h # [..., num_rx_ant, num_rx_ant] h_hat_var_2 = tf.transpose(h_hat_var_2, [0, 1, 2, 3, 4, 5, 7, 6]) # [..., num_rx_ant] h_hat_var_2 = tf.reduce_sum(lmmse_mat*h_hat_var_2, axis=-1) # Variance of h_hat # [..., num_rx_ant] h_hat_var = h_hat_var_1 + h_hat_var_2 # Scaling factor # [..., num_rx_ant] err_var_c = tf.complex(err_var, self._rzero) h_var = tf.linalg.diag_part(cov_mat) s = tf.math.divide_no_nan(2.*h_var, h_hat_var + h_var - err_var_c) # Apply scaling to estimate # [..., num_rx_ant] h_hat = s*h_hat # Updated variance # [..., num_rx_ant] err_var = s*(s-1.)*h_hat_var + (1.-s)*h_var + s*err_var_c err_var = tf.math.real(err_var) err_var = tf.maximum(err_var, self._rzero) return h_hat, err_var
[docs]class LMMSEInterpolator(BaseChannelInterpolator): # pylint: disable=line-too-long r"""LMMSEInterpolator(pilot_pattern, cov_mat_time, cov_mat_freq, cov_mat_space=None, order='t-f') LMMSE interpolation on a resource grid with optional spatial smoothing. This class computes for each element of an OFDM resource grid a channel estimate and error variance through linear minimum mean square error (LMMSE) interpolation/smoothing. It is assumed that the measurements were taken at the nonzero positions of a :class:`~sionna.ofdm.PilotPattern`. Depending on the value of ``order``, the interpolation is carried out accross time (t), i.e., OFDM symbols, frequency (f), i.e., subcarriers, and optionally space (s), i.e., receive antennas, in any desired order. For simplicity, we describe the underlying algorithm assuming that interpolation across the sub-carriers is performed first, followed by interpolation across OFDM symbols, and finally by spatial smoothing across receive antennas. The algorithm is similar if interpolation and/or smoothing are performed in a different order. For clarity, antenna indices are omitted when describing frequency and time interpolation, as the same process is applied to all the antennas. The input ``h_hat`` is first reshaped to a resource grid :math:`\hat{\mathbf{H}} \in \mathbb{C}^{N \times M}`, by scattering the channel estimates at pilot locations according to the ``pilot_pattern``. :math:`N` denotes the number of OFDM symbols and :math:`M` the number of sub-carriers. The first pass consists in interpolating across the sub-carriers: .. math:: \hat{\mathbf{h}}_n^{(1)} = \mathbf{A}_n \hat{\mathbf{h}}_n where :math:`1 \leq n \leq N` is the OFDM symbol index and :math:`\hat{\mathbf{h}}_n` is the :math:`n^{\text{th}}` (transposed) row of :math:`\hat{\mathbf{H}}`. :math:`\mathbf{A}_n` is the :math:`M \times M` matrix such that: .. math:: \mathbf{A}_n = \bar{\mathbf{A}}_n \mathbf{\Pi}_n^\intercal where .. math:: \bar{\mathbf{A}}_n = \underset{\mathbf{Z} \in \mathbb{C}^{M \times K_n}}{\text{argmin}} \left\lVert \mathbf{Z}\left( \mathbf{\Pi}_n^\intercal \mathbf{R^{(f)}} \mathbf{\Pi}_n + \mathbf{\Sigma}_n \right) - \mathbf{R^{(f)}} \mathbf{\Pi}_n \right\rVert_{\text{F}}^2 and :math:`\mathbf{R^{(f)}}` is the :math:`M \times M` channel frequency covariance matrix, :math:`\mathbf{\Pi}_n` the :math:`M \times K_n` matrix that spreads :math:`K_n` values to a vector of size :math:`M` according to the ``pilot_pattern`` for the :math:`n^{\text{th}}` OFDM symbol, and :math:`\mathbf{\Sigma}_n \in \mathbb{R}^{K_n \times K_n}` is the channel estimation error covariance built from ``err_var`` and assumed to be diagonal. Computation of :math:`\bar{\mathbf{A}}_n` is done using an algorithm based on complete orthogonal decomposition. This is done to avoid matrix inversion for badly conditioned covariance matrices. The channel estimation error variances after the first interpolation pass are computed as .. math:: \mathbf{\Sigma}^{(1)}_n = \text{diag} \left( \mathbf{R^{(f)}} - \mathbf{A}_n \mathbf{\Xi}_n \mathbf{R^{(f)}} \right) where :math:`\mathbf{\Xi}_n` is the diagonal matrix of size :math:`M \times M` that zeros the columns corresponding to sub-carriers not carrying any pilots. Note that interpolation is not performed for OFDM symbols which do not carry pilots. **Remark**: The interpolation matrix differs across OFDM symbols as different OFDM symbols may carry pilots on different sub-carriers and/or have different estimation error variances. Scaling of the estimates is then performed to ensure that their variances match the ones expected by the next interpolation step, and the error variances are updated accordingly: .. math:: \begin{align} \left[\hat{\mathbf{h}}_n^{(2)}\right]_m &= s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m\\ \left[\mathbf{\Sigma}^{(2)}_n\right]_{m,m} &= s_{n,m}\left( s_{n,m}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m} + \left( 1 - s_{n,m} \right) \left[\mathbf{R^{(f)}}\right]_{m,m} + s_{n,m} \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m} \end{align} where the scaling factor :math:`s_{n,m}` is such that: .. math:: \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}_n^{(1)}\right]_m \right\rvert^2 \right\} = \left[\mathbf{R^{(f)}}\right]_{m,m} + \mathbb{E} \left\{ \left\lvert s_{n,m} \left[\hat{\mathbf{h}}^{(1)}_n\right]_m - \left[\mathbf{h}_n\right]_m \right\rvert^2 \right\} which leads to: .. math:: \begin{align} s_{n,m} &= \frac{2 \left[\mathbf{R^{(f)}}\right]_{m,m}}{\left[\mathbf{R^{(f)}}\right]_{m,m} - \left[\mathbf{\Sigma}^{(1)}_n\right]_{m,m} + \left[\hat{\mathbf{\Sigma}}^{(1)}_n\right]_{m,m}}\\ \hat{\mathbf{\Sigma}}^{(1)}_n &= \mathbf{A}_n \mathbf{R^{(f)}} \mathbf{A}_n^{\mathrm{H}}. \end{align} The second pass consists in interpolating across the OFDM symbols: .. math:: \hat{\mathbf{h}}_m^{(3)} = \mathbf{B}_m \tilde{\mathbf{h}}^{(2)}_m where :math:`1 \leq m \leq M` is the sub-carrier index and :math:`\tilde{\mathbf{h}}^{(2)}_m` is the :math:`m^{\text{th}}` column of .. math:: \hat{\mathbf{H}}^{(2)} = \begin{bmatrix} {\hat{\mathbf{h}}_1^{(2)}}^\intercal\\ \vdots\\ {\hat{\mathbf{h}}_N^{(2)}}^\intercal \end{bmatrix} and :math:`\mathbf{B}_m` is the :math:`N \times N` interpolation LMMSE matrix: .. math:: \mathbf{B}_m = \bar{\mathbf{B}}_m \tilde{\mathbf{\Pi}}_m^\intercal where .. math:: \bar{\mathbf{B}}_m = \underset{\mathbf{Z} \in \mathbb{C}^{N \times L_m}}{\text{argmin}} \left\lVert \mathbf{Z} \left( \tilde{\mathbf{\Pi}}_m^\intercal \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m + \tilde{\mathbf{\Sigma}}^{(2)}_m \right) - \mathbf{R^{(t)}}\tilde{\mathbf{\Pi}}_m \right\rVert_{\text{F}}^2 where :math:`\mathbf{R^{(t)}}` is the :math:`N \times N` channel time covariance matrix, :math:`\tilde{\mathbf{\Pi}}_m` the :math:`N \times L_m` matrix that spreads :math:`L_m` values to a vector of size :math:`N` according to the ``pilot_pattern`` for the :math:`m^{\text{th}}` sub-carrier, and :math:`\tilde{\mathbf{\Sigma}}^{(2)}_m \in \mathbb{R}^{L_m \times L_m}` is the diagonal matrix of channel estimation error variances built by gathering the error variances from (:math:`\mathbf{\Sigma}^{(2)}_1,\dots,\mathbf{\Sigma}^{(2)}_N`) corresponding to resource elements carried by the :math:`m^{\text{th}}` sub-carrier. Computation of :math:`\bar{\mathbf{B}}_m` is done using an algorithm based on complete orthogonal decomposition. This is done to avoid matrix inversion for badly conditioned covariance matrices. The resulting channel estimate for the resource grid is .. math:: \hat{\mathbf{H}}^{(3)} = \left[ \hat{\mathbf{h}}_1^{(3)} \dots \hat{\mathbf{h}}_M^{(3)} \right] The resulting channel estimation error variances are the diagonal coefficients of the matrices .. math:: \mathbf{\Sigma}^{(3)}_m = \mathbf{R^{(t)}} - \mathbf{B}_m \tilde{\mathbf{\Xi}}_m \mathbf{R^{(t)}}, 1 \leq m \leq M where :math:`\tilde{\mathbf{\Xi}}_m` is the diagonal matrix of size :math:`N \times N` that zeros the columns corresponding to OFDM symbols not carrying any pilots. **Remark**: The interpolation matrix differs across sub-carriers as different sub-carriers may have different estimation error variances computed by the first pass. However, all sub-carriers carry at least one channel estimate as a result of the first pass, ensuring that a channel estimate is computed for all the resource elements after the second pass. **Remark:** LMMSE interpolation requires knowledge of the time and frequency covariance matrices of the channel. The notebook `OFDM MIMO Channel Estimation and Detection <../examples/OFDM_MIMO_Detection.ipynb>`_ shows how to estimate such matrices for arbitrary channel models. Moreover, the functions :func:`~sionna.ofdm.tdl_time_cov_mat` and :func:`~sionna.ofdm.tdl_freq_cov_mat` compute the expected time and frequency covariance matrices, respectively, for the :class:`~sionna.channel.tr38901.TDL` channel models. Scaling of the estimates is then performed to ensure that their variances match the ones expected by the next smoothing step, and the error variances are updated accordingly: .. math:: \begin{align} \left[\hat{\mathbf{h}}_m^{(4)}\right]_n &= \gamma_{m,n} \left[\hat{\mathbf{h}}_m^{(3)}\right]_n\\ \left[\mathbf{\Sigma}^{(4)}_m\right]_{n,n} &= \gamma_{m,n}\left( \gamma_{m,n}-1 \right) \left[\hat{\mathbf{\Sigma}}^{(3)}_m\right]_{n,n} + \left( 1 - \gamma_{m,n} \right) \left[\mathbf{R^{(t)}}\right]_{n,n} + \gamma_{m,n} \left[\mathbf{\Sigma}^{(3)}_n\right]_{m,m} \end{align} where: .. math:: \begin{align} \gamma_{m,n} &= \frac{2 \left[\mathbf{R^{(t)}}\right]_{n,n}}{\left[\mathbf{R^{(t)}}\right]_{n,n} - \left[\mathbf{\Sigma}^{(3)}_m\right]_{n,n} + \left[\hat{\mathbf{\Sigma}}^{(3)}_n\right]_{m,m}}\\ \hat{\mathbf{\Sigma}}^{(3)}_m &= \mathbf{B}_m \mathbf{R^{(t)}} \mathbf{B}_m^{\mathrm{H}} \end{align} Finally, a spatial smoothing step is applied to every resource element carrying a channel estimate. For clarity, we drop the resource element indexing :math:`(n,m)`. We denote by :math:`L` the number of receive antennas, and by :math:`\mathbf{R^{(s)}}\in\mathbb{C}^{L \times L}` the spatial covariance matrix. LMMSE spatial smoothing consists in the following computations: .. math:: \hat{\mathbf{h}}^{(5)} = \mathbf{C} \hat{\mathbf{h}}^{(4)} where .. math:: \mathbf{C} = \mathbf{R^{(s)}} \left( \mathbf{R^{(s)}} + \mathbf{\Sigma}^{(4)} \right)^{-1}. The estimation error variances are the digonal coefficients of .. math:: \mathbf{\Sigma}^{(5)} = \mathbf{R^{(s)}} - \mathbf{C}\mathbf{R^{(s)}} The smoothed channel estimate :math:`\hat{\mathbf{h}}^{(5)}` and corresponding error variances :math:`\text{diag}\left( \mathbf{\Sigma}^{(5)} \right)` are returned for every resource element :math:`(m,n)`. **Remark:** No scaling is performed after the last interpolation or smoothing step. **Remark:** All passes assume that the estimation error covariance matrix (:math:`\mathbf{\Sigma}`, :math:`\tilde{\mathbf{\Sigma}}^{(2)}`, or :math:`\tilde{\mathbf{\Sigma}}^{(4)}`) is diagonal, which may not be accurate. When this assumption does not hold, this interpolator is only an approximation of LMMSE interpolation. **Remark:** The order in which frequency interpolation, temporal interpolation, and, optionally, spatial smoothing are applied, is controlled using the ``order`` parameter. Note ---- This layer does not support graph mode with XLA. Parameters ---------- pilot_pattern : PilotPattern An instance of :class:`~sionna.ofdm.PilotPattern` cov_mat_time : [num_ofdm_symbols, num_ofdm_symbols], tf.complex Time covariance matrix of the channel cov_mat_freq : [fft_size, fft_size], tf.complex Frequency covariance matrix of the channel cov_time_space : [num_rx_ant, num_rx_ant], tf.complex Spatial covariance matrix of the channel. Defaults to `None`. Only required if spatial smoothing is requested (see ``order``). order : str Order in which to perform interpolation and optional smoothing. For example, ``"t-f-s"`` means that interpolation across the OFDM symbols is performed first (``"t"``: time), followed by interpolation across the sub-carriers (``"f"``: frequency), and finally smoothing across the receive antennas (``"s"``: space). Similarly, ``"f-t"`` means interpolation across the sub-carriers followed by interpolation across the OFDM symbols and no spatial smoothing. The spatial covariance matrix (``cov_time_space``) is only required when spatial smoothing is requested. Time and frequency interpolation are not optional to ensure that a channel estimate is computed for all resource elements. Input ----- h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimates for the pilot-carrying resource elements err_var : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_pilot_symbols], tf.complex Channel estimation error variances for the pilot-carrying resource elements Output ------ h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex Channel estimates accross the entire resource grid for all transmitters and streams err_var : Same shape as ``h_hat``, tf.float Channel estimation error variances accross the entire resource grid for all transmitters and streams """ def __init__(self, pilot_pattern, cov_mat_time, cov_mat_freq, cov_mat_space=None, order='t-f'): # Check the specified order order = order.split('-') assert 2 <= len(order) <= 3, "Invalid order for interpolation." spatial_smoothing = False freq_smoothing = False time_smoothing = False for o in order: assert o in ('s', 'f', 't'), f"Uknown dimension {o}" if o == 's': assert not spatial_smoothing,\ "Spatial smoothing can be specified at most once" spatial_smoothing = True elif o == 't': assert not time_smoothing,\ "Temporal interpolation can be specified once only" time_smoothing = True elif o == 'f': assert not freq_smoothing,\ "Frequency interpolation can be specified once only" freq_smoothing = True if spatial_smoothing: assert cov_mat_space is not None,\ "A spatial covariance matrix is required for spatial smoothing" assert freq_smoothing, "Frequency interpolation is required" assert time_smoothing, "Time interpolation is required" self._order = order self._num_ofdm_symbols = pilot_pattern.num_ofdm_symbols self._num_effective_subcarriers =pilot_pattern.num_effective_subcarriers # Build pilot masks for every stream pilot_mask = self._build_pilot_mask(pilot_pattern) # Build indices for mapping channel estimates and # error variances that are given as input to a # resource grid num_pilots = pilot_pattern.pilots.shape[2] inputs_to_rg_indices = self._build_inputs2rg_indices(pilot_mask, num_pilots) self._inputs_to_rg_indices = tf.cast(inputs_to_rg_indices, tf.int32) # 1D interpolator according to requested order # Interpolation is always performed along the inner dimension. interpolators = [] # Masks for masking error variances that were not updated err_var_masks = [] for i, o in enumerate(order): # Is it the last one? last_step = i == len(order)-1 # Frequency if o == "f": interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_freq, last_step=last_step) pilot_mask = self._update_pilot_mask_interp(pilot_mask) err_var_mask = tf.cast(pilot_mask == 1, cov_mat_freq.dtype.real_dtype) # Time elif o == 't': pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2]) interpolator = LMMSEInterpolator1D(pilot_mask, cov_mat_time, last_step=last_step) pilot_mask = self._update_pilot_mask_interp(pilot_mask) pilot_mask = tf.transpose(pilot_mask, [0, 1, 3, 2]) err_var_mask = tf.cast(pilot_mask == 1, cov_mat_freq.dtype.real_dtype) # Space elif o == 's': interpolator = SpatialChannelFilter(cov_mat_space, last_step=last_step) err_var_mask = tf.cast(pilot_mask == 1, cov_mat_freq.dtype.real_dtype) interpolators.append(interpolator) err_var_masks.append(err_var_mask) self._interpolators = interpolators self._err_var_masks = err_var_masks def _build_pilot_mask(self, pilot_pattern): """ Build for every transmitter and stream a pilot mask indicating which REs are allocated to pilots, data, or not used. # 0 -> Data # 1 -> Pilot # 2 -> Not used """ mask = pilot_pattern.mask pilots = pilot_pattern.pilots num_tx = mask.shape[0] num_streams_per_tx = mask.shape[1] num_ofdm_symbols = mask.shape[2] num_effective_subcarriers = mask.shape[3] pilot_mask = np.zeros([num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], int) for tx,st in itertools.product( range(num_tx), range(num_streams_per_tx)): pil_index = 0 for sb,sc in itertools.product( range(num_ofdm_symbols), range(num_effective_subcarriers)): if mask[tx,st,sb,sc] == 1: if np.abs(pilots[tx,st,pil_index]) > 0.0: pilot_mask[tx,st,sb,sc] = 1 else: pilot_mask[tx,st,sb,sc] = 2 pil_index += 1 return pilot_mask def _build_inputs2rg_indices(self, pilot_mask, num_pilots): """ Builds indices for mapping channel estimates and error variances that are given as input to a resource grid """ num_tx = pilot_mask.shape[0] num_streams_per_tx = pilot_mask.shape[1] num_ofdm_symbols = pilot_mask.shape[2] num_effective_subcarriers = pilot_mask.shape[3] inputs_to_rg_indices = np.zeros([num_tx, num_streams_per_tx, num_pilots, 4], int) for tx,st in itertools.product( range(num_tx), range(num_streams_per_tx)): pil_index = 0 # Pilot index for this stream for sb,sc in itertools.product( range(num_ofdm_symbols), range(num_effective_subcarriers)): if pilot_mask[tx,st,sb,sc] == 0: continue if pilot_mask[tx,st,sb,sc] == 1: inputs_to_rg_indices[tx, st, pil_index] = [tx, st, sb, sc] pil_index += 1 return inputs_to_rg_indices def _update_pilot_mask_interp(self, pilot_mask): """ Update the pilot mask to label the resource elements for which the channel was interpolated. """ interpolated = np.any(pilot_mask == 1, axis=-1, keepdims=True) pilot_mask = np.where(interpolated, 1, pilot_mask) return pilot_mask def __call__(self, h_hat, err_var): # h_hat : [batch_size, num_rx, num_rx_ant, num_tx, # num_streams_per_tx, num_pilots] # err_var : [batch_size, num_rx, num_rx_ant, num_tx, # num_streams_per_tx, num_pilots] batch_size = tf.shape(h_hat)[0] num_rx = tf.shape(h_hat)[1] num_rx_ant = tf.shape(h_hat)[2] num_tx = tf.shape(h_hat)[3] num_tx_stream = tf.shape(h_hat)[4] num_ofdm_symbols = self._num_ofdm_symbols num_effective_subcarriers = self._num_effective_subcarriers # For some estimator, err_var might not have the same shape # as h_hat err_var = tf.broadcast_to(err_var, tf.shape(h_hat)) # Mapping the channel estimates and error variances to a resource grid # all : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_effective_subcarriers] h_hat = tf.transpose(h_hat, [3, 4, 5, 0, 1, 2]) err_var = tf.transpose(err_var, [3, 4, 5, 0, 1, 2]) h_hat = tf.scatter_nd(self._inputs_to_rg_indices, h_hat, [num_tx, num_tx_stream, num_ofdm_symbols, num_effective_subcarriers, batch_size, num_rx, num_rx_ant]) err_var = tf.scatter_nd(self._inputs_to_rg_indices, err_var, [num_tx, num_tx_stream, num_ofdm_symbols, num_effective_subcarriers, batch_size, num_rx, num_rx_ant]) h_hat = tf.transpose(h_hat, [4, 5, 6, 0, 1, 2, 3]) err_var = tf.transpose(err_var, [4, 5, 6, 0, 1, 2, 3]) # Interpolation # Performed according to the requested order. Transpose are used as # 1D interpolation is performed along the inner axis. items = zip(self._order, self._interpolators, self._err_var_masks) for o,interp,err_var_mask in items: # Frequency if o == 'f': # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_effective_subcarriers] h_hat, err_var = interp(h_hat, err_var) err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) err_var = err_var*err_var_mask # Time elif o == 't': # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # num_effective_subcarriers, num_ofdm_symbols] h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5]) err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5]) h_hat, err_var = interp(h_hat, err_var) # [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_effective_subcarriers] h_hat = tf.transpose(h_hat, [0, 1, 2, 3, 4, 6, 5]) err_var = tf.transpose(err_var, [0, 1, 2, 3, 4, 6, 5]) err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) err_var = err_var*err_var_mask # Space elif o == 's': # [batch_size, num_rx, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] h_hat = tf.transpose(h_hat, [0, 1, 3, 4, 5, 6, 2]) err_var = tf.transpose(err_var, [0, 1, 3, 4, 5, 6, 2]) h_hat, err_var = interp(h_hat, err_var) # [batch_size, num_rx, num_tx, num_streams_per_tx, # num_ofdm_symbols, num_effective_subcarriers, num_rx_ant] h_hat = tf.transpose(h_hat, [0, 1, 6, 2, 3, 4, 5]) err_var = tf.transpose(err_var, [0, 1, 6, 2, 3, 4, 5]) err_var_mask = expand_to_rank(err_var_mask, tf.rank(err_var), 0) err_var = err_var*err_var_mask return h_hat, err_var
####################################################### # Utilities #######################################################
[docs]def tdl_freq_cov_mat(model, subcarrier_spacing, fft_size, delay_spread, dtype=tf.complex64): # pylint: disable=line-too-long r""" Computes the frequency covariance matrix of a :class:`~sionna.channel.tr38901.TDL` channel model. The channel frequency covariance matrix :math:`\mathbf{R}^{(f)}` of a TDL channel model is .. math:: \mathbf{R}^{(f)}_{u,v} = \sum_{\ell=1}^L P_\ell e^{-j 2 \pi \tau_\ell \Delta_f (u-v)}, 1 \leq u,v \leq M where :math:`M` is the FFT size, :math:`L` is the number of paths for the selected TDL model, :math:`P_\ell` and :math:`\tau_\ell` are the average power and delay for the :math:`\ell^{\text{th}}` path, respectively, and :math:`\Delta_f` is the sub-carrier spacing. Input ------ model : str TDL model for which to return the covariance matrix. Should be one of "A", "B", "C", "D", or "E". subcarrier_spacing : float Sub-carrier spacing [Hz] fft_size : float FFT size delay_spread : float Delay spread [s] dtype : tf.DType Datatype to use for the output. Should be one of `tf.complex64` or `tf.complex128`. Defaults to `tf.complex64`. Output ------ cov_mat : [fft_size, fft_size], tf.complex Channel frequency covariance matrix """ assert dtype in (tf.complex64, tf.complex128),\ "The `dtype` should be a complex datatype" # # Load the power delay profile # # Set the file from which to load the model assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model" if model == 'A': parameters_fname = "TDL-A.json" elif model == 'B': parameters_fname = "TDL-B.json" elif model == 'C': parameters_fname = "TDL-C.json" elif model == 'D': parameters_fname = "TDL-D.json" elif model == 'E': parameters_fname = "TDL-E.json" source = files(models).joinpath(parameters_fname) # pylint: disable=unspecified-encoding with open(source) as parameter_file: params = json.load(parameter_file) # LoS scenario ? los = bool(params['los']) # Retrieve power and delays delays = np.array(params['delays'])*delay_spread mean_powers = np.power(10.0, np.array(params['powers'])/10.0) if los: # Add the power of the specular and non-specular component of # the first path mean_powers[0] = mean_powers[0] + mean_powers[1] mean_powers = np.concatenate([mean_powers[:1], mean_powers[2:]], axis=0) # The first two paths have 0 delays as they correspond to the # specular and reflected components of the first path. delays = delays[1:] # Normalize the PDP norm_factor = np.sum(mean_powers) mean_powers = mean_powers / norm_factor # # Build frequency covariance matrix # n = np.arange(fft_size) p = -2.*np.pi*subcarrier_spacing*n p = np.expand_dims(p, axis=0) delays = np.expand_dims(delays, axis=1) p = p*delays p = np.exp(1j*p) p = np.expand_dims(p, axis=-1) cov_mat = np.matmul(p, np.transpose(np.conj(p), [0, 2, 1])) mean_powers = np.expand_dims(mean_powers, axis=(1,2)) cov_mat = np.sum(mean_powers*cov_mat, axis=0) return tf.cast(cov_mat, dtype)
[docs]def tdl_time_cov_mat(model, speed, carrier_frequency, ofdm_symbol_duration, num_ofdm_symbols, los_angle_of_arrival=PI/4., dtype=tf.complex64): # pylint: disable=line-too-long r""" Computes the time covariance matrix of a :class:`~sionna.channel.tr38901.TDL` channel model. For non-line-of-sight (NLoS) model, the channel time covariance matrix :math:`\mathbf{R^{(t)}}` of a TDL channel model is .. math:: \mathbf{R^{(t)}}_{u,v} = J_0 \left( \nu \Delta_t \left( u-v \right) \right) where :math:`J_0` is the zero-order Bessel function of the first kind, :math:`\Delta_t` the duration of an OFDM symbol, and :math:`\nu` the Doppler spread defined by .. math:: \nu = 2 \pi \frac{v}{c} f_c where :math:`v` is the movement speed, :math:`c` the speed of light, and :math:`f_c` the carrier frequency. For line-of-sight (LoS) channel models, the channel time covariance matrix is .. math:: \mathbf{R^{(t)}}_{u,v} = P_{\text{NLoS}} J_0 \left( \nu \Delta_t \left( u-v \right) \right) + P_{\text{LoS}}e^{j \nu \Delta_t \left( u-v \right) \cos{\alpha_{\text{LoS}}}} where :math:`\alpha_{\text{LoS}}` is the angle-of-arrival for the LoS path, :math:`P_{\text{NLoS}}` the total power of NLoS paths, and :math:`P_{\text{LoS}}` the power of the LoS path. The power delay profile is assumed to have unit power, i.e., :math:`P_{\text{NLoS}} + P_{\text{LoS}} = 1`. Input ------ model : str TDL model for which to return the covariance matrix. Should be one of "A", "B", "C", "D", or "E". speed : float Speed [m/s] carrier_frequency : float Carrier frequency [Hz] ofdm_symbol_duration : float Duration of an OFDM symbol [s] num_ofdm_symbols : int Number of OFDM symbols los_angle_of_arrival : float Angle-of-arrival for LoS path [radian]. Only used with LoS models. Defaults to :math:`\pi/4`. dtype : tf.DType Datatype to use for the output. Should be one of `tf.complex64` or `tf.complex128`. Defaults to `tf.complex64`. Output ------ cov_mat : [num_ofdm_symbols, num_ofdm_symbols], tf.complex Channel time covariance matrix """ # Doppler spread doppler_spread = 2.*PI*speed/SPEED_OF_LIGHT*carrier_frequency # # Load the power delay profile # # Set the file from which to load the model assert model in ('A', 'B', 'C', 'D', 'E'), "Invalid TDL model" if model == 'A': parameters_fname = "TDL-A.json" elif model == 'B': parameters_fname = "TDL-B.json" elif model == 'C': parameters_fname = "TDL-C.json" elif model == 'D': parameters_fname = "TDL-D.json" elif model == 'E': parameters_fname = "TDL-E.json" source = files(models).joinpath(parameters_fname) # pylint: disable=unspecified-encoding with open(source) as parameter_file: params = json.load(parameter_file) # LoS scenario ? los = bool(params['los']) # Retrieve power and delays mean_powers = np.power(10.0, np.array(params['powers'])/10.0) # Normalize the PDP norm_factor = np.sum(mean_powers) mean_powers = mean_powers / norm_factor if los: los_power = mean_powers[0] nlos_power = np.sum(mean_powers[1:]) else: nlos_power = np.sum(mean_powers) # # Build time covariance matrix # indices = np.arange(num_ofdm_symbols) s1 = np.expand_dims(indices, axis=1) s2 = np.expand_dims(indices, axis=0) exp = doppler_spread*ofdm_symbol_duration*(s1-s2) cov_mat_nlos = jv(0.0, exp)*nlos_power if los: cov_mat_los = np.exp(1j*exp*np.cos(los_angle_of_arrival))*los_power cov_mat = cov_mat_nlos+cov_mat_los else: cov_mat = cov_mat_nlos return tf.cast(cov_mat, dtype)