#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Blocks for LDPC channel encoding and utility functions."""
from typing import List, Optional, Tuple
import numbers
import warnings
import numpy as np
import scipy as sp
import torch
from importlib_resources import files, as_file
from sionna.phy import Block
from sionna.phy.fec.utils import int_mod_2
from . import codes
__all__ = ["LDPC5GEncoder"]
[docs]
class LDPC5GEncoder(Block):
# pylint: disable=line-too-long
r"""5G NR LDPC Encoder following the 3GPP 38.212 including rate-matching.
The implementation follows the 3GPP NR Initiative :cite:p:`3GPPTS38212`.
:param k: Number of information bits per codeword.
:param n: Desired codeword length.
:param num_bits_per_symbol: Number of bits per QAM symbol. If provided,
the codeword will be interleaved after rate-matching as specified
in Sec. 5.4.2.2 in :cite:p:`3GPPTS38212`.
:param bg: Basegraph to be used for the code construction.
If `None`, the encoder will automatically select the basegraph
according to :cite:p:`3GPPTS38212`.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input bits: [..., k], `torch.float`.
Binary tensor containing the information bits to be encoded.
:input rv: `None` | list of int.
List of redundancy version indices (0–3) for HARQ-IR
rate-matching per TS 38.212 Table 5.4.2.1-2.
If `None` (default), standard RV 0 encoding is used and the
output shape is ``[..., n]``. When provided, each RV produces
an independent rate-matched codeword and the output shape
becomes ``[..., len(rv), n]``.
:output cw: [..., n] or [..., len(rv), n], `torch.float`.
Encoded codeword bits. Same shape as the input with the last
dimension replaced by ``n``. When ``rv`` is provided, an
additional dimension of size ``len(rv)`` is inserted before the
codeword dimension.
.. rubric:: Notes
As specified in :cite:p:`3GPPTS38212`, the encoder also performs
rate-matching (puncturing and shortening). Thus, the corresponding
decoder needs to `invert` these operations, i.e., must be compatible with
the 5G encoding scheme.
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.fec.ldpc import LDPC5GEncoder
# Create encoder for k=100 information bits and n=200 codeword bits
encoder = LDPC5GEncoder(k=100, n=200)
# Generate random information bits
u = torch.randint(0, 2, (10, 100), dtype=torch.float32)
c = encoder(u)
print(c.shape)
# [10, 200]
# HARQ: produce two redundancy versions
c_harq = encoder(u, rv=[0, 2])
print(c_harq.shape)
# [10, 2, 200]
"""
def __init__(
self,
k: int,
n: int,
num_bits_per_symbol: Optional[int] = None,
bg: Optional[str] = None,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if not isinstance(k, numbers.Number):
raise TypeError("k must be a number.")
if not isinstance(n, numbers.Number):
raise TypeError("n must be a number.")
k = int(k) # k or n can be float (e.g. as result of n=k*r)
n = int(n)
if k > 8448:
raise ValueError("Unsupported code length (k too large).")
if k < 12:
raise ValueError("Unsupported code length (k too small).")
if n > (316 * 384):
raise ValueError("Unsupported code length (n too large).")
if n < 0:
raise ValueError("Unsupported code length (n negative).")
# Init encoder parameters
self._k = k # number of input bits (= input shape)
self._n = n # the desired length (= output shape)
self._coderate = k / n
self._check_input = True # check input for consistency (i.e., binary)
# Allow actual code rates slightly larger than 948/1024
# to account for the quantization procedure in 38.214 5.1.3.1
if self._coderate > (948 / 1024): # as specified in 38.212 5.4.2.1
warnings.warn(
f"Effective coderate r>948/1024 for n={n}, k={k}.")
if self._coderate > 0.95: # as specified in 38.212 5.4.2.1
raise ValueError(f"Unsupported coderate (r>0.95) for n={n}, k={k}.")
if self._coderate < (1 / 5):
# outer rep. coding currently not supported
raise ValueError("Unsupported coderate (r<1/5).")
# Construct the basegraph according to 38.212
self._bg = self._sel_basegraph(self._k, self._coderate, bg)
self._z, self._i_ls, self._k_b = self._sel_lifting(self._k, self._bg)
self._bm = self._load_basegraph(self._i_ls, self._bg)
# Total number of codeword bits
self._n_ldpc = self._bm.shape[1] * self._z
# If K_real < K_target puncturing must be applied earlier
self._k_ldpc = self._k_b * self._z
# Construct explicit graph via lifting
pcm = self._lift_basegraph(self._bm, self._z)
pcm_a, pcm_b_inv, pcm_c1, pcm_c2 = self._gen_submat(
self._bm, self._k_b, self._z, self._bg
)
# Init sub-matrices for fast encoding ("RU"-method)
self._pcm = pcm # store the sparse parity-check matrix (for decoding)
# Store indices for fast gathering (instead of explicit matmul)
# Register as buffers for CUDAGraph compatibility
self.register_buffer("_pcm_a_ind", torch.tensor(
self._mat_to_ind(pcm_a), dtype=torch.int32, device=self.device
))
self.register_buffer("_pcm_b_inv_ind", torch.tensor(
self._mat_to_ind(pcm_b_inv), dtype=torch.int32, device=self.device
))
self.register_buffer("_pcm_c1_ind", torch.tensor(
self._mat_to_ind(pcm_c1), dtype=torch.int32, device=self.device
))
self.register_buffer("_pcm_c2_ind", torch.tensor(
self._mat_to_ind(pcm_c2), dtype=torch.int32, device=self.device
))
self._num_bits_per_symbol = num_bits_per_symbol
if num_bits_per_symbol is not None:
out_int, out_int_inv = self.generate_out_int(
self._n, self._num_bits_per_symbol
)
self.register_buffer("_out_int", torch.tensor(out_int, dtype=torch.int32, device=self.device))
self.register_buffer("_out_int_inv", torch.tensor(
out_int_inv, dtype=torch.int32, device=self.device
))
###############################
# Public methods and properties
###############################
@property
def k(self) -> int:
"""Number of input information bits."""
return self._k
@property
def n(self) -> int:
"""Number of output codeword bits."""
return self._n
@property
def coderate(self) -> float:
"""Coderate of the LDPC code after rate-matching."""
return self._coderate
@property
def k_ldpc(self) -> int:
"""Number of LDPC information bits after rate-matching."""
return self._k_ldpc
@property
def n_ldpc(self) -> int:
"""Number of LDPC codeword bits before rate-matching."""
return self._n_ldpc
@property
def pcm(self) -> sp.sparse.csr_matrix:
"""Parity-check matrix for given code parameters."""
return self._pcm
@property
def z(self) -> int:
"""Lifting factor of the basegraph."""
return self._z
@property
def num_bits_per_symbol(self) -> Optional[int]:
"""Modulation order used for the rate-matching output interleaver."""
return self._num_bits_per_symbol
@property
def k_filler(self) -> int:
"""Number of filler bits added to pad ``k`` to ``k_ldpc``."""
return self._k_ldpc - self._k
@property
def n_cb(self) -> int:
"""Circular buffer length (excludes first 2Z)."""
return self._n_ldpc - 2 * self._z
@property
def n_cb_comp(self) -> int:
"""Compressed circular buffer length (excludes first 2Z and fillers)."""
return self.n_cb - self.k_filler
@property
def rv_starts(self) -> List[int]:
r"""RV start positions ``k0`` from TS 38.212 Table 5.4.2.1-2.
Returns a list of length 4 indexed by ``rv_id``. Values are in
spec buffer space (after the first 2Z punctured positions,
before filler removal).
.. math::
k_0 = \left\lfloor \frac{\text{coeff} \cdot N_{cb}}
{N_{\text{cols}} \cdot Z_c} \right\rfloor \cdot Z_c
where :math:`N_{\text{cols}}` is 66 (BG1) or 50 (BG2).
Currently assumes ``I_LBRM=0``, i.e., ``N_cb = N``.
"""
n_cb = self.n_cb
z = self._z
if self._bg == "bg1":
coeffs = [0, 17, 33, 56]
n_cols = 66
else:
coeffs = [0, 13, 25, 43]
n_cols = 50
return [(c * n_cb // (n_cols * z)) * z for c in coeffs]
@property
def out_int(self) -> torch.Tensor:
"""Output interleaver sequence as defined in 5.4.2.2."""
return self._out_int
@property
def out_int_inv(self) -> torch.Tensor:
"""Inverse output interleaver sequence as defined in 5.4.2.2."""
return self._out_int_inv
#################
# Utility methods
#################
def _k0_comp(self, k0: int) -> int:
"""Map ``k0`` from spec buffer space to compressed RM buffer space.
The compressed buffer omits the contiguous filler block.
"""
filler_start = self._k - 2 * self._z
filler_len = self.k_filler
if filler_len <= 0 or k0 < filler_start:
return k0
if k0 < filler_start + filler_len:
return filler_start
return k0 - filler_len
[docs]
def get_start_positions_comp(self, rv_list: List[int]) -> List[int]:
"""Return compressed-buffer start positions for a list of RV indices.
:param rv_list: List of RV indices (each 0–3).
:output: Corresponding start positions in the compressed RM buffer.
"""
rv_start_map = self.rv_starts
return [self._k0_comp(rv_start_map[rv_id]) for rv_id in rv_list]
[docs]
def generate_out_int(
self, n: int, num_bits_per_symbol: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Generates LDPC output interleaver sequence as defined in
Sec 5.4.2.2 in :cite:p:`3GPPTS38212`.
:param n: Desired output sequence length.
:param num_bits_per_symbol: Number of bits per QAM symbol,
i.e., the modulation order.
.. rubric:: Notes
The interleaver pattern depends on the modulation order and helps to
reduce dependencies in bit-interleaved coded modulation (BICM) schemes
combined with higher order modulation.
"""
# Allow float inputs, but verify that they represent integer
if n % 1 != 0:
raise ValueError("n must be int.")
if num_bits_per_symbol % 1 != 0:
raise ValueError("num_bits_per_symbol must be int.")
n = int(n)
if n <= 0:
raise ValueError("n must be a positive integer.")
if num_bits_per_symbol <= 0:
raise ValueError("num_bits_per_symbol must be a positive integer.")
num_bits_per_symbol = int(num_bits_per_symbol)
if n % num_bits_per_symbol != 0:
raise ValueError("n must be a multiple of num_bits_per_symbol.")
# Pattern as defined in Sec 5.4.2.2
perm_seq = np.zeros(n, dtype=int)
for j in range(int(n / num_bits_per_symbol)):
for i in range(num_bits_per_symbol):
perm_seq[i + j * num_bits_per_symbol] = int(
i * int(n / num_bits_per_symbol) + j
)
perm_seq_inv = np.argsort(perm_seq)
return perm_seq, perm_seq_inv
def _sel_basegraph(
self, k: int, r: float, bg_: Optional[str] = None
) -> str:
"""Select basegraph according to :cite:p:`3GPPTS38212` and check for
consistency.
"""
# If bg is explicitly provided, we only check for consistency
if bg_ is None:
if k <= 292:
bg = "bg2"
elif k <= 3824 and r <= 0.67:
bg = "bg2"
elif r <= 0.25:
bg = "bg2"
else:
bg = "bg1"
elif bg_ in ("bg1", "bg2"):
bg = bg_
else:
raise ValueError("Basegraph must be bg1, bg2 or None.")
# Check for consistency
if bg == "bg1" and k > 8448:
raise ValueError("K is not supported by BG1 (too large).")
if bg == "bg2" and k > 3840:
raise ValueError(f"K is not supported by BG2 (too large) k={k}.")
if bg == "bg1" and r < 1 / 3:
raise ValueError(
"Only coderate>1/3 supported for BG1. "
"Remark: Repetition coding is currently not supported."
)
if bg == "bg2" and r < 1 / 5:
raise ValueError(
"Only coderate>1/5 supported for BG2. "
"Remark: Repetition coding is currently not supported."
)
return bg
def _load_basegraph(self, i_ls: int, bg: str) -> np.ndarray:
"""Helper to load basegraph from csv files.
``i_ls`` is sub_index of the basegraph and fixed during lifting
selection.
"""
if i_ls > 7:
raise ValueError("i_ls too large.")
if i_ls < 0:
raise ValueError("i_ls cannot be negative.")
# csv files are taken from 38.212 and dimension is explicitly given
if bg == "bg1":
bm = np.zeros([46, 68]) - 1 # init matrix with -1 (None positions)
elif bg == "bg2":
bm = np.zeros([42, 52]) - 1 # init matrix with -1 (None positions)
else:
raise ValueError("Basegraph not supported.")
# And load the basegraph from csv format in folder "codes"
source = files(codes).joinpath(f"5G_{bg}.csv")
with as_file(source) as codes_csv:
bg_csv = np.genfromtxt(codes_csv, delimiter=";")
# Reconstruct BG for given i_ls
r_ind = 0
for r in np.arange(2, bg_csv.shape[0]):
# Check for next row index
if not np.isnan(bg_csv[r, 0]):
r_ind = int(bg_csv[r, 0])
c_ind = int(bg_csv[r, 1]) # second column in csv is column index
value = bg_csv[r, i_ls + 2] # i_ls entries start at offset 2
bm[r_ind, c_ind] = value
return bm
def _lift_basegraph(self, bm: np.ndarray, z: int) -> sp.sparse.csr_matrix:
"""Lift basegraph with lifting factor ``z`` and shifted identities as
defined by the entries of ``bm``.
"""
num_nonzero = np.sum(bm >= 0) # num of non-neg elements in bm
# Init all non-zero row/column indices
r_idx = np.zeros(z * num_nonzero)
c_idx = np.zeros(z * num_nonzero)
data = np.ones(z * num_nonzero)
# Row/column indices of identity matrix for lifting
im = np.arange(z)
idx = 0
for r in range(bm.shape[0]):
for c in range(bm.shape[1]):
if bm[r, c] == -1: # -1 is used as all-zero matrix placeholder
pass # do nothing (sparse)
else:
# Roll matrix by bm[r,c]
c_roll = np.mod(im + bm[r, c], z)
# Append rolled identity matrix to pcm
r_idx[idx * z : (idx + 1) * z] = r * z + im
c_idx[idx * z : (idx + 1) * z] = c * z + c_roll
idx += 1
# Generate lifted sparse matrix from indices
pcm = sp.sparse.csr_matrix(
(data, (r_idx, c_idx)), shape=(z * bm.shape[0], z * bm.shape[1])
)
return pcm
def _sel_lifting(self, k: int, bg: str) -> Tuple[int, int, int]:
"""Select lifting as defined in Sec. 5.2.2 in :cite:p:`3GPPTS38212`.
We assume B < K_cb, thus B'= B and C = 1, i.e., no
additional CRC is appended. Thus, K' = B'/C = B and B is our K.
Z is the lifting factor.
i_ls is the set index ranging from 0...7 (specifying the exact bg
selection).
k_b is the number of information bit columns in the basegraph.
"""
# Lifting set according to 38.212 Tab 5.3.2-1
s_val = [
[2, 4, 8, 16, 32, 64, 128, 256],
[3, 6, 12, 24, 48, 96, 192, 384],
[5, 10, 20, 40, 80, 160, 320],
[7, 14, 28, 56, 112, 224],
[9, 18, 36, 72, 144, 288],
[11, 22, 44, 88, 176, 352],
[13, 26, 52, 104, 208],
[15, 30, 60, 120, 240],
]
if bg == "bg1":
k_b = 22
else:
if k > 640:
k_b = 10
elif k > 560:
k_b = 9
elif k > 192:
k_b = 8
else:
k_b = 6
# Find the min of Z from Tab. 5.3.2-1 s.t. k_b*Z>=K'
min_val = 100000
z = 0
i_ls = 0
i = -1
for s in s_val:
i += 1
for s1 in s:
x = k_b * s1
if x >= k:
# Valid solution
if x < min_val:
min_val = x
z = s1
i_ls = i
# And set K=22*Z for bg1 and K=10Z for bg2
if bg == "bg1":
k_b = 22
else:
k_b = 10
return z, i_ls, k_b
def _gen_submat(
self, bm: np.ndarray, k_b: int, z: int, bg: str
) -> Tuple[
sp.sparse.csr_matrix,
sp.sparse.csr_matrix,
sp.sparse.csr_matrix,
sp.sparse.csr_matrix,
]:
"""Split the basegraph into multiple sub-matrices such that efficient
encoding is possible.
"""
g = 4 # code property (always fixed for 5G)
mb = bm.shape[0] # number of CN rows in basegraph (BG property)
bm_a = bm[0:g, 0:k_b]
bm_b = bm[0:g, k_b : (k_b + g)]
bm_c1 = bm[g:mb, 0:k_b]
bm_c2 = bm[g:mb, k_b : (k_b + g)]
# H could be sliced immediately (but easier to implement if based on B)
hm_a = self._lift_basegraph(bm_a, z)
hm_c1 = self._lift_basegraph(bm_c1, z)
hm_c2 = self._lift_basegraph(bm_c2, z)
hm_b_inv = self._find_hm_b_inv(bm_b, z, bg)
return hm_a, hm_b_inv, hm_c1, hm_c2
def _find_hm_b_inv(
self, bm_b: np.ndarray, z: int, bg: str
) -> sp.sparse.csr_matrix:
"""For encoding we need to find the inverse of `hm_b` such that
`hm_b^-1 * hm_b = I`.
Could be done sparse.
For BG1 the structure of hm_b is given as (for all values of i_ls)
hm_b =
[P_A I 0 0
P_B I I 0
0 0 I I
P_A 0 0 I]
where P_B and P_A are shifted identities.
The inverse can be found by solving a linear system of equations
hm_b_inv =
[P_B^-1, P_B^-1, P_B^-1, P_B^-1,
I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1, I+P_A*P_B^-1,
P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1].
For bg2 the structure of hm_b is given as (for all values of i_ls)
hm_b =
[P_A I 0 0
0 I I 0
P_B 0 I I
P_A 0 0 I]
where P_B and P_A are shifted identities.
The inverse can be found by solving a linear system of equations
hm_b_inv =
[P_B^-1, P_B^-1, P_B^-1, P_B^-1,
I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
I+P_A*P_B^-1, I+P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1]
Note: the inverse of B is simply a shifted identity matrix with
negative shift direction.
"""
# Permutation indices
pm_a = int(bm_b[0, 0])
if bg == "bg1":
pm_b_inv = int(-bm_b[1, 0])
else: # structure of B is slightly different for bg2
pm_b_inv = int(-bm_b[2, 0])
hm_b_inv = np.zeros([4 * z, 4 * z])
im = np.eye(z)
am = np.roll(im, pm_a, axis=1)
b_inv = np.roll(im, pm_b_inv, axis=1)
ab_inv = np.matmul(am, b_inv)
# Row 0
hm_b_inv[0:z, 0:z] = b_inv
hm_b_inv[0:z, z : 2 * z] = b_inv
hm_b_inv[0:z, 2 * z : 3 * z] = b_inv
hm_b_inv[0:z, 3 * z : 4 * z] = b_inv
# Row 1
hm_b_inv[z : 2 * z, 0:z] = im + ab_inv
hm_b_inv[z : 2 * z, z : 2 * z] = ab_inv
hm_b_inv[z : 2 * z, 2 * z : 3 * z] = ab_inv
hm_b_inv[z : 2 * z, 3 * z : 4 * z] = ab_inv
# Row 2
if bg == "bg1":
hm_b_inv[2 * z : 3 * z, 0:z] = ab_inv
hm_b_inv[2 * z : 3 * z, z : 2 * z] = ab_inv
hm_b_inv[2 * z : 3 * z, 2 * z : 3 * z] = im + ab_inv
hm_b_inv[2 * z : 3 * z, 3 * z : 4 * z] = im + ab_inv
else: # for bg2 the structure is slightly different
hm_b_inv[2 * z : 3 * z, 0:z] = im + ab_inv
hm_b_inv[2 * z : 3 * z, z : 2 * z] = im + ab_inv
hm_b_inv[2 * z : 3 * z, 2 * z : 3 * z] = ab_inv
hm_b_inv[2 * z : 3 * z, 3 * z : 4 * z] = ab_inv
# Row 3
hm_b_inv[3 * z : 4 * z, 0:z] = ab_inv
hm_b_inv[3 * z : 4 * z, z : 2 * z] = ab_inv
hm_b_inv[3 * z : 4 * z, 2 * z : 3 * z] = ab_inv
hm_b_inv[3 * z : 4 * z, 3 * z : 4 * z] = im + ab_inv
# Return results as sparse matrix
return sp.sparse.csr_matrix(hm_b_inv)
def _mat_to_ind(self, mat: sp.sparse.csr_matrix) -> np.ndarray:
"""Helper to transform matrix into index representation for
gather. An index pointing to the ``last_ind+1`` is used for
non-existing edges due to irregular degrees.
"""
m = mat.shape[0]
n = mat.shape[1]
# Transpose mat for sorted column format
c_idx, r_idx, _ = sp.sparse.find(mat.transpose())
# Sort indices explicitly, as scipy.sparse.find changed from column to
# row sorting in scipy>=1.11
idx = np.argsort(r_idx)
c_idx = c_idx[idx]
r_idx = r_idx[idx]
# Find max number of non-zero entries
n_max = np.max(mat.getnnz(axis=1))
# Init index array with n (pointer to last_ind+1, will be a default
# value)
gat_idx = np.zeros([m, n_max]) + n
r_val = -1
c_val = 0
for idx in range(len(c_idx)):
# Check if same row or if a new row starts
if r_idx[idx] != r_val:
r_val = r_idx[idx]
c_val = 0
gat_idx[r_val, c_val] = c_idx[idx]
c_val += 1
return gat_idx.astype(np.int32)
def _matmul_gather(self, mat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Implements a fast sparse matmul via gather function."""
# Add 0 entry for gather-reduce_sum operation
# (otherwise ragged Tensors are required)
bs = vec.shape[0]
vec = torch.cat([vec, torch.zeros(bs, 1, dtype=vec.dtype, device=vec.device)], dim=1)
# Gather and sum
retval = vec[:, mat] # [bs, m, n_max]
retval = retval.sum(dim=-1)
return retval
def _encode_fast(self, s: torch.Tensor) -> torch.Tensor:
"""Main encoding function based on gathering function."""
p_a = self._matmul_gather(self._pcm_a_ind, s)
p_a = self._matmul_gather(self._pcm_b_inv_ind, p_a)
# Calc second part of parity bits p_b
# second parities are given by C_1*s' + C_2*p_a' + p_b' = 0
p_b_1 = self._matmul_gather(self._pcm_c1_ind, s)
p_b_2 = self._matmul_gather(self._pcm_c2_ind, p_a)
p_b = p_b_1 + p_b_2
c = torch.cat([s, p_a, p_b], dim=1)
# Faster implementation of mod-2 operation
c = int_mod_2(c)
c = c.unsqueeze(-1) # returns nx1 vector
return c
[docs]
def build(self, input_shape: tuple, **kwargs) -> None:
"""Build block and check for valid input shapes."""
if input_shape[-1] != self._k:
raise ValueError(f"Last dimension must be of length k={self._k}.")
@torch.compiler.disable
def _validate_input(self, u: torch.Tensor) -> None:
"""Validate input tensor for binary values (disabled from compilation).
This method is excluded from torch.compile to avoid recompilation
issues caused by the mutable _check_input flag.
"""
if self._check_input:
is_binary = ((u == 0) | (u == 1)).all()
if not is_binary:
raise ValueError("Input must be binary.")
# Input datatype consistency should only be evaluated once
self._check_input = False
def call(
self,
bits: torch.Tensor,
/,
*,
rv: Optional[List[int]] = None,
) -> torch.Tensor:
"""5G LDPC encoding including rate-matching.
:param bits: ``[..., k]`` information bits to be encoded.
:param rv: Redundancy version indices (0–3) for HARQ-IR.
If `None`, standard RV 0 rate-matching is used.
:output: ``[..., n]`` or ``[..., len(rv), n]`` encoded codewords.
"""
# Validate rv early
if rv is not None:
rv = list(rv)
if not rv:
raise ValueError("rv must be a non-empty list of RV indices.")
for v in rv:
if v not in (0, 1, 2, 3):
raise ValueError(f"Invalid RV index {v}; must be 0–3.")
input_shape = list(bits.shape)
u = bits.reshape(-1, input_shape[-1])
self._validate_input(u)
batch_size = u.shape[0]
# Pad with filler bits and encode
u_fill = torch.cat(
[
u,
torch.zeros(
batch_size, self._k_ldpc - self._k,
dtype=u.dtype, device=u.device,
),
],
dim=1,
)
c = self._encode_fast(u_fill.to(self.dtype))
c = c.reshape(batch_size, self._n_ldpc)
# Remove filler bits → compressed codeword
c_no_filler = torch.cat(
[c[:, :self._k], c[:, self._k_ldpc:]], dim=1
)
# --- Rate matching ------------------------------------------------
# Compressed RM buffer: skip first 2Z punctured positions
c_rm = c_no_filler[:, 2 * self._z :] # [batch, n_cb_comp]
if rv is None:
c_out = c_rm[:, :self._n] # [batch, n]
output_shape = input_shape[:-1] + [self.n]
else:
starts = self.get_start_positions_comp(rv)
n_cb_comp = self.n_cb_comp
rv_slices = []
for start in starts:
if start + self._n <= n_cb_comp:
rv_slices.append(c_rm[:, start:start + self._n])
else:
first_len = n_cb_comp - start
rv_slices.append(torch.cat(
[c_rm[:, start:], c_rm[:, :self._n - first_len]],
dim=1,
))
c_out = torch.stack(rv_slices, dim=1) # [batch, num_rvs, n]
output_shape = input_shape[:-1] + [len(rv), self.n]
# Output interleaver (Sec. 5.4.2.2) — works on last dim for any rank
if self._num_bits_per_symbol is not None:
c_out = c_out[..., self._out_int]
output_shape[0] = -1
return c_out.reshape(output_shape).to(bits.dtype)