Source code for sionna.phy.fec.ldpc.decoding

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for channel decoding and utility functions."""

from typing import Callable, List, Optional, Tuple, Union
import types

import numpy as np
import scipy as sp
import torch

from sionna.phy import Block
from sionna.phy.fec.ldpc.encoding import LDPC5GEncoder


__all__ = [
    "LDPCBPDecoder",
    "LDPC5GDecoder",
    "vn_update_sum",
    "vn_node_update_identity",
    "cn_update_tanh",
    "cn_update_phi",
    "cn_update_minsum",
    "cn_update_offset_minsum",
    "cn_node_update_identity",
]


[docs] class LDPCBPDecoder(Block): # pylint: disable=line-too-long r"""Iterative belief propagation decoder for low-density parity-check (LDPC) codes and other codes on graphs. This class defines a generic belief propagation decoder for decoding with arbitrary parity-check matrices. It can be used to iteratively estimate/recover the transmitted codeword (or information bits) based on the LLR-values of the received noisy codeword observation. Per default, the decoder implements the flooding message passing algorithm :cite:p:`Ryan`, i.e., all nodes are updated in a parallel fashion. Different check node update functions are available: (1) `boxplus` .. math:: y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right) (2) `boxplus-phi` .. math:: y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right) with :math:`\phi(x)=-\operatorname{log}\left(\operatorname{tanh}\left(\frac{x}{2}\right)\right)` (3) `minsum` .. math:: \qquad y_{j \to i} = \alpha_{j \to i} \cdot \min_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}|\right) (4) `offset-minsum` .. math:: \qquad y_{j \to i} = \alpha_{j \to i} \cdot \max \left( \min_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}| \right)-\beta , 0\right) where :math:`\beta=0.5` and :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to :cite:p:`Ryan` and :cite:p:`Chen` for offset corrected minsum. Note that for full 5G 3GPP NR compatibility, the correct puncturing and shortening patterns must be applied (cf. :cite:p:`Richardson` for details), this can be done by :class:`~sionna.phy.fec.ldpc.encoding.LDPC5GEncoder` and :class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`, respectively. If required, the decoder can be made trainable and is fully differentiable by following the concept of *weighted BP* :cite:p:`Nachmani`. For this, custom callbacks can be registered that scale the messages during decoding. Please see the corresponding tutorial notebook for details. For numerical stability, the decoder applies LLR clipping of +/- ``llr_max`` to the input LLRs. :param pcm: An ndarray of shape `[n-k, n]` defining the parity-check matrix consisting only of `0` or `1` entries. Can also be of type `scipy.sparse.csr_matrix` or `scipy.sparse.csc_matrix`. :param cn_update: Check node update rule to be used as described above. One of "boxplus-phi" (default), "boxplus", "minsum", "offset-minsum", "identity", or a callable. If a callable is provided, it will be used instead as CN update. The input of the function is a tensor of v2c messages of shape `[batch_size, num_cns, max_degree]` with a mask of shape `[num_cns, max_degree]`. :param vn_update: Variable node update rule to be used. One of "sum" (default), "identity", or a callable. If a callable is provided, it will be used instead as VN update. The input of the function is a tensor of c2v messages of shape `[batch_size, num_vns, max_degree]` with a mask of shape `[num_vns, max_degree]`. :param cn_schedule: Defines the CN update scheduling per BP iteration. Can be either "flooding" to update all nodes in parallel (recommended) or a 2D tensor of shape `[num_update_steps, num_active_nodes]` where each row defines the node indices to be updated per subiteration. In this case each BP iteration runs ``num_update_steps`` subiterations, thus the decoder's level of parallelization is lower and usually the decoding throughput decreases. :param hard_out: If `True`, the decoder provides hard-decided codeword bits instead of soft-values. :param num_iter: Defining the number of decoder iterations (due to batching, no early stopping used at the moment!). :param llr_max: Internal clipping value for all internal messages. If `None`, no clipping is applied. :param v2c_callbacks: Each callable will be executed after each VN update with the following arguments ``msg_vn``, ``it``, ``x_hat``, where ``msg_vn`` are the v2c messages as tensor of shape `[batch_size, num_vns, max_degree]`, ``x_hat`` is the current estimate of each VN of shape `[batch_size, num_vns]`, and ``it`` is the current iteration counter. It must return an updated version of ``msg_vn`` of same shape. :param c2v_callbacks: Each callable will be executed after each CN update with the following arguments ``msg_cn`` and ``it`` where ``msg_cn`` are the c2v messages as tensor of shape `[batch_size, num_cns, max_degree]` and ``it`` is the current iteration counter. It must return an updated version of ``msg_cn`` of same shape. :param return_state: If `True`, the internal VN messages ``msg_vn`` from the last decoding iteration are returned, and ``msg_vn`` or `None` needs to be given as a second input when calling the decoder. This can be used for iterative demapping and decoding. :param precision: Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). :input llr_ch: [..., n], `torch.float`. Tensor containing the channel logits/llr values. :input msg_v2c: `None` | [batch_size, num_edges], `torch.float`. Tensor of VN messages representing the internal decoder state. Required only if the decoder shall use its previous internal state, e.g., for iterative detection and decoding (IDD) schemes. :output x_hat: [..., n], `torch.float`. Tensor of same shape as ``llr_ch`` containing bit-wise soft-estimates (or hard-decided bit-values) of all codeword bits. :output msg_v2c: [batch_size, num_edges], `torch.float`. Tensor of VN messages representing the internal decoder state. Returned only if ``return_state`` is set to `True`. .. rubric:: Notes As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for compatibility with the learning framework, but internally log-likelihood ratios (LLRs) with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC codes and, thus, supports arbitrary parity-check matrices. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.utils import load_parity_check_examples from sionna.phy.fec.ldpc import LDPCBPDecoder # Load (7,4) Hamming code pcm, k, n, _ = load_parity_check_examples(0) decoder = LDPCBPDecoder(pcm, num_iter=10) # Decode random LLRs llr_ch = torch.randn(100, n) * 2.0 c_hat = decoder(llr_ch) print(c_hat.shape) # torch.Size([100, 7]) """ def __init__( self, pcm: Union[np.ndarray, sp.sparse.csr_matrix, sp.sparse.csc_matrix], cn_update: Union[str, Callable] = "boxplus-phi", vn_update: Union[str, Callable] = "sum", cn_schedule: Union[str, np.ndarray, torch.Tensor] = "flooding", hard_out: bool = True, num_iter: int = 20, llr_max: Optional[float] = 20.0, v2c_callbacks: Optional[List[Callable]] = None, c2v_callbacks: Optional[List[Callable]] = None, return_state: bool = False, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): super().__init__(precision=precision, device=device, **kwargs) # Check inputs for consistency if not isinstance(hard_out, bool): raise TypeError("hard_out must be bool.") if not isinstance(num_iter, int): raise TypeError("num_iter must be int.") if num_iter < 0: raise ValueError("num_iter cannot be negative.") if not isinstance(return_state, bool): raise TypeError("return_state must be bool.") if isinstance(pcm, np.ndarray): if not np.array_equal(pcm, pcm.astype(bool)): raise ValueError("PC matrix must be binary.") elif isinstance(pcm, sp.sparse.csr_matrix): if not np.array_equal(pcm.data, pcm.data.astype(bool)): raise ValueError("PC matrix must be binary.") elif isinstance(pcm, sp.sparse.csc_matrix): if not np.array_equal(pcm.data, pcm.data.astype(bool)): raise ValueError("PC matrix must be binary.") else: raise TypeError("Unsupported dtype of pcm.") # Deprecation warning for cn_type if "cn_type" in kwargs: raise TypeError("'cn_type' is deprecated; use 'cn_update' instead.") # Init decoder parameters self._pcm = pcm self._hard_out = hard_out self._num_iter = num_iter self._return_state = return_state self._num_cns = pcm.shape[0] # total number of check nodes self._num_vns = pcm.shape[1] # total number of variable nodes # Internal value for LLR clipping if llr_max is not None and not isinstance(llr_max, (int, float)): raise TypeError("llr_max must be int or float.") self._llr_max = float(llr_max) if llr_max is not None else None if v2c_callbacks is None: self._v2c_callbacks = [] else: if isinstance(v2c_callbacks, (list, tuple)): self._v2c_callbacks = list(v2c_callbacks) elif isinstance(v2c_callbacks, types.FunctionType): self._v2c_callbacks = [v2c_callbacks] else: raise TypeError("v2c_callbacks must be a list of callables.") if c2v_callbacks is None: self._c2v_callbacks = [] else: if isinstance(c2v_callbacks, (list, tuple)): self._c2v_callbacks = list(c2v_callbacks) elif isinstance(c2v_callbacks, types.FunctionType): self._c2v_callbacks = [c2v_callbacks] else: raise TypeError("c2v_callbacks must be a list of callables.") # Make pcm sparse first if ndarray is provided if isinstance(pcm, np.ndarray): pcm_sparse = sp.sparse.csr_matrix(pcm) else: pcm_sparse = pcm # Assign all edges to CN and VN nodes, respectively cn_idx, vn_idx, _ = sp.sparse.find(pcm_sparse) # Sort indices explicitly (scipy.sparse.find changed from column to # row sorting in scipy>=1.11) idx = np.argsort(vn_idx) self._cn_idx = cn_idx[idx] self._vn_idx = vn_idx[idx] # Number of edges equals number of non-zero elements in PCM self._num_edges = len(self._vn_idx) # Pre-load the CN function if cn_update == "boxplus": self._cn_update = cn_update_tanh elif cn_update == "boxplus-phi": self._cn_update = cn_update_phi elif cn_update in ("minsum", "min"): self._cn_update = cn_update_minsum elif cn_update == "offset-minsum": self._cn_update = cn_update_offset_minsum elif cn_update == "identity": self._cn_update = cn_node_update_identity elif callable(cn_update): self._cn_update = cn_update else: raise TypeError("Provided cn_update not supported.") # Pre-load the VN function if vn_update == "sum": self._vn_update = vn_update_sum elif vn_update == "identity": self._vn_update = vn_node_update_identity elif callable(vn_update): self._vn_update = vn_update else: raise TypeError("Provided vn_update not supported.") ###################### # Init graph structure ###################### # Handle scheduling # Register as buffers for CUDAGraph compatibility if isinstance(cn_schedule, str) and cn_schedule == "flooding": self._scheduling = "flooding" self.register_buffer( "_cn_schedule", torch.arange( self._num_cns, dtype=torch.int32, device=self.device ).unsqueeze(0), ) elif isinstance(cn_schedule, (np.ndarray, torch.Tensor)): if isinstance(cn_schedule, np.ndarray): cn_schedule = torch.tensor( cn_schedule, dtype=torch.int32, device=self.device ) else: cn_schedule = cn_schedule.to(dtype=torch.int32, device=self.device) self._scheduling = "custom" if len(cn_schedule.shape) != 2: raise ValueError("cn_schedule must be of rank 2.") if cn_schedule.max() >= self._num_cns: raise ValueError( "cn_schedule can only contain values smaller than num_cns." ) if cn_schedule.min() < 0: raise ValueError("cn_schedule cannot contain negative values.") self.register_buffer("_cn_schedule", cn_schedule) else: raise ValueError("cn_schedule can be 'flooding' or an array of ints.") # Build index arrays for message permutation # Permutation index to rearrange edge messages into CN perspective v2c_perm = np.argsort(self._cn_idx) # And the inverse operation v2c_perm_inv = np.argsort(v2c_perm) self.register_buffer( "_v2c_perm", torch.tensor(v2c_perm, dtype=torch.int32, device=self.device) ) self.register_buffer( "_v2c_perm_inv", torch.tensor(v2c_perm_inv, dtype=torch.int32, device=self.device), ) self.register_buffer( "_vn_idx_t", torch.tensor(self._vn_idx, dtype=torch.int32, device=self.device), ) self.register_buffer( "_cn_idx_t", torch.tensor(self._cn_idx, dtype=torch.int32, device=self.device), ) # Compute row splits for CN perspective (after v2c_perm) cn_idx_sorted = self._cn_idx[v2c_perm] cn_row_splits = self._compute_row_splits(cn_idx_sorted, self._num_cns) self.register_buffer( "_cn_row_splits", torch.tensor(cn_row_splits, dtype=torch.int32, device=self.device), ) # Compute row splits for VN perspective vn_row_splits = self._compute_row_splits(self._vn_idx, self._num_vns) self.register_buffer( "_vn_row_splits", torch.tensor(vn_row_splits, dtype=torch.int32, device=self.device), ) # Compute max degrees for padding cn_degrees = np.diff(cn_row_splits) vn_degrees = np.diff(vn_row_splits) self._max_cn_degree = int(cn_degrees.max()) if len(cn_degrees) > 0 else 0 self._max_vn_degree = int(vn_degrees.max()) if len(vn_degrees) > 0 else 0 # Build padded index arrays for vectorized operations self._cn_gather_idx, self._cn_mask = self._build_padded_indices( cn_idx_sorted, cn_row_splits, self._num_cns, self._max_cn_degree ) self._vn_gather_idx, self._vn_mask = self._build_padded_indices( self._vn_idx, vn_row_splits, self._num_vns, self._max_vn_degree ) # Build scatter indices for CN update # This maps from padded CN messages back to edge format self._cn_scatter_idx = self._build_scatter_indices( cn_row_splits, self._num_cns, self._max_cn_degree ) # Build scatter indices for VN update self._vn_scatter_idx = self._build_scatter_indices( vn_row_splits, self._num_vns, self._max_vn_degree ) # Precompute valid position indices for scatter operations (avoids dynamic shapes) # For CN: which positions in the flattened padded array are valid cn_valid_positions, cn_valid_edge_idx = self._build_valid_scatter_indices( cn_row_splits, self._num_cns, self._max_cn_degree ) self.register_buffer("_cn_valid_positions", cn_valid_positions) self.register_buffer("_cn_valid_edge_idx", cn_valid_edge_idx) # For VN: which positions in the flattened padded array are valid vn_valid_positions, vn_valid_edge_idx = self._build_valid_scatter_indices( vn_row_splits, self._num_vns, self._max_vn_degree ) self.register_buffer("_vn_valid_positions", vn_valid_positions) self.register_buffer("_vn_valid_edge_idx", vn_valid_edge_idx) # Precompute per-subiteration index arrays for custom scheduling if self._scheduling == "custom": self._build_custom_schedule_indices(cn_row_splits, v2c_perm) def _compute_row_splits(self, idx: np.ndarray, num_nodes: int) -> np.ndarray: """Compute row splits from sorted indices.""" row_splits = np.zeros(num_nodes + 1, dtype=np.int32) for i in idx: row_splits[i + 1] += 1 row_splits = np.cumsum(row_splits) return row_splits def _build_padded_indices( self, idx: np.ndarray, row_splits: np.ndarray, num_nodes: int, max_degree: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Build padded gather indices and mask for vectorized operations.""" # Create padded index array gather_idx = np.zeros((num_nodes, max_degree), dtype=np.int32) mask = np.zeros((num_nodes, max_degree), dtype=np.float32) for node in range(num_nodes): start = row_splits[node] end = row_splits[node + 1] degree = end - start if degree > 0: gather_idx[node, :degree] = np.arange(start, end) mask[node, :degree] = 1.0 return ( torch.tensor(gather_idx, dtype=torch.int32, device=self.device), torch.tensor(mask, dtype=self.dtype, device=self.device), ) def _build_scatter_indices( self, row_splits: np.ndarray, num_nodes: int, max_degree: int ) -> torch.Tensor: """Build scatter indices for converting padded format back to flat.""" scatter_idx = np.zeros((num_nodes, max_degree), dtype=np.int32) for node in range(num_nodes): start = row_splits[node] end = row_splits[node + 1] degree = end - start if degree > 0: scatter_idx[node, :degree] = np.arange(start, end) return torch.tensor(scatter_idx, dtype=torch.int32, device=self.device) def _build_valid_scatter_indices( self, row_splits: np.ndarray, num_nodes: int, max_degree: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Build precomputed valid position and edge indices for scatter operations. This precomputes which positions in the flattened padded format are valid and their corresponding edge indices, avoiding dynamic shape operations during forward pass. :param row_splits: Row splits array for the nodes. :param num_nodes: Number of nodes. :param max_degree: Maximum degree (padding size). :output valid_positions: [num_edges] indices into flattened padded array. :output valid_edge_idx: [num_edges] corresponding edge indices. """ valid_positions = [] valid_edge_idx = [] for node in range(num_nodes): start = row_splits[node] end = row_splits[node + 1] degree = end - start for d in range(degree): # Position in flattened padded array: node * max_degree + d flat_pos = node * max_degree + d # Corresponding edge index edge_idx = start + d valid_positions.append(flat_pos) valid_edge_idx.append(edge_idx) return ( torch.tensor(valid_positions, dtype=torch.int32, device=self.device), torch.tensor(valid_edge_idx, dtype=torch.int32, device=self.device), ) def _build_custom_schedule_indices( self, cn_row_splits: np.ndarray, v2c_perm: np.ndarray ) -> None: """Build index arrays for custom/layered CN scheduling. For each sub-iteration, we need indices to: 1. Gather messages for only the active CNs 2. Scatter updated messages back to the correct edge positions v2c_perm[cn_order_pos] gives the VN-order index of the edge at CN-order position cn_order_pos. """ num_sub_iters = self._cn_schedule.shape[0] self._schedule_gather_idx = [] self._schedule_cn_mask = [] self._schedule_edge_idx = [] # Edge indices in VN order for scatter update self._schedule_valid_positions = ( [] ) # Precomputed valid positions (avoids dynamic shapes) cn_schedule_np = self._cn_schedule.cpu().numpy() for j in range(num_sub_iters): active_cns = cn_schedule_np[j] num_active = len(active_cns) # Compute max degree for active CNs in this sub-iteration active_degrees = [] for cn in active_cns: start = cn_row_splits[cn] end = cn_row_splits[cn + 1] active_degrees.append(end - start) max_active_degree = max(active_degrees) if active_degrees else 0 # Build gather indices for active CNs gather_idx = np.zeros((num_active, max_active_degree), dtype=np.int32) mask = np.zeros((num_active, max_active_degree), dtype=np.float32) edge_indices = [] # Edge indices in VN order valid_positions = [] # Valid positions in flattened padded array for i, cn in enumerate(active_cns): start = cn_row_splits[cn] end = cn_row_splits[cn + 1] degree = end - start if degree > 0: gather_idx[i, :degree] = np.arange(start, end) mask[i, :degree] = 1.0 # Edge positions in CN order edge_cn_order = np.arange(start, end) # v2c_perm[cn_pos] gives the VN-order index of edge at cn_pos edge_vn_order = v2c_perm[edge_cn_order] edge_indices.extend(edge_vn_order.tolist()) # Precompute valid positions in flattened format for d in range(degree): valid_positions.append(i * max_active_degree + d) self._schedule_gather_idx.append( torch.tensor(gather_idx, dtype=torch.int32, device=self.device) ) self._schedule_cn_mask.append( torch.tensor(mask, dtype=self.dtype, device=self.device) ) self._schedule_edge_idx.append( torch.tensor(edge_indices, dtype=torch.int32, device=self.device) ) self._schedule_valid_positions.append( torch.tensor(valid_positions, dtype=torch.int32, device=self.device) ) ############################### # Public methods and properties ############################### @property def pcm(self) -> Union[np.ndarray, sp.sparse.csr_matrix]: """Parity-check matrix of LDPC code.""" return self._pcm @property def num_cns(self) -> int: """Number of check nodes.""" return self._num_cns @property def num_vns(self) -> int: """Number of variable nodes.""" return self._num_vns @property def n(self) -> int: """Codeword length.""" return self._num_vns @property def coderate(self) -> float: """Coderate assuming independent parity checks.""" return (self._num_vns - self._num_cns) / self._num_vns @property def num_edges(self) -> int: """Number of edges in decoding graph.""" return self._num_edges @property def num_iter(self) -> int: """Number of decoding iterations.""" return self._num_iter @num_iter.setter def num_iter(self, num_iter: int) -> None: """Set number of decoding iterations.""" if not isinstance(num_iter, int): raise TypeError("num_iter must be int.") if num_iter < 0: raise ValueError("num_iter cannot be negative.") self._num_iter = num_iter @property def llr_max(self) -> Optional[float]: """Max LLR value used for internal calculations.""" return self._llr_max @llr_max.setter def llr_max(self, value: float) -> None: """Set max LLR value.""" if value is not None and value < 0: raise ValueError("llr_max cannot be negative.") self._llr_max = float(value) if value is not None else None @property def return_state(self) -> bool: """Return internal decoder state for IDD schemes.""" return self._return_state ######################### # Decoding functions ######################### def _gather_to_cn(self, msg_v2c: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Gather v2c messages to CN perspective with padding. :param msg_v2c: [batch_size, num_edges] tensor of v2c messages. :output msg_cn: [batch_size, num_cns, max_cn_degree] tensor of CN messages. :output mask: [num_cns, max_cn_degree] mask tensor. """ msg_cn_flat = msg_v2c[:, self._v2c_perm] # [bs, num_edges] msg_cn = msg_cn_flat[ :, self._cn_gather_idx ] # [bs, num_cns, max_cn_degree] # mask [num_cns, max_cn_degree] broadcasts with [bs, num_cns, max_deg] msg_cn = msg_cn * self._cn_mask return msg_cn, self._cn_mask def _scatter_from_cn( self, msg_cn: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """Scatter CN messages back to flat edge format. :param msg_cn: [batch_size, num_cns, max_cn_degree] tensor. :param mask: [num_cns, max_cn_degree] mask tensor (unused, kept for API). :output msg_c2v: [batch_size, num_edges] tensor in VN order. """ batch_size = msg_cn.shape[0] msg_flat = msg_cn.reshape( batch_size, -1 ) # [bs, num_cns * max_cn_degree] msg_c2v = torch.zeros( batch_size, self._num_edges, dtype=msg_cn.dtype, device=msg_cn.device ) valid_msg = msg_flat[:, self._cn_valid_positions] # [bs, num_edges] msg_c2v[:, self._cn_valid_edge_idx] = valid_msg msg_c2v = msg_c2v[:, self._v2c_perm_inv] return msg_c2v def _gather_to_vn(self, msg_c2v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Gather c2v messages to VN perspective with padding. :param msg_c2v: [batch_size, num_edges] tensor of c2v messages (VN order). :output msg_vn: [batch_size, num_vns, max_vn_degree] padded VN messages. :output mask: [num_vns, max_vn_degree] VN mask tensor. """ msg_vn = msg_c2v[ :, self._vn_gather_idx ] # [bs, num_vns, max_vn_degree] # mask [num_vns, max_vn_degree] broadcasts with [bs, num_vns, max_deg] msg_vn = msg_vn * self._vn_mask return msg_vn, self._vn_mask def _scatter_from_vn( self, msg_vn: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: """Scatter VN messages back to flat edge format. :param msg_vn: [batch_size, num_vns, max_vn_degree] tensor. :param mask: [num_vns, max_vn_degree] mask tensor (unused, kept for API). :output msg_v2c: [batch_size, num_edges] tensor. """ batch_size = msg_vn.shape[0] msg_flat = msg_vn.reshape(batch_size, -1) msg_v2c = torch.zeros( batch_size, self._num_edges, dtype=msg_vn.dtype, device=msg_vn.device ) valid_msg = msg_flat[:, self._vn_valid_positions] # [bs, num_edges] msg_v2c[:, self._vn_valid_edge_idx] = valid_msg return msg_v2c def _bp_iter( self, msg_v2c: torch.Tensor, msg_c2v: torch.Tensor, llr_ch: torch.Tensor, x_hat: torch.Tensor, it: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Main decoding iteration. :param msg_v2c: [batch_size, num_edges] v2c messages. :param msg_c2v: [batch_size, num_edges] c2v messages. :param llr_ch: [batch_size, num_vns] channel LLRs. :param x_hat: [batch_size, num_vns] current estimate. :param it: Current iteration number. :output msg_v2c: Updated v2c messages. :output msg_c2v: Updated c2v messages. :output x_hat: Updated VN estimates. """ # Process all sub-iterations for j in range(self._cn_schedule.shape[0]): if self._scheduling == "flooding": # Flooding: update all CNs in parallel # Gather messages to CN perspective msg_cn, cn_mask = self._gather_to_cn(msg_v2c) # Apply CN update msg_cn_out = self._cn_update(msg_cn, cn_mask, self._llr_max) # Apply CN callbacks for cb in self._c2v_callbacks: msg_cn_out = cb(msg_cn_out, it) # Scatter back to edge format msg_c2v = self._scatter_from_cn(msg_cn_out, cn_mask) else: # Custom/layered scheduling: update only active CNs msg_c2v = self._bp_iter_custom_cn(msg_v2c, msg_c2v, j, it) # Gather messages to VN perspective msg_vn, vn_mask = self._gather_to_vn(msg_c2v) # Apply VN update msg_vn_out, x_hat = self._vn_update(msg_vn, vn_mask, llr_ch, self._llr_max) # Apply VN callbacks for cb in self._v2c_callbacks: msg_vn_out = cb(msg_vn_out, it + 1, x_hat) # Scatter back to edge format msg_v2c = self._scatter_from_vn(msg_vn_out, vn_mask) return msg_v2c, msg_c2v, x_hat def _bp_iter_custom_cn( self, msg_v2c: torch.Tensor, msg_c2v: torch.Tensor, sub_iter: int, it: int, ) -> torch.Tensor: """Process CN update for custom scheduling (only active CNs). :param msg_v2c: [batch_size, num_edges] v2c messages in VN order. :param msg_c2v: [batch_size, num_edges] current c2v messages. :param sub_iter: Sub-iteration index (which CNs to update). :param it: Current iteration number. :output msg_c2v: Updated c2v messages. """ batch_size = msg_v2c.shape[0] gather_idx = self._schedule_gather_idx[sub_iter] cn_mask = self._schedule_cn_mask[sub_iter] edge_idx = self._schedule_edge_idx[sub_iter] valid_positions = self._schedule_valid_positions[sub_iter] msg_v2c_cn_order = msg_v2c[:, self._v2c_perm] msg_cn = msg_v2c_cn_order[ :, gather_idx ] # [bs, num_active_cns, max_degree] # mask [num_active_cns, max_degree] broadcasts with [bs, ...] msg_cn = msg_cn * cn_mask msg_cn_out = self._cn_update(msg_cn, cn_mask, self._llr_max) for cb in self._c2v_callbacks: msg_cn_out = cb(msg_cn_out, it) msg_flat = msg_cn_out.reshape(batch_size, -1) valid_msg = msg_flat[:, valid_positions] msg_c2v = msg_c2v.clone() msg_c2v[:, edge_idx] = valid_msg return msg_c2v
[docs] def build(self, input_shape: tuple, **kwargs) -> None: """Build block and validate input shape.""" if input_shape[-1] != self._num_vns: raise ValueError("Last dimension must be of length n.")
def call( self, llr_ch: torch.Tensor, /, *, num_iter: Optional[int] = None, msg_v2c: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Iterative BP decoding function. :param llr_ch: Channel LLRs of shape [..., n]. :param num_iter: Number of iterations. If `None`, uses default. :param msg_v2c: Initial v2c messages for IDD schemes. :output x_hat: Decoded bits of shape [..., n]. :output msg_v2c: Decoder state. Returned only if ``return_state`` is `True`. """ if num_iter is None: num_iter = self._num_iter # Clip LLRs for numerical stability if self._llr_max is not None: llr_ch = llr_ch.clamp(-self._llr_max, self._llr_max) # Reshape to support multi-dimensional inputs llr_ch_shape = list(llr_ch.shape) new_shape = [-1, self._num_vns] llr_ch_reshaped = llr_ch.reshape(new_shape) # [batch_size, num_vns] # Logits are converted into "true" LLRs as usually done in literature llr_ch_reshaped = llr_ch_reshaped * -1.0 # Initialize v2c messages if msg_v2c is None: msg_v2c = llr_ch_reshaped[:, self._vn_idx_t] # [bs, num_edges] else: msg_v2c = msg_v2c * -1.0 # invert sign due to logit definition # Messages from CN perspective; initialized to zero msg_c2v = torch.zeros_like(msg_v2c) # Apply VN callbacks before first iteration if self._v2c_callbacks: msg_vn, vn_mask = self._gather_to_vn(msg_v2c) for cb in self._v2c_callbacks: msg_vn = cb(msg_vn, 0, llr_ch_reshaped) msg_v2c = self._scatter_from_vn(msg_vn, vn_mask) # Initialize x_hat x_hat = llr_ch_reshaped # Main decoding loop for it in range(num_iter): msg_v2c, msg_c2v, x_hat = self._bp_iter( msg_v2c, msg_c2v, llr_ch_reshaped, x_hat, it ) if self._hard_out: # Hard decide decoder output x_hat = (x_hat <= 0).to(self.dtype) else: x_hat = x_hat * -1.0 # convert LLRs back into logits # Reshape to match original input dimensions output_shape = llr_ch_shape.copy() x_reshaped = x_hat.reshape(output_shape) if not self._return_state: return x_reshaped else: msg_v2c = msg_v2c * -1.0 # invert sign due to logit definition return x_reshaped, msg_v2c
####################### # Node update functions ####################### def vn_node_update_identity( msg_c2v: torch.Tensor, mask: torch.Tensor, llr_ch: torch.Tensor, llr_clipping: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=line-too-long r"""Dummy variable node update function for testing. Behaves as an identity function and can be used for testing and debugging of message passing decoding. Marginalizes input messages and returns them as second output. :param msg_c2v: Tensor of shape `[batch_size, num_nodes, max_degree]` representing c2v messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_ch: Tensor of shape `[batch_size, num_nodes]` containing the channel LLRs. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. """ x_tot = msg_c2v.sum(dim=2) + llr_ch # [bs, num_nodes] return msg_c2v, x_tot
[docs] def vn_update_sum( msg_c2v: torch.Tensor, mask: torch.Tensor, llr_ch: torch.Tensor, llr_clipping: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # pylint: disable=line-too-long r"""Variable node update function implementing the `sum` update. This function implements the (extrinsic) variable node update function. It takes the sum over all incoming messages ``msg`` excluding the intrinsic (= outgoing) message itself. Additionally, the channel LLR ``llr_ch`` is considered in each variable node. :param msg_c2v: Tensor of shape `[batch_size, num_nodes, max_degree]` representing c2v messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_ch: Tensor of shape `[batch_size, num_nodes]` containing the channel LLRs. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. """ x = msg_c2v.sum(dim=2) # [bs, num_nodes] x_tot = x + llr_ch # Extrinsic message: total - intrinsic x_e = x_tot.unsqueeze(2) - msg_c2v # [bs, num_nodes, max_degree] # mask [num_nodes, max_degree] broadcasts with [bs, num_nodes, max_degree] x_e = x_e * mask if llr_clipping is not None: x_e = x_e.clamp(-llr_clipping, llr_clipping) x_tot = x_tot.clamp(-llr_clipping, llr_clipping) return x_e, x_tot
def cn_node_update_identity( msg_v2c: torch.Tensor, mask: torch.Tensor, llr_clipping: Optional[float] = None, ) -> torch.Tensor: # pylint: disable=line-too-long r"""Dummy function that returns the first tensor without any processing. Used for testing and debugging of message passing decoding. :param msg_v2c: Tensor of shape `[batch_size, num_nodes, max_degree]` representing v2c messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_clipping: Clipping value (unused). """ return msg_v2c
[docs] def cn_update_offset_minsum( msg_v2c: torch.Tensor, mask: torch.Tensor, llr_clipping: Optional[float] = None, offset: float = 0.5, ) -> torch.Tensor: # pylint: disable=line-too-long r"""Check node update function implementing the offset corrected minsum. The function implements .. math:: \qquad y_{j \to i} = \alpha_{j \to i} \cdot \max \left( \min_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}| \right)-\beta , 0\right) where :math:`\beta=0.5` and :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to :cite:p:`Chen`. :param msg_v2c: Tensor of shape `[batch_size, num_nodes, max_degree]` representing v2c messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. :param offset: Offset value to be subtracted from each outgoing message. """ large_val = 1e6 # mask [num_nodes, max_degree] broadcasts with [bs, num_nodes, max_degree] inv_mask = 1.0 - mask sign_val = torch.sign(msg_v2c) sign_val = torch.where(sign_val == 0, torch.ones_like(sign_val), sign_val) # For padded positions, set sign to 1 (neutral for product) sign_val = sign_val * mask + inv_mask sign_node = sign_val.prod(dim=2, keepdim=True) # [bs, num_nodes, 1] # Extrinsic sign: total sign / intrinsic sign sign_out = sign_node * sign_val # [bs, num_nodes, max_degree] msg_abs = torch.abs(msg_v2c) # For padded positions, set to large value so they don't affect min msg_abs_masked = msg_abs * mask + large_val * inv_mask min_val, _ = msg_abs_masked.min(dim=2, keepdim=True) # [bs, num_nodes, 1] msg_min1 = (msg_abs_masked - min_val) * mask + large_val * inv_mask is_min_position = msg_min1 == 0 num_min_positions = (is_min_position * mask).sum(dim=2, keepdim=True) double_min = (num_min_positions > 1).to(msg_v2c.dtype) msg_for_second_min = torch.where( is_min_position, torch.full_like(msg_min1, large_val), msg_min1 ) min_val_2, _ = msg_for_second_min.min(dim=2, keepdim=True) min_val_2 = min_val_2 + min_val min_val_e = (1 - double_min) * min_val_2 + double_min * min_val msg_e = torch.where(is_min_position, min_val_e, min_val) msg_e = torch.clamp(msg_e - offset, min=0) msg_e = msg_e * mask msg_out = sign_out * msg_e if llr_clipping is not None: msg_out = msg_out.clamp(-llr_clipping, llr_clipping) return msg_out
[docs] def cn_update_minsum( msg_v2c: torch.Tensor, mask: torch.Tensor, llr_clipping: Optional[float] = None, ) -> torch.Tensor: # pylint: disable=line-too-long r"""Check node update function implementing the `minsum` update. The function implements .. math:: \qquad y_{j \to i} = \alpha_{j \to i} \cdot \min_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}|\right) where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to :cite:p:`Ryan` and :cite:p:`Chen`. :param msg_v2c: Tensor of shape `[batch_size, num_nodes, max_degree]` representing v2c messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. """ return cn_update_offset_minsum(msg_v2c, mask, llr_clipping, offset=0.0)
[docs] def cn_update_tanh( msg_v2c: torch.Tensor, mask: torch.Tensor, llr_clipping: Optional[float] = None, ) -> torch.Tensor: # pylint: disable=line-too-long r"""Check node update function implementing the `boxplus` operation. This function implements the (extrinsic) check node update function. It calculates the boxplus function over all incoming messages "msg" excluding the intrinsic (=outgoing) message itself. The exact boxplus function is implemented by using the tanh function. The function implements .. math:: y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right) where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to :cite:p:`Ryan`. Note that for numerical stability clipping can be applied. :param msg_v2c: Tensor of shape `[batch_size, num_nodes, max_degree]` representing v2c messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. """ atanh_clip_value = 1 - 1e-7 # mask [num_nodes, max_degree] broadcasts with [bs, num_nodes, max_degree] inv_mask = 1.0 - mask msg = msg_v2c / 2 msg = torch.tanh(msg) msg = torch.where(msg == 0, torch.full_like(msg, 1e-12), msg) # For padded positions, set to 1 (neutral for product) msg = msg * mask + inv_mask msg_prod = msg.prod(dim=2, keepdim=True) # [bs, num_nodes, 1] msg_recip = 1.0 / msg msg_e = msg_recip * msg_prod msg_e = torch.where(torch.abs(msg_e) < 1e-7, torch.zeros_like(msg_e), msg_e) msg_e = msg_e.clamp(-atanh_clip_value, atanh_clip_value) msg_out = 2 * torch.atanh(msg_e) msg_out = msg_out * mask if llr_clipping is not None: msg_out = msg_out.clamp(-llr_clipping, llr_clipping) return msg_out
[docs] def cn_update_phi( msg_v2c: torch.Tensor, mask: torch.Tensor, llr_clipping: Optional[float] = None, ) -> torch.Tensor: # pylint: disable=line-too-long r"""Check node update function implementing the `boxplus` operation. This function implements the (extrinsic) check node update function based on the numerically more stable `"_phi"` function (cf. :cite:p:`Ryan`). It calculates the boxplus function over all incoming messages ``msg`` excluding the intrinsic (=outgoing) message itself. The exact boxplus function is implemented by using the `"_phi"` function as in :cite:p:`Ryan`. The function implements .. math:: y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right) where :math:`\phi(x)=-\operatorname{log}\left(\operatorname{tanh}\left(\frac{x}{2}\right)\right)` and :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to :cite:p:`Ryan`. Note that for numerical stability clipping can be applied. :param msg_v2c: Tensor of shape `[batch_size, num_nodes, max_degree]` representing v2c messages. :param mask: Tensor of shape `[num_nodes, max_degree]` indicating valid edges. :param llr_clipping: Clipping value used for internal processing. If `None`, no internal clipping is applied. """ def _phi(x: torch.Tensor) -> torch.Tensor: r"""Implements :math:`\phi(x)=-\operatorname{log}\left(\operatorname{tanh}\left(\frac{x}{2}\right)\right)`.""" if x.dtype == torch.float32: x = x.clamp(min=8.5e-8, max=16.635532) elif x.dtype == torch.float64: x = x.clamp(min=1e-12, max=28.324079) else: x = x.clamp(min=1e-7, max=20.0) return torch.log(torch.exp(x) + 1) - torch.log(torch.exp(x) - 1) # mask [num_nodes, max_degree] broadcasts with [bs, num_nodes, max_degree] inv_mask = 1.0 - mask sign_val = torch.sign(msg_v2c) sign_val = torch.where(sign_val == 0, torch.ones_like(sign_val), sign_val) sign_val = sign_val * mask + inv_mask sign_node = sign_val.prod(dim=2, keepdim=True) # [bs, num_nodes, 1] sign_out = sign_val * sign_node msg_abs = torch.abs(msg_v2c) msg_phi = _phi(msg_abs) # For padded positions, set to 0 (neutral for sum) msg_phi = msg_phi * mask msg_sum = msg_phi.sum(dim=2, keepdim=True) # [bs, num_nodes, 1] # Extrinsic: total sum - intrinsic msg_e = msg_sum - msg_phi msg_e = _phi(msg_e) msg_out = sign_out * msg_e msg_out = msg_out * mask if llr_clipping is not None: msg_out = msg_out.clamp(-llr_clipping, llr_clipping) return msg_out
[docs] class LDPC5GDecoder(LDPCBPDecoder): # pylint: disable=line-too-long r"""Iterative belief propagation decoder for 5G NR LDPC codes. Inherits from :class:`~sionna.phy.fec.ldpc.decoding.LDPCBPDecoder` and provides a wrapper for 5G compatibility, i.e., automatically handles rate-matching according to :cite:p:`3GPPTS38212`. Note that for full 5G 3GPP NR compatibility, the correct puncturing and shortening patterns must be applied and, thus, the encoder object is required as input. If required the decoder can be made trainable and is differentiable (the training of some check node types may be not supported) following the concept of "weighted BP" :cite:p:`Nachmani`. When ``harq_mode=True``, the decoder supports HARQ-IR decoding. Pass ``rv`` to :meth:`call` together with stacked LLRs of shape ``[..., num_rvs, n]`` to combine multiple redundancy versions before BP decoding. PCM pruning is automatically disabled in HARQ mode. :param encoder: An instance of :class:`~sionna.phy.fec.ldpc.encoding.LDPC5GEncoder` containing the correct code parameters. :param cn_update: Check node update rule to be used as described above. One of "boxplus-phi" (default), "boxplus", "minsum", "offset-minsum", "identity", or a callable. If a callable is provided, it will be used instead as CN update. The input of the function is a tensor of v2c messages of shape `[batch_size, num_cns, max_degree]` with a mask of shape `[num_cns, max_degree]`. :param vn_update: Variable node update rule to be used. One of "sum" (default), "identity", or a callable. If a callable is provided, it will be used instead as VN update. The input of the function is a tensor of c2v messages of shape `[batch_size, num_vns, max_degree]` with a mask of shape `[num_vns, max_degree]`. :param cn_schedule: Defines the CN update scheduling per BP iteration. Can be either "flooding" to update all nodes in parallel (recommended) or "layered" to sequentially update all CNs in the same lifting group together or a 2D tensor of shape `[num_update_steps, num_active_nodes]` where each row defines the node indices to be updated per subiteration. In this case each BP iteration runs ``num_update_steps`` subiterations, thus the decoder's level of parallelization is lower and usually the decoding throughput decreases. :param hard_out: If `True`, the decoder provides hard-decided codeword bits instead of soft-values. :param return_infobits: If `True`, only the `k` info bits (soft or hard-decided) are returned. Otherwise all `n` positions are returned. :param prune_pcm: If `True`, all punctured degree-1 VNs and connected check nodes are removed from the decoding graph (see :cite:p:`Cammerer` for details). Besides numerical differences, this should yield the same decoding result but improves the decoding throughput and reduces the memory footprint. Automatically set to `False` when ``harq_mode=True``. :param num_iter: Defining the number of decoder iterations (due to batching, no early stopping used at the moment!). :param llr_max: Internal clipping value for all internal messages. If `None`, no clipping is applied. :param v2c_callbacks: Each callable will be executed after each VN update with the following arguments ``msg_vn``, ``it``, ``x_hat``, where ``msg_vn`` are the v2c messages as tensor of shape `[batch_size, num_vns, max_degree]`, ``x_hat`` is the current estimate of each VN of shape `[batch_size, num_vns]`, and ``it`` is the current iteration counter. It must return an updated version of ``msg_vn`` of same shape. :param c2v_callbacks: Each callable will be executed after each CN update with the following arguments ``msg_cn`` and ``it`` where ``msg_cn`` are the c2v messages as tensor of shape `[batch_size, num_cns, max_degree]` and ``it`` is the current iteration counter. It must return an updated version of ``msg_cn`` of same shape. :param return_state: If `True`, the internal VN messages ``msg_vn`` from the last decoding iteration are returned, and ``msg_vn`` or `None` needs to be given as a second input when calling the decoder. This can be used for iterative demapping and decoding. :param harq_mode: If `True`, the decoder operates in HARQ-IR mode. PCM pruning is automatically disabled, and ``rv`` can be passed to :meth:`call` to combine multiple redundancy versions. :param precision: Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). :input llr_ch: [..., n] or [..., num_rvs, n], `torch.float`. Tensor containing the channel logits/llr values. When ``rv`` is given, the second-to-last dimension must equal ``len(rv)``. :input rv: `None` | list of int. Redundancy versions corresponding to the stacked LLRs. If `None`, standard single-transmission decoding is used. :input msg_v2c: `None` | [batch_size, num_edges], `torch.float`. Tensor of VN messages representing the internal decoder state. Required only if the decoder shall use its previous internal state, e.g., for iterative detection and decoding (IDD) schemes. :output x_hat: [..., n] or [..., k], `torch.float`. Tensor of same shape as ``llr_ch`` containing bit-wise soft-estimates (or hard-decided bit-values) of all `n` codeword bits or only the `k` information bits if ``return_infobits`` is `True`. :output msg_v2c: [batch_size, num_edges], `torch.float`. Tensor of VN messages representing the internal decoder state. Returned only if ``return_state`` is set to `True`. Remark: always returns entire decoder state, even if ``return_infobits`` is `True`. .. rubric:: Notes As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for compatibility with the learning framework, but internally LLRs with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC codes and, thus, supports arbitrary parity-check matrices. The batch-dimension is shifted to the last dimension during decoding to avoid a performance degradation caused by a severe indexing overhead. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder # Create encoder and decoder encoder = LDPC5GEncoder(k=100, n=200) decoder = LDPC5GDecoder(encoder, num_iter=20) # Encode and decode u = torch.randint(0, 2, (10, 100), dtype=torch.float32) c = encoder(u) llr_ch = 2.0 * (2.0 * c - 1.0) # Perfect LLRs u_hat = decoder(llr_ch) print(torch.equal(u, u_hat)) # True # HARQ example: combine two redundancy versions dec_harq = LDPC5GDecoder(encoder, num_iter=20, harq_mode=True) c_rvs = encoder(u, rv=[0, 2]) # [10, 2, 200] llr_rvs = 2.0 * (2.0 * c_rvs - 1.0) u_hat = dec_harq(llr_rvs, rv=[0, 2]) # [10, 100] """ def __init__( self, encoder: LDPC5GEncoder, cn_update: Union[str, Callable] = "boxplus-phi", vn_update: Union[str, Callable] = "sum", cn_schedule: Union[str, np.ndarray, torch.Tensor] = "flooding", hard_out: bool = True, return_infobits: bool = True, num_iter: int = 20, llr_max: Optional[float] = 20.0, v2c_callbacks: Optional[List[Callable]] = None, c2v_callbacks: Optional[List[Callable]] = None, prune_pcm: bool = True, return_state: bool = False, harq_mode: bool = False, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): # Needs the 5G Encoder to access all 5G parameters if not isinstance(encoder, LDPC5GEncoder): raise TypeError("encoder must be of class LDPC5GEncoder.") # Store encoder reference (will be assigned after super().__init__()) _encoder_ref = encoder pcm = encoder.pcm if not isinstance(return_infobits, bool): raise TypeError("return_infobits must be bool.") self._return_infobits = return_infobits if not isinstance(return_state, bool): raise TypeError("return_state must be bool.") if not isinstance(harq_mode, bool): raise TypeError("harq_mode must be bool.") self._harq_mode = harq_mode # Deprecation warning for cn_type if "cn_type" in kwargs: raise TypeError("'cn_type' is deprecated; use 'cn_update' instead.") # Prune punctured degree-1 VNs and connected CNs if not isinstance(prune_pcm, bool): raise TypeError("prune_pcm must be bool.") # Pruning must be disabled in HARQ mode if harq_mode: prune_pcm = False self._prune_pcm = prune_pcm if prune_pcm: # Find index of first position with only degree-1 VN dv = np.sum(pcm, axis=0) # VN degree last_pos = encoder.n_ldpc for idx in range(encoder.n_ldpc - 1, 0, -1): if dv[0, idx] == 1: last_pos = idx else: break # Number of filler bits k_filler = encoder.k_ldpc - encoder.k # Number of punctured bits nb_punc_bits = (encoder.n_ldpc - k_filler) - encoder.n - 2 * encoder.z # If layered decoding is used, quantize number of punctured bits # to a multiple of z if cn_schedule == "layered": nb_punc_bits = int(np.floor(nb_punc_bits / encoder.z) * encoder.z) # Effective codeword length after pruning of vn-1 nodes self._n_pruned = int(np.maximum(last_pos, encoder._n_ldpc - nb_punc_bits)) self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned # Remove last CNs and VNs from pcm pcm = pcm[: -self._nb_pruned_nodes, : -self._nb_pruned_nodes] if self._nb_pruned_nodes < 0: raise ArithmeticError( "Internal error: number of pruned nodes must be positive." ) else: self._nb_pruned_nodes = 0 self._n_pruned = encoder._n_ldpc # Handle layered scheduling if cn_schedule == "layered": z = encoder.z num_blocks = int(pcm.shape[0] / z) cn_schedule_list = [] for i in range(num_blocks): cn_schedule_list.append(np.arange(z) + i * z) cn_schedule = np.stack(cn_schedule_list, axis=0) super().__init__( pcm, cn_update=cn_update, vn_update=vn_update, cn_schedule=cn_schedule, hard_out=hard_out, num_iter=num_iter, llr_max=llr_max, v2c_callbacks=v2c_callbacks, c2v_callbacks=c2v_callbacks, return_state=return_state, precision=precision, device=device, **kwargs, ) # Assign encoder after super().__init__() for nn.Module compatibility self._encoder = _encoder_ref ############################### # Public methods and properties ############################### @property def encoder(self) -> LDPC5GEncoder: """LDPC Encoder used for rate-matching/recovery.""" return self._encoder ######################## # Sionna block functions ########################
[docs] def build(self, input_shape: tuple, **kwargs) -> None: """Build block and check input dimensions.""" if input_shape[-1] != self.encoder.n: raise ValueError("Last dimension must be of length n.") self._old_shape_5g = input_shape
def call( self, llr_ch: torch.Tensor, /, *, rv: Optional[List[int]] = None, num_iter: Optional[int] = None, msg_v2c: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Iterative BP decoding function and rate-recovery. :param llr_ch: Channel LLRs of shape ``[..., n]`` (standard) or ``[..., num_rvs, n]`` (HARQ). :param rv: List of RV indices matching the stacked LLRs. If `None`, standard single-RV decoding is used. In HARQ mode (``harq_mode=True``), defaults to ``[0]``. :param num_iter: Number of iterations. If `None`, uses default. :param msg_v2c: Initial v2c messages for IDD schemes. :output x_hat: Decoded bits of shape ``[..., n]`` or ``[..., k]``. :output msg_v2c: Decoder state. Returned only if ``return_state`` is `True`. """ llr_ch_shape = list(llr_ch.shape) input_device = llr_ch.device n = self.encoder.n k = self.encoder.k z = self.encoder.z k_filler = self.encoder.k_filler nb_pruned = self._nb_pruned_nodes buf_len = self.encoder.n_cb_comp - nb_pruned # --- Build compressed RM buffer ----------------------------------- if self._harq_mode: if rv is None: rv = [0] rv = list(rv) num_rvs = len(rv) if num_rvs > 1: if llr_ch_shape[-2] != num_rvs: raise ValueError( f"Second-to-last dimension of llr_ch " f"({llr_ch_shape[-2]}) must equal len(rv) " f"({num_rvs})." ) llr_ch_flat = llr_ch.reshape(-1, num_rvs, n) leading_shape = llr_ch_shape[:-2] else: has_rv_dim = (len(llr_ch_shape) >= 3 and llr_ch_shape[-2] == 1) leading_shape = (llr_ch_shape[:-2] if has_rv_dim else llr_ch_shape[:-1]) llr_ch_flat = llr_ch.reshape(-1, n).unsqueeze(1) batch_size = llr_ch_flat.shape[0] # Accumulate multiple RVs via pad + roll starts = self.encoder.get_start_positions_comp(rv) pad_len = buf_len - n llr_buf = torch.zeros( batch_size, buf_len, dtype=self.dtype, device=input_device, ) for rv_idx in range(num_rvs): llr_rv = llr_ch_flat[:, rv_idx, :] if self._encoder.num_bits_per_symbol is not None: llr_rv = llr_rv[:, self._encoder.out_int_inv] llr_buf = llr_buf + torch.nn.functional.pad( llr_rv.to(self.dtype), (0, pad_len) ).roll(starts[rv_idx], dims=1) else: leading_shape = llr_ch_shape[:-1] llr_ch_reshaped = llr_ch.reshape(-1, n) batch_size = llr_ch_reshaped.shape[0] if self._encoder.num_bits_per_symbol is not None: llr_ch_reshaped = llr_ch_reshaped[ :, self._encoder.out_int_inv ] llr_buf = torch.nn.functional.pad( llr_ch_reshaped, (0, buf_len - n) ) # --- Reassemble full n_ldpc vector --------------------------------- n_sys_rm = k - 2 * z llr_5g = torch.cat( [ torch.zeros(batch_size, 2 * z, dtype=self.dtype, device=input_device), llr_buf[:, :n_sys_rm], -self._llr_max * torch.ones(batch_size, k_filler, dtype=self.dtype, device=input_device), llr_buf[:, n_sys_rm:], ], dim=1, ) # --- BP decoding -------------------------------------------------- output = super().call(llr_5g, num_iter=num_iter, msg_v2c=msg_v2c) if self._return_state: x_hat, msg_v2c_out = output else: x_hat = output # --- Output formatting -------------------------------------------- if self._return_infobits: u_hat = x_hat[:, :k] output_shape = leading_shape + [k] output_shape[0] = -1 u_reshaped = u_hat.reshape(output_shape) if self._return_state: return u_reshaped, msg_v2c_out return u_reshaped x = x_hat.reshape(batch_size, self._n_pruned) x_no_filler1 = x[:, :k] x_no_filler2 = x[:, self.encoder.k_ldpc:self._n_pruned] x_no_filler = torch.cat([x_no_filler1, x_no_filler2], dim=1) x_short = x_no_filler[:, 2 * z:2 * z + n] if self._encoder.num_bits_per_symbol is not None: x_short = x_short[:, self._encoder.out_int] output_shape = leading_shape + [n] output_shape[0] = -1 x_short = x_short.reshape(output_shape) if self._return_state: return x_short, msg_v2c_out return x_short