#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Convolutional code Viterbi and BCJR decoding."""
from typing import Optional, Tuple, Union
import warnings
import numpy as np
import torch
from sionna.phy import Block
from sionna.phy.fec.utils import int2bin, int_mod_2
from sionna.phy.fec.conv.utils import resolve_gen_poly, Trellis
__all__ = ["ViterbiDecoder", "BCJRDecoder"]
[docs]
class ViterbiDecoder(Block):
r"""Applies Viterbi decoding to a sequence of noisy codeword bits.
Implements the Viterbi decoding algorithm :cite:p:`Viterbi` that returns an
estimate of the information bits for a noisy convolutional codeword.
Takes as input either LLR values (``method`` = ``'soft_llr'``) or hard
bit values (``method`` = ``'hard'``) and returns a hard decided estimation
of the information bits.
:param encoder: If ``encoder`` is provided as input, the following input
parameters are not required and will be ignored: ``gen_poly``,
``rate``, ``constraint_length``, ``rsc``, ``terminate``. They will be
inferred from the ``encoder`` object itself. If `None`, the above
parameters must be provided explicitly.
:param gen_poly: Tuple of strings with each string being a 0,1 sequence.
If `None`, ``rate`` and ``constraint_length`` must be provided.
:param rate: Valid values are 1/3 and 0.5. Only required if ``gen_poly``
is `None`.
:param constraint_length: Valid values are between 3 and 8 inclusive.
Only required if ``gen_poly`` is `None`.
:param rsc: Boolean flag indicating whether the encoder is
recursive-systematic for given generator polynomials.
`True` indicates encoder is recursive-systematic.
`False` indicates encoder is feed-forward non-systematic.
Defaults to `False`.
:param terminate: Boolean flag indicating whether the codeword is
terminated.
`True` indicates codeword is terminated to all-zero state.
`False` indicates codeword is not terminated.
Defaults to `False`.
:param method: Valid values are ``'soft_llr'`` or ``'hard'``. In computing
path metrics, ``'soft_llr'`` expects channel LLRs as input.
``'hard'`` assumes a binary symmetric channel (BSC) with 0/1 values
as inputs. In case of ``'hard'``, inputs will be quantized to 0/1
values.
:param return_info_bits: Boolean flag indicating whether only the
information bits or all codeword bits are returned. Defaults to `True`.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input inputs: [..., n], `torch.float`.
Tensor containing the (noisy) channel output symbols where ``n``
denotes the codeword length.
:output output: [..., rate \* n], `torch.float`.
Binary tensor containing the estimates of the information bit tensor.
.. rubric:: Notes
A full implementation of the decoder rather than a windowed approach
is used. For a given codeword of duration ``T``, the path metric is
computed from time ``0`` to ``T`` and the path with optimal metric at
time ``T`` is selected. The optimal path is then traced back from ``T``
to ``0`` to output the estimate of the information bit vector used to
encode. For larger codewords, note that the current method is sub-optimal
in terms of memory utilization and latency.
This method is also excluded from ``torch.compile`` using
``@torch.compiler.disable`` because the Viterbi algorithm's inherently
sequential structure (forward pass, traceback, output extraction) causes
extremely long compilation times due to loop unrolling.
.. rubric:: Examples
.. code-block:: python
from sionna.phy.fec.conv import ViterbiDecoder
decoder = ViterbiDecoder(rate=0.5, constraint_length=5)
llr = torch.randn(10, 200) # Received LLRs
u_hat = decoder(llr)
print(u_hat.shape)
# torch.Size([10, 100])
"""
def __init__(
self,
*,
encoder: Optional["ConvEncoder"] = None,
gen_poly: Optional[Tuple[str, ...]] = None,
rate: float = 1/2,
constraint_length: int = 3,
rsc: bool = False,
terminate: bool = False,
method: str = 'soft_llr',
return_info_bits: bool = True,
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if encoder is not None:
self._gen_poly = encoder.gen_poly
self._trellis = encoder.trellis
self._terminate = encoder.terminate
if self._trellis.device != self.device:
self._trellis.to(self.device)
else:
self._gen_poly = resolve_gen_poly(gen_poly, rate,
constraint_length)
self._trellis = Trellis(self.gen_poly, rsc=rsc, device=self.device)
self._terminate = terminate
self._coderate_desired = 1 / len(self.gen_poly)
self._mu = self._trellis.mu
if method not in ('soft_llr', 'hard'):
raise ValueError("method must be 'soft_llr' or 'hard'.")
# conv_k denotes number of input bit streams
# Can only be 1 in current implementation
self._conv_k = self._trellis.conv_k
# conv_n denotes number of output bits for conv_k input bits
self._conv_n = self._trellis.conv_n
# For conv codes, the code dimensions are unknown during initialization
self._k = None
self._n = None
self._num_syms = None
self._ni = 2**self._conv_k
self._no = 2**self._conv_n
self._ns = self._trellis.ns
self._method = method
self._return_info_bits = return_info_bits
# If i->j state transition emits symbol k, gather with ipst_op_idx
# gathers (i,k) element from input in row j.
self._ipst_op_idx = None
# Pre-computed output bit patterns for branch metric calculation
# Register buffer placeholder for CUDAGraph compatibility
self.register_buffer("_op_bits", None)
@property
def gen_poly(self) -> Tuple[str, ...]:
"""Generator polynomial used by the encoder."""
return self._gen_poly
@property
def coderate(self) -> float:
"""Rate of the code used in the encoder."""
if self.terminate and self._n is None:
warnings.warn(
"Due to termination, the true coderate is lower "
"than the returned design rate. "
"The exact true rate is dependent on the value of n and "
"hence cannot be computed before the first call().")
self._coderate = self._coderate_desired
elif self.terminate and self._n is not None:
k = self._coderate_desired * self._n - self._mu
self._coderate = k / self._n
else:
self._coderate = self._coderate_desired
return self._coderate
@property
def trellis(self) -> Trellis:
"""Trellis object used during encoding."""
return self._trellis
@property
def terminate(self) -> bool:
"""Indicates if the encoder is terminated during codeword generation."""
return self._terminate
@property
def k(self) -> Optional[int]:
"""Number of information bits per codeword."""
if self._k is None:
warnings.warn("The value of k cannot be computed before the "
"first call().")
return self._k
@property
def n(self) -> Optional[int]:
"""Number of codeword bits."""
if self._n is None:
warnings.warn("The value of n cannot be computed before the "
"first call().")
return self._n
def _mask_by_tonode(self) -> torch.Tensor:
"""Creates index matrix for gathering by to-node.
Returns Ns x Ni x 2 index matrix. When applied as gather index on a
Ns x num_ops matrix ((i,j) denoting metric for prev_st=i and output=j)
the output is matrix sorted by next_state. Row i in output
denotes the 2 possible metrics for transition to state i.
"""
cnst = self._ns * self._ni
from_nodes_vec = self._trellis.from_nodes.reshape(cnst)
op_idx = self._trellis.op_by_tonode.reshape(cnst)
st_op_idx = torch.stack([from_nodes_vec, op_idx], dim=-1)
st_op_idx = st_op_idx.reshape(self._ns, self._ni, 2)
return st_op_idx
def _bmcalc(self, y: torch.Tensor) -> torch.Tensor:
"""Calculate branch metrics for a given noisy codeword tensor.
For each time period t, computes the distance of symbol vector y[t]
from each possible output symbol. The distance metric is L2 distance
if decoder parameter method is 'soft'. The distance metric is L1
distance if parameter method is 'hard'.
"""
batch_size = y.shape[0]
# Reshape y to [bs, num_syms, conv_n]
y_reshaped = y.reshape(batch_size, -1, self._conv_n)
num_syms = y_reshaped.shape[1]
# op_bits: [no, conv_n] - pre-computed in build()
# Expand for broadcasting: [1, 1, no, conv_n]
op_bits_exp = self._op_bits.unsqueeze(0).unsqueeze(0)
# y_reshaped: [bs, num_syms, 1, conv_n]
y_exp = y_reshaped.unsqueeze(2)
if self._method == 'soft_llr':
op_mat_sign = 1 - 2. * op_bits_exp # [1, 1, no, conv_n]
llr_sign = -1. * y_exp * op_mat_sign # [bs, num_syms, no, conv_n]
# Sum of LLR*(sign of bit) for each symbol
bm = llr_sign.sum(dim=-1) # [bs, num_syms, no]
bm = bm.permute(0, 2, 1) # [bs, no, num_syms]
else: # method == 'hard'
diffabs = torch.abs(y_exp - op_bits_exp) # [bs, num_syms, no, conv_n]
# Manhattan distance of symbols
bm = diffabs.sum(dim=-1) # [bs, num_syms, no]
bm = bm.permute(0, 2, 1) # [bs, no, num_syms]
return bm
def _update_fwd(
self,
init_cm: torch.Tensor,
bm_mat: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass computing cumulative metrics and traceback states."""
batch_size = init_cm.shape[0]
cm_list = []
tb_list = []
# Pre-compute gather indices for vectorized gather-by-tonode
# _ipst_op_idx[to_st, inp_idx, :] = (from_st, op_sym)
from_st_idx = self._ipst_op_idx[:, :, 0] # [ns, ni]
op_sym_idx = self._ipst_op_idx[:, :, 1] # [ns, ni]
prev_cm = init_cm
for sym in range(self._num_syms):
metrics_t = bm_mat[..., sym] # [bs, no]
# Ns x No matrix - (s,j) is path_metric at state s with
# transition op=j
sum_metric = prev_cm.unsqueeze(2) + metrics_t.unsqueeze(1)
# [bs, ns, no]
# Vectorized gather by to-node using advanced indexing
# sum_metric_bytonode[b, to_st, inp_idx] = sum_metric[b, from_st, op_sym]
sum_metric_bytonode = sum_metric[:, from_st_idx, op_sym_idx]
# Get minimum metric and corresponding predecessor index
tb_state_idx = sum_metric_bytonode.argmin(dim=2) # [bs, ns]
# Vectorized: get the actual from-states for traceback
# from_nodes[to_st, :] gives possible predecessors for each to_st
# tb_state_idx[:, to_st] selects which predecessor (0 or 1)
# Result: tb_states[b, to_st] = from_nodes[to_st, tb_state_idx[b, to_st]]
from_nodes_exp = self._trellis.from_nodes.unsqueeze(0).expand(
batch_size, -1, -1) # [bs, ns, ni]
tb_states = from_nodes_exp.gather(
2, tb_state_idx.unsqueeze(2).long()
).squeeze(2).to(torch.int32)
cum_t = sum_metric_bytonode.min(dim=2).values
cm_list.append(cum_t)
tb_list.append(tb_states)
prev_cm = cum_t
cm = torch.stack(cm_list, dim=-1) # [bs, ns, num_syms]
tb = torch.stack(tb_list, dim=-1) # [bs, ns, num_syms]
return cm, tb
def _optimal_path(
self,
cm_: torch.Tensor,
tb_: torch.Tensor,
) -> torch.Tensor:
"""Compute optimal path (state at each time t) given cm_ & tb_.
:param cm_: Cumulative metrics for each state at time t [bs, ns, T]
:param tb_: Traceback state for each state at time t [bs, ns, T]
:output opt_path: Optimal path of shape [bs, T]
"""
batch_size = cm_.shape[0]
num_syms = tb_.shape[-1]
optst_list = [None] * num_syms
if self._terminate:
opt_term_state = torch.zeros(batch_size, dtype=torch.int32,
device=self.device)
else:
opt_term_state = cm_[:, :, -1].argmin(dim=1).to(torch.int32)
optst_list[num_syms - 1] = opt_term_state
for sym in range(num_syms - 1, 0, -1):
opt_st = optst_list[sym]
# Get the traceback state for each batch element
opt_st_tminus1 = tb_[:, :, sym].gather(
1, opt_st.unsqueeze(1).long()
).squeeze(1).to(torch.int32)
optst_list[sym - 1] = opt_st_tminus1
return torch.stack(optst_list, dim=1) # [bs, num_syms]
def _op_bits_path(
self,
paths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Given a path, compute the input bit stream that results in the path.
Used in call() where the input is optimal path (seq of states) such
as the path returned by _optimal_path.
"""
paths = paths.to(torch.int32)
batch_size = paths.shape[0]
num_transitions = paths.shape[-1] - 1
ip_bits_list = []
dec_syms_list = []
ni = self._trellis.ni
ip_sym_mask = torch.arange(ni, device=self.device).unsqueeze(0)
for sym in range(1, paths.shape[-1]):
prev_st = paths[:, sym - 1]
curr_st = paths[:, sym]
# Get output symbol for transition
dec_ = self._trellis.op_mat[prev_st, curr_st]
dec_syms_list.append(dec_)
# Find which input bit caused the transition
# to_nodes[prev_st] gives the 2 possible next states
to_states = self._trellis.to_nodes[prev_st] # [bs, ni]
match_st = (to_states == curr_st.unsqueeze(1)) # [bs, ni]
# Get the input bit (0 or 1) that matches
ip_bit = (match_st * ip_sym_mask).sum(dim=-1)
ip_bits_list.append(ip_bit)
ip_bit_vec_est = torch.stack(ip_bits_list, dim=1)
ip_sym_vec_est = torch.stack(dec_syms_list, dim=1)
return ip_bit_vec_est, ip_sym_vec_est
[docs]
def build(self, input_shape: torch.Size):
"""Build block and check dimensions."""
self._n = input_shape[-1]
divisible = self._n % self._conv_n
if divisible != 0:
raise ValueError('Length of codeword should be divisible by '
'number of output bits per symbol.')
self._num_syms = int(self._n * self._coderate_desired)
self._num_term_syms = self._mu if self.terminate else 0
self._k = self._num_syms - self._num_term_syms
# Build index mask
self._ipst_op_idx = self._mask_by_tonode()
# Pre-compute output bit patterns for branch metric calculation
# Shape: [no, conv_n]
op_bits = np.stack(
[int2bin(op, self._conv_n) for op in range(self._no)]
)
# Register as buffer for CUDAGraph compatibility
self.register_buffer("_op_bits", torch.tensor(op_bits, dtype=self.dtype,
device=self.device))
# Move trellis to correct device if needed
if self._trellis.device != self.device:
self._trellis.to(self.device)
@torch.compiler.disable
def call(self, inputs: torch.Tensor, /) -> torch.Tensor:
"""Viterbi decoding function.
:param inputs: Noisy codeword tensor of shape [..., n] where ``n`` is
the codeword length. All leading dimensions are treated as batch
dimensions.
:output output: Decoded information bits of shape [..., k] if
``return_info_bits`` is `True`, otherwise [..., n].
"""
LARGEDIST = 2.**20
# Ensure build() has been called
if self._n is None:
self.build(inputs.shape)
if self._method == 'hard':
# Ensure binary values
inputs = int_mod_2(inputs)
elif self._method == 'soft_llr':
inputs = -1. * inputs
output_shape = list(inputs.shape)
y_resh = inputs.reshape(-1, self._n)
if self._return_info_bits:
output_shape[-1] = self._k
else:
output_shape[-1] = self._n
batch_size = y_resh.shape[0]
# Branch metrics matrix for a given y
bm_mat = self._bmcalc(y_resh)
init_cm = torch.full((self._ns,), LARGEDIST, dtype=self.dtype,
device=self.device)
init_cm[0] = 0.0
prev_cm = init_cm.unsqueeze(0).expand(batch_size, -1).clone()
# Forward pass computing cumulative metrics and traceback
cm, tb = self._update_fwd(prev_cm, bm_mat)
zero_st = torch.zeros((batch_size, 1), dtype=torch.int32,
device=self.device)
opt_path = self._optimal_path(cm, tb)
opt_path = torch.cat((zero_st, opt_path), dim=1)
msghat, cwhat = self._op_bits_path(opt_path)
if self._return_info_bits:
msghat = msghat[..., :self._k]
output = msghat.to(self.dtype)
else:
output = cwhat.to(self.dtype)
output_reshaped = output.reshape(output_shape)
return output_reshaped
[docs]
class BCJRDecoder(Block):
r"""Applies BCJR decoding to a sequence of noisy codeword bits.
Implements the BCJR decoding algorithm :cite:p:`BCJR` that returns an
estimate of the information bits for a noisy convolutional codeword.
Takes as input channel LLRs and optional a priori LLRs.
Returns an estimate of the information bits, either output LLRs
(``hard_out`` = `False`) or hard decoded bits (``hard_out`` = `True`),
respectively.
:param encoder: If ``encoder`` is provided as input, the following input
parameters are not required and will be ignored: ``gen_poly``,
``rate``, ``constraint_length``, ``rsc``, ``terminate``. They will be
inferred from the ``encoder`` object itself. If `None`, the above
parameters must be provided explicitly.
:param gen_poly: Tuple of strings with each string being a 0,1 sequence.
If `None`, ``rate`` and ``constraint_length`` must be provided.
:param rate: Valid values are 1/3 and 1/2. Only required if ``gen_poly``
is `None`.
:param constraint_length: Valid values are between 3 and 8 inclusive.
Only required if ``gen_poly`` is `None`.
:param rsc: Boolean flag indicating whether the encoder is
recursive-systematic for given generator polynomials.
`True` indicates encoder is recursive-systematic.
`False` indicates encoder is feed-forward non-systematic.
Defaults to `False`.
:param terminate: Boolean flag indicating whether the codeword is
terminated.
`True` indicates codeword is terminated to all-zero state.
`False` indicates codeword is not terminated.
Defaults to `False`.
:param hard_out: Boolean flag indicating whether to output hard or soft
decisions on the decoded information vector.
`True` implies a hard-decoded information vector of 0/1's as output.
`False` implies output is decoded LLRs of the information.
Defaults to `True`.
:param algorithm: Indicates the implemented BCJR algorithm,
where ``'map'`` denotes the exact MAP algorithm, ``'log'`` indicates
the exact MAP implementation but in log-domain, and ``'maxlog'``
indicates the approximated MAP implementation in log-domain where
:math:`\log(e^{a}+e^{b}) \sim \max(a,b)`.
:param precision: Precision used for internal calculations and outputs.
If `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation (e.g., 'cpu', 'cuda:0').
If `None`, :attr:`~sionna.phy.config.Config.device` is used.
:input llr_ch: [..., n], `torch.float`.
Tensor containing the (noisy) channel LLRs, where ``n`` denotes the
codeword length.
:input llr_a: [..., k], `None` (default) | `torch.float`.
Tensor containing the a priori information of each information bit.
Implicitly assumed to be 0 if only ``llr_ch`` is provided.
:output msghat: `torch.float`.
Tensor of shape ``[..., coderate*n]`` containing the estimates of the
information bit tensor.
.. rubric:: Examples
.. code-block:: python
from sionna.phy.fec.conv import BCJRDecoder
decoder = BCJRDecoder(rate=0.5, constraint_length=5)
llr = torch.randn(10, 200) # Received LLRs
u_hat = decoder(llr)
print(u_hat.shape)
# torch.Size([10, 100])
"""
def __init__(
self,
encoder: Optional["ConvEncoder"] = None,
gen_poly: Optional[Tuple[str, ...]] = None,
rate: float = 1/2,
constraint_length: int = 3,
rsc: bool = False,
terminate: bool = False,
hard_out: bool = True,
algorithm: str = 'map',
precision: Optional[str] = None,
device: Optional[str] = None,
**kwargs,
):
super().__init__(precision=precision, device=device, **kwargs)
if encoder is not None:
self._gen_poly = encoder.gen_poly
self._trellis = encoder.trellis
self._terminate = encoder.terminate
if self._trellis.device != self.device:
self._trellis.to(self.device)
else:
self._gen_poly = resolve_gen_poly(gen_poly, rate,
constraint_length)
self._trellis = Trellis(self.gen_poly, rsc=rsc, device=self.device)
self._terminate = terminate
valid_algorithms = ['map', 'log', 'maxlog']
if algorithm not in valid_algorithms:
raise ValueError("algorithm must be one of map, log or maxlog")
self._coderate_desired = 1 / len(self._gen_poly)
self._mu = len(self._gen_poly[0]) - 1
self._num_term_bits = None
self._num_term_syms = None
# conv_k denotes number of input bit streams
# Can only be 1 in current implementation
self._conv_k = self._trellis.conv_k
if self._conv_k != 1:
raise NotImplementedError("Only conv_k=1 currently supported.")
self._mu = self._trellis.mu
# conv_n denotes number of output bits for conv_k input bits
self._conv_n = self._trellis.conv_n
# Length of Info-bit vector
self._k = None
# Length of codeword, including termination bits
self._n = None
# Number of encoding periods or state transitions
self._num_syms = None
self._ni = 2**self._conv_k
self._no = 2**self._conv_n
self._ns = self._trellis.ns
self._hard_out = hard_out
self._algorithm = algorithm
self._ipst_op_idx = None
self._ipst_ip_idx = None
# Pre-computed output bit patterns for branch metric calculation
# Register buffer placeholder for CUDAGraph compatibility
self.register_buffer("_op_bits", None)
@property
def gen_poly(self) -> Tuple[str, ...]:
"""Generator polynomial used by the encoder."""
return self._gen_poly
@property
def coderate(self) -> float:
"""Rate of the code used in the encoder."""
if self.terminate and self._n is None:
warnings.warn(
"Due to termination, the true coderate is lower "
"than the returned design rate. "
"The exact true rate is dependent on the value of n and "
"hence cannot be computed before the first call().")
self._coderate = self._coderate_desired
elif self.terminate and self._n is not None:
k = self._coderate_desired * self._n - self._mu
self._coderate = k / self._n
else:
self._coderate = self._coderate_desired
return self._coderate
@property
def trellis(self) -> Trellis:
"""Trellis object used during encoding."""
return self._trellis
@property
def terminate(self) -> bool:
"""Indicates if the encoder is terminated during codeword generation."""
return self._terminate
@property
def k(self) -> Optional[int]:
"""Number of information bits per codeword."""
if self._k is None:
warnings.warn("The value of k cannot be computed before the "
"first call().")
return self._k
@property
def n(self) -> Optional[int]:
"""Number of codeword bits."""
if self._n is None:
warnings.warn("The value of n cannot be computed before the "
"first call().")
return self._n
def _mask_by_tonode(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Creates index matrices for gathering by to-node.
Assume i->j a valid state transition given info-bit b & emits symbol k.
Returns following two _ns x _ni x 2 matrices:
- st_op_idx: jth row contains (i,k) tuples
- st_ip_idx: jth row contains (i,b) tuples
When applied as gather on a _ns x _no matrix, the output is
matrix sorted by next_state.
"""
cnst = self._ns * self._ni
from_nodes_vec = self._trellis.from_nodes.reshape(cnst)
op_idx = self._trellis.op_by_tonode.reshape(cnst)
st_op_idx = torch.stack([from_nodes_vec, op_idx], dim=-1)
st_op_idx = st_op_idx.reshape(self._ns, self._ni, 2)
ip_idx = self._trellis.ip_by_tonode.reshape(cnst)
st_ip_idx = torch.stack([from_nodes_vec, ip_idx], dim=-1)
st_ip_idx = st_ip_idx.reshape(self._ns, self._ni, 2)
return st_op_idx, st_ip_idx
def _bmcalc(self, llr_in: torch.Tensor) -> torch.Tensor:
"""Calculate branch gamma metrics for a given noisy codeword tensor.
For each time period t, computes the 'distance' of symbol
vector y[t] from each possible output symbol.
"""
batch_size = llr_in.shape[0]
# Reshape llr_in to [bs, num_syms, conv_n]
llr_reshaped = llr_in.reshape(batch_size, -1, self._conv_n)
# op_bits: [no, conv_n] - pre-computed in build()
# Expand for broadcasting: [1, 1, no, conv_n]
op_bits_exp = self._op_bits.unsqueeze(0).unsqueeze(0)
op_mat_sign = 1. - 2. * op_bits_exp
# llr_reshaped: [bs, num_syms, 1, conv_n]
llr_exp = llr_reshaped.unsqueeze(2)
llr_sign = llr_exp * op_mat_sign # [bs, num_syms, no, conv_n]
half_llr_sign = 0.5 * llr_sign
if self._algorithm in ['log', 'maxlog']:
bm = half_llr_sign.sum(dim=-1) # [bs, num_syms, no]
else:
bm = torch.exp(half_llr_sign.sum(dim=-1))
bm = bm.permute(0, 2, 1).contiguous() # [bs, no, num_syms]
return bm
def _initialize(
self,
llr_ch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize alpha and beta tensors."""
batch_size = llr_ch.shape[0]
if self._algorithm in ['log', 'maxlog']:
init_vals = (float('-inf'), 0.0)
else:
init_vals = (0.0, 1.0)
alpha_init = torch.full((self._ns,), init_vals[0], dtype=self.dtype,
device=self.device)
alpha_init[0] = init_vals[1]
if not self._terminate:
eq_prob = 1. / self._ns
if self._algorithm in ['log', 'maxlog']:
eq_prob = np.log(eq_prob)
beta_init = torch.full((self._ns,), eq_prob, dtype=self.dtype,
device=self.device)
else:
beta_init = alpha_init.clone()
alpha_init = alpha_init.unsqueeze(0).expand(batch_size, -1).clone()
beta_init = beta_init.unsqueeze(0).expand(batch_size, -1).clone()
return alpha_init, beta_init
def _update_fwd(
self,
alph_init: torch.Tensor,
bm_mat: torch.Tensor,
llr: torch.Tensor,
) -> torch.Tensor:
"""Run forward update from time t=0 to t=k-1.
At each time t, computes alpha_t using alpha_t-1 and gamma_t.
Returns tensor of alpha values [bs, ns, num_syms+1].
"""
batch_size = alph_init.shape[0]
alph_list = [alph_init]
alph_prev = alph_init
# op_by_fromnode[from_st, input] = output symbol
op_mask = self._trellis.op_by_fromnode # [ns, ni]
ipbit_mat = torch.arange(self._ni, device=self.device).unsqueeze(0) \
.unsqueeze(0).expand(batch_size, self._ns, -1).contiguous() # [bs, ns, ni]
ipbitsign_mat = 1. - 2. * ipbit_mat.to(self.dtype)
# Pre-compute gather indices for vectorized gather-by-tonode
# _ipst_ip_idx[to_st, inp_idx, :] = (from_st, inp_bit)
# We need to gather alph_gam_prod[:, from_st, inp_bit] for all (to_st, inp_idx)
from_st_idx = self._ipst_ip_idx[:, :, 0].contiguous() # [ns, ni]
inp_bit_idx = self._ipst_ip_idx[:, :, 1].contiguous() # [ns, ni]
for t in range(self._num_syms):
bm_t = bm_mat[..., t].contiguous() # [bs, no]
llr_t = 0.5 * llr[..., t].unsqueeze(1).unsqueeze(2) # [bs, 1, 1]
# bm_byfromst[bs, from_st, input] = bm_t[bs, op_mask[from_st, input]]
bm_byfromst = bm_t[:, op_mask].contiguous() # [bs, ns, ni]
signed_half_llr = llr_t * ipbitsign_mat # [bs, ns, ni]
if self._algorithm in ['log', 'maxlog']:
llr_byfromst = signed_half_llr
gamma_byfromst = llr_byfromst + bm_byfromst
alph_gam_prod = (gamma_byfromst + alph_prev.unsqueeze(2)).contiguous()
else:
llr_byfromst = torch.exp(signed_half_llr)
gamma_byfromst = llr_byfromst * bm_byfromst
alph_gam_prod = (gamma_byfromst * alph_prev.unsqueeze(2)).contiguous()
# Vectorized gather by to-node using advanced indexing
# alph_gam_prod: [bs, ns, ni] indexed by [from_st_idx, inp_bit_idx]
# Result: alphgam_bytost[b, to_st, inp_idx] = alph_gam_prod[b, from_st_idx[to_st, inp_idx], inp_bit_idx[to_st, inp_idx]]
alphgam_bytost = alph_gam_prod[:, from_st_idx, inp_bit_idx].contiguous()
if self._algorithm == 'map':
alph_t = alphgam_bytost.sum(dim=-1)
alph_t_sum = alph_t.sum(dim=-1, keepdim=True)
alph_t = alph_t / alph_t_sum
elif self._algorithm == 'log':
alph_t = torch.logsumexp(alphgam_bytost, dim=-1)
else: # maxlog
alph_t = alphgam_bytost.max(dim=-1).values
alph_prev = alph_t
alph_list.append(alph_t)
return torch.stack(alph_list, dim=-1) # [bs, ns, num_syms+1]
def _update_bwd(
self,
beta_init: torch.Tensor,
bm_mat: torch.Tensor,
llr: torch.Tensor,
alpha_ta: torch.Tensor,
) -> torch.Tensor:
"""Run backward update from time t=k-1 to t=0.
At each time t, computes beta_t-1 using beta_t and gamma_t.
Returns LLRs for information bits for t=0,1,...,k-1.
"""
batch_size = beta_init.shape[0]
beta_next = beta_init
llr_op_list = [None] * self._num_syms
# op_mask[from_st, input] = output symbol
op_mask = self._trellis.op_by_fromnode # [ns, ni]
tonode_mask = self._trellis.to_nodes # [ns, ni]
ipbit_mat = torch.arange(self._ni, device=self.device).unsqueeze(0) \
.unsqueeze(0).expand(batch_size, self._ns, -1).contiguous() # [bs, ns, ni]
ipbitsign_mat = 1. - 2. * ipbit_mat.to(self.dtype)
for t in range(self._num_syms - 1, -1, -1):
bm_t = bm_mat[..., t].contiguous() # [bs, no]
llr_t = 0.5 * llr[..., t].unsqueeze(1).unsqueeze(2) # [bs, 1, 1]
signed_half_llr = llr_t * ipbitsign_mat
bm_byfromst = bm_t[:, op_mask].contiguous() # [bs, ns, ni]
if self._algorithm in ['log', 'maxlog']:
llr_byfromst = signed_half_llr
gamma_byfromst = (llr_byfromst + bm_byfromst).contiguous()
else:
llr_byfromst = torch.exp(signed_half_llr)
gamma_byfromst = (llr_byfromst * bm_byfromst).contiguous()
# beta_bytonode[bs, from_st, input] = beta_next[bs, to_nodes[from_st, input]]
beta_bytonode = beta_next[:, tonode_mask].contiguous() # [bs, ns, ni]
if self._algorithm not in ['log', 'maxlog']:
beta_gam_prod = gamma_byfromst * beta_bytonode
beta_t = beta_gam_prod.sum(dim=-1)
beta_t_sum = beta_t.sum(dim=-1, keepdim=True)
beta_t = beta_t / beta_t_sum
elif self._algorithm == 'log':
beta_gam_prod = gamma_byfromst + beta_bytonode
beta_t = torch.logsumexp(beta_gam_prod, dim=-1)
else: # maxlog
beta_gam_prod = gamma_byfromst + beta_bytonode
beta_t = beta_gam_prod.max(dim=-1).values
alph_t = alpha_ta[..., t].contiguous() # [bs, ns]
if self._algorithm not in ['log', 'maxlog']:
llr_op_t0 = alph_t * gamma_byfromst[..., 0].contiguous() * beta_bytonode[..., 0].contiguous()
llr_op_t1 = alph_t * gamma_byfromst[..., 1].contiguous() * beta_bytonode[..., 1].contiguous()
llr_op_t = torch.log(
llr_op_t0.sum(dim=-1) / llr_op_t1.sum(dim=-1)
)
else:
llr_op_t0 = alph_t + gamma_byfromst[..., 0].contiguous() + beta_bytonode[..., 0].contiguous()
llr_op_t1 = alph_t + gamma_byfromst[..., 1].contiguous() + beta_bytonode[..., 1].contiguous()
if self._algorithm == 'log':
llr_op_t = torch.logsumexp(llr_op_t0, dim=-1) - \
torch.logsumexp(llr_op_t1, dim=-1)
else: # maxlog
llr_op_t = llr_op_t0.max(dim=-1).values - \
llr_op_t1.max(dim=-1).values
llr_op_list[t] = llr_op_t
beta_next = beta_t
return torch.stack(llr_op_list, dim=-1) # [bs, num_syms]
[docs]
def build(self, llr_ch_shape: torch.Size, **kwargs):
"""Build block and check dimensions."""
self._n = llr_ch_shape[-1]
self._num_syms = int(self._n * self._coderate_desired)
self._num_term_syms = self._mu if self._terminate else 0
self._num_term_bits = int(self._num_term_syms / self._coderate_desired)
self._k = self._num_syms - self._num_term_syms
# Build index masks
self._ipst_op_idx, self._ipst_ip_idx = self._mask_by_tonode()
# Pre-compute output bit patterns for branch metric calculation
# Shape: [no, conv_n]
op_bits = np.stack(
[int2bin(op, self._conv_n) for op in range(self._no)]
)
# Register as buffer for CUDAGraph compatibility
self.register_buffer("_op_bits", torch.tensor(op_bits, dtype=self.dtype,
device=self.device))
# Move trellis to correct device if needed
if self._trellis.device != self.device:
self._trellis.to(self.device)
@torch.compiler.disable
def call(
self,
llr_ch: torch.Tensor,
/,
*,
llr_a: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""BCJR decoding function.
:param llr_ch: Noisy channel LLR tensor of shape [..., n] where ``n``
is the codeword length. All leading dimensions are treated as
batch dimensions.
:param llr_a: Optional a priori LLR tensor of shape [..., k] where
``k`` is the number of information bits. Implicitly assumed to be
0 if not provided.
:output msghat: Decoded information bits (or LLRs) of shape
[..., k].
"""
output_shape = list(llr_ch.shape)
input_device = llr_ch.device
# Allow different codeword lengths in eager mode
# Also ensure build() is called (needed for torch.compile compatibility)
if self._n is None or output_shape[-1] != self._n:
self._built = False
if torch.compiler.is_compiling():
# During compilation trace, we need concrete values
torch.compiler.disable(self.build)(llr_ch.shape)
else:
self.build(llr_ch.shape)
self._built = True
# Move module to input device if needed (for torch.compile compatibility)
if self._op_bits is not None and self._op_bits.device != input_device:
self.to(input_device)
output_shape[-1] = self._k
llr_ch = llr_ch.reshape(-1, self._n)
batch_size = llr_ch.shape[0]
if llr_a is None:
llr_a = torch.zeros(
batch_size, self._num_syms,
dtype=self.dtype, device=input_device
)
else:
llr_a = llr_a.reshape(-1, self._num_syms)
# Internally, we use more common LLR definition log(p(x=0)/p(x=1))
llr_ch = -1. * llr_ch
llr_a = -1. * llr_a
# Branch metrics matrix for a given y
bm_mat = self._bmcalc(llr_ch)
alpha_init, beta_init = self._initialize(llr_ch)
alph_ta = self._update_fwd(alpha_init, bm_mat, llr_a)
llr_op = self._update_bwd(beta_init, bm_mat, llr_a, alph_ta)
# Revert LLR definition
msghat = -1. * llr_op[..., :self._k]
if self._hard_out:
msghat = (msghat > 0.0).to(self.dtype)
msghat_reshaped = msghat.reshape(output_shape)
return msghat_reshaped