#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Scheduling algorithms for Sionna SYS"""
from typing import Any, List, Optional, Tuple, Union
import torch
from sionna.phy import Block, config, dtypes
from sionna.phy.config import Precision
from sionna.phy.utils import insert_dims
__all__ = ["PFSchedulerSUMIMO"]
[docs]
class PFSchedulerSUMIMO(Block):
r"""Proportional fairness (PF) scheduler for single-user MIMO (SU-MIMO) systems.
Schedules users according to a proportional fairness (PF) metric in a
single-user (SU) multiple-input multiple-output (MIMO) system, i.e., at most
one user is scheduled per time-frequency resource.
Fixing the time slot :math:`t`, :math:`\tilde{R}_t(u,i)` is
the :emphasis:`achievable` rate for user :math:`u` on the time-frequency
resource :math:`i` during the current slot.
Let :math:`T_{t-1}(u)` denote the throughput :emphasis:`achieved` by user
:math:`u` up to and including slot :math:`t-1`.
Resource :math:`i` is assigned to the user with the highest PF metric,
as defined in :cite:p:`Jalali00`:
.. math::
\operatorname{argmax}_{u} \frac{\tilde{R}_{t}(u,i)}{T_{t-1}(u)}.
All streams within a scheduled resource element are assigned to the selected user.
Let :math:`R_t(u)` be the rate achieved by user :math:`u` in slot :math:`t`.
The throughput :math:`T` by each user :math:`u` is updated via
geometric discounting:
.. math::
T_t(u) = \beta \, T_{t-1}(u) + (1-\beta) \, R_t(u)
where :math:`\beta\in(0,1)` is the discount factor.
:param num_ut: Number of user terminals
:param num_freq_res: Number of available frequency resources.
A frequency resource is the smallest frequency unit that can be
allocated to a user, typically a physical resource block (PRB).
:param num_ofdm_sym: Number of OFDM symbols in a slot
:param batch_size: Batch size or shape. It can account for multiple sectors in
which scheduling is performed simultaneously. If `None`, the batch size
is set to [].
:param num_streams_per_ut: Number of streams per user. Defaults to 1.
:param beta: Discount factor for computing the time-averaged achieved rate.
Must be within (0,1). Defaults to 0.98.
:param precision: Precision used for internal calculations and outputs.
If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation. If `None`,
:attr:`~sionna.phy.config.Config.device` is used.
:input rate_last_slot: [batch_size, num_ut], `torch.float`.
Rate achieved by each user in the last slot.
:input rate_achievable_curr_slot: [batch_size, num_ofdm_sym, num_freq_res, num_ut], `torch.float`.
Achievable rate for each user across the OFDM grid in the
current slot.
:output is_scheduled: [batch_size, num_ofdm_sym, num_freq_res, num_ut, num_streams_per_ut], `torch.bool`.
Whether a user is scheduled for transmission for each available resource.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.sys import PFSchedulerSUMIMO
num_ut = 4
num_freq_res = 52
num_ofdm_sym = 14
batch_size = 10
# Create PF scheduler
scheduler = PFSchedulerSUMIMO(
num_ut,
num_freq_res,
num_ofdm_sym,
batch_size=batch_size
)
# Generate random achievable rates and last slot rates
rate_last_slot = torch.rand(batch_size, num_ut) * 100
rate_achievable_curr_slot = torch.rand(
batch_size, num_ofdm_sym, num_freq_res, num_ut
) * 100
# Get scheduling decisions
is_scheduled = scheduler(rate_last_slot, rate_achievable_curr_slot)
print(is_scheduled.shape)
# torch.Size([10, 14, 52, 4, 1])
"""
def __init__(
self,
num_ut: int,
num_freq_res: int,
num_ofdm_sym: int,
batch_size: Optional[Union[List[int], int]] = None,
num_streams_per_ut: int = 1,
beta: float = 0.98,
precision: Optional[Precision] = None,
device: Optional[str] = None,
) -> None:
super().__init__(precision=precision, device=device)
if batch_size is None:
batch_size = []
elif not isinstance(batch_size, list):
if isinstance(batch_size, int) or len(batch_size) == 0:
batch_size = [batch_size]
self._batch_size = batch_size
self._num_ut = int(num_ut)
self._num_freq_res = int(num_freq_res)
self._num_ofdm_sym = int(num_ofdm_sym)
self._num_streams_per_ut = int(num_streams_per_ut)
# Validate and store beta as Python float to avoid device issues with torch.compile
assert 0.0 < beta < 1.0, "Discount factor 'beta' must be within (0, 1)"
self._beta_value = float(beta)
# Register state tensors as buffers for proper device tracking
# Average achieved rate (internal state)
self.register_buffer(
"_rate_achieved_past",
torch.ones(
list(batch_size) + [num_ut], dtype=self.dtype, device=self.device
),
persistent=False,
)
# PF metric (internal state for debugging)
self.register_buffer(
"_pf_metric",
torch.zeros(
list(batch_size) + [num_ofdm_sym, num_freq_res, num_ut],
dtype=self.dtype,
device=self.device,
),
persistent=False,
)
@property
def rate_achieved_past(self) -> torch.Tensor:
r"""[batch_size, num_ut] (read-only) : :math:`\beta`-discounted
time-averaged achieved rate for each user"""
return self._rate_achieved_past
@property
def pf_metric(self) -> torch.Tensor:
"""[batch_size, num_ofdm_sym, num_freq_res, num_ut] (read-only) :
Proportional fairness (PF) metric in the last slot"""
return self._pf_metric
@property
def beta(self) -> float:
"""Get/set the discount factor for computing the time-averaged
achieved rate. Must be within (0,1)."""
return self._beta_value
@beta.setter
def beta(self, value: float) -> None:
assert 0.0 < value < 1.0, "Discount factor 'beta' must be within (0, 1)"
self._beta_value = float(value)
def call(
self,
rate_last_slot: torch.Tensor,
rate_achievable_curr_slot: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
"""Compute scheduling decisions based on proportional fairness."""
# ------------------------ #
# Validate and cast inputs #
# ------------------------ #
expected_rate_last_slot_shape = self._batch_size + [self._num_ut]
assert list(rate_last_slot.shape) == expected_rate_last_slot_shape, (
f"Inconsistent 'rate_last_slot' shape: expected {expected_rate_last_slot_shape}, "
f"got {list(rate_last_slot.shape)}"
)
expected_rate_achievable_shape = self._batch_size + [
self._num_ofdm_sym,
self._num_freq_res,
self._num_ut,
]
assert list(rate_achievable_curr_slot.shape) == expected_rate_achievable_shape, (
f"Inconsistent 'rate_achievable_curr_slot' shape: expected "
f"{expected_rate_achievable_shape}, got {list(rate_achievable_curr_slot.shape)}"
)
# [batch_size, num_ut]
rate_last_slot = rate_last_slot.to(self.dtype)
# [batch_size, num_ofdm_sym, num_freq_res, num_ut]
rate_achievable_curr_slot = rate_achievable_curr_slot.to(self.dtype)
# ---------------------------- #
# Update average achieved rate #
# ---------------------------- #
# [batch_size, num_ut]
# Use Python float for beta to avoid device mismatch issues with torch.compile
beta = self._beta_value
rate_achieved_past_new = beta * self._rate_achieved_past + (1 - beta) * rate_last_slot
# Store updated state using in-place copy for torch.compile compatibility
self._rate_achieved_past.copy_(rate_achieved_past_new)
# [batch_size, 1, 1, num_ut]
rate_achieved_past = insert_dims(rate_achieved_past_new, 2, axis=-2)
# ----------------- #
# Compute PF metric #
# ----------------- #
# [batch_size, num_ofdm_sym, num_freq_res, num_ut]
pf_metric = rate_achievable_curr_slot / rate_achieved_past
# Store for debugging access using in-place copy for torch.compile compatibility
self._pf_metric.copy_(pf_metric)
# ------------ #
# Schedule UTs #
# ------------ #
# Assign each time/frequency resource to the user with highest PF metric
# [batch_size, num_ofdm_sym, num_freq_res]
scheduled_ut = torch.argmax(self._pf_metric, dim=-1)
# [batch_size, num_ofdm_sym, num_freq_res, num_ut]
is_scheduled = torch.nn.functional.one_hot(
scheduled_ut, num_classes=self._num_ut
)
# [batch_size, num_ofdm_sym, num_freq_res, num_ut, 1]
is_scheduled = is_scheduled.unsqueeze(-1)
# [batch_size, num_ofdm_sym, num_freq_res, num_ut, num_streams]
is_scheduled = is_scheduled.expand(
*is_scheduled.shape[:-1], self._num_streams_per_ut
)
return is_scheduled.to(torch.bool)