#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layers for decoding of linear codes."""
import tensorflow as tf
import numpy as np
import scipy as sp # for sparse H matrix computations
from tensorflow.keras.layers import Layer
from sionna.fec.utils import pcm2gm, int_mod_2, make_systematic
from sionna.utils import hard_decisions
import itertools
[docs]
class OSDecoder(Layer):
# pylint: disable=line-too-long
r"""OSDecoder(enc_mat=None, t=0, is_pcm=False, encoder=None, dtype=tf.float32, **kwargs)
Ordered statistics decoding (OSD) for binary, linear block codes.
This layer implements the OSD algorithm as proposed in [Fossorier]_ and,
thereby, approximates maximum likelihood decoding for a sufficiently large
order :math:`t`. The algorithm works for arbitrary linear block codes, but
has a high computational complexity for long codes.
The algorithm consists of the following steps:
1. Sort LLRs according to their reliability and apply the same column
permutation to the generator matrix.
2. Bring the permuted generator matrix into its systematic form
(so-called *most-reliable basis*).
3. Hard-decide and re-encode the :math:`k` most reliable bits and
discard the remaining :math:`n-k` received positions.
4. Generate all possible error patterns up to :math:`t` errors in the
:math:`k` most reliable positions find the most likely codeword within
these candidates.
This implementation of the OSD algorithm uses the LLR-based distance metric
from [Stimming_LLR_OSD]_ which simplifies the handling of higher-order
modulation schemes.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
enc_mat : [k, n] or [n-k, n], ndarray
Binary generator matrix of shape `[k, n]`. If ``is_pcm`` is
True, ``enc_mat`` is interpreted as parity-check matrix of shape
`[n-k, n]`.
t : int
Order of the OSD algorithm
is_pcm: bool
Defaults to False. If True, ``enc_mat`` is interpreted as parity-check
matrix.
encoder: Layer
Keras layer that implements a FEC encoder.
If not None, ``enc_mat`` will be ignored and the code as specified by he
encoder is used to initialize OSD.
dtype: tf.DType
Defaults to `tf.float32`. Defines the datatype for the output dtype.
Input
-----
llrs_ch: [...,n], tf.float32
2+D tensor containing the channel logits/llr values.
Output
------
: [...,n], tf.float32
2+D Tensor of same shape as ``llrs_ch`` containing
binary hard-decisions of all codeword bits.
Note
----
OS decoding is of high complexity and is only feasible for small values of
:math:`t` as :math:`{n \choose t}` patterns must be evaluated. The
advantage of OSD is that it works for arbitrary linear block codes and
provides an estimate of the expected ML performance for sufficiently large
:math:`t`. However, for some code families, more efficient decoding
algorithms with close to ML performance exist which can exploit certain
code specific properties. Examples of such decoders are the
:class:`~sionna.fec.conv.ViterbiDecoder` algorithm for convolutional codes
or the :class:`~sionna.fec.polar.decoding.PolarSCLDecoder` for Polar codes
(for a sufficiently large list size).
It is recommended to run the decoder in XLA mode as it
significantly reduces the memory complexity.
"""
def __init__(self,
enc_mat=None,
t=0,
is_pcm=False,
encoder=None,
dtype=tf.float32,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
assert isinstance(is_pcm, bool), 'is_pcm must be bool.'
self._llr_max = 100. # internal clipping value for llrs
if enc_mat is not None:
# check that gm is binary
if isinstance(enc_mat, np.ndarray):
assert np.array_equal(enc_mat, enc_mat.astype(bool)), \
'PC matrix must be binary.'
elif isinstance(enc_mat, sp.sparse.csr_matrix):
assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\
'PC matrix must be binary.'
elif isinstance(enc_mat, sp.sparse.csc_matrix):
assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\
'PC matrix must be binary.'
else:
raise TypeError("Unsupported dtype of pcm.")
if dtype not in (tf.float16, tf.float32, tf.float64):
raise ValueError(
'dtype must be {tf.float16, tf.float32, tf.float64}.')
assert (int(t)==t), "t must be int."
self._t = int(t)
if encoder is not None:
# test that encoder is already initialized (relevant for conv codes)
if encoder.k is None:
raise AttributeError("It seems as if the encoder is not "\
"initialized or has no attribute k.")
# encode identity matrix to get k basis vectors of the code
u = tf.expand_dims(tf.eye(encoder.k), axis=0)
# encode and remove batch_dim
self._gm = tf.cast(tf.squeeze(encoder(u), axis=0), self.dtype)
else:
assert (enc_mat is not None),\
"enc_mat cannot be None if no encoder is provided."
if is_pcm:
gm = pcm2gm(enc_mat)
else:
# check if gm is of full rank (raise error otherwise)
make_systematic(enc_mat)
gm = enc_mat
self._gm = tf.constant(gm, dtype=self.dtype)
self._k = self._gm.shape[0]
self._n = self._gm.shape[1]
# init error patterns
num_patterns = self._num_error_patterns(self._n, self._t)
# storage/computational complexity scales with n
num_symbols = num_patterns * self._n
if num_symbols>1e9: # number still to be optimized
print(f"Note: Required memory complexity is large for the "\
f"given code parameters and t={t}. Please consider small " \
f"batch-sizes to keep the inference complexity small and " \
f"activate XLA mode if possible." )
if num_symbols>1e11: # number still to be optimized
raise ResourceWarning("Due to its high complexity, OSD is not " \
"feasible for the selected parameters. " \
"Please consider using a smaller value for t.")
# pre-compute all error patterns
self._err_patterns = []
for t_i in range(1, t+1):
self._err_patterns.append(self._gen_error_patterns(self._k, t_i))
#########################################
# Public methods and properties
#########################################
@property
def gm(self):
"""Generator matrix of the code"""
return self._gm
@property
def n(self):
"""Codeword length"""
return self._n
@property
def k(self):
"""Number of information bits per codeword"""
return self._k
@property
def t(self):
"""Order of the OSD algorithm"""
return self._t
#########################
# Utility methods
#########################
def _num_error_patterns(self, n, t):
r"""Returns number of possible error patterns for t errors in n
positions, i.e., calculates :math:`{n \choose t}`.
Input
-----
n: int
length of vector.
t: int
number of errors.
"""
return sp.special.comb(n, t, exact=True, repetition=False)
def _gen_error_patterns(self, n, t):
r"""Returns list of all possible error patterns for t errors in n
positions.
Input
-----
n: int
Length of vector.
t: int
Number of errors.
Output
------
: [num_patterns, t], tf.int32
Tensor of size `num_patterns`=:math:`{n \choose t}` containing the
t error indices.
"""
err_patterns = []
for p in itertools.combinations(range(n), t):
err_patterns.append(p)
return tf.constant(err_patterns)
def _get_dist(self, llr, c_hat):
"""Distance function used for ML candidate selection.
Currently, the distance metric from Polar decoding [Stimming_LLR_OSD]_
literature is implemented.
Input
-----
llr: [bs, n], tf.float32
Received llrs of the channel observations.
c_hat: [bs, num_cand, n], tf.float32
Candidate codewords for which the distance to ``llr`` shall be
evaluated.
Output
------
: [bs, num_cand], tf.float32
Distance between ``llr`` and ``c_hat`` for each of the `num_cand`
codeword candidates.
Reference
---------
[Stimming_LLR_OSD] Alexios Balatsoukas-Stimming, Mani Bastani Parizi,
Andreas Burg, "LLR-Based Successive Cancellation List Decoding
of Polar Codes." IEEE Trans Signal Processing, 2015.
"""
# broadcast llr to all codeword candidates
llr = tf.expand_dims(llr, axis=1)
llr_sign = llr * (-2.*c_hat + 1.) # apply BPSK mapping
d = tf.math.log(1. + tf.exp(llr_sign))
return tf.reduce_mean(d, axis=2)
def _find_min_dist(self, llr_ch, ep, gm_mrb, c):
r"""Find error pattern which leads to minimum distance.
Input
-----
llr_ch: [bs, n], tf.float32
Channel observations as llrs after mrb sorting.
ep: [num_patterns, t], tf.int32
Tensor of size `num_patterns`=:math:`{n \choose t}` containing the
t error indices.
gm_mrb: [bs, k, n] tf.float32
Most reliable basis for each batch example.
c: [bs, n], tf.float32
Most reliable base codeword.
Output
------
: [bs], tf.float32
Distance of the most likely codeword to ``llr_ch`` after testing all
``ep`` error patterns.
: [bs, n], tf.float32
The most likely codeword after testing against all ``ep`` error
patterns.
"""
# generate all test candidates for each possible error pattern
e = tf.gather(gm_mrb, ep, axis=1)
e = tf.reduce_sum(e, axis=2)
e += tf.expand_dims(c, axis=1) # add to mrb codeword
c_cand = int_mod_2(e) # apply modulo-2 operation
# calculate distance for each candidate
# where c_cand has shape [bs, num_patterns, n]
d = self._get_dist(llr_ch, c_cand)
# find candidate index with smallest metric
idx = tf.argmin(d, axis=1)
c_hat = tf.gather(c_cand, idx, batch_dims=1)
d = tf.gather(d, idx, batch_dims=1)
return d, c_hat
def _find_mrb(self, gm):
"""Find most reliable basis for all generator matrices in batch.
Input
-----
gm: [bs, k, n] tf.float32
Generator matrix for each batch example.
Output
------
gm_mrb: [bs, k, n] tf.float32
Most reliable basis in systematic form for each batch example.
idx_sort: [bs, n] tf.int64
Indices of column permutations applied during mrb calculation.
"""
bs = tf.shape(gm)[0]
s = gm.shape
idx_pivot = tf.TensorArray(tf.int64, self._k, dynamic_size=False)
# bring gm in systematic form (by so-called pivot method)
for idx_c in tf.range(self._k):
# ensure shape to avoid XLA incompatibility with TF2.11 in tf.range
gm = tf.ensure_shape(gm, s)
# find pivot (i.e., first pos with index 1)
idx_p = tf.argmax(gm[:, idx_c, :], axis=-1)
# store pivot position
idx_pivot = idx_pivot.write(idx_c, idx_p)
# and eliminate the column in all other rows
r = tf.gather(gm, idx_p, batch_dims=1, axis=-1)
# ignore idx_c row itself by adding all-zero row
rz = tf.zeros((bs, 1), dtype=self.dtype)
r = tf.concat([r[:,:idx_c], rz , r[:,idx_c+1:]], axis=1)
# mask is zero at all rows where pivot position of this row is zero
mask = tf.tile(tf.expand_dims(r, axis=-1), (1, 1, self._n))
gm_off = tf.expand_dims(gm[:,idx_c,:], axis=1)
# update all row in parallel
gm = int_mod_2(gm + mask * gm_off) # account for binary operations
# pivot positions
idx_pivot = tf.transpose(idx_pivot.stack())
# find non-pivot positions (i.e., all indices that are not part of
# idx_pivot)
# solution 1: sets.difference() does not support XLA (unknown shapes)
#idx_parity = tf.sets.difference(idx_range, idx_pivot)
#idx_parity = tf.sparse.to_dense(idx_parity)
#idx_pivot = tf.reshape(idx_pivot, (-1, self._n)) # ensure shape
# solution 2: add large offset to pivot indices and sorting gives the
# indices of interest
idx_range = tf.tile(tf.expand_dims(
tf.range(self._n, dtype=tf.int64), axis=0),
(bs, 1))
# large value to be added to irrelevant indices
updates = self._n * tf.ones((bs, self._k), tf.int64)
# generate indices for tf.scatter_nd_add
s = tf.shape(idx_pivot, tf.int64)
ii, _ = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
idx_updates = tf.stack([ii, idx_pivot], axis=-1)
# add large value to pivot positions
idx = tf.tensor_scatter_nd_add(idx_range, idx_updates, updates)
# sort and slice first n-k indices (equals parity positions)
idx_parity = tf.cast(tf.argsort(idx)[:,:self._n-self._k], tf.int64)
idx_sort = tf.concat([idx_pivot, idx_parity], axis=1)
# permute gm according to indices idx_sort
gm = tf.gather(gm, idx_sort, batch_dims=1, axis=-1)
return gm, idx_sort
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Nothing to build, but check for valid shapes."""
assert input_shape[-1]==self._n, "Invalid input shape."
def call(self, inputs):
r"""Applies ordered statistic decoding to inputs.
Remark: the decoder is implemented with llr definition
llr = p(x=1)/p(x=0).
"""
# flatten batch-dim
input_shape = tf.shape(inputs)
llr_ch = tf.reshape(inputs, (-1, self._n))
llr_ch = tf.cast(llr_ch, self.dtype)
bs = tf.shape(llr_ch)[0]
# clip inputs
llr_ch = tf.clip_by_value(llr_ch, -self._llr_max, self._llr_max)
# step 1: sort LLRs
idx_sort = tf.argsort(tf.abs(llr_ch), direction="DESCENDING")
# permute gm per batch sample individually
gm = tf.broadcast_to(tf.expand_dims(self._gm, axis=0),
(bs, self._k,self._n))
gm_sort = tf.gather(gm, idx_sort, batch_dims=1, axis=-1)
# step 2: Find most reliable basis (MRB)
gm_mrb, idx_mrb = self._find_mrb(gm_sort)
# apply corresponding mrb permutations
idx_sort = tf.gather(idx_sort, idx_mrb, batch_dims=1)
llr_sort = tf.gather(llr_ch, idx_sort, batch_dims=1)
# find inverse permutation for final output
idx_sort_inv = tf.argsort(idx_sort)
# hard-decide k most reliable positions and encode
u_hd = hard_decisions(llr_sort[:,0:self._k])
u_hd = tf.expand_dims(u_hd, axis=1)
c = tf.squeeze(tf.matmul(u_hd, gm_mrb), axis=1)
c = int_mod_2(c)
# and search for most likely pattern
# _get_dist expects a list of candidates, thus expand_dims to [bs, 1, n]
d_best = self._get_dist(llr_sort, tf.expand_dims(c, axis=1))
d_best = tf.squeeze(d_best, axis=1)
c_hat_best = c
# known in advance - can be unrolled
for ep in self._err_patterns:
# compute distance for all candidate codewords
d, c_hat = self._find_min_dist(llr_sort, ep, gm_mrb, c)
# select most likely candidate
ind = tf.expand_dims(d<d_best, axis=1)
c_hat_best = tf.where(ind, c_hat, c_hat_best)
d_best = tf.where(d<d_best, d, d_best)
# undo permutations for final codeword
c_hat_best = tf.gather(c_hat_best, idx_sort_inv, axis=1, batch_dims=1)
# input shape
c_hat = tf.reshape(c_hat_best, input_shape)
return c_hat