#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Utility functions for LDPC decoding."""
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
import torch
from sionna.phy.object import Object
from sionna.phy.fec.utils import llr2mi
__all__ = [
"EXITCallback",
"DecoderStatisticsCallback",
"WeightedBPCallback",
]
[docs]
class EXITCallback(Object):
# pylint: disable=line-too-long
"""Callback for the LDPCBPDecoder to track EXIT statistics.
Can be registered as ``c2v_callbacks`` or ``v2c_callbacks`` in the
:class:`~sionna.phy.fec.ldpc.decoding.LDPCBPDecoder` and the
:class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`.
This callback requires all-zero codeword simulations.
:param num_iter: Maximum number of decoding iterations.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
:input msg: [batch_size, num_vns, max_degree], `torch.float`.
The v2c or c2v messages.
:input it: `int`.
Current number of decoding iterations.
:output msg: `torch.float`.
Same as ``msg``.
.. rubric:: Examples
.. code-block:: python
from sionna.phy.fec.ldpc import LDPCBPDecoder
from sionna.phy.fec.ldpc.utils import EXITCallback
# Create callback
exit_cb = EXITCallback(num_iter=20)
# Create decoder with callback
decoder = LDPCBPDecoder(pcm, v2c_callbacks=[exit_cb])
# After decoding, access mutual information
mi = exit_cb.mi
"""
def __init__(
self,
num_iter: int,
device: Optional[str] = None,
):
super().__init__(device=device)
self.register_buffer("_mi", torch.zeros(num_iter + 1, dtype=torch.float32, device=self.device))
self.register_buffer("_num_samples", torch.zeros(
num_iter + 1, dtype=torch.float32, device=self.device
))
@property
def mi(self) -> torch.Tensor:
"""Mutual information after each iteration"""
return self._mi / self._num_samples
def __call__(
self,
msg: torch.Tensor,
it: int,
*args,
**kwargs,
) -> torch.Tensor:
"""Process messages and update EXIT statistics."""
# Flatten messages and compute MI (exclude padded values)
msg_flat = msg.reshape(-1)
nonzero_mask = msg_flat != 0
if nonzero_mask.any():
mi_val = llr2mi(-1 * msg_flat[nonzero_mask])
self._mi[it] = self._mi[it] + mi_val
self._num_samples[it] = self._num_samples[it] + 1.0
return msg
[docs]
class DecoderStatisticsCallback(Object):
"""Callback for the LDPCBPDecoder to track decoder statistics.
Can be registered as ``c2v_callbacks`` in the
:class:`~sionna.phy.fec.ldpc.decoding.LDPCBPDecoder` and the
:class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`.
Remark: the decoding statistics are based on CN convergence, i.e.,
successful decoding is assumed if all check nodes are fulfilled.
This overestimates the success-rate as it includes cases where the decoder
converges to the wrong codeword.
:param num_iter: Maximum number of decoding iterations.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
:input msg: [batch_size, num_vns, max_degree], `torch.float`.
v2c messages.
:input it: `int`.
Current number of decoding iterations.
:output msg: `torch.float`.
Same as ``msg``.
.. rubric:: Examples
.. code-block:: python
from sionna.phy.fec.ldpc import LDPCBPDecoder
from sionna.phy.fec.ldpc.utils import DecoderStatisticsCallback
# Create callback
stats_cb = DecoderStatisticsCallback(num_iter=20)
# Create decoder with callback
decoder = LDPCBPDecoder(pcm, c2v_callbacks=[stats_cb])
# After decoding, access statistics
print(stats_cb.success_rate)
print(stats_cb.avg_number_iterations)
"""
def __init__(
self,
num_iter: int,
device: Optional[str] = None,
):
super().__init__(device=device)
self._num_iter = num_iter
self.register_buffer("_num_samples", torch.zeros(num_iter, dtype=torch.int64, device=self.device))
self.register_buffer("_decoded_samples", torch.zeros(
num_iter, dtype=torch.int64, device=self.device
))
@property
def num_samples(self) -> torch.Tensor:
"""Total number of processed codewords"""
return self._num_samples
@property
def num_decoded_cws(self) -> torch.Tensor:
"""Number of decoded codewords after each iteration"""
return self._decoded_samples
@property
def success_rate(self) -> torch.Tensor:
"""Success rate after each iteration"""
succ = self._decoded_samples.to(torch.float64)
num_samples = self._num_samples.to(torch.float64)
return succ / num_samples
@property
def avg_number_iterations(self) -> torch.Tensor:
"""Average number of decoding iterations"""
num_decoded = self._decoded_samples.to(torch.float64)
num_samples = self._num_samples.to(torch.float64)
num_active = num_samples - num_decoded
total_iters = num_active.sum()
avg_iter = total_iters / num_samples[0]
return avg_iter
[docs]
def reset_stats(self) -> None:
"""Reset internal statistics"""
self.register_buffer("_num_samples", torch.zeros(
self._num_iter, dtype=torch.int64, device=self.device
))
self.register_buffer("_decoded_samples", torch.zeros(
self._num_iter, dtype=torch.int64, device=self.device
))
def __call__(
self,
msg: torch.Tensor,
it: int,
*args,
**kwargs,
) -> torch.Tensor:
"""Process messages and update decoder statistics."""
# msg shape: [batch_size, num_nodes, max_degree]
sign_val = torch.sign(msg)
sign_val = torch.where(sign_val == 0, torch.ones_like(sign_val), sign_val)
sign_node = sign_val.prod(dim=2) # [bs, num_nodes]
node_success = sign_node > 0 # [bs, num_nodes]
cw_success = node_success.all(dim=1) # [bs]
num_decoded = cw_success.sum().to(torch.int64)
batch_size = msg.shape[0]
# Update statistics
if it < self._num_iter:
self._num_samples[it] = self._num_samples[it] + batch_size
self._decoded_samples[it] = self._decoded_samples[it] + num_decoded
return msg
[docs]
class WeightedBPCallback(Object):
# pylint: disable=line-too-long
r"""Callback for the LDPCBPDecoder to enable weighted BP :cite:p:`Nachmani`.
The BP decoder is fully differentiable and can be made trainable
by following the concept of *weighted BP* :cite:p:`Nachmani` leading to
.. math::
y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{tanh} \left( \frac{\textcolor{red}{w_{i' \to j}} \cdot x_{i' \to j}}{2} \right) \right)
where :math:`w_{i \to j}` denotes the trainable weight of message
:math:`x_{i \to j}`.
Please note that the training of some check node types may be not supported.
Can be registered as ``c2v_callbacks`` and ``v2c_callbacks`` in the
:class:`~sionna.phy.fec.ldpc.decoding.LDPCBPDecoder` and the
:class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`.
:param num_edges: Number of edges in the decoding graph.
:param pcm: Optional parity-check matrix. If provided, enables weighted BP
in padded message format used by the decoder.
:param precision: Precision used for internal calculations and outputs.
If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
:input msg: [batch_size, num_vns, max_degree], `torch.float`.
v2c messages.
:output msg: `torch.float`.
Same as ``msg``.
.. rubric:: Examples
.. code-block:: python
from sionna.phy.fec.ldpc import LDPCBPDecoder
from sionna.phy.fec.ldpc.utils import WeightedBPCallback
import numpy as np
# Create a simple parity-check matrix
pcm = np.array([[1, 1, 0, 1], [0, 1, 1, 1]])
# Create callback with trainable weights
weighted_cb = WeightedBPCallback(num_edges=np.sum(pcm), pcm=pcm)
# Create decoder with callback
decoder = LDPCBPDecoder(pcm, v2c_callbacks=[weighted_cb])
# Access trainable weights
print(weighted_cb.weights)
"""
def __init__(
self,
num_edges: int,
pcm: Optional[Union[np.ndarray, sp.spmatrix]] = None,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
# Note: Using nn.Parameter instead of register_buffer since requires_grad=True
self._edge_weights = torch.nn.Parameter(torch.ones(
num_edges, dtype=self.dtype, device=self.device
))
# Build indices for padded format if PCM is provided
self._has_pcm = pcm is not None
if self._has_pcm:
self._build_padded_indices(pcm)
def _build_padded_indices(
self, pcm: Union[np.ndarray, sp.spmatrix]
) -> None:
"""Build index arrays for mapping flat edge weights to padded format"""
# Convert to sparse if needed
if isinstance(pcm, np.ndarray):
pcm_sparse = sp.csr_matrix(pcm)
else:
pcm_sparse = pcm
# Get edge indices (same logic as in LDPCBPDecoder)
cn_idx, vn_idx, _ = sp.find(pcm_sparse)
# Sort by VN index (for VN-padded format)
idx_vn_sorted = np.argsort(vn_idx)
vn_idx_sorted = vn_idx[idx_vn_sorted]
# Sort by CN index (for CN-padded format)
idx_cn_sorted = np.argsort(cn_idx)
num_vns = pcm.shape[1]
num_cns = pcm.shape[0]
# Compute row splits for VN perspective
vn_row_splits = np.zeros(num_vns + 1, dtype=np.int32)
for i in vn_idx_sorted:
vn_row_splits[i + 1] += 1
vn_row_splits = np.cumsum(vn_row_splits)
# Compute row splits for CN perspective
cn_idx_sorted = cn_idx[idx_cn_sorted]
cn_row_splits = np.zeros(num_cns + 1, dtype=np.int32)
for i in cn_idx_sorted:
cn_row_splits[i + 1] += 1
cn_row_splits = np.cumsum(cn_row_splits)
# Compute max degrees
vn_degrees = np.diff(vn_row_splits)
cn_degrees = np.diff(cn_row_splits)
max_vn_degree = int(vn_degrees.max()) if len(vn_degrees) > 0 else 0
max_cn_degree = int(cn_degrees.max()) if len(cn_degrees) > 0 else 0
# Build VN gather index: maps (vn, position) -> edge_index
vn_gather_idx = np.zeros((num_vns, max_vn_degree), dtype=np.int32)
for vn in range(num_vns):
start = vn_row_splits[vn]
end = vn_row_splits[vn + 1]
degree = end - start
if degree > 0:
# idx_vn_sorted gives original edge indices sorted by VN
vn_gather_idx[vn, :degree] = idx_vn_sorted[start:end]
# Build CN gather index: maps (cn, position) -> edge_index
cn_gather_idx = np.zeros((num_cns, max_cn_degree), dtype=np.int32)
for cn in range(num_cns):
start = cn_row_splits[cn]
end = cn_row_splits[cn + 1]
degree = end - start
if degree > 0:
cn_gather_idx[cn, :degree] = idx_cn_sorted[start:end]
# Register as buffers
self.register_buffer(
"_vn_gather_idx",
torch.tensor(vn_gather_idx, dtype=torch.int32, device=self.device)
)
self.register_buffer(
"_cn_gather_idx",
torch.tensor(cn_gather_idx, dtype=torch.int32, device=self.device)
)
self._num_vns = num_vns
self._num_cns = num_cns
self._max_vn_degree = max_vn_degree
self._max_cn_degree = max_cn_degree
@property
def weights(self) -> torch.Tensor:
"""Trainable edge weights"""
return self._edge_weights
[docs]
def show_weights(self, size: float = 7) -> None:
"""Show histogram of trainable weights.
:param size: Figure size of the matplotlib figure.
"""
plt.figure(figsize=(size, size))
plt.hist(self._edge_weights.detach().cpu().numpy(), density=True, bins=20, align="mid")
plt.xlabel("weight value")
plt.ylabel("density")
plt.grid(True, which="both", axis="both")
plt.title("Weight Distribution")
def __call__(
self,
msg: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
"""Multiply messages with trainable weights for weighted BP."""
if msg.dim() == 2:
# Flat format [batch_size, num_edges]
msg = msg * self._edge_weights # broadcasts [bs, num_edges] * [num_edges]
elif msg.dim() == 3 and self._has_pcm:
# Padded format [batch_size, num_nodes, max_degree]
num_nodes = msg.shape[1]
max_degree = msg.shape[2]
if num_nodes == self._num_vns and max_degree == self._max_vn_degree:
gather_idx = self._vn_gather_idx
elif num_nodes == self._num_cns and max_degree == self._max_cn_degree:
gather_idx = self._cn_gather_idx
else:
return msg
weights_padded = self._edge_weights[gather_idx] # [num_nodes, max_degree]
# [bs, num_nodes, max_degree] * [1, num_nodes, max_degree]
msg = msg * weights_padded.unsqueeze(0)
return msg