Source code for sionna.sys.link_adaptation

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Link adaptation for Sionna SYS"""

from typing import Any, Optional, Tuple, Union

import torch

from sionna.phy import Block, config, dtypes
from sionna.phy.config import Precision
from sionna.phy.utils import (
    db_to_lin,
    expand_to_rank,
    find_true_position,
    gather_from_batched_indices,
    insert_dims,
    lin_to_db,
    scalar_to_shaped_tensor,
)
from sionna.sys.phy_abstraction import PHYAbstraction
from sionna.sys.utils import is_scheduled_in_slot

__all__ = ["InnerLoopLinkAdaptation", "OuterLoopLinkAdaptation"]


def _to_python_float(value: Union[float, int, torch.Tensor]) -> float:
    """Convert a value to a Python float in a torch.compile friendly way.

    Note: If a tensor is passed, this will cause a graph break in torch.compile
    due to the .item() call. To avoid graph breaks, ensure Python floats are
    passed directly when setting properties during compiled execution.

    :param value: Value to convert (float, int, or scalar tensor)

    :output value: Python float value
    """
    if isinstance(value, (int, float)):
        return float(value)
    elif isinstance(value, torch.Tensor):
        # Detach and move to CPU to ensure we can extract the value
        return float(value.detach().cpu().item())
    else:
        # Fallback for other numeric types (numpy scalars, etc.)
        return float(value)


[docs] class InnerLoopLinkAdaptation(Block): r"""Inner loop link adaptation (ILLA). Computes the highest available modulation and coding scheme (MCS) whose associated transport block error rate (TBLER) does not exceed the specified ``bler_target``: .. math:: \max \left\{ \text{MCS}: \ \text{TBLER}(\text{MCS}, \text{SINR}_{\text{eff}}) \le \text{BLER}_{\text{target}} \right\} where :math:`\text{SINR}_{\text{eff}}` is the effective SINR value provided as input. If no such MCS exists, the lowest available MCS index is returned. If a user is not scheduled, ``fill_mcs_value`` is returned. :param phy_abstraction: An instance of :class:`~sionna.sys.PHYAbstraction`. If `None`, a default instance is created. :param bler_target: BLER target. Defaults to 0.1. :param fill_mcs_value: MCS value assigned to non-scheduled users. Defaults to 0. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input sinr: [..., num_ofdm_symbols, num_subcarriers, num_ut, num_streams_per_ut], `torch.float` | `None` (default). SINR for each OFDM symbol, subcarrier, user and stream. If `None`, then ``sinr_eff`` and ``num_allocated_re`` are both required. :input sinr_eff: [..., num_ut], `torch.float` | `None` (default). Estimated effective SINR for each user. If `None`, then ``sinr`` is required. :input num_allocated_re: [..., num_ut], `torch.int32` | `None` (default). Number of allocated resources in a slot, computed across OFDM symbols, subcarriers and streams, for each user. If `None`, then ``sinr`` is required. :input mcs_table_index: [..., num_ut], `torch.int32` | `int` (default: 1). MCS table index for each user. For further details, refer to the :ref:`mcs_table_cat_note`. :input mcs_category: [..., num_ut], `torch.int32` | `int` (default: 0). MCS table category for each user. For further details, refer to the :ref:`mcs_table_cat_note`. :input return_lowest_available_mcs: `bool` (default: `False`). If `True`, the lowest MCS available in ``phy_abstraction`` BLER tables is returned for each user. Only used for internal purposes. :output mcs_index: [..., num_ut]. Highest available MCS whose BLER does not exceed the target, or the lowest available MCS if no such MCS exists, for each user. .. rubric:: Examples .. code-block:: python import torch from sionna.sys import PHYAbstraction, InnerLoopLinkAdaptation bler_target = 0.1 # Initialize the PHY abstraction object phy_abs = PHYAbstraction() # Initialize the ILLA object illa = InnerLoopLinkAdaptation(phy_abs, bler_target=0.1) # Effective SINR for each user sinr_eff = torch.tensor([0.1, 10, 100]) # N. allocated resource elements for each user num_allocated_re = torch.tensor([20, 30, 30]) # Compute the MCS index for each user mcs_index = illa(sinr_eff=sinr_eff, num_allocated_re=num_allocated_re, mcs_table_index=1, mcs_category=0) print("Selected MCS index =", mcs_index) """ def __init__( self, phy_abstraction: Optional[PHYAbstraction] = None, bler_target: float = 0.1, fill_mcs_value: int = 0, precision: Optional[Precision] = None, device: Optional[str] = None, ) -> None: super().__init__(precision=precision, device=device) if phy_abstraction is None: phy_abstraction = PHYAbstraction(precision=precision, device=device) self._phy_abstraction = phy_abstraction self._fill_mcs_value = torch.tensor(fill_mcs_value, dtype=torch.int32, device=self.device) self._bler_target = torch.tensor(bler_target, dtype=self.dtype, device=self.device) @property def phy_abstraction(self) -> PHYAbstraction: """PHYAbstraction object used to compute TBLER (read-only).""" return self._phy_abstraction @property def bler_target(self) -> torch.Tensor: """Get/set the BLER target for each user.""" return self._bler_target @bler_target.setter def bler_target(self, value: Union[float, torch.Tensor]) -> None: if isinstance(value, torch.Tensor): self._bler_target = value.to(dtype=self.dtype, device=self.device) else: self._bler_target = torch.tensor(value, dtype=self.dtype, device=self.device) def call( self, sinr: Optional[torch.Tensor] = None, sinr_eff: Optional[torch.Tensor] = None, num_allocated_re: Optional[torch.Tensor] = None, mcs_table_index: Union[int, torch.Tensor] = 1, mcs_category: Union[int, torch.Tensor] = 0, return_lowest_available_mcs: bool = False, **kwargs: Any, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Select optimal MCS index for each user.""" # Validate inputs assert (sinr is not None) ^ ( (sinr_eff is not None) and (num_allocated_re is not None) ), "Either 'sinr' or ('sinr_eff', 'num_allocated_re') is required as input" # Number of available MCS indices num_mcs = self._phy_abstraction.bler_table_interp.shape[2] # Check which UTs are scheduled ut_is_scheduled = is_scheduled_in_slot( sinr=sinr, num_allocated_re=num_allocated_re ) # Determine batch dimensions and num_ut if sinr is not None: sinr = sinr.to(self.dtype) batch_dims = list(sinr.shape[:-4]) num_ut = sinr.shape[-2] else: sinr_eff = sinr_eff.to(self.dtype) batch_dims = list(sinr_eff.shape[:-1]) num_ut = sinr_eff.shape[-1] # ----------------------- # # Tile across MCS indices # # ----------------------- # # [..., num_mcs, num_ut] mcs_index_all = torch.arange(num_mcs, dtype=torch.int32, device=self.device) mcs_index_all = insert_dims(mcs_index_all, len(batch_dims), axis=0).unsqueeze(-1) mcs_index_all = mcs_index_all.expand(*batch_dims, num_mcs, num_ut) # [..., num_mcs, num_ut] mcs_table_index_tiled = scalar_to_shaped_tensor( mcs_table_index, torch.int32, batch_dims + [num_ut], device=self.device ) mcs_table_index_tiled = mcs_table_index_tiled.unsqueeze(-2).expand( *batch_dims, num_mcs, num_ut ) # [..., num_mcs, num_ut] mcs_category_tiled = scalar_to_shaped_tensor( mcs_category, torch.int32, batch_dims + [num_ut], device=self.device ) mcs_category_tiled = mcs_category_tiled.unsqueeze(-2).expand( *batch_dims, num_mcs, num_ut ) if num_allocated_re is not None: # [..., num_mcs, num_ut] num_allocated_re = num_allocated_re.to(torch.int32) num_allocated_re_tiled = num_allocated_re.unsqueeze(-2).expand( *batch_dims, num_mcs, num_ut ) else: num_allocated_re_tiled = None # -------------- # # Effective SINR # # -------------- # # Expand across all possible MCS indices if sinr is not None: # [..., num_mcs, num_ofdm_symbols, num_subcarriers, num_ut, num_streams_per_ut] sinr_tiled = sinr.unsqueeze(-5).expand( *batch_dims, num_mcs, *sinr.shape[-4:] ) sinr_eff_tiled = None else: # [..., num_mcs, num_ut] sinr_tiled = None sinr_eff_tiled = sinr_eff.unsqueeze(-2).expand(*batch_dims, num_mcs, num_ut) # ----- # # TBLER # # ----- # # [..., num_mcs, num_ut] *_, tbler_per_mcs, _ = self._phy_abstraction( mcs_index_all, sinr=sinr_tiled, sinr_eff=sinr_eff_tiled, num_allocated_re=num_allocated_re_tiled, mcs_table_index=mcs_table_index_tiled, mcs_category=mcs_category_tiled, check_mcs_index_validity=False, **kwargs, ) # ---------- # # Select MCS # # ---------- # # Find the highest MCS with TBLER <= bler_target # If no such MCS is found, returns -1 # Note: TBLER can be -inf for MCS indices without BLER data in the table, # which should still be considered valid (as -inf <= bler_target is True) # [..., num_ut] mcs_index = find_true_position( tbler_per_mcs <= self._bler_target, side="last", axis=-2, ) # Lowest available MCS (TBLER in valid range [0, 1]) # [..., num_ut] lowest_available_mcs = find_true_position( (tbler_per_mcs >= 0) & (tbler_per_mcs <= 1), side="first", axis=-2 ) # If all MCS have TBLER > bler_target, select lowest available MCS mcs_index = torch.where(mcs_index != -1, mcs_index, lowest_available_mcs) # A non-scheduled user receives MCS=fill_mcs_value mcs_index = torch.where(ut_is_scheduled, mcs_index, self._fill_mcs_value) if return_lowest_available_mcs: return mcs_index, lowest_available_mcs return mcs_index
[docs] class OuterLoopLinkAdaptation(Block): r"""Outer-loop link adaptation (OLLA). The modulation and coding scheme (MCS) index for a user is determined as the highest index whose corresponding transport block error rate (TBLER) remains below the specified ``bler_target``. The SINR value used for TBLER computation is given by the last effective SINR feedback, :math:`\text{SINR}_{\text{eff}}` [dB], reduced by an offset value, :math:`\Delta_{\mathrm{offset}}`: .. math:: \max \left\{ \text{MCS}: \ \text{TBLER}(\text{MCS}, \text{SINR}_{\text{eff}}-\Delta_{\text{offset}}) \le \text{BLER}_{\text{target}} \right\} The value of :math:`\Delta_{\text{offset}}` is adjusted depending on the HARQ feedback :cite:p:`Pedersen05`: .. math:: \Delta_{\mathrm{offset}} = \left\{ \begin{array}{l} \Delta_{\mathrm{offset}} - \Delta_{\mathrm{down}} \quad \text{if HARQ=ACK} \\ \Delta_{\mathrm{offset}} + \Delta_{\mathrm{up}} \quad \text{if HARQ=NACK} \end{array} \right. where the relationship between :math:`\Delta_{\mathrm{up}}` and :math:`\Delta_{\mathrm{down}}` is given by :cite:p:`Sampath97`: .. math:: \frac{\Delta_{\mathrm{up}}}{\Delta_{\mathrm{down}}} = \frac{1 - \mathrm{BLER}_{\mathrm{target}}}{\mathrm{BLER}_{\mathrm{target}}}. :param phy_abstraction: An instance of :class:`~sionna.sys.PHYAbstraction` :param num_ut: Number of user terminals :param bler_target: BLER target value, within 0 and 1. Defaults to 0.1. :param delta_up: Increment applied to the SINR offset [dB] when a NACK is received for a user. Defaults to 1.0. :param batch_size: Batch size or shape. It accounts for multiple users for which link adaptation is performed simultaneously. If `None`, the batch size is set to []. :param sinr_eff_init: Initial value of effective SINR for each user. Non-positive values are treated as missing and replaced by ``sinr_eff_init_fill``. If `float`, the same value is assigned to all users. Defaults to 1.0. :param sinr_eff_init_fill: Value replacing non-positive ``sinr_eff_init`` values. Defaults to 1.0. :param offset_min: Minimum SINR [dB] offset value. Defaults to -20.0. :param offset_max: Maximum SINR [dB] offset value. Defaults to 20.0. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input num_allocated_re: [..., num_ut], `torch.int32`. Number of allocated resources in the upcoming slot, computed across OFDM symbols, subcarriers and streams, for each user. :input harq_feedback: [..., num_ut], -1 | 0 | 1. If 0 (1, resp.), then a NACK (ACK, resp.) is received. If -1, feedback is missing. :input sinr_eff: [..., num_ut], `torch.float` | `None` (default). Estimated effective SINR for each user. Non-positive values are treated as missing. :input mcs_table_index: [..., num_ut], `torch.int32` | `int` (default: 1). MCS table index for each user. For further details, refer to the :ref:`mcs_table_cat_note`. :input mcs_category: [..., num_ut], `torch.int32` | `int` (default: 0). MCS table category for each user. For further details, refer to the :ref:`mcs_table_cat_note`. :output mcs_index: [..., num_ut]. Selected MCS index for each user. .. rubric:: Examples .. code-block:: python import torch from sionna.sys import PHYAbstraction, OuterLoopLinkAdaptation num_ut = 4 bler_target = 0.1 # Initialize the PHY abstraction object phy_abs = PHYAbstraction() # Initialize the OLLA object olla = OuterLoopLinkAdaptation(phy_abs, num_ut=num_ut, bler_target=bler_target) # Number of allocated REs for each user num_allocated_re = torch.tensor([100, 200, 150, 50]) # HARQ feedback for each user (-1: N/A, 0: NACK, 1: ACK) harq_feedback = torch.tensor([1, 0, 1, -1]) # Effective SINR feedback for each user sinr_eff = torch.tensor([10.0, 5.0, 8.0, 0.0]) # Compute the MCS index for each user mcs_index = olla(num_allocated_re, harq_feedback, sinr_eff) """ def __init__( self, phy_abstraction: PHYAbstraction, num_ut: int, bler_target: float = 0.1, delta_up: float = 1.0, batch_size: Optional[Union[int, list]] = None, sinr_eff_init: float = 1.0, sinr_eff_init_fill: float = 1.0, offset_min: float = -20.0, offset_max: float = 20.0, precision: Optional[Precision] = None, device: Optional[str] = None, ) -> None: super().__init__(precision=precision, device=device) if sinr_eff_init_fill <= 0: raise ValueError("'sinr_eff_init_fill' must be positive") if batch_size is None: batch_size = [] elif not isinstance(batch_size, list): batch_size = [batch_size] self._batch_size = batch_size self._num_ut = num_ut self._phy_abstraction = phy_abstraction self._illa = InnerLoopLinkAdaptation(phy_abstraction, bler_target=bler_target, precision=precision, device=device) # Store scalar parameters as Python floats to avoid device issues with torch.compile self._bler_target_value = _to_python_float(bler_target) self._delta_up_value = _to_python_float(delta_up) self._delta_down_value = self._get_delta_down_value() self._offset_min_value = _to_python_float(offset_min) self._offset_max_value = _to_python_float(offset_max) # Initialize effective SINR [dB] sinr_eff_init_tensor = scalar_to_shaped_tensor( sinr_eff_init, self.dtype, self._batch_size + [self._num_ut], device=self.device ) sinr_eff_init_fill_db = lin_to_db( torch.tensor(sinr_eff_init_fill, dtype=self.dtype, device=self.device), precision=self.precision ) # Convert effective SINR to dB and fill N/A values (<=0) # State tensors as regular attributes (not buffers) for simpler torch.compile handling self._sinr_eff_db_last = torch.where( sinr_eff_init_tensor > 0, lin_to_db(sinr_eff_init_tensor, precision=self.precision), sinr_eff_init_fill_db ) # Reset SINR offset to 0 self._offset = torch.zeros( self._batch_size + [self._num_ut], dtype=self.dtype, device=self.device ) def _get_delta_down_value(self) -> float: return self._delta_up_value * self._bler_target_value / (1 - self._bler_target_value)
[docs] def reset( self, sinr_eff_init: float = 1.0, sinr_eff_init_fill: float = 0.1, ) -> None: """Resets the values of ``sinr_eff_db_last`` and ``offset``. :param sinr_eff_init: Initial effective SINR value (linear scale). :param sinr_eff_init_fill: Fill value for N/A SINR entries (linear scale). """ device = self._sinr_eff_db_last.device sinr_eff_init_tensor = scalar_to_shaped_tensor( sinr_eff_init, self.dtype, self._batch_size + [self._num_ut], device=device ) sinr_eff_init_fill_db = lin_to_db( torch.tensor(sinr_eff_init_fill, dtype=self.dtype, device=device), precision=self.precision ) # Convert effective SINR to dB and fill N/A values (<=0) self._sinr_eff_db_last = torch.where( sinr_eff_init_tensor > 0, lin_to_db(sinr_eff_init_tensor, precision=self.precision), sinr_eff_init_fill_db ) # Reset SINR offset to 0 self._offset = torch.zeros_like(self._offset)
@property def offset(self) -> torch.Tensor: """Effective SINR [dB] offset for each user (read-only).""" return self._offset @property def offset_max(self) -> float: """Get/set the maximum ``offset`` value.""" return self._offset_max_value @offset_max.setter def offset_max(self, value: float) -> None: self._offset_max_value = _to_python_float(value) @property def offset_min(self) -> float: """Get/set the minimum ``offset`` value.""" return self._offset_min_value @offset_min.setter def offset_min(self, value: float) -> None: self._offset_min_value = _to_python_float(value) @property def bler_target(self) -> float: """Get/set the BLER target for each user.""" return self._bler_target_value @bler_target.setter def bler_target(self, value: float) -> None: self._bler_target_value = _to_python_float(value) self._delta_down_value = self._get_delta_down_value() self._illa.bler_target = self._bler_target_value @property def sinr_eff_db_last(self) -> torch.Tensor: """Get/set the last observed effective SINR [dB] value for each user.""" return self._sinr_eff_db_last @sinr_eff_db_last.setter def sinr_eff_db_last(self, value: torch.Tensor) -> None: self._sinr_eff_db_last = value.to(dtype=self.dtype) @property def delta_down(self) -> float: r"""Decrement applied to the SINR offset when an ACK is received (read-only). Computed as ``delta_up * bler_target / (1 - bler_target)``. """ return self._delta_down_value @property def delta_up(self) -> float: """Get/set the increment applied to the SINR offset when a NACK is received.""" return self._delta_up_value @delta_up.setter def delta_up(self, value: float) -> None: value = _to_python_float(value) if value <= 0: raise ValueError("'delta_up' must be positive") self._delta_up_value = value self._delta_down_value = self._get_delta_down_value() def call( self, num_allocated_re: torch.Tensor, harq_feedback: Optional[torch.Tensor] = None, sinr_eff: Optional[torch.Tensor] = None, mcs_table_index: int = 1, mcs_category: int = 0, ) -> torch.Tensor: """Run outer loop link adaptation.""" shape = num_allocated_re.shape # Handle defaults - use module's device (which should match input device) if harq_feedback is None: harq_feedback = torch.full(shape, -1, dtype=torch.int32, device=self.device) else: harq_feedback = harq_feedback.to(dtype=torch.int32) if sinr_eff is None: sinr_eff = torch.zeros(shape, dtype=self.dtype, device=self.device) else: sinr_eff = sinr_eff.to(dtype=self.dtype) num_allocated_re = num_allocated_re.to(dtype=torch.int32) mcs_table_index = scalar_to_shaped_tensor( mcs_table_index, torch.int32, list(shape), device=self.device ) mcs_category = scalar_to_shaped_tensor( mcs_category, torch.int32, list(shape), device=self.device ) # ---------------------------- # # Update effective SINR offset # # ---------------------------- # # Use Python floats for scalar parameters to avoid device issues with torch.compile delta_down = self._delta_down_value delta_up = self._delta_up_value offset_min = self._offset_min_value offset_max = self._offset_max_value new_offset = torch.where( harq_feedback == 1, self._offset - delta_down, torch.where( harq_feedback == 0, self._offset + delta_up, self._offset ) ) # Project offset to [offset_min; offset_max] new_offset = torch.clamp(new_offset, offset_min, offset_max) self._offset = new_offset # ----------------------------------- # # Update last observed effective SINR # # ----------------------------------- # new_sinr_eff_db = torch.where( sinr_eff > 0, lin_to_db(sinr_eff, precision=self.precision), self._sinr_eff_db_last ) self._sinr_eff_db_last = new_sinr_eff_db # -------------------------- # # Offset SINR and apply ILLA # # -------------------------- # sinr_eff_offset = db_to_lin( self._sinr_eff_db_last - self._offset, precision=self.precision ) mcs_index = self._illa( sinr_eff=sinr_eff_offset, num_allocated_re=num_allocated_re, mcs_table_index=mcs_table_index, mcs_category=mcs_category, return_lowest_available_mcs=False ) return mcs_index