#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for Polar encoding including 5G compliant rate-matching and CRC
concatenation."""
from typing import Optional, Tuple
import numbers
import numpy as np
import torch
from sionna.phy import Block
from sionna.phy.fec.crc import CRCEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking
__all__ = ["PolarEncoder", "Polar5GEncoder"]
[docs]
class PolarEncoder(Block):
"""Polar encoder for given code parameters.
This block performs polar encoding for the given ``k`` information bits and
the `frozen set` (i.e., indices of frozen positions) specified by
``frozen_pos``.
:param frozen_pos: Array of `int` defining the `n-k` frozen indices, i.e.,
information bits are mapped onto the `k` complementary positions.
:param n: Defining the codeword length.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input bits: [..., k], `torch.float`.
Binary tensor containing the information bits to be encoded.
:output cw: [..., n], `torch.float`.
Binary tensor containing the codeword bits.
.. rubric:: Notes
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import PolarEncoder
from sionna.phy.fec.polar.utils import generate_5g_ranking
k, n = 100, 256
frozen_pos, _ = generate_5g_ranking(k, n)
encoder = PolarEncoder(frozen_pos, n)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
print(codewords.shape)
# torch.Size([10, 256])
"""
def __init__(
self,
frozen_pos: np.ndarray,
n: int,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
n = int(n) # n can be float (e.g. as result of n=k*r)
if not np.issubdtype(frozen_pos.dtype, int):
raise TypeError("frozen_pos must consist of ints.")
if len(frozen_pos) > n:
msg = "Number of elements in frozen_pos cannot be greater than n."
raise ValueError(msg)
if np.log2(n) != int(np.log2(n)):
raise ValueError("n must be a power of 2.")
self._k = n - len(frozen_pos)
self._n = n
self._frozen_pos = frozen_pos
# Generate info positions
info_pos = np.setdiff1d(np.arange(self._n), frozen_pos)
if self._k != len(info_pos):
raise ValueError("Internal error: invalid info_pos generated.")
self._info_pos = info_pos # Keep numpy array for property access
# Register info_pos as buffer for torch.compile compatibility
self.register_buffer(
"_info_pos_t",
torch.tensor(info_pos, dtype=torch.int64, device=self.device),
)
self._check_input = True # Check input for binary values during first call
self._nb_stages = int(np.log2(self._n))
self._ind_gather = self._gen_indices(self._n)
@property
def k(self) -> int:
"""Number of information bits."""
return self._k
@property
def n(self) -> int:
"""Codeword length."""
return self._n
@property
def frozen_pos(self) -> np.ndarray:
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self) -> np.ndarray:
"""Information bit positions for Polar encoding."""
return self._info_pos
def _gen_indices(self, n: int) -> torch.Tensor:
"""Pre-calculate encoding indices stage-wise for gather operations."""
nb_stages = int(np.log2(n))
# Last position denotes empty placeholder (points to element n+1)
ind_gather = np.ones([nb_stages, n + 1], dtype=np.int32) * n
for s in range(nb_stages):
ind_range = np.arange(int(n / 2))
ind_dest = ind_range * 2 - np.mod(ind_range, 2**s)
ind_origin = ind_dest + 2**s
ind_gather[s, ind_dest] = ind_origin # Update gather indices
ind_gather = torch.tensor(ind_gather, dtype=torch.int32, device=self.device)
return ind_gather
@torch.compiler.disable
def _validate_binary_input(self, u: torch.Tensor) -> None:
"""Validate that input tensor contains only binary values.
This method is decorated with @torch.compiler.disable to avoid
recompilation issues caused by the mutable _check_input flag.
"""
if self._check_input:
u_test = u.float()
is_binary = torch.logical_or(
torch.eq(u_test, 0.0), torch.eq(u_test, 1.0)
).all()
if not is_binary:
raise ValueError("Input must be binary.")
self._check_input = False
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Build and check if ``k`` and ``input_shape`` match."""
if input_shape[-1] != self._k:
raise ValueError("Last dimension must be of length k.")
def call(self, bits: torch.Tensor) -> torch.Tensor:
"""Polar encoding function.
This function returns the polar encoded codewords for the given
information bits ``bits``.
:param bits: Tensor of shape `[..., k]` containing the
information bits to be encoded.
:output cw: Tensor of shape `[..., n]`.
"""
# Reshape inputs to [..., k]
input_shape = bits.shape
new_shape = (-1, self._k)
u = bits.reshape(new_shape)
# Validate input (excluded from compilation to avoid recompilation)
self._validate_binary_input(u)
# Copy info bits to information set; other positions are frozen (=0)
batch_size = u.shape[0]
# Return an all-zero tensor of shape [batch, n]
c = torch.zeros(
(batch_size, self._n + 1), dtype=u.dtype, device=u.device
)
# Scatter info bits into the correct positions using pre-registered buffer
c[:, : self._n] = c[:, : self._n].scatter(
1, self._info_pos_t.unsqueeze(0).expand(batch_size, -1), u
)
# Cast to integer for more efficient XORing
x = c.to(torch.uint8)
# Loop over all stages
for s in range(self._nb_stages):
ind_helper = self._ind_gather[s, :]
x_add = x[:, ind_helper]
x = torch.bitwise_xor(x, x_add)
# Remove last position
c_out = x[:, : self._n]
# Restore original shape
output_shape = list(input_shape[:-1]) + [self._n]
c_reshaped = c_out.reshape(output_shape)
# Cast to rdtype for compatibility with other components
return c_reshaped.to(self.dtype)
[docs]
class Polar5GEncoder(PolarEncoder):
# pylint: disable=line-too-long
"""5G compliant Polar encoder including rate-matching following
:cite:p:`3GPPTS38212` for the uplink scenario (`UCI`) and downlink
scenario (`DCI`).
This block performs polar encoding for ``k`` information bits and
rate-matching such that the codeword length is ``n``. This includes the CRC
concatenation and the interleaving as defined in :cite:p:`3GPPTS38212`.
Note: `block segmentation` is currently not supported (`I_seq=False`).
We follow the basic structure from Fig. 6 in :cite:p:`Bioglio_Design`.
For further details, we refer to :cite:p:`3GPPTS38212`, :cite:p:`Bioglio_Design` and
:cite:p:`Hui_ChannelCoding`.
:param k: Defining the number of information bits per codeword.
:param n: Defining the codeword length.
:param channel_type: Can be ``'uplink'`` or ``'downlink'``.
:param verbose: If `True`, rate-matching parameters will be printed.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input bits: [..., k], `torch.float`.
Binary tensor containing the information bits to be encoded.
:output cw: [..., n], `torch.float`.
Binary tensor containing the codeword bits.
.. rubric:: Notes
The encoder supports the `uplink` Polar coding (`UCI`) scheme from
:cite:p:`3GPPTS38212` and the `downlink` Polar coding (`DCI`) :cite:p:`3GPPTS38212`,
respectively.
For `12 <= k <= 19` the 3 additional parity bits as defined in
:cite:p:`3GPPTS38212` are not implemented as it would also require a
modified decoding procedure to materialize the potential gains.
`Code segmentation` is currently not supported and, thus, ``n`` is
limited to a maximum length of 1088 codeword bits.
For the downlink scenario, the input length is limited to `k <= 140`
information bits due to the limited input bit interleaver size
:cite:p:`3GPPTS38212`.
For simplicity, the implementation does not exactly re-implement the
`DCI` scheme from :cite:p:`3GPPTS38212`. This implementation neglects the
`all-one` initialization of the CRC shift register and the scrambling
of the CRC parity bits with the `RNTI`.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.polar import Polar5GEncoder
k, n = 100, 200
encoder = Polar5GEncoder(k, n)
bits = torch.randint(0, 2, (10, k), dtype=torch.float32)
codewords = encoder(bits)
print(codewords.shape)
# torch.Size([10, 200])
"""
def __init__(
self,
k: int,
n: int,
channel_type: str = "uplink",
verbose: bool = False,
*,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
if not isinstance(k, numbers.Number):
raise TypeError("k must be a number.")
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
k = int(k) # k or n can be float (e.g. as result of n=k*r)
n = int(n)
if n < k:
raise ValueError("Invalid coderate (>1).")
if not isinstance(verbose, bool):
raise TypeError("verbose must be bool.")
if channel_type not in ("uplink", "downlink"):
raise ValueError("Unsupported channel_type.")
self._channel_type = channel_type
self._k_target = k
self._n_target = n
self._verbose = verbose
# Initialize rate-matcher
crc_degree, n_polar, frozen_pos, idx_rm, idx_input = self._init_rate_match(
k, n
)
self._frozen_pos = frozen_pos # Required for decoder
self._ind_rate_matching = idx_rm # Keep numpy array for reference
self._ind_input_int = idx_input # Keep numpy array for reference
# Initialize CRC encoder with k to pre-build generator matrix
# (required for torch.compile compatibility)
# Store reference to assign after super().__init__()
_enc_crc_ref = CRCEncoder(
crc_degree, k=k, precision=precision, device=device
)
# Init super-class (PolarEncoder)
super().__init__(
frozen_pos, n_polar, precision=precision, device=device, **kwargs
)
# Assign CRC encoder after super().__init__() for nn.Module compatibility
self._enc_crc = _enc_crc_ref
# Register rate-matching and interleaver indices as buffers
self.register_buffer(
"_ind_rate_matching_t",
torch.tensor(idx_rm.astype(np.int32), dtype=torch.int32, device=self.device),
)
if idx_input is not None:
self.register_buffer(
"_ind_input_int_t",
torch.tensor(idx_input, dtype=torch.int32, device=self.device),
)
else:
self._ind_input_int_t = None
@property
def enc_crc(self) -> CRCEncoder:
"""CRC encoder block used for CRC concatenation."""
return self._enc_crc
@property
def k_target(self) -> int:
"""Number of information bits including rate-matching."""
return self._k_target
@property
def n_target(self) -> int:
"""Codeword length including rate-matching."""
return self._n_target
@property
def k_polar(self) -> int:
"""Number of information bits of the underlying Polar code."""
return self._k
@property
def n_polar(self) -> int:
"""Codeword length of the underlying Polar code."""
return self._n
@property
def k(self) -> int:
"""Number of information bits including rate-matching."""
return self._k_target
@property
def n(self) -> int:
"""Codeword length including rate-matching."""
return self._n_target
[docs]
def subblock_interleaving(self, u: np.ndarray) -> np.ndarray:
"""Input bit interleaving as defined in Sec 5.4.1.1 :cite:p:`3GPPTS38212`.
:param u: 1D array to be interleaved. Length of ``u`` must be a
multiple of 32.
:output y: Interleaved version of ``u`` with same shape and dtype
as ``u``.
"""
k = u.shape[-1]
if np.mod(k, 32) != 0:
msg = "length for sub-block interleaving must be a multiple of 32."
raise ValueError(msg)
y = np.zeros_like(u)
# Permutation according to Tab 5.4.1.1.1-1 in 38.212
perm = np.array(
[
0, 1, 2, 4, 3, 5, 6, 7, 8, 16, 9, 17, 10, 18, 11, 19,
12, 20, 13, 21, 14, 22, 15, 23, 24, 25, 26, 28, 27, 29, 30, 31,
]
)
for n_idx in range(k):
i = int(np.floor(32 * n_idx / k))
j = perm[i] * k / 32 + np.mod(n_idx, k / 32)
j = int(j)
y[n_idx] = u[j]
return y
[docs]
def channel_interleaver(self, c: np.ndarray) -> np.ndarray:
"""Triangular interleaver following Sec. 5.4.1.3 in :cite:p:`3GPPTS38212`.
:param c: 1D array to be interleaved.
:output c_int: Interleaved version of ``c`` with same shape and
dtype as ``c``.
"""
n = c.shape[-1] # Denoted as E in 38.212
c_int = np.zeros_like(c)
# Find smallest T s.t. T*(T+1)/2 >= n
t = 0
while t * (t + 1) / 2 < n:
t += 1
v = np.zeros([t, t])
ind_k = 0
for ind_i in range(t):
for ind_j in range(t - ind_i):
if ind_k < n:
v[ind_i, ind_j] = c[ind_k]
else:
v[ind_i, ind_j] = np.nan # NULL
ind_k += 1
ind_k = 0
for ind_j in range(t):
for ind_i in range(t - ind_j):
if not np.isnan(v[ind_i, ind_j]):
c_int[ind_k] = v[ind_i, ind_j]
ind_k += 1
return c_int
def _init_rate_match(
self, k_target: int, n_target: int
) -> Tuple[str, int, np.ndarray, np.ndarray, Optional[np.ndarray]]:
"""Implementing polar rate matching according to :cite:p:`3GPPTS38212`.
Please note that this part of the code only runs during the
initialization and, thus, is not performance critical. For easier
alignment and traceability with the standard document :cite:p:`3GPPTS38212`
the implementation prefers `for loop`-based indexing.
The relation of terminology between :cite:p:`3GPPTS38212` and this code is
given as:
`A`...`k_target`
`E`...`n_target`
`K`...`k_polar`
`N`...`n_polar`
`L`...`k_crc`.
"""
# Check input for consistency (see Sec. 6.3.1.2.1 for UL)
if n_target < k_target:
msg = "n must be larger or equal k."
raise ValueError(msg)
if n_target < 18:
msg = "n<18 is not supported by the 5G Polar coding scheme."
raise ValueError(msg)
if k_target > 1013:
msg = "k too large - currently, no codeword segmentation supported."
raise ValueError(msg)
if n_target > 1088:
msg = "n too large - currently, no codeword segmentation supported."
raise ValueError(msg)
# Select CRC polynomials (see Sec. 6.3.1.2.1 for UL)
if self._channel_type == "uplink":
if 12 <= k_target <= 19:
crc_pol = "CRC6"
k_crc = 6
elif k_target >= 20:
crc_pol = "CRC11"
k_crc = 11
else:
raise ValueError(
"k_target<12 is not supported in 5G NR for "
"the uplink; please use 'channel coding of small block "
"lengths' scheme from Sec. 5.3.3 in 3GPP 38.212 instead."
)
# PC bit for k_target = 12-19 bits (see Sec. 6.3.1.3.1 for UL)
n_pc = 0
if k_target <= 19:
n_pc = 0 # Currently deactivated
print(
"Warning: For 12<=k<=19 additional 3 parity-check bits "
"are defined in 38.212. They are currently not "
"implemented by this encoder and, thus, ignored."
)
else: # downlink channel
# For downlink CRC24 is used
# Remark: in PDCCH messages are limited to k=140
if k_target > 140:
msg = "k too large for downlink configuration."
raise ValueError(msg)
if n_target < 25:
msg = "n too small for downlink configuration with 24 bit CRC."
raise ValueError(msg)
if n_target > 576:
msg = "n too large for downlink configuration."
raise ValueError(msg)
crc_pol = "CRC24C" # following 7.3.2
k_crc = 24
n_pc = 0
# Calculate Polar payload length (CRC bits are treated as info bits)
k_polar = k_target + k_crc + n_pc
if k_polar > n_target:
msg = (
"Device is not expected to be configured "
"with k_polar + k_crc + n_pc > n_target."
)
raise ValueError(msg)
# Select polar mother code length n_polar
n_min = 5
n_max = 10 # For uplink; otherwise 9
# Select rate-matching scheme following Sec. 5.3.1
if (n_target <= ((9 / 8) * 2 ** (np.ceil(np.log2(n_target)) - 1))) and (
k_polar / n_target < 9 / 16
):
n1 = np.ceil(np.log2(n_target)) - 1
else:
n1 = np.ceil(np.log2(n_target))
n2 = np.ceil(np.log2(8 * k_polar)) # Lower bound such that rate > 1/8
n_polar = int(2 ** np.max((np.min([n1, n2, n_max]), n_min)))
# Puncturing and shortening as defined in Sec. 5.4.1.1
prefrozen_pos = [] # List containing the pre-frozen indices
if n_target < n_polar:
if k_polar / n_target <= 7 / 16:
# Puncturing
if self._verbose:
print("Using puncturing for rate-matching.")
n_int = 32 * np.ceil((n_polar - n_target) / 32)
int_pattern = self.subblock_interleaving(np.arange(n_int))
for i in range(n_polar - n_target):
# Freeze additional bits
prefrozen_pos.append(int(int_pattern[i]))
if n_target >= 3 * n_polar / 4:
t = int(np.ceil(3 / 4 * n_polar - n_target / 2) - 1)
else:
t = int(np.ceil(9 / 16 * n_polar - n_target / 4) - 1)
# Extra freezing
for i in range(t):
prefrozen_pos.append(i)
else:
# Shortening ("through" sub-block interleaver)
if self._verbose:
print("Using shortening for rate-matching.")
n_int = 32 * np.ceil((n_polar) / 32)
int_pattern = self.subblock_interleaving(np.arange(n_int))
for i in range(n_target, n_polar):
prefrozen_pos.append(int_pattern[i])
# Remove duplicates
prefrozen_pos = np.unique(prefrozen_pos)
# Find the remaining n_polar - k_polar - |frozen_set|
# Load full channel ranking
ch_ranking, _ = generate_5g_ranking(0, n_polar, sort=False)
# Remove positions that are already frozen by `pre-freezing` stage
info_cand = np.setdiff1d(ch_ranking, prefrozen_pos, assume_unique=True)
# Identify k_polar most reliable positions from candidate positions
info_pos = []
for i in range(k_polar):
info_pos.append(info_cand[-i - 1])
# Sort and create frozen positions for n_polar indices (no shortening)
info_pos = np.sort(info_pos).astype(int)
frozen_pos = np.setdiff1d(
np.arange(n_polar), info_pos, assume_unique=True
)
# For downlink only: generate input bit interleaver
if self._channel_type == "downlink":
if self._verbose:
print("Using input bit interleaver for downlink.")
ind_input_int = self.input_interleaver(np.arange(k_polar))
else:
ind_input_int = None
# Generate indices for sub-block interleaver
ind_sub_int = self.subblock_interleaving(np.arange(n_polar))
# Rate matching via circular buffer as defined in Sec. 5.4.1.2
c_int = np.arange(n_polar)
idx_c_matched = np.zeros([n_target])
if n_target >= n_polar:
# Repetition coding
if self._verbose:
print("Using repetition coding for rate-matching")
for ind in range(n_target):
idx_c_matched[ind] = c_int[np.mod(ind, n_polar)]
else:
if k_polar / n_target <= 7 / 16:
# Puncturing
for ind in range(n_target):
idx_c_matched[ind] = c_int[ind + n_polar - n_target]
else:
# Shortening
for ind in range(n_target):
idx_c_matched[ind] = c_int[ind]
# For uplink only: generate input bit interleaver
if self._channel_type == "uplink":
if self._verbose:
print("Using channel interleaver for uplink.")
ind_channel_int = self.channel_interleaver(np.arange(n_target))
# Combine indices for single gather operation
ind_t = idx_c_matched[ind_channel_int].astype(int)
idx_rate_matched = ind_sub_int[ind_t]
else: # no channel interleaver for downlink
idx_rate_matched = ind_sub_int[idx_c_matched.astype(int)]
if self._verbose:
print(
f"Code parameters after rate-matching: k = {k_target}, n = {n_target}"
)
print(f"Polar mother code: k_polar = {k_polar}, n_polar = {n_polar}")
print("Using", crc_pol)
print("Frozen positions: ", frozen_pos)
print("Channel type: " + self._channel_type)
return crc_pol, n_polar, frozen_pos, idx_rate_matched, ind_input_int
[docs]
def build(self, input_shape: Tuple[int, ...]) -> None:
"""Build and check if ``k`` and ``input_shape`` match."""
if input_shape[-1] != self._k_target:
raise ValueError("Invalid input shape.")
def call(self, bits: torch.Tensor) -> torch.Tensor:
"""Polar encoding function including rate-matching and CRC encoding.
This function returns the polar encoded codewords for the given
information bits ``bits`` following :cite:p:`3GPPTS38212` including
rate-matching.
:param bits: Tensor of shape `[..., k]` containing the information
bits to be encoded.
:output cw: Tensor of shape `[..., n]`.
"""
# Reshape inputs to [..., k]
input_shape = bits.shape
new_shape = (-1, input_shape[-1])
u = bits.reshape(new_shape)
# CRC encode
u_crc = self._enc_crc(u)
# For downlink only: apply input bit interleaver
if self._ind_input_int_t is not None:
u_crc = u_crc[:, self._ind_input_int_t]
# Encode bits (= channel allocation + Polar transform)
c = super().call(u_crc)
# Sub-block interleaving with 32 sub-blocks as in Sec. 5.4.1.1
# Rate matching via circular buffer as defined in Sec. 5.4.1.2
# For uplink only: channel interleaving (i_bil=True)
# Use pre-registered buffer for torch.compile compatibility
c_matched = c[:, self._ind_rate_matching_t]
# Restore original shape
output_shape = list(input_shape[:-1]) + [self._n_target]
c_reshaped = c_matched.reshape(output_shape)
return c_reshaped