#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for Polar decoding such as successive cancellation (SC),
successive cancellation list (SCL) and iterative belief propagation (BP)
decoding."""
from typing import Optional, Tuple, Union
import numbers
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from sionna.phy import Block
from sionna.phy.fec.crc import CRCDecoder, CRCEncoder
from sionna.phy.fec.polar.encoding import Polar5GEncoder
__all__ = [
"PolarSCDecoder",
"PolarSCLDecoder",
"PolarBPDecoder",
"Polar5GDecoder",
]
[docs]
class PolarSCDecoder(Block):
"""Successive cancellation (SC) decoder :cite:p:`Arikan_Polar` for Polar codes
and Polar-like codes.
:param frozen_pos: Array of `int` defining the ``n-k`` indices of the
frozen positions.
:param n: Defining the codeword length.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the channel LLR values (as logits).
:output u_hat: [..., k], `torch.float`.
Tensor containing hard-decided estimations of all ``k``
information bits.
.. rubric:: Notes
This block implements the SC decoder as described in
:cite:p:`Arikan_Polar`. However, the implementation follows the `recursive
tree` :cite:p:`Gross_Fast_SCL` terminology and combines nodes for increased
throughputs without changing the outcome of the algorithm.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import PolarSCDecoder, PolarEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking
k, n = 100, 256
frozen_pos, _ = generate_5g_ranking(k, n)
encoder = PolarEncoder(frozen_pos, n)
decoder = PolarSCDecoder(frozen_pos, n)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
llr_ch = 20.0 * (2.0 * codewords - 1) # BPSK without noise
decoded = decoder(llr_ch)
print(torch.equal(bits, decoded))
# True
"""
def __init__(
self,
frozen_pos: np.ndarray,
n: int,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
n = int(n)
if not np.issubdtype(frozen_pos.dtype, int):
raise TypeError("frozen_pos contains non int.")
if len(frozen_pos) > n:
msg = "Num. of elements in frozen_pos cannot be greater than n."
raise ValueError(msg)
if np.log2(n) != int(np.log2(n)):
raise ValueError("n must be a power of 2.")
# Store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
if self._k != len(self._info_pos):
msg = "Internal error: invalid info_pos generated."
raise ArithmeticError(msg)
# Register info_pos as buffer for torch.compile compatibility
self.register_buffer(
"_info_pos_t",
torch.tensor(self._info_pos, dtype=torch.int64, device=self.device),
)
self._llr_max = 30.0 # Internal max LLR value
# Create a frozen bit vector for simpler encoding
self._frozen_ind = np.zeros(self._n)
self._frozen_ind[self._frozen_pos] = 1
# Register frozen indicator as tensor buffer for torch.compile compatibility
self.register_buffer(
"_frozen_ind_t",
torch.tensor(self._frozen_ind, dtype=self.dtype, device=self.device),
)
# Enable graph pruning
self._use_fast_sc = False
@property
def n(self) -> int:
"""Codeword length."""
return self._n
@property
def k(self) -> int:
"""Number of information bits."""
return self._k
@property
def frozen_pos(self) -> np.ndarray:
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self) -> np.ndarray:
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self) -> float:
"""Maximum LLR value for internal calculations."""
return self._llr_max
def _cn_op(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Check-node update (boxplus) for LLR inputs.
Operations are performed element-wise.
See :cite:p:`Stimming_LLR` and :cite:p:`Hashemi_SSCL` for detailed equations.
"""
x_in = torch.clamp(x, min=-self._llr_max, max=self._llr_max)
y_in = torch.clamp(y, min=-self._llr_max, max=self._llr_max)
# Avoid division for numerical stability
llr_out = torch.log(1 + torch.exp(x_in + y_in))
llr_out = llr_out - torch.log(torch.exp(x_in) + torch.exp(y_in))
return llr_out
def _vn_op(
self, x: torch.Tensor, y: torch.Tensor, u_hat: torch.Tensor
) -> torch.Tensor:
"""VN update for LLR inputs."""
return (1 - 2 * u_hat) * x + y
def _polar_decode_sc(
self, llr_ch: torch.Tensor, frozen_ind: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Recursive SC decoding function.
Recursively branch decoding tree and split into decoding of `upper`
and `lower` path until reaching a leaf node.
The function returns the u_hat decisions at stage `0` and the bit
decisions of the intermediate stage `s` (i.e., the re-encoded
version of `u_hat` until the current stage `s`).
This decoder parallelizes over the batch-dimension, i.e., the tree
is processed for all samples in the batch in parallel. This yields a
higher throughput, but does not improve the latency.
"""
# Calculate current codeword length
n = frozen_ind.shape[0]
# Branch if leaf is not reached yet
if n > 1:
if self._use_fast_sc:
if frozen_ind.sum() == n:
u_hat = torch.zeros_like(llr_ch)
return u_hat, u_hat
llr_ch1 = llr_ch[..., 0 : int(n / 2)]
llr_ch2 = llr_ch[..., int(n / 2) :]
frozen_ind1 = frozen_ind[0 : int(n / 2)]
frozen_ind2 = frozen_ind[int(n / 2) :]
# Upper path
x_llr1_in = self._cn_op(llr_ch1, llr_ch2)
# Call the decoding function (with upper half)
u_hat1, u_hat1_up = self._polar_decode_sc(x_llr1_in, frozen_ind1)
# Lower path
x_llr2_in = self._vn_op(llr_ch1, llr_ch2, u_hat1_up)
# Call the decoding function again (with lower half)
u_hat2, u_hat2_up = self._polar_decode_sc(x_llr2_in, frozen_ind2)
# Combine u_hat from both branches
u_hat = torch.cat([u_hat1, u_hat2], -1)
# Calculate re-encoded version of u_hat at current stage
u_hat1_up_int = u_hat1_up.to(torch.int8)
u_hat2_up_int = u_hat2_up.to(torch.int8)
u_hat1_up_int = torch.bitwise_xor(u_hat1_up_int, u_hat2_up_int)
u_hat1_up = u_hat1_up_int.to(self.dtype)
u_hat_up = torch.cat([u_hat1_up, u_hat2_up], -1)
else: # If leaf is reached perform basic decoding op (=decision)
# Use tensor operations to avoid CUDA graph breaks
# frozen_ind is a 1-element tensor at this point
is_frozen = frozen_ind[0] == 1 # Tensor comparison
# Compute frozen case: u_hat = 0
frozen_result = torch.zeros_like(llr_ch)
# Compute non-frozen case: hard decision
decision_result = 0.5 * (1.0 - torch.sign(llr_ch))
# Handle exact 0 LLRs (u_hat = 0.5) by setting to 1
decision_result = torch.where(
decision_result == 0.5,
torch.ones_like(decision_result),
decision_result,
)
# Branchless selection using torch.where
u_hat = torch.where(is_frozen, frozen_result, decision_result)
u_hat_up = u_hat
return u_hat, u_hat_up
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Check if shape of input is invalid."""
if input_shape[-1] != self._n:
raise ValueError("Invalid input shape.")
def call(self, llr_ch: torch.Tensor) -> torch.Tensor:
"""Successive cancellation (SC) decoding function.
Performs successive cancellation decoding and returns the estimated
information bits.
:param llr_ch: Tensor of shape `[..., n]` containing the
channel LLR values (as logits).
:output u_hat: Tensor of shape `[..., k]` containing hard-decided
estimations of all ``k`` information bits.
Note: This function recursively unrolls the SC decoding tree, thus,
for larger values of ``n`` building the decoding graph can become
time consuming.
"""
# Reshape inputs to [-1, n]
input_shape = llr_ch.shape
new_shape = (-1, self._n)
llr_ch = llr_ch.reshape(new_shape)
llr_ch = -1.0 * llr_ch # Logits are converted into "true" llrs
# Decode
u_hat_n, _ = self._polar_decode_sc(llr_ch, self._frozen_ind_t)
# Recover the k information bit positions using pre-registered buffer
u_hat = u_hat_n[:, self._info_pos_t]
# Reconstruct input shape
output_shape = list(input_shape[:-1]) + [self.k]
u_hat_reshape = u_hat.reshape(output_shape)
return u_hat_reshape
[docs]
class PolarSCLDecoder(Block):
# pylint: disable=line-too-long
"""Successive cancellation list (SCL) decoder :cite:p:`Tal_SCL` for Polar codes
and Polar-like codes.
:param frozen_pos: Array of `int` defining the ``n-k`` indices of the
frozen positions.
:param n: Defining the codeword length.
:param list_size: Defines the list size of the decoder.
:param crc_degree: Defining the CRC polynomial to be used. Can be any
value from `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`.
:param use_hybrid_sc: If `True`, SC decoding is applied and only the
codewords with invalid CRC are decoded with SCL. This option
requires an outer CRC specified via ``crc_degree``.
:param use_fast_scl: If `True`, tree pruning is used to reduce
the decoding complexity. The output is equivalent to the
non-pruned version (besides numerical differences).
:param cpu_only: If `True`, a NumPy-based decoder runs on the CPU.
This option is usually slower, but also more memory efficient
and in particular recommended for larger blocklengths.
:param use_scatter: If `True`, scatter update is used for tensor
updates. This option is usually slower, but more memory efficient.
:param ind_iil_inv: If not `None`, the sequence is used as inverse
input bit interleaver before evaluating the CRC. This only
affects the CRC evaluation but the output sequence is not
permuted.
:param return_crc_status: If `True`, the decoder additionally returns
the CRC status indicating if a codeword was (most likely)
correctly recovered. This is only available if ``crc_degree`` is
not `None`.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the channel LLR values (as logits).
:output b_hat: [..., k], `torch.float`.
Binary tensor containing hard-decided estimations of all `k`
information bits.
:output crc_status: [...], `torch.bool`.
CRC status indicating if a codeword was (most likely) correctly
recovered. This is only returned if ``return_crc_status`` is `True`.
Note that false positives are possible.
.. rubric:: Notes
This block implements the successive cancellation list (SCL) decoder
as described in :cite:p:`Tal_SCL` but uses LLR-based message updates
:cite:p:`Stimming_LLR`. The implementation follows the notation from
:cite:p:`Gross_Fast_SCL`, :cite:p:`Hashemi_SSCL`. If option ``use_fast_scl`` is
active, tree pruning is used and tree nodes are combined if possible
(see :cite:p:`Hashemi_SSCL` for details).
For longer code lengths, the complexity of the decoding graph becomes
large and we recommend to use the ``cpu_only`` option that uses an
embedded NumPy decoder. Further, this function recursively unrolls the
SCL decoding tree, thus, for larger values of ``n`` building the
decoding graph can become time consuming. Please consider the
``cpu_only`` option if building the graph takes too long.
A hybrid SC/SCL decoder as proposed in :cite:p:`Cammerer_Hybrid_SCL` (using
SC instead of BP) can be activated with option ``use_hybrid_sc`` iff
an outer CRC is available. Please note that the results are not
exactly SCL performance caused by the false positive rate of the CRC.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import PolarSCLDecoder, PolarEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking
k, n = 100, 256
frozen_pos, _ = generate_5g_ranking(k, n)
encoder = PolarEncoder(frozen_pos, n)
decoder = PolarSCLDecoder(frozen_pos, n, list_size=8)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
llr_ch = 20.0 * (2.0 * codewords - 1) # BPSK without noise
decoded = decoder(llr_ch)
print(torch.equal(bits, decoded))
# True
"""
def __init__(
self,
frozen_pos: np.ndarray,
n: int,
list_size: int = 8,
crc_degree: Optional[str] = None,
use_hybrid_sc: bool = False,
use_fast_scl: bool = True,
cpu_only: bool = False,
use_scatter: bool = False,
ind_iil_inv: Optional[np.ndarray] = None,
return_crc_status: bool = False,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
n = int(n)
if not isinstance(list_size, int):
raise TypeError("list_size must be integer.")
if not isinstance(cpu_only, bool):
raise TypeError("cpu_only must be bool.")
if not isinstance(use_scatter, bool):
raise TypeError("use_scatter must be bool.")
if not isinstance(use_fast_scl, bool):
raise TypeError("use_fast_scl must be bool.")
if not isinstance(use_hybrid_sc, bool):
raise TypeError("use_hybrid_sc must be bool.")
if not isinstance(return_crc_status, bool):
raise TypeError("return_crc_status must be bool.")
if not np.issubdtype(frozen_pos.dtype, int):
raise TypeError("frozen_pos contains non int.")
if len(frozen_pos) > n:
msg = "Num. of elements in frozen_pos cannot be greater than n."
raise ValueError(msg)
if np.log2(n) != int(np.log2(n)):
raise ValueError("n must be a power of 2.")
if np.log2(list_size) != int(np.log2(list_size)):
raise ValueError("list_size must be a power of 2.")
# CPU mode is recommended for larger values of n
if n > 128 and cpu_only is False and use_hybrid_sc is False:
warnings.warn(
"Required resource allocation is large "
"for the selected blocklength. Consider option `cpu_only=True`."
)
# CPU mode is recommended for larger values of L
if list_size > 32 and cpu_only is False and use_hybrid_sc is False:
warnings.warn(
"Resource allocation is high for the "
"selected list_size. Consider option `cpu_only=True`."
)
# Internal decoder parameters
self._use_fast_scl = use_fast_scl
self._use_scatter = use_scatter
self._cpu_only = cpu_only
self._use_hybrid_sc = use_hybrid_sc
# Store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._list_size = list_size
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
self._llr_max = 30.0
if self._k != len(self._info_pos):
raise ArithmeticError("Internal error: invalid info_pos generated.")
# Create a frozen bit vector
self._frozen_ind = np.zeros(self._n)
self._frozen_ind[self._frozen_pos] = 1
self._cw_ind = np.arange(self._n)
self._n_stages = int(np.log2(self._n))
# Register frozen indicator as tensor buffer for torch.compile compatibility
self.register_buffer(
"_frozen_ind_t",
torch.tensor(self._frozen_ind, dtype=self.dtype, device=self.device),
)
# Register info_pos as tensor buffer to avoid torch.tensor() in call
self.register_buffer(
"_info_pos_t",
torch.tensor(self._info_pos, dtype=torch.int64, device=self.device),
)
# Init CRC check (if needed)
if crc_degree is not None:
self._use_crc = True
self._crc_encoder = CRCEncoder(
crc_degree, precision=precision, device=device
)
self._crc_decoder = CRCDecoder(
self._crc_encoder, precision=precision, device=device
)
self._k_crc = self._crc_decoder.encoder.crc_length
else:
self._use_crc = False
self._k_crc = 0
if self._k < self._k_crc:
msg = "Value of k is too small for given CRC_degree."
raise ValueError(msg)
if (crc_degree is None) and return_crc_status:
self._return_crc_status = False
raise ValueError("Returning CRC status requires given crc_degree.")
else:
self._return_crc_status = return_crc_status
# Store the inverse interleaver pattern
if ind_iil_inv is not None:
if ind_iil_inv.shape[0] != self._k:
raise ValueError("ind_int must be of length k+k_crc.")
self._ind_iil_inv = ind_iil_inv
self._iil = True
# Register as tensor buffer to avoid torch.tensor() in call
self.register_buffer(
"_ind_iil_inv_t",
torch.tensor(ind_iil_inv, dtype=torch.int32, device=self.device),
)
else:
self._iil = False
# Use SC decoder first and use numpy-based SCL as "afterburner"
if self._use_hybrid_sc:
self._decoder_sc = PolarSCDecoder(
frozen_pos, n, precision=precision, device=device
)
if not self._use_crc:
raise ValueError("Hybrid SC requires outer CRC.")
@property
def n(self) -> int:
"""Codeword length."""
return self._n
@property
def k(self) -> int:
"""Number of information bits."""
return self._k
@property
def k_crc(self) -> int:
"""Number of CRC bits."""
return self._k_crc
@property
def frozen_pos(self) -> np.ndarray:
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self) -> np.ndarray:
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self) -> float:
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def list_size(self) -> int:
"""List size for SCL decoding."""
return self._list_size
# NumPy-based decoder helper functions
def _cn_op_np(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Check node update (boxplus) for LLRs in NumPy.
See :cite:p:`Stimming_LLR` and :cite:p:`Hashemi_SSCL` for detailed equations.
"""
x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max)
y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max)
llr_out = np.log(1 + np.exp(x_in + y_in))
llr_out -= np.log(np.exp(x_in) + np.exp(y_in))
return llr_out
def _vn_op_np(
self, x: np.ndarray, y: np.ndarray, u_hat: np.ndarray
) -> np.ndarray:
"""Variable node update (boxplus) for LLRs in Numpy."""
return np.multiply((1 - 2 * u_hat), x) + y
def _update_rate0_code_np(self, cw_ind: np.ndarray) -> None:
"""Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind``.
See Eq. (26) in :cite:p:`Hashemi_SSCL`.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
ind = np.expand_dims(self._dec_pointer, axis=-1)
llr_in = np.take_along_axis(
self.msg_llr[:, :, stage_ind, cw_ind], ind, axis=1
)
llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
pm_val = np.log(1 + np.exp(-llr_clip))
self.msg_pm += np.sum(pm_val, axis=-1)
def _update_rep_code_np(self, cw_ind: np.ndarray) -> None:
"""Update rep. code sub-code at position ``cw_ind``.
See Eq. (31) in :cite:p:`Hashemi_SSCL`.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
bs = self._dec_pointer.shape[0]
llr = np.zeros([bs, 2 * self._list_size, n])
for i in range(bs):
llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :]
llr[i, :, :] = llr_i[:, cw_ind]
llr[:, self._list_size :, :] = -llr[:, self._list_size :, :]
llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max)
pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1)
self.msg_pm += pm_val
for i in range(bs):
ind_dec = self._dec_pointer[i, self._list_size :]
for j in cw_ind:
self.msg_uhat[i, ind_dec, stage_ind, j] = 1
self._update_single_bit_np([cw_ind[-1]])
self._sort_decoders_np()
self._duplicate_paths_np()
def _update_single_bit_np(self, ind_u: list) -> None:
"""Update single bit at position ``ind_u`` of all decoders."""
if self._frozen_ind[ind_u] == 0:
ind_dec = np.expand_dims(
self._dec_pointer[:, self._list_size :], axis=-1
)
uhat_slice = self.msg_uhat[:, :, 0, ind_u]
np.put_along_axis(uhat_slice, ind_dec, 1.0, axis=1)
self.msg_uhat[:, :, 0, ind_u] = uhat_slice
def _update_pm_np(self, ind_u: list) -> None:
"""Update path metric of all decoders at bit position ``ind_u``.
We apply Eq. (10) from :cite:p:`Stimming_LLR`.
"""
ind = np.expand_dims(self._dec_pointer, axis=-1)
u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1)
u_hat = np.squeeze(u_hat, axis=-1)
llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1)
llr_in = np.squeeze(llr_in, axis=-1)
llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
self.msg_pm += np.log(
1 + np.exp(-np.multiply((1 - 2 * u_hat), llr_clip))
)
def _sort_decoders_np(self) -> None:
"""Sort decoders according to their path metric."""
ind = np.argsort(self.msg_pm, axis=-1)
self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1)
self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1)
def _duplicate_paths_np(self) -> None:
"""Copy first ``list_size``/2 paths into lower part.
Decoder indices are encoded in ``self._dec_pointer``.
"""
ind_low = self._dec_pointer[:, : self._list_size]
ind_up = self._dec_pointer[:, self._list_size :]
for i in range(ind_up.shape[0]):
self.msg_uhat[i, ind_up[i, :], :, :] = self.msg_uhat[
i, ind_low[i, :], :, :
]
self.msg_llr[i, ind_up[i, :], :, :] = self.msg_llr[
i, ind_low[i, :], :, :
]
self.msg_pm[:, self._list_size :] = self.msg_pm[:, : self._list_size]
def _polar_decode_scl_np(self, cw_ind: np.ndarray) -> None:
"""Recursive decoding function in NumPy.
We follow the terminology from :cite:p:`Hashemi_SSCL` and
:cite:p:`Stimming_LLR` and branch the messages into a `left` and `right`
update paths until reaching a leaf node.
Tree pruning as proposed in :cite:p:`Hashemi_SSCL` is used to minimize
the tree depth while maintaining the same output.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
if n > 1:
if self._use_fast_scl:
if np.sum(self._frozen_ind[cw_ind]) == n:
self._update_rate0_code_np(cw_ind)
return
if (
self._frozen_ind[cw_ind[-1]] == 0
and np.sum(self._frozen_ind[cw_ind[:-1]]) == n - 1
):
self._update_rep_code_np(cw_ind)
return
cw_ind_left = cw_ind[0 : int(n / 2)]
cw_ind_right = cw_ind[int(n / 2) :]
# Left branch
llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
self.msg_llr[:, :, stage_ind - 1, cw_ind_left] = self._cn_op_np(
llr_left, llr_right
)
self._polar_decode_scl_np(cw_ind_left)
# Right branch
u_hat_left_up = self.msg_uhat[:, :, stage_ind - 1, cw_ind_left]
llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
self.msg_llr[:, :, stage_ind - 1, cw_ind_right] = self._vn_op_np(
llr_left, llr_right, u_hat_left_up
)
self._polar_decode_scl_np(cw_ind_right)
# Combine u_hat
u_hat_left_up = self.msg_uhat[:, :, stage_ind - 1, cw_ind_left]
u_hat_right_up = self.msg_uhat[:, :, stage_ind - 1, cw_ind_right]
u_hat_left = (u_hat_left_up != u_hat_right_up) + 0
u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1)
self.msg_uhat[:, :, stage_ind, cw_ind] = u_hat
else:
self._update_single_bit_np(cw_ind)
self._update_pm_np(cw_ind)
if self._frozen_ind[cw_ind] == 0:
self._sort_decoders_np()
self._duplicate_paths_np()
def _decode_np_batch(
self, llr_ch: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Decode batch of ``llr_ch`` with Numpy decoder."""
bs = llr_ch.shape[0]
self.msg_uhat = np.zeros(
[bs, 2 * self._list_size, self._n_stages + 1, self._n]
)
self.msg_llr = np.zeros(
[bs, 2 * self._list_size, self._n_stages + 1, self._n]
)
self.msg_pm = np.zeros([bs, 2 * self._list_size])
self.msg_pm[:, 1 : self._list_size] = self._llr_max
self.msg_pm[:, self._list_size + 1 :] = self._llr_max
self._dec_pointer = np.arange(2 * self._list_size)
self._dec_pointer = np.tile(
np.expand_dims(self._dec_pointer, axis=0), [bs, 1]
)
self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1)
self._polar_decode_scl_np(self._cw_ind)
self._sort_decoders_np()
for ind in range(bs):
self.msg_uhat[ind, :, :, :] = self.msg_uhat[
ind, self._dec_pointer[ind], :, :
]
return self.msg_uhat, self.msg_pm
def _decode_np_hybrid(
self,
llr_ch: np.ndarray,
u_hat_sc: np.ndarray,
crc_valid: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Hybrid SCL decoding stage that decodes iff CRC from previous SC
decoding attempt failed.
This option avoids the usage of the high-complexity SCL decoder in
cases where SC would be sufficient. For further details we refer to
:cite:p:`Cammerer_Hybrid_SCL` (we use SC instead of the proposed BP
stage).
This decoder does not exactly implement SCL as the CRC can be
false positive after the SC stage. However, in these cases SCL+CRC
may also yield the wrong results.
"""
bs = llr_ch.shape[0]
crc_valid = np.squeeze(crc_valid, axis=-1)
ind_invalid = np.arange(bs)[np.invert(crc_valid)]
llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0)
msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb)
msg_uhat = np.zeros([bs, 2 * self._list_size, 1, self._n])
msg_pm = np.ones([bs, 2 * self._list_size]) * self._llr_max * self.k
msg_pm[:, 0] = 0
msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc
ind_hyb = 0
for ind in range(bs):
if not crc_valid[ind]:
msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :]
msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :]
ind_hyb += 1
return msg_uhat, msg_pm
# =========================================================================
# PyTorch tensor-based decoder helper functions (for GPU acceleration)
# =========================================================================
def _cn_op_pt(
self, x: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
"""Check-node update (boxplus) for LLR inputs in PyTorch.
Operations are performed element-wise.
See :cite:p:`Stimming_LLR` and :cite:p:`Hashemi_SSCL` for detailed equations.
"""
x_in = torch.clamp(x, min=-self._llr_max, max=self._llr_max)
y_in = torch.clamp(y, min=-self._llr_max, max=self._llr_max)
# Implements log(1+e^(x+y)) - log(e^x+e^y)
llr_out = F.softplus(x_in + y_in)
llr_out = llr_out - torch.logsumexp(
torch.stack([x_in, y_in], dim=-1), dim=-1
)
return llr_out
def _vn_op_pt(
self, x: torch.Tensor, y: torch.Tensor, u_hat: torch.Tensor
) -> torch.Tensor:
"""Variable node update for LLR inputs in PyTorch.
Operations are performed element-wise.
See :cite:p:`Stimming_LLR` and :cite:p:`Hashemi_SSCL` for detailed equations.
"""
return (1 - 2 * u_hat) * x + y
def _sort_decoders_pt(
self,
msg_pm: torch.Tensor,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sort decoders according to their path metric in PyTorch."""
ind = torch.argsort(msg_pm, dim=-1)
# For msg_pm: [batch, 2*L]
msg_pm = torch.gather(msg_pm, 1, ind)
# For msg_uhat: [batch, 2*L, stages+1, n]
ind_expanded = ind.unsqueeze(-1).unsqueeze(-1).expand(
-1, -1, msg_uhat.shape[2], msg_uhat.shape[3]
)
msg_uhat = torch.gather(msg_uhat, 1, ind_expanded)
# For msg_llr: [batch, 2*L, stages+1, n]
ind_expanded = ind.unsqueeze(-1).unsqueeze(-1).expand(
-1, -1, msg_llr.shape[2], msg_llr.shape[3]
)
msg_llr = torch.gather(msg_llr, 1, ind_expanded)
return msg_pm, msg_uhat, msg_llr
def _duplicate_paths_pt(
self,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
msg_pm: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Duplicate paths by copying the upper branch into the lower one."""
# Take first list_size paths and tile them
msg_uhat = msg_uhat[:, : self._list_size, :, :].repeat(1, 2, 1, 1)
msg_llr = msg_llr[:, : self._list_size, :, :].repeat(1, 2, 1, 1)
msg_pm = msg_pm[:, : self._list_size].repeat(1, 2)
return msg_uhat, msg_llr, msg_pm
def _update_pm_pt(
self,
ind_u: np.ndarray,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
msg_pm: torch.Tensor,
) -> torch.Tensor:
"""Update path metric after updating bit_pos ``ind_u`` in PyTorch.
We implement Eq. (10) from :cite:p:`Stimming_LLR`.
"""
u_hat = msg_uhat[:, :, 0, ind_u[0]]
llr = msg_llr[:, :, 0, ind_u[0]]
llr_in = torch.clamp(llr, min=-self._llr_max, max=self._llr_max)
# Numerically stable: log(1 + exp(-x))
msg_pm = msg_pm + F.softplus(-(1 - 2 * u_hat) * llr_in)
return msg_pm
def _update_single_bit_pt(
self, ind_u: np.ndarray, msg_uhat: torch.Tensor
) -> torch.Tensor:
"""Update single bit at position ``ind_u`` for all decoders in PyTorch.
Uses branchless computation for torch.compile compatibility.
For info bits (non-frozen), sets upper half decoders' bit to 1.
For frozen bits, sets to 0 (no-op since frozen bits are 0).
"""
# Get info bit indicator from tensor buffer (1 if info, 0 if frozen)
# This avoids data-dependent Python control flow
is_info_bit = 1.0 - self._frozen_ind_t[ind_u[0]]
# Set upper half decoders' bit at position ind_u
# Value is 1 for info bits, 0 for frozen bits (branchless)
msg_uhat1 = msg_uhat[:, : self._list_size, :, :]
msg_uhat21 = msg_uhat[:, self._list_size :, 0:1, : ind_u[0]]
msg_uhat22 = msg_uhat[:, self._list_size :, 0:1, ind_u[0] + 1 :]
# Insert value: 1 if info bit, 0 if frozen
batch_size = msg_uhat.shape[0]
msg_insert = is_info_bit * torch.ones(
batch_size, self._list_size, 1, 1,
dtype=msg_uhat.dtype, device=msg_uhat.device
)
msg_uhat23 = torch.cat([msg_uhat21, msg_insert, msg_uhat22], dim=3)
msg_uhat24 = msg_uhat[:, self._list_size :, 1:, :]
msg_uhat2 = torch.cat([msg_uhat23, msg_uhat24], dim=2)
msg_uhat = torch.cat([msg_uhat1, msg_uhat2], dim=1)
return msg_uhat
def _update_rate0_code_pt(
self,
msg_pm: torch.Tensor,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
cw_ind: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Update rate-0 sub-code (all frozen) at pos ``cw_ind`` in PyTorch.
See Eq. (26) in :cite:p:`Hashemi_SSCL`.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
llr = msg_llr[:, :, stage_ind, cw_ind[0] : cw_ind[-1] + 1]
llr_in = torch.clamp(llr, min=-self._llr_max, max=self._llr_max)
# Update path metric for complete sub-block
pm_val = F.softplus(-llr_in)
msg_pm = msg_pm + pm_val.sum(dim=-1)
return msg_pm, msg_uhat, msg_llr
def _update_rep_code_pt(
self,
msg_pm: torch.Tensor,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
cw_ind: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Update rep. code sub-code at position ``cw_ind`` in PyTorch.
See Eq. (31) in :cite:p:`Hashemi_SSCL`.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
# Get LLRs for this sub-code
llr = msg_llr[:, :, stage_ind, cw_ind[0] : cw_ind[-1] + 1]
llr_in = torch.clamp(llr, min=-self._llr_max, max=self._llr_max)
# Upper branch has negative LLR values (bit is 1)
llr_low = llr_in[:, : self._list_size, :]
llr_up = -llr_in[:, self._list_size :, :]
llr_pm = torch.cat([llr_low, llr_up], dim=1)
pm_val = F.softplus(-llr_pm)
msg_pm = msg_pm + pm_val.sum(dim=-1)
# Set bits to 1 for upper branch decoders
# Use split/concat approach
msg_uhat1 = msg_uhat[:, : self._list_size, :, :]
msg_uhat21 = msg_uhat[:, self._list_size :, stage_ind : stage_ind + 1, : cw_ind[0]]
msg_uhat22 = msg_uhat[:, self._list_size :, stage_ind : stage_ind + 1, cw_ind[-1] + 1 :]
batch_size = msg_uhat.shape[0]
msg_ones = torch.ones(
batch_size, self._list_size, 1, n,
dtype=msg_uhat.dtype, device=msg_uhat.device
)
msg_uhat23 = torch.cat([msg_uhat21, msg_ones, msg_uhat22], dim=3)
msg_uhat24_1 = msg_uhat[:, self._list_size :, :stage_ind, :]
msg_uhat24_2 = msg_uhat[:, self._list_size :, stage_ind + 1 :, :]
msg_uhat2 = torch.cat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], dim=2)
msg_uhat = torch.cat([msg_uhat1, msg_uhat2], dim=1)
# Branch last bit and update
msg_uhat = self._update_single_bit_pt([cw_ind[-1]], msg_uhat)
msg_pm, msg_uhat, msg_llr = self._sort_decoders_pt(msg_pm, msg_uhat, msg_llr)
msg_uhat, msg_llr, msg_pm = self._duplicate_paths_pt(msg_uhat, msg_llr, msg_pm)
return msg_pm, msg_uhat, msg_llr
def _update_left_branch_pt(
self,
msg_llr: torch.Tensor,
stage_ind: int,
cw_ind_left: np.ndarray,
cw_ind_right: np.ndarray,
) -> torch.Tensor:
"""Update messages of left branch in PyTorch."""
llr_left_in = msg_llr[:, :, stage_ind, cw_ind_left[0] : cw_ind_left[-1] + 1]
llr_right_in = msg_llr[:, :, stage_ind, cw_ind_right[0] : cw_ind_right[-1] + 1]
llr_left_out = self._cn_op_pt(llr_left_in, llr_right_in)
# Use split/concatenation approach
llr_left0 = msg_llr[:, :, stage_ind - 1, : cw_ind_left[0]]
llr_right = msg_llr[:, :, stage_ind - 1, cw_ind_right[0] : cw_ind_right[-1] + 1]
llr_right1 = msg_llr[:, :, stage_ind - 1, cw_ind_right[-1] + 1 :]
llr_s = torch.cat([llr_left0, llr_left_out, llr_right, llr_right1], dim=2)
llr_s = llr_s.unsqueeze(2)
msg_llr1 = msg_llr[:, :, : stage_ind - 1, :]
msg_llr2 = msg_llr[:, :, stage_ind:, :]
msg_llr = torch.cat([msg_llr1, llr_s, msg_llr2], dim=2)
return msg_llr
def _update_right_branch_pt(
self,
msg_llr: torch.Tensor,
msg_uhat: torch.Tensor,
stage_ind: int,
cw_ind_left: np.ndarray,
cw_ind_right: np.ndarray,
) -> torch.Tensor:
"""Update messages for right branch in PyTorch."""
u_hat_left_up = msg_uhat[:, :, stage_ind - 1, cw_ind_left[0] : cw_ind_left[-1] + 1]
llr_left_in = msg_llr[:, :, stage_ind, cw_ind_left[0] : cw_ind_left[-1] + 1]
llr_right = msg_llr[:, :, stage_ind, cw_ind_right[0] : cw_ind_right[-1] + 1]
llr_right_out = self._vn_op_pt(llr_left_in, llr_right, u_hat_left_up)
# Use split/concatenation approach
llr_left0 = msg_llr[:, :, stage_ind - 1, : cw_ind_left[0]]
llr_left = msg_llr[:, :, stage_ind - 1, cw_ind_left[0] : cw_ind_left[-1] + 1]
llr_right1 = msg_llr[:, :, stage_ind - 1, cw_ind_right[-1] + 1 :]
llr_s = torch.cat([llr_left0, llr_left, llr_right_out, llr_right1], dim=2)
llr_s = llr_s.unsqueeze(2)
msg_llr1 = msg_llr[:, :, : stage_ind - 1, :]
msg_llr2 = msg_llr[:, :, stage_ind:, :]
msg_llr = torch.cat([msg_llr1, llr_s, msg_llr2], dim=2)
return msg_llr
def _update_branch_u_pt(
self,
msg_uhat: torch.Tensor,
stage_ind: int,
cw_ind_left: np.ndarray,
cw_ind_right: np.ndarray,
) -> torch.Tensor:
"""Update ``u_hat`` messages after executing both branches in PyTorch."""
u_hat_left_up = msg_uhat[:, :, stage_ind - 1, cw_ind_left[0] : cw_ind_left[-1] + 1]
u_hat_right_up = msg_uhat[:, :, stage_ind - 1, cw_ind_right[0] : cw_ind_right[-1] + 1]
# Combine u_hat via XOR
u_hat_left = (u_hat_left_up.int() ^ u_hat_right_up.int()).to(msg_uhat.dtype)
# Use split/concatenation approach
u_hat_left_0 = msg_uhat[:, :, stage_ind, : cw_ind_left[0]]
u_hat_right_1 = msg_uhat[:, :, stage_ind, cw_ind_right[-1] + 1 :]
u_hat = torch.cat([u_hat_left_0, u_hat_left, u_hat_right_up, u_hat_right_1], dim=2)
msg_uhat1 = msg_uhat[:, :, :stage_ind, :]
msg_uhat2 = msg_uhat[:, :, stage_ind + 1 :, :]
u_hat = u_hat.unsqueeze(2)
msg_uhat = torch.cat([msg_uhat1, u_hat, msg_uhat2], dim=2)
return msg_uhat
def _polar_decode_scl_pt(
self,
cw_ind: np.ndarray,
msg_uhat: torch.Tensor,
msg_llr: torch.Tensor,
msg_pm: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Recursive decoding function for SCL decoding in PyTorch.
We follow the terminology from :cite:p:`Hashemi_SSCL` and
:cite:p:`Stimming_LLR` and branch the messages into a `left` and `right`
update paths until reaching a leaf node.
Tree pruning as proposed in :cite:p:`Hashemi_SSCL` is used to minimize
the tree depth while maintaining the same output.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
if n > 1:
# Prune tree if rate-0 subcode is detected
if self._use_fast_scl:
if np.sum(self._frozen_ind[cw_ind]) == n:
msg_pm, msg_uhat, msg_llr = self._update_rate0_code_pt(
msg_pm, msg_uhat, msg_llr, cw_ind
)
return msg_uhat, msg_llr, msg_pm
if (
self._frozen_ind[cw_ind[-1]] == 0
and np.sum(self._frozen_ind[cw_ind[:-1]]) == n - 1
):
msg_pm, msg_uhat, msg_llr = self._update_rep_code_pt(
msg_pm, msg_uhat, msg_llr, cw_ind
)
return msg_uhat, msg_llr, msg_pm
# Split index into left and right part
cw_ind_left = cw_ind[: n // 2]
cw_ind_right = cw_ind[n // 2 :]
# ----- Left branch -----
msg_llr = self._update_left_branch_pt(
msg_llr, stage_ind, cw_ind_left, cw_ind_right
)
# Call sub-graph decoder of left branch
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl_pt(
cw_ind_left, msg_uhat, msg_llr, msg_pm
)
# ----- Right branch -----
msg_llr = self._update_right_branch_pt(
msg_llr, msg_uhat, stage_ind, cw_ind_left, cw_ind_right
)
# Call sub-graph decoder of right branch
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl_pt(
cw_ind_right, msg_uhat, msg_llr, msg_pm
)
# Update uhat at current stage
msg_uhat = self._update_branch_u_pt(
msg_uhat, stage_ind, cw_ind_left, cw_ind_right
)
else:
# Leaf node: perform basic decoding op (=decision)
msg_uhat = self._update_single_bit_pt(cw_ind, msg_uhat)
msg_pm = self._update_pm_pt(cw_ind, msg_uhat, msg_llr, msg_pm)
if self._frozen_ind[cw_ind] == 0: # Position is non-frozen
msg_pm, msg_uhat, msg_llr = self._sort_decoders_pt(
msg_pm, msg_uhat, msg_llr
)
msg_uhat, msg_llr, msg_pm = self._duplicate_paths_pt(
msg_uhat, msg_llr, msg_pm
)
return msg_uhat, msg_llr, msg_pm
@torch.compiler.disable
def _decode_pt_hybrid(
self, llr_ch: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Hybrid SC/SCL decoding in PyTorch.
Runs SC decoding first, checks CRC, then runs SCL only for samples
with failed CRC. This is more efficient than full SCL when most
samples decode correctly.
Note:
This function is marked with @torch.compiler.disable because
it uses data-dependent conditional logic that causes graph breaks.
"""
batch_size = llr_ch.shape[0]
device = llr_ch.device
dtype = llr_ch.dtype
# Step 1: Run SC decoding on all samples
u_hat_sc = self._decoder_sc(-llr_ch)
# Step 2: Check CRC to find which samples need SCL
# Apply input bit interleaver inverse before CRC check if needed
if self._iil:
u_hat_sc_crc = u_hat_sc[:, self._ind_iil_inv_t]
else:
u_hat_sc_crc = u_hat_sc
_, crc_valid = self._crc_decoder(u_hat_sc_crc)
crc_valid = crc_valid.squeeze(-1) # [batch_size]
# Step 3: Initialize output with SC results
msg_uhat = torch.zeros(
batch_size, 2 * self._list_size, 1, self._n,
dtype=dtype, device=device
)
msg_pm = torch.ones(
batch_size, 2 * self._list_size,
dtype=dtype, device=device
) * self._llr_max * self.k
msg_pm[:, 0] = 0 # SC result has zero path metric
# Place SC results in first decoder slot at info positions
msg_uhat[:, 0, 0, self._info_pos_t] = u_hat_sc
# Step 4: Find samples with invalid CRC
invalid_mask = ~crc_valid
invalid_indices = torch.nonzero(invalid_mask, as_tuple=True)[0]
# Step 5: Run SCL only on invalid samples (if any)
if invalid_indices.numel() > 0:
llr_invalid = llr_ch[invalid_indices]
msg_uhat_scl, msg_pm_scl = self._decode_pt(llr_invalid)
# Merge SCL results into output
# msg_uhat_scl has shape [num_invalid, 2*L, stages+1, n]
# We only need the final stage (index 0 after sorting)
msg_uhat[invalid_indices] = msg_uhat_scl[:, :, 0:1, :]
msg_pm[invalid_indices] = msg_pm_scl
return msg_uhat, msg_pm
@torch.compiler.disable
def _decode_pt(
self, llr_ch: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Main decoding function in PyTorch.
Initializes memory and calls recursive decoding function.
Note:
This function is marked with @torch.compiler.disable because the
recursive SCL algorithm uses position-dependent slicing that causes
graph breaks. Disabling compilation here allows the rest of the
model to be compiled while this runs in eager mode.
"""
batch_size = llr_ch.shape[0]
device = llr_ch.device
dtype = llr_ch.dtype
# Allocate memory for all 2*list_size decoders
msg_uhat = torch.zeros(
batch_size, 2 * self._list_size, self._n_stages + 1, self._n,
dtype=dtype, device=device
)
msg_llr = torch.zeros(
batch_size, 2 * self._list_size, self._n_stages, self._n,
dtype=dtype, device=device
)
# Init all 2*L decoders with same llr_ch
llr_ch_expanded = llr_ch.reshape(-1, 1, 1, self._n)
llr_ch_expanded = llr_ch_expanded.expand(-1, 2 * self._list_size, 1, -1)
# Init last stage with llr_ch
msg_llr = torch.cat([msg_llr, llr_ch_expanded], dim=2)
# Init all remaining L-1 decoders with high penalty
pm0 = torch.zeros(batch_size, 1, dtype=dtype, device=device)
pm1 = self._llr_max * torch.ones(
batch_size, self._list_size - 1, dtype=dtype, device=device
)
msg_pm = torch.cat([pm0, pm1, pm0, pm1], dim=1)
# Call recursive graph function
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl_pt(
self._cw_ind, msg_uhat, msg_llr, msg_pm
)
# Sort output
msg_pm, msg_uhat, msg_llr = self._sort_decoders_pt(
msg_pm, msg_uhat, msg_llr
)
return msg_uhat, msg_pm
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Build and check if shape of input is invalid."""
if input_shape[-1] != self._n:
raise ValueError("Invalid input shape.")
def call(
self, llr_ch: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Successive cancellation list (SCL) decoding function.
This function performs successive cancellation list decoding
and returns the estimated information bits.
An outer CRC can be applied optionally by setting ``crc_degree``.
:param llr_ch: Tensor of shape `[..., n]` containing the
channel LLR values (as logits).
:output b_hat: Tensor of shape `[..., k]` containing hard-decided
estimations of all ``k`` information bits.
:output crc_status: CRC status. Returned only if
``return_crc_status`` is `True`.
Note: This function recursively unrolls the SCL decoding tree,
thus, for larger values of ``n`` building the decoding graph can
become time consuming. Please consider the ``cpu_only`` option
instead.
"""
input_shape = llr_ch.shape
new_shape = (-1, self._n)
llr_ch = llr_ch.reshape(new_shape)
llr_ch = -1.0 * llr_ch # Logits to LLRs
# Choose decoder implementation
if self._use_hybrid_sc:
# Hybrid SC/SCL: use SC first, then SCL for failed CRC
# Uses PyTorch implementation for GPU acceleration
msg_uhat, msg_pm = self._decode_pt_hybrid(llr_ch)
elif self._cpu_only:
# CPU-only mode: use NumPy decoder (more memory efficient)
llr_np = llr_ch.cpu().numpy()
msg_uhat, msg_pm = self._decode_np_batch(llr_np)
msg_uhat = torch.tensor(
msg_uhat, dtype=self.dtype, device=self.device
)
msg_pm = torch.tensor(msg_pm, dtype=self.dtype, device=self.device)
else:
# Default: use PyTorch tensor-based decoder (GPU-accelerated)
msg_uhat, msg_pm = self._decode_pt(llr_ch)
# Check CRC (and remove CRC parity bits)
if self._use_crc:
# Use pre-registered tensor buffer instead of torch.tensor()
u_hat_list = msg_uhat[:, :, 0, self._info_pos_t]
if self._iil:
# Use pre-registered tensor buffer
u_hat_list_crc = u_hat_list[:, :, self._ind_iil_inv_t]
else:
u_hat_list_crc = u_hat_list
_, crc_valid = self._crc_decoder(u_hat_list_crc)
pm_penalty = (
(1.0 - crc_valid.float()) * self._llr_max * self.k
)
msg_pm = msg_pm + pm_penalty.squeeze(-1)
# Select most likely candidate
cand_ind = torch.argmin(msg_pm, dim=-1)
batch_indices = torch.arange(msg_uhat.shape[0], device=msg_uhat.device)
c_hat = msg_uhat[batch_indices, cand_ind, 0, :]
# Use pre-registered tensor buffer
u_hat = c_hat[:, self._info_pos_t]
# Reconstruct input shape
output_shape = list(input_shape[:-1]) + [self.k]
u_hat_reshape = u_hat.reshape(output_shape)
if self._return_crc_status:
crc_status = crc_valid[batch_indices, cand_ind]
output_shape_crc = list(input_shape[:-1])
crc_status = crc_status.reshape(output_shape_crc)
return u_hat_reshape, crc_status
else:
return u_hat_reshape
[docs]
class PolarBPDecoder(Block):
# pylint: disable=line-too-long
"""Belief propagation (BP) decoder for Polar codes :cite:p:`Arikan_Polar` and
Polar-like codes based on :cite:p:`Arikan_BP` and :cite:p:`Forney_Graphs`.
:param frozen_pos: Array of `int` defining the ``n-k`` indices of the
frozen positions.
:param n: Defining the codeword length.
:param num_iter: Defining the number of decoder iterations (no early
stopping used at the moment).
:param hard_out: If `True`, the decoder provides hard-decided
information bits instead of soft-values.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the channel logits/llr values.
:output u_hat: [..., k], `torch.float`.
Tensor containing bit-wise soft-estimates (or hard-decided
bit-values) of all ``k`` information bits.
.. rubric:: Notes
This decoder is fully differentiable and, thus, well-suited for
gradient descent-based learning tasks such as `learned code design`
:cite:p:`Ebada_Design`.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import PolarBPDecoder, PolarEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking
k, n = 100, 256
frozen_pos, _ = generate_5g_ranking(k, n)
encoder = PolarEncoder(frozen_pos, n)
decoder = PolarBPDecoder(frozen_pos, n, num_iter=20)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
llr_ch = 20.0 * (2.0 * codewords - 1) # BPSK without noise
decoded = decoder(llr_ch)
print(torch.equal(bits, decoded))
# True
"""
def __init__(
self,
frozen_pos: np.ndarray,
n: int,
num_iter: int = 20,
hard_out: bool = True,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
n = int(n)
if not np.issubdtype(frozen_pos.dtype, int):
raise TypeError("frozen_pos contains non int.")
if len(frozen_pos) > n:
msg = "Num. of elements in frozen_pos cannot be greater than n."
raise ValueError(msg)
if np.log2(n) != int(np.log2(n)):
raise ValueError("n must be a power of 2.")
if not isinstance(hard_out, bool):
raise TypeError("hard_out must be boolean.")
# Store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
if self._k != len(self._info_pos):
raise ArithmeticError("Internal error: invalid info_pos generated.")
# Register info_pos as buffer for torch.compile compatibility
self.register_buffer(
"_info_pos_t",
torch.tensor(self._info_pos, dtype=torch.int64, device=self.device),
)
if not isinstance(num_iter, int):
raise TypeError("num_iter must be integer.")
if num_iter <= 0:
raise ValueError("num_iter must be a positive value.")
self._num_iter = num_iter
self._llr_max = 19.3
self._hard_out = hard_out
self._n_stages = int(np.log2(self._n))
@property
def n(self) -> int:
"""Codeword length."""
return self._n
@property
def k(self) -> int:
"""Number of information bits."""
return self._k
@property
def frozen_pos(self) -> np.ndarray:
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self) -> np.ndarray:
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self) -> float:
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def num_iter(self) -> int:
"""Number of decoding iterations."""
return self._num_iter
@num_iter.setter
def num_iter(self, num_iter: int) -> None:
"""Number of decoding iterations."""
if not isinstance(num_iter, int):
raise ValueError("num_iter must be int.")
if num_iter < 0:
raise ValueError("num_iter cannot be negative.")
self._num_iter = num_iter
@property
def hard_out(self) -> bool:
"""Indicates if decoder hard-decides outputs."""
return self._hard_out
def _boxplus(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Check-node update (boxplus) for LLR inputs."""
x_in = torch.clamp(x, min=-self._llr_max, max=self._llr_max)
y_in = torch.clamp(y, min=-self._llr_max, max=self._llr_max)
llr_out = torch.log(1 + torch.exp(x_in + y_in))
llr_out = llr_out - torch.log(torch.exp(x_in) + torch.exp(y_in))
return llr_out
def _decode_bp(
self, llr_ch: torch.Tensor, num_iter: int
) -> torch.Tensor:
"""Iterative BP decoding function with LLR-values."""
bs = llr_ch.shape[0]
device = llr_ch.device
# Store intermediate tensors in lists
msg_l = [[None] * (self._n_stages + 1) for _ in range(num_iter)]
msg_r = [[None] * (self._n_stages + 1) for _ in range(num_iter)]
# Init frozen positions with infinity
msg_r_in = torch.zeros((bs, self._n), dtype=self.dtype, device=device)
msg_r_in[:, self._frozen_pos] = self._llr_max
# Perform decoding iterations
for ind_it in range(num_iter):
# Update left-to-right messages
for ind_s in range(self._n_stages):
ind_range = np.arange(int(self._n / 2))
ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
ind_2 = ind_1 + 2**ind_s
# Load incoming l messages
if ind_s == self._n_stages - 1:
l1_in = llr_ch[:, ind_1]
l2_in = llr_ch[:, ind_2]
elif ind_it == 0:
l1_in = torch.zeros(
(bs, int(self._n / 2)), dtype=self.dtype, device=device
)
l2_in = torch.zeros(
(bs, int(self._n / 2)), dtype=self.dtype, device=device
)
else:
l_in = msg_l[ind_it - 1][ind_s + 1]
l1_in = l_in[:, ind_1]
l2_in = l_in[:, ind_2]
# Load incoming r messages
if ind_s == 0:
r1_in = msg_r_in[:, ind_1]
r2_in = msg_r_in[:, ind_2]
else:
r_in = msg_r[ind_it][ind_s]
r1_in = r_in[:, ind_1]
r2_in = r_in[:, ind_2]
r1_out = self._boxplus(r1_in, l2_in + r2_in)
r2_out = self._boxplus(r1_in, l1_in) + r2_in
# Re-concatenate output
ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
r_out = torch.cat([r1_out, r2_out], 1)
r_out = r_out[:, ind_inv]
msg_r[ind_it][ind_s + 1] = r_out
# Update right-to-left messages
for ind_s in range(self._n_stages - 1, -1, -1):
ind_range = np.arange(int(self._n / 2))
ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
ind_2 = ind_1 + 2**ind_s
ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
# Load messages
if ind_s == self._n_stages - 1:
l1_in = llr_ch[:, ind_1]
l2_in = llr_ch[:, ind_2]
else:
l_in = msg_l[ind_it][ind_s + 1]
l1_in = l_in[:, ind_1]
l2_in = l_in[:, ind_2]
if ind_s == 0:
r1_in = msg_r_in[:, ind_1]
r2_in = msg_r_in[:, ind_2]
else:
r_in = msg_r[ind_it][ind_s]
r1_in = r_in[:, ind_1]
r2_in = r_in[:, ind_2]
# Node update functions
l1_out = self._boxplus(l1_in, l2_in + r2_in)
l2_out = self._boxplus(r1_in, l1_in) + l2_in
l_out = torch.cat([l1_out, l2_out], 1)
l_out = l_out[:, ind_inv]
msg_l[ind_it][ind_s] = l_out
# Recover u_hat using pre-registered buffer
u_hat = msg_l[num_iter - 1][0][:, self._info_pos_t]
if self._hard_out:
u_hat = torch.where(
u_hat > 0,
torch.zeros_like(u_hat),
torch.ones_like(u_hat),
)
else:
u_hat = -1.0 * u_hat # Re-transform to logits
return u_hat
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Build and check if shape of input is invalid."""
if input_shape[-1] != self._n:
raise ValueError("Invalid input shape")
def call(self, llr_ch: torch.Tensor) -> torch.Tensor:
"""Iterative BP decoding function.
This function performs ``num_iter`` belief propagation decoding
iterations and returns the estimated information bits.
:param llr_ch: Tensor of shape `[..., n]` containing the
channel logits/llr values.
:output u_hat: Tensor of shape `[..., k]` containing bit-wise
soft-estimates (or hard-decided bit-values) of all ``k``
information bits.
Note: This function recursively unrolls the BP decoding graph,
thus, for larger values of ``n`` or more iterations, building the
decoding graph can become time and memory consuming.
"""
# Reshape inputs to [-1, n]
input_shape = llr_ch.shape
new_shape = (-1, self._n)
llr_ch = llr_ch.reshape(new_shape)
llr_ch = -1.0 * llr_ch # Logits to LLRs
# Decode
u_hat = self._decode_bp(llr_ch, self._num_iter)
# Reconstruct input shape
output_shape = list(input_shape[:-1]) + [self.k]
u_hat_reshape = u_hat.reshape(output_shape)
return u_hat_reshape
[docs]
class Polar5GDecoder(Block):
# pylint: disable=line-too-long
"""Wrapper for 5G compliant decoding including rate-recovery and CRC
removal.
:param enc_polar: Instance of the
:class:`~sionna.phy.fec.polar.encoding.Polar5GEncoder` used for
encoding including rate-matching.
:param dec_type: Defining the decoder to be used. Must be one of
`{"SC", "SCL", "hybSCL", "BP"}`.
:param list_size: Defining the list size iff list-decoding is used.
Only required for ``dec_types`` `{"SCL", "hybSCL"}`.
:param num_iter: Defining the number of BP iterations. Only required
for ``dec_type`` `"BP"`.
:param return_crc_status: If `True`, the decoder additionally returns
the CRC status indicating if a codeword was (most likely) correctly
recovered.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the channel logits/llr values.
:output b_hat: [..., k], `torch.float`.
Binary tensor containing hard-decided estimations of all `k`
information bits.
:output crc_status: [...], `torch.bool`.
CRC status indicating if a codeword was (most likely) correctly
recovered. This is only returned if ``return_crc_status`` is `True`.
Note that false positives are possible.
.. rubric:: Notes
This block supports the uplink and downlink Polar rate-matching scheme
without `codeword segmentation`.
Although the decoding `list size` is not provided by 3GPP
:cite:p:`3GPPTS38212`, the consortium has agreed on a `list size` of 8 for
the 5G decoding reference curves :cite:p:`Bioglio_Design`.
All list-decoders apply `CRC-aided` decoding, however, the non-list
decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an
effective rate-loss.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import Polar5GEncoder, Polar5GDecoder
k, n = 100, 200
encoder = Polar5GEncoder(k, n)
decoder = Polar5GDecoder(encoder, dec_type="SCL", list_size=8)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
llr_ch = 20.0 * (2.0 * codewords - 1) # BPSK without noise
decoded = decoder(llr_ch)
print(torch.equal(bits, decoded))
# True
"""
def __init__(
self,
enc_polar: Polar5GEncoder,
dec_type: str = "SC",
list_size: int = 8,
num_iter: int = 20,
return_crc_status: bool = False,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(enc_polar, Polar5GEncoder):
raise TypeError("enc_polar must be Polar5GEncoder.")
if not isinstance(dec_type, str):
raise TypeError("dec_type must be str.")
# Store internal attributes
self._n_target = enc_polar.n_target
self._k_target = enc_polar.k_target
self._n_polar = enc_polar.n_polar
self._k_polar = enc_polar.k_polar
self._k_crc = enc_polar.enc_crc.crc_length
self._bil = enc_polar._channel_type == "uplink"
self._iil = enc_polar._channel_type == "downlink"
self._llr_max = 100
self._enc_polar = enc_polar
self._dec_type = dec_type
# Initialize the de-interleaver patterns
self._init_interleavers()
# Initialize decoder
if dec_type == "SC":
print(
"Warning: 5G Polar codes use an integrated CRC that "
"cannot be materialized with SC decoding and, thus, "
"causes a degraded performance. Please consider SCL "
"decoding instead."
)
self._polar_dec = PolarSCDecoder(
self._enc_polar.frozen_pos,
self._n_polar,
precision=precision,
device=device,
)
elif dec_type == "SCL":
self._polar_dec = PolarSCLDecoder(
self._enc_polar.frozen_pos,
self._n_polar,
crc_degree=self._enc_polar.enc_crc.crc_degree,
list_size=list_size,
ind_iil_inv=self.ind_iil_inv,
precision=precision,
device=device,
)
elif dec_type == "hybSCL":
self._polar_dec = PolarSCLDecoder(
self._enc_polar.frozen_pos,
self._n_polar,
crc_degree=self._enc_polar.enc_crc.crc_degree,
list_size=list_size,
use_hybrid_sc=True,
ind_iil_inv=self.ind_iil_inv,
precision=precision,
device=device,
)
elif dec_type == "BP":
print(
"Warning: 5G Polar codes use an integrated CRC that "
"cannot be materialized with BP decoding and, thus, "
"causes a degraded performance. Please consider SCL "
"decoding instead."
)
if not isinstance(num_iter, int):
raise TypeError("num_iter must be int.")
if num_iter <= 0:
raise ValueError("num_iter must be positive.")
self._num_iter = num_iter
self._polar_dec = PolarBPDecoder(
self._enc_polar.frozen_pos,
self._n_polar,
num_iter=num_iter,
hard_out=True,
precision=precision,
device=device,
)
else:
raise ValueError("Unknown value for dec_type.")
if not isinstance(return_crc_status, bool):
raise TypeError("return_crc_status must be bool.")
self._return_crc_status = return_crc_status
if self._return_crc_status:
if dec_type in ("SCL", "hybSCL"):
self._dec_crc = self._polar_dec._crc_decoder
else:
self._dec_crc = CRCDecoder(
self._enc_polar._enc_crc,
precision=precision,
device=device,
)
@property
def k_target(self) -> int:
"""Number of information bits including rate-matching."""
return self._k_target
@property
def n_target(self) -> int:
"""Codeword length including rate-matching."""
return self._n_target
@property
def k_polar(self) -> int:
"""Number of information bits of mother Polar code."""
return self._k_polar
@property
def n_polar(self) -> int:
"""Codeword length of mother Polar code."""
return self._n_polar
@property
def llr_max(self) -> float:
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def dec_type(self) -> str:
"""Decoder type used for decoding as str."""
return self._dec_type
@property
def polar_dec(self):
"""Decoder instance used for decoding."""
return self._polar_dec
def _init_interleavers(self) -> None:
"""Initialize inverse interleaver patterns for rate-recovery."""
# Channel interleaver
ind_ch_int = self._enc_polar.channel_interleaver(
np.arange(self._n_target)
)
self.ind_ch_int_inv = np.argsort(ind_ch_int)
# Sub-block interleaver
ind_sub_int = self._enc_polar.subblock_interleaving(
np.arange(self._n_polar)
)
self.ind_sub_int_inv = np.argsort(ind_sub_int)
# Input bit interleaver
if self._iil:
self.ind_iil_inv = np.argsort(
self._enc_polar.input_interleaver(np.arange(self._k_polar))
)
else:
self.ind_iil_inv = None
# Register as buffers for torch.compile compatibility
self.register_buffer(
"_ind_ch_int_inv_t",
torch.tensor(
self.ind_ch_int_inv, dtype=torch.int32, device=self.device
),
)
self.register_buffer(
"_ind_sub_int_inv_t",
torch.tensor(
self.ind_sub_int_inv, dtype=torch.int32, device=self.device
),
)
if self._iil:
self.register_buffer(
"_ind_iil_inv_t",
torch.tensor(
self.ind_iil_inv, dtype=torch.int32, device=self.device
),
)
else:
self._ind_iil_inv_t = None
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Build and check if shape of input is invalid."""
if input_shape[-1] != self._n_target:
raise ValueError("Invalid input shape.")
def call(
self, llr_ch: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Polar decoding and rate-recovery for uplink 5G Polar codes.
:param llr_ch: Tensor of shape `[..., n]` containing the
channel logits/llr values.
:output b_hat: Tensor of shape `[..., k]` containing hard-decided
estimates of all ``k`` information bits.
:output crc_status: CRC status. Returned only if
``return_crc_status`` is `True`.
"""
input_shape = llr_ch.shape
new_shape = (-1, self._n_target)
llr_ch = llr_ch.reshape(new_shape)
# 1.) Undo channel interleaving
if self._bil:
llr_deint = llr_ch[:, self._ind_ch_int_inv_t]
else:
llr_deint = llr_ch
# 2.) Remove puncturing, shortening, repetition
if self._n_target >= self._n_polar:
# Repetition coding
n_rep = self._n_target - self._n_polar
llr_1 = llr_deint[:, :n_rep]
llr_2 = llr_deint[:, n_rep : self._n_polar]
llr_3 = llr_deint[:, self._n_polar :]
llr_dematched = torch.cat([llr_1 + llr_3, llr_2], 1)
else:
if self._k_polar / self._n_target <= 7 / 16:
# Puncturing
llr_zero = torch.zeros(
(llr_deint.shape[0], self._n_polar - self._n_target),
dtype=self.dtype,
device=llr_deint.device,
)
llr_dematched = torch.cat([llr_zero, llr_deint], 1)
else:
# Shortening
llr_infty = (
-self._llr_max
* torch.ones(
(llr_deint.shape[0], self._n_polar - self._n_target),
dtype=self.dtype,
device=llr_deint.device,
)
)
llr_dematched = torch.cat([llr_deint, llr_infty], 1)
# 3.) Remove subblock interleaving
llr_dec = llr_dematched[:, self._ind_sub_int_inv_t]
# 4.) Run main decoder
u_hat_crc = self._polar_dec(llr_dec)
# 5.) Remove input bit interleaving for downlink channels only
if self._ind_iil_inv_t is not None:
u_hat_crc = u_hat_crc[:, self._ind_iil_inv_t]
# 6.) Evaluate or remove CRC (and PC)
if self._return_crc_status:
u_hat, crc_status = self._dec_crc(u_hat_crc)
else:
u_hat = u_hat_crc[:, : -self._k_crc]
# Reconstruct input shape
output_shape = list(input_shape[:-1]) + [self._k_target]
u_hat_reshape = u_hat.reshape(output_shape)
u_hat_reshape = u_hat_reshape.to(self.dtype)
if self._return_crc_status:
output_shape_crc = list(input_shape[:-1])
crc_status = crc_status.reshape(output_shape_crc)
return u_hat_reshape, crc_status
else:
return u_hat_reshape