Source code for sionna.phy.mimo.equalization

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Classes and functions related to MIMO channel equalization."""

from typing import Optional, Tuple
import torch

from sionna.phy.config import config, dtypes, Precision
from sionna.phy.utils import expand_to_rank, matrix_pinv
from sionna.phy.mimo.utils import whiten_channel

__all__ = [
    "lmmse_matrix",
    "lmmse_equalizer",
    "zf_equalizer",
    "mf_equalizer",
]


[docs] def lmmse_matrix( h: torch.Tensor, s: Optional[torch.Tensor] = None, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""MIMO LMMSE Equalization matrix. This function computes the LMMSE equalization matrix for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. This function returns the LMMSE equalization matrix: .. math:: \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}. If :math:`\mathbf{S}=\mathbf{I}_M`, a numerically more stable version of the equalization matrix is computed: .. math:: \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H} + \mathbf{I}\right)^{-1}\mathbf{H}^{\mathsf{H}} . :param h: Channel matrices with shape [..., M, K] :param s: Noise covariance matrices with shape [..., M, M]. If `None`, the noise is assumed to be white with unit variance. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output g: [..., K, M], `torch.complex`. LMMSE equalization matrices. .. rubric:: Examples .. code-block:: python h = torch.complex(torch.randn(4, 2), torch.randn(4, 2)) g = lmmse_matrix(h) # g.shape = torch.Size([2, 4]) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] h = h.to(dtype=cdtype) s_none = s is None if s is not None: s = s.to(dtype=cdtype) else: # Identity matrix with proper batch dimensions k = h.shape[-1] s = torch.eye(k, dtype=cdtype, device=h.device) s = expand_to_rank(s, h.dim(), 0) if not s_none: # Compute g = h^H @ (h @ h^H + s)^-1 # hhs = h @ h^H + s # Note that hhs^H = hhs, hence it admits a Cholesky decomposition hhs = h @ h.mH + s # Solve hhs @ g_t = h for g_t using Cholesky # Use cholesky_ex with check_errors=False for CUDA graph compatibility chol, _ = torch.linalg.cholesky_ex(hhs, check_errors=False) # cholesky_solve: solve L L^H x = b # First solve L y = h, then L^H g_t = y y = torch.linalg.solve_triangular(chol, h, upper=False) g_t = torch.linalg.solve_triangular(chol.mH, y, upper=True) # Compute g = g_t^H = (hhs^-1 @ h)^H = h^H @ hhs^-1 g = g_t.mH else: # Compute g = (h^H @ h + I)^-1 @ h^H hhs = h.mH @ h + s # Use cholesky_ex with check_errors=False for CUDA graph compatibility chol, _ = torch.linalg.cholesky_ex(hhs, check_errors=False) h_h = h.mH y = torch.linalg.solve_triangular(chol, h_h, upper=False) g = torch.linalg.solve_triangular(chol.mH, y, upper=True) return g
[docs] def lmmse_equalizer( y: torch.Tensor, h: torch.Tensor, s: torch.Tensor, whiten_interference: bool = True, precision: Optional[Precision] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""MIMO LMMSE Equalizer. This function implements LMMSE equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Lemma B.19) :cite:p:`BHS2017` : .. math:: \hat{\mathbf{x}} = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}\mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \mathbf{H}^{\mathsf{H}} \left(\mathbf{H}\mathbf{H}^{\mathsf{H}} + \mathbf{S}\right)^{-1}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of .. math:: \mathop{\text{diag}}\left(\mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right]\right) = \mathop{\text{diag}}\left(\mathbf{G}\mathbf{H} \right)^{-1} - \mathbf{I}. Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{G}\mathbf{H}\right)^{-1}` is important for the :class:`~sionna.phy.mapping.Demapper` although it does not change the signal-to-noise ratio. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. :param y: Received signals with shape [..., M] :param h: Channel matrices with shape [..., M, K] :param s: Noise covariance matrices with shape [..., M, M] :param whiten_interference: If `True`, the interference is first whitened before equalization. In this case, an alternative expression for the receive filter is used that can be numerically more stable. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output x_hat: [..., K], `torch.complex`. Estimated symbol vectors. :output no_eff: [..., K], `torch.float`. Effective noise variance for each stream. .. rubric:: Examples .. code-block:: python y = torch.complex(torch.randn(4), torch.randn(4)) h = torch.complex(torch.randn(4, 2), torch.randn(4, 2)) s = torch.eye(4, dtype=torch.complex64) x_hat, no_eff = lmmse_equalizer(y, h, s) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] y = y.to(dtype=cdtype) h = h.to(dtype=cdtype) s = s.to(dtype=cdtype) # LMMSE estimate of x: # x_hat = diag(GH)^(-1) @ G @ y # with G = H^H @ (H @ H^H + S)^(-1) # # This leads to the per-symbol model: # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # diag(E[ee^H]) = diag(GH)^(-1) - I if not whiten_interference: # Compute equalizer matrix G g = lmmse_matrix(h, s, precision=precision) else: # Whiten channel y, h = whiten_channel(y, h, s, return_s=False) # Compute equalizer matrix G g = lmmse_matrix(h, s=None, precision=precision) # Compute G @ y gy = (g @ y.unsqueeze(-1)).squeeze(-1) # Compute G @ H gh = g @ h # Compute diag(G @ H) d = torch.diagonal(gh, dim1=-2, dim2=-1) # Compute x_hat = diag(G @ H)^-1 @ G @ y x_hat = gy / d # Compute residual error variance no_eff = (1.0 / d - 1.0).real return x_hat, no_eff
[docs] def zf_equalizer( y: torch.Tensor, h: torch.Tensor, s: torch.Tensor, precision: Optional[Precision] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""MIMO ZF Equalizer. This function implements zero-forcing (ZF) equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Eq. 4.10) :cite:p:`BHS2017` : .. math:: \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of the matrix .. math:: \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] = \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. :param y: Received signals with shape [..., M] :param h: Channel matrices with shape [..., M, K] :param s: Noise covariance matrices with shape [..., M, M] :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output x_hat: [..., K], `torch.complex`. Estimated symbol vectors. :output no_eff: [..., K], `torch.float`. Effective noise variance for each stream. .. rubric:: Examples .. code-block:: python y = torch.complex(torch.randn(8), torch.randn(8)) h = torch.complex(torch.randn(8, 4), torch.randn(8, 4)) s = torch.eye(8, dtype=torch.complex64) x_hat, no_eff = zf_equalizer(y, h, s) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] y = y.to(dtype=cdtype) h = h.to(dtype=cdtype) s = s.to(dtype=cdtype) # ZF estimate of x: # x_hat = G @ y # with G = (H^H @ H)^(-1) @ H^H (Moore-Penrose pseudoinverse) # # This leads to the per-symbol model: # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # E[e @ e^H] = G @ S @ G^H # Compute G (pseudoinverse) g = matrix_pinv(h) # Compute x_hat x_hat = (g @ y.unsqueeze(-1)).squeeze(-1) # Compute residual error variance gsg = g @ s @ g.mH no_eff = torch.diagonal(gsg, dim1=-2, dim2=-1).real return x_hat, no_eff
[docs] def mf_equalizer( y: torch.Tensor, h: torch.Tensor, s: torch.Tensor, precision: Optional[Precision] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""MIMO Matched Filter (MF) Equalizer. This function implements matched filter (MF) equalization for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector. It is assumed that :math:`\mathbb{E}\left[\mathbf{x}\right]=\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}`, :math:`\mathbb{E}\left[\mathbf{x}\mathbf{x}^{\mathsf{H}}\right]=\mathbf{I}_K` and :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`. The estimated symbol vector :math:`\hat{\mathbf{x}}\in\mathbb{C}^K` is given as (Eq. 4.11) :cite:p:`BHS2017` : .. math:: \hat{\mathbf{x}} = \mathbf{G}\mathbf{y} where .. math:: \mathbf{G} = \mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}\mathbf{H}^{\mathsf{H}}. This leads to the post-equalized per-symbol model: .. math:: \hat{x}_k = x_k + e_k,\quad k=0,\dots,K-1 where the variances :math:`\sigma^2_k` of the effective residual noise terms :math:`e_k` are given by the diagonal elements of the matrix .. math:: \mathbb{E}\left[\mathbf{e}\mathbf{e}^{\mathsf{H}}\right] = \left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)\left(\mathbf{I}-\mathbf{G}\mathbf{H} \right)^{\mathsf{H}} + \mathbf{G}\mathbf{S}\mathbf{G}^{\mathsf{H}}. Note that the scaling by :math:`\mathop{\text{diag}}\left(\mathbf{H}^{\mathsf{H}}\mathbf{H}\right)^{-1}` in the definition of :math:`\mathbf{G}` is important for the :class:`~sionna.phy.mapping.Demapper` although it does not change the signal-to-noise ratio. The function returns :math:`\hat{\mathbf{x}}` and :math:`\boldsymbol{\sigma}^2=\left[\sigma^2_0,\dots, \sigma^2_{K-1}\right]^{\mathsf{T}}`. :param y: Received signals with shape [..., M] :param h: Channel matrices with shape [..., M, K] :param s: Noise covariance matrices with shape [..., M, M] :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output x_hat: [..., K], `torch.complex`. Estimated symbol vectors. :output no_eff: [..., K], `torch.float`. Effective noise variance for each stream. .. rubric:: Examples .. code-block:: python y = torch.complex(torch.randn(8), torch.randn(8)) h = torch.complex(torch.randn(8, 4), torch.randn(8, 4)) s = torch.eye(8, dtype=torch.complex64) x_hat, no_eff = mf_equalizer(y, h, s) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] y = y.to(dtype=cdtype) h = h.to(dtype=cdtype) s = s.to(dtype=cdtype) # MF estimate of x: # x_hat = G @ y # with G = diag(H^H @ H)^-1 @ H^H # # This leads to the per-symbol model: # x_hat_k = x_k + e_k # # The elements of the residual noise vector e have variance: # E[e @ e^H] = (I - G @ H) @ (I - G @ H)^H + G @ S @ G^H # Compute G hth = h.mH @ h d = torch.diag_embed(1.0 / torch.diagonal(hth, dim1=-2, dim2=-1)) g = d @ h.mH # Compute x_hat x_hat = (g @ y.unsqueeze(-1)).squeeze(-1) # Compute residual error variance gsg = g @ s @ g.mH gh = g @ h k = gh.shape[-1] i = torch.eye(k, dtype=gsg.dtype, device=gsg.device) i = expand_to_rank(i, gsg.dim(), 0) i_minus_gh = i - gh no_eff = torch.abs(torch.diagonal(i_minus_gh @ i_minus_gh.mH + gsg, dim1=-2, dim2=-1)) return x_hat, no_eff