Source code for sionna.phy.nr.utils

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Utility functions for 5G NR physical layer processing."""

from typing import Optional, Tuple, Union
import numpy as np
import torch

from sionna.phy import config
from sionna.phy.utils import (
    MCSDecoder,
    TransportBlock,
    SingleLinkChannel,
    ebnodb2no,
)
from sionna.phy.channel import AWGN
from sionna.phy.mapping import Mapper, Demapper, Constellation, BinarySource


__all__ = [
    "generate_prng_seq",
    "decode_mcs_index",
    "calculate_num_coded_bits",
    "calculate_tb_size",
    "MCSDecoderNR",
    "TransportBlockNR",
    "CodedAWGNChannelNR",
]


[docs] def generate_prng_seq(length: int, c_init: int) -> np.ndarray: r"""Implements pseudo-random sequence generator as defined in Sec. 5.2.1 in :cite:p:`3GPPTS38211` based on a length-31 Gold sequence. :param length: Desired output sequence length. :param c_init: Initialization sequence of the PRNG. Must be in the range of 0 to :math:`2^{32}-1`. :output seq: [``length``], `ndarray` of 0s and 1s. Containing the scrambling sequence. .. rubric:: Notes The initialization sequence ``c_init`` is application specific and is usually provided by higher layer protocols. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import generate_prng_seq seq = generate_prng_seq(100, 12345) print(seq.shape) # (100,) """ # Check inputs for consistency if length % 1 != 0: raise ValueError("length must be a positive integer.") length = int(length) if length <= 0: raise ValueError("length must be a positive integer.") if c_init % 1 != 0: raise ValueError("c_init must be integer.") c_init = int(c_init) if c_init >= 2**32: raise ValueError("c_init must be in [0, 2^32-1].") if c_init < 0: raise ValueError("c_init must be in [0, 2^32-1].") # Internal parameters n_seq = 31 # Length of gold sequence n_c = 1600 # Defined in 5.2.1 in 38.211 # Init sequences c = np.zeros(length) x1 = np.zeros(length + n_c + n_seq) x2 = np.zeros(length + n_c + n_seq) # int2bin bin_ = format(c_init, f"0{n_seq}b") c_init_arr = [int(x) for x in bin_[-n_seq:]] if n_seq else [] c_init_arr = np.flip(c_init_arr) # Reverse order # Init x1 and x2 x1[0] = 1 x2[0:n_seq] = c_init_arr # Run the generator for idx in range(length + n_c): x1[idx + 31] = np.mod(x1[idx + 3] + x1[idx], 2) x2[idx + 31] = np.mod(x2[idx + 3] + x2[idx + 2] + x2[idx + 1] + x2[idx], 2) # Update output sequence for idx in range(length): c[idx] = np.mod(x1[idx + n_c] + x2[idx + n_c], 2) return c
[docs] def decode_mcs_index( mcs_index: Union[int, torch.Tensor], table_index: Union[int, torch.Tensor] = 1, is_pusch: Union[bool, torch.Tensor] = True, transform_precoding: Union[bool, torch.Tensor] = False, pi2bpsk: Union[bool, torch.Tensor] = False, check_index_validity: bool = True, verbose: bool = False, device: Optional[str] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Returns the modulation order and target coderate for a given MCS index. Implements MCS tables as defined in :cite:p:`3GPPTS38214` for PUSCH and PDSCH. :param mcs_index: MCS index (denoted as :math:`I_{MCS}` in :cite:p:`3GPPTS38214`). Accepted values are ``{0,1,...28}``. :param table_index: MCS table index from :cite:p:`3GPPTS38214`. Accepted values are ``{1,2,3,4}``. :param is_pusch: Specifies whether the 5G NR physical channel is of "PUSCH" type. If `False`, then the "PDSCH" channel is considered. :param transform_precoding: Specifies whether the MCS tables described in Sec. 6.1.4.1 of :cite:p:`3GPPTS38214` are applied. Only relevant for "PUSCH". :param pi2bpsk: Specifies whether the higher-layer parameter `tp-pi2BPSK` described in Sec. 6.1.4.1 of :cite:p:`3GPPTS38214` is applied. Only relevant for "PUSCH". :param check_index_validity: If `True`, a ValueError is raised if the input MCS indices are not valid for the given configuration. :param verbose: If `True`, additional information is printed. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :output modulation_order: [...], `torch.int32`. Modulation order, i.e., number of bits per symbol, associated with the input MCS index. :output target_rate: [...], `torch.float32`. Target coderate associated with the input MCS index. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import decode_mcs_index import torch # Scalar input mod_order, rate = decode_mcs_index(14, table_index=1) print(f"Modulation order: {mod_order.item()}, Target rate: {rate.item():.3f}") # Tensor input mcs_indices = torch.tensor([10, 14, 20]) mod_orders, rates = decode_mcs_index(mcs_indices, table_index=1) """ from sionna.phy.utils import scalar_to_shaped_tensor if device is None: device = config.device # Convert mcs_index to tensor if isinstance(mcs_index, (int, float)): mcs_index = torch.tensor(mcs_index, dtype=torch.int32, device=device) else: mcs_index = mcs_index.to(dtype=torch.int32, device=device) shape = list(mcs_index.shape) # Cast and reshape inputs to match mcs_index shape table_index = scalar_to_shaped_tensor(table_index, torch.int32, shape, device) is_pusch = scalar_to_shaped_tensor(is_pusch, torch.bool, shape, device) transform_precoding = scalar_to_shaped_tensor( transform_precoding, torch.bool, shape, device ) pi2bpsk = scalar_to_shaped_tensor(pi2bpsk, torch.bool, shape, device) # Input validation if check_index_validity: assert (mcs_index >= 0).all(), "MCS index cannot be negative" assert (mcs_index <= 28).all(), "MCS index cannot be higher than 28" valid_tables = (table_index >= 1) & (table_index <= 4) assert valid_tables.all(), "table_index must contain values in [1,2,3,4]" # Modulation orders lookup table # [2, 4, 29]: [channel_type, table_index, mcs_index] mod_orders = torch.tensor( [ [ # PUSCH with transform_precoding # Table 1 (q=1) [1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, -1], # Table 2 (q=1) [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, -1], # Table 3 (dummy) [-1] * 29, # Table 4 (dummy) [-1] * 29 ], [ # PDSCH | transform_precoding is False # Table 1 [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], # Table 2 [2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 8, 8, -1], # Table 3 [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6], # Table 4 [2, 2, 2, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 8, 8, 10, 10, 10, 10, -1, -1] ] ], dtype=torch.int32, device=device, ) # Target rates lookup table (x1024) target_rates = torch.tensor( [ [ # PUSCH with transform_precoding # Table 1 (q=1) [240, 314, 193, 251, 308, 379, 449, 526, 602, 679, 340, 378, 434, 490, 553, 616, 658, 466, 517, 567, 616, 666, 719, 772, 822, 873, 910, 948, -1], # Table 2 (q=1) [60, 80, 100, 128, 156, 198, 120, 157, 193, 251, 308, 379, 449, 526, 602, 679, 378, 434, 490, 553, 616, 658, 699, 772, 567, 616, 666, 772, -1], # Table 3 (dummy) [-1] * 29, # Table 4 (dummy) [-1] * 29 ], [ # PDSCH | transform_precoding is False # Table 1 [120, 157, 193, 251, 308, 379, 449, 526, 602, 679, 340, 378, 434, 490, 553, 616, 658, 438, 466, 517, 567, 616, 666, 719, 772, 822, 873, 910, 948], # Table 2 [120, 193, 308, 449, 602, 378, 434, 490, 553, 616, 658, 466, 517, 567, 616, 666, 719, 772, 822, 873, 682.5, 711, 754, 797, 841, 885, 916.5, 948, -1], # Table 3 [30, 40, 50, 64, 78, 99, 120, 157, 193, 251, 308, 379, 449, 526, 602, 340, 378, 434, 490, 553, 616, 438, 466, 517, 567, 616, 666, 719, 772], # Table 4 [120, 193, 449, 378, 490, 616, 466, 517, 567, 616, 666, 719, 772, 822, 873, 682.5, 711, 754, 797, 841, 885, 916.5, 948, 805.5, 853, 900.5, 948, -1, -1] ] ], dtype=torch.float32, device=device, ) # Compute channel type index: 0 for PUSCH with transform_precoding, # 1 for PDSCH or no transform_precoding channel_type_idx = (~is_pusch | ~transform_precoding).to(torch.int32) # Flatten inputs for indexing, then reshape back orig_shape = mcs_index.shape mcs_flat = mcs_index.flatten() table_flat = (table_index - 1).flatten() channel_flat = channel_type_idx.flatten() # Use advanced indexing for batched lookup mod_orders_sel = mod_orders[channel_flat, table_flat, mcs_flat] target_rates_sel = target_rates[channel_flat, table_flat, mcs_flat] # Reshape back to original shape if len(orig_shape) > 0: mod_orders_sel = mod_orders_sel.reshape(orig_shape) target_rates_sel = target_rates_sel.reshape(orig_shape) # Check that the selected indices are valid if check_index_validity: assert (mod_orders_sel >= 0).all(), "Invalid MCS index" ####################### # Account for pi2BPSK # ####################### q = torch.where(pi2bpsk, 1, 2) # Condition: channel_type == 0 AND ((table_index == 1 AND mcs < 2) OR # (table_index == 2 AND mcs < 6)) needs_q_correction = ( (channel_type_idx == 0) & (((table_index == 1) & (mcs_index < 2)) | ((table_index == 2) & (mcs_index < 6))) ) # Apply correction where needed mod_orders_sel = torch.where( needs_q_correction, mod_orders_sel * q, mod_orders_sel ) target_rates_sel = torch.where( needs_q_correction, target_rates_sel / q.to(target_rates_sel.dtype), target_rates_sel ) # Convert target rate from x1024 to actual rate target_rates_sel = target_rates_sel / 1024.0 if verbose: print(f"Modulation order: {mod_orders_sel}") print(f"Target code rate: {target_rates_sel}") return mod_orders_sel, target_rates_sel
[docs] def calculate_num_coded_bits( modulation_order: int, num_prbs: int, num_ofdm_symbols: int, num_dmrs_per_prb: int, num_layers: int = 1, num_ov: int = 0, tb_scaling: float = 1.0, ) -> int: r"""Computes the number of coded bits that fit in a slot for the given resource grid structure. :param modulation_order: Modulation order, i.e., number of bits per QAM symbol. :param num_prbs: Total number of allocated PRBs per OFDM symbol, where 1 PRB equals 12 subcarriers. Must not exceed 275. :param num_ofdm_symbols: Number of OFDM symbols allocated for transmission. Cannot be larger than 14. :param num_dmrs_per_prb: Number of DMRS (i.e., pilot) symbols per PRB that are `not` used for data transmission, across all ``num_ofdm_symbols`` OFDM symbols. :param num_layers: Number of MIMO layers. :param num_ov: Number of unused resource elements due to additional overhead as specified by higher layer. :param tb_scaling: TB scaling factor for PDSCH as defined in TS 38.214 Tab. 5.1.3.2-2. Must contain values in {0.25, 0.5, 1.0}. :output num_coded_bits: `int`. Number of coded bits that can be fit into a given slot for the given configuration. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import calculate_num_coded_bits num_bits = calculate_num_coded_bits(4, 50, 14, 12, 2) print(num_bits) """ # Validate inputs if num_ofdm_symbols < 1 or num_ofdm_symbols > 14: raise ValueError("num_ofdm_symbols must be in [1, 14]") if num_prbs < 1 or num_prbs > 275: raise ValueError("num_prbs must be in [1, 275]") if tb_scaling not in (0.25, 0.5, 1.0): raise ValueError("tb_scaling must be 0.25, 0.5, or 1.0") # Compute number of Resource Elements (RE) per PRB n_re_per_prb = 12 * num_ofdm_symbols - num_dmrs_per_prb - num_ov # Max REs per PRB is limited to 156 in 38.214 n_re_per_prb = min(156, n_re_per_prb) # Compute number of coded bits num_coded_bits = int(tb_scaling * n_re_per_prb * num_prbs * modulation_order * num_layers) return num_coded_bits
[docs] def calculate_tb_size( modulation_order: Union[int, torch.Tensor], target_coderate: Union[float, torch.Tensor], target_tb_size: Optional[Union[int, float, torch.Tensor]] = None, num_coded_bits: Optional[Union[int, torch.Tensor]] = None, num_prbs: Optional[Union[int, torch.Tensor]] = None, num_ofdm_symbols: Optional[Union[int, torch.Tensor]] = None, num_dmrs_per_prb: Optional[Union[int, torch.Tensor]] = None, num_layers: Union[int, torch.Tensor] = 1, num_ov: Union[int, torch.Tensor] = 0, tb_scaling: Union[float, torch.Tensor] = 1.0, return_cw_length: bool = True, verbose: bool = False, device: Optional[str] = None, ) -> Tuple: r"""Calculates the transport block (TB) size for given system parameters. This function follows the procedure defined in TS 38.214 Sec. 5.1.3.2 and Sec. 6.1.4.2 :cite:p:`3GPPTS38214`. :param modulation_order: Modulation order, i.e., number of bits per QAM symbol. :param target_coderate: Target coderate. :param target_tb_size: Target transport block size, i.e., number of information bits that can be encoded into a slot for the given slot configuration. :param num_coded_bits: Number of coded bits that can be fit into a given slot. If provided, ``num_prbs``, ``num_ofdm_symbols`` and ``num_dmrs_per_prb`` are ignored. :param num_prbs: Total number of allocated PRBs per OFDM symbol, where 1 PRB equals 12 subcarriers. Must not exceed 275. :param num_ofdm_symbols: Number of OFDM symbols allocated for transmission. Cannot be larger than 14. :param num_dmrs_per_prb: Number of DMRS (i.e., pilot) symbols per PRB that are `not` used for data transmission, across all ``num_ofdm_symbols`` OFDM symbols. :param num_layers: Number of MIMO layers. :param num_ov: Number of unused resource elements due to additional overhead as specified by higher layer. :param tb_scaling: TB scaling factor for PDSCH as defined in TS 38.214 Tab. 5.1.3.2-2. :param return_cw_length: If `True`, the function returns ``tb_size``, ``cb_size``, ``num_cb``, ``tb_crc_length``, ``cb_crc_length``, ``cw_length``. Otherwise, it does not return ``cw_length`` to reduce computation time. :param verbose: If `True`, additional information is printed. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :output tb_size: [...], `torch.int32`. Transport block (TB) size, i.e., how many information bits can be encoded into a slot for the given slot configuration. :output cb_size: [...], `torch.int32`. Code block (CB) size, i.e., the number of information bits per codeword, including the TB/CB CRC parity bits. :output num_cb: [...], `torch.int32`. Number of CBs that the TB is segmented into. :output tb_crc_length: [...], `torch.int32`. Length of the TB CRC. :output cb_crc_length: [...], `torch.int32`. Length of each CB CRC. :output cw_length: [..., N], `torch.int32`. Codeword length of each of the ``num_cbs`` codewords after LDPC encoding and rate-matching. Note that zeros are appended along the last axis to obtain a dense tensor. The total number of coded bits, ``num_coded_bits``, is the sum of ``cw_length`` across its last axis. Only returned if ``return_cw_length`` is `True`. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import calculate_tb_size tb_size, cb_size, num_cb, tb_crc, cb_crc, cw_len = calculate_tb_size( modulation_order=4, target_coderate=0.5, num_coded_bits=4800, num_layers=1 ) print(f"TB size: {tb_size}, CB size: {cb_size}, Num CBs: {num_cb}") """ from sionna.phy.utils import scalar_to_shaped_tensor if device is None: device = config.device # Convert modulation_order to tensor if isinstance(modulation_order, (int, float)): modulation_order = torch.tensor(modulation_order, dtype=torch.int32, device=device) else: modulation_order = modulation_order.to(dtype=torch.int32, device=device) shape = list(modulation_order.shape) # Convert target_coderate to tensor if isinstance(target_coderate, (int, float)): target_coderate = torch.tensor(target_coderate, dtype=torch.float32, device=device) if shape: target_coderate = target_coderate.expand(shape) else: target_coderate = target_coderate.to(dtype=torch.float32, device=device) # Broadcast scalars num_layers = scalar_to_shaped_tensor(num_layers, torch.int32, shape, device) tb_scaling = scalar_to_shaped_tensor(tb_scaling, torch.float32, shape, device) # ---------------# # N. coded bits # # ---------------# if num_coded_bits is not None: if isinstance(num_coded_bits, (int, float)): num_coded_bits = torch.tensor(num_coded_bits, dtype=torch.int32, device=device) if shape: num_coded_bits = num_coded_bits.expand(shape) else: num_coded_bits = num_coded_bits.to(dtype=torch.int32, device=device) else: assert num_prbs is not None and num_ofdm_symbols is not None and num_dmrs_per_prb is not None, \ "If num_coded_bits is None then num_prbs, num_ofdm_symbols, num_dmrs_per_prb must be specified." # calculate_num_coded_bits returns int, need to convert to tensor num_coded_bits_val = calculate_num_coded_bits( modulation_order.item() if isinstance(modulation_order, torch.Tensor) else modulation_order, num_prbs.item() if isinstance(num_prbs, torch.Tensor) else num_prbs, num_ofdm_symbols.item() if isinstance(num_ofdm_symbols, torch.Tensor) else num_ofdm_symbols, num_dmrs_per_prb.item() if isinstance(num_dmrs_per_prb, torch.Tensor) else num_dmrs_per_prb, num_layers.item() if isinstance(num_layers, torch.Tensor) else num_layers, num_ov.item() if isinstance(num_ov, torch.Tensor) else num_ov, tb_scaling.item() if isinstance(tb_scaling, torch.Tensor) else tb_scaling) num_coded_bits = torch.tensor(num_coded_bits_val, dtype=torch.int32, device=device) if shape: num_coded_bits = num_coded_bits.expand(shape) # --------------# # Target TB size # # --------------# if target_tb_size is not None: if isinstance(target_tb_size, (int, float)): target_tb_size = torch.tensor(target_tb_size, dtype=torch.float32, device=device) if shape: target_tb_size = target_tb_size.expand(shape) else: target_tb_size = target_tb_size.to(dtype=torch.float32, device=device) else: target_tb_size = target_coderate * num_coded_bits.to(torch.float32) # -----------------------------# # Quantized n. information bits # # -----------------------------# # For target_tb_size <= 3824 log2_n_info = torch.log2(torch.clamp(target_tb_size, min=1.0)) n_small = torch.clamp(torch.floor(log2_n_info) - 6, min=3.0) n_info_q_small = torch.clamp( 2**n_small * torch.floor(target_tb_size / 2**n_small), min=24.0 ) # For target_tb_size > 3824 log2_n_info_minus_24 = torch.log2(torch.clamp(target_tb_size - 24, min=1.0)) n_large = torch.floor(log2_n_info_minus_24) - 5.0 n_info_q_large = torch.clamp( 2**n_large * torch.round((target_tb_size - 24) / 2**n_large), min=3840.0 ) n_info_q = torch.where(target_tb_size <= 3824, n_info_q_small, n_info_q_large) # -----------------# # N. of code blocks # # -----------------# # Case 1: n_info_q <= 3824 num_cb_case1 = torch.ones_like(n_info_q, dtype=torch.int32) # Case 2: target_coderate <= 0.25 num_cb_case2 = torch.ceil((n_info_q + 24) / 3816).to(torch.int32) # Case 3: n_info_q > 8424 num_cb_case3 = torch.ceil((n_info_q + 24) / 8424).to(torch.int32) # Case 4: else num_cb_case4 = torch.ones_like(n_info_q, dtype=torch.int32) num_cb = torch.where( n_info_q <= 3824, num_cb_case1, torch.where( target_coderate <= 0.25, num_cb_case2, torch.where(n_info_q > 8424, num_cb_case3, num_cb_case4) ) ) # ----------------------# # TB size (n. info bits) # # ----------------------# # Table 5.1.3.2-1 from 38.214 tab51321 = torch.tensor([ -1, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 408, 432, 456, 480, 504, 528, 552, 576, 608, 640, 672, 704, 736, 768, 808, 848, 888, 928, 984, 1032, 1064, 1128, 1160, 1192, 1224, 1256, 1288, 1320, 1352, 1416, 1480, 1544, 1608, 1672, 1736, 1800, 1864, 1928, 2024, 2088, 2152, 2216, 2280, 2408, 2472, 2536, 2600, 2664, 2728, 2792, 2856, 2976, 3104, 3240, 3368, 3496, 3624, 3752, 3824], dtype=torch.int32, device=device) # For n_info_q <= 3824: find smallest TB size >= n_info_q using searchsorted idx = torch.searchsorted(tab51321, n_info_q.to(torch.int32)) idx = torch.clamp(idx, max=len(tab51321) - 1) tb_size_small = tab51321[idx] # For n_info_q > 3824: Step 5 of 38.214 5.1.3.2 tb_size_large = (8 * num_cb * torch.ceil((n_info_q + 24) / (8 * num_cb.to(torch.float32))) - 24).to(torch.int32) tb_size = torch.where(n_info_q <= 3824, tb_size_small, tb_size_large) # ----------------# # TB/CB CRC length # # ----------------# tb_crc_length = torch.where(tb_size > 3824, torch.tensor(24, dtype=torch.int32, device=device), torch.tensor(16, dtype=torch.int32, device=device)) cb_crc_length = torch.where(num_cb > 1, torch.tensor(24, dtype=torch.int32, device=device), torch.tensor(0, dtype=torch.int32, device=device)) # -------# # CB size # # -------# cb_size = (tb_size + tb_crc_length) // num_cb + cb_crc_length if verbose: print(f"Modulation order: {modulation_order}") print(f"Target coderate: {target_coderate}") print(f"Number of layers: {num_layers}") print("------------------") print(f"Info bits per TB: {tb_size}") print(f"TB CRC length: {tb_crc_length}") print(f"Total number of coded TB bits: {num_coded_bits}") print("------------------") print(f"Info bits per CB: {cb_size}") print(f"Number of CBs: {num_cb}") print(f"CB CRC length: {cb_crc_length}") if not return_cw_length: return tb_size, cb_size, num_cb, tb_crc_length, cb_crc_length # ---------------------------# # Codeword length for each CB # # ---------------------------# num_last_blocks = (num_coded_bits // (num_layers * modulation_order)) % num_cb cw_length_last_blocks = (num_layers * modulation_order * torch.ceil(num_coded_bits.to(torch.float32) / (num_layers * modulation_order * num_cb).to(torch.float32))).to(torch.int32) num_first_blocks = num_cb - num_last_blocks cw_length_first_blocks = (num_layers * modulation_order * torch.floor(num_coded_bits.to(torch.float32) / (num_layers * modulation_order * num_cb).to(torch.float32))).to(torch.int32) # For tensor outputs, we return the max num_cb and pad with zeros # Flatten for construction orig_shape = tb_size.shape num_last_flat = num_last_blocks.flatten() cw_last_flat = cw_length_last_blocks.flatten() num_first_flat = num_first_blocks.flatten() cw_first_flat = cw_length_first_blocks.flatten() num_cb_flat = num_cb.flatten() # Build codeword lengths: for each element, first blocks have cw_first, last have cw_last max_num_cb = num_cb_flat.max().item() batch_size = num_cb_flat.numel() # Create range tensor for comparison r = torch.arange(max_num_cb, device=device).unsqueeze(0) # [1, max_num_cb] # Construct cw_length tensor cw_length = torch.where( r < num_first_flat.unsqueeze(1), cw_first_flat.unsqueeze(1), torch.where( r < num_cb_flat.unsqueeze(1), cw_last_flat.unsqueeze(1), torch.zeros(1, dtype=torch.int32, device=device) ) ) # Reshape to original shape + [max_num_cb] if len(orig_shape) > 0: cw_length = cw_length.reshape(list(orig_shape) + [max_num_cb]) else: cw_length = cw_length.squeeze(0) if verbose: print(f"Output codeword lengths: {cw_length}") return tb_size, cb_size, num_cb, tb_crc_length, cb_crc_length, cw_length
[docs] class MCSDecoderNR(MCSDecoder): r"""Maps a Modulation and Coding Scheme (MCS) index to the corresponding modulation order, i.e., number of bits per symbol, and coderate for 5G-NR networks. Wraps :func:`~sionna.phy.nr.utils.decode_mcs_index` and inherits from :class:`~sionna.phy.utils.MCSDecoder`. :input mcs_index: [...], `torch.int32`. MCS index. :input mcs_table_index: [...], `torch.int32`. MCS table index. Different tables contain different mappings. :input mcs_category: [...], `torch.int32`. `0` for PUSCH, `1` for PDSCH channel. :input check_index_validity: `bool`. If `True`, a ValueError is raised if the input MCS indices are not valid for the given configuration. Defaults to `True`. :input transform_precoding: [...], `torch.bool` | `bool`. Specifies whether the MCS tables described in Sec. 6.1.4.1 of :cite:p:`3GPPTS38214` are applied. Only relevant for "PUSCH". Defaults to `False`. :input pi2bpsk: [...], `torch.bool` | `bool`. Specifies whether the higher-layer parameter `tp-pi2BPSK` described in Sec. 6.1.4.1 of :cite:p:`3GPPTS38214` is applied. Only relevant for "PUSCH". Defaults to `False`. :input verbose: `bool`. If `True`, additional information is printed. Defaults to `False`. :output modulation_order: [...], `torch.int32`. Modulation order corresponding to the input MCS index. :output target_coderate: [...], `torch.float32`. Target coderate corresponding to the input MCS index. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import MCSDecoderNR import torch decoder = MCSDecoderNR() # Scalar input mod_order, rate = decoder(mcs_index=14, mcs_table_index=1, mcs_category=0) print(f"Modulation order: {mod_order.item()}, Target rate: {rate.item():.3f}") # Tensor input mcs_indices = torch.tensor([10, 14, 20]) mod_orders, rates = decoder(mcs_index=mcs_indices, mcs_table_index=1, mcs_category=0) """ def call( self, mcs_index: Union[int, torch.Tensor], mcs_table_index: Union[int, torch.Tensor], mcs_category: Union[int, torch.Tensor], check_index_validity: bool = True, transform_precoding: Union[bool, torch.Tensor] = True, pi2bpsk: Union[bool, torch.Tensor] = False, verbose: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Process MCS index to return modulation order and coderate.""" # Convert mcs_category to is_pusch: 0 -> True (PUSCH), 1 -> False (PDSCH) if isinstance(mcs_category, torch.Tensor): is_pusch = mcs_category == 0 else: is_pusch = mcs_category == 0 modulation_order, target_coderate = decode_mcs_index( mcs_index, table_index=mcs_table_index, is_pusch=is_pusch, transform_precoding=transform_precoding, pi2bpsk=pi2bpsk, check_index_validity=check_index_validity, verbose=verbose, device=self.device, ) return modulation_order, target_coderate
[docs] class TransportBlockNR(TransportBlock): r"""Computes the number and size (measured in number of bits) of code blocks within a 5G-NR compliant transport block, given the modulation order, coderate and the total number of coded bits of a transport block. Used in :class:`~sionna.sys.PHYAbstraction`. Inherits from :class:`~sionna.phy.utils.TransportBlock` and wraps :func:`~sionna.phy.nr.utils.calculate_tb_size`. :input modulation_order: [...], `torch.int32`. Modulation order, i.e., number of bits per symbol, associated with the input MCS index. :input target_rate: [...], `torch.float32`. Target coderate. :input num_coded_bits: [...], `torch.int32`. Total number of coded bits across all codewords. :output cb_size: [...], `torch.int32`. Code block (CB) size, i.e., the number of information bits per code block. :output num_cb: [...], `torch.int32`. Number of code blocks that the transport block is segmented into. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import TransportBlockNR import torch tb = TransportBlockNR() # Scalar input cb_size, num_cb = tb(modulation_order=4, target_coderate=0.5, num_coded_bits=4800) print(f"CB size: {cb_size}, Num CBs: {num_cb}") # Tensor input mod_orders = torch.tensor([4, 6, 4]) rates = torch.tensor([0.5, 0.5, 0.75]) coded_bits = torch.tensor([4800, 7200, 3600]) cb_sizes, num_cbs = tb(mod_orders, rates, coded_bits) """ def call( self, modulation_order: Union[int, torch.Tensor], target_coderate: Union[float, torch.Tensor], num_coded_bits: Union[int, torch.Tensor], **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute code block size and count.""" _, cb_size, num_cb, *_ = calculate_tb_size( modulation_order, target_coderate, num_coded_bits=num_coded_bits, tb_scaling=1.0, return_cw_length=False, verbose=False, device=self.device, ) return cb_size, num_cb
[docs] class CodedAWGNChannelNR(SingleLinkChannel): r"""Simulates a 5G-NR compliant single-link coded AWGN channel. Inherits from :class:`~sionna.phy.utils.SingleLinkChannel`. :param num_bits_per_symbol: Number of bits per symbol, i.e., modulation order. :param num_info_bits: Number of information bits per code block. :param target_coderate: Target code rate, i.e., the target ratio between the information and the coded bits within a block. :param num_iter_decoder: Number of decoder iterations. See :class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder` for more details. :param cn_update_decoder: Check node update rule. See :class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder` for more details. :param precision: Precision for internal calculations. If `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation. If `None`, :attr:`~sionna.phy.config.Config.device` is used. :param kwargs: Additional keyword arguments for :class:`~sionna.phy.fec.ldpc.decoding.LDPC5GDecoder`. :input batch_size: `int`. Size of the simulation batches. :input ebno_db: `float`. Eb/No value in dB. :output bits: [``batch_size``, ``num_info_bits``], `torch.int32`. Transmitted bits. :output bits_hat: [``batch_size``, ``num_info_bits``], `torch.int32`. Decoded bits. .. rubric:: Examples .. code-block:: python from sionna.phy.nr.utils import CodedAWGNChannelNR channel = CodedAWGNChannelNR( num_bits_per_symbol=4, num_info_bits=1024, target_coderate=0.5 ) bits, bits_hat = channel(batch_size=100, ebno_db=5.0) """ def __init__( self, num_bits_per_symbol: Optional[int] = None, num_info_bits: Optional[int] = None, target_coderate: Optional[float] = None, num_iter_decoder: int = 20, cn_update_decoder: str = "boxplus-phi", precision=None, device=None, **kwargs, ): super().__init__( num_bits_per_symbol, num_info_bits, target_coderate, precision=precision, device=device, ) self._num_iter_decoder = num_iter_decoder self._cn_update_decoder = cn_update_decoder self._kwargs = kwargs def call( self, batch_size: int, ebno_db: float, ) -> Tuple[torch.Tensor, torch.Tensor]: """Simulate the 5G-NR coded AWGN channel.""" # Import here to avoid circular imports from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder # Set the QAM constellation constellation = Constellation( "qam", self.num_bits_per_symbol, precision=self.precision, device=self.device, ) # Set the Mapper/Demapper mapper = Mapper( constellation=constellation, precision=self.precision, device=self.device, ) demapper = Demapper( "app", constellation=constellation, precision=self.precision, device=self.device, ) binary_source = BinarySource( precision=self.precision, device=self.device, ) awgn_channel = AWGN( precision=self.precision, device=self.device, ) # 5G code block encoder encoder = LDPC5GEncoder( self.num_info_bits, int(self.num_coded_bits), num_bits_per_symbol=self.num_bits_per_symbol, precision=self.precision, device=self.device, ) # 5G code block decoder decoder = LDPC5GDecoder( encoder, hard_out=True, num_iter=self._num_iter_decoder, cn_update=self._cn_update_decoder, precision=self.precision, device=self.device, **self._kwargs, ) # Noise power no = ebnodb2no( ebno_db, num_bits_per_symbol=self.num_bits_per_symbol, coderate=self.target_coderate, ) # Generate random information bits bits = binary_source([batch_size, self.num_info_bits]) # Encode bits codewords = encoder(bits) # Map coded bits to complex symbols x = mapper(codewords) # Pass through an AWGN channel y = awgn_channel(x, no) # Compute log-likelihood ratio (LLR) llr = demapper(y, no) # Decode transmitted bits bits_hat = decoder(llr) return bits, bits_hat