Source code for sionna.phy.nr.tb_decoder

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Transport block decoding functions for the NR (5G) module of Sionna PHY"""

import numpy as np
import tensorflow as tf

from sionna.phy import Block
from sionna.phy.fec.crc import CRCDecoder
from sionna.phy.fec.scrambling import  Descrambler
from sionna.phy.fec.ldpc import LDPC5GDecoder
from sionna.phy.nr import TBEncoder

[docs] class TBDecoder(Block): # pylint: disable=line-too-long r"""TBDecoder(encoder, num_bp_iter=20, cn_type="boxplus-phi", precision=None, **kwargs) 5G NR transport block (TB) decoder as defined in TS 38.214 [3GPP38214]_. The transport block decoder takes as input a sequence of noisy channel observations and reconstructs the corresponding `transport block` of information bits. The detailed procedure is described in TS 38.214 [3GPP38214]_ and TS 38.211 [3GPP38211]_. Parameters ---------- encoder : :class:`~sionna.phy.nr.TBEncoder` Associated transport block encoder used for encoding of the signal num_bp_iter : 20 (default) | `int` Number of BP decoder iterations cn_update: str, "boxplus-phi" (default) | "boxplus" | "minsum" | "offset-minsum" | "identity" | callable Check node update rule to be used as described above. If a callable is provided, it will be used instead as CN update. The input of the function is a ragged tensor of v2c messages of shape `[num_cns, None, batch_size]` where the second dimension is ragged (i.e., depends on the individual CN degree). vn_update: str, "sum" (default) | "identity" | callable Variable node update rule to be used. If a callable is provided, it will be used instead as VN update. The input of the function is a ragged tensor of c2v messages of shape `[num_vns, None, batch_size]` where the second dimension is ragged (i.e., depends on the individual VN degree). precision : `None` (default) | "single" | "double" Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. Input ----- inputs : [...,num_coded_bits], `tf.float` 2+D tensor containing channel logits/llr values of the (noisy) channel observations. Output ------ b_hat : [...,target_tb_size], `tf.float` 2+D tensor containing hard decided bit estimates of all information bits of the transport block. tb_crc_status : [...], `tf.bool` Transport block CRC status indicating if a transport block was (most likely) correctly recovered. Note that false positives are possible. """ def __init__(self, encoder, num_bp_iter=20, cn_update="boxplus-phi", vn_update="sum", precision=None, **kwargs): super().__init__(precision=precision, **kwargs) assert isinstance(encoder, TBEncoder), "encoder must be TBEncoder." self._tb_encoder = encoder self._num_cbs = encoder.num_cbs # init BP decoder self._decoder = LDPC5GDecoder(encoder=encoder.ldpc_encoder, num_iter=num_bp_iter, cn_update=cn_update, vn_update=vn_update, hard_out=True, # TB operates on bit-level return_infobits=True, precision=precision) # init descrambler if encoder.scrambler is not None: self._descrambler = Descrambler(encoder.scrambler, binary=False, precision=precision) else: self._descrambler = None # init CRC Decoder for CB and TB self._tb_crc_decoder = CRCDecoder(encoder.tb_crc_encoder, precision=precision) if encoder.cb_crc_encoder is not None: self._cb_crc_decoder = CRCDecoder(encoder.cb_crc_encoder, precision=precision) else: self._cb_crc_decoder = None ######################################### # Public methods and properties ######################################### @property def tb_size(self): """ `int` : Number of information bits per TB""" return self._tb_encoder.tb_size # required for @property def k(self): """ `int` : Number of input information bits. Equals TB size. """ return self._tb_encoder.tb_size @property def n(self): """ `int` : Total number of output codeword bits """ return self._tb_encoder.n def build(self, input_shapes): """Test input shapes for consistency""" assert input_shapes[-1]==self.n, \ f"Invalid input shape. Expected input length is {self.n}." def call(self, inputs): """Apply transport block decoding""" # store shapes input_shape = inputs.shape.as_list() llr_ch = tf.cast(inputs, tf.float32) llr_ch = tf.reshape(llr_ch, (-1, self._tb_encoder.num_tx, self._tb_encoder.n)) # undo scrambling (only if scrambler was used) if self._descrambler is not None: llr_scr = self._descrambler(llr_ch) else: llr_scr = llr_ch # undo CB interleaving and puncturing num_fillers = self._tb_encoder.ldpc_encoder.n * \ self._tb_encoder.num_cbs - np.sum(self._tb_encoder.cw_lengths) llr_int = tf.concat([llr_scr, tf.zeros([tf.shape(llr_scr)[0], self._tb_encoder.num_tx, num_fillers], dtype=llr_scr.dtype) ], axis=-1) llr_int = tf.gather(llr_int, self._tb_encoder.output_perm_inv, axis=-1) # undo CB concatenation llr_cb = tf.reshape(llr_int, (-1, self._tb_encoder.num_tx, self._num_cbs, self._tb_encoder.ldpc_encoder.n)) # LDPC decoding u_hat_cb = self._decoder(llr_cb) # CB CRC removal (if relevant) if self._cb_crc_decoder is not None: # we are ignoring the CB CRC status for the moment # Could be combined with the TB CRC for even better estimates u_hat_cb_crc, _ = self._cb_crc_decoder(u_hat_cb) else: u_hat_cb_crc = u_hat_cb # undo CB segmentation u_hat_tb = tf.reshape(u_hat_cb_crc, (-1, self._tb_encoder.num_tx, self.tb_size+self._tb_encoder.tb_crc_encoder.crc_length)) # TB CRC removal u_hat, tb_crc_status = self._tb_crc_decoder(u_hat_tb) # restore input shape output_shape = input_shape output_shape[0] = -1 output_shape[-1] = self.tb_size u_hat = tf.reshape(u_hat, output_shape) # also apply to tb_crc_status output_shape[-1] = 1 # but last dim is 1 tb_crc_status = tf.reshape(tb_crc_status, output_shape) # remove if zero-padding was applied if self._tb_encoder.k_padding>0: u_hat = u_hat[...,:-self._tb_encoder.k_padding] # cast to output rdtype u_hat = tf.cast(u_hat, self.rdtype) tb_crc_status = tf.squeeze(tf.cast(tb_crc_status, tf.bool), axis=-1) return u_hat, tb_crc_status