Source code for sionna.phy.fec.polar.encoding

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for Polar encoding including 5G compliant rate-matching and CRC
concatenation."""

from typing import Optional, Tuple
import numbers
import numpy as np
import torch

from sionna.phy import Block
from sionna.phy.fec.crc import CRCEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking


__all__ = ["PolarEncoder", "Polar5GEncoder"]


[docs] class PolarEncoder(Block): """Polar encoder for given code parameters. This block performs polar encoding for the given ``k`` information bits and the `frozen set` (i.e., indices of frozen positions) specified by ``frozen_pos``. :param frozen_pos: Array of `int` defining the `n-k` frozen indices, i.e., information bits are mapped onto the `k` complementary positions. :param n: Defining the codeword length. :param precision: Precision used for internal calculations and outputs. If `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input bits: [..., k], `torch.float`. Binary tensor containing the information bits to be encoded. :output cw: [..., n], `torch.float`. Binary tensor containing the codeword bits. .. rubric:: Notes As commonly done, we assume frozen bits are set to `0`. Please note that - although its practical relevance is only little - setting frozen bits to `1` may result in `affine` codes instead of linear code as the `all-zero` codeword is not necessarily part of the code any more. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.polar import PolarEncoder from sionna.phy.fec.polar.utils import generate_5g_ranking k, n = 100, 256 frozen_pos, _ = generate_5g_ranking(k, n) encoder = PolarEncoder(frozen_pos, n) bits = torch.randint(0, 2, (10, k), dtype=torch.float32) codewords = encoder(bits) print(codewords.shape) # torch.Size([10, 256]) """ def __init__( self, frozen_pos: np.ndarray, n: int, *, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): super().__init__(precision=precision, device=device, **kwargs) if not isinstance(n, numbers.Number): raise TypeError("n must be a number.") n = int(n) # n can be float (e.g. as result of n=k*r) if not np.issubdtype(frozen_pos.dtype, int): raise TypeError("frozen_pos must consist of ints.") if len(frozen_pos) > n: msg = "Number of elements in frozen_pos cannot be greater than n." raise ValueError(msg) if np.log2(n) != int(np.log2(n)): raise ValueError("n must be a power of 2.") self._k = n - len(frozen_pos) self._n = n self._frozen_pos = frozen_pos # Generate info positions info_pos = np.setdiff1d(np.arange(self._n), frozen_pos) if self._k != len(info_pos): raise ValueError("Internal error: invalid info_pos generated.") self._info_pos = info_pos # Keep numpy array for property access # Register info_pos as buffer for torch.compile compatibility self.register_buffer( "_info_pos_t", torch.tensor(info_pos, dtype=torch.int64, device=self.device), ) self._check_input = True # Check input for binary values during first call self._nb_stages = int(np.log2(self._n)) self._ind_gather = self._gen_indices(self._n) @property def k(self) -> int: """Number of information bits.""" return self._k @property def n(self) -> int: """Codeword length.""" return self._n @property def frozen_pos(self) -> np.ndarray: """Frozen positions for Polar decoding.""" return self._frozen_pos @property def info_pos(self) -> np.ndarray: """Information bit positions for Polar encoding.""" return self._info_pos def _gen_indices(self, n: int) -> torch.Tensor: """Pre-calculate encoding indices stage-wise for gather operations.""" nb_stages = int(np.log2(n)) # Last position denotes empty placeholder (points to element n+1) ind_gather = np.ones([nb_stages, n + 1], dtype=np.int32) * n for s in range(nb_stages): ind_range = np.arange(int(n / 2)) ind_dest = ind_range * 2 - np.mod(ind_range, 2**s) ind_origin = ind_dest + 2**s ind_gather[s, ind_dest] = ind_origin # Update gather indices ind_gather = torch.tensor(ind_gather, dtype=torch.int32, device=self.device) return ind_gather @torch.compiler.disable def _validate_binary_input(self, u: torch.Tensor) -> None: """Validate that input tensor contains only binary values. This method is decorated with @torch.compiler.disable to avoid recompilation issues caused by the mutable _check_input flag. """ if self._check_input: u_test = u.float() is_binary = torch.logical_or( torch.eq(u_test, 0.0), torch.eq(u_test, 1.0) ).all() if not is_binary: raise ValueError("Input must be binary.") self._check_input = False
[docs] def build(self, input_shape: Tuple[int, ...]) -> None: """Build and check if ``k`` and ``input_shape`` match.""" if input_shape[-1] != self._k: raise ValueError("Last dimension must be of length k.")
def call(self, bits: torch.Tensor) -> torch.Tensor: """Polar encoding function. This function returns the polar encoded codewords for the given information bits ``bits``. :param bits: Tensor of shape `[..., k]` containing the information bits to be encoded. :output cw: Tensor of shape `[..., n]`. """ # Reshape inputs to [..., k] input_shape = bits.shape new_shape = (-1, self._k) u = bits.reshape(new_shape) # Validate input (excluded from compilation to avoid recompilation) self._validate_binary_input(u) # Copy info bits to information set; other positions are frozen (=0) batch_size = u.shape[0] # Return an all-zero tensor of shape [batch, n] c = torch.zeros( (batch_size, self._n + 1), dtype=u.dtype, device=u.device ) # Scatter info bits into the correct positions using pre-registered buffer c[:, : self._n] = c[:, : self._n].scatter( 1, self._info_pos_t.unsqueeze(0).expand(batch_size, -1), u ) # Cast to integer for more efficient XORing x = c.to(torch.uint8) # Loop over all stages for s in range(self._nb_stages): ind_helper = self._ind_gather[s, :] x_add = x[:, ind_helper] x = torch.bitwise_xor(x, x_add) # Remove last position c_out = x[:, : self._n] # Restore original shape output_shape = list(input_shape[:-1]) + [self._n] c_reshaped = c_out.reshape(output_shape) # Cast to rdtype for compatibility with other components return c_reshaped.to(self.dtype)
[docs] class Polar5GEncoder(PolarEncoder): # pylint: disable=line-too-long """5G compliant Polar encoder including rate-matching following :cite:p:`3GPPTS38212` for the uplink scenario (`UCI`) and downlink scenario (`DCI`). This block performs polar encoding for ``k`` information bits and rate-matching such that the codeword length is ``n``. This includes the CRC concatenation and the interleaving as defined in :cite:p:`3GPPTS38212`. Note: `block segmentation` is currently not supported (`I_seq=False`). We follow the basic structure from Fig. 6 in :cite:p:`Bioglio_Design`. For further details, we refer to :cite:p:`3GPPTS38212`, :cite:p:`Bioglio_Design` and :cite:p:`Hui_ChannelCoding`. :param k: Defining the number of information bits per codeword. :param n: Defining the codeword length. :param channel_type: Can be ``'uplink'`` or ``'downlink'``. :param verbose: If `True`, rate-matching parameters will be printed. :param precision: Precision used for internal calculations and outputs. If `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input bits: [..., k], `torch.float`. Binary tensor containing the information bits to be encoded. :output cw: [..., n], `torch.float`. Binary tensor containing the codeword bits. .. rubric:: Notes The encoder supports the `uplink` Polar coding (`UCI`) scheme from :cite:p:`3GPPTS38212` and the `downlink` Polar coding (`DCI`) :cite:p:`3GPPTS38212`, respectively. For `12 <= k <= 19` the 3 additional parity bits as defined in :cite:p:`3GPPTS38212` are not implemented as it would also require a modified decoding procedure to materialize the potential gains. `Code segmentation` is currently not supported and, thus, ``n`` is limited to a maximum length of 1088 codeword bits. For the downlink scenario, the input length is limited to `k <= 140` information bits due to the limited input bit interleaver size :cite:p:`3GPPTS38212`. For simplicity, the implementation does not exactly re-implement the `DCI` scheme from :cite:p:`3GPPTS38212`. This implementation neglects the `all-one` initialization of the CRC shift register and the scrambling of the CRC parity bits with the `RNTI`. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.polar import Polar5GEncoder k, n = 100, 200 encoder = Polar5GEncoder(k, n) bits = torch.randint(0, 2, (10, k), dtype=torch.float32) codewords = encoder(bits) print(codewords.shape) # torch.Size([10, 200]) """ def __init__( self, k: int, n: int, channel_type: str = "uplink", verbose: bool = False, *, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): if not isinstance(k, numbers.Number): raise TypeError("k must be a number.") if not isinstance(n, numbers.Number): raise TypeError("n must be a number.") k = int(k) # k or n can be float (e.g. as result of n=k*r) n = int(n) if n < k: raise ValueError("Invalid coderate (>1).") if not isinstance(verbose, bool): raise TypeError("verbose must be bool.") if channel_type not in ("uplink", "downlink"): raise ValueError("Unsupported channel_type.") self._channel_type = channel_type self._k_target = k self._n_target = n self._verbose = verbose # Initialize rate-matcher crc_degree, n_polar, frozen_pos, idx_rm, idx_input = self._init_rate_match( k, n ) self._frozen_pos = frozen_pos # Required for decoder self._ind_rate_matching = idx_rm # Keep numpy array for reference self._ind_input_int = idx_input # Keep numpy array for reference # Initialize CRC encoder with k to pre-build generator matrix # (required for torch.compile compatibility) # Store reference to assign after super().__init__() _enc_crc_ref = CRCEncoder( crc_degree, k=k, precision=precision, device=device ) # Init super-class (PolarEncoder) super().__init__( frozen_pos, n_polar, precision=precision, device=device, **kwargs ) # Assign CRC encoder after super().__init__() for nn.Module compatibility self._enc_crc = _enc_crc_ref # Register rate-matching and interleaver indices as buffers self.register_buffer( "_ind_rate_matching_t", torch.tensor(idx_rm.astype(np.int32), dtype=torch.int32, device=self.device), ) if idx_input is not None: self.register_buffer( "_ind_input_int_t", torch.tensor(idx_input, dtype=torch.int32, device=self.device), ) else: self._ind_input_int_t = None @property def enc_crc(self) -> CRCEncoder: """CRC encoder block used for CRC concatenation.""" return self._enc_crc @property def k_target(self) -> int: """Number of information bits including rate-matching.""" return self._k_target @property def n_target(self) -> int: """Codeword length including rate-matching.""" return self._n_target @property def k_polar(self) -> int: """Number of information bits of the underlying Polar code.""" return self._k @property def n_polar(self) -> int: """Codeword length of the underlying Polar code.""" return self._n @property def k(self) -> int: """Number of information bits including rate-matching.""" return self._k_target @property def n(self) -> int: """Codeword length including rate-matching.""" return self._n_target
[docs] def subblock_interleaving(self, u: np.ndarray) -> np.ndarray: """Input bit interleaving as defined in Sec 5.4.1.1 :cite:p:`3GPPTS38212`. :param u: 1D array to be interleaved. Length of ``u`` must be a multiple of 32. :output y: Interleaved version of ``u`` with same shape and dtype as ``u``. """ k = u.shape[-1] if np.mod(k, 32) != 0: msg = "length for sub-block interleaving must be a multiple of 32." raise ValueError(msg) y = np.zeros_like(u) # Permutation according to Tab 5.4.1.1.1-1 in 38.212 perm = np.array( [ 0, 1, 2, 4, 3, 5, 6, 7, 8, 16, 9, 17, 10, 18, 11, 19, 12, 20, 13, 21, 14, 22, 15, 23, 24, 25, 26, 28, 27, 29, 30, 31, ] ) for n_idx in range(k): i = int(np.floor(32 * n_idx / k)) j = perm[i] * k / 32 + np.mod(n_idx, k / 32) j = int(j) y[n_idx] = u[j] return y
[docs] def channel_interleaver(self, c: np.ndarray) -> np.ndarray: """Triangular interleaver following Sec. 5.4.1.3 in :cite:p:`3GPPTS38212`. :param c: 1D array to be interleaved. :output c_int: Interleaved version of ``c`` with same shape and dtype as ``c``. """ n = c.shape[-1] # Denoted as E in 38.212 c_int = np.zeros_like(c) # Find smallest T s.t. T*(T+1)/2 >= n t = 0 while t * (t + 1) / 2 < n: t += 1 v = np.zeros([t, t]) ind_k = 0 for ind_i in range(t): for ind_j in range(t - ind_i): if ind_k < n: v[ind_i, ind_j] = c[ind_k] else: v[ind_i, ind_j] = np.nan # NULL ind_k += 1 ind_k = 0 for ind_j in range(t): for ind_i in range(t - ind_j): if not np.isnan(v[ind_i, ind_j]): c_int[ind_k] = v[ind_i, ind_j] ind_k += 1 return c_int
[docs] def input_interleaver(self, c: np.ndarray) -> np.ndarray: """Input interleaver following Sec. 5.4.1.1 in :cite:p:`3GPPTS38212`. :param c: 1D array to be interleaved. :output c_apo: Interleaved version of ``c`` with same shape and dtype as ``c``. """ # 38.212 Table 5.3.1.1-1 p_il_max_table = [ 0, 2, 4, 7, 9, 14, 19, 20, 24, 25, 26, 28, 31, 34, 42, 45, 49, 50, 51, 53, 54, 56, 58, 59, 61, 62, 65, 66, 67, 69, 70, 71, 72, 76, 77, 81, 82, 83, 87, 88, 89, 91, 93, 95, 98, 101, 104, 106, 108, 110, 111, 113, 115, 118, 119, 120, 122, 123, 126, 127, 129, 132, 134, 138, 139, 140, 1, 3, 5, 8, 10, 15, 21, 27, 29, 32, 35, 43, 46, 52, 55, 57, 60, 63, 68, 73, 78, 84, 90, 92, 94, 96, 99, 102, 105, 107, 109, 112, 114, 116, 121, 124, 128, 130, 133, 135, 141, 6, 11, 16, 22, 30, 33, 36, 44, 47, 64, 74, 79, 85, 97, 100, 103, 117, 125, 131, 136, 142, 12, 17, 23, 37, 48, 75, 80, 86, 137, 143, 13, 18, 38, 144, 39, 145, 40, 146, 41, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, ] k_il_max = 164 k = len(c) if k > k_il_max: msg = "Input interleaver only defined for length of 164." raise ValueError(msg) c_apo = np.empty(k, "int") i = 0 for p_il_max in p_il_max_table: if p_il_max >= (k_il_max - k): c_apo[i] = c[p_il_max - (k_il_max - k)] i += 1 return c_apo
def _init_rate_match( self, k_target: int, n_target: int ) -> Tuple[str, int, np.ndarray, np.ndarray, Optional[np.ndarray]]: """Implementing polar rate matching according to :cite:p:`3GPPTS38212`. Please note that this part of the code only runs during the initialization and, thus, is not performance critical. For easier alignment and traceability with the standard document :cite:p:`3GPPTS38212` the implementation prefers `for loop`-based indexing. The relation of terminology between :cite:p:`3GPPTS38212` and this code is given as: `A`...`k_target` `E`...`n_target` `K`...`k_polar` `N`...`n_polar` `L`...`k_crc`. """ # Check input for consistency (see Sec. 6.3.1.2.1 for UL) if n_target < k_target: msg = "n must be larger or equal k." raise ValueError(msg) if n_target < 18: msg = "n<18 is not supported by the 5G Polar coding scheme." raise ValueError(msg) if k_target > 1013: msg = "k too large - currently, no codeword segmentation supported." raise ValueError(msg) if n_target > 1088: msg = "n too large - currently, no codeword segmentation supported." raise ValueError(msg) # Select CRC polynomials (see Sec. 6.3.1.2.1 for UL) if self._channel_type == "uplink": if 12 <= k_target <= 19: crc_pol = "CRC6" k_crc = 6 elif k_target >= 20: crc_pol = "CRC11" k_crc = 11 else: raise ValueError( "k_target<12 is not supported in 5G NR for " "the uplink; please use 'channel coding of small block " "lengths' scheme from Sec. 5.3.3 in 3GPP 38.212 instead." ) # PC bit for k_target = 12-19 bits (see Sec. 6.3.1.3.1 for UL) n_pc = 0 if k_target <= 19: n_pc = 0 # Currently deactivated print( "Warning: For 12<=k<=19 additional 3 parity-check bits " "are defined in 38.212. They are currently not " "implemented by this encoder and, thus, ignored." ) else: # downlink channel # For downlink CRC24 is used # Remark: in PDCCH messages are limited to k=140 if k_target > 140: msg = "k too large for downlink configuration." raise ValueError(msg) if n_target < 25: msg = "n too small for downlink configuration with 24 bit CRC." raise ValueError(msg) if n_target > 576: msg = "n too large for downlink configuration." raise ValueError(msg) crc_pol = "CRC24C" # following 7.3.2 k_crc = 24 n_pc = 0 # Calculate Polar payload length (CRC bits are treated as info bits) k_polar = k_target + k_crc + n_pc if k_polar > n_target: msg = ( "Device is not expected to be configured " "with k_polar + k_crc + n_pc > n_target." ) raise ValueError(msg) # Select polar mother code length n_polar n_min = 5 n_max = 10 # For uplink; otherwise 9 # Select rate-matching scheme following Sec. 5.3.1 if (n_target <= ((9 / 8) * 2 ** (np.ceil(np.log2(n_target)) - 1))) and ( k_polar / n_target < 9 / 16 ): n1 = np.ceil(np.log2(n_target)) - 1 else: n1 = np.ceil(np.log2(n_target)) n2 = np.ceil(np.log2(8 * k_polar)) # Lower bound such that rate > 1/8 n_polar = int(2 ** np.max((np.min([n1, n2, n_max]), n_min))) # Puncturing and shortening as defined in Sec. 5.4.1.1 prefrozen_pos = [] # List containing the pre-frozen indices if n_target < n_polar: if k_polar / n_target <= 7 / 16: # Puncturing if self._verbose: print("Using puncturing for rate-matching.") n_int = 32 * np.ceil((n_polar - n_target) / 32) int_pattern = self.subblock_interleaving(np.arange(n_int)) for i in range(n_polar - n_target): # Freeze additional bits prefrozen_pos.append(int(int_pattern[i])) if n_target >= 3 * n_polar / 4: t = int(np.ceil(3 / 4 * n_polar - n_target / 2) - 1) else: t = int(np.ceil(9 / 16 * n_polar - n_target / 4) - 1) # Extra freezing for i in range(t): prefrozen_pos.append(i) else: # Shortening ("through" sub-block interleaver) if self._verbose: print("Using shortening for rate-matching.") n_int = 32 * np.ceil((n_polar) / 32) int_pattern = self.subblock_interleaving(np.arange(n_int)) for i in range(n_target, n_polar): prefrozen_pos.append(int_pattern[i]) # Remove duplicates prefrozen_pos = np.unique(prefrozen_pos) # Find the remaining n_polar - k_polar - |frozen_set| # Load full channel ranking ch_ranking, _ = generate_5g_ranking(0, n_polar, sort=False) # Remove positions that are already frozen by `pre-freezing` stage info_cand = np.setdiff1d(ch_ranking, prefrozen_pos, assume_unique=True) # Identify k_polar most reliable positions from candidate positions info_pos = [] for i in range(k_polar): info_pos.append(info_cand[-i - 1]) # Sort and create frozen positions for n_polar indices (no shortening) info_pos = np.sort(info_pos).astype(int) frozen_pos = np.setdiff1d( np.arange(n_polar), info_pos, assume_unique=True ) # For downlink only: generate input bit interleaver if self._channel_type == "downlink": if self._verbose: print("Using input bit interleaver for downlink.") ind_input_int = self.input_interleaver(np.arange(k_polar)) else: ind_input_int = None # Generate indices for sub-block interleaver ind_sub_int = self.subblock_interleaving(np.arange(n_polar)) # Rate matching via circular buffer as defined in Sec. 5.4.1.2 c_int = np.arange(n_polar) idx_c_matched = np.zeros([n_target]) if n_target >= n_polar: # Repetition coding if self._verbose: print("Using repetition coding for rate-matching") for ind in range(n_target): idx_c_matched[ind] = c_int[np.mod(ind, n_polar)] else: if k_polar / n_target <= 7 / 16: # Puncturing for ind in range(n_target): idx_c_matched[ind] = c_int[ind + n_polar - n_target] else: # Shortening for ind in range(n_target): idx_c_matched[ind] = c_int[ind] # For uplink only: generate input bit interleaver if self._channel_type == "uplink": if self._verbose: print("Using channel interleaver for uplink.") ind_channel_int = self.channel_interleaver(np.arange(n_target)) # Combine indices for single gather operation ind_t = idx_c_matched[ind_channel_int].astype(int) idx_rate_matched = ind_sub_int[ind_t] else: # no channel interleaver for downlink idx_rate_matched = ind_sub_int[idx_c_matched.astype(int)] if self._verbose: print( f"Code parameters after rate-matching: k = {k_target}, n = {n_target}" ) print(f"Polar mother code: k_polar = {k_polar}, n_polar = {n_polar}") print("Using", crc_pol) print("Frozen positions: ", frozen_pos) print("Channel type: " + self._channel_type) return crc_pol, n_polar, frozen_pos, idx_rate_matched, ind_input_int
[docs] def build(self, input_shape: Tuple[int, ...]) -> None: """Build and check if ``k`` and ``input_shape`` match.""" if input_shape[-1] != self._k_target: raise ValueError("Invalid input shape.")
def call(self, bits: torch.Tensor) -> torch.Tensor: """Polar encoding function including rate-matching and CRC encoding. This function returns the polar encoded codewords for the given information bits ``bits`` following :cite:p:`3GPPTS38212` including rate-matching. :param bits: Tensor of shape `[..., k]` containing the information bits to be encoded. :output cw: Tensor of shape `[..., n]`. """ # Reshape inputs to [..., k] input_shape = bits.shape new_shape = (-1, input_shape[-1]) u = bits.reshape(new_shape) # CRC encode u_crc = self._enc_crc(u) # For downlink only: apply input bit interleaver if self._ind_input_int_t is not None: u_crc = u_crc[:, self._ind_input_int_t] # Encode bits (= channel allocation + Polar transform) c = super().call(u_crc) # Sub-block interleaving with 32 sub-blocks as in Sec. 5.4.1.1 # Rate matching via circular buffer as defined in Sec. 5.4.1.2 # For uplink only: channel interleaving (i_bil=True) # Use pre-registered buffer for torch.compile compatibility c_matched = c[:, self._ind_rate_matching_t] # Restore original shape output_shape = list(input_shape[:-1]) + [self._n_target] c_reshaped = c_matched.reshape(output_shape) return c_reshaped