Source code for sionna.phy.channel.discrete_channel

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for discrete channel models"""

from typing import Optional, Tuple, Union

import torch

from sionna.phy import Block
from sionna.phy.config import Precision
from sionna.phy.utils import expand_to_rank, rand

__all__ = [
    "BinaryMemorylessChannel",
    "BinarySymmetricChannel",
    "BinaryZChannel",
    "BinaryErasureChannel",
]


class _STEBinarizer(torch.autograd.Function):
    """Straight-through estimator for binarization."""

    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        """Hard decision in forward pass."""
        return torch.where(x < 0.5, 0.0, 1.0)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """Identity in backward pass."""
        return grad_output


class _CustomXOR(torch.autograd.Function):
    """Straight-through estimator for XOR operation."""

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """XOR operation in forward pass."""
        if a.dtype in (
            torch.uint8,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
        ):
            z = torch.remainder(a + b, 2)
        else:
            z = torch.abs(a - b)
        return z

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Identity for both inputs in backward pass."""
        return grad_output, grad_output


[docs] class BinaryMemorylessChannel(Block): r""" Discrete binary memoryless channel with (possibly) asymmetric bit flipping probabilities. Input bits are flipped with probability :math:`p_\text{b,0}` and :math:`p_\text{b,1}`, respectively. .. figure:: /phy/figures/BMC_channel.png :align: center This block supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` inputs (:math:`x \in \{-1, 1\}`). If activated, the channel directly returns log-likelihood ratios (LLRs) defined as .. math:: \ell = \begin{cases} \operatorname{log} \frac{p_{b,1}}{1-p_{b,0}}, \qquad \text{if} \, y=0 \\ \operatorname{log} \frac{1-p_{b,1}}{p_{b,0}}, \qquad \text{if} \, y=1 \\ \end{cases} The error probability :math:`p_\text{b}` can be either scalar or a tensor (broadcastable to the shape of the input). This allows different erasure probabilities per bit position. In any case, its last dimension must be of length 2 and is interpreted as :math:`p_\text{b,0}` and :math:`p_\text{b,1}`. :param return_llrs: If `True`, the layer returns log-likelihood ratios instead of binary values based on ``pb``. Defaults to `False`. :param bipolar_input: If `True`, the expected input is given as :math:`\{-1,1\}` instead of :math:`\{0,1\}`. Defaults to `False`. :param llr_max: Clipping value of the LLRs. Defaults to 100. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input x: [...,n], `torch.Tensor`. Input sequence to the channel consisting of binary values :math:`\{0,1\}` or :math:`\{-1,1\}`, respectively. :input pb: [...,2], `torch.Tensor`. Error probability. Can be a tuple of two scalars or of any shape that can be broadcasted to the shape of ``x``. It has an additional last dimension which is interpreted as :math:`p_\text{b,0}` and :math:`p_\text{b,1}`. :output y: [...,n], `torch.Tensor`. Output sequence of same length as the input ``x``. If ``return_llrs`` is `False`, the output is ternary where a `-1` and `0` indicate an erasure for the binary and bipolar input, respectively. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.channel import BinaryMemorylessChannel channel = BinaryMemorylessChannel() x = torch.randint(0, 2, (10, 100), dtype=torch.float32) pb = (0.1, 0.2) # p(flip|0) = 0.1, p(flip|1) = 0.2 y = channel(x, pb) print(y.shape) # torch.Size([10, 100]) """ def __init__( self, return_llrs: bool = False, bipolar_input: bool = False, llr_max: float = 100.0, precision: Optional[Precision] = None, device: Optional[str] = None, **kwargs, ) -> None: super().__init__(precision=precision, device=device, **kwargs) if not isinstance(return_llrs, bool): raise TypeError("return_llrs must be bool.") self._return_llrs = return_llrs if not isinstance(bipolar_input, bool): raise TypeError("bipolar_input must be bool.") self._bipolar_input = bipolar_input if llr_max < 0.0: raise ValueError("llr_max must be a non-negative value.") # Register as buffer for CUDAGraph compatibility self.register_buffer("_llr_max", torch.tensor(llr_max, dtype=self.dtype, device=self.device)) self._check_input = True # check input for consistency (i.e., binary) self._eps = 1e-9 # small additional term for numerical stability self.register_buffer("_temperature", torch.tensor(0.1, dtype=self.dtype, device=self.device)) @property def llr_max(self) -> torch.Tensor: """Get/set maximum value used for LLR calculations.""" return self._llr_max @llr_max.setter def llr_max(self, value: float) -> None: if value < 0: raise ValueError("llr_max cannot be negative.") # Register as buffer for CUDAGraph compatibility self.register_buffer("_llr_max", torch.tensor(value, dtype=self.dtype, device=self.device)) @property def temperature(self) -> torch.Tensor: """Get/set temperature for Gumbel-softmax trick.""" return self._temperature @temperature.setter def temperature(self, value: float) -> None: if value < 0: raise ValueError("temperature cannot be negative.") # Register as buffer for CUDAGraph compatibility self.register_buffer("_temperature", torch.tensor(value, dtype=self.dtype, device=self.device)) @torch.compiler.disable def _check_inputs(self, x: torch.Tensor) -> None: """Check input x for consistency, i.e., verify that all values are binary or bipolar values. This method is excluded from torch.compile to avoid recompilation issues caused by the mutable _check_input flag. """ if self._check_input: x_float = x.to(self.dtype) if self._bipolar_input: valid = torch.logical_or(x_float == -1, x_float == 1) else: valid = torch.logical_or(x_float == 0, x_float == 1) if not valid.all(): raise ValueError("Input must be binary.") self._check_input = False def _check_dtype(self, x: torch.Tensor, allow_uint: bool = True) -> None: """Check input dtype for consistency with parameters.""" float_dtypes = (torch.float32, torch.float64) signed_int_dtypes = (torch.int8, torch.int16, torch.int32, torch.int64) unsigned_int_dtypes = (torch.uint8,) if self._return_llrs: if x.dtype not in float_dtypes: raise TypeError("LLR outputs require float dtypes.") else: if self._bipolar_input: valid_dtypes = float_dtypes + signed_int_dtypes if x.dtype not in valid_dtypes: raise TypeError( "Only signed dtypes are supported for bipolar inputs." ) else: valid_dtypes = float_dtypes + signed_int_dtypes + unsigned_int_dtypes if x.dtype not in valid_dtypes: raise TypeError("Only real-valued dtypes are supported.") if not allow_uint: if x.dtype in unsigned_int_dtypes: raise TypeError("Only signed dtypes supported.") def _sample_errors(self, pb: torch.Tensor, shape: torch.Size) -> torch.Tensor: """Sample binary error vector with given error probability. This function is based on the Gumbel-softmax trick to keep the sampling differentiable. """ # Implementation follows https://arxiv.org/pdf/1611.01144v5.pdf # and https://arxiv.org/pdf/1906.07748.pdf # Uses smart rand that switches to global RNG in compiled mode u1 = rand( shape, dtype=pb.dtype, device=self.device, generator=self.torch_rng ) u2 = rand( shape, dtype=pb.dtype, device=self.device, generator=self.torch_rng ) u = torch.stack((u1, u2), dim=-1) # Sample Gumbel distribution eps = torch.tensor(self._eps, dtype=pb.dtype, device=self.device) temp = self._temperature.to(pb.dtype) q = -torch.log(-torch.log(u + eps) + eps) p = torch.stack((pb, 1 - pb), dim=-1) p = expand_to_rank(p, q.dim(), axis=0) p = p.broadcast_to(q.shape) a = (torch.log(p + eps) + q) / temp # Apply softmax e_cat = torch.nn.functional.softmax(a, dim=-1) # Binarize final values via straight-through estimator return _STEBinarizer.apply(e_cat[..., 0]) def build(self, *input_shapes) -> None: """Verify correct input shapes.""" pb_shapes = input_shapes[1] # Allow tuple of scalars as alternative input if isinstance(pb_shapes, (tuple, list)): if len(pb_shapes) != 2: raise ValueError("Last dim of pb must be of length 2.") else: if len(pb_shapes) > 0: if pb_shapes[-1] != 2: raise ValueError("Last dim of pb must be of length 2.") else: raise ValueError("Last dim of pb must be of length 2.") def call( self, x: torch.Tensor, pb: Union[Tuple[float, float], torch.Tensor], ) -> torch.Tensor: """Apply discrete binary memoryless channel to inputs.""" # Check input dtype for consistency with parameters self._check_dtype(x) # Allow pb to be a tuple of two scalars if isinstance(pb, (tuple, list)): pb0 = pb[0] pb1 = pb[1] else: pb0 = pb[..., 0] pb1 = pb[..., 1] # Convert to tensor and clip for numerical stability if not isinstance(pb0, torch.Tensor): pb0 = torch.tensor(pb0, dtype=self.dtype, device=self.device) else: pb0 = pb0.to(dtype=self.dtype, device=self.device) if not isinstance(pb1, torch.Tensor): pb1 = torch.tensor(pb1, dtype=self.dtype, device=self.device) else: pb1 = pb1.to(dtype=self.dtype, device=self.device) pb0 = pb0.clamp(0.0, 1.0) pb1 = pb1.clamp(0.0, 1.0) # Check x for consistency (binary, bipolar) self._check_inputs(x) e0 = self._sample_errors(pb0, x.shape) e1 = self._sample_errors(pb1, x.shape) if self._bipolar_input: neutral_element = torch.tensor(-1, dtype=x.dtype, device=self.device) else: neutral_element = torch.tensor(0, dtype=x.dtype, device=self.device) # Mask e0 and e1 with input such that e0 only applies where x==0 e = torch.where(x == neutral_element, e0, e1) e = e.to(x.dtype) if self._bipolar_input: # Flip signs for bipolar case y = x * (-2 * e + 1) else: # XOR for binary case y = _CustomXOR.apply(x, e) # If LLRs should be returned if self._return_llrs: if not self._bipolar_input: y = 2 * y - 1 # transform to bipolar # Remark: Sionna uses the logit definition log[p(x=1)/p(x=0)] eps = self._eps y0 = -(torch.log(pb1 + eps) - torch.log(1 - pb0 - eps)) y1 = torch.log(1 - pb1 - eps) - torch.log(pb0 + eps) # Multiply by y to keep gradient y = torch.where(y == 1, y1, y0).to(y.dtype) * y # Clip output LLRs llr_max = self._llr_max.to(y.dtype) y = y.clamp(-llr_max, llr_max) return y
[docs] class BinarySymmetricChannel(BinaryMemorylessChannel): r""" Discrete binary symmetric channel which randomly flips bits with probability :math:`p_\text{b}`. .. figure:: /phy/figures/BSC_channel.png :align: center This layer supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` inputs (:math:`x \in \{-1, 1\}`). If activated, the channel directly returns log-likelihood ratios (LLRs) defined as .. math:: \ell = \begin{cases} \operatorname{log} \frac{p_{b}}{1-p_{b}}, \qquad \text{if}\, y=0 \\ \operatorname{log} \frac{1-p_{b}}{p_{b}}, \qquad \text{if}\, y=1 \\ \end{cases} where :math:`y` denotes the binary output of the channel. The bit flipping probability :math:`p_\text{b}` can be either a scalar or a tensor (broadcastable to the shape of the input). This allows different bit flipping probabilities per bit position. :param return_llrs: If `True`, the layer returns log-likelihood ratios instead of binary values based on ``pb``. Defaults to `False`. :param bipolar_input: If `True`, the expected input is given as :math:`\{-1,1\}` instead of :math:`\{0,1\}`. Defaults to `False`. :param llr_max: Clipping value of the LLRs. Defaults to 100. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input x: [...,n], `torch.Tensor`. Input sequence to the channel. :input pb: `torch.Tensor` or `float`. Bit flipping probability. Can be a scalar or of any shape that can be broadcasted to the shape of ``x``. :output y: [...,n], `torch.Tensor`. Output sequence of same length as the input ``x``. If ``return_llrs`` is `False`, the output is ternary where a `-1` and `0` indicate an erasure for the binary and bipolar input, respectively. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.channel import BinarySymmetricChannel channel = BinarySymmetricChannel() x = torch.randint(0, 2, (10, 100), dtype=torch.float32) pb = 0.1 # bit flip probability y = channel(x, pb) print(y.shape) # torch.Size([10, 100]) """ def __init__( self, return_llrs: bool = False, bipolar_input: bool = False, llr_max: float = 100.0, precision: Optional[Precision] = None, device: Optional[str] = None, **kwargs, ) -> None: super().__init__( return_llrs=return_llrs, bipolar_input=bipolar_input, llr_max=llr_max, precision=precision, device=device, **kwargs, ) def build(self, *input_shapes) -> None: """No shape validation needed for BSC.""" def call( self, x: torch.Tensor, pb: Union[float, torch.Tensor], ) -> torch.Tensor: """Apply discrete binary symmetric channel, i.e., randomly flip bits with probability ``pb``.""" # BSC is implemented by calling the base class with symmetric pb if not isinstance(pb, torch.Tensor): pb = torch.tensor(pb, dtype=x.dtype, device=self.device) else: pb = pb.to(dtype=x.dtype, device=self.device) pb_stacked = torch.stack((pb, pb), dim=-1) return super().call(x, pb_stacked)
[docs] class BinaryZChannel(BinaryMemorylessChannel): r""" Block that implements the binary Z-channel. In the Z-channel, transmission errors only occur for the transmission of second input element (i.e., if a `1` is transmitted) with error probability :math:`p_\text{b}` but the first element is always correctly received. .. figure:: /phy/figures/Z_channel.png :align: center This block supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` inputs (:math:`x \in \{-1, 1\}`). If activated, the channel directly returns log-likelihood ratios (LLRs) defined as .. math:: \ell = \begin{cases} \operatorname{log} \left( p_b \right), \qquad \text{if} \, y=0 \\ \infty, \qquad \qquad \text{if} \, y=1 \\ \end{cases} assuming equal probable inputs :math:`P(X=0) = P(X=1) = 0.5`. The error probability :math:`p_\text{b}` can be either a scalar or a tensor (broadcastable to the shape of the input). This allows different error probabilities per bit position. :param return_llrs: If `True`, the layer returns log-likelihood ratios instead of binary values based on ``pb``. Defaults to `False`. :param bipolar_input: If `True`, the expected input is given as :math:`\{-1,1\}` instead of :math:`\{0,1\}`. Defaults to `False`. :param llr_max: Clipping value of the LLRs. Defaults to 100. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input x: [...,n], `torch.Tensor`. Input sequence to the channel. :input pb: `torch.Tensor` or `float`. Error probability. Can be a scalar or of any shape that can be broadcasted to the shape of ``x``. :output y: [...,n], `torch.Tensor`. Output sequence of same length as the input ``x``. If ``return_llrs`` is `False`, the output is binary and otherwise soft-values are returned. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.channel import BinaryZChannel channel = BinaryZChannel() x = torch.randint(0, 2, (10, 100), dtype=torch.float32) pb = 0.1 # error probability for 1->0 y = channel(x, pb) print(y.shape) # torch.Size([10, 100]) """ def __init__( self, return_llrs: bool = False, bipolar_input: bool = False, llr_max: float = 100.0, precision: Optional[Precision] = None, device: Optional[str] = None, **kwargs, ) -> None: super().__init__( return_llrs=return_llrs, bipolar_input=bipolar_input, llr_max=llr_max, precision=precision, device=device, **kwargs, ) def build(self, *input_shapes) -> None: """No shape validation needed for Z-channel.""" def call( self, x: torch.Tensor, pb: Union[float, torch.Tensor], ) -> torch.Tensor: """Apply binary Z-channel, i.e., randomly flip 1s to 0s with probability ``pb``.""" # Z-channel is implemented by calling the base class with p(1|0)=0 if not isinstance(pb, torch.Tensor): pb = torch.tensor(pb, dtype=x.dtype, device=self.device) else: pb = pb.to(dtype=x.dtype, device=self.device) pb_stacked = torch.stack((torch.zeros_like(pb), pb), dim=-1) return super().call(x, pb_stacked)
[docs] class BinaryErasureChannel(BinaryMemorylessChannel): r""" Binary erasure channel (BEC) where a bit is either correctly received or erased. In the binary erasure channel, bits are always correctly received or erased with erasure probability :math:`p_\text{b}`. .. figure:: /phy/figures/BEC_channel.png :align: center This block supports binary inputs (:math:`x \in \{0, 1\}`) and `bipolar` inputs (:math:`x \in \{-1, 1\}`). If activated, the channel directly returns log-likelihood ratios (LLRs) defined as .. math:: \ell = \begin{cases} -\infty, \qquad \text{if} \, y=0 \\ 0, \qquad \quad \,\, \text{if} \, y=? \\ \infty, \qquad \quad \text{if} \, y=1 \\ \end{cases} The erasure probability :math:`p_\text{b}` can be either a scalar or a tensor (broadcastable to the shape of the input). This allows different erasure probabilities per bit position. Please note that the output of the BEC is ternary. Hereby, `-1` indicates an erasure for the binary configuration and `0` for the bipolar mode, respectively. :param return_llrs: If `True`, the layer returns log-likelihood ratios instead of binary values based on ``pb``. Defaults to `False`. :param bipolar_input: If `True`, the expected input is given as :math:`\{-1,1\}` instead of :math:`\{0,1\}`. Defaults to `False`. :param llr_max: Clipping value of the LLRs. Defaults to 100. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input x: [...,n], `torch.Tensor`. Input sequence to the channel. :input pb: `torch.Tensor` or `float`. Erasure probability. Can be a scalar or of any shape that can be broadcasted to the shape of ``x``. :output y: [...,n], `torch.Tensor`. Output sequence of same length as the input ``x``. If ``return_llrs`` is `False`, the output is ternary where each `-1` and each `0` indicate an erasure for the binary and bipolar input, respectively. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.channel import BinaryErasureChannel channel = BinaryErasureChannel() x = torch.randint(0, 2, (10, 100), dtype=torch.float32) pb = 0.1 # erasure probability y = channel(x, pb) print(y.shape) # torch.Size([10, 100]) """ def __init__( self, return_llrs: bool = False, bipolar_input: bool = False, llr_max: float = 100.0, precision: Optional[Precision] = None, device: Optional[str] = None, **kwargs, ) -> None: super().__init__( return_llrs=return_llrs, bipolar_input=bipolar_input, llr_max=llr_max, precision=precision, device=device, **kwargs, ) def build(self, *input_shapes) -> None: """No shape validation needed for BEC.""" def call( self, x: torch.Tensor, pb: Union[float, torch.Tensor], ) -> torch.Tensor: """Apply erasure channel to inputs.""" # Check input dtype for consistency with parameters self._check_dtype(x, allow_uint=False) # Convert to tensor and clip for numerical stability if not isinstance(pb, torch.Tensor): pb = torch.tensor(pb, dtype=self.dtype, device=self.device) else: pb = pb.to(dtype=self.dtype, device=self.device) pb = pb.clamp(0.0, 1.0) # Check x for consistency (binary, bipolar) self._check_inputs(x) # Sample erasure pattern e = self._sample_errors(pb, x.shape) # If LLRs should be returned # Remark: the Sionna logit definition is llr = log[p(x=1)/p(x=0)] if self._return_llrs: if not self._bipolar_input: x = 2 * x - 1 x = x * self._llr_max.to(x.dtype) # calculate LLRs # Erase positions by setting LLRs to 0 y = torch.where( e == 1, torch.tensor(0, dtype=x.dtype, device=self.device), x ) else: # ternary outputs # The erasure indicator depends on the operation mode if self._bipolar_input: erased_element = torch.tensor(0, dtype=x.dtype, device=self.device) else: erased_element = torch.tensor(-1, dtype=x.dtype, device=self.device) y = torch.where(e == 0, x, erased_element) return y