Source code for sionna.phy.nr.tb_encoder

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""5G NR transport block encoding for Sionna PHY."""

from typing import List, Optional, Tuple, Union
import numpy as np
import torch

from sionna.phy import Block
from sionna.phy.fec.crc import CRCEncoder
from sionna.phy.fec.scrambling import TB5GScrambler
from sionna.phy.fec.ldpc import LDPC5GEncoder
from sionna.phy.nr.utils import calculate_tb_size


__all__ = ["TBEncoder"]


[docs] class TBEncoder(Block): # pylint: disable=line-too-long r"""5G NR transport block (TB) encoder as defined in TS 38.214 :cite:p:`3GPPTS38214` and TS 38.211 :cite:p:`3GPPTS38211` The transport block (TB) encoder takes as input a `transport block` of information bits and generates a sequence of codewords for transmission. For this, the information bit sequence is segmented into multiple codewords, protected by additional CRC checks and FEC encoded. Further, interleaving and scrambling is applied before a codeword concatenation generates the final bit sequence. Fig. 1 provides an overview of the TB encoding procedure and we refer the interested reader to :cite:p:`3GPPTS38214` and :cite:p:`3GPPTS38211` for further details. .. figure:: ../figures/tb_encoding.png Fig. 1: Overview TB encoding (CB CRC does not always apply). If ``n_rnti`` and ``n_id`` are given as list, the TBEncoder encodes `num_tx = len(` ``n_rnti`` `)` parallel input streams with different scrambling sequences per user. :param target_tb_size: Target transport block size, i.e., how many information bits are encoded into the TB. Note that the effective TB size can be slightly different due to quantization. If required, zero padding is internally applied. :param num_coded_bits: Number of coded bits after TB encoding. :param target_coderate: Target coderate. :param num_bits_per_symbol: Modulation order, i.e., number of bits per QAM symbol. :param num_layers: Number of transmission layers. Must be in [1, ..., 8]. Defaults to 1. :param n_rnti: RNTI identifier provided by higher layer. Defaults to 1 and must be in range `[0, 65535]`. Defines a part of the random seed of the scrambler. If provided as list, every list entry defines the RNTI of an independent input stream. :param n_id: Data scrambling ID :math:`n_\text{ID}` related to cell id and provided by higher layer. Defaults to 1 and must be in range `[0, 1023]`. If provided as list, every list entry defines the scrambling id of an independent input stream. :param channel_type: Can be either "PUSCH" or "PDSCH". Defaults to "PUSCH". :param codeword_index: Scrambler can be configured for two codeword transmission. ``codeword_index`` can be either 0 or 1. Must be 0 for ``channel_type`` = "PUSCH". Defaults to 0. :param use_scrambler: If `False`, no data scrambling is applied (non standard-compliant). Defaults to `True`. :param verbose: If `True`, additional parameters are printed during initialization. Defaults to `False`. :param precision: Precision used for internal calculations and outputs. If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation ('cpu' or 'cuda'). :input inputs: [..., target_tb_size] or [..., num_tx, target_tb_size], `torch.float`. 2+D tensor containing the information bits to be encoded. If ``n_rnti`` and ``n_id`` are a list of size `num_tx`, the input must be of shape ``[..., num_tx, target_tb_size]``. :output codeword: [..., num_coded_bits], `torch.float`. 2+D tensor containing the sequence of the encoded codeword bits of the transport block. .. rubric:: Notes The parameters ``tb_size`` and ``num_coded_bits`` can be derived by the :meth:`~sionna.phy.nr.calculate_tb_size` function or by accessing the corresponding :class:`~sionna.phy.nr.PUSCHConfig` attributes. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.nr import TBEncoder encoder = TBEncoder( target_tb_size=1000, num_coded_bits=2000, target_coderate=0.5, num_bits_per_symbol=4, n_rnti=1, n_id=1 ) bits = torch.randint(0, 2, (10, 1000), dtype=torch.float32) coded_bits = encoder(bits) print(coded_bits.shape) # torch.Size([10, 2000]) """ def __init__( self, target_tb_size: int, num_coded_bits: int, target_coderate: float, num_bits_per_symbol: int, num_layers: int = 1, n_rnti: Union[int, List[int]] = 1, n_id: Union[int, List[int]] = 1, channel_type: str = "PUSCH", codeword_index: int = 0, use_scrambler: bool = True, verbose: bool = False, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): super().__init__(precision=precision, device=device, **kwargs) if not isinstance(use_scrambler, bool): raise TypeError("use_scrambler must be bool.") self._use_scrambler = use_scrambler if not isinstance(verbose, bool): raise TypeError("verbose must be bool.") self._verbose = verbose if channel_type not in ("PDSCH", "PUSCH"): raise ValueError("channel_type must be 'PDSCH' or 'PUSCH'.") self._channel_type = channel_type if target_tb_size % 1 != 0: raise ValueError("target_tb_size must be int.") self._target_tb_size = int(target_tb_size) if num_coded_bits % 1 != 0: raise ValueError("num_coded_bits must be int.") self._num_coded_bits = int(num_coded_bits) if not (0. < target_coderate <= 948 / 1024): raise ValueError("target_coderate must be in range (0, 0.925].") self._target_coderate = target_coderate if num_bits_per_symbol % 1 != 0: raise ValueError("num_bits_per_symbol must be int.") self._num_bits_per_symbol = int(num_bits_per_symbol) if num_layers % 1 != 0: raise ValueError("num_layers must be int.") self._num_layers = int(num_layers) if channel_type == "PDSCH": if codeword_index not in (0, 1): raise ValueError("codeword_index must be 0 or 1.") else: if codeword_index != 0: raise ValueError('codeword_index must be 0 for "PUSCH".') self._codeword_index = int(codeword_index) # Handle n_rnti and n_id if isinstance(n_rnti, (list, tuple)): if not isinstance(n_id, (list, tuple)): raise ValueError("n_id must also be a list.") if len(n_rnti) != len(n_id): raise ValueError("n_id and n_rnti must be of same length.") self._n_rnti = list(n_rnti) self._n_id = list(n_id) else: self._n_rnti = [n_rnti] self._n_id = [n_id] for idx, n in enumerate(self._n_rnti): if n % 1 != 0: raise ValueError("n_rnti must be int.") self._n_rnti[idx] = int(n) for idx, n in enumerate(self._n_id): if n % 1 != 0: raise ValueError("n_id must be int.") self._n_id[idx] = int(n) self._num_tx = len(self._n_id) # Calculate TB parameters tbconfig = calculate_tb_size( target_tb_size=self._target_tb_size, num_coded_bits=self._num_coded_bits, target_coderate=self._target_coderate, modulation_order=self._num_bits_per_symbol, num_layers=self._num_layers, verbose=verbose, ) self._tb_size = int(tbconfig[0]) self._cb_size = int(tbconfig[1]) self._num_cbs = int(tbconfig[2]) self._tb_crc_length = int(tbconfig[3]) self._cb_crc_length = int(tbconfig[4]) # Convert cw_lengths to numpy if it's a tensor (calculate_tb_size returns tensors) cw_lengths = tbconfig[5] if isinstance(cw_lengths, torch.Tensor): self._cw_lengths = cw_lengths.cpu().numpy() else: self._cw_lengths = np.asarray(cw_lengths) # Flatten to 1-D array in case input had extra dimensions (e.g., [1, num_cb]) self._cw_lengths = self._cw_lengths.flatten() # Cache as Python ints for torch.compile compatibility self._cw_lengths_sum = int(np.sum(self._cw_lengths)) self._cw_lengths_max = int(np.max(self._cw_lengths)) self._cw_lengths_min = int(np.min(self._cw_lengths)) if self._tb_size > self._tb_crc_length + self._cw_lengths_sum: raise ValueError("Invalid TB parameters.") # Zero padding for quantization self._k_padding = self._tb_size - self._target_tb_size if self._tb_size != self._target_tb_size: print(f"Note: actual tb_size={self._tb_size} is slightly " f"different than requested target_tb_size=" f"{self._target_tb_size} due to quantization. " f"Internal zero padding will be applied.") # Effective coderate self._coderate = self._tb_size / self._num_coded_bits # Initialize CRC encoders with pre-built generator matrices # (k parameter enables torch.compile compatibility) if self._tb_crc_length == 16: self._tb_crc_encoder = CRCEncoder( "CRC16", k=self._tb_size, precision=precision, device=device) else: self._tb_crc_encoder = CRCEncoder( "CRC24A", k=self._tb_size, precision=precision, device=device) if self._cb_crc_length == 24: cb_info_size = self._cb_size - self._cb_crc_length self._cb_crc_encoder = CRCEncoder( "CRC24B", k=cb_info_size, precision=precision, device=device) else: self._cb_crc_encoder = None # Initialize scrambler if self._use_scrambler: self._scrambler = TB5GScrambler( n_rnti=self._n_rnti, n_id=self._n_id, binary=True, channel_type=channel_type, codeword_index=codeword_index, precision=precision, device=device, ) else: self._scrambler = None # Initialize LDPC encoder self._encoder = LDPC5GEncoder( self._cb_size, self._cw_lengths_max, num_bits_per_symbol=1, # Disable interleaver precision=precision, device=device, ) # Initialize interleaver perm_seq_short, _ = self._encoder.generate_out_int( self._cw_lengths_min, num_bits_per_symbol) perm_seq_long, _ = self._encoder.generate_out_int( self._cw_lengths_max, num_bits_per_symbol) perm_seq = [] perm_seq_punc = [] payload_bit_pos = 0 for length in self._cw_lengths: # Convert numpy scalar to Python int for reliable comparison length_val = int(length) # Skip zero-padded entries (calculate_tb_size pads with zeros) if length_val == 0: continue if self._cw_lengths_min == length_val: perm_seq = np.concatenate([perm_seq, perm_seq_short + payload_bit_pos]) r = np.arange(payload_bit_pos + self._cw_lengths_min, payload_bit_pos + self._cw_lengths_max) perm_seq_punc = np.concatenate([perm_seq_punc, r]) payload_bit_pos += self._cw_lengths_max elif self._cw_lengths_max == length_val: perm_seq = np.concatenate([perm_seq, perm_seq_long + payload_bit_pos]) payload_bit_pos += length_val else: raise ValueError("Invalid cw_lengths.") perm_seq = np.concatenate([perm_seq, perm_seq_punc]) # Register as buffers for CUDAGraph compatibility self.register_buffer("_output_perm", torch.tensor( perm_seq, dtype=torch.long, device=self.device)) self.register_buffer("_output_perm_inv", torch.argsort( torch.tensor(perm_seq, dtype=torch.long, device=self.device))) ######################################### # Public methods and properties ######################################### @property def tb_size(self) -> int: r"""Effective number of information bits per TB. Note that (if required) internal zero padding can be applied to match the requested exact ``target_tb_size``.""" return self._tb_size @property def k(self) -> int: r"""Number of input information bits. Equals ``tb_size`` except for zero padding of the last positions if the ``target_tb_size`` is quantized.""" return self._target_tb_size @property def k_padding(self) -> int: """Number of zero padded bits at the end of the TB.""" return self._k_padding @property def n(self) -> int: """Total number of output bits.""" return self._num_coded_bits @property def num_cbs(self) -> int: """Number of code blocks.""" return self._num_cbs @property def coderate(self) -> float: """Effective coderate of the TB after rate-matching including overhead for the CRC.""" return self._coderate @property def ldpc_encoder(self) -> LDPC5GEncoder: """LDPC encoder used for TB encoding.""" return self._encoder @property def scrambler(self) -> Optional[TB5GScrambler]: """Scrambler used for TB scrambling. `None` if no scrambler is used.""" return self._scrambler @property def tb_crc_encoder(self) -> CRCEncoder: """TB CRC encoder.""" return self._tb_crc_encoder @property def cb_crc_encoder(self) -> Optional[CRCEncoder]: """CB CRC encoder. `None` if no CB CRC is applied.""" return self._cb_crc_encoder @property def num_tx(self) -> int: """Number of independent streams.""" return self._num_tx @property def cw_lengths(self) -> np.ndarray: r"""Each list element defines the codeword length of each of the codewords after LDPC encoding and rate-matching. The total number of coded bits is :math:`\sum` ``cw_lengths``.""" return self._cw_lengths @property def cw_lengths_sum(self) -> int: """Sum of codeword lengths (cached for torch.compile compatibility).""" return self._cw_lengths_sum @property def cw_lengths_max(self) -> int: """Maximum codeword length (cached for torch.compile compatibility).""" return self._cw_lengths_max @property def output_perm_inv(self) -> torch.Tensor: """Inverse interleaver pattern for output bit interleaver.""" return self._output_perm_inv
[docs] def build(self, input_shape: tuple) -> None: """Test input shapes for consistency.""" if input_shape[-1] != self.k: raise ValueError(f"Invalid input shape. Expected TB length is {self.k}.")
def call(self, inputs: torch.Tensor) -> torch.Tensor: """Apply transport block encoding procedure.""" input_shape = list(inputs.shape) u = inputs.float() # Handle tb_size vs target_tb_size mismatch due to quantization if self._k_padding > 0: # tb_size > target_tb_size: pad with zeros padding_shape = list(u.shape) padding_shape[-1] = self._k_padding padding = torch.zeros(padding_shape, dtype=u.dtype, device=u.device) u = torch.cat([u, padding], dim=-1) elif self._k_padding < 0: # tb_size < target_tb_size: truncate to tb_size u = u[..., :self._tb_size] # Apply TB CRC u_crc = self._tb_crc_encoder(u) # CB segmentation u_cb = u_crc.reshape(-1, self._num_tx, self._num_cbs, self._cb_size - self._cb_crc_length) # Apply CB CRC if relevant if self._cb_crc_length == 24: u_cb_crc = self._cb_crc_encoder(u_cb) else: u_cb_crc = u_cb # LDPC encode c_cb = self._encoder(u_cb_crc) # CB concatenation c = c_cb.reshape(-1, self._num_tx, self._num_cbs * self._cw_lengths_max) # Apply interleaver c = torch.index_select(c, -1, self._output_perm) # Puncture last bits c = c[..., :self._cw_lengths_sum] # Apply scrambler if self._use_scrambler: c_scr = self._scrambler(c) else: c_scr = c # Cast to output dtype c_scr = c_scr.to(self.dtype) # Ensure output shapes output_shape = input_shape.copy() output_shape[-1] = self._cw_lengths_sum c_tb = c_scr.reshape(output_shape) return c_tb