Source code for sionna.phy.mimo.precoding

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Classes and functions related to MIMO transmit precoding."""

import math
from typing import Optional, Tuple, Union, List
import torch

from sionna.phy.config import config, dtypes, Precision
from sionna.phy.constants import PI
from sionna.phy.utils import expand_to_rank

__all__ = [
    "rzf_precoding_matrix",
    "cbf_precoding_matrix",
    "rzf_precoder",
    "grid_of_beams_dft_ula",
    "grid_of_beams_dft",
    "flatten_precoding_mat",
    "normalize_precoding_power",
]


[docs] def rzf_precoding_matrix( h: torch.Tensor, alpha: Union[float, torch.Tensor] = 0.0, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""Computes the Regularized Zero-Forcing (RZF) Precoder. This function computes the RZF precoding matrix for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{G}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^K` is the received signal vector, :math:`\mathbf{H}\in\mathbb{C}^{K\times M}` is the known channel matrix, :math:`\mathbf{G}\in\mathbb{C}^{M\times K}` is the precoding matrix, :math:`\mathbf{x}\in\mathbb{C}^K` is the symbol vector to be precoded, and :math:`\mathbf{n}\in\mathbb{C}^K` is a noise vector. The precoding matrix :math:`\mathbf{G}` is defined as: .. math:: \mathbf{G} = \mathbf{V}\mathbf{D} where .. math:: \mathbf{V} &= \mathbf{H}^{\mathsf{H}}\left(\mathbf{H} \mathbf{H}^{\mathsf{H}} + \alpha \mathbf{I} \right)^{-1}\\ \mathbf{D} &= \mathop{\text{diag}}\left( \lVert \mathbf{v}_{k} \rVert_2^{-1}, k=0,\dots,K-1 \right) where :math:`\alpha>0` is the regularization parameter. The matrix :math:`\mathbf{D}` ensures that each stream is precoded with a unit-norm vector, i.e., :math:`\mathop{\text{tr}}\left(\mathbf{G}\mathbf{G}^{\mathsf{H}}\right)=K`. The function returns the matrix :math:`\mathbf{G}`. :param h: Channel matrices with shape [..., K, M] :param alpha: Regularization parameter with shape [...] or scalar :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output g: [..., M, K], `torch.complex`. Precoding matrices. .. rubric:: Examples .. code-block:: python h = torch.complex(torch.randn(4, 8), torch.randn(4, 8)) g = rzf_precoding_matrix(h, alpha=0.1) # g.shape = torch.Size([8, 4]) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] h = h.to(dtype=cdtype) alpha = torch.as_tensor(alpha, dtype=cdtype, device=h.device) # Compute pseudo inverse for precoding g = h @ h.mH alpha = expand_to_rank(alpha, g.dim(), axis=-1) k = g.shape[-1] eye = torch.eye(k, dtype=cdtype, device=g.device) eye = expand_to_rank(eye, g.dim(), 0) g = g + alpha * eye # Cholesky decomposition and solve # Use cholesky_ex with check_errors=False for CUDA graph compatibility l, _ = torch.linalg.cholesky_ex(g, check_errors=False) # Solve L @ L^H @ X = h for X y = torch.linalg.solve_triangular(l, h, upper=False) g = torch.linalg.solve_triangular(l.mH, y, upper=True) g = g.mH # Normalize each column to unit power norm = torch.sqrt((g.abs() ** 2).sum(dim=-2, keepdim=True)) g = torch.where(norm > 0, g / norm, g) return g
[docs] def cbf_precoding_matrix( h: torch.Tensor, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""Computes the conjugate beamforming (CBF) Precoder. This function computes the CBF precoding matrix for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{G}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^K` is the received signal vector, :math:`\mathbf{H}\in\mathbb{C}^{K\times M}` is the known channel matrix, :math:`\mathbf{G}\in\mathbb{C}^{M\times K}` is the precoding matrix, :math:`\mathbf{x}\in\mathbb{C}^K` is the symbol vector to be precoded, and :math:`\mathbf{n}\in\mathbb{C}^K` is a noise vector. The precoding matrix :math:`\mathbf{G}` is defined as: .. math:: \mathbf{G} = \mathbf{V}\mathbf{D} where .. math:: \mathbf{V} &= \mathbf{H}^{\mathsf{H}} \\ \mathbf{D} &= \mathop{\text{diag}}\left( \lVert \mathbf{v}_{k} \rVert_2^{-1}, k=0,\dots,K-1 \right). The matrix :math:`\mathbf{D}` ensures that each stream is precoded with a unit-norm vector, i.e., :math:`\mathop{\text{tr}}\left(\mathbf{G}\mathbf{G}^{\mathsf{H}}\right)=K`. The function returns the matrix :math:`\mathbf{G}`. :param h: Channel matrices with shape [..., K, M] :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output g: [..., M, K], `torch.complex`. Precoding matrices. .. rubric:: Examples .. code-block:: python h = torch.complex(torch.randn(4, 8), torch.randn(4, 8)) g = cbf_precoding_matrix(h) # g.shape = torch.Size([8, 4]) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] h = h.to(dtype=cdtype) # Compute conjugate transpose of channel matrix g = h.mH # Normalize each column to unit power norm = torch.sqrt((g.abs() ** 2).sum(dim=-2, keepdim=True)) g = torch.where(norm > 0, g / norm, g) return g
[docs] def rzf_precoder( x: torch.Tensor, h: torch.Tensor, alpha: Union[float, torch.Tensor] = 0.0, return_precoding_matrix: bool = False, precision: Optional[Precision] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Regularized Zero-Forcing (RZF) Precoder. This function implements RZF precoding for a MIMO link, assuming the following model: .. math:: \mathbf{y} = \mathbf{H}\mathbf{G}\mathbf{x} + \mathbf{n} where :math:`\mathbf{y}\in\mathbb{C}^K` is the received signal vector, :math:`\mathbf{H}\in\mathbb{C}^{K\times M}` is the known channel matrix, :math:`\mathbf{G}\in\mathbb{C}^{M\times K}` is the precoding matrix, :math:`\mathbf{x}\in\mathbb{C}^K` is the symbol vector to be precoded, and :math:`\mathbf{n}\in\mathbb{C}^K` is a noise vector. The precoding matrix :math:`\mathbf{G}` is defined as (Eq. 4.37) :cite:p:`BHS2017`: .. math:: \mathbf{G} = \mathbf{V}\mathbf{D} where .. math:: \mathbf{V} &= \mathbf{H}^{\mathsf{H}}\left(\mathbf{H} \mathbf{H}^{\mathsf{H}} + \alpha \mathbf{I} \right)^{-1}\\ \mathbf{D} &= \mathop{\text{diag}}\left( \lVert \mathbf{v}_{k} \rVert_2^{-1}, k=0,\dots,K-1 \right) where :math:`\alpha>0` is the regularization parameter. This ensures that each stream is precoded with a unit-norm vector, i.e., :math:`\mathop{\text{tr}}\left(\mathbf{G}\mathbf{G}^{\mathsf{H}}\right)=K`. The function returns the precoded vector :math:`\mathbf{G}\mathbf{x}`. :param x: Symbol vectors to be precoded with shape [..., K] :param h: Channel matrices with shape [..., K, M] :param alpha: Regularization parameter with shape [...] or scalar :param return_precoding_matrix: If `True`, the precoding matrices are also returned :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output x_precoded: [..., M], `torch.complex`. Precoded symbol vectors. :output g: [..., M, K], `torch.complex`. Precoding matrices. Only returned if ``return_precoding_matrix=True``. .. rubric:: Examples .. code-block:: python x = torch.complex(torch.randn(4), torch.randn(4)) h = torch.complex(torch.randn(4, 8), torch.randn(4, 8)) x_precoded = rzf_precoder(x, h, alpha=0.1) # x_precoded.shape = torch.Size([8]) """ # Determine dtype if precision is None: cdtype = config.cdtype else: cdtype = dtypes[precision]["torch"]["cdtype"] x = x.to(dtype=cdtype) h = h.to(dtype=cdtype) # Compute the precoding matrix g = rzf_precoding_matrix(h, alpha=alpha, precision=precision) # Precode x_precoded = (g @ x.unsqueeze(-1)).squeeze(-1) if return_precoding_matrix: return x_precoded, g else: return x_precoded
[docs] def grid_of_beams_dft_ula( num_ant: int, oversmpl: int = 1, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""Computes the Discrete Fourier Transform (DFT) Grid of Beam (GoB) coefficients for a uniform linear array (ULA). The coefficient applied to antenna :math:`n` for beam :math:`m` is expressed as: .. math:: c_n^m = e^{\frac{2\pi n m}{N O}}, \quad n=0,\dots,N-1, \ m=0,\dots,NO where :math:`N` is the number of antennas ``num_ant`` and :math:`O` is the oversampling factor ``oversmpl``. Note that the main lobe of beam :math:`m` points in the azimuth direction :math:`\theta = \mathrm{arc sin} \left( 2\frac{m}{N} \right)` if :math:`m\le N/2` and :math:`\theta = \mathrm{arc sin} \left( 2\frac{m-N}{N} \right)` if :math:`m\ge N/2`, where :math:`\theta=0` defines the perpendicular to the antenna array. :param num_ant: Number of antennas :param oversmpl: Oversampling factor :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output gob: [num_ant x oversmpl, num_ant], `torch.complex`. The :math:`m`-th row contains the `num_ant` antenna coefficients for the :math:`m`-th DFT beam. .. rubric:: Examples .. code-block:: python gob = grid_of_beams_dft_ula(num_ant=8, oversmpl=2) # gob.shape = torch.Size([16, 8]) """ if precision is None: rdtype = config.dtype else: rdtype = dtypes[precision]["torch"]["dtype"] oversmpl = int(oversmpl) # Beam indices: [0, .., num_ant * oversmpl - 1] beam_ind = torch.arange(num_ant * oversmpl, dtype=rdtype, device=config.device).unsqueeze(-1) # Antenna indices: [0, .., num_ant - 1] antenna_ind = torch.arange(num_ant, dtype=rdtype, device=config.device).unsqueeze(0) # Compute phases and combine to complex coefficients phases = 2 * PI * beam_ind * antenna_ind / (num_ant * oversmpl) gob = torch.complex(torch.cos(phases), torch.sin(phases)) / math.sqrt(num_ant) return gob
[docs] def grid_of_beams_dft( num_ant_v: int, num_ant_h: int, oversmpl_v: int = 1, oversmpl_h: int = 1, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""Computes the Discrete Fourier Transform (DFT) Grid of Beam (GoB) coefficients for a uniform rectangular array (URA). GoB indices are arranged over a 2D grid indexed by :math:`(m_v,m_h)`. The coefficient of the beam with index :math:`(m_v,m_h)` applied to the antenna located at row :math:`n_v` and column :math:`n_h` of the rectangular array is expressed as: .. math:: c_{n_v,n_h}^{m_v,m_h} = e^{\frac{2\pi n_h m_v}{N_h O_h}} e^{\frac{2\pi n_h m_h}{N_v O_v}} where :math:`n_v=0,\dots,N_v-1`, :math:`n_h=0,\dots,N_h-1`, :math:`m_v=0,\dots,N_v O_v`, :math:`m_h=0,\dots,N_h O_h`, :math:`N` is the number of antennas ``num_ant`` and :math:`O_v,O_h` are the oversampling factor ``oversmpl_v``, ``oversmpl_h`` in the vertical and horizontal direction, respectively. We can rewrite more concisely the matrix coefficients :math:`c^{m_v,m_h}` as follows: .. math:: c^{m_v,m_h} = c^{m_v} \otimes c^{m_h} where :math:`\otimes` denotes the Kronecker product and :math:`c^{m_v},c^{m_h}` are the ULA DFT beams computed as in :func:`~sionna.phy.mimo.grid_of_beams_dft_ula`. Such a DFT GoB is, e.g., defined in Section 5.2.2.2.1 :cite:p:`3GPPTS38214`. :param num_ant_v: Number of antenna rows (i.e., in vertical direction) :param num_ant_h: Number of antenna columns (i.e., in horizontal direction) :param oversmpl_v: Oversampling factor in vertical direction :param oversmpl_h: Oversampling factor in horizontal direction :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output gob: [num_ant_v x oversmpl_v, num_ant_h x oversmpl_h, num_ant_v x num_ant_h], `torch.complex`. The elements :math:`[m_v,m_h,:]` contain the antenna coefficients of the DFT beam with index pair :math:`(m_v,m_h)`. .. rubric:: Examples .. code-block:: python gob = grid_of_beams_dft(num_ant_v=4, num_ant_h=8) # gob.shape = torch.Size([4, 8, 32]) """ # Compute the DFT coefficients for vertical and horizontal directions gob_v = grid_of_beams_dft_ula(num_ant_v, oversmpl=oversmpl_v, precision=precision) gob_v = gob_v[:, None, :, None] gob_h = grid_of_beams_dft_ula(num_ant_h, oversmpl=oversmpl_h, precision=precision) gob_h = gob_h[None, :, None, :] # Kronecker product # [num_ant_v * oversmpl_v, num_ant_h * oversmpl_h, num_ant_v, num_ant_h] coef_vh = gob_h * gob_v # Flatten the last two dimensions coef_vh = flatten_precoding_mat(coef_vh) return coef_vh
[docs] def flatten_precoding_mat( precoding_mat: torch.Tensor, by_column: bool = True, ) -> torch.Tensor: r"""Flattens a [..., num_ant_v, num_ant_h] precoding matrix associated with a rectangular array by producing a [..., num_ant_v x num_ant_h] precoding vector. :param precoding_mat: Precoding matrix with shape [..., num_antennas_vertical, num_antennas_horizontal]. The element :math:`(i,j)` contains the precoding coefficient of the antenna element located at row :math:`i` and column :math:`j` of a rectangular antenna array. :param by_column: If `True`, flattening occurs on a per-column basis, i.e., the first column is appended to the second, and so on. Else, flattening is performed on a per-row basis. :output precoding_vec: [..., num_antennas_vertical x num_antennas_horizontal], `torch.complex`. Flattened precoding matrix. .. rubric:: Examples .. code-block:: python mat = torch.randn(4, 8, dtype=torch.complex64) vec = flatten_precoding_mat(mat) # vec.shape = torch.Size([32]) """ # Transpose the last two dimensions if flattening by column if by_column: precoding_mat = precoding_mat.mT # Flatten the last two dimensions shape = list(precoding_mat.shape[:-2]) + [-1] precoding_vec = precoding_mat.reshape(shape) return precoding_vec
[docs] def normalize_precoding_power( precoding_vec: torch.Tensor, tx_power_list: Optional[List[float]] = None, precision: Optional[Precision] = None, ) -> torch.Tensor: r"""Normalizes the beam coefficient power to 1 by default, or to ``tx_power_list`` if provided as input. :param precoding_vec: Precoding vectors with shape [N, M]. Each row contains a set of antenna coefficients whose power is to be normalized. :param tx_power_list: The :math:`i`-th element defines the power of the :math:`i`-th precoding vector. If `None`, power is normalized to 1. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :output precoding_vec: [N, M], `torch.complex`. Normalized antenna coefficients. .. rubric:: Examples .. code-block:: python vec = torch.complex(torch.randn(4, 8), torch.randn(4, 8)) vec_norm = normalize_precoding_power(vec) # Each row now has unit power """ if precision is None: cdtype = config.cdtype rdtype = config.dtype else: cdtype = dtypes[precision]["torch"]["cdtype"] rdtype = dtypes[precision]["torch"]["dtype"] precoding_vec = precoding_vec.to(dtype=cdtype) if precoding_vec.dim() == 1: precoding_vec = precoding_vec.unsqueeze(0) if tx_power_list is None: tx_power_list = [1.0] * precoding_vec.shape[0] precoding_vec_norm = torch.norm(precoding_vec, dim=1, keepdim=True) tx_power = torch.tensor(tx_power_list, dtype=rdtype, device=precoding_vec.device).unsqueeze(-1) # Normalize the power of each row precoding_vec = (precoding_vec / precoding_vec_norm) * tx_power return precoding_vec