Source code for sionna.phy.mimo.equalization

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Classes and functions related to MIMO channel equalization"""

import tensorflow as tf
from sionna.phy import config, dtypes
from sionna.phy.utils import expand_to_rank, matrix_pinv
from sionna.phy.mimo.utils import whiten_channel

[docs] def lmmse_matrix(h, s=None, precision=None): # pylint: disable=line-too-long r"""MIMO LMMSE Equalization matrix This function computes the LMMSE equalization matrix for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. This function returns the LLMSE equalization matrix: .. math:: \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}. If :math:`\mathbf{S}=\mathbf{I}_M`, a numerically more stable version of the equalization matrix is computed: .. math:: \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H} + \mathbf{I}\right)^{-1}\mathbf{H}^{\mathsf{H}} . Input ----- h : [...,M,K], `tf.complex` Channel matrices s : `None` (default) | [...,M,M], `tf.complex` Noise covariance matrices. If `None`, the noise is assumed to be white with unit variance. precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. Output ------ g : [...,K,M], `tf.complex` LLMSE equalization matrices """ # Cast inputs if precision is None: cdtype = config.tf_cdtype else: cdtype = dtypes[precision]["tf"]["cdtype"] h = tf.cast(h, dtype=cdtype) if s is not None: s = tf.cast(s, dtype=cdtype) s_none = False else: s = expand_to_rank(tf.eye(h.shape[-1], dtype=h.dtype), tf.rank(h), 0) s_none = True if not s_none: #------------------------------------# # Compute g = h^* @ (h @ h^* + s)^-1 # #------------------------------------# # hhs = h @ h^* + s. # Note that hhs^* = hhs, hence it admits a Cholesky decomposition hhs = tf.matmul(h, h, adjoint_b=True) + s # Solve hhs @ g_t = h in the unknown g_t chol = tf.linalg.cholesky(hhs) g_t = tf.linalg.cholesky_solve(chol, h) # Compute g = g_t^* = (hhs^-1 @ h)^* = h^* @ hhs^-1 g = tf.linalg.adjoint(g_t) else: #------------------------------------# # Compute g = (h^* @ h + I)^-1 @ h^* # #------------------------------------# hhs = tf.matmul(h, h, adjoint_a=True) + s chol = tf.linalg.cholesky(hhs) g = tf.linalg.cholesky_solve(chol, tf.linalg.adjoint(h)) return g
[docs] def lmmse_equalizer(y, h, s, whiten_interference=True, precision=None): # pylint: disable=line-too-long r""" MIMO LMMSE Equalizer This function implements LMMSE equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Lemma B.19) [BHS2017]_ : .. math:: \hat{\mathbf{x}} = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}\mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of .. math:: \mathop{\text{diag}}\left(\mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]\right) = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H} \right)^{-1} - \mathbf{I}. Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}` is important for the :class:`~sionna.phy.mapping.Demapper` although it does not change the signal-to-noise ratio. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. Input ----- y : [...,M], `tf.complex` Received signals h : [...,M,K], `tf.complex` Channel matrices s : [...,M,M], `tf.complex` Noise covariance matrices whiten_interference : `bool`, (default `True`) If `True`, the interference is first whitened before equalization. In this case, an alternative expression for the receive filter is used that can be numerically more stable. precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. Output ------ x_hat : [...,K], `tf.complex` Estimated symbol vectors no_eff : `tf.float` Effective noise variance estimates """ # Cast inputs if precision is None: cdtype = config.tf_cdtype else: cdtype = dtypes[precision]["tf"]["cdtype"] y = tf.cast(y, dtype=cdtype) h = tf.cast(h, dtype=cdtype) s = tf.cast(s, dtype=cdtype) # We assume the model: # y = Hx + n, where E[nn']=S. # E[x]=E[n]=0 # # The LMMSE estimate of x is given as: # x_hat = diag(GH)^(-1)Gy # with G=H'(HH'+S)^(-1). # # This leads us to the per-symbol model; # # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # diag(E[ee']) = diag(GH)^(-1) - I if not whiten_interference: # Compute equalizer matrix G g = lmmse_matrix(h, s, precision=precision) else: # Whiten channel y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=unbalanced-tuple-unpacking # Compute equalizer matrix G g = lmmse_matrix(h, s=None, precision=precision) # Compute G @ y y = tf.expand_dims(y, -1) gy = tf.squeeze(tf.matmul(g, y), axis=-1) # Compute G @ H gh = tf.matmul(g, h) # Compute diag(G @ H) d = tf.linalg.diag_part(gh) # Compute x_hat = diag(G @ H)^-1 @ G @ y x_hat = gy / d # Compute residual error variance one = tf.cast(1, dtype=d.dtype) no_eff = tf.math.real(one/d - one) return x_hat, no_eff
[docs] def zf_equalizer(y, h, s, precision=None): # pylint: disable=line-too-long r"""Applies MIMO ZF Equalizer This function implements zero-forcing (ZF) equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Eq. 4.10) [BHS2017]_ : .. math:: \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of the matrix .. math:: \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] = \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. Input ----- y : [...,M], `tf.complex` Received signals h : [...,M,K], `tf.complex` Channel matrices s : [...,M,M], `tf.complex` Noise covariance matrices precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. Output ------ x_hat : [...,K], `tf.complex` Estimated symbol vectors no_eff : tf.float Effective noise variance estimates """ # Cast inputs if precision is None: cdtype = config.tf_cdtype else: cdtype = dtypes[precision]["tf"]["cdtype"] y = tf.cast(y, dtype=cdtype) h = tf.cast(h, dtype=cdtype) s = tf.cast(s, dtype=cdtype) # We assume the model: # y = Hx + n, where E[nn']=S. # E[x]=E[n]=0 # # The ZF estimate of x is given as: # x_hat = Gy # with G=(H'H')^(-1)H'. # # This leads us to the per-symbol model; # # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # E[ee'] = GSG' # Compute G g = matrix_pinv(h) # Compute x_hat y = tf.expand_dims(y, -1) x_hat = tf.squeeze(tf.matmul(g, y), axis=-1) # Compute residual error variance gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True) no_eff = tf.math.real(tf.linalg.diag_part(gsg)) return x_hat, no_eff
[docs] def mf_equalizer(y, h, s, precision=None): # pylint: disable=line-too-long r"""MIMO Matched Filter (MF) Equalizer This function implements matched filter (MF) equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Eq. 4.11) [BHS2017]_ : .. math:: \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of the matrix .. math:: \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] = \left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)\left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)^{\mathsf{H}} + \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}` in the definition of :math:`\mathbf{G}` is important for the :class:`~sionna.phy.mapping.Demapper` although it does not change the signal-to-noise ratio. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. Input ----- y : [...,M], `tf.complex` Received signals h : [...,M,K], `tf.complex` Channel matrices s : [...,M,M], `tf.complex` Noise covariance matrices precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. Output ------ x_hat : [...,K], `tf.complex` Estimated symbol vectors no_eff : tf.float Effective noise variance estimates """ # Cast inputs if precision is None: cdtype = config.tf_cdtype else: cdtype = dtypes[precision]["tf"]["cdtype"] y = tf.cast(y, dtype=cdtype) h = tf.cast(h, dtype=cdtype) s = tf.cast(s, dtype=cdtype) # We assume the model: # y = Hx + n, where E[nn']=S. # E[x]=E[n]=0 # # The MF estimate of x is given as: # x_hat = Gy # with G=diag(H'H)^-1 H'. # # This leads us to the per-symbol model; # # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # E[ee'] = (I-GH)(I-GH)' + GSG' # Compute G hth = tf.matmul(h, h, adjoint_a=True) d = tf.linalg.diag(tf.cast(1, h.dtype)/tf.linalg.diag_part(hth)) g = tf.matmul(d, h, adjoint_b=True) # Compute x_hat y = tf.expand_dims(y, -1) x_hat = tf.squeeze(tf.matmul(g, y), axis=-1) # Compute residual error variance gsg = tf.matmul(tf.matmul(g, s), g, adjoint_b=True) gh = tf.matmul(g, h) i = expand_to_rank(tf.eye(gsg.shape[-2], dtype=gsg.dtype), tf.rank(gsg), 0) no_eff = tf.abs(tf.linalg.diag_part(tf.matmul(i-gh, i-gh, adjoint_b=True) + gsg)) return x_hat, no_eff