Source code for sionna.fec.ldpc.decoding

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layers for channel decoding and utility functions."""

import tensorflow as tf
import numpy as np
import scipy as sp # for sparse H matrix computations
from tensorflow.keras.layers import Layer
from sionna.fec.ldpc.encoding import LDPC5GEncoder
from sionna.fec.utils import llr2mi
import matplotlib.pyplot as plt

[docs]class LDPCBPDecoder(Layer): # pylint: disable=line-too-long r"""LDPCBPDecoder(pcm, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, num_iter=20, stateful=False,output_dtype=tf.float32, **kwargs) Iterative belief propagation decoder for low-density parity-check (LDPC) codes and other `codes on graphs`. This class defines a generic belief propagation decoder for decoding with arbitrary parity-check matrices. It can be used to iteratively estimate/recover the transmitted codeword (or information bits) based on the LLR-values of the received noisy codeword observation. The decoder implements the flooding SPA algorithm [Ryan]_, i.e., all nodes are updated in a parallel fashion. Different check node update functions are available (1) `boxplus` .. math:: y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right) (2) `boxplus-phi` .. math:: y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}_(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right) with :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)` (3) `minsum` .. math:: \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}_(j) \setminus i} \left(|x_{i' \to j}|\right) where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, respectively. Further, :math:`\mathcal{N}_(j)` denotes all indices of connected VNs to CN *j* and .. math:: \alpha_{j \to i} = \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{sign}(x_{i' \to j}) is the sign of the outgoing message. For further details we refer to [Ryan]_. Note that for full 5G 3GPP NR compatibility, the correct puncturing and shortening patterns must be applied (cf. [Richardson]_ for details), this can be done by :class:`~sionna.fec.ldpc.decoding.LDPC5GEncoder` and :class:`~sionna.fec.ldpc.decoding.LDPC5GDecoder`, respectively. If required, the decoder can be made trainable and is fully differentiable by following the concept of `weighted BP` [Nachmani]_ as shown in Fig. 1 leading to .. math:: y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{\textcolor{red}{w_{i' \to j}} \cdot x_{i' \to j}}{2} \right) \right) where :math:`w_{i \to j}` denotes the trainable weight of message :math:`x_{i \to j}`. Please note that the training of some check node types may be not supported. .. figure:: ../figures/weighted_bp.png Fig. 1: Weighted BP as proposed in [Nachmani]_. For numerical stability, the decoder applies LLR clipping of +/- 20 to the input LLRs. The class inherits from the Keras layer class and can be used as layer in a Keras model. Parameters ---------- pcm: ndarray An ndarray of shape `[n-k, n]` defining the parity-check matrix consisting only of `0` or `1` entries. Can be also of type `scipy. sparse.csr_matrix` or `scipy.sparse.csc_matrix`. trainable: bool Defaults to False. If True, every outgoing variable node message is scaled with a trainable scalar. cn_type: str A string defaults to '"boxplus-phi"'. One of {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where '"boxplus"' implements the single-parity-check APP decoding rule. '"boxplus-phi"' implements the numerical more stable version of boxplus [Ryan]_. '"minsum"' implements the min-approximation of the CN update rule [Ryan]_. hard_out: bool Defaults to True. If True, the decoder provides hard-decided codeword bits instead of soft-values. track_exit: bool Defaults to False. If True, the decoder tracks EXIT characteristics. Note that this requires the all-zero CW as input. num_iter: int Defining the number of decoder iteration (no early stopping used at the moment!). stateful: bool Defaults to False. If True, the internal VN messages ``msg_vn`` from the last decoding iteration are returned, and ``msg_vn`` or `None` needs to be given as a second input when calling the decoder. This is required for iterative demapping and decoding. output_dtype: tf.DType Defaults to tf.float32. Defines the output datatype of the layer (internal precision remains tf.float32). Input ----- llrs_ch or (llrs_ch, msg_vn): Tensor or Tuple (only required if ``stateful`` is True): llrs_ch: [...,n], tf.float32 2+D tensor containing the channel logits/llr values. msg_vn: None or RaggedTensor, tf.float32 Ragged tensor of VN messages. Required only if ``stateful`` is True. Output ------ : [...,n], tf.float32 2+D Tensor of same shape as ``inputs`` containing bit-wise soft-estimates (or hard-decided bit-values) of all codeword bits. : RaggedTensor, tf.float32: Tensor of VN messages. Returned only if ``stateful`` is set to True. Attributes ---------- pcm: ndarray An ndarray of shape `[n-k, n]` defining the parity-check matrix consisting only of `0` or `1` entries. Can be also of type `scipy. sparse.csr_matrix` or `scipy.sparse.csc_matrix`. num_cns: int Defining the number of check nodes. num_vns: int Defining the number of variable nodes. num_edges: int Defining the total number of edges. trainable: bool If True, the decoder uses trainable weights. _atanh_clip_value: float Defining the internal clipping value before the atanh is applied (relates to the CN update). _cn_type: str Defining the CN update function type. _cn_update: A function defining the CN update. _hard_out: bool If True, the decoder outputs hard-decided bits. _cn_con: ndarray An ndarray of shape `[num_edges]` defining all edges from check node perspective. _vn_con: ndarray An ndarray of shape `[num_edges]` defining all edges from variable node perspective. _vn_mask_tf: tf.float32 A ragged Tensor of shape `[num_vns, None]` defining the incoming message indices per VN. The second dimension is ragged and depends on the node degree. _cn_mask_tf: tf.float32 A ragged Tensor of shape `[num_cns, None]` defining the incoming message indices per CN. The second dimension is ragged and depends on the node degree. _ind_cn: ndarray An ndarray of shape `[num_edges]` defining the permutation index to rearrange messages from variable into check node perspective. _ind_cn_inv: ndarray An ndarray of shape `[num_edges]` defining the permutation index to rearrange messages from check into variable node perspective. _vn_row_splits: ndarray An ndarray of shape `[num_vns+1]` defining the row split positions of a 1D vector consisting of all edges messages. Used to build a ragged Tensor of incoming VN messages. _cn_row_splits: ndarray An ndarray of shape `[num_cns+1]` defining the row split positions of a 1D vector consisting of all edges messages. Used to build a ragged Tensor of incoming CN messages. _edge_weights: tf.float32 A Tensor of shape `[num_edges]` defining a (trainable) weight per outgoing VN message. Raises: ValueError If the shape of ``pcm`` is invalid or contains other values than `0` or `1` or dtype is not `tf.float32`. ValueError If ``num_iter`` is not an integer greater (or equal) `0`. ValueError If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}. ValueError If ``inputs`` is not of shape `[batch_size, n]`. InvalidArgumentError When rank(``inputs``)<2. Note ---- As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for compatibility with the learning framework, but internally log-likelihood ratios (LLRs) with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC codes and, thus, supports arbitrary parity-check matrices. The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to account for arbitrary node degrees. To avoid a performance degradation caused by a severe indexing overhead, the batch-dimension is shifted to the last dimension during decoding. If the decoder is made trainable [Nachmani]_, for performance improvements only variable to check node messages are scaled as the VN operation is linear and, thus, would not increase the expressive power of the weights. """ def __init__(self, pcm, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs): super().__init__(dtype=output_dtype, **kwargs) assert isinstance(trainable, bool), 'trainable must be bool.' assert isinstance(hard_out, bool), 'hard_out must be bool.' assert isinstance(track_exit, bool), 'track_exit must be bool.' assert isinstance(cn_type, str) , 'cn_type must be str.' assert isinstance(num_iter, int), 'num_iter must be int.' assert num_iter>=0, 'num_iter cannot be negative.' assert isinstance(stateful, bool), 'stateful must be bool.' assert isinstance(output_dtype, tf.DType), \ 'output_dtype must be tf.Dtype.' if isinstance(pcm, np.ndarray): assert np.array_equal(pcm, pcm.astype(bool)), 'PC matrix \ must be binary.' elif isinstance(pcm, sp.sparse.csr_matrix): assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 'PC matrix must be binary.' elif isinstance(pcm, sp.sparse.csc_matrix): assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 'PC matrix must be binary.' else: raise TypeError("Unsupported dtype of pcm.") if output_dtype not in (tf.float16, tf.float32, tf.float64): raise ValueError( 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') if output_dtype is not tf.float32: print('Note: decoder uses tf.float32 for internal calculations.') # init decoder parameters self._pcm = pcm self._trainable = trainable self._cn_type = cn_type self._hard_out = hard_out self._track_exit = track_exit self._num_iter = tf.constant(num_iter, dtype=tf.int32) self._stateful = stateful self._output_dtype = output_dtype # clipping value for the atanh function is applied (tf.float32 is used) self._atanh_clip_value = 1 - 1e-7 # internal value for llr clipping self._llr_max = tf.constant(20., tf.float32) # init code parameters self._num_cns = pcm.shape[0] # total number of check nodes self._num_vns = pcm.shape[1] # total number of variable nodes # make pcm sparse first if ndarray is provided if isinstance(pcm, np.ndarray): pcm = sp.sparse.csr_matrix(pcm) # find all edges from variable and check node perspective self._cn_con, self._vn_con, _ = sp.sparse.find(pcm) # sort indices explicitly, as scipy.sparse.find changed from column to # row sorting in scipy>=1.11 idx = np.argsort(self._vn_con) self._cn_con = self._cn_con[idx] self._vn_con = self._vn_con[idx] # number of edges equals number of non-zero elements in the # parity-check matrix self._num_edges = len(self._vn_con) # permutation index to rearrange messages into check node perspective self._ind_cn = np.argsort(self._cn_con) # inverse permutation index to rearrange messages back into variable # node perspective self._ind_cn_inv = np.argsort(self._ind_cn) # generate row masks (array of integers defining the row split pos.) self._vn_row_splits = self._gen_node_mask_row(self._vn_con) self._cn_row_splits = self._gen_node_mask_row( self._cn_con[self._ind_cn]) # pre-load the CN function for performance reasons if self._cn_type=='boxplus': # check node update using the tanh function self._cn_update = self._cn_update_tanh elif self._cn_type=='boxplus-phi': # check node update using the "_phi" function self._cn_update = self._cn_update_phi elif self._cn_type=='minsum': # check node update using the min-sum approximation self._cn_update = self._cn_update_minsum else: raise ValueError('Unknown node type.') # init trainable weights if needed self._has_weights = False # indicates if trainable weights exist if self._trainable: self._has_weights = True self._edge_weights = tf.Variable(tf.ones(self._num_edges), trainable=self._trainable, dtype=tf.float32) # track mutual information during decoding self._ie_c = 0 self._ie_v = 0 ######################################### # Public methods and properties ######################################### @property def pcm(self): """Parity-check matrix of LDPC code.""" return self._pcm @property def num_cns(self): """Number of check nodes.""" return self._num_cns @property def num_vns(self): """Number of variable nodes.""" return self._num_vns @property def num_edges(self): """Number of edges in decoding graph.""" return self._num_edges @property def has_weights(self): """Indicates if decoder has trainable weights.""" return self._has_weights @property def edge_weights(self): """Trainable weights of the BP decoder.""" if not self._has_weights: return [] else: return self._edge_weights @property def output_dtype(self): """Output dtype of decoder.""" return self._output_dtype @property def ie_c(self): "Extrinsic mutual information at check node." return self._ie_c @property def ie_v(self): "Extrinsic mutual information at variable node." return self._ie_v @property def num_iter(self): "Number of decoding iterations." return self._num_iter @num_iter.setter def num_iter(self, num_iter): "Number of decoding iterations." assert isinstance(num_iter, int), 'num_iter must be int.' assert num_iter>=0, 'num_iter cannot be negative.' self._num_iter = tf.constant(num_iter, dtype=tf.int32) @property def llr_max(self): """Max LLR value used for internal calculations and rate-matching.""" return self._llr_max @llr_max.setter def llr_max(self, value): """Max LLR value used for internal calculations and rate-matching.""" assert value>=0, 'llr_max cannot be negative.' self._llr_max = tf.cast(value, dtype=tf.float32)
[docs] def show_weights(self, size=7): """Show histogram of trainable weights. Input ----- size: float Figure size of the matplotlib figure. """ # only plot if weights exist if self._has_weights: weights = self._edge_weights.numpy() plt.figure(figsize=(size,size)) plt.hist(weights, density=True, bins=20, align='mid') plt.xlabel('weight value') plt.ylabel('density') plt.grid(True, which='both', axis='both') plt.title('Weight Distribution') else: print("No weights to show.")
######################### # Utility methods ######################### def _gen_node_mask(self, con): """ Generates internal node masks indicating which msg index belongs to which node index. """ ind = np.argsort(con) con = con[ind] node_mask = [] cur_node = 0 cur_mask = [] for i in range(self._num_edges): if con[i] == cur_node: cur_mask.append(ind[i]) else: node_mask.append(cur_mask) cur_mask = [ind[i]] cur_node += 1 node_mask.append(cur_mask) return node_mask def _gen_node_mask_row(self, con): """ Defining the row split positions of a 1D vector consisting of all edges messages. Used to build a ragged Tensor of incoming node messages. """ node_mask = [0] # the first element indicates the first node index (=0) cur_node = 0 for i in range(self._num_edges): if con[i] != cur_node: node_mask.append(i) cur_node += 1 node_mask.append(self._num_edges) # last element must be the number of # elements (delimiter) return node_mask def _vn_update(self, msg, llr_ch): """ Variable node update function. This function implements the (extrinsic) variable node update function. It takes the sum over all incoming messages ``msg`` excluding the intrinsic (= outgoing) message itself. Additionally, the channel LLR ``llr_ch`` is added to each message. """ # aggregate all incoming messages per node x = tf.reduce_sum(msg, axis=1) x = tf.add(x, llr_ch) # TF2.9 does not support XLA for the addition of ragged tensors # the following code provides a workaround that supports XLA # subtract extrinsic message from node value # x = tf.expand_dims(x, axis=1) # x = tf.add(-msg, x) x = tf.ragged.map_flat_values(lambda x, y, row_ind : x + tf.gather(y, row_ind), -1.*msg, x, msg.value_rowids()) return x def _where_ragged(self, msg): """Helper to replace 0 elements from ragged tensor (called with map_flat_values).""" return tf.where(tf.equal(msg, 0), tf.ones_like(msg) * 1e-12, msg) def _where_ragged_inv(self, msg): """Helper to replace small elements from ragged tensor (called with map_flat_values) with exact `0`.""" msg_mod = tf.where(tf.less(tf.abs(msg), 1e-7), tf.zeros_like(msg), msg) return msg_mod def _cn_update_tanh(self, msg): """Check node update function implementing the exact boxplus operation. This function implements the (extrinsic) check node update function. It calculates the boxplus function over all incoming messages "msg" excluding the intrinsic (=outgoing) message itself. The exact boxplus function is implemented by using the tanh function. The input is expected to be a ragged Tensor of shape `[num_cns, None, batch_size]`. Note that for numerical stability clipping is applied. """ msg = msg / 2 # tanh is not overloaded for ragged tensors msg = tf.ragged.map_flat_values(tf.tanh, msg) # tanh is not overloaded # for ragged tensors; map to flat tensor first msg = tf.ragged.map_flat_values(self._where_ragged, msg) msg_prod = tf.reduce_prod(msg, axis=1) # TF2.9 does not support XLA for the multiplication of ragged tensors # the following code provides a workaround that supports XLA # ^-1 to avoid division # Note this is (potentially) numerically unstable # msg = msg**-1 * tf.expand_dims(msg_prod, axis=1) # remove own edge msg = tf.ragged.map_flat_values(lambda x, y, row_ind : x * tf.gather(y, row_ind), msg**-1, msg_prod, msg.value_rowids()) # Overwrite small (numerical zeros) message values with exact zero # these are introduced by the previous "_where_ragged" operation # this is required to keep the product stable (cf. _phi_update for log # sum implementation) msg = tf.ragged.map_flat_values(self._where_ragged_inv, msg) msg = tf.clip_by_value(msg, clip_value_min=-self._atanh_clip_value, clip_value_max=self._atanh_clip_value) # atanh is not overloaded for ragged tensors msg = 2 * tf.ragged.map_flat_values(tf.atanh, msg) return msg def _phi(self, x): """Helper function for the check node update. This function implements the (element-wise) `"_phi"` function as defined in [Ryan]_. """ # the clipping values are optimized for tf.float32 x = tf.clip_by_value(x, clip_value_min=8.5e-8, clip_value_max=16.635532) return tf.math.log(tf.math.exp(x)+1) - tf.math.log(tf.math.exp(x)-1) def _cn_update_phi(self, msg): """Check node update function implementing the exact boxplus operation. This function implements the (extrinsic) check node update function based on the numerically more stable `"_phi"` function (cf. [Ryan]_). It calculates the boxplus function over all incoming messages ``msg`` excluding the intrinsic (=outgoing) message itself. The exact boxplus function is implemented by using the `"_phi"` function as in [Ryan]_. The input is expected to be a ragged Tensor of shape `[num_cns, None, batch_size]`. Note that for numerical stability clipping is applied. """ sign_val = tf.sign(msg) # TF2.14 does not support XLA for tf.where and ragged tensors in # CPU mode. The following code provides a workaround that supports XLA # sign_val = tf.where(tf.equal(sign_val, 0), # tf.ones_like(sign_val), # sign_val) sign_val = tf.ragged.map_flat_values(lambda x : tf.where(tf.equal(x, 0), tf.ones_like(x),x), sign_val) sign_node = tf.reduce_prod(sign_val, axis=1) # TF2.9 does not support XLA for the multiplication of ragged tensors # the following code provides a workaround that supports XLA # sign_val = sign_val * tf.expand_dims(sign_node, axis=1) sign_val = tf.ragged.map_flat_values(lambda x, y, row_ind : x * tf.gather(y, row_ind), sign_val, sign_node, sign_val.value_rowids()) msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign # apply _phi element-wise (does not support ragged Tensors) msg = tf.ragged.map_flat_values(self._phi, msg) msg_sum = tf.reduce_sum(msg, axis=1) # TF2.9 does not support XLA for the addition of ragged tensors # the following code provides a workaround that supports XLA # msg = tf.add( -msg, tf.expand_dims(msg_sum, axis=1)) # remove own edge msg = tf.ragged.map_flat_values(lambda x, y, row_ind : x + tf.gather(y, row_ind), -1.*msg, msg_sum, msg.value_rowids()) # apply _phi element-wise (does not support ragged Tensors) msg = self._stop_ragged_gradient(sign_val) * tf.ragged.map_flat_values( self._phi, msg) return msg def _stop_ragged_gradient(self, rt): """Helper function as TF 2.5 does not support ragged gradient stopping""" return rt.with_flat_values(tf.stop_gradient(rt.flat_values)) def _sign_val_minsum(self, msg): """Helper to replace find sign-value during min-sum decoding. Must be called with `map_flat_values`.""" sign_val = tf.sign(msg) sign_val = tf.where(tf.equal(sign_val, 0), tf.ones_like(sign_val), sign_val) return sign_val def _cn_update_minsum(self, msg): """ Check node update function implementing the min-sum approximation. This function approximates the (extrinsic) check node update function based on the min-sum approximation (cf. [Ryan]_). It calculates the "extrinsic" min function over all incoming messages ``msg`` excluding the intrinsic (=outgoing) message itself. The input is expected to be a ragged Tensor of shape `[num_vns, None, batch_size]`. """ # a constant used to overwrite the first min LARGE_VAL = 10000. # pylint: disable=invalid-name # clip values for numerical stability msg = tf.clip_by_value(msg, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # calculate sign of outgoing msg and the node sign_val = tf.ragged.map_flat_values(self._sign_val_minsum, msg) sign_node = tf.reduce_prod(sign_val, axis=1) # TF2.9 does not support XLA for the multiplication of ragged tensors # the following code provides a workaround that supports XLA # sign_val = self._stop_ragged_gradient(sign_val) \ # * tf.expand_dims(sign_node, axis=1) sign_val = tf.ragged.map_flat_values( lambda x, y, row_ind: tf.multiply(x, tf.gather(y, row_ind)), self._stop_ragged_gradient(sign_val), sign_node, sign_val.value_rowids()) # remove sign from messages msg = tf.ragged.map_flat_values(tf.abs, msg) # Calculate the extrinsic minimum per CN, i.e., for each message of # index i, find the smallest and the second smallest value. # However, in some cases the second smallest value may equal the # smallest value (multiplicity of mins). # Please note that this needs to be applied to raggedTensors, e.g., # tf.top_k() is currently not supported and all ops must support graph # and XLA mode. # find min_value per node min_val = tf.reduce_min(msg, axis=1, keepdims=True) # TF2.9 does not support XLA for the subtraction of ragged tensors # the following code provides a workaround that supports XLA # and subtract min; the new array contains zero at the min positions # benefits from broadcasting; all other values are positive msg_min1 = tf.ragged.map_flat_values(lambda x, y, row_ind: x - tf.gather(y, row_ind), msg, tf.squeeze(min_val, axis=1), msg.value_rowids()) # replace 0 (=min positions) with large value to ignore it for further # min calculations msg = tf.ragged.map_flat_values( lambda x: tf.where(tf.equal(x, 0), LARGE_VAL, x), msg_min1) # find the second smallest element (we add min_val as this has been # subtracted before) min_val_2 = tf.reduce_min(msg, axis=1, keepdims=True) + min_val # Detect duplicated minima (i.e., min_val occurs at two incoming # messages). As the LLRs per node are <LLR_MAX and we have # replace at least 1 position (position with message "min_val") by # LARGE_VAL, it holds for the sum < LARGE_VAL + node_degree*LLR_MAX. # If the sum > 2*LARGE_VAL, the multiplicity of the min is at least 2. node_sum = tf.reduce_sum(msg, axis=1, keepdims=True) - (2*LARGE_VAL-1.) # indicator that duplicated min was detected (per node) double_min = 0.5*(1-tf.sign(node_sum)) # if a duplicate min occurred, both edges must have min_val, otherwise # the second smallest value is taken min_val_e = (1-double_min) * min_val + (double_min) * min_val_2 # replace all values with min_val except the position where the min # occurred (=extrinsic min). # no XLA support for TF 2.15 # msg_e = tf.where(msg==LARGE_VAL, min_val_e, min_val) min_1 = tf.squeeze(tf.gather(min_val, msg.value_rowids()), axis=1) min_e = tf.squeeze(tf.gather(min_val_e, msg.value_rowids()), axis=1) msg_e = tf.ragged.map_flat_values( lambda x: tf.where(x==LARGE_VAL, min_e, min_1), msg) # it seems like tf.where does not set the shape of tf.ragged properly # we need to ensure the shape manually msg_e = tf.ragged.map_flat_values( lambda x: tf.ensure_shape(x, msg.flat_values.shape), msg_e) # TF2.9 does not support XLA for the multiplication of ragged tensors # the following code provides a workaround that supports XLA # and apply sign #msg = sign_val * msg_e msg = tf.ragged.map_flat_values(tf.multiply, sign_val, msg_e) return msg def _mult_weights(self, x): """Multiply messages with trainable weights for weighted BP.""" # transpose for simpler broadcasting of training variables x = tf.transpose(x, (1, 0)) x = tf.math.multiply(x, self._edge_weights) x = tf.transpose(x, (1, 0)) return x ######################### # Keras layer functions ######################### def build(self, input_shape): # Raise AssertionError if shape of x is invalid if self._stateful: assert(len(input_shape)==2), \ "For stateful decoding, a tuple of two inputs is expected." input_shape = input_shape[0] assert (input_shape[-1]==self._num_vns), \ 'Last dimension must be of length n.' assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' def call(self, inputs): """Iterative BP decoding function. This function performs ``num_iter`` belief propagation decoding iterations and returns the estimated codeword. Args: llr_ch or (llr_ch, msg_vn): llr_ch (tf.float32): Tensor of shape `[...,n]` containing the channel logits/llr values. msg_vn (tf.float32) : Ragged tensor containing the VN messages, or None. Required if ``stateful`` is set to True. Returns: `tf.float32`: Tensor of shape `[...,n]` containing bit-wise soft-estimates (or hard-decided bit-values) of all codeword bits. Raises: ValueError: If ``inputs`` is not of shape `[batch_size, n]`. InvalidArgumentError: When rank(``inputs``)<2. """ # Extract inputs if self._stateful: llr_ch, msg_vn = inputs else: llr_ch = inputs tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') # internal calculations still in tf.float32 llr_ch = tf.cast(llr_ch, tf.float32) # clip llrs for numerical stability llr_ch = tf.clip_by_value(llr_ch, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # last dim must be of length n tf.debugging.assert_equal(tf.shape(llr_ch)[-1], self._num_vns, 'Last dimension must be of length n.') llr_ch_shape = llr_ch.get_shape().as_list() new_shape = [-1, self._num_vns] llr_ch_reshaped = tf.reshape(llr_ch, new_shape) # must be done during call, as XLA fails otherwise due to ragged # indices placed on the CPU device. # create permutation index from cn perspective self._cn_mask_tf = tf.ragged.constant(self._gen_node_mask(self._cn_con), row_splits_dtype=tf.int32) # batch dimension is last dimension due to ragged tensor representation llr_ch = tf.transpose(llr_ch_reshaped, (1,0)) llr_ch = -1. * llr_ch # logits are converted into "true" llrs # init internal decoder state if not explicitly # provided (e.g., required to restore decoder state for iterative # detection and decoding) # load internal state from previous iteration # required for iterative det./dec. if not self._stateful or msg_vn is None: msg_shape = tf.stack([tf.constant(self._num_edges), tf.shape(llr_ch)[1]], axis=0) msg_vn = tf.zeros(msg_shape, dtype=tf.float32) else: msg_vn = msg_vn.flat_values # track exit decoding trajectory; requires all-zero cw? if self._track_exit: self._ie_c = tf.zeros(self._num_iter+1) self._ie_v = tf.zeros(self._num_iter+1) # perform one decoding iteration # Remark: msg_vn cannot be ragged as input for tf.while_loop as # otherwise XLA will not be supported (with TF 2.5) def dec_iter(llr_ch, msg_vn, it): it += 1 msg_vn = tf.RaggedTensor.from_row_splits( values=msg_vn, row_splits=tf.constant(self._vn_row_splits, tf.int32)) # variable node update msg_vn = self._vn_update(msg_vn, llr_ch) # track exit decoding trajectory; requires all-zero cw if self._track_exit: # neg values as different llr def is expected mi = llr2mi(-1. * msg_vn.flat_values) self._ie_v = tf.tensor_scatter_nd_add(self._ie_v, tf.reshape(it, (1, 1)), tf.reshape(mi, (1))) # scale outgoing vn messages (weighted BP); only if activated if self._has_weights: msg_vn = tf.ragged.map_flat_values(self._mult_weights, msg_vn) # permute edges into CN perspective msg_cn = tf.gather(msg_vn.flat_values, self._cn_mask_tf, axis=None) # check node update using the pre-defined function msg_cn = self._cn_update(msg_cn) # track exit decoding trajectory; requires all-zero cw? if self._track_exit: # neg values as different llr def is expected mi = llr2mi(-1.*msg_cn.flat_values) # update pos i+1 such that first iter is stored as 0 self._ie_c = tf.tensor_scatter_nd_add(self._ie_c, tf.reshape(it, (1, 1)), tf.reshape(mi, (1))) # re-permute edges to variable node perspective msg_vn = tf.gather(msg_cn.flat_values, self._ind_cn_inv, axis=None) return llr_ch, msg_vn, it # stopping condition (required for tf.while_loop) def dec_stop(llr_ch, msg_vn, it): # pylint: disable=W0613 return tf.less(it, self._num_iter) # start decoding iterations it = tf.constant(0) # maximum_iterations required for XLA _, msg_vn, _ = tf.while_loop(dec_stop, dec_iter, (llr_ch, msg_vn, it), parallel_iterations=1, maximum_iterations=self._num_iter) # raggedTensor for final marginalization msg_vn = tf.RaggedTensor.from_row_splits( values=msg_vn, row_splits=tf.constant(self._vn_row_splits, tf.int32)) # marginalize and remove ragged Tensor x_hat = tf.add(llr_ch, tf.reduce_sum(msg_vn, axis=1)) # restore batch dimension to first dimension x_hat = tf.transpose(x_hat, (1,0)) x_hat = -1. * x_hat # convert llrs back into logits if self._hard_out: # hard decide decoder output if required x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype) # Reshape c_short so that it matches the original input dimensions output_shape = llr_ch_shape output_shape[0] = -1 # overwrite batch dim (can be None in Keras) x_reshaped = tf.reshape(x_hat, output_shape) # cast output to output_dtype x_out = tf.cast(x_reshaped, self._output_dtype) if not self._stateful: return x_out else: return x_out, msg_vn
[docs]class LDPC5GDecoder(LDPCBPDecoder): # pylint: disable=line-too-long r"""LDPC5GDecoder(encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs) (Iterative) belief propagation decoder for 5G NR LDPC codes. Inherits from :class:`~sionna.fec.ldpc.decoding.LDPCBPDecoder` and provides a wrapper for 5G compatibility, i.e., automatically handles puncturing and shortening according to [3GPPTS38212_LDPC]_. Note that for full 5G 3GPP NR compatibility, the correct puncturing and shortening patterns must be applied and, thus, the encoder object is required as input. If required the decoder can be made trainable and is differentiable (the training of some check node types may be not supported) following the concept of "weighted BP" [Nachmani]_. For numerical stability, the decoder applies LLR clipping of +/- 20 to the input LLRs. The class inherits from the Keras layer class and can be used as layer in a Keras model. Parameters ---------- encoder: LDPC5GEncoder An instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder` containing the correct code parameters. trainable: bool Defaults to False. If True, every outgoing variable node message is scaled with a trainable scalar. cn_type: str A string defaults to '"boxplus-phi"'. One of {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where '"boxplus"' implements the single-parity-check APP decoding rule. '"boxplus-phi"' implements the numerical more stable version of boxplus [Ryan]_. '"minsum"' implements the min-approximation of the CN update rule [Ryan]_. hard_out: bool Defaults to True. If True, the decoder provides hard-decided codeword bits instead of soft-values. track_exit: bool Defaults to False. If True, the decoder tracks EXIT characteristics. Note that this requires the all-zero CW as input. return_infobits: bool Defaults to True. If True, only the `k` info bits (soft or hard-decided) are returned. Otherwise all `n` positions are returned. prune_pcm: bool Defaults to True. If True, all punctured degree-1 VNs and connected check nodes are removed from the decoding graph (see [Cammerer]_ for details). Besides numerical differences, this should yield the same decoding result but improved the decoding throughput and reduces the memory footprint. num_iter: int Defining the number of decoder iteration (no early stopping used at the moment!). stateful: bool Defaults to False. If True, the internal VN messages ``msg_vn`` from the last decoding iteration are returned, and ``msg_vn`` or `None` needs to be given as a second input when calling the decoder. This is required for iterative demapping and decoding. output_dtype: tf.DType Defaults to tf.float32. Defines the output datatype of the layer (internal precision remains tf.float32). Input ----- llrs_ch or (llrs_ch, msg_vn): Tensor or Tuple (only required if ``stateful`` is True): llrs_ch: [...,n], tf.float32 2+D tensor containing the channel logits/llr values. msg_vn: None or RaggedTensor, tf.float32 Ragged tensor of VN messages. Required only if ``stateful`` is True. Output ------ : [...,n] or [...,k], tf.float32 2+D Tensor of same shape as ``inputs`` containing bit-wise soft-estimates (or hard-decided bit-values) of all codeword bits. If ``return_infobits`` is True, only the `k` information bits are returned. : RaggedTensor, tf.float32: Tensor of VN messages. Returned only if ``stateful`` is set to True. Raises ------ ValueError If the shape of ``pcm`` is invalid or contains other values than `0` or `1`. AssertionError If ``trainable`` is not `bool`. AssertionError If ``track_exit`` is not `bool`. AssertionError If ``hard_out`` is not `bool`. AssertionError If ``return_infobits`` is not `bool`. AssertionError If ``encoder`` is not an instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`. ValueError If ``output_dtype`` is not {tf.float16, tf.float32, tf. float64}. ValueError If ``inputs`` is not of shape `[batch_size, n]`. ValueError If ``num_iter`` is not an integer greater (or equal) `0`. InvalidArgumentError When rank(``inputs``)<2. Note ---- As decoding input logits :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for compatibility with the learning framework, but internally llrs with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC codes and, thus, supports arbitrary parity-check matrices. The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to account for arbitrary node degrees. To avoid a performance degradation caused by a severe indexing overhead, the batch-dimension is shifted to the last dimension during decoding. If the decoder is made trainable [Nachmani]_, for performance improvements only variable to check node messages are scaled as the VN operation is linear and, thus, would not increase the expressive power of the weights. """ def __init__(self, encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs): # needs the 5G Encoder to access all 5G parameters assert isinstance(encoder, LDPC5GEncoder), 'encoder must \ be of class LDPC5GEncoder.' self._encoder = encoder pcm = encoder.pcm assert isinstance(return_infobits, bool), 'return_info must be bool.' self._return_infobits = return_infobits assert isinstance(output_dtype, tf.DType), \ 'output_dtype must be tf.DType.' if output_dtype not in (tf.float16, tf.float32, tf.float64): raise ValueError( 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') self._output_dtype = output_dtype assert isinstance(stateful, bool), 'stateful must be bool.' self._stateful = stateful assert isinstance(prune_pcm, bool), 'prune_pcm must be bool.' # prune punctured degree-1 VNs and connected CNs. A punctured # VN-1 node will always "send" llr=0 to the connected CN. Thus, this # CN will only send 0 messages to all other VNs, i.e., does not # contribute to the decoding process. self._prune_pcm = prune_pcm if prune_pcm: # find index of first position with only degree-1 VN dv = np.sum(pcm, axis=0) # VN degree last_pos = encoder._n_ldpc for idx in range(encoder._n_ldpc-1, 0, -1): if dv[0, idx]==1: last_pos = idx else: break # number of filler bits k_filler = self.encoder.k_ldpc - self.encoder.k # number of punctured bits nb_punc_bits = ((self.encoder.n_ldpc - k_filler) - self.encoder.n - 2*self.encoder.z) # effective codeword length after pruning of vn-1 nodes self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits)) self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned # remove last CNs and VNs from pcm pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes] #check for consistency assert(self._nb_pruned_nodes>=0), "Internal error: number of \ pruned nodes must be positive." else: self._nb_pruned_nodes = 0 # no pruning; same length as before self._n_pruned = encoder._n_ldpc super().__init__(pcm, trainable, cn_type, hard_out, track_exit, num_iter=num_iter, stateful=stateful, output_dtype=output_dtype, **kwargs) ######################################### # Public methods and properties ######################################### @property def encoder(self): """LDPC Encoder used for rate-matching/recovery.""" return self._encoder ######################### # Keras layer functions ######################### def build(self, input_shape): """Build model.""" if self._stateful: assert(len(input_shape)==2), \ "For stateful decoding, a tuple of two inputs is expected." input_shape = input_shape[0] # check input dimensions for consistency assert (input_shape[-1]==self.encoder.n), \ 'Last dimension must be of length n.' assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' self._old_shape_5g = input_shape def call(self, inputs): """Iterative BP decoding function. This function performs ``num_iter`` belief propagation decoding iterations and returns the estimated codeword. Args: inputs (tf.float32): Tensor of shape `[...,n]` containing the channel logits/llr values. Returns: `tf.float32`: Tensor of shape `[...,n]` or `[...,k]` (``return_infobits`` is True) containing bit-wise soft-estimates (or hard-decided bit-values) of all codeword bits (or info bits, respectively). Raises: ValueError: If ``inputs`` is not of shape `[batch_size, n]`. ValueError: If ``num_iter`` is not an integer greater (or equal) `0`. InvalidArgumentError: When rank(``inputs``)<2. """ # Extract inputs if self._stateful: llr_ch, msg_vn = inputs else: llr_ch = inputs tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') llr_ch_shape = llr_ch.get_shape().as_list() new_shape = [-1, llr_ch_shape[-1]] llr_ch_reshaped = tf.reshape(llr_ch, new_shape) batch_size = tf.shape(llr_ch_reshaped)[0] # invert if rate-matching output interleaver was applied as defined in # Sec. 5.4.2.2 in 38.212 if self._encoder.num_bits_per_symbol is not None: llr_ch_reshaped = tf.gather(llr_ch_reshaped, self._encoder.out_int_inv, axis=-1) # undo puncturing of the first 2*Z bit positions llr_5g = tf.concat( [tf.zeros([batch_size, 2*self.encoder.z], self._output_dtype), llr_ch_reshaped], 1) # undo puncturing of the last positions # total length must be n_ldpc, while llr_ch has length n # first 2*z positions are already added # -> add n_ldpc - n - 2Z punctured positions k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits nb_punc_bits = ((self.encoder.n_ldpc - k_filler) - self.encoder.n - 2*self.encoder.z) llr_5g = tf.concat([llr_5g, tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes], self._output_dtype)], 1) # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max) # the first k positions are the systematic bits x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k]) # parity part nb_par_bits = (self.encoder.n_ldpc - k_filler - self.encoder.k - self._nb_pruned_nodes) x2 = tf.slice(llr_5g, [0, self.encoder.k], [batch_size, nb_par_bits]) # negative sign due to logit definition z = -tf.cast(self._llr_max, self._output_dtype) \ * tf.ones([batch_size, k_filler], self._output_dtype) llr_5g = tf.concat([x1, z, x2], 1) # and execute the decoder if not self._stateful: x_hat = super().call(llr_5g) else: x_hat,msg_vn = super().call([llr_5g, msg_vn]) if self._return_infobits: # return only info bits # reconstruct u_hat # code is systematic u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k]) # Reshape u_hat so that it matches the original input dimensions output_shape = llr_ch_shape[0:-1] + [self.encoder.k] # overwrite first dimension as this could be None (Keras) output_shape[0] = -1 u_reshaped = tf.reshape(u_hat, output_shape) # enable other output datatypes than tf.float32 u_out = tf.cast(u_reshaped, self._output_dtype) if not self._stateful: return u_out else: return u_out, msg_vn else: # return all codeword bits # the transmitted CW bits are not the same as used during decoding # cf. last parts of 5G encoding function # remove last dim x = tf.reshape(x_hat, [batch_size, self._n_pruned]) # remove filler bits at pos (k, k_ldpc) x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k]) x_no_filler2 = tf.slice(x, [0, self.encoder.k_ldpc], [batch_size, self._n_pruned-self.encoder.k_ldpc]) x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1) # shorten the first 2*Z positions and end after n bits x_short = tf.slice(x_no_filler, [0, 2*self.encoder.z], [batch_size, self.encoder.n]) # if used, apply rate-matching output interleaver again as # Sec. 5.4.2.2 in 38.212 if self._encoder.num_bits_per_symbol is not None: x_short = tf.gather(x_short, self._encoder.out_int, axis=-1) # Reshape x_short so that it matches the original input dimensions # overwrite first dimension as this could be None (Keras) llr_ch_shape[0] = -1 x_short= tf.reshape(x_short, llr_ch_shape) # enable other output datatypes than tf.float32 x_out = tf.cast(x_short, self._output_dtype) if not self._stateful: return x_out else: return x_out, msg_vn