Source code for sionna.phy.fec.crc

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for cyclic redundancy checks (CRC) and utility functions."""

from typing import Optional, Tuple
import warnings
import numpy as np
import torch

from sionna.phy import Block
from sionna.phy.fec.utils import int_mod_2


__all__ = ["CRCEncoder", "CRCDecoder"]


[docs] class CRCEncoder(Block): """Adds a Cyclic Redundancy Check (CRC) to the input sequence. The CRC polynomials from Sec. 5.1 in :cite:p:`3GPPTS38212` are available: `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`. :param crc_degree: Defines the CRC polynomial to be used. Can be any value from `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`. :param k: Optional number of input bits. If specified, the generator matrix is pre-built during initialization, which is required for ``torch.compile`` compatibility. If not specified, the matrix is built lazily on first call. :param precision: Precision used for internal calculations and outputs. If `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input bits: [..., k], `torch.float`. Binary tensor of arbitrary shape where the last dimension is `[..., k]`. :output x_crc: [..., k + crc_length], `torch.float`. Binary tensor containing CRC-encoded bits of the same shape as ``bits`` except the last dimension changes to `[..., k + crc_length]`. .. rubric:: Notes For performance enhancements, a generator-matrix-based implementation is used for fixed `k` instead of the more common shift register-based operations. Thus, the encoder must trigger an (internal) rebuild if `k` changes. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.crc import CRCEncoder encoder = CRCEncoder("CRC24A") bits = torch.randint(0, 2, (10, 100), dtype=torch.float32) encoded = encoder(bits) print(encoded.shape) # torch.Size([10, 124]) """ def __init__( self, crc_degree: str, *, k: Optional[int] = None, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): super().__init__(precision=precision, device=device, **kwargs) if not isinstance(crc_degree, str): raise TypeError("crc_degree must be a string.") self._crc_degree = crc_degree # Init 5G CRC polynomial self._crc_pol, self._crc_length = self._select_crc_pol(self._crc_degree) self._k: Optional[int] = None self._n: Optional[int] = None # Register buffer placeholder for CUDAGraph compatibility self.register_buffer("_g_mat_crc", None) self._fixed_k: bool = False # True if k was pre-specified # Pre-build generator matrix if k is specified # (required for torch.compile compatibility) if k is not None: self.build((k,)) self._built = True # Mark as built to skip lazy init in __call__ self._fixed_k = True @property def crc_degree(self) -> str: """CRC degree as string.""" return self._crc_degree @property def crc_length(self) -> int: """Length of CRC. Equals number of CRC parity bits.""" return self._crc_length @property def crc_pol(self) -> np.ndarray: """CRC polynomial in binary representation.""" return self._crc_pol @property def k(self) -> Optional[int]: """Number of information bits per codeword.""" if self._k is None: warnings.warn( "CRC encoder is not initialized yet. " "Input dimensions are unknown." ) return self._k @property def n(self) -> Optional[int]: """Number of codeword bits after CRC encoding.""" if self._n is None: warnings.warn( "CRC encoder is not initialized yet. " "Output dimensions are unknown." ) return self._n def _select_crc_pol(self, crc_degree: str) -> Tuple[np.ndarray, int]: """Select 5G CRC polynomial according to Sec. 5.1 :cite:p:`3GPPTS38212`.""" if crc_degree == "CRC24A": crc_length = 24 crc_coeffs = [24, 23, 18, 17, 14, 11, 10, 7, 6, 5, 4, 3, 1, 0] elif crc_degree == "CRC24B": crc_length = 24 crc_coeffs = [24, 23, 6, 5, 1, 0] elif crc_degree == "CRC24C": crc_length = 24 crc_coeffs = [24, 23, 21, 20, 17, 15, 13, 12, 8, 4, 2, 1, 0] elif crc_degree == "CRC16": crc_length = 16 crc_coeffs = [16, 12, 5, 0] elif crc_degree == "CRC11": crc_length = 11 crc_coeffs = [11, 10, 9, 5, 0] elif crc_degree == "CRC6": crc_length = 6 crc_coeffs = [6, 5, 0] else: raise ValueError("Invalid CRC Polynomial") # Invert array (MSB instead of LSB) crc_pol_inv = np.zeros(crc_length + 1) crc_pol_inv[[crc_length - c for c in crc_coeffs]] = 1 return crc_pol_inv.astype(int), crc_length def _gen_crc_mat(self, k: int, pol_crc: np.ndarray) -> np.ndarray: """Build (dense) generator matrix for CRC parity bits. The principle idea is to treat the CRC as systematic linear code, i.e., the generator matrix can be composed out of `k` linear independent (valid) codewords. For this, we CRC encode all `k` unit-vectors `[0,...1,...,0]` and build the generator matrix. To avoid `O(k^2)` complexity, we start with the last unit vector given as `[0,...,0,1]` and can generate the result for next vector `[0,...,1,0]` via another polynomial division of the remainder from the previous result. This allows to successively build the generator matrix at linear complexity `O(k)`. """ crc_length = len(pol_crc) - 1 g_mat = np.zeros([k, crc_length]) x_crc = np.zeros(crc_length, dtype=int) x_crc[0] = 1 for i in range(k): # Shift by one position x_crc = np.concatenate([x_crc, [0]]) if x_crc[0] == 1: x_crc = np.bitwise_xor(x_crc, pol_crc) x_crc = x_crc[1:] g_mat[k - i - 1, :] = x_crc return g_mat
[docs] @torch.compiler.disable def build(self, input_shape: Tuple[int, ...]) -> None: """Build the generator matrix. The CRC is always added to the last dimension of the input. Note: For torch.compile compatibility, use the ``k`` parameter in ``__init__`` to pre-build the generator matrix. """ k = input_shape[-1] if k is None: raise ValueError("Shape of last dimension cannot be None.") g_mat_crc = self._gen_crc_mat(k, self.crc_pol) # Register as buffer for CUDAGraph compatibility self.register_buffer("_g_mat_crc", torch.tensor( g_mat_crc, dtype=self.dtype, device=self.device )) self._k = k self._n = k + g_mat_crc.shape[1]
def call(self, bits: torch.Tensor) -> torch.Tensor: """Cyclic Redundancy Check (CRC) encoding function. This function adds the CRC parity bits to ``bits``. :param bits: Binary tensor of arbitrary shape `[..., k]`. :output x_out: CRC-encoded bits of shape `[..., k + crc_length]`. """ # For torch.compile compatibility: if k was pre-specified, skip all # dynamic checks and just use the pre-built matrix if not self._fixed_k: # Dynamic mode: rebuild if needed input_k = bits.shape[-1] if self._g_mat_crc is None or input_k != self._k: self.build(tuple(bits.shape)) # Note: as the code is systematic, we only encode the CRC positions # Thus, the generator matrix is non-sparse and a "full" matrix # multiplication is probably the fastest implementation x_exp = bits.unsqueeze(-2) # row vector of shape [..., 1, k] # Matrix multiplication for CRC bits x_crc = torch.matmul(x_exp.to(self.dtype), self._g_mat_crc) # Take modulo 2 of x_crc x_crc = x_crc.to(torch.int32) x_crc = int_mod_2(x_crc) # Cast back to original dtype x_crc = x_crc.to(x_exp.dtype) x_conc = torch.cat([x_exp, x_crc], dim=-1) x_out = x_conc.squeeze(-2) return x_out
[docs] class CRCDecoder(Block): """Allows Cyclic Redundancy Check (CRC) verification and removes parity bits. The CRC polynomials from Sec. 5.1 in :cite:p:`3GPPTS38212` are available: `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`. :param crc_encoder: An instance of :class:`~sionna.phy.fec.crc.CRCEncoder` associated with the CRCDecoder. :param precision: Precision used for internal calculations and outputs. If `None`, :attr:`~sionna.phy.config.Config.precision` is used. :param device: Device for computation (e.g., 'cpu', 'cuda:0'). If `None`, :attr:`~sionna.phy.config.Config.device` is used. :input x_crc: [..., k + crc_length], `torch.float`. Binary tensor containing the CRC-encoded bits (the last `crc_length` bits are parity bits). :output bits: [..., k], `torch.float`. Binary tensor containing the information bit sequence without CRC parity bits. :output crc_valid: [..., 1], `torch.bool`. Boolean tensor containing the result of the CRC check per codeword. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.fec.crc import CRCEncoder, CRCDecoder encoder = CRCEncoder("CRC24A") decoder = CRCDecoder(encoder) bits = torch.randint(0, 2, (10, 100), dtype=torch.float32) encoded = encoder(bits) decoded, crc_valid = decoder(encoded) print(decoded.shape, crc_valid.all()) # torch.Size([10, 100]) tensor(True) """ def __init__( self, crc_encoder: CRCEncoder, *, precision: Optional[str] = None, device: Optional[str] = None, **kwargs, ): super().__init__(precision=precision, device=device, **kwargs) if not isinstance(crc_encoder, CRCEncoder): raise TypeError("crc_encoder must be a CRCEncoder instance.") self._encoder = crc_encoder # To detect changing input dimensions self._bit_shape: Optional[Tuple[int, ...]] = None @property def crc_degree(self) -> str: """CRC degree as string.""" return self._encoder.crc_degree @property def encoder(self) -> CRCEncoder: """CRC Encoder used for internal validation.""" return self._encoder
[docs] def build(self, input_shape: Tuple[int, ...]) -> None: """Check shapes.""" self._bit_shape = input_shape if input_shape[-1] < self._encoder.crc_length: raise ValueError( "Input length must be greater than or equal to the CRC length." )
def call( self, x_crc: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Cyclic Redundancy Check (CRC) verification function. This function verifies the CRC of ``x_crc``. It returns the result of the CRC validation and removes parity bits from ``x_crc``. :param x_crc: Binary tensor of arbitrary shape `[..., k + crc_length]`. :output x_info: Information bits without CRC parity bits of shape `[..., k]`. :output crc_valid: Result of the CRC validation for each codeword of shape `[..., 1]`. """ if self._bit_shape is None or x_crc.shape[-1] != self._bit_shape[-1]: self.build(tuple(x_crc.shape)) # Extract information bits and received CRC parity bits x_info = x_crc[..., : -self._encoder.crc_length] x_parity_received = x_crc[..., -self._encoder.crc_length :] # Re-encode information bits to compute expected CRC parity x_parity_computed = self._encoder(x_info)[..., -self._encoder.crc_length :] # Cast output to desired precision as encoder can have a different # precision x_parity_computed = x_parity_computed.to(self.dtype) # Compare received parity with computed parity # CRC is valid if all parity bits match crc_valid = (x_parity_received == x_parity_computed).all(dim=-1, keepdim=True) return x_info, crc_valid