Source code for sionna.mimo.detection

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Classes and functions related to MIMO channel detection"""

import warnings
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
from sionna.utils import expand_to_rank, matrix_sqrt_inv, flatten_last_dims, flatten_dims, split_dim, insert_dims, hard_decisions
from sionna.mapping import Constellation, SymbolLogits2LLRs, LLRs2SymbolLogits, PAM2QAM, Demapper, SymbolDemapper, SymbolInds2Bits, DemapperWithPrior, SymbolLogits2Moments
from sionna.mimo.utils import complex2real_channel, whiten_channel, List2LLR, List2LLRSimple, complex2real_matrix, complex2real_vector, real2complex_vector
from sionna.mimo.equalization import lmmse_equalizer, zf_equalizer, mf_equalizer

[docs]class LinearDetector(Layer): # pylint: disable=line-too-long r"""LinearDetector(equalizer, output, demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) Convenience class that combines an equalizer, such as :func:`~sionna.mimo.lmmse_equalizer`, and a :class:`~sionna.mapping.Demapper`. Parameters ---------- equalizer : str, one of ["lmmse", "zf", "mf"], or an equalizer function The equalizer to be used. Either one of the existing equalizers :func:`~sionna.mimo.lmmse_equalizer`, :func:`~sionna.mimo.zf_equalizer`, or :func:`~sionna.mimo.mf_equalizer` can be used, or a custom equalizer callable provided that has the same input/output specification. output : One of ["bit", "symbol"], str The type of output, either LLRs on bits or logits on constellation symbols. demapping_method : One of ["app", "maxlog"], str The demapping method used. constellation_type : One of ["qam", "pam", "custom"], str For "custom", an instance of :class:`~sionna.mapping.Constellation` must be provided. num_bits_per_symbol : int The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for ``constellation_type`` in ["qam", "pam"]. constellation : Constellation An instance of :class:`~sionna.mapping.Constellation` or `None`. In the latter case, ``constellation_type`` and ``num_bits_per_symbol`` must be provided. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). Input ------ (y, h, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals h : [...,M,num_streams], tf.complex 2+D tensor containing the channel matrices s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices Output ------ One of: : [..., num_streams, num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` Hard-decisions correspond to the symbol indices. Note ---- If you want to use this layer in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you might need to set ``sionna.Config.xla_compat=true``. This depends on the chosen equalizer function. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, equalizer, output, demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) self._output = output self._hard_out = hard_out # Determine the equalizer to use if isinstance(equalizer, str): assert equalizer in ["lmmse", "zf", "mf"], "Unknown equalizer." if equalizer=="lmmse": self._equalizer = lmmse_equalizer elif equalizer=="zf": self._equalizer = zf_equalizer else: self._equalizer = mf_equalizer else: self._equalizer = equalizer assert output in ("bit", "symbol"), "Unknown output" assert demapping_method in ("app","maxlog"), "Unknown demapping method" constellation = Constellation.create_or_check_constellation( constellation_type, num_bits_per_symbol, constellation, dtype=dtype) self._constellation = constellation # Determine the demapper to use if output=="bit": self._demapper = Demapper(demapping_method, constellation=constellation, hard_out=hard_out, dtype=dtype) else: self._demapper = SymbolDemapper(constellation=constellation, hard_out=hard_out, dtype=dtype) def call(self, inputs): x_hat, no_eff = self._equalizer(*inputs) z = self._demapper([x_hat, no_eff]) # Reshape to the expected output shape num_streams = tf.shape(inputs[1])[-1] if self._output == 'bit': num_bits_per_symbol = self._constellation.num_bits_per_symbol z = split_dim(z, [num_streams, num_bits_per_symbol], tf.rank(z)-1) return z
[docs]class MaximumLikelihoodDetector(Layer): # pylint: disable=line-too-long r""" MaximumLikelihoodDetector(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs) MIMO maximum-likelihood (ML) detector. If the ``with_prior`` flag is set, prior knowledge on the bits or constellation points is assumed to be available. This layer implements MIMO maximum-likelihood (ML) detection assuming the following channel model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, where :math:`\mathbf{S}` has full rank. If the ``with_prior`` flag is set, it is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available, provided either as LLRs on the bits mapped onto :math:`\mathbf{x}` or as logits on the individual constellation points forming :math:`\mathbf{x}`. Prior to demapping, the received signal is whitened: .. math:: \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}} The layer can compute ML detection of symbols or bits with either soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector of bits). **\ML detection of bits:** Soft-decisions on bits are called log-likelihood ratios (LLR). With the “app” demapping method, the LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` user is then computed according to .. math:: \begin{align} LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\ &=\ln\left(\frac{ \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) }{ \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) }\right) \end{align} where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the sets of vectors of constellation points for which the :math:`i\text{th}` bit of the :math:`k\text{th}` user is equal to 1 and 0, respectively. :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of constellation points :math:`\mathbf{x}`. Assuming that the constellation points and bit levels are independent, it is computed from the prior of the bits according to .. math:: \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right) where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the :math:`k\text{th}` user given as an LLR and which is set to :math:`0` if no prior knowledge is assumed to be available, and :math:`\sigma\left(\cdot\right)` is the sigmoid function. The definition of the LLR has been chosen such that it is equivalent with that of logit. This is different from many textbooks in communications, where the LLR is defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`. With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` user is approximated like .. math:: \begin{align} LLR(k,i) \approx&\ln\left(\frac{ \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) \right) }{ \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) \right) }\right)\\ = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) - \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). \end{align} **ML detection of symbols:** Soft-decisions on symbols are called logits (i.e., unnormalized log-probability). With the “app” demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is computed according to .. math:: \begin{align} \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right)\Pr\left( \mathbf{x} \right)\right). \end{align} With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is approximated like .. math:: \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). When hard decisions are requested, this layer returns for the :math:`k` th stream .. math:: \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right)\Pr\left( \mathbf{x} \right) \right) where :math:`\mathcal{C}` is the set of constellation points. Parameters ----------- output : One of ["bit", "symbol"], str The type of output, either LLRs on bits or logits on constellation symbols. demapping_method : One of ["app", "maxlog"], str The demapping method used. num_streams : tf.int Number of transmitted streams constellation_type : One of ["qam", "pam", "custom"], str For "custom", an instance of :class:`~sionna.mapping.Constellation` must be provided. num_bits_per_symbol : int The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for ``constellation_type`` in ["qam", "pam"]. constellation : Constellation An instance of :class:`~sionna.mapping.Constellation` or `None`. In the latter case, ``constellation_type`` and ``num_bits_per_symbol`` must be provided. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. with_prior : bool If `True`, it is assumed that prior knowledge on the bits or constellation points is available. This prior information is given as LLRs (for bits) or log-probabilities (for constellation points) as an additional input to the layer. Defaults to `False`. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). Input ------ (y, h, s) or (y, h, prior, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals. h : [...,M,num_streams], tf.complex 2+D tensor containing the channel matrices. prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float Prior of the transmitted signals. If ``output`` equals "bit", then LLRs of the transmitted bits are expected. If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. Only required if the ``with_prior`` flag is set. s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices. Output ------ One of: : [..., num_streams, num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices. Note ---- If you want to use this layer in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) assert dtype in [tf.complex64, tf.complex128],\ "dtype must be tf.complex64 or tf.complex128" assert output in ("bit", "symbol"), "Unknown output" assert demapping_method in ("app","maxlog"), "Unknown demapping method" self._output = output self._demapping_method = demapping_method self._hard_out = hard_out self._with_prior = with_prior # Determine the reduce function for LLR computation if self._demapping_method == "app": self._reduce = tf.reduce_logsumexp else: self._reduce = tf.reduce_max # Create constellation object self._constellation = Constellation.create_or_check_constellation( constellation_type, num_bits_per_symbol, constellation, dtype=dtype) # Utility function to compute # vecs : [num_vecs, num_streams] The list of all possible transmitted vectors. # vecs_ind : [num_vecs, num_streams] The list of all possible transmitted vectors # constellation indices # c : [num_vecs/num_points, num_streams, num_points] Which is such that `c[:,k,s]` # gives the symbol indices in the first dimension of `vecs` for which # the `k`th stream transmitted the `s`th constellation point. vecs, vecs_ind, c = self._build_vecs(num_streams) self._vecs = tf.cast(vecs, dtype) self._vecs_ind = tf.cast(vecs_ind, tf.int32) self._c = tf.cast(c, tf.int32) if output == 'bit': num_bits_per_symbol = self._constellation.num_bits_per_symbol self._logits2llr = SymbolLogits2LLRs( method=demapping_method, num_bits_per_symbol=num_bits_per_symbol, hard_out=hard_out, dtype=dtype.real_dtype, **kwargs) self._llrs2logits = LLRs2SymbolLogits( num_bits_per_symbol=num_bits_per_symbol, hard_out=False, dtype=dtype.real_dtype, **kwargs) @property def constellation(self): return self._constellation def _build_vecs(self, num_streams): """ Utility function for building the list of all possible transmitted vectors of constellation points and the symbol indices corresponding to all possibly transmitted constellation points for every stream. Input ------ num_streams : int Number of transmitted streams Output ------- vecs : [num_vecs, K], tf.complex List of all possible transmitted vectors. c : [num_vecs/num_points, num_streams, num_points], int `c[:,k,s]` gives the symbol indices in the first dimension of `vecs` for which the `k`th stream transmitted the `s`th symbol. """ points = self._constellation.points num_points = points.shape[0] # Recursive function for generating all possible transmitted # vector of symbols and indices # `n` is the remaining number of stream to process def _build_vecs_(n): if n == 1: # If there is a single stream, then the list of possibly # transmitted vectors corresponds to the constellation points. # No recusrion is needed. vecs = np.expand_dims(points, axis=1) vecs_ind = np.expand_dims(np.arange(num_points), axis=1) else: # If the number of streams is `n >= 2` streams, then the list # of possibly transmitted vectors is # # [c_1 v , c_2 v, ..., c_N v] # # where `[c_1, ..., c_N]` is the constellation of size N, and # `v` is the list of possible vectors for `n-1` streams. # This list has therefore length `N x len(v)`. # # Building the list for `n-1` streams, recursively. v, vi = _build_vecs_(n-1) # Building the list of `n` streams by appending the # constellation points. vecs = [] vecs_ind = [] for i,p in enumerate(points): vecs.append(np.concatenate([np.full([v.shape[0], 1], p), v], axis=1)) vecs_ind.append(np.concatenate([np.full([v.shape[0], 1], i), vi], axis=1)) vecs = np.concatenate(vecs, axis=0) vecs_ind = np.concatenate(vecs_ind, axis=0) return vecs, vecs_ind # Building the list of possible vectors for the `k` streams. # [num_vecs, K] vecs, vecs_ind = _build_vecs_(num_streams) tx_ind = np.arange(num_streams) tx_ind = np.expand_dims(tx_ind, axis=0) tx_ind = np.tile(tx_ind, [vecs_ind.shape[0], 1]) vecs_ind = np.stack([tx_ind, vecs_ind], axis=-1) # Compute symbol indices for every stream. # For every constellation point `p` and for every stream `j`, we gather # the list of vector indices from `vecs` corresponding the vectors for # which the `jth` stream transmitted `p`. # [num_vecs/num_points, num_streams, num_points] c = [] for p in points: c_ = [] for j in range(num_streams): c_.append(np.where(vecs[:,j]==p)[0]) c_ = np.stack(c_, axis=-1) c.append(c_) c = np.stack(c, axis=-1) return vecs, vecs_ind, c def call(self, inputs): if self._with_prior: y, h, prior, s = inputs # If operating on bits, computes prior on symbols from the prior # on bits if self._output == 'bit': # [..., K, num_points] prior = self._llrs2logits(prior) else: y, h, s = inputs # Compute square-root of interference covariance matrix s_inv = matrix_sqrt_inv(s) # Whiten the observation y = tf.expand_dims(y, -1) y = tf.squeeze(tf.matmul(s_inv, y), axis=-1) # Compute channel after whitening h = tf.matmul(s_inv, h) # Add extra dims for broadcasting with the dimensions corresponding # to all possible transmimtted vectors # Shape: [..., 1, M, K] h = tf.expand_dims(h, axis=-3) # Add extra dims for broadcasting with the dimensions corresponding # to all possible transmimtted vectors # Shape: [..., 1, M] y = tf.expand_dims(y, axis=-2) # Reshape list of all possible vectors from # [num_vecs, K] # to # [1,...,1, num_vecs, K, 1] vecs = self._vecs vecs = tf.expand_dims(vecs, axis=-1) vecs = expand_to_rank(vecs, tf.rank(h), 0) # Compute exponents # [..., num_vecs] diff = y - tf.squeeze(h@vecs, axis=-1) exponents = -tf.reduce_sum(tf.square(tf.abs(diff)), axis=-1) # Add prior if self._with_prior: # [..., num_vecs, K] prior = expand_to_rank(prior, tf.rank(exponents), axis=0) prior_rank = tf.rank(prior) transpose_ind = tf.concat([[prior_rank-2, prior_rank-1], tf.range(prior_rank-2)], axis=0) prior = tf.transpose(prior, transpose_ind) prior = tf.gather_nd(prior, self._vecs_ind) transpose_ind = tf.concat([ tf.range(2, prior_rank), [0, 1]], axis=0) prior = tf.transpose(prior, transpose_ind) # [..., num_vecs] prior = tf.reduce_sum(prior, axis=-1) exponents = exponents + prior # Gather exponents for all symbols # [..., num_vecs/num_points, K, num_points] exp = tf.gather(exponents, self._c, axis=-1) # Compute logits on constellation points # [..., K, num_points] logits = self._reduce(exp, axis=-3) if self._output == 'bit': # Compute LLRs or hard decisions return self._logits2llr(logits) else: if self._hard_out: return tf.argmax(logits, axis=-1, output_type=tf.int32) else: return logits
[docs]class MaximumLikelihoodDetectorWithPrior(MaximumLikelihoodDetector): # pylint: disable=line-too-long r""" MaximumLikelihoodDetectorWithPrior(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) MIMO maximum-likelihood (ML) detector, assuming prior knowledge on the bits or constellation points is available. This class is deprecated as the functionality has been integrated into :class:`~sionna.mimo.MaximumLikelihoodDetector`. This layer implements MIMO maximum-likelihood (ML) detection assuming the following channel model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, where :math:`\mathbf{S}` has full rank. It is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available, provided either as LLRs on the bits modulated onto :math:`\mathbf{x}` or as logits on the individual constellation points forming :math:`\mathbf{x}`. Prior to demapping, the received signal is whitened: .. math:: \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}} The layer can compute ML detection of symbols or bits with either soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector of bits). **\ML detection of bits:** Soft-decisions on bits are called log-likelihood ratios (LLR). With the “app” demapping method, the LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` user is then computed according to .. math:: \begin{align} LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\ &=\ln\left(\frac{ \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) }{ \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) }\right) \end{align} where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the sets of vectors of constellation points for which the :math:`i\text{th}` bit of the :math:`k\text{th}` user is equal to 1 and 0, respectively. :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of constellation points :math:`\mathbf{x}`. Assuming that the constellation points and bit levels are independent, it is computed from the prior of the bits according to .. math:: \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right) where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the :math:`k\text{th}` user given as an LLR, and :math:`\sigma\left(\cdot\right)` is the sigmoid function. The definition of the LLR has been chosen such that it is equivalent with that of logit. This is different from many textbooks in communications, where the LLR is defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`. With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` user is approximated like .. math:: \begin{align} LLR(k,i) \approx&\ln\left(\frac{ \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) \right) }{ \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right) \Pr\left( \mathbf{x} \right) \right) }\right)\\ = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) - \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). \end{align} **ML detection of symbols:** Soft-decisions on symbols are called logits (i.e., unnormalized log-probability). With the “app” demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is computed according to .. math:: \begin{align} \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right)\Pr\left( \mathbf{x} \right)\right). \end{align} With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is approximated like .. math:: \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). When hard decisions are requested, this layer returns for the :math:`k` th stream .. math:: \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left( -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 \right)\Pr\left( \mathbf{x} \right) \right) where :math:`\mathcal{C}` is the set of constellation points. Parameters ----------- output : One of ["bit", "symbol"], str The type of output, either LLRs on bits or logits on constellation symbols. demapping_method : One of ["app", "maxlog"], str The demapping method used. num_streams : tf.int Number of transmitted streams constellation_type : One of ["qam", "pam", "custom"], str For "custom", an instance of :class:`~sionna.mapping.Constellation` must be provided. num_bits_per_symbol : int The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for ``constellation_type`` in ["qam", "pam"]. constellation : Constellation An instance of :class:`~sionna.mapping.Constellation` or `None`. In the latter case, ``constellation_type`` and ``num_bits_per_symbol`` must be provided. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). Input ------ (y, h, prior, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals. h : [...,M,num_streams], tf.complex 2+D tensor containing the channel matrices. prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float Prior of the transmitted signals. If ``output`` equals "bit", then LLRs of the transmitted bits are expected. If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices. Output ------ One of: : [..., num_streams, num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices. Note ---- If you want to use this layer in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs): super().__init__( output=output, demapping_method=demapping_method, num_streams=num_streams, constellation_type=constellation_type, num_bits_per_symbol=num_bits_per_symbol, constellation=constellation, hard_out=hard_out, with_prior=True, dtype=dtype, **kwargs)
[docs]class KBestDetector(Layer): # pylint: disable=line-too-long r"""KBestDetector(output, num_streams, k, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr=None, dtype=tf.complex64) MIMO K-Best detector This layer implements K-Best MIMO detection as described in (Eq. 4-5) [FT2015]_. It can either generate hard decisions (for symbols or bits) or compute LLRs. The algorithm operates in either the complex or real-valued domain. Although both options produce identical results, the former has the advantage that it can be applied to arbitrary non-QAM constellations. It also reduces the number of streams (or depth) by a factor of two. The way soft-outputs (i.e., LLRs) are computed is determined by the ``list2llr`` function. The default solution :class:`~sionna.mimo.List2LLRSimple` assigns a predetermined value to all LLRs without counter-hypothesis. This layer assumes the following channel model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, where :math:`\mathbf{S}` has full rank. In a first optional step, the channel model is converted to its real-valued equivalent, see :func:`~sionna.mimo.complex2real_channel`. We assume in the sequel the complex-valued representation. Then, the channel is whitened using :func:`~sionna.mimo.whiten_channel`: .. math:: \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}}. Next, the columns of :math:`\tilde{\mathbf{H}}` are sorted according to their norm in descending order. Then, the QR decomposition of the resulting channel matrix is computed: .. math:: \tilde{\mathbf{H}} = \mathbf{Q}\mathbf{R} where :math:`\mathbf{Q}\in\mathbb{C}^{M\times S}` is unitary and :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is upper-triangular. The channel outputs are then pre-multiplied by :math:`\mathbf{Q}^{\mathsf{H}}`. This leads to the final channel model on which the K-Best detection algorithm operates: .. math:: \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}} where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S`, :math:`\bar{\mathbf{x}}\in\mathbb{C}^S`, and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`. **LLR Computation** The K-Best algorithm produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S` and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2` for :math:`k=1,\dots,K`. If the real-valued channel representation is used, the distance metrics are scaled by 0.5 to account for the reduced noise power in each complex dimension. A hard-decision is simply the candidate with the shortest distance. Various ways to compute LLRs from this list (and possibly additional side-information) are possible. The (sub-optimal) default solution is :class:`~sionna.mimo.List2LLRSimple`. Custom solutions can be provided. Parameters ----------- output : One of ["bit", "symbol"], str The type of output, either bits or symbols. Whether soft- or hard-decisions are returned can be configured with the ``hard_out`` flag. num_streams : tf.int Number of transmitted streams k : tf.int The number of paths to keep. Cannot be larger than the number of constellation points to the power of the number of streams. constellation_type : One of ["qam", "pam", "custom"], str For "custom", an instance of :class:`~sionna.mapping.Constellation` must be provided. num_bits_per_symbol : int The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for ``constellation_type`` in ["qam", "pam"]. constellation : Constellation An instance of :class:`~sionna.mapping.Constellation` or `None`. In the latter case, ``constellation_type`` and ``num_bits_per_symbol`` must be provided. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. The detector cannot compute soft-symbols. use_real_rep : bool If `True`, the detector use the real-valued equivalent representation of the channel. Note that this only works with a QAM constellation. Defaults to `False`. list2llr: `None` or instance of :class:`~sionna.mimo.List2LLR` The function to be used to compute LLRs from a list of candidate solutions. If `None`, the default solution :class:`~sionna.mimo.List2LLRSimple` is used. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). Input ----- (y, h, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals h : [...,M,num_streams], tf.complex 2+D tensor containing the channel matrices s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices Output ------ One of: : [...,num_streams,num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` : [...,num_streams,2**num_points], tf.float or [...,num_streams], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` Hard-decisions correspond to the symbol indices. Note ---- If you want to use this layer in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``, you must set ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, output, num_streams, k, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr="default", dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) assert dtype in [tf.complex64, tf.complex128],\ "dtype must be tf.complex64 or tf.complex128." assert output in ("bit", "symbol"), "Unknown output" err_msg = "You must provide either constellation or " + \ "constellation_type and num_bits_per_symbol." if constellation is None: assert constellation_type is not None and \ num_bits_per_symbol is not None, err_msg else: assert constellation_type is None and \ num_bits_per_symbol is None, err_msg if constellation is not None: assert constellation.points.dtype==dtype, \ "Constellation has wrong dtype." self._output = output self._hard_out = hard_out self._use_real_rep = use_real_rep if self._use_real_rep: # Real-valued representation is used err_msg = "Only QAM can be used for the real-valued representation" if constellation_type is not None: assert constellation_type=="qam", err_msg else: assert constellation._constellation_type=="qam", err_msg # Double the number of streams to dectect self._num_streams = 2*num_streams # Half the number of bits for the PAM constellation if num_bits_per_symbol is None: n = constellation.num_bits_per_symbol//2 self._num_bits_per_symbol = n else: self._num_bits_per_symbol = num_bits_per_symbol//2 # Geerate a PAM constellation with 0.5 energy c = Constellation("pam", self._num_bits_per_symbol, normalize=False, dtype=dtype) c._points /= tf.cast(np.std(c._points)*np.sqrt(2), c._points.dtype) self._constellation = tf.cast(c.points, dtype.real_dtype) self._pam2qam = PAM2QAM(2*self._num_bits_per_symbol) else: # Complex-valued representation is used # Number of streams is equal to number of transmitters self._num_streams = num_streams # Create constellation or take the one provided c = Constellation.create_or_check_constellation( constellation_type, num_bits_per_symbol, constellation, dtype=dtype) self._constellation = c.points self._num_bits_per_symbol = c.num_bits_per_symbol # Number of constellation symbols self._num_symbols = self._constellation.shape[0] # Number of best paths to keep self._k = np.minimum(k, self._num_symbols**self._num_streams) if self._k < k: msg = "KBestDetector: " + \ f"The provided value of k={k} is larger than " + \ "the possible maximum number of paths. " + \ f"It has been set to k={self._k}." warnings.warn(msg) # Compute the number of previous paths a layer needs to consider num_paths = [1] # The first layer considers a single path for l in range(1, self._num_streams+1): # The lth layer considers min(k, num_symbols**l) paths num_paths.append(np.minimum(self._k, self._num_symbols**l)) self._num_paths = tf.constant(tf.stack(num_paths, 0), tf.int32) # The symbols and indices for all paths will be stored in tensors # of shape [batch_size, k, num_streams]. However, only # a subset of the available entries are updated by each stream. # To enable XLA, we need to compute the relevant indices of the tensors # that will be updated through tf.tensor_scatter_nd_update. indices = np.zeros([self._num_streams, self._k*self._num_streams, 2], np.int32) for l in range(0, self._num_streams): ind = np.zeros([self._num_paths[l+1], self._num_streams]) ind[:, :l+1] = 1 ind = np.stack(np.where(ind), -1) indices[l,:ind.shape[0],:ind.shape[1]] = ind self._indices = tf.constant(indices, dtype=tf.int32) if self._output=="bit": if self._hard_out is False: if list2llr=="default": self.list2llr = List2LLRSimple(self._num_bits_per_symbol) else: self.list2llr = list2llr else: if self._use_real_rep: n = 2*self._num_bits_per_symbol else: n = self._num_bits_per_symbol self._symbolinds2bits = SymbolInds2Bits(n, dtype=dtype.real_dtype) else: assert self._hard_out is True, \ "Soft-symbols are not supported for this detector." @property def list2llr(self): return self._list2llr @list2llr.setter def list2llr(self, value): assert isinstance(value, List2LLR) self._list2llr = value def _preprocessing(self, inputs): y, h, s = inputs # Convert to real-valued representation if desired if self._use_real_rep: y, h, s = complex2real_channel(y, h, s) # Whiten channel y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=W0632 # Order columns of H in order of decreasing norm h_norm = tf.reduce_sum(tf.abs(h)**2, axis=1) column_order = tf.argsort(h_norm, axis=-1, direction="DESCENDING") h = tf.gather(h, column_order, axis=-1, batch_dims=1) # Compute QR decomposition of sorted channel # r is upper triangular q, r = tf.linalg.qr(h) # Project y on Q' y = tf.squeeze(tf.matmul(q, tf.expand_dims(y, -1), adjoint_a=True), -1) return y, r, column_order def _select_best_paths(self, dists, path_syms, path_inds): # Determine the number of paths to keep (either all or k) num_paths = tf.shape(path_syms)[1] k = tf.minimum(num_paths, self._k) # Get the k paths with the shortest distance dists, ind = tf.math.top_k(-dists, k=k, sorted=True) dists = -dists # Select the same best paths for the symbols and symbol indices path_syms = tf.gather(path_syms, ind, axis=1, batch_dims=1) path_inds = tf.gather(path_inds, ind, axis=1, batch_dims=1) return dists, path_syms, path_inds def _next_layer(self, y, r, dists, path_syms, path_inds, stream): batch_size = tf.shape(y)[0] # Streams are processed in reverse order stream_ind = self._num_streams-1-stream # Current number of considered paths num_paths = tf.gather(self._num_paths, stream) # Store input tensors for scatter update later on dists_o = dists path_syms_o = path_syms path_inds_o = path_inds # Extract relevant values from input tensor dists = dists[..., :num_paths] path_syms = path_syms[..., :num_paths, :stream] path_inds = path_inds[..., :num_paths, :stream] # Each path creates num_symbols branches dists = tf.repeat(dists, repeats=self._num_symbols, axis=1) path_syms = tf.repeat(path_syms, repeats=self._num_symbols, axis=1) path_inds = tf.repeat(path_inds, repeats=self._num_symbols, axis=1) # Append to each path the symbols corresponding to the branch syms = tf.reshape(self._constellation, [1,-1]) syms = tf.repeat(syms, self._k, 0) syms = tf.reshape(syms, [1, -1, 1]) syms = tf.repeat(syms, batch_size, 0) syms = syms[:,:num_paths*self._num_symbols] path_syms = tf.concat([path_syms, syms], axis=-1) # Do the same for the symbol indices inds = tf.reshape(tf.range(0, self._num_symbols), [1, -1]) inds = tf.repeat(inds, self._k, 0) inds = tf.reshape(inds, [1, -1, 1]) inds = tf.repeat(inds, batch_size, 0) inds = inds[:,:num_paths*self._num_symbols] path_inds = tf.concat([path_inds, inds], axis=-1) # Compute partial distances # Extract the row of r corresponding to layer and reverse the order y = tf.expand_dims(y[:, stream_ind], axis=-1) r = tf.expand_dims(tf.reverse(r[:, stream_ind, stream_ind:], [-1]), 1) delta = tf.pow(tf.abs(y - tf.reduce_sum(r*path_syms, axis=-1)), 2) # Update distances dists += delta # Get k best paths dists, path_syms, path_inds = self._select_best_paths(dists, path_syms, path_inds) # Scatter updates of dists tensor = tf.transpose(dists_o, perm=[1, 0]) updates = tf.transpose(dists, perm=[1, 0]) indices = tf.expand_dims(tf.range(tf.shape(updates)[0], dtype=tf.int32), -1) dists = tf.tensor_scatter_nd_update(tensor, indices, updates) dists = tf.transpose(dists, perm=[1, 0]) # Scatter update of path_syms tensor = tf.transpose(path_syms_o, [1, 2, 0]) updates = tf.transpose(path_syms, [1, 2, 0]) updates = tf.reshape(updates, [-1, batch_size]) indices = self._indices[stream, :self._num_paths[stream+1]*(stream+1)] path_syms = tf.tensor_scatter_nd_update(tensor, indices, updates) path_syms = tf.transpose(path_syms, perm=[2, 0, 1]) # Scatter update of path_inds tensor = tf.transpose(path_inds_o, [1, 2, 0]) updates = tf.transpose(path_inds, [1, 2, 0]) updates = tf.reshape(updates, [-1, batch_size]) path_inds = tf.tensor_scatter_nd_update(tensor, indices, updates) path_inds = tf.transpose(path_inds, perm=[2, 0, 1]) return dists, path_syms, path_inds def _unsort(self, column_order, tensor, transpose=True): # Undo the column sorting # If transpose=True, the unsorting is done along the last dimension # Otherwise, sorting is done along the second-last index unsort_inds = tf.argsort(column_order, axis=-1) if transpose: tensor = tf.transpose(tensor, perm=[0, 2, 1]) tensor = tf.gather(tensor, unsort_inds, axis=-2, batch_dims=1) if transpose: tensor = tf.transpose(tensor, perm=[0, 2, 1]) return tensor def build(self, input_shape): assert input_shape[1][-2]>=input_shape[1][-1], \ "The number of receive antennas cannot be smaller \ than the number of streams" def call(self, inputs): # Flatten the batch dimensions y, h, s = inputs batch_shape = tf.shape(y)[:-1] num_batch_dims = len(batch_shape) if num_batch_dims > 1: y = flatten_dims(y, num_batch_dims, 0) h = flatten_dims(h, num_batch_dims, 0) s = flatten_dims(s, num_batch_dims, 0) inputs = (y,h,s) # Initialization # (i) (optional) Convert to real-valued representation # (ii) Whiten channel # (iii) Sort columns of H by decreasing column norm # (iv) QR Decomposition of H # (v) Project y onto Q' y, r, column_order = self._preprocessing(inputs) batch_size = tf.shape(y)[0] # Tensor to keep track of the aggregate distances of all paths dists = tf.zeros([batch_size, self._k], y.dtype.real_dtype) # Tensor to store constellation symbols of all paths path_syms = tf.zeros([batch_size, self._k, self._num_streams], y.dtype) # Tensor to store constellation symbol indices of all paths path_inds = tf.zeros([batch_size, self._k, self._num_streams],tf.int32) # Sequential K-Best algorithm for stream in range(0, self._num_streams): dists, path_syms, path_inds = self._next_layer(y, r, dists, path_syms, path_inds, stream) # Reverse order as detection started with the last symbol first path_syms = tf.reverse(path_syms, axis=[-1]) path_inds = tf.reverse(path_inds, axis=[-1]) # Processing for hard-decisions if self._hard_out: path_inds = self._unsort(column_order, path_inds) hard_dec = path_inds[:,0,:] # Real-valued representation if self._use_real_rep: hard_dec = \ self._pam2qam(hard_dec[...,:self._num_streams//2], hard_dec[...,self._num_streams//2:]) # Hard decisions on bits if self._output=="bit": hard_dec = self._symbolinds2bits(hard_dec) # Reshape batch dimensions if num_batch_dims > 1: hard_dec = split_dim(hard_dec, batch_shape, 0) return hard_dec # Processing for soft-decisions else: # Real-valued representation if self._use_real_rep: llr = self.list2llr([y, r, dists, path_inds, path_syms]) llr = self._unsort(column_order, llr, transpose=False) # Combine LLRs from PAM symbols in the correct order llr1 = llr[:,:self._num_streams//2] llr2 = llr[:,self._num_streams//2:] llr1 = tf.expand_dims(llr1, -1) llr2 = tf.expand_dims(llr2, -1) llr = tf.concat([llr1, llr2], -1) llr = tf.reshape(llr, [-1, self._num_streams//2, 2*self._num_bits_per_symbol]) # Complex-valued representation else: llr = self.list2llr([y, r, dists, path_inds, path_syms]) llr = self._unsort(column_order, llr, transpose=False) # Reshape batch dimensions if num_batch_dims > 1: llr = split_dim(llr, batch_shape, 0) return llr
[docs]class EPDetector(Layer): # pylint: disable=line-too-long r"""EPDetector(output, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64) MIMO Expectation Propagation (EP) detector This layer implements Expectation Propagation (EP) MIMO detection as described in [EP2014]_. It can generate hard- or soft-decisions for symbols or bits. This layer assumes the following channel model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, where :math:`\mathbf{S}` has full rank. The channel model is first whitened using :func:`~sionna.mimo.whiten_channel` and then converted to its real-valued equivalent, see :func:`~sionna.mimo.complex2real_channel`, prior to MIMO detection. The computation of LLRs is done by converting the symbol logits that naturally arise in the algorithm to LLRs using :func:`~sionna.mapping.PAM2QAM`. Custom conversions of symbol logits to LLRs can be implemented by using the soft-symbol output. Parameters ----------- output : One of ["bit", "symbol"], str The type of output, either bits or symbols. Whether soft- or hard-decisions are returned can be configured with the ``hard_out`` flag. num_bits_per_symbol : int The number of bits per QAM constellation symbol, e.g., 4 for QAM16. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. l : int Number of iterations. Defaults to 10. beta : float Parameter :math:`\beta\in[0,1]` for update smoothing. Defaults to 0.9. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) Precision used for internal computations. Defaults to ``tf.complex64``. Especially for large MIMO setups, the precision can make a significant performance difference. Input ----- (y, h, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals h : [...,M,num_streams], tf.complex 2+D tensor containing the channel matrices s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices Output ------ One of: : [...,num_streams,num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` : [...,num_streams,2**num_bits_per_symbol], tf.float or [...,num_streams], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` Note ---- For numerical stability, we do not recommend to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``. However, it is possible to do so by setting ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, output, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) assert dtype in [tf.complex64, tf.complex128], \ "Invalid dtype" self._cdtype = tf.dtypes.as_dtype(dtype) self._rdtype = self._cdtype.real_dtype # Variable used to avoid numerical instabilities # See paragraph after Eq. (38) if self.dtype=="complex64": self._prec = 1e-6 else: self._prec = 1e-12 assert output in ("bit", "symbol"), "Unknown output" self._output = output self._hard_out = hard_out if self._output=="symbol": self._pam2qam = PAM2QAM(num_bits_per_symbol, hard_out) else: self._symbollogits2llrs = SymbolLogits2LLRs("maxlog", num_bits_per_symbol//2, hard_out=hard_out) self._demapper = Demapper("maxlog", "pam", num_bits_per_symbol//2) assert l>=1, "l must be a positive integer" self._l = l assert 0.0<= beta <=1.0, "beta must be in [0,1]" self._beta = beta # Create PAM constellations for real-valued detection self._num_bits_per_symbol = num_bits_per_symbol//2 points = Constellation("pam", int(self._num_bits_per_symbol)).points # Scale constellation points to half the energy because QAM is assumed self._points = tf.cast(points/np.sqrt(2.0), self._rdtype) # Average symbol energy self._es = tf.constant(np.var(self._points), self._rdtype) def compute_sigma_mu(self, h_t_h, h_t_y, no, lam, gam): """Equations (28) and (29)""" # Prepare inputs lam = tf.linalg.diag(lam) gam = tf.expand_dims(gam, axis=-1) # Computations sigma = tf.linalg.inv(h_t_h + no*lam) mu = tf.squeeze(tf.matmul(sigma, h_t_y + no*gam), axis=-1) sigma *= no sigma = tf.linalg.diag_part(sigma) return sigma, mu def compute_v_x_obs(self, sigma, mu, lam, gam): """Equations (31) and (32)""" v_obs = tf.maximum(1/(1/sigma-lam), self._prec) x_obs = v_obs*(mu/sigma-gam) return v_obs, x_obs def compute_v_x(self, v_obs, x_obs): """Equation (33)""" # Compute probability mass function for the symbols x_obs = tf.expand_dims(x_obs, -1) v_obs = tf.expand_dims(v_obs, -1) points = expand_to_rank(self._points, tf.rank(x_obs), axis=0) logits = -tf.pow(x_obs-points, 2) / (tf.cast(2, self._rdtype)*v_obs) pmf = tf.math.softmax(logits) # Compute mean and variance of all symbols x = tf.reduce_sum(points * pmf, axis=-1, keepdims=True) v = tf.reduce_sum((points-x)**2 * pmf, axis=-1) v = tf.maximum(v, self._prec) x = tf.squeeze(x, axis=-1) return v, x, logits def update_lam_gam(self, v, v_obs, x, x_obs, lam, gam): """Equations (35), (36), (37), (38)""" # Save old values of lam, and gam lam_old = lam gam_old = gam # Compute potential new values (35), (36) lam = 1/v - 1/v_obs gam = x/v - x_obs/v_obs # Only update nonnegative values lam_new = tf.where(lam<0, lam_old, lam) gam_new = tf.where(lam<0, gam_old, gam) # Damp updates (37), (38) lam_damp = (1-self._beta)*lam_new + self._beta*lam_old gam_damp = (1-self._beta)*gam_new + self._beta*gam_old return lam_damp, gam_damp def call(self, inputs): # Flatten the batch dimensions y, h, s = inputs batch_shape = tf.shape(y)[:-1] num_batch_dims = len(batch_shape) if num_batch_dims > 1: y = flatten_dims(y, num_batch_dims, 0) h = flatten_dims(h, num_batch_dims, 0) s = flatten_dims(s, num_batch_dims, 0) inputs = (y,h,s) # Number of transmit streams n_t = tf.shape(h)[-1] # Whiten channel y, h, s = whiten_channel(y, h, s) # Convert channel to real-valued representation y, h, s = complex2real_channel(y,h,s) # Convert all inputs to desired dtypes y = tf.cast(y, self._rdtype) h = tf.cast(h, self._rdtype) no = tf.cast(0.5, self._rdtype) # Gather relevant parameters batch_dims = tf.shape(y)[:-1] n_t_r = tf.shape(h)[-1] # Initialize gamma and lambda (Paragraph after Eq. (29)) gam = tf.zeros(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype) lam = tf.ones(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype) lam /= tf.cast(self._es, y.dtype) # Precompute values that are repeatedly needed h_t_h = tf.matmul(h, h, transpose_a=True) y = tf.expand_dims(y, axis=-1) h_t_y = tf.matmul(h, y, transpose_a=True) no = expand_to_rank(no, tf.rank(h), axis=-1) for _ in range(self._l): sigma, mu = self.compute_sigma_mu(h_t_h, h_t_y, no, lam, gam) v_obs, x_obs = self.compute_v_x_obs(sigma, mu, lam, gam) v, x, logits = self.compute_v_x(v_obs, x_obs) lam, gam = self.update_lam_gam(v, v_obs, x, x_obs, lam, gam) # Extract the logits for the 2 PAM constellations for each streams pam1_logits = logits[...,:n_t,:] pam2_logits = logits[...,n_t:,:] if self._output=="symbol" and self._hard_out: # Take hard decisions on PAM symbol;s pam1_ind = tf.argmax(pam1_logits, axis=-1, output_type=tf.int32) pam2_ind = tf.argmax(pam2_logits, axis=-1, output_type=tf.int32) # Transform to QAM indices qam_ind = self._pam2qam(pam1_ind, pam2_ind) # Reshape batch dimensions if num_batch_dims > 1: qam_ind = split_dim(qam_ind, batch_shape, 0) return qam_ind elif self._output=="symbol" and not self._hard_out: qam_logits = self._pam2qam(pam1_logits, pam2_logits) # Reshape batch dimensions if num_batch_dims > 1: qam_logits = split_dim(qam_logits, batch_shape, 0) return qam_logits elif self._output=="bit": # Compute LLRs for both PAM constellations llr1 = self._symbollogits2llrs(pam1_logits) llr2 = self._symbollogits2llrs(pam2_logits) # Put LLRs in the correct order and shape llr = tf.stack([llr1, llr2], -1) llr = flatten_last_dims(llr) # Reshape batch dimensions if num_batch_dims > 1: llr = split_dim(llr, batch_shape, 0) return llr
[docs]class MMSEPICDetector(Layer): # pylint: disable=line-too-long r"""MMSEPICDetector(output, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) Minimum mean square error (MMSE) with parallel interference cancellation (PIC) detector This layer implements the MMSE PIC detector, as proposed in [CST2011]_. For ``num_iter``>1, this implementation performs MMSE PIC self-iterations. MMSE PIC self-iterations can be understood as a concatenation of MMSE PIC detectors from [CST2011]_, which forward intrinsic LLRs to the next self-iteration. Compared to [CST2011]_, this implementation also accepts priors on the constellation symbols as an alternative to priors on the bits. This layer assumes the following channel model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, where :math:`\mathbf{S}` has full rank. The algorithm starts by computing the soft symbols :math:`\bar{x}_s=\mathbb{E}\left[ x_s \right]` and variances :math:`v_s=\mathbb{E}\left[ |e_s|^2\right]` from the priors, where :math:`e_s = x_s - \bar{x}_s`, for all :math:`s=1,\dots,S`. Next, for each stream, the interference caused by all other streams is cancelled from the observation :math:`\mathbf{y}`, leading to .. math:: \hat{\mathbf{y}}_s = \mathbf{y} - \sum_{j\neq s} \mathbf{h}_j x_j = \mathbf{h}_s x_s + \tilde{\mathbf{n}}_s,\quad s=1,\dots,S where :math:`\tilde{\mathbf{n}}_s=\sum_{j\neq s} \mathbf{h}_j e_j + \mathbf{n}`. Then, a linear MMSE filter :math:`\mathbf{w}_s` is computed to reduce the resdiual noise for each observation :math:`\hat{\mathbf{y}}_s`, which is given as .. math:: \mathbf{w}_s = \mathbf{h}_s^{\mathsf{H}}\left( \mathbf{H} \mathbf{D}_s\mathbf{H}^{\mathsf{H}} +\mathbf{S} \right)^{-1} where :math:`\mathbf{D}_s \in \mathbb{C}^{S\times S}` is diagonal with entries .. math:: \left[\mathbf{D}_s\right]_{i,i} = \begin{cases} v_i & i\neq s \\ 1 & i=s. \end{cases} The filtered observations .. math:: \tilde{z}_s = \mathbf{w}_s^{\mathsf{H}} \hat{\mathbf{y}}_s = \tilde{\mu}_s x_s + \mathbf{w}_s^{\mathsf{H}}\tilde{\mathbf{n}}_s where :math:`\tilde{\mu}_s=\mathbf{w}_s^{\mathsf{H}} \mathbf{h}_s`, are then demapped to either symbol logits or LLRs, assuming that the remaining noise is Gaussian with variance .. math:: \nu_s^2 = \mathop{\text{Var}}\left[\tilde{z}_s\right] = \mathbf{w}_s^{\mathsf{H}} \left(\sum_{j\neq s} \mathbf{h}_j \mathbf{h}_j^{\mathsf{H}} v_j +\mathbf{S} \right)\mathbf{w}_s. The resulting soft-symbols can then be used for the next self-iteration of the algorithm. Note that this algorithm can be substantially simplified as described in [CST2011]_ to avoid the computation of different matrix inverses for each stream. This is the version which is implemented. Parameters ----------- output : One of ["bit", "symbol"], str The type of output, either LLRs on bits or logits on constellation symbols. demapping_method : One of ["app", "maxlog"], str The demapping method used. Defaults to "maxlog". num_iter : int Number of MMSE PIC iterations. Defaults to 1. constellation_type : One of ["qam", "pam", "custom"], str For "custom", an instance of :class:`~sionna.mapping.Constellation` must be provided. num_bits_per_symbol : int The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for ``constellation_type`` in ["qam", "pam"]. constellation : Constellation An instance of :class:`~sionna.mapping.Constellation` or `None`. In the latter case, ``constellation_type`` and ``num_bits_per_symbol`` must be provided. hard_out : bool If `True`, the detector computes hard-decided bit values or constellation point indices instead of soft-values. Defaults to `False`. dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) The dtype of ``y``. Defaults to tf.complex64. The output dtype is the corresponding real dtype (tf.float32 or tf.float64). Input ----- (y, h, prior, s) : Tuple: y : [...,M], tf.complex 1+D tensor containing the received signals h : [...,M,S], tf.complex 2+D tensor containing the channel matrices prior : [...,S,num_bits_per_symbol] or [...,S,num_points], tf.float Prior of the transmitted signals. If ``output`` equals "bit", then LLRs of the transmitted bits are expected. If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. s : [...,M,M], tf.complex 2+D tensor containing the noise covariance matrices Output ------ One of: : [...,S,num_bits_per_symbol], tf.float LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` : [...,S,2**num_bits_per_symbol], tf.float or [...,S], tf.int Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` Note ---- For numerical stability, we do not recommend to use this function in Graph mode with XLA, i.e., within a function that is decorated with ``@tf.function(jit_compile=True)``. However, it is possible to do so by setting ``sionna.Config.xla_compat=true``. See :py:attr:`~sionna.Config.xla_compat`. """ def __init__(self, output, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs): super().__init__(dtype=dtype, **kwargs) assert isinstance(num_iter, int), "num_iter must be an integer" assert output in ("bit", "symbol"), "Unknown output" assert demapping_method in ("app", "maxlog"), "Unknown demapping method" assert dtype in [tf.complex64, tf.complex128], \ "dtype must be tf.complex64 or tf.complex128" self._num_iter = num_iter self._output = output self._epsilon = 1e-4 self._realdtype = dtype.real_dtype self._demapping_method = demapping_method self._hard_out = hard_out # Create constellation object self._constellation = Constellation.create_or_check_constellation( constellation_type, num_bits_per_symbol, constellation, dtype=dtype) # Soft symbol mapping self._llr_2_symbol_logits = LLRs2SymbolLogits( self._constellation.num_bits_per_symbol, dtype=self._realdtype) if self._output == "symbol": self._llr_2_symbol_logits_output = LLRs2SymbolLogits( self._constellation.num_bits_per_symbol, dtype=self._realdtype, hard_out=hard_out) self._symbol_logits_2_llrs = SymbolLogits2LLRs( method=demapping_method, num_bits_per_symbol=self._constellation.num_bits_per_symbol) self._symbol_logits_2_moments = SymbolLogits2Moments( constellation=self._constellation, dtype=self._realdtype) # soft output demapping self._bit_demapper = DemapperWithPrior( demapping_method=demapping_method, constellation=self._constellation, dtype=dtype) def call(self, inputs): y, h, prior, s = inputs # y is unwhitened receive signal # [..., M] # h the channel estimate # [..., M, K] # prior is either the soft input LLRs # [..., K, num_bits_per_symbol] or symbol logits [..., K, Q] # s the noise covariance matrix # [..., M, M] ## Preprocessing # Whiten channel # y : [..., M] # s : [..., M, M] y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=unbalanced-tuple-unpacking # matched filtering of y # [..., K, 1] y_mf = insert_dims(tf.linalg.matvec(h, y, adjoint_a=True), num_dims=1, axis=-1) ## Step 1: compute Gramm matrix # [..., K, K] g = tf.matmul(h, h, adjoint_a=True) # For XLA compatibility, this implementation performs the MIMO # equalization in the real-valued domain # [..., 2M, 2K] hr = complex2real_matrix(h) # [..., 2K, 2K] gr = tf.matmul(hr, hr, adjoint_a=True) # Compute a priori LLRs if self._output == "symbol": llr_a = self._symbol_logits_2_llrs(prior) else: llr_a = prior # llr_a is [..., K, num_bits_per_symbol] llr_shape = tf.shape(llr_a) def mmse_pic_self_iteration(llr_d, llr_a, it): # MMSE PIC takes in a priori LLRs llr_a = llr_d # Step 2: compute soft symbol estimates and variances # x_hat, var_x : [..., K] x_logits = self._llr_2_symbol_logits(llr_a) x_hat, var_x = self._symbol_logits_2_moments(x_logits) # Step 3: perform parallel interference cancellation # H^H y_hat_i = y_mf - sum_j!=i gj x_hat_j = y + g_i x_hat_i # - sum_j g_j x_hat_j # [..., K, K] y_mf_pic = y_mf + g * insert_dims(x_hat, num_dims=1, axis=-2) \ - tf.linalg.matmul(g, insert_dims(x_hat, num_dims=1, axis=-1)) # Step 4: compute A^-1 matrix # Calculate MMSE Filter (efficiently) # W^H = A^-1 H^H # A = H^H H \Lambda + N_0 I_Mt # \Lambda_ii is a diagonal matrix with \Lambda_ii = E_i = error_var # Stack error variances and make it real # Note: Imaginary part is zero var_x = tf.cast(tf.concat([var_x, var_x], axis=-1), dtype=self._realdtype) var_x_row_vec = insert_dims(var_x, num_dims=1, axis=-2) # [..., 2K, 2K] a = gr * var_x_row_vec i = expand_to_rank(tf.eye(tf.shape(a)[-1], dtype=a.dtype), tf.rank(a), 0) a = a + i # a is non-hermitian! that's why we can't use sn.utils.matrix_inv # XLA can't invert complex matrices, that's why we work with the # real valued domain a_inv = tf.linalg.inv(a) # Step 5: compute unbiased MMSE filter and outputs, calculate A\H^H # Calculate bias mu_i = diag(A^-1 H^H H) = diag(A^-1 G) # Diagonal elements of matrix matrix multiplication simplified # to sum and dot-product # [..., 2K] mu = tf.reduce_sum(a_inv * tf.linalg.matrix_transpose(gr), axis=-1) # Make y_mf_pic columns real (after transposition, # the last dimension corresponds to vectors) # [..., K, 2K] y_mf_pic_trans = tf.linalg.matrix_transpose(y_mf_pic) y_mf_pic_trans = complex2real_vector(y_mf_pic_trans) # stack them such that y_mf_pic_trans has shape [..., 2K, 2K] y_mf_pic_trans = tf.concat([y_mf_pic_trans, y_mf_pic_trans], axis=-2) # Efficient parallel equalization after PIC # z_i = i'th row of a_inv * y_MF_PIC_i # boils down to tf.reduce_sum(a_inv * y_mf_pic_trans, axis=-1) # divide by mu_i for unbiasedness # [..., K] x_hat = real2complex_vector(tf.reduce_sum(a_inv * y_mf_pic_trans, axis=-1) / tf.cast(mu, dtype=a_inv.dtype)) # Compute post equalization signal error estimate: # rho_i = mu_i / (1 - var_x_i * mu_i) # 1 - var_x_i * mu_i can become numerically 0, or even slightly # smaller than zero due to limited numerical precision # [..., 2K] var_x = tf.divide(mu, tf.maximum(1 - var_x * mu, self._epsilon)) # real variances map to the same complex valued variances in this # model var_x, _ = tf.split(var_x, 2, -1) no_eff = 1. / var_x # Step 6: LLR demapping (extrinsic LLRs) # [..., K, num_bits_per_symbols] llr_d = tf.reshape(self._bit_demapper([x_hat, llr_a, no_eff]), llr_shape) return llr_d, llr_a, it # Stopping condition (required for tf.while_loop) def dec_stop(llr_d, llr_a, it): # pylint: disable=W0613 return tf.less(it, self._num_iter) # start decoding iterations it = tf.constant(0) null_prior = tf.zeros(llr_shape, dtype=self._realdtype) llr_d, llr_a, _ = tf.while_loop(dec_stop, mmse_pic_self_iteration, (llr_a, null_prior, it), parallel_iterations=1, maximum_iterations=self._num_iter) llr_e = llr_d - llr_a if self._output == "symbol": # convert back to symbols if requested. # output symbol logits computed on extrinsic LLRs out = self._llr_2_symbol_logits_output(llr_e) else: # output extrinsic LLRs out = llr_e if self._hard_out: out = hard_decisions(out) return out