#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Blocks for channel decoding and utility functions."""
import tensorflow as tf
import numpy as np
import scipy as sp # for sparse H matrix computations
from sionna.phy import Block
from sionna.phy.fec.ldpc.encoding import LDPC5GEncoder
import types
[docs]
class LDPCBPDecoder(Block):
    # pylint: disable=line-too-long
    r"""Iterative belief propagation decoder for low-density parity-check (LDPC)
    codes and other `codes on graphs`.
    This class defines a generic belief propagation decoder for decoding
    with arbitrary parity-check matrices. It can be used to iteratively
    estimate/recover the transmitted codeword (or information bits) based on the
    LLR-values of the received noisy codeword observation.
    Per default, the decoder implements the flooding message passing algorithm
    [Ryan]_, i.e., all nodes are updated in a parallel fashion. Different check node update functions are available
    (1) `boxplus`
        .. math::
            y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right)
    (2) `boxplus-phi`
        .. math::
            y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right)
        with :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)`
    (3) `minsum`
        .. math::
            \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}|\right)
    (4) `offset-minsum`
    .. math::
            \qquad y_{j \to i} = \alpha_{j \to i} \cdot {max} \left( {min}_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}| \right)-\beta , 0\right)
    where :math:`\beta=0.5` and and :math:`y_{j \to i}` denotes the message
    from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}`
    from VN *i* to CN *j*, respectively.  Further, :math:`\mathcal{N}(j)`
    denotes all indices of connected VNs to CN *j* and
    .. math::
        \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j})
    is the sign of the outgoing message. For further details we refer to
    [Ryan]_ and [Chen]_ for offset corrected minsum.
    Note that for full 5G 3GPP NR compatibility, the correct puncturing and
    shortening patterns must be applied (cf. [Richardson]_ for details), this
    can be done by :class:`~sionna.phy.fec.ldpc.encoding.LDPC5GEncoder` and
    :class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`, respectively.
    If required, the decoder can be made trainable and is fully differentiable
    by following the concept of `weighted BP` [Nachmani]_. For this, custom
    callbacks can be registered that scale the messages during decoding. Please
    see the corresponding tutorial notebook for details.
    For numerical stability, the decoder applies LLR clipping of +/- `llr_max`
    to the input LLRs.
    Parameters
    ----------
    pcm: ndarray
        An ndarray of shape `[n-k, n]` defining the parity-check matrix
        consisting only of `0` or `1` entries. Can be also of type `scipy.
        sparse.csr_matrix` or `scipy.sparse.csc_matrix`.
    cn_update: str, "boxplus-phi" (default) | "boxplus" | "minsum" | "offset-minsum" | "identity" | callable
        Check node update rule to be used as described above.
        If a callable is provided, it will be used instead as CN update.
        The input of the function is a ragged tensor of v2c messages of shape
        `[num_cns, None, batch_size]` where the second dimension is ragged
        (i.e., depends on the individual CN degree).
    vn_update: str, "sum" (default) | "identity" | callable
        Variable node update rule to be used.
        If a callable is provided, it will be used instead as VN update.
        The input of the function is a ragged tensor of c2v messages of shape
        `[num_vns, None, batch_size]` where the second dimension is ragged
        (i.e., depends on the individual VN degree).
    cn_schedule: "flooding" | [num_update_steps, num_active_nodes], tf.int
        Defines the CN update scheduling per BP iteration. Can be either
        "flooding" to update all nodes in parallel (recommended) or an 2D tensor
        where each row defines the `num_active_nodes` node indices to be
        updated per subiteration. In this case each BP iteration runs
        `num_update_steps` subiterations, thus the decoder's level of
        parallelization is lower and usually the decoding throughput decreases.
    hard_out: `bool`, (default `True`)
        If `True`,  the decoder provides hard-decided codeword bits instead of
        soft-values.
    num_iter: int
        Defining the number of decoder iteration (due to batching, no early
        stopping used at the moment!).
    llr_max: float (default 20) | `None`
        Internal clipping value for all internal messages. If `None`, no
        clipping is applied.
    v2c_callbacks, `None` (default) | list of callables
        Each callable will be executed after each VN update with the following
        arguments `msg_vn_rag_`, `it`, `x_hat`,where `msg_vn_rag_` are the v2c
        messages as ragged tensor of shape `[num_vns, None, batch_size]`,
        `x_hat` is the current estimate of each VN of shape
        `[num_vns, batch_size]` and `it` is the current iteration counter.
        It must return and updated version of `msg_vn_rag_` of same shape.
    c2v_callbacks: `None` (default) | list of callables
        Each callable will be executed after each CN update with the following
        arguments `msg_cn_rag_` and `it` where `msg_cn_rag_` are the c2v
        messages as ragged tensor of shape `[num_cns, None, batch_size]` and
        `it` is the current iteration counter.
        It must return and updated version of `msg_cn_rag_` of same shape.
    return_state: `bool`, (default `False`)
        If `True`,  the internal VN messages ``msg_vn`` from the last decoding
        iteration are returned, and ``msg_vn`` or `None` needs to be given as a
        second input when calling the decoder.
        This can be used for iterative demapping and decoding.
    precision : `None` (default) | 'single' | 'double'
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.
    Input
    -----
    llr_ch: [...,n], tf.float
        Tensor containing the channel logits/llr values.
    msg_v2c: `None` | [num_edges, batch_size], tf.float
        Tensor of VN messages representing the internal decoder state.
        Required only if the decoder shall use its previous internal state, e.g.
        for iterative detection and decoding (IDD) schemes.
    Output
    ------
    : [...,n], tf.float
        Tensor of same shape as ``llr_ch`` containing
        bit-wise soft-estimates (or hard-decided bit-values) of all
        codeword bits.
    : [num_edges, batch_size], tf.float:
        Tensor of VN messages representing the internal decoder state.
        Returned only if ``return_state`` is set to `True`.
    Note
    ----
    As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}`
    are assumed for compatibility with the learning framework, but internally
    log-likelihood ratios (LLRs) with definition
    :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used.
    The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC
    codes and, thus, supports arbitrary parity-check matrices.
    The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to
    account for arbitrary node degrees. To avoid a performance degradation
    caused by a severe indexing overhead, the batch-dimension is shifted to
    the last dimension during decoding.
    """
    def __init__(self,
                 pcm,
                 cn_update="boxplus-phi",
                 vn_update="sum",
                 cn_schedule="flooding",
                 hard_out=True,
                 num_iter=20,
                 llr_max=20.,
                 v2c_callbacks=None,
                 c2v_callbacks=None,
                 return_state=False,
                 precision=None,
                 **kwargs):
        super().__init__(precision=precision, **kwargs)
        # check inputs for consistency
        if not isinstance(hard_out, bool):
            raise TypeError('hard_out must be bool.')
        if not isinstance(num_iter, int):
            raise TypeError('num_iter must be int.')
        if num_iter<0:
            raise ValueError('num_iter cannot be negative.')
        if not isinstance(return_state, bool):
            raise TypeError('return_state must be bool.')
        if isinstance(pcm, np.ndarray):
            if not np.array_equal(pcm, pcm.astype(bool)):
                raise ValueError('PC matrix must be binary.')
        elif isinstance(pcm, sp.sparse.csr_matrix):
            if not np.array_equal(pcm.data, pcm.data.astype(bool)):
                raise ValueError('PC matrix must be binary.')
        elif isinstance(pcm, sp.sparse.csc_matrix):
            if not np.array_equal(pcm.data, pcm.data.astype(bool)):
                raise ValueError('PC matrix must be binary.')
        else:
            raise TypeError("Unsupported dtype of pcm.")
        # Deprecation warning for cn_type
        if 'cn_type' in kwargs:
            raise TypeError("'cn_type' is deprecated; use 'cn_update' instead.")
        # init decoder parameters
        self._pcm = pcm
        self._hard_out = hard_out
        self._num_iter = tf.constant(num_iter, dtype=tf.int32)
        self._return_state = return_state
        self._num_cns = pcm.shape[0] # total number of check nodes
        self._num_vns = pcm.shape[1] # total number of variable nodes
        # internal value for llr clipping
        if not isinstance(llr_max, (int, float)):
            raise TypeError("llr_max must be int or float.")
        self._llr_max = tf.cast(llr_max, self.rdtype)
        if v2c_callbacks is None:
            self._v2c_callbacks = []
        else:
            if isinstance(v2c_callbacks, (list, tuple)):
                self._v2c_callbacks = v2c_callbacks
            elif isinstance(v2c_callbacks, types.FunctionType):
                # allow that user provides single function
                self._v2c_callbacks = [v2c_callbacks,]
            else:
                raise TypeError("v2c_callbacks must be a list of callables.")
        if c2v_callbacks is None:
            self._c2v_callbacks = []
        else:
            if isinstance(c2v_callbacks, (list, tuple)):
                self._c2v_callbacks = c2v_callbacks
            elif isinstance(c2v_callbacks, types.FunctionType):
                # allow that user provides single function
                self._c2v_callbacks = [c2v_callbacks,]
            else:
                raise TypeError("c2v_callbacks must be a list of callables.")
        if isinstance(cn_schedule, str) and cn_schedule=="flooding":
            self._scheduling = "flooding"
            self._cn_schedule = tf.stack([tf.range(self._num_cns)], axis=0)
        elif tf.is_tensor(cn_schedule) or isinstance(cn_schedule, np.ndarray):
            cn_schedule = tf.cast(cn_schedule, tf.int32)
            self._scheduling = "custom"
            # check custom schedule for consistency
            if len(cn_schedule.shape)!=2:
                raise ValueError("cn_schedule must be of rank 2.")
            if tf.reduce_max(cn_schedule)>=self._num_cns:
                msg = "cn_schedule can only contain values smaller number_cns."
                raise ValueError(msg)
            if tf.reduce_min(cn_schedule)<0:
                msg = "cn_schedule cannot contain negative values."
                raise ValueError(msg)
            self._cn_schedule = cn_schedule
        else:
            msg = "cn_schedule can be 'flooding' or an array of ints."
            raise ValueError(msg)
        ######################
        # Init graph structure
        ######################
        # make pcm sparse first if ndarray is provided
        if isinstance(pcm, np.ndarray):
            pcm = sp.sparse.csr_matrix(pcm)
        # Assign all edges to CN and VN nodes, respectively
        self._cn_idx, self._vn_idx, _ = sp.sparse.find(pcm)
        # sort indices explicitly, as scipy.sparse.find changed from column to
        # row sorting in scipy>=1.11
        idx = np.argsort(self._vn_idx)
        self._cn_idx = self._cn_idx[idx]
        self._vn_idx = self._vn_idx[idx]
        # number of edges equals number of non-zero elements in the
        # parity-check matrix
        self._num_edges = len(self._vn_idx)
        # pre-load the CN function
        if cn_update=='boxplus':
            # check node update using the tanh function
            self._cn_update = cn_update_tanh
        elif cn_update=='boxplus-phi':
            # check node update using the "_phi" function
            self._cn_update = cn_update_phi
        elif cn_update in ('minsum', 'min'):
            # check node update using the min-sum approximation
            self._cn_update = cn_update_minsum
        elif cn_update=="offset-minsum":
            # check node update using the min-sum approximation
            self._cn_update = cn_update_offset_minsum
        elif cn_update=='identity':
            self._cn_update = cn_node_update_identity
        elif isinstance(cn_update, types.FunctionType):
            self._cn_update = cn_update
        else:
            raise TypeError("Provided cn_update not supported.")
        # pre-load the VN function
        if vn_update=='sum':
            self._vn_update = vn_update_sum
        elif vn_update=='identity':
            self._vn_update = vn_node_update_identity
        elif isinstance(vn_update, types.FunctionType):
            self._vn_update = vn_update
        else:
            raise TypeError("Provided vn_update not supported.")
        ######################
        # init graph structure
        ######################
        # Permutation index to rearrange edge messages into CN perspective
        v2c_perm = np.argsort(self._cn_idx)
        # and the inverse operation;
        v2c_perm_inv = np.argsort(v2c_perm)
        # only required for layered decoding
        self._v2c_perm_inv = tf.constant(v2c_perm_inv)
        # Initialize a ragged tensor that allows to gather
        # from the v2c messages (from VN perspective) and returns
        # a ragged tensor of incoming messages of each CN.
        # This needs to be ragged as the CN degree can be irregular.
        self._v2c_perm = tf.RaggedTensor.from_value_rowids(
                                values=v2c_perm,
                                value_rowids=self._cn_idx[v2c_perm])
        self._c2v_perm = tf.RaggedTensor.from_value_rowids(
                                values=v2c_perm_inv,
                                value_rowids=self._vn_idx)
    ###############################
    # Public methods and properties
    ###############################
    @property
    def pcm(self):
        """Parity-check matrix of LDPC code"""
        return self._pcm
    @property
    def num_cns(self):
        """Number of check nodes"""
        return self._num_cns
    @property
    def num_vns(self):
        """Number of variable nodes"""
        return self._num_vns
    @property
    def n(self):
        """codeword length"""
        return self._num_vns
    @property
    def coderate(self):
        """codrate assuming independent parity checks"""
        return (self._num_vns - self._num_cns) / self._num_vns
    @property
    def num_edges(self):
        """Number of edges in decoding graph"""
        return self._num_edges
    @property
    def num_iter(self):
        "Number of decoding iterations"
        return self._num_iter
    @num_iter.setter
    def num_iter(self, num_iter):
        "Number of decoding iterations"
        if not isinstance(num_iter, int):
            raise TypeError('num_iter must be int.')
        if num_iter<0:
            raise ValueError('num_iter cannot be negative.')
        self._num_iter = tf.constant(num_iter, dtype=tf.int32)
    @property
    def llr_max(self):
        """Max LLR value used for internal calculations and rate-matching"""
        return self._llr_max
    @llr_max.setter
    def llr_max(self, value):
        """Max LLR value used for internal calculations"""
        if value<0:
            raise ValueError('llr_max cannot be negative.')
        self._llr_max = tf.cast(value, dtype=self.rdtype)
    @property
    def return_state(self):
        """Return internal decoder state for IDD schemes"""
        return self._return_state
    #########################
    # Decoding functions
    #########################
    def _bp_iter(self, msg_v2c, msg_c2v, llr_ch, x_hat, it, num_iter):
        """Main decoding loop
        Parameters
        ----------
        msg_v2c: [num_edges, batch_size], tf.float
            Tensor of VN messages representing the internal decoder state.
        msg_c2v: [num_edges, batch_size], tf.float
            Tensor of CN messages representing the internal decoder state.
        llr_ch: [...,n], tf.float
            Tensor containing the channel logits/llr values.
        x_hat : [...,n] or [...,k], tf.float
            Tensor of same shape as ``llr_ch`` containing bit-wise
            soft-estimates of all `n` codeword bits.
        it: tf.int
            Current iteration number
        num_iter: int
            Total number of decoding iterations
        Returns
        -------
        msg_v2c: [num_edges, batch_size], tf.float
            Tensor of VN messages representing the internal decoder state.
        msg_c2v: [num_edges, batch_size], tf.float
            Tensor of CN messages representing the internal decoder state.
        llr_ch: [...,n], tf.float
            Tensor containing the channel logits/llr values.
        x_hat : [...,n] or [...,k], tf.float
            Tensor of same shape as ``llr_ch`` containing bit-wise
            soft-estimates of all `n` codeword bits.
        it: tf.int
            Current iteration number
        num_iter: int
            Total number of decoding iterations
        """
        # Unroll loop to keep XLA / Keras compatibility
        # For flooding this will be unrolled to a single loop iteration
        for j in range(self._cn_schedule.shape[0]):
            # get active check nodes
            if self._scheduling=="flooding":
                # for flooding all CNs are active
                v2c_perm = self._v2c_perm
            else: # select active CNs for j-th subiteration
                cn_idx = tf.gather(self._cn_schedule, j, axis=0)
                v2c_perm = tf.gather(self._v2c_perm, cn_idx, axis=0)
            # Gather ragged tensor of incoming messages at CN.
            # The shape is [num_cns, None, batch_size,...].
            # The None dimension is the ragged dimension and depends on the
            # individual check node degree
            msg_cn_rag = tf.gather(msg_v2c, v2c_perm, axis=0)
            # Apply the CN update
            msg_cn_rag_ = self._cn_update(msg_cn_rag, self.llr_max)
            # Apply CN callbacks
            for cb in self._c2v_callbacks:
                msg_cn_rag_ = cb(msg_cn_rag_, it)
            # Apply partial message updates for layered decoding
            if self._scheduling!="flooding":
                # note: the scatter update operation is quite expensive
                up_idx = tf.gather(self._c2v_perm.flat_values,
                                   v2c_perm.flat_values)
                # update only active cns are updated
                msg_c2v = tf.tensor_scatter_nd_update(
                                     msg_c2v,
                                     tf.expand_dims(up_idx, axis=1),
                                     msg_cn_rag_.flat_values)
            else:
                # for flodding all nodes are updated
                msg_c2v = msg_cn_rag_.flat_values
            # Gather ragged tensor of incoming messages at VN.
            # Note for performance reasons this includes the re-permute
            # of edges from CN to VN perspective.
            # The shape is [num_vns, None, batch_size,...].
            msg_vn_rag = tf.gather(msg_c2v, self._c2v_perm, axis=0)
            # Apply the VN update
            msg_vn_rag_, x_hat = self._vn_update(msg_vn_rag,
                                                 llr_ch,
                                                 self.llr_max)
            # apply v2c callbacks
            for cb in self._v2c_callbacks:
                msg_vn_rag_ = cb(msg_vn_rag_, it+1, x_hat)
            # we return flat values to avoid ragged tensors passing the tf.
            # while boundary (possible issues with XLA)
            msg_v2c = msg_vn_rag_.flat_values
        #increase iteration coutner
        it += 1
        return msg_v2c, msg_c2v, llr_ch, x_hat, it, num_iter
    # pylint: disable=unused-argument,unused-variable
    def _stop_cond(self, msg_v2c, msg_c2v, llr_ch, x_hat, it, num_iter):
        """stops decoding loop after num_iter iterations.
        Most inputs are ignored, just for compatibility with tf.while.
        """
        return it < num_iter
    #########################
    # Sionna Block functions
    #########################
    # pylint: disable=(unused-argument)
    def build(self, input_shape, **kwargs):
        # Raise AssertionError if shape of x is invalid
        assert (input_shape[-1]==self._num_vns), \
                            
'Last dimension must be of length n.'
    def call(self, llr_ch, /, *, num_iter=None, msg_v2c=None):
        """Iterative BP decoding function.
        """
        if num_iter is None:
            num_iter=self.num_iter
        # clip LLRs for numerical stability
        llr_ch = tf.clip_by_value(llr_ch,
                                  clip_value_min=-self._llr_max,
                                  clip_value_max=self._llr_max)
        # reshape to support multi-dimensional inputs
        llr_ch_shape = llr_ch.get_shape().as_list()
        new_shape = [-1, self._num_vns]
        llr_ch_reshaped = tf.reshape(llr_ch, new_shape)
        # batch dimension is last dimension due to ragged tensor representation
        llr_ch = tf.transpose(llr_ch_reshaped, (1, 0))
        # logits are converted into "true" LLRs as usually done in literature
        llr_ch *= -1.
        # If no initial decoder state is provided, we initialize it with 0.
        # This is relevant for IDD schemes.
        if msg_v2c is None:
            # init v2c messages with channel LLRs
            msg_v2c = tf.gather(llr_ch, self._vn_idx)
        else:
            msg_v2c *= -1 # invert sign due to logit definition
        # msg_v2c is of shape [num_edges, batch_size]
        # it contains all edge message from VN to CN
        # Hereby, self._vn_idx indicates the index of the associated VN
        # and self._cn_idx the index of the associated CN
        # messages from CN perspective; are inititalized to zero
        msg_c2v = tf.zeros_like(msg_v2c)
        # apply VN callbacks before first iteration
        if self._v2c_callbacks != []:
            msg_vn_rag_ = tf.RaggedTensor.from_value_rowids(
                                values=msg_v2c,
                                value_rowids=self._vn_idx)
            # apply v2c callbacks
            for cb in self._v2c_callbacks:
                msg_vn_rag_ = cb(msg_vn_rag_, tf.constant(0, tf.int32), llr_ch)
            # Ensure shape as otherwise XLA cannot infer
            # the output signature of the loop
            msg_v2c = msg_vn_rag_.flat_values
        #####################
        # Main decoding loop
        #####################
        # msg_v2c : decoder state (from vN perspective)
        # msg_c2v : decoder state (from CN perspective)
        # llr_ch : channel llrs
        # llr_ch:  x_hat; automatically returns llr_ch for 0 iterations
        # tf.constant(0, tf.int32) : iteration counter
        # num_iter : total number of iterations
        inputs = (msg_v2c, msg_c2v, llr_ch, llr_ch,
                  tf.constant(0, tf.int32), num_iter)
        # and run main decoding loop for num_iter iterations
        msg_v2c, _, _, x_hat, _, _ = tf.while_loop(
                                        self._stop_cond,self._bp_iter,
                                        inputs, maximum_iterations=num_iter)
        ######################
        # Post process outputs
        ######################
        # restore batch dimension to first dimension
        x_hat = tf.transpose(x_hat, (1,0))
        if self._hard_out: # hard decide decoder output if required
            x_hat = tf.greater_equal(tf.cast(0, self.rdtype), x_hat)
            x_hat = tf.cast(x_hat, self.rdtype)
        else:
            x_hat *= -1.  # convert LLRs back into logits
        # Reshape c_short so that it matches the original input dimensions
        output_shape = llr_ch_shape
        output_shape[0] = -1 # Dynamic batch dim
        x_reshaped = tf.reshape(x_hat, output_shape)
        if not self._return_state:
            return x_reshaped
        else:
            msg_v2c *= -1 # invert sign due to logit definition
            return x_reshaped, msg_v2c 
#######################
# Node update functions
#######################
# pylint: disable=unused-argument,unused-variable
def vn_node_update_identity(msg_c2v_rag, llr_ch, llr_clipping=None, **kwargs):
    # pylint: disable=line-too-long
    r"""Dummy variable node update function for testing.
    Behaves as an identity function and can be used for testing an debugging of
    message passing decoding.
    Marginalizes input messages and returns them as second output.
    Parameters
    ----------
    msg_c2v_rag: [num_edges, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents c2v messages.
    llr_ch: [num_nodes, batch_size], tf.float
        Tensor containing the channel LLRs.
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    Returns
    -------
    msg_v2c_rag : tf.ragged
        Updated v2c messages. Ragged tensor of same shape as ``msg_c2v``
    x_tot: tf.float
        Mariginalized LLRs per variable node of shape `[num_nodes, batch_size]`.
        Can be used as final estimate per VN.
    """
    # aggregate all incoming messages per node
    x_tot = tf.reduce_sum(msg_c2v_rag, axis=1) + llr_ch
    return msg_c2v_rag, x_tot
[docs]
def vn_update_sum(msg_c2v_rag, llr_ch, llr_clipping=None):
    # pylint: disable=line-too-long
    r"""Variable node update function implementing the `sum` update.
    This function implements the (extrinsic) variable node update
    function. It takes the sum over all incoming messages ``msg`` excluding
    the intrinsic (= outgoing) message itself.
    Additionally, the channel LLR ``llr_ch`` is considered in each variable
    node.
    Parameters
    ----------
    msg_c2v_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents c2v messages.
    llr_ch: [num_nodes, batch_size], tf.float
        Tensor containing the channel LLRs.
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    Returns
    -------
    msg_v2c_rag : tf.ragged
        Updated v2c messages. Ragged tensor of same shape as ``msg_c2v``
    x_tot: tf.float
        Mariginalized LLRs per variable node of shape `[num_nodes, batch_size]`.
    """
    # aggregate all incoming messages per node
    x = tf.reduce_sum(msg_c2v_rag, axis=1)
    x_tot = tf.add(x, llr_ch)
    # TF2.9 does not support XLA for the addition of ragged tensors
    # the following code provides a workaround that supports XLA
    # subtract extrinsic message from node value
    #x_e = tf.expand_dims(x_tot, axis=1)
    #x_e = tf.add(-msg_c2v, x_e)
    x_e = tf.ragged.map_flat_values(lambda x,y,row_ind: x+tf.gather(y, row_ind),
                            -1.*msg_c2v_rag, x_tot, msg_c2v_rag.value_rowids())
    if llr_clipping is not None:
        x_e = tf.clip_by_value(x_e,
                    clip_value_min=-llr_clipping, clip_value_max=llr_clipping)
        x_tot = tf.clip_by_value(x_tot,
                    clip_value_min=-llr_clipping, clip_value_max=llr_clipping)
    return x_e, x_tot 
# pylint: disable=unused-argument,unused-variable
def cn_node_update_identity(msg_v2c_rag, *kwargs):
    # pylint: disable=line-too-long
    r"""Dummy function that returns the first tensor without any processing.
    Used for testing an debugging of message passing decoding.
    Parameters
    ----------
    msg_v2c_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents v2c messages
    Returns
    -------
    msg_c2v_rag : [num_nodes, None, batch_size], tf.ragged
        Updated v2c messages. Ragged tensor of same shape as ``msg_c2v``.
    """
    return msg_v2c_rag
[docs]
def cn_update_offset_minsum(msg_v2c_rag, llr_clipping=None, offset=0.5):
    # pylint: disable=line-too-long
    r"""Check node update function implementing the offset corrected minsum.
    The function implements
    .. math::
            \qquad y_{j \to i} = \alpha_{j \to i} \cdot {max} \left( {min}_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}| \right)-\beta , 0\right)
    where :math:`\beta=0.5` and :math:`y_{j \to i}` denotes the message from
    check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from
    VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes
    all indices of connected VNs to CN *j* and
    .. math::
        \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j})
    is the sign of the outgoing message. For further details we refer to
    [Chen]_.
    Parameters
    ----------
    msg_v2c_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents v2c messages.
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    offset: float (default `0.5`)
        Offset value to be subtracted from each outgoing message.
    Returns
    -------
    msg_c2v : [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of same shape as ``msg_c2v`` containing the updated c2v
        messages.
    """
    def _sign_val_minsum(msg):
        """Helper to replace find sign-value during min-sum decoding.
        Must be called with `map_flat_values`."""
        sign_val = tf.sign(msg)
        sign_val = tf.where(tf.equal(sign_val, 0),
                            tf.ones_like(sign_val),
                            sign_val)
        return sign_val
    # a constant used to overwrite the first min
    large_val = 100000.
    msg_v2c_rag = tf.clip_by_value(msg_v2c_rag,
                               clip_value_min=-large_val,
                               clip_value_max=large_val)
    # only output is clipped (we assume input was clipped by previous function)
    # calculate sign of outgoing msg and the node
    sign_val = tf.ragged.map_flat_values(_sign_val_minsum, msg_v2c_rag)
    sign_node = tf.reduce_prod(sign_val, axis=1)
    # TF2.9 does not support XLA for the multiplication of ragged tensors
    # the following code provides a workaround that supports XLA
    # sign_val = self._stop_ragged_gradient(sign_val) \
    #             * tf.expand_dims(sign_node, axis=1)
    sign_val = tf.ragged.map_flat_values(
                                    lambda x, y, row_ind:
                                    tf.multiply(x, tf.gather(y, row_ind)),
                                    sign_val,
                                    sign_node,
                                    sign_val.value_rowids())
    # remove sign from messages
    msg = tf.ragged.map_flat_values(tf.abs, msg_v2c_rag)
    # Calculate the extrinsic minimum per CN, i.e., for each message of
    # index i, find the smallest and the second smallest value.
    # However, in some cases the second smallest value may equal the
    # smallest value (multiplicity of mins).
    # Please note that this needs to be applied to raggedTensors, e.g.,
    # tf.top_k() is currently not supported and all ops must support graph
    # and XLA mode.
    # find min_value per node
    min_val = tf.reduce_min(msg, axis=1, keepdims=True)
    # TF2.9 does not support XLA for the subtraction of ragged tensors
    # the following code provides a workaround that supports XLA
    # and subtract min; the new array contains zero at the min positions
    # benefits from broadcasting; all other values are positive
    msg_min1 = tf.ragged.map_flat_values(lambda x, y, row_ind:
                                            x - tf.gather(y, row_ind),
                                            msg,
                                            tf.squeeze(min_val, axis=1),
                                            msg.value_rowids())
    # replace 0 (=min positions) with large value to ignore it for further
    # min calculations
    msg = tf.ragged.map_flat_values(
                        lambda x: tf.where(tf.equal(x, 0), large_val, x),
                        msg_min1)
    # find the second smallest element (we add min_val as this has been
    # subtracted before)
    min_val_2 = tf.reduce_min(msg, axis=1, keepdims=True) + min_val
    # Detect duplicated minima (i.e., min_val occurs at two incoming
    # messages). As the LLRs per node are <LLR_MAX and we have
    # replace at least 1 position (position with message "min_val") by
    # large_val, it holds for the sum < large_val + node_degree*LLR_MAX.
    # If the sum > 2*large_val, the multiplicity of the min is at least 2.
    node_sum = tf.reduce_sum(msg, axis=1, keepdims=True) - (2*large_val-1.)
    # indicator that duplicated min was detected (per node)
    double_min = 0.5*(1-tf.sign(node_sum))
    # if a duplicate min occurred, both edges must have min_val, otherwise
    # the second smallest value is taken
    min_val_e = (1-double_min) * min_val + (double_min) * min_val_2
    # replace all values with min_val except the position where the min
    # occurred (=extrinsic min).
    # no XLA support for TF 2.15
    # msg_e = tf.where(msg==large_val, min_val_e, min_val)
    min_1 = tf.squeeze(tf.gather(min_val, msg.value_rowids()), axis=1)
    min_e = tf.squeeze(tf.gather(min_val_e, msg.value_rowids()), axis=1)
    msg_e = tf.ragged.map_flat_values(
                lambda x: tf.where(x==large_val, min_e, min_1), msg)
    # it seems like tf.where does not set the shape of tf.ragged properly
    # we need to ensure the shape manually
    msg_e = tf.ragged.map_flat_values(
                lambda x: tf.ensure_shape(x, msg.flat_values.shape), msg_e)
    # apply offset
    msg_e = tf.ragged.map_flat_values(lambda x,y: tf.maximum(x-y, 0),
                                      msg_e, offset)
    # TF2.9 does not support XLA for the multiplication of ragged tensors
    # the following code provides a workaround that supports XLA
    # and apply sign
    #msg = sign_val * msg_e
    msg = tf.ragged.map_flat_values(tf.multiply, sign_val, msg_e)
    # clip output values if required
    if llr_clipping is not None:
        msg = tf.clip_by_value(msg,
                    clip_value_min=-llr_clipping, clip_value_max=llr_clipping)
    return msg 
[docs]
def cn_update_minsum(msg_v2c_rag, llr_clipping=None):
    # pylint: disable=line-too-long
    r"""Check node update function implementing the `minsum` update.
    The function implements
    .. math::
            \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}(j) \setminus i} \left(|x_{i' \to j}|\right)
    where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to
    variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*,
    respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of
    connected VNs to CN *j* and
    .. math::
        \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j})
    is the sign of the outgoing message. For further details we refer to
    [Ryan]_ and [Chen]_.
    Parameters
    ----------
    msg_v2c_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents v2c messages
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    Returns
    -------
    msg_c2v : [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of same shape as ``msg_c2v`` containing the updated c2v
        messages.
    """
    msg_c2v = cn_update_offset_minsum(msg_v2c_rag,
                                      llr_clipping=llr_clipping,
                                      offset=0)
    return msg_c2v 
[docs]
def cn_update_tanh(msg, llr_clipping=None):
    # pylint: disable=line-too-long
    r"""Check node update function implementing the `boxplus` operation.
    This function implements the (extrinsic) check node update
    function. It calculates the boxplus function over all incoming messages
    "msg" excluding the intrinsic (=outgoing) message itself.
    The exact boxplus function is implemented by using the tanh function.
    The function implements
    .. math::
            y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right)
    where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to
    variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*,
    respectively. Further, :math:`\mathcal{N}(j)` denotes all indices of
    connected VNs to CN *j* and
    .. math::
        \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j})
    is the sign of the outgoing message. For further details we refer to
    [Ryan]_.
    Note that for numerical stability clipping can be applied.
    Parameters
    ----------
    msg_v2c_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents v2c messages
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    Returns
    -------
    msg_c2v : [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of same shape as ``msg_c2v`` containing the updated c2v
        messages.
    """
    # clipping value for the atanh function is applied (tf.float32 is used)
    atanh_clip_value = 1 - 1e-7
    msg = msg / 2
    # tanh is not overloaded for ragged tensors
    msg = tf.ragged.map_flat_values(tf.tanh, msg) # tanh is not overloaded
    # for ragged tensors; map to flat tensor first
    msg = tf.ragged.map_flat_values(
            lambda x: tf.where(tf.equal(x, 0), tf.ones_like(x) * 1e-12, x), msg)
    msg_prod = tf.reduce_prod(msg, axis=1)
    # TF2.9 does not support XLA for the multiplication of ragged tensors
    # the following code provides a workaround that supports XLA
    # ^-1 to avoid division
    # Note this is (potentially) numerically unstable
    # msg = msg**-1 * tf.expand_dims(msg_prod, axis=1) # remove own edge
    msg = tf.ragged.map_flat_values(
                        lambda x,y,row_ind : x * tf.gather(y, row_ind),
                        msg**-1, msg_prod, msg.value_rowids())
    # Overwrite small (numerical zeros) message values with exact zero
    # these are introduced by the previous "_where_ragged" operation
    # this is required to keep the product stable (cf. _phi_update for log
    # sum implementation)
    msg = tf.ragged.map_flat_values(
        lambda x: tf.where(tf.less(tf.abs(x), 1e-7), tf.zeros_like(x), x), msg)
    msg = tf.clip_by_value(msg,
                           clip_value_min=-atanh_clip_value,
                           clip_value_max=atanh_clip_value)
    # atanh is not overloaded for ragged tensors
    msg = 2 * tf.ragged.map_flat_values(tf.atanh, msg)
    # clip output values if required
    if llr_clipping is not None:
        msg = tf.clip_by_value(msg,
                               clip_value_min=-llr_clipping,
                               clip_value_max=llr_clipping)
    return msg 
[docs]
def cn_update_phi(msg, llr_clipping=None):
    # pylint: disable=line-too-long
    r"""Check node update function implementing the `boxplus` operation.
    This function implements the (extrinsic) check node update function
    based on the numerically more stable `"_phi"` function (cf. [Ryan]_).
    It calculates the boxplus function over all incoming messages ``msg``
    excluding the intrinsic (=outgoing) message itself.
    The exact boxplus function is implemented by using the `"_phi"` function
    as in [Ryan]_.
    The function implements
    .. math::
            y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right)
    where :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x} {2}) \right)`
    and :math:`y_{j \to i}` denotes the message from check node
    (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to
    CN *j*, respectively. Further, :math:`\mathcal{N}(j)` denotes all indices
    of connected VNs to CN *j* and
    .. math::
        \alpha_{j \to i} = \prod_{i' \in \mathcal{N}(j) \setminus i} \operatorname{sign}(x_{i' \to j})
    is the sign of the outgoing message. For further details we refer to
    [Ryan]_.
    Note that for numerical stability clipping can be applied.
    Parameters
    ----------
    msg_v2c_rag: [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of shape `[num_nodes, None, batch_size]` where the second
        axis is ragged (represents individual node degrees).
        Represents v2c messages
    llr_clipping: `None` (default) | float
        Clipping value used for internal processing. If `None`, no internal
        clipping is applied.
    Returns
    -------
    msg_c2v : [num_nodes, None, batch_size], tf.ragged
        Ragged tensor of same shape as ``msg_c2v`` containing the updated c2v
        messages.
    """
    def _phi(x):
        # pylint: disable=line-too-long
        r"""Utility function for the boxplus-phi check node update.
        This function implements the (element-wise) `"phi"` function as defined
        in [Ryan]_  :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)`.
        Parameters
        ----------
        x : tf.float
            Input tensor of arbitrary shape.
        Returns
        -------
        : tf.float
            Tensor of same shape and dtype as ``x``.
        """
        if x.dtype==tf.float32:
            # the clipping values are optimized for tf.float32
            x = tf.clip_by_value(x,
                    clip_value_min=8.5e-8, clip_value_max=16.635532)
        elif x.dtype==tf.float64:
            x = tf.clip_by_value(x,
                    clip_value_min=1e-12, clip_value_max=28.324079)
        else:
            raise TypeError("Unsupported dtype for phi function.")
        return tf.math.log(tf.math.exp(x)+1) - tf.math.log(tf.math.exp(x)-1)
    ##################
    # Sign of messages
    ##################
    sign_val = tf.sign(msg)
    # TF2.14 does not support XLA for tf.where
    # the following code provides a workaround that supports XLA
    sign_val = tf.ragged.map_flat_values(lambda x : tf.where(tf.equal(x, 0),
                                         tf.ones_like(x), x), sign_val)
    # calculate sign of entire node
    sign_node = tf.reduce_prod(sign_val, axis=1)
    # TF2.9 does not support XLA for the multiplication of ragged tensors
    # the following code provides a workaround that supports XLA
    #sign_val = sign_val * tf.expand_dims(sign_node, axis=1)
    sign_val = tf.ragged.map_flat_values(
                lambda x,y,row_ind : x * tf.gather(y, row_ind),
                sign_val, sign_node, sign_val.value_rowids())
    ###################
    # Value of messages
    ###################
    msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign
    # apply _phi element-wise
    msg = tf.ragged.map_flat_values(_phi, msg)
    # sum over entire node
    msg_sum = tf.reduce_sum(msg, axis=1)
    # TF2.9 does not support XLA for the addition of ragged tensors
    # the following code provides a workaround that supports XLA
    #msg = tf.add( -msg, tf.expand_dims(msg_sum, axis=1)) # remove own edge
    msg = tf.ragged.map_flat_values(
                            lambda x, y, row_ind : x + tf.gather(y, row_ind),
                            -1.*msg, msg_sum, msg.value_rowids())
    # apply _phi element-wise (does not support ragged Tensors)
    sign_val = sign_val.with_flat_values(tf.stop_gradient(sign_val.flat_values))
    msg_e = sign_val * tf.ragged.map_flat_values(_phi, msg)
    if llr_clipping is not None:
        msg_e = tf.clip_by_value(msg_e,
                    clip_value_min=-llr_clipping, clip_value_max=llr_clipping)
    return msg_e 
[docs]
class LDPC5GDecoder(LDPCBPDecoder):
    # pylint: disable=line-too-long
    r"""Iterative belief propagation decoder for 5G NR LDPC codes.
    Inherits from :class:`~sionna.phy.fec.ldpc.decoding.LDPCBPDecoder` and
    provides a wrapper for 5G compatibility, i.e., automatically handles
    rate-matching according to [3GPPTS38212_LDPC]_.
    Note that for full 5G 3GPP NR compatibility, the correct puncturing and
    shortening patterns must be applied and, thus, the encoder object is
    required as input.
    If required the decoder can be made trainable and is differentiable
    (the training of some check node types may be not supported) following the
    concept of "weighted BP" [Nachmani]_.
    Parameters
    ----------
    encoder: LDPC5GEncoder
        An instance of :class:`~sionna.phy.fec.ldpc.encoding.LDPC5GEncoder`
        containing the correct code parameters.
    cn_update: `str`, "boxplus-phi" (default) | "boxplus" | "minsum" | "offset-minsum" | "identity" | callable
        Check node update rule to be used as described above.
        If a callable is provided, it will be used instead as CN update.
        The input of the function is a ragged tensor of v2c messages of shape
        `[num_cns, None, batch_size]` where the second dimension is ragged
        (i.e., depends on the individual CN degree).
    vn_update: `str`, "sum" (default) | "identity" | callable
        Variable node update rule to be used.
        If a callable is provided, it will be used instead as VN update.
        The input of the function is a ragged tensor of c2v messages of shape
        `[num_vns, None, batch_size]` where the second dimension is ragged
        (i.e., depends on the individual VN degree).
    cn_schedule: "flooding" | "layered" | [num_update_steps, num_active_nodes], tf.int
        Defines the CN update scheduling per BP iteration. Can be either
        "flooding" to update all nodes in parallel (recommended) or "layered"
        to sequentally update all CNs in the same lifting group together or an
        2D tensor where each row defines the `num_active_nodes` node indices to
        be updated per subiteration. In this case each BP iteration runs
        `num_update_steps` subiterations, thus the decoder's level of
        parallelization is lower and usually the decoding throughput decreases.
    hard_out: `bool`, (default `True`)
        If `True`,  the decoder provides hard-decided codeword bits instead of
        soft-values.
    return_infobits: `bool`, (default `True`)
        If `True`, only the `k` info bits (soft or hard-decided) are returned.
        Otherwise all `n` positions are returned.
    prune_pcm: `bool`, (default `True`)
        If `True`, all punctured degree-1 VNs and connected check nodes are
        removed from the decoding graph (see [Cammerer]_ for details). Besides
        numerical differences, this should yield the same decoding result but
        improved the decoding throughput and reduces the memory footprint.
    num_iter: `int` (default: 20)
        Defining the number of decoder iterations (due to batching, no early
        stopping used at the moment!).
    llr_max: `float` (default: 20) | `None`
        Internal clipping value for all internal messages. If `None`, no
        clipping is applied.
    v2c_callbacks, `None` (default) | list of callables
        Each callable will be executed after each VN update with the following
        arguments `msg_vn_rag_`, `it`, `x_hat`,where `msg_vn_rag_` are the v2c
        messages as ragged tensor of shape `[num_vns, None, batch_size]`,
        `x_hat` is the current estimate of each VN of shape
        `[num_vns, batch_size]` and `it` is the current iteration counter.
        It must return and updated version of `msg_vn_rag_` of same shape.
    c2v_callbacks: `None` (default) | list of callables
        Each callable will be executed after each CN update with the following
        arguments `msg_cn_rag_` and `it` where `msg_cn_rag_` are the c2v
        messages as ragged tensor of shape `[num_cns, None, batch_size]` and
        `it` is the current iteration counter.
        It must return and updated version of `msg_cn_rag_` of same shape.
    return_state: `bool`, (default `False`)
        If `True`,  the internal VN messages ``msg_vn`` from the last decoding
        iteration are returned, and ``msg_vn`` or `None` needs to be given as a
        second input when calling the decoder.
        This can be used for iterative demapping and decoding.
    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.
    Input
    -----
    llr_ch: [...,n], tf.float
        Tensor containing the channel logits/llr values.
    msg_v2c: `None` | [num_edges, batch_size], tf.float
        Tensor of VN messages representing the internal decoder state.
        Required only if the decoder shall use its previous internal state, e.g.
        for iterative detection and decoding (IDD) schemes.
    Output
    ------
    : [...,n] or [...,k], tf.float
        Tensor of same shape as ``llr_ch`` containing
        bit-wise soft-estimates (or hard-decided bit-values) of all
        `n` codeword bits or only the `k` information bits if
        ``return_infobits`` is True.
    : [num_edges, batch_size], tf.float:
        Tensor of VN messages representing the internal decoder state.
        Returned only if ``return_state`` is set to `True`.
        Remark: always retruns entire decoder state, even if
        ``return_infobits`` is True.
    Note
    ----
    As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}`
    are assumed for compatibility with the learning framework, but internally
    LLRs with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are
    used.
    The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC
    codes and, thus, supports arbitrary parity-check matrices.
    The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to
    account for arbitrary node degrees. To avoid a performance degradation
    caused by a severe indexing overhead, the batch-dimension is shifted to
    the last dimension during decoding.
    """
    def __init__(self,
                 encoder,
                 cn_update="boxplus-phi",
                 vn_update="sum",
                 cn_schedule="flooding",
                 hard_out=True,
                 return_infobits=True,
                 num_iter=20,
                 llr_max=20.,
                 v2c_callbacks=None,
                 c2v_callbacks=None,
                 prune_pcm=True,
                 return_state=False,
                 precision=None,
                 **kwargs):
        # needs the 5G Encoder to access all 5G parameters
        if not isinstance(encoder, LDPC5GEncoder):
            raise TypeError("encoder must be of class LDPC5GEncoder.")
        self._encoder = encoder
        pcm = encoder.pcm
        if not isinstance(return_infobits, bool):
            raise TypeError('return_info must be bool.')
        self._return_infobits = return_infobits
        if not isinstance(return_state, bool):
            raise TypeError('return_state must be bool.')
        self._return_state = return_state
        # Deprecation warning for cn_type
        if 'cn_type' in kwargs:
            raise TypeError("'cn_type' is deprecated; use 'cn_update' instead.")
        # prune punctured degree-1 VNs and connected CNs. A punctured
        # VN-1 node will always "send" llr=0 to the connected CN. Thus, this
        # CN will only send 0 messages to all other VNs, i.e., does not
        # contribute to the decoding process.
        if not isinstance(prune_pcm, bool):
            raise TypeError('prune_pcm must be bool.')
        self._prune_pcm = prune_pcm
        if prune_pcm:
            # find index of first position with only degree-1 VN
            dv = np.sum(pcm, axis=0) # VN degree
            last_pos = encoder._n_ldpc
            for idx in range(encoder._n_ldpc-1, 0, -1):
                if dv[0, idx]==1:
                    last_pos = idx
                else:
                    break
            # number of filler bits
            k_filler = self.encoder.k_ldpc - self.encoder.k
            # number of punctured bits
            nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
                                     - self.encoder.n - 2*self.encoder.z)
            # if layered decoding is used, qunatized number of punctured bits
            # to a multiple of z; otherwise scheduling groups of Z CNs becomes
            # impossible
            if cn_schedule=="layered":
                nb_punc_bits = np.floor(nb_punc_bits/self.encoder.z) \
                             
* self.encoder.z
                nb_punc_bits = int (nb_punc_bits) # cast to int
            # effective codeword length after pruning of vn-1 nodes
            self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits))
            self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned
            # remove last CNs and VNs from pcm
            pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes]
            #check for consistency
            if self._nb_pruned_nodes<0:
                msg = "Internal error: number of pruned nodes must be positive."
                raise ArithmeticError(msg)
        else:
            # no pruning; same length as before
            self._nb_pruned_nodes = 0
            self._n_pruned = encoder._n_ldpc
        if cn_schedule=="layered":
            z = self._encoder.z
            num_blocks = int(pcm.shape[0]/z)
            cn_schedule = []
            for i in range(num_blocks):
                cn_schedule.append(np.arange(z) + i*z)
            cn_schedule = tf.stack(cn_schedule, axis=0)
        super().__init__(pcm,
                         cn_update=cn_update,
                         vn_update=vn_update,
                         cn_schedule=cn_schedule,
                         hard_out=hard_out,
                         num_iter=num_iter,
                         llr_max=llr_max,
                         v2c_callbacks=v2c_callbacks,
                         c2v_callbacks=c2v_callbacks,
                         return_state=return_state,
                         precision=precision,
                         **kwargs)
    ###############################
    # Public methods and properties
    ###############################
    @property
    def encoder(self):
        """LDPC Encoder used for rate-matching/recovery"""
        return self._encoder
    ########################
    # Sionna block functions
    ########################
    def build(self, input_shape, **kwargs):
        """Build block"""
        # check input dimensions for consistency
        if input_shape[-1]!=self.encoder.n:
            raise ValueError('Last dimension must be of length n.')
        self._old_shape_5g = input_shape
    def call(self, llr_ch, /, *, num_iter=None, msg_v2c=None):
        """Iterative BP decoding function and rate matching.
        """
        llr_ch_shape = llr_ch.get_shape().as_list()
        new_shape = [-1, self.encoder.n]
        llr_ch_reshaped = tf.reshape(llr_ch, new_shape)
        batch_size = tf.shape(llr_ch_reshaped)[0]
        # invert if rate-matching output interleaver was applied as defined in
        # Sec. 5.4.2.2 in 38.212
        if self._encoder.num_bits_per_symbol is not None:
            llr_ch_reshaped = tf.gather(llr_ch_reshaped,
                                        self._encoder.out_int_inv,
                                        axis=-1)
        # undo puncturing of the first 2*Z bit positions
        llr_5g = tf.concat(
                    [tf.zeros([batch_size, 2*self.encoder.z], self.rdtype),
                    llr_ch_reshaped], axis=1)
        # undo puncturing of the last positions
        # total length must be n_ldpc, while llr_ch has length n
        # first 2*z positions are already added
        # -> add n_ldpc - n - 2Z punctured positions
        k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits
        nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
                        - self.encoder.n - 2*self.encoder.z)
        llr_5g = tf.concat([llr_5g,
                    tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes],
                            self.rdtype)], axis=1)
        # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max)
        # the first k positions are the systematic bits
        x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k])
        # parity part
        nb_par_bits = (self.encoder.n_ldpc - k_filler
                       - self.encoder.k - self._nb_pruned_nodes)
        x2 = tf.slice(llr_5g,
                      [0, self.encoder.k],
                      [batch_size, nb_par_bits])
        # negative sign due to logit definition
        z = -tf.cast(self._llr_max, self.rdtype) \
            
* tf.ones([batch_size, k_filler], self.rdtype)
        llr_5g = tf.concat([x1, z, x2], axis=1)
        # and run the core decoder
        output = super().call(llr_5g, num_iter=num_iter, msg_v2c=msg_v2c)
        if self._return_state:
            x_hat, msg_v2c = output
        else:
            x_hat = output
        if self._return_infobits:# return only info bits
            # reconstruct u_hat
            # 5G NR code is systematic
            u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k])
            # Reshape u_hat so that it matches the original input dimensions
            output_shape = llr_ch_shape[0:-1] + [self.encoder.k]
            # overwrite first dimension as this could be None
            output_shape[0] = -1
            u_reshaped = tf.reshape(u_hat, output_shape)
            if self._return_state:
                return u_reshaped, msg_v2c
            else:
                return u_reshaped
        else: # return all codeword bits
            # The transmitted CW bits are not the same as used during decoding
            # cf. last parts of 5G encoding function
            # remove last dim
            x = tf.reshape(x_hat, [batch_size, self._n_pruned])
            # remove filler bits at pos (k, k_ldpc)
            x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k])
            x_no_filler2 = tf.slice(x,
                                    [0, self.encoder.k_ldpc],
                                    [batch_size,
                                    self._n_pruned-self.encoder.k_ldpc])
            x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1)
            # shorten the first 2*Z positions and end after n bits
            x_short = tf.slice(x_no_filler,
                               [0, 2*self.encoder.z],
                               [batch_size, self.encoder.n])
            # if used, apply rate-matching output interleaver again as
            # Sec. 5.4.2.2 in 38.212
            if self._encoder.num_bits_per_symbol is not None:
                x_short = tf.gather(x_short, self._encoder.out_int, axis=-1)
            # Reshape x_short so that it matches the original input dimensions
            # overwrite first dimension as this could be None
            llr_ch_shape[0] = -1
            x_short= tf.reshape(x_short, llr_ch_shape)
            if self._return_state:
                return x_short, msg_v2c
            else:
                return x_short