#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for decoding of linear codes."""
from typing import Optional
import itertools
import warnings
import numpy as np
import scipy as sp
import torch
from sionna.phy import Block
from sionna.phy.fec.coding import pcm2gm, make_systematic
from sionna.phy.fec.utils import int_mod_2
from sionna.phy.utils import hard_decisions
__all__ = ["OSDecoder"]
[docs]
class OSDecoder(Block):
# pylint: disable=line-too-long
r"""Ordered statistics decoding (OSD) for binary, linear block codes.
This block implements the OSD algorithm as proposed in :cite:p:`Fossorier` and,
thereby, approximates maximum likelihood decoding for a sufficiently large
order :math:`t`. The algorithm works for arbitrary linear block codes, but
has a high computational complexity for long codes.
The algorithm consists of the following steps:
1. Sort LLRs according to their reliability and apply the same column
permutation to the generator matrix.
2. Bring the permuted generator matrix into its systematic form
(so-called *most-reliable basis*).
3. Hard-decide and re-encode the :math:`k` most reliable bits and
discard the remaining :math:`n-k` received positions.
4. Generate all possible error patterns up to :math:`t` errors in the
:math:`k` most reliable positions find the most likely codeword within
these candidates.
This implementation of the OSD algorithm uses the LLR-based distance metric
from :cite:p:`Stimming_LLR` which simplifies the handling of higher-order
modulation schemes.
:param enc_mat: Binary generator matrix of shape `[k, n]`. If ``is_pcm`` is
`True`, ``enc_mat`` is interpreted as parity-check matrix of shape
`[n-k, n]`.
:param t: Order of the OSD algorithm.
:param is_pcm: If `True`, ``enc_mat`` is interpreted as parity-check matrix.
:param encoder: Sionna block that implements a FEC encoder.
If not `None`, ``enc_mat`` will be ignored and the code as specified by
the encoder is used to initialize OSD.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the channel logits/llr values.
:output c_hat: [..., n], `torch.float`.
Tensor of same shape as ``llr_ch`` containing binary hard-decisions
of all codeword bits.
.. rubric:: Notes
OS decoding is of high complexity and is only feasible for small values of
:math:`t` as :math:`{n \choose t}` patterns must be evaluated. The
advantage of OSD is that it works for arbitrary linear block codes and
provides an estimate of the expected ML performance for sufficiently large
:math:`t`. However, for some code families, more efficient decoding
algorithms with close to ML performance exist which can exploit certain
code specific properties. Examples of such decoders are the
:class:`~sionna.phy.fec.conv.ViterbiDecoder` algorithm for convolutional codes
or the :class:`~sionna.phy.fec.polar.decoding.PolarSCLDecoder` for Polar codes
(for a sufficiently large list size).
It is recommended to run the decoder with ``torch.compile()`` as it
significantly reduces the memory complexity (typically 4-5x reduction)
and improves execution speed (typically 7x or more). Without compilation,
the decoder materializes large intermediate tensors of shape
``[batch_size, num_patterns, n]`` where ``num_patterns`` can be very
large for higher values of ``t``.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.utils import load_parity_check_examples
from sionna.phy.fec.linear import LinearEncoder, OSDecoder
# Load (7,4) Hamming code
pcm, k, n, _ = load_parity_check_examples(0)
encoder = LinearEncoder(pcm, is_pcm=True)
decoder = OSDecoder(encoder=encoder, t=2)
# Generate random codeword and add noise
u = torch.randint(0, 2, (10, k), dtype=torch.float32)
c = encoder(u)
llr_ch = 2.0 * (2.0 * c - 1.0) # Perfect LLRs
c_hat = decoder(llr_ch)
print(torch.equal(c, c_hat))
# True
"""
def __init__(
self,
enc_mat: Optional[np.ndarray] = None,
t: int = 0,
is_pcm: bool = False,
encoder: Optional[Block] = None,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(is_pcm, bool):
raise TypeError("is_pcm must be bool.")
self._llr_max = 100.0 # internal clipping value for llrs
if enc_mat is not None:
# Check that gm is binary
if isinstance(enc_mat, np.ndarray):
if not np.array_equal(enc_mat, enc_mat.astype(bool)):
raise ValueError("enc_mat must be binary.")
elif isinstance(enc_mat, sp.sparse.csr_matrix):
if not np.array_equal(enc_mat.data, enc_mat.data.astype(bool)):
raise ValueError("enc_mat must be binary.")
elif isinstance(enc_mat, sp.sparse.csc_matrix):
if not np.array_equal(enc_mat.data, enc_mat.data.astype(bool)):
raise ValueError("enc_mat must be binary.")
else:
raise TypeError("Unsupported dtype of enc_mat.")
if int(t) != t:
raise TypeError("t must be int.")
self._t = int(t)
if encoder is not None:
# Test that encoder is already initialized (relevant for conv codes)
if encoder.k is None:
raise ValueError(
"It seems as if the encoder is not "
"initialized or has no attribute k."
)
# Encode identity matrix to get k basis vectors of the code
u = torch.eye(encoder.k, dtype=self.dtype, device=self.device)
u = u.unsqueeze(0)
# Encode and remove batch_dim
gm = encoder(u).squeeze(0).to(self.dtype)
self._gm = gm
else:
if enc_mat is None:
raise ValueError("enc_mat cannot be None if no encoder is provided.")
if is_pcm:
gm = pcm2gm(enc_mat)
else:
# Check if gm is of full rank (raise error otherwise)
make_systematic(enc_mat)
gm = enc_mat
# Register as buffer for CUDAGraph compatibility
self.register_buffer("_gm", torch.tensor(gm, dtype=self.dtype, device=self.device))
self._k = self._gm.shape[0]
self._n = self._gm.shape[1]
# Init error patterns
num_patterns = self._num_error_patterns(self._n, self._t)
# Storage/computational complexity scales with n
num_symbols = num_patterns * self._n
if num_symbols > 1e9: # empirically found to be a good trade-off
warnings.warn(
f"Required memory complexity is large for the "
f"given code parameters and t={t}. Please consider small "
f"batch-sizes to keep the inference complexity small and "
f"activate torch.compile() if possible."
)
if num_symbols > 1e11: # empirically found to be a good trade-off
raise ResourceWarning(
"Due to its high complexity, OSD is not "
"feasible for the selected parameters. "
"Please consider using a smaller value for t."
)
# Pre-compute all error patterns
self._err_patterns = []
for t_i in range(1, t + 1):
self._err_patterns.append(self._gen_error_patterns(self._k, t_i))
@property
def gm(self) -> torch.Tensor:
"""Generator matrix of the code."""
return self._gm
@property
def n(self) -> int:
"""Codeword length."""
return self._n
@property
def k(self) -> int:
"""Number of information bits per codeword."""
return self._k
@property
def t(self) -> int:
"""Order of the OSD algorithm."""
return self._t
def _num_error_patterns(self, n: int, t: int) -> int:
r"""Returns number of possible error patterns for t errors in n
positions, i.e., calculates :math:`{n \choose t}`.
:param n: Length of vector.
:param t: Number of errors.
:output num_patterns: Number of error patterns.
"""
return sp.special.comb(n, t, exact=True, repetition=False)
def _gen_error_patterns(self, n: int, t: int) -> torch.Tensor:
r"""Returns tensor of all possible error patterns for t errors in n
positions.
:param n: Length of vector.
:param t: Number of errors.
:output err_patterns: Tensor of shape [`num_patterns`, `t`] where
`num_patterns` = :math:`{n \choose t}`, containing the `t` error
indices.
"""
err_patterns = list(itertools.combinations(range(n), t))
return torch.tensor(err_patterns, dtype=torch.long, device=self.device)
def _get_dist(self, llr: torch.Tensor, c_hat: torch.Tensor) -> torch.Tensor:
"""Distance function used for ML candidate selection.
Currently, the distance metric from Polar decoding :cite:p:`Stimming_LLR`
literature is implemented.
:param llr: Received llrs of the channel observations of shape [bs, n].
:param c_hat: Candidate codewords for which the distance to ``llr``
shall be evaluated of shape [bs, num_cand, n].
:output d: Distance between ``llr`` and ``c_hat`` for each of the
`num_cand` codeword candidates of shape [bs, num_cand].
Reference
---------
[Stimming_LLR] Alexios Balatsoukas-Stimming, Mani Bastani Parizi,
Andreas Burg, "LLR-Based Successive Cancellation List Decoding
of Polar Codes." IEEE Trans Signal Processing, 2015.
"""
# Broadcast llr to all codeword candidates
llr = llr.unsqueeze(1)
llr_sign = llr * (-2.0 * c_hat + 1.0) # apply BPSK mapping
d = torch.log1p(torch.exp(llr_sign))
return d.mean(dim=2)
def _find_min_dist(
self,
llr_ch: torch.Tensor,
ep: torch.Tensor,
gm_mrb: torch.Tensor,
c: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Find error pattern which leads to minimum distance.
Uses chunked processing when memory requirements exceed threshold to
prevent OOM errors for large numbers of error patterns.
:param llr_ch: Channel observations as llrs after mrb sorting of shape
[bs, n].
:param ep: Tensor of shape [`num_patterns`, `t`] where
`num_patterns` = :math:`{n \choose t}`, containing the `t` error
indices.
:param gm_mrb: Most reliable basis for each batch example of shape
[bs, k, n].
:param c: Most reliable base codeword of shape [bs, n].
:output d_best: Distance of shape [bs] for the most likely codeword
after testing all ``ep`` error patterns.
:output c_hat: Codeword of shape [bs, n] for the most likely
codeword after testing all ``ep`` error patterns.
"""
num_patterns = ep.shape[0]
t_val = ep.shape[1]
bs = llr_ch.shape[0]
# Estimate memory for full computation: e tensor [bs, num_patterns, t, n]
bytes_per_element = 4 if llr_ch.dtype == torch.float32 else 8
estimated_memory = bs * num_patterns * t_val * self._n * bytes_per_element
# Use chunking if estimated memory exceeds 1 GB
memory_threshold = 1 * 1024 * 1024 * 1024 # 1 GB
if estimated_memory <= memory_threshold:
# Small enough - process all at once (best for torch.compile)
e = gm_mrb[:, ep, :] # [bs, num_patterns, t, n]
e = e.sum(dim=2) # [bs, num_patterns, n]
e = e + c.unsqueeze(1)
c_cand = int_mod_2(e) # [bs, num_patterns, n]
d = self._get_dist(llr_ch, c_cand) # [bs, num_patterns]
idx = d.argmin(dim=1) # [bs]
batch_range = torch.arange(bs, device=self.device)
c_hat = c_cand[batch_range, idx] # [bs, n]
d_best = d[batch_range, idx] # [bs]
return d_best, c_hat
# Large case - process in chunks to avoid OOM
# Target ~256 MB per chunk for the e tensor
target_chunk_memory = 256 * 1024 * 1024
chunk_size = max(1, target_chunk_memory // (bs * t_val * self._n * bytes_per_element))
chunk_size = min(chunk_size, num_patterns)
# Initialize best distance and codeword
d_best = torch.full((bs,), float("inf"), dtype=llr_ch.dtype, device=self.device)
c_hat_best = c.clone()
# Process error patterns in chunks
for start in range(0, num_patterns, chunk_size):
end = min(start + chunk_size, num_patterns)
ep_chunk = ep[start:end]
# Generate test candidates for this chunk
e = gm_mrb[:, ep_chunk, :] # [bs, chunk_size, t, n]
e = e.sum(dim=2) # [bs, chunk_size, n]
e = e + c.unsqueeze(1)
c_cand = int_mod_2(e) # [bs, chunk_size, n]
# Calculate distance for each candidate
d = self._get_dist(llr_ch, c_cand) # [bs, chunk_size]
# Find best in this chunk
d_min_chunk, idx_chunk = d.min(dim=1) # [bs]
# Update best if this chunk has better candidates
improved = d_min_chunk < d_best
batch_range = torch.arange(bs, device=self.device)
c_hat_chunk = c_cand[batch_range, idx_chunk] # [bs, n]
c_hat_best = torch.where(improved.unsqueeze(1), c_hat_chunk, c_hat_best)
d_best = torch.where(improved, d_min_chunk, d_best)
return d_best, c_hat_best
def _find_mrb(
self, gm: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Find most reliable basis for all generator matrices in batch.
:param gm: Generator matrix for each batch example of shape [bs, k, n].
:output gm_mrb: Tensor of shape [bs, k, n] containing the most
reliable basis in systematic form for each batch example.
:output idx_sort: Tensor of shape [bs, n] containing the indices of
column permutations applied during mrb calculation.
"""
bs = gm.shape[0]
k = self._k
n = self._n
# Storage for pivot positions
idx_pivot = torch.zeros((bs, k), dtype=torch.long, device=self.device)
# Bring gm in systematic form (by so-called pivot method)
gm = gm.clone()
for idx_c in range(k):
# Find pivot (i.e., first pos with index 1)
idx_p = gm[:, idx_c, :].argmax(dim=-1) # [bs]
# Store pivot position
idx_pivot[:, idx_c] = idx_p
# And eliminate the column in all other rows
# Get the column at pivot position for each batch
batch_range = torch.arange(bs, device=self.device)
r = gm[batch_range, :, idx_p] # [bs, k]
# Ignore idx_c row itself by setting to zero
r[:, idx_c] = 0
# Mask is zero at all rows where pivot position of this row is zero
mask = r.unsqueeze(-1) # [bs, k, 1]
gm_off = gm[:, idx_c, :].unsqueeze(1) # [bs, 1, n]
# Update all rows in parallel (binary operations)
gm = int_mod_2(gm + mask * gm_off)
# Find non-pivot positions (i.e., all indices that are not part of idx_pivot)
# Add large offset to pivot indices and sorting gives the indices of interest
idx_range = torch.arange(n, device=self.device).unsqueeze(0).expand(bs, -1)
# Large value to be added to irrelevant indices
updates = n * torch.ones((bs, k), dtype=torch.long, device=self.device)
# Create scatter indices
batch_idx = torch.arange(bs, device=self.device).unsqueeze(1).expand(-1, k)
# Add large value to pivot positions
idx = idx_range.clone()
idx[batch_idx, idx_pivot] = idx[batch_idx, idx_pivot] + updates
# Sort and slice first n-k indices (equals parity positions)
sorted_idx = idx.argsort(dim=1)
idx_parity = sorted_idx[:, : n - k] # [bs, n-k]
idx_sort = torch.cat([idx_pivot, idx_parity], dim=1) # [bs, n]
# Permute gm according to indices idx_sort
batch_idx = torch.arange(bs, device=self.device).view(-1, 1, 1)
row_idx = torch.arange(k, device=self.device).view(1, -1, 1)
idx_sort_expanded = idx_sort.unsqueeze(1).expand(-1, k, -1)
gm = gm[batch_idx, row_idx, idx_sort_expanded]
return gm, idx_sort
[docs]
def build(self, input_shape: tuple) -> None:
"""Check for valid input shapes."""
if input_shape[-1] != self._n:
raise ValueError(f"Last dimension must be of size n={self._n}.")
def call(self, llr_ch: torch.Tensor, /) -> torch.Tensor:
r"""Applies ordered statistic decoding to inputs.
Remark: the decoder is implemented with llr definition
llr = p(x=1)/p(x=0).
:param llr_ch: Channel LLRs of shape [..., n].
:output c_hat: Hard decisions of shape [..., n].
"""
# Validate input shape
if llr_ch.shape[-1] != self._n:
raise ValueError(f"Last dimension must be of size n={self._n}.")
# Flatten batch-dim
input_shape = llr_ch.shape
llr_ch = llr_ch.reshape(-1, self._n)
bs = llr_ch.shape[0]
# Clip inputs
llr_ch = llr_ch.clamp(-self._llr_max, self._llr_max)
# Step 1: sort LLRs
idx_sort = llr_ch.abs().argsort(dim=-1, descending=True)
# Permute gm per batch sample individually
gm = self._gm.unsqueeze(0).expand(bs, -1, -1)
# Gather columns according to idx_sort
batch_idx = torch.arange(bs, device=self.device).view(-1, 1, 1)
row_idx = torch.arange(self._k, device=self.device).view(1, -1, 1)
idx_sort_expanded = idx_sort.unsqueeze(1).expand(-1, self._k, -1)
gm_sort = gm[batch_idx, row_idx, idx_sort_expanded]
# Step 2: Find most reliable basis (MRB)
gm_mrb, idx_mrb = self._find_mrb(gm_sort)
# Apply corresponding mrb permutations
batch_range = torch.arange(bs, device=self.device).unsqueeze(1)
idx_sort = idx_sort.gather(1, idx_mrb)
llr_sort = llr_ch.gather(1, idx_sort)
# Find inverse permutation for final output
idx_sort_inv = idx_sort.argsort(dim=1)
# Hard-decide k most reliable positions and encode
u_hd = hard_decisions(llr_sort[:, : self._k])
u_hd = u_hd.unsqueeze(1) # [bs, 1, k]
c = torch.matmul(u_hd, gm_mrb).squeeze(1) # [bs, n]
c = int_mod_2(c)
# And search for most likely pattern
# _get_dist expects a list of candidates, thus expand_dims to [bs, 1, n]
d_best = self._get_dist(llr_sort, c.unsqueeze(1))
d_best = d_best.squeeze(1) # [bs]
c_hat_best = c
# Known in advance - can be unrolled
for ep in self._err_patterns:
# Compute distance for all candidate codewords
d, c_hat = self._find_min_dist(llr_sort, ep, gm_mrb, c)
# Select most likely candidate
mask = (d < d_best).unsqueeze(1) # [bs, 1]
c_hat_best = torch.where(mask, c_hat, c_hat_best)
d_best = torch.where(d < d_best, d, d_best)
# Undo permutations for final codeword
c_hat_best = c_hat_best.gather(1, idx_sort_inv)
# Restore input shape
c_hat = c_hat_best.reshape(input_shape)
return c_hat