#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layers for Polar decoding such as successive cancellation (SC), successive
cancellation list (SCL) and iterative belief propagation (BP) decoding."""
import tensorflow as tf
import numpy as np
from numpy.core.numerictypes import issubdtype
import warnings
from tensorflow.keras.layers import Layer
from sionna.fec.crc import CRCDecoder, CRCEncoder
from sionna.fec.polar.encoding import Polar5GEncoder
import numbers
[docs]
class PolarSCDecoder(Layer):
"""PolarSCDecoder(frozen_pos, n, output_dtype=tf.float32, **kwargs)
Successive cancellation (SC) decoder [Arikan_Polar]_ for Polar codes and
Polar-like codes.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
frozen_pos: ndarray
Array of `int` defining the ``n-k`` indices of the frozen positions.
n: int
Defining the codeword length.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer
(internal precision remains tf.float32).
Input
-----
inputs: [...,n], tf.float32
2+D tensor containing the channel LLR values (as logits).
Output
------
: [...,k], tf.float32
2+D tensor containing hard-decided estimations of all ``k``
information bits.
Raises
------
AssertionError
If ``n`` is not `int`.
AssertionError
If ``n`` is not a power of 2.
AssertionError
If the number of elements in ``frozen_pos`` is greater than ``n``.
AssertionError
If ``frozen_pos`` does not consists of `int`.
ValueError
If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}.
Note
----
This layer implements the SC decoder as described in
[Arikan_Polar]_. However, the implementation follows the `recursive
tree` [Gross_Fast_SCL]_ terminology and combines nodes for increased
throughputs without changing the outcome of the algorithm.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
"""
def __init__(self, frozen_pos, n, output_dtype=tf.float32, **kwargs):
if output_dtype not in (tf.float16, tf.float32, tf.float64):
raise ValueError(
'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
if output_dtype is not tf.float32:
print('Note: decoder uses tf.float32 for internal calculations.')
super().__init__(dtype=output_dtype, **kwargs)
self._output_dtype = output_dtype
# assert error if r>1 or k, n are negativ
assert isinstance(n, numbers.Number), "n must be a number."
n = int(n) # n can be float (e.g. as result of n=k*r)
assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
"be greater than n."
assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
# store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
assert self._k==len(self._info_pos), "Internal error: invalid " \
"info_pos generated."
self._llr_max = 30. # internal max LLR value (uncritical for SC dec)
# and create a frozen bit vector for simpler encoding
self._frozen_ind = np.zeros(self._n)
self._frozen_ind[self._frozen_pos] = 1
# enable graph pruning
self._use_fast_sc = False
#########################################
# Public methods and properties
#########################################
@property
def n(self):
"""Codeword length."""
return self._n
@property
def k(self):
"""Number of information bits."""
return self._k
@property
def frozen_pos(self):
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self):
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self):
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def output_dtype(self):
"""Output dtype of decoder."""
return self._output_dtype
#########################
# Utility methods
#########################
def _cn_op_tf(self, x, y):
"""Check-node update (boxplus) for LLR inputs.
Operations are performed element-wise.
See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
"""
x_in = tf.clip_by_value(x,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
y_in = tf.clip_by_value(y,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# avoid division for numerical stability
llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))
return llr_out
def _vn_op_tf(self, x, y, u_hat):
"""VN update for LLR inputs."""
return tf.multiply((1-2*u_hat), x) + y
def _polar_decode_sc_tf(self, llr_ch, frozen_ind):
"""Recursive SC decoding function.
Recursively branch decoding tree and split into decoding of `upper`
and `lower` path until reaching a leaf node.
The function returns the u_hat decisions at stage `0` and the bit
decisions of the intermediate stage `s` (i.e., the re-encoded version of
`u_hat` until the current stage `s`).
Note:
This decoder parallelizes over the batch-dimension, i.e., the tree
is processed for all samples in the batch in parallel. This yields a
higher throughput, but does not improve the latency.
"""
# calculate current codeword length
n = len(frozen_ind)
# branch if leaf is not reached yet
if n>1:
if self._use_fast_sc:
if np.sum(frozen_ind)==n:
#print("rate-0 detected! Length: ", n)
u_hat = tf.zeros_like(llr_ch)
return u_hat, u_hat
llr_ch1 = llr_ch[...,0:int(n/2)]
llr_ch2 = llr_ch[...,int(n/2):]
frozen_ind1 = frozen_ind[0:int(n/2)]
frozen_ind2 = frozen_ind[int(n/2):]
# upper path
x_llr1_in = self._cn_op_tf(llr_ch1, llr_ch2)
# and call the decoding function (with upper half)
u_hat1, u_hat1_up = self._polar_decode_sc_tf(x_llr1_in, frozen_ind1)
# lower path
x_llr2_in = self._vn_op_tf(llr_ch1, llr_ch2, u_hat1_up)
# and call the decoding function again (with lower half)
u_hat2, u_hat2_up = self._polar_decode_sc_tf(x_llr2_in, frozen_ind2)
# combine u_hat from both branches
u_hat = tf.concat([u_hat1, u_hat2], -1)
# calculate re-encoded version of u_hat at current stage
# u_hat1_up = tf.math.mod(u_hat1_up + u_hat2_up, 2)
# combine u_hat via bitwise_xor (more efficient than mod2)
u_hat1_up_int = tf.cast(u_hat1_up, tf.int8)
u_hat2_up_int = tf.cast(u_hat2_up, tf.int8)
u_hat1_up_int = tf.bitwise.bitwise_xor(u_hat1_up_int,
u_hat2_up_int)
u_hat1_up = tf.cast(u_hat1_up_int , tf.float32)
u_hat_up = tf.concat([u_hat1_up, u_hat2_up], -1)
else: # if leaf is reached perform basic decoding op (=decision)
if frozen_ind==1: # position is frozen
u_hat = tf.expand_dims(tf.zeros_like(llr_ch[:,0]), axis=-1)
u_hat_up = u_hat
else: # otherwise hard decide
u_hat = 0.5 * (1. - tf.sign(llr_ch))
#remove "exact 0 llrs" leading to u_hat=0.5
u_hat = tf.where(tf.equal(u_hat, 0.5),
tf.ones_like(u_hat),
u_hat)
u_hat_up = u_hat
return u_hat, u_hat_up
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Check if shape of input is invalid."""
assert (input_shape[-1]==self._n), "Invalid input shape."
assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
def call(self, inputs):
"""Successive cancellation (SC) decoding function.
Performs successive cancellation decoding and returns the estimated
information bits.
Args:
inputs (tf.float32): Tensor of shape `[...,n]` containing the
channel LLR values (as logits).
Returns:
`tf.float32`: Tensor of shape `[...,k]` containing
hard-decided estimations of all ``k`` information bits.
Raises:
ValueError: If ``inputs`` is not of shape `[..., n]`
or `dtype` is not `tf.float32`.
InvalidArgumentError: When rank(``inputs``)<2.
Note:
This function recursively unrolls the SC decoding tree, thus,
for larger values of ``n`` building the decoding graph can become
time consuming.
"""
tf.debugging.assert_type(inputs, self.dtype, 'Invalid input dtype.')
# internal calculations still in tf.float32
inputs = tf.cast(inputs, tf.float32)
# last dim must be of length n
tf.debugging.assert_equal(tf.shape(inputs)[-1],
self._n,
"Last input dimension must be of length n.")
# Reshape inputs to [-1, n]
tf.debugging.assert_greater(tf.rank(inputs), 1)
input_shape = inputs.shape
new_shape = [-1, self._n]
llr_ch = tf.reshape(inputs, new_shape)
llr_ch = -1. * llr_ch # logits are converted into "true" llrs
# and decode
u_hat_n, _ = self._polar_decode_sc_tf(llr_ch, self._frozen_ind)
# and recover the k information bit positions
u_hat = tf.gather(u_hat_n, self._info_pos, axis=1)
# and reconstruct input shape
output_shape = input_shape.as_list()
output_shape[-1] = self.k
output_shape[0] = -1 # first dim can be dynamic (None)
u_hat_reshape = tf.reshape(u_hat, output_shape)
return tf.cast(u_hat_reshape, self._output_dtype)
[docs]
class PolarSCLDecoder(Layer):
# pylint: disable=line-too-long
"""PolarSCLDecoder(frozen_pos, n, list_size=8, crc_degree=None, use_hybrid_sc=False, use_fast_scl=True, cpu_only=False, use_scatter=False, ind_iil_inv=None, return_crc_status=False, output_dtype=tf.float32, **kwargs)
Successive cancellation list (SCL) decoder [Tal_SCL]_ for Polar codes
and Polar-like codes.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
frozen_pos: ndarray
Array of `int` defining the ``n-k`` indices of the frozen positions.
n: int
Defining the codeword length.
list_size: int
Defaults to 8. Defines the list size of the decoder.
crc_degree: str
Defining the CRC polynomial to be used. Can be any value from
`{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`.
use_hybrid_sc: bool
Defaults to False. If True, SC decoding is applied and only the
codewords with invalid CRC are decoded with SCL. This option
requires an outer CRC specified via ``crc_degree``.
Remark: hybrid_sc does not support XLA optimization, i.e.,
`@tf.function(jit_compile=True)`.
use_fast_scl: bool
Defaults to True. If True, Tree pruning is used to
reduce the decoding complexity. The output is equivalent to the
non-pruned version (besides numerical differences).
cpu_only: bool
Defaults to False. If True, `tf.py_function` embedding
is used and the decoder runs on the CPU. This option is usually
slower, but also more memory efficient and, in particular,
recommended for larger blocklengths. Remark: cpu_only does not
support XLA optimization `@tf.function(jit_compile=True)`.
use_scatter: bool
Defaults to False. If True, `tf.tensor_scatter_update` is used for
tensor updates. This option is usually slower, but more memory
efficient.
ind_iil_inv : None or [k+k_crc], int or tf.int
Defaults to None. If not `None`, the sequence is used as inverse
input bit interleaver before evaluating the CRC.
Remark: this only effects the CRC evaluation but the output
sequence is not permuted.
return_crc_status: bool
Defaults to False. If True, the decoder additionally returns the
CRC status indicating if a codeword was (most likely) correctly
recovered. This is only available if ``crc_degree`` is not None.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer
(internal precision remains tf.float32).
Input
-----
inputs: [...,n], tf.float32
2+D tensor containing the channel LLR values (as logits).
Output
------
b_hat : [...,k], tf.float32
2+D tensor containing hard-decided estimations of all `k`
information bits.
crc_status : [...], tf.bool
CRC status indicating if a codeword was (most likely) correctly
recovered. This is only returned if ``return_crc_status`` is True.
Note that false positives are possible.
Raises:
AssertionError
If ``n`` is not `int`.
AssertionError
If ``n`` is not a power of 2.
AssertionError
If the number of elements in ``frozen_pos`` is greater than ``n``.
AssertionError
If ``frozen_pos`` does not consists of `int`.
AssertionError
If ``list_size`` is not `int`.
AssertionError
If ``cpu_only`` is not `bool`.
AssertionError
If ``use_scatter`` is not `bool`.
AssertionError
If ``use_fast_scl`` is not `bool`.
AssertionError
If ``use_hybrid_sc`` is not `bool`.
AssertionError
If ``list_size`` is not a power of 2.
ValueError
If ``output_dtype`` is not {tf.float16, tf.float32, tf.
float64}.
ValueError
If ``inputs`` is not of shape `[..., n]` or `dtype` is not
correct.
InvalidArgumentError
When rank(``inputs``)<2.
Note
----
This layer implements the successive cancellation list (SCL) decoder
as described in [Tal_SCL]_ but uses LLR-based message updates
[Stimming_LLR]_. The implementation follows the notation from
[Gross_Fast_SCL]_, [Hashemi_SSCL]_. If option `use_fast_scl` is active
tree pruning is used and tree nodes are combined if possible (see
[Hashemi_SSCL]_ for details).
Implementing SCL decoding as TensorFlow graph is a difficult task that
requires several design tradeoffs to match the TF constraints while
maintaining a reasonable throughput. Thus, the decoder minimizes
the `control flow` as much as possible, leading to a strong memory
occupation (e.g., due to full path duplication after each decision).
For longer code lengths, the complexity of the decoding graph becomes
large and we recommend to use the `CPU_only` option that uses an
embedded Numpy decoder. Further, this function recursively unrolls the
SCL decoding tree, thus, for larger values of ``n`` building the
decoding graph can become time consuming. Please consider the
``cpu_only`` option if building the graph takes to long.
A hybrid SC/SCL decoder as proposed in [Cammerer_Hybrid_SCL]_ (using SC
instead of BP) can be activated with option ``use_hybrid_sc`` iff an
outer CRC is available. Please note that the results are not exactly
SCL performance caused by the false positive rate of the CRC.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
"""
def __init__(self,
frozen_pos,
n,
list_size=8,
crc_degree=None,
use_hybrid_sc=False,
use_fast_scl=True,
cpu_only=False,
use_scatter=False,
ind_iil_inv=None,
return_crc_status=False,
output_dtype=tf.float32,
**kwargs):
if output_dtype not in (tf.float16, tf.float32, tf.float64):
raise ValueError(
'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
if output_dtype is not tf.float32:
print('Note: decoder uses tf.float32 for internal calculations.')
super().__init__(dtype=output_dtype, **kwargs)
self._output_dtype = output_dtype
# assert error if r>1 or k, n are negative
assert isinstance(n, numbers.Number), "n must be a number."
n = int(n) # n can be float (e.g. as result of n=k*r)
assert isinstance(list_size, int), "list_size must be integer."
assert isinstance(cpu_only, bool), "cpu_only must be bool."
assert isinstance(use_scatter, bool), "use_scatter must be bool."
assert isinstance(use_fast_scl, bool), "use_fast_scl must be bool."
assert isinstance(use_hybrid_sc, bool), "use_hybrid_sc must be bool."
assert isinstance(return_crc_status, bool), \
"return_crc_status must be bool."
assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
"be greater than n."
assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
assert np.log2(list_size)==int(np.log2(list_size)), \
"list_size must be a power of 2."
# CPU mode is recommended for larger values of n
if n>128 and cpu_only is False and use_hybrid_sc is False:
warnings.warn("Required resource allocation is large " \
"for the selected blocklength. Consider option `cpu_only=True`.")
# CPU mode is recommended for larger values of L
if list_size>32 and cpu_only is False and use_hybrid_sc is False:
warnings.warn("Resource allocation is high for the " \
"selected list_size. Consider option `cpu_only=True`.")
# internal decoder parameters
self._use_fast_scl = use_fast_scl # optimize rate-0 and rep nodes
self._use_scatter = use_scatter # slower but more memory friendly
self._cpu_only = cpu_only # run numpy decoder
self._use_hybrid_sc = use_hybrid_sc
# store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._list_size = list_size
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
self._llr_max = 30. # internal max LLR value (not very critical for SC)
assert self._k==len(self._info_pos), "Internal error: invalid " \
"info_pos generated."
# create a frozen bit vector
self._frozen_ind = np.zeros(self._n)
self._frozen_ind[self._frozen_pos] = 1
self._cw_ind = np.arange(self._n)
self._n_stages = int(np.log2(self._n)) # number of decoding stages
# init CRC check (if needed)
if crc_degree is not None:
self._use_crc = True
self._crc_decoder = CRCDecoder(CRCEncoder(crc_degree))
self._k_crc = self._crc_decoder.encoder.crc_length
else:
self._use_crc = False
self._k_crc = 0
assert self._k>=self._k_crc, "Value of k is too small for \
given CRC_degree."
if (crc_degree is None) and return_crc_status:
self._return_crc_status = False
raise ValueError("Returning CRC status requires given crc_degree.")
else:
self._return_crc_status = return_crc_status
# store the inverse interleaver patter
if ind_iil_inv is not None:
assert (ind_iil_inv.shape[0]==self._k), \
"ind_int must be of length k+k_crc."
self._ind_iil_inv = ind_iil_inv
self._iil = True
else:
self._iil = False
# use SC decoder first and use numpy-based SCL as "afterburner"
if self._use_hybrid_sc:
self._decoder_sc = PolarSCDecoder(frozen_pos, n)
# Note: CRC required to detect SC success
if not self._use_crc:
raise ValueError("Hybrid SC requires outer CRC.")
#########################################
# Public methods and properties
#########################################
@property
def n(self):
"""Codeword length."""
return self._n
@property
def k(self):
"""Number of information bits."""
return self._k
@property
def k_crc(self):
"""Number of CRC bits."""
return self._k_crc
@property
def frozen_pos(self):
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self):
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self):
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def list_size(self):
"""List size for SCL decoding."""
return self._list_size
@property
def output_dtype(self):
"""Output dtype of decoder."""
return self._output_dtype
#####################################
# Helper functions for the TF decoder
#####################################
def _update_rate0_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
"""Update rate-0 sub-code (i.e., all frozen) at pos ``cw_ind``.
See eq. (26) in [Hashemi_SSCL]_.
Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
initialized with `0` already.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
llr_in = tf.clip_by_value(llr,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# update path metric for complete sub-block of length n
pm_val = tf.math.softplus(-1.*llr_in)
msg_pm += tf.reduce_sum(pm_val, axis=-1)
return msg_pm, msg_uhat, msg_llr
def _update_rep_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
"""Update rep. code (i.e., only rightmost bit is non-frozen)
sub-code at position ``ind_u``.
See Eq. (31) in [Hashemi_SSCL]_.
Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
initialized with `0` already.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
# update PM
llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
llr_in = tf.clip_by_value(llr,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# upper branch has negative llr values (bit is 1)
llr_low = llr_in[:, :self._list_size, :]
llr_up = - llr_in[:, self._list_size:, :]
llr_pm = tf.concat([llr_low, llr_up], 1)
pm_val = tf.math.softplus(-1.*llr_pm)
msg_pm += tf.reduce_sum(pm_val, axis=-1)
msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
msg_uhat21 = tf.expand_dims(
msg_uhat[:, self._list_size:, stage_ind, :cw_ind[0]],
axis=2)
msg_uhat22= tf.expand_dims(
msg_uhat[:, self._list_size:, stage_ind, cw_ind[-1]+1:],
axis=2)
# ones to insert
msg_ones = tf.ones([tf.shape(msg_uhat)[0], self._list_size, 1, n],
tf.float32)
msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
msg_uhat24_1 = msg_uhat[:, self._list_size:, :stage_ind, :]
msg_uhat24_2 = msg_uhat[:, self._list_size:, stage_ind+1:, :]
msg_uhat2 = tf.concat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], 2)
msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)
# branch last bit and update pm at pos cw_ind[-1]
msg_uhat = self._update_single_bit([cw_ind[-1]], msg_uhat)
msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
msg_uhat,
msg_llr)
msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
msg_llr,
msg_pm)
return msg_pm, msg_uhat, msg_llr
def _update_single_bit(self, ind_u, msg_uhat):
"""Update single bit at position ``ind_u`` for all decoders.
Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
initialized with `0` already.
Remark: Two versions are implemented (throughput vs. graph complexity):
1.) use tensor_scatter_nd_update
2.) explicitly split graph and concatenate again
"""
# position is non-frozen
if self._frozen_ind[ind_u[0]]==0:
# msg_uhat[:, ind_up, 0, ind_u] = 1
if self._use_scatter:
ind_dec = np.arange(self._list_size, 2*self._list_size, 1)
ind_stage = np.array([0])
# transpose such that batch dim can be broadcasted
msg_uhat_t = tf.transpose(msg_uhat, [1, 3, 2, 0])
# generate index grid
ind_u = tf.cast(ind_u, tf.int64)
grid = tf.meshgrid(ind_dec, ind_u, ind_stage)
ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 3])
updates = tf.ones([ind.shape[0], tf.shape(msg_uhat)[0]])
msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
ind,
updates)
# and restore original order
msg_uhat = tf.transpose(msg_uhat_s, [3, 0, 2, 1])
else:
# alternative solution with split/concatenation of graph
msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
msg_uhat21 = tf.expand_dims(
msg_uhat[:, self._list_size:, 0, :ind_u[0]],
axis=2)
msg_uhat22= tf.expand_dims(
msg_uhat[:, self._list_size:, 0, ind_u[0]+1:],
axis=2)
# ones to insert
msg_ones = tf.ones_like(tf.reshape(
msg_uhat[:, self._list_size:, 0, ind_u[0]],
[-1, self._list_size, 1, 1]))
msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
msg_uhat24 = msg_uhat[:, self._list_size:, 1:, :]
msg_uhat2 = tf.concat([msg_uhat23, msg_uhat24], 2)
msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)
return msg_uhat
def _update_pm(self, ind_u, msg_uhat, msg_llr, msg_pm):
"""Update path metric of all decoders after updating bit_pos ``ind_u``.
We implement (10) from [Stimming_LLR]_.
"""
u_hat = msg_uhat[:, :, 0, ind_u[0]]
llr = msg_llr[:, :, 0, ind_u[0]]
llr_in = tf.clip_by_value(llr,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# Numerically more stable implementation of log(1 + exp(-x))
msg_pm += tf.math.softplus(-tf.multiply((1 - 2*u_hat), llr_in))
return msg_pm
def _sort_decoders(self, msg_pm, msg_uhat, msg_llr):
"""Sort decoders according to their path metric."""
ind = tf.argsort(msg_pm, axis=-1)
msg_pm = tf.gather(msg_pm, ind, batch_dims=1, axis=None)
msg_uhat = tf.gather(msg_uhat, ind, batch_dims=1, axis=None)
msg_llr = tf.gather(msg_llr, ind, batch_dims=1, axis=None)
return msg_pm, msg_uhat, msg_llr
def _cn_op(self, x, y):
"""Check-node update (boxplus) for LLR inputs.
Operations are performed element-wise.
See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
"""
x_in = tf.clip_by_value(x,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
y_in = tf.clip_by_value(y,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# Avoid division for numerical stability
# Implements log(1+e^(x+y))
llr_out = tf.math.softplus((x_in + y_in))
# Implements log(e^x+e^y)
llr_out -= tf.math.reduce_logsumexp(tf.stack([x_in, y_in], axis=-1),
axis=-1)
return llr_out
def _vn_op(self, x, y, u_hat):
"""Variable node update for LLR inputs.
Operations are performed element-wise.
See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
"""
return tf.multiply((1 - 2*u_hat), x) + y
def _duplicate_paths(self, msg_uhat, msg_llr, msg_pm):
"""Duplicate paths by copying the upper branch into the lower one.
"""
msg_uhat = tf.tile(msg_uhat[:, :self._list_size, :, :], [1, 2, 1, 1])
msg_llr = tf.tile(msg_llr[:, :self._list_size, :, :], [1, 2, 1, 1])
msg_pm = tf.tile(msg_pm[:, :self._list_size], [1, 2])
return msg_uhat, msg_llr, msg_pm
def _update_left_branch(self, msg_llr, stage_ind, cw_ind_left,cw_ind_right):
"""Update messages of left branch.
Remark: Two versions are implemented (throughput vs. graph complexity):
1.) use tensor_scatter_nd_update
2.) explicitly split graph and concatenate again
"""
llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
cw_ind_left,
axis=2)
llr_right_in = tf.gather(msg_llr[:, :, stage_ind, :],
cw_ind_right,
axis=2)
llr_left_out = self._cn_op(llr_left_in, llr_right_in)
if self._use_scatter:
# self.msg_llr[:, :, stage_ind-1, cw_ind_left] = llr_left_out
# transpose such that batch-dim can be broadcasted
msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
llr_left_out_s = tf.transpose(llr_left_out, [2, 1, 0])
# generate index grid
stage_ind = tf.cast(stage_ind, tf.int64)
cw_ind_left = tf.cast(cw_ind_left, tf.int64)
grid = tf.meshgrid(stage_ind-1, cw_ind_left)
ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
# update values
msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
ind,
llr_left_out_s)
# and restore original order
msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
else:
# alternative solution with split/concatenation of graph
# llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
np.arange(0, cw_ind_left[0]),
axis=2)
llr_right = tf.gather(msg_llr[:, :, stage_ind-1, :],
cw_ind_right,
axis=2)
llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
np.arange(cw_ind_right[-1] +1, self._n),
axis=2)
llr_s = tf.concat([llr_left0,
llr_left_out,
llr_right,
llr_right1], 2)
llr_s = tf.expand_dims(llr_s, axis=2)
msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
msg_llr2 = msg_llr[:, :, stage_ind:, :]
msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)
return msg_llr
def _update_right_branch(self, msg_llr, msg_uhat, stage_ind, cw_ind_left,
cw_ind_right):
"""Update messages for right branch.
Remark: Two versions are implemented (throughput vs. graph complexity):
1.) use tensor_scatter_nd_update
2.) explicitly split graph and concatenate again
"""
u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
cw_ind_left,
axis=2)
llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
cw_ind_left,
axis=2)
llr_right = tf.gather(msg_llr[:, :, stage_ind, :],
cw_ind_right,
axis=2)
llr_right_out = self._vn_op(llr_left_in, llr_right, u_hat_left_up)
if self._use_scatter:
# transpose such that batch dim can be broadcasted
msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
llr_right_out_s = tf.transpose(llr_right_out, [2, 1, 0])
# generate index grid
stage_ind = tf.cast(stage_ind, tf.int64)
cw_ind_left = tf.cast(cw_ind_right, tf.int64)
grid = tf.meshgrid(stage_ind-1, cw_ind_right)
ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
ind,
llr_right_out_s)
# and restore original order
msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
else:
# alternative solution with split/concatenation of graph
# llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
np.arange(0, cw_ind_left[0]),
axis=2)
llr_left = tf.gather(msg_llr[:, :, stage_ind-1, :],
cw_ind_left,
axis=2)
llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
np.arange(cw_ind_right[-1]+1, self._n),
axis=2)
llr_s = tf.concat([llr_left0, llr_left, llr_right_out,llr_right1],2)
llr_s = tf.expand_dims(llr_s, axis=2)
msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
msg_llr2 = msg_llr[:, :, stage_ind:, :]
msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)
return msg_llr
def _update_branch_u(self, msg_uhat, stage_ind, cw_ind_left, cw_ind_right):
"""Update ``u_hat`` messages after executing both branches.
Remark: Two versions are implemented (throughput vs. graph complexity):
1.) use tensor_scatter_nd_update
2.) explicitly split graph and concatenate again
"""
u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
cw_ind_left,
axis=2)
u_hat_right_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
cw_ind_right,
axis=2)
# combine u_hat via bitwise_xor (more efficient than mod2)
u_hat_left_up_int = tf.cast(u_hat_left_up, tf.int32)
u_hat_right_up_int = tf.cast(u_hat_right_up, tf.int32)
u_hat_left = tf.bitwise.bitwise_xor(u_hat_left_up_int,
u_hat_right_up_int)
u_hat_left = tf.cast(u_hat_left, tf.float32)
if self._use_scatter:
cw_ind = np.concatenate([cw_ind_left, cw_ind_right])
u_hat = tf.concat([u_hat_left, u_hat_right_up], -1)
# self.msg_llr[:, stage_ind-1, cw_ind_left] = llr_left_out
# transpose such that batch dim can be broadcasted
msg_uhat_t = tf.transpose(msg_uhat, [2, 3, 1, 0])
u_hat_s = tf.transpose(u_hat, [2, 1, 0])
# generate index grid
stage_ind = tf.cast(stage_ind, tf.int64)
cw_ind = tf.cast(cw_ind, tf.int64)
grid = tf.meshgrid(stage_ind, cw_ind)
ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
ind,
u_hat_s)
# and restore original order
msg_uhat = tf.transpose(msg_uhat_s, [3, 2, 0, 1])
else:
# alternative solution with split/concatenation of graph
u_hat_left_0 = tf.gather(msg_uhat[:, :, stage_ind, :],
np.arange(0, cw_ind_left[0]),
axis=2)
u_hat_right_1 = tf.gather(msg_uhat[:, :, stage_ind, :],
np.arange(cw_ind_right[-1]+1, self._n),
axis=2)
u_hat = tf.concat([u_hat_left_0,
u_hat_left,
u_hat_right_up,
u_hat_right_1], 2)
# provide u_hat for next higher stage
msg_uhat1 = msg_uhat[:, :, 0:stage_ind, :]
msg_uhat2 = msg_uhat[:, :, stage_ind+1:, :]
u_hat = tf.expand_dims(u_hat, axis=2)
msg_uhat = tf.concat([msg_uhat1, u_hat, msg_uhat2], 2)
return msg_uhat
def _polar_decode_scl(self, cw_ind, msg_uhat, msg_llr, msg_pm):
"""Recursive decoding function for SCL decoding.
We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
and branch the messages into a `left` and `right` update paths until
reaching a leaf node.
Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
tree depth while maintaining the same output.
"""
# current sub-code length and stage index (= tree depth)
n = len(cw_ind)
stage_ind = int(np.log2(n))
# recursively branch through decoding tree
if n>1:
# prune tree if rate-0 subcode is detected
if self._use_fast_scl:
if np.sum(self._frozen_ind[cw_ind])==n:
msg_pm, msg_uhat, msg_llr = self._update_rate0_code(msg_pm,
msg_uhat,
msg_llr,
cw_ind)
return msg_uhat, msg_llr, msg_pm
if (self._frozen_ind[cw_ind[-1]]==0 and
np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
msg_pm, msg_uhat, msg_llr, = self._update_rep_code(msg_pm,
msg_uhat,
msg_llr,
cw_ind)
return msg_uhat, msg_llr, msg_pm
# split index into left and right part
cw_ind_left = cw_ind[0:int(n/2)]
cw_ind_right = cw_ind[int(n/2):]
# ----- left branch -----
msg_llr = self. _update_left_branch(msg_llr,
stage_ind,
cw_ind_left,
cw_ind_right)
# call sub-graph decoder of left branch
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_left,
msg_uhat,
msg_llr,
msg_pm)
# ----- right branch -----
msg_llr = self._update_right_branch(msg_llr,
msg_uhat,
stage_ind,
cw_ind_left,
cw_ind_right)
# call sub-graph decoder of right branch
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_right,
msg_uhat,
msg_llr,
msg_pm)
# update uhat at current stage
msg_uhat = self._update_branch_u(msg_uhat,
stage_ind,
cw_ind_left,
cw_ind_right)
# if leaf is reached perform basic decoding op (=decision)
else:
# update bit value at current position
msg_uhat = self._update_single_bit(cw_ind, msg_uhat)
# update PM
msg_pm = self._update_pm(cw_ind, msg_uhat, msg_llr, msg_pm)
if self._frozen_ind[cw_ind]==0: # position is non-frozen
# sort list
msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
msg_uhat,
msg_llr)
# duplicate l best decoders to pos l:2*l (kill other decoders)
msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
msg_llr,
msg_pm)
return msg_uhat, msg_llr, msg_pm
def _decode_tf(self, llr_ch):
"""Main decoding function in TF.
Initializes memory and calls recursive decoding function.
"""
batch_size = tf.shape(llr_ch)[0]
# allocate memory for all 2*list_size decoders
msg_uhat = tf.zeros([batch_size,
2*self._list_size,
self._n_stages+1,
self._n])
msg_llr = tf.zeros([batch_size,
2*self._list_size,
self._n_stages,
self._n])
# init all 2*l decoders with same llr_ch
llr_ch = tf.reshape(llr_ch, [-1, 1, 1, self._n])
llr_ch = tf.tile(llr_ch,[1, 2*self._list_size, 1, 1])
# init last stage with llr_ch
msg_llr = tf.concat([msg_llr, llr_ch], 2)
# init all remaining L-1 decoders with high penalty
pm0 = tf.zeros([batch_size, 1])
pm1 = self._llr_max * tf.ones([batch_size, self._list_size-1])
msg_pm = tf.concat([pm0, pm1, pm0, pm1], 1)
# and call recursive graph function
msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(self._cw_ind,
msg_uhat,
msg_llr,
msg_pm)
# and sort output
msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
msg_uhat,
msg_llr)
return [msg_uhat, msg_pm]
####################################
# Helper functions for Numpy decoder
####################################
def _update_rate0_code_np(self, cw_ind):
"""Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind`` in Numpy.
See Eq. (26) in [Hashemi_SSCL]_.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
# update PM for each batch sample
ind = np.expand_dims(self._dec_pointer, axis=-1)
llr_in = np.take_along_axis(self.msg_llr[:, :, stage_ind, cw_ind],
ind,
axis=1)
llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
pm_val = np.log(1 + np.exp(-llr_clip))
self.msg_pm += np.sum(pm_val, axis=-1)
def _update_rep_code_np(self, cw_ind):
"""Update rep. code (i.e., only rightmost bit is non-frozen)
sub-code at position ``ind_u`` in Numpy.
See Eq. (31) in [Hashemi_SSCL]_.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
bs = self._dec_pointer.shape[0]
# update PM
llr = np.zeros([bs, 2*self._list_size, n])
for i in range(bs):
llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :]
llr[i, :, :] = llr_i[:, cw_ind]
# upper branch has negative llr values (bit is 1)
llr[:, self._list_size:, :] = - llr[:, self._list_size:, :]
llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max)
pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1)
self.msg_pm += pm_val
for i in range(bs):
ind_dec = self._dec_pointer[i, self._list_size:]
for j in cw_ind:
self.msg_uhat[i, ind_dec, stage_ind, j] = 1
# branch last bit and update pm at pos cw_ind[-1]
self._update_single_bit_np([cw_ind[-1]])
self._sort_decoders_np()
self._duplicate_paths_np()
def _update_single_bit_np(self, ind_u):
"""Update single bit at position ``ind_u`` of all decoders in Numpy."""
if self._frozen_ind[ind_u]==0: # position is non-frozen
ind_dec = np.expand_dims(self._dec_pointer[:, self._list_size:],
axis=-1)
uhat_slice = self.msg_uhat[:, :, 0, ind_u]
np.put_along_axis(uhat_slice, ind_dec, 1., axis=1)
self.msg_uhat[:, :, 0, ind_u] = uhat_slice
def _update_pm_np(self, ind_u):
""" Update path metric of all decoders at bit position ``ind_u`` in
Numpy.
We apply Eq. (10) from [Stimming_LLR]_.
"""
ind = np.expand_dims(self._dec_pointer, axis=-1)
u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1)
u_hat = np.squeeze(u_hat, axis=-1)
llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1)
llr_in = np.squeeze(llr_in, axis=-1)
llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
self.msg_pm += np.log(1 + np.exp(-np.multiply((1-2*u_hat), llr_clip)))
def _sort_decoders_np(self):
"""Sort decoders according to their path metric."""
ind = np.argsort(self.msg_pm, axis=-1)
self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1)
self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1)
def _cn_op_np(self, x, y):
"""Check node update (boxplus) for LLRs in Numpy.
See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
"""
x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max)
y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max)
# avoid division for numerical stability
llr_out = np.log(1 + np.exp(x_in + y_in))
llr_out -= np.log(np.exp(x_in) + np.exp(y_in))
return llr_out
def _vn_op_np(self, x, y, u_hat):
"""Variable node update (boxplus) for LLRs in Numpy."""
return np.multiply((1-2*u_hat), x) + y
def _duplicate_paths_np(self):
"""Copy first ``list_size``/2 paths into lower part in Numpy.
Decoder indices are encoded in ``self._dec_pointer``.
"""
ind_low = self._dec_pointer[:, :self._list_size]
ind_up = self._dec_pointer[:, self._list_size:]
for i in range(ind_up.shape[0]):
self.msg_uhat[i, ind_up[i,:], :, :] = self.msg_uhat[i,
ind_low[i,:],
:, :]
self.msg_llr[i, ind_up[i,:],:,:] = self.msg_llr[i, ind_low[i,:],:,:]
# pm must be sorted directly (not accessed via pointer)
self.msg_pm[:, self._list_size:] = self.msg_pm[:, :self._list_size]
def _polar_decode_scl_np(self, cw_ind):
"""Recursive decoding function in Numpy.
We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
and branch the messages into a `left` and `right` update paths until
reaching a leaf node.
Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
tree depth while maintaining the same output.
"""
n = len(cw_ind)
stage_ind = int(np.log2(n))
# recursively branch through decoding tree
if n>1:
# prune tree if rate-0 subcode or rep-code is detected
if self._use_fast_scl:
if np.sum(self._frozen_ind[cw_ind])==n:
# rate0 code detected
self._update_rate0_code_np(cw_ind)
return
if (self._frozen_ind[cw_ind[-1]]==0 and
np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
# rep code detected
self._update_rep_code_np(cw_ind)
return
cw_ind_left = cw_ind[0:int(n/2)]
cw_ind_right = cw_ind[int(n/2):]
# ----- left branch -----
llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
self.msg_llr[:, :, stage_ind-1, cw_ind_left] = self._cn_op_np(
llr_left,
llr_right)
# call left branch decoder
self._polar_decode_scl_np(cw_ind_left)
# ----- right branch -----
u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
self.msg_llr[:, :, stage_ind-1, cw_ind_right] = self._vn_op_np(
llr_left,
llr_right,
u_hat_left_up)
# call right branch decoder
self._polar_decode_scl_np(cw_ind_right)
# combine u_hat
u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
u_hat_right_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_right]
# u_hat_left_up XOR u_hat_right_up
u_hat_left = (u_hat_left_up != u_hat_right_up) + 0
u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1)
# provide u_hat for next higher stage
self.msg_uhat[:, :, stage_ind, cw_ind] = u_hat
else: # if leaf is reached perform basic decoding op (=decision)
self._update_single_bit_np(cw_ind)
# update PM
self._update_pm_np(cw_ind)
# position is non-frozen
if self._frozen_ind[cw_ind]==0:
# sort list
self._sort_decoders_np()
# duplicate the best list_size decoders
self._duplicate_paths_np()
return
def _decode_np_batch(self, llr_ch):
"""Decode batch of ``llr_ch`` with Numpy decoder."""
bs = llr_ch.shape[0]
# allocate memory for all 2*list_size decoders
self.msg_uhat = np.zeros([bs,
2*self._list_size,
self._n_stages+1,
self._n])
self.msg_llr = np.zeros([bs,
2*self._list_size,
self._n_stages+1,
self._n])
self.msg_pm = np.zeros([bs,
2*self._list_size])
# L-1 decoders start with high penalty
self.msg_pm[:,1:self._list_size] = self._llr_max
# same for the second half of the L-1 decoders
self.msg_pm[:,self._list_size+1:] = self._llr_max
# use pointers to avoid in-memory sorting
self._dec_pointer = np.arange(2*self._list_size)
self._dec_pointer = np.tile(np.expand_dims(self._dec_pointer, axis=0),
[bs,1])
# init llr_ch (broadcast via list dimension)
self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1)
# call recursive graph function
self._polar_decode_scl_np(self._cw_ind)
# select most likely candidate
self._sort_decoders_np()
# remove pointers
for ind in range(bs):
self.msg_uhat[ind, :, :, :] = self.msg_uhat[ind,
self._dec_pointer[ind],
:, :]
return self.msg_uhat, self.msg_pm
def _decode_np_hybrid(self, llr_ch, u_hat_sc, crc_valid):
"""Hybrid SCL decoding stage that decodes iff CRC from previous SC
decoding attempt failed.
This option avoids the usage of the high-complexity SCL decoder in cases
where SC would be sufficient. For further details we refer to
[Cammerer_Hybrid_SCL]_ (we use SC instead of the proposed BP stage).
Remark: This decoder does not exactly implement SCL as the CRC
can be false positive after the SC stage. However, in these cases
SCL+CRC may also yield the wrong results.
Remark 2: Due to the excessive control flow (if/else) and the
varying batch-sizes, this function is only available as Numpy
decoder (i.e., runs on the CPU).
"""
bs = llr_ch.shape[0]
crc_valid = np.squeeze(crc_valid, axis=-1)
# index of codewords that need SCL decoding
ind_invalid = np.arange(bs)[np.invert(crc_valid)]
# init SCL decoder for bs_hyb samples requiring SCL dec.
llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0)
msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb)
# merge results with previously decoded SC results
msg_uhat = np.zeros([bs, 2*self._list_size, 1, self._n])
msg_pm = np.ones([bs, 2*self._list_size]) * self._llr_max * self.k
msg_pm[:, 0] = 0
# copy SC data
msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc
ind_hyb = 0
for ind in range(bs):
if not crc_valid[ind]:
#copy data from SCL
msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :]
msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :]
ind_hyb += 1
return msg_uhat, msg_pm
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Build and check if shape of input is invalid."""
assert (input_shape[-1]==self._n), "Invalid input shape."
assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
def call(self, inputs):
"""Successive cancellation list (SCL) decoding function.
This function performs successive cancellation list decoding
and returns the estimated information bits.
An outer CRC can be applied optionally by setting ``crc_degree``.
Args:
inputs (tf.float32): Tensor of shape `[...,n]` containing the
channel LLR values (as logits).
Returns:
`tf.float32`: Tensor of shape `[...,k]` containing
hard-decided estimations of all ``k`` information bits.
Raises:
ValueError: If ``inputs`` is not of shape `[..., n]`
or `dtype` is not `tf.float32`.
InvalidArgumentError: When rank(``inputs``)<2.
Note:
This function recursively unrolls the SCL decoding tree, thus,
for larger values of ``n`` building the decoding graph can become
time consuming. Please consider the ``cpu_only`` option instead.
"""
tf.debugging.assert_type(inputs, self._output_dtype,
"Invalid input dtype.")
# internal calculations still in tf.float32
inputs = tf.cast(inputs, tf.float32)
# last dim must be of length n
tf.debugging.assert_equal(tf.shape(inputs)[-1],
self._n,
"Last input dimension must be of length n.")
# Reshape inputs to [-1, n]
tf.debugging.assert_greater(tf.rank(inputs), 1)
input_shape = inputs.shape
new_shape = [-1, self._n]
llr_ch = tf.reshape(inputs, new_shape)
llr_ch = -1. * llr_ch # logits are converted into "true" llrs
# if activated use Numpy decoder
if self._use_hybrid_sc:
# use SC decoder to decode first
u_hat = self._decoder_sc(-llr_ch)
_, crc_valid = self._crc_decoder(u_hat)
msg_uhat, msg_pm = tf.py_function(func=self._decode_np_hybrid,
inp=[llr_ch, u_hat, crc_valid],
Tout=[tf.float32, tf.float32])
# note: return shape is only 1 in 3. dim (to avoid copy overhead)
msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, 1, self._n])
msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
else:
if self._cpu_only:
msg_uhat, msg_pm = tf.py_function(func=self._decode_np_batch,
inp=[llr_ch],
Tout=[tf.float32, tf.float32])
# restore shape information
msg_uhat = tf.reshape(msg_uhat,
[-1, 2*self._list_size, self._n_stages+1, self._n])
msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
else:
msg_uhat, msg_pm = self._decode_tf(llr_ch)
# check CRC (and remove CRC parity bits)
if self._use_crc:
u_hat_list = tf.gather(msg_uhat[:, :, 0, :],
self._info_pos,
axis=-1)
# undo input bit interleaving
# remark: the output is not interleaved for compatibility with SC
if self._iil:
u_hat_list_crc = tf.gather(u_hat_list,
self._ind_iil_inv,
axis=-1)
else: # no interleaving applied
u_hat_list_crc = u_hat_list
_, crc_valid = self._crc_decoder(u_hat_list_crc)
# add penalty to pm if CRC fails
pm_penalty = ((1. - tf.cast(crc_valid, tf.float32))
* self._llr_max * self.k)
msg_pm += tf.squeeze(pm_penalty, axis=2)
# select most likely candidate
cand_ind = tf.argmin(msg_pm, axis=-1)
c_hat = tf.gather(msg_uhat[:, :, 0, :], cand_ind, axis=1, batch_dims=1)
u_hat = tf.gather(c_hat, self._info_pos, axis=-1)
# and reconstruct input shape
output_shape = input_shape.as_list()
output_shape[-1] = self.k
output_shape[0] = -1 # first dim can be dynamic (None)
u_hat_reshape = tf.reshape(u_hat, output_shape)
if self._return_crc_status:
# reconstruct CRC status
crc_status = tf.gather(crc_valid, cand_ind, axis=1, batch_dims=1)
# reconstruct shape
output_shape.pop() # remove last dimension
crc_status = tf.reshape(crc_status, output_shape)
crc_status = tf.cast(crc_status, self._output_dtype)
# return info bits and CRC status
return tf.cast(u_hat_reshape, self._output_dtype), crc_status
else: # return only info bits
return tf.cast(u_hat_reshape, self._output_dtype)
[docs]
class PolarBPDecoder(Layer):
# pylint: disable=line-too-long
"""PolarBPDecoder(frozen_pos, n, num_iter=20, hard_out=True, output_dtype=tf.float32, **kwargs)
Belief propagation (BP) decoder for Polar codes [Arikan_Polar]_ and
Polar-like codes based on [Arikan_BP]_ and [Forney_Graphs]_.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Remark: The PolarBPDecoder does currently not support XLA.
Parameters
----------
frozen_pos: ndarray
Array of `int` defining the ``n-k`` indices of the frozen positions.
n: int
Defining the codeword length.
num_iter: int
Defining the number of decoder iterations (no early stopping used
at the moment).
hard_out: bool
Defaults to True. If True, the decoder provides hard-decided
information bits instead of soft-values.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer
(internal precision remains tf.float32).
Input
-----
inputs: [...,n], tf.float32
2+D tensor containing the channel logits/llr values.
Output
------
: [...,k], tf.float32
2+D tensor containing bit-wise soft-estimates
(or hard-decided bit-values) of all ``k`` information bits.
Raises
------
AssertionError
If ``n`` is not `int`.
AssertionError
If ``n`` is not a power of 2.
AssertionError
If the number of elements in ``frozen_pos`` is greater than ``n``.
AssertionError
If ``frozen_pos`` does not consists of `int`.
AssertionError
If ``hard_out`` is not `bool`.
ValueError
If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}.
AssertionError
If ``num_iter`` is not `int`.
AssertionError
If ``num_iter`` is not a positive value.
Note
----
This decoder is fully differentiable and, thus, well-suited for
gradient descent-based learning tasks such as `learned code design`
[Ebada_Design]_.
As commonly done, we assume frozen bits are set to `0`. Please note
that - although its practical relevance is only little - setting frozen
bits to `1` may result in `affine` codes instead of linear code as the
`all-zero` codeword is not necessarily part of the code any more.
"""
def __init__(self,
frozen_pos,
n,
num_iter=20,
hard_out=True,
output_dtype=tf.float32,
**kwargs):
if output_dtype not in (tf.float16, tf.float32, tf.float64):
raise ValueError(
'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
if output_dtype is not tf.float32:
print('Note: decoder uses tf.float32 for internal calculations.')
super().__init__(dtype=output_dtype, **kwargs)
self._output_dtype = output_dtype
# assert error if r>1 or k, n are negative
assert isinstance(n, numbers.Number), "n must be a number."
n = int(n) # n can be float (e.g. as result of n=k*r)
assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
"be greater than n."
assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
assert isinstance(hard_out, bool), "hard_out must be boolean."
# store internal attributes
self._n = n
self._frozen_pos = frozen_pos
self._k = self._n - len(self._frozen_pos)
self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
assert self._k==len(self._info_pos), "Internal error: invalid " \
"info_pos generated."
assert isinstance(num_iter, int), "num_iter must be integer."
assert num_iter>0, "num_iter must be a positive value."
self._num_iter = tf.constant(num_iter, dtype=tf.int32)
self._llr_max = 19.3 # internal max LLR value
self._hard_out = hard_out
# depth of decoding graph
self._n_stages = int(np.log2(self._n))
#########################################
# Public methods and properties
#########################################
@property
def n(self):
"""Codeword length."""
return self._n
@property
def k(self):
"""Number of information bits."""
return self._k
@property
def frozen_pos(self):
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self):
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self):
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def num_iter(self):
"""Number of decoding iterations."""
return self._num_iter
@property
def hard_out(self):
"""Indicates if decoder hard-decides outputs."""
return self._hard_out
@property
def output_dtype(self):
"""Output dtype of decoder."""
return self._output_dtype
@num_iter.setter
def num_iter(self, num_iter):
"Number of decoding iterations."
assert isinstance(num_iter, int), 'num_iter must be int.'
assert num_iter>=0, 'num_iter cannot be negative.'
self._num_iter = tf.constant(num_iter, dtype=tf.int32)
#########################
# Utility methods
#########################
def _boxplus_tf(self, x, y):
"""Check-node update (boxplus) for LLR inputs.
Operations are performed element-wise.
"""
x_in = tf.clip_by_value(x,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
y_in = tf.clip_by_value(y,
clip_value_min=-self._llr_max,
clip_value_max=self._llr_max)
# avoid division for numerical stability
llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))
return llr_out
def _decode_bp(self, llr_ch, num_iter):
"""Iterative BP decoding function with LLR-values.
Args:
llr_ch (tf.float32): Tensor of shape `[batch_size, n]` containing
the channel logits/llr values where `batch_size` denotes the
batch-size.
num_iter (int): Defining the number of decoder iteration
(no early stopping used at the moment).
Returns:
`tf.float32`: Tensor of shape `[batch_size, k]` containing
bit-wise soft-estimates (or hard-decided bit-values) of all
information bits.
"""
bs = tf.shape(llr_ch)[0]
# store intermediate Tensors in TensorArray
msg_l = tf.TensorArray(tf.float32,
size=num_iter*(self._n_stages+1),
dynamic_size=False,
clear_after_read=False)
msg_r = tf.TensorArray(tf.float32,
size=num_iter*(self._n_stages+1),
dynamic_size=False,
clear_after_read=False)
# init frozen positions with infinity
msg_r_in = np.zeros([1, self._n])
msg_r_in[:, self._frozen_pos] = self._llr_max
# copy for all batch-samples
msg_r_in = tf.tile(tf.constant(msg_r_in, tf.float32), [bs, 1])
# perform decoding iterations
for ind_it in tf.range(self._num_iter):
# update left-to-right messages
for ind_s in range(self._n_stages):
# calc indices
ind_range = np.arange(int(self._n/2))
ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
ind_2 = ind_1 + 2**ind_s
# simplify gather with concatenated outputs
ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
# load incoming l messages
if ind_s==self._n_stages-1:
l1_in = tf.gather(llr_ch, ind_1, axis=1)
l2_in = tf.gather(llr_ch, ind_2, axis=1)
elif ind_it==0:
l1_in = tf.zeros([bs, int(self._n/2)])
l2_in = tf.zeros([bs, int(self._n/2)])
else:
l_in = msg_l.read((ind_s+1) + (ind_it-1)*(self._n_stages+1))
l1_in = tf.gather(l_in, ind_1, axis=1)
l2_in = tf.gather(l_in, ind_2, axis=1)
# load incoming r messages
if ind_s==0:
r1_in = tf.gather(msg_r_in, ind_1, axis=1)
r2_in = tf.gather(msg_r_in, ind_2, axis=1)
else:
r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
r1_in = tf.gather(r_in, ind_1, axis=1)
r2_in = tf.gather(r_in, ind_2, axis=1)
r1_out = self._boxplus_tf(r1_in, l2_in + r2_in)
r2_out = self._boxplus_tf(r1_in, l1_in) + r2_in
# and re-concatenate output
r_out = tf.concat([r1_out, r2_out], 1)
r_out = tf.gather(r_out, ind_inv, axis=1)
msg_r = msg_r.write((ind_s+1)
+ ind_it*(self._n_stages+1), r_out)
# update right-to-left messages
for ind_s in range(self._n_stages-1, -1, -1):
ind_range = np.arange(int(self._n/2))
ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
ind_2 = ind_1 + 2**ind_s
ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
# load messages
if ind_s==self._n_stages-1:
l1_in = tf.gather(llr_ch, ind_1, axis=1)
l2_in = tf.gather(llr_ch, ind_2, axis=1)
else:
l_in = msg_l.read((ind_s+1)+ind_it*(self._n_stages+1))
l1_in = tf.gather(l_in, ind_1, axis=1)
l2_in = tf.gather(l_in, ind_2, axis=1)
if ind_s==0:
r1_in = tf.gather(msg_r_in, ind_1, axis=1)
r2_in = tf.gather(msg_r_in, ind_2, axis=1)
else:
r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
r1_in = tf.gather(r_in, ind_1, axis=1)
r2_in = tf.gather(r_in, ind_2, axis=1)
# node update functions
l1_out = self._boxplus_tf(l1_in, l2_in + r2_in)
l2_out = self._boxplus_tf(r1_in, l1_in) + l2_in
l_out = tf.concat([l1_out, l2_out], 1)
l_out = tf.gather(l_out, ind_inv, axis=1)
msg_l = msg_l.write(ind_s + ind_it*(self._n_stages+1), l_out)
# recover u_hat
u_hat = tf.gather(msg_l.read((num_iter-1)*(self._n_stages+1)),
self._info_pos,
axis=1)
# if active, hard-decide output bits
if self._hard_out:
u_hat = tf.where(u_hat>0, 0., 1.)
else: # re-transform soft output to logits (instead of llrs)
u_hat = -1. * u_hat
return u_hat
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Build and check if shape of input is invalid."""
assert (input_shape[-1]==self._n), "Invalid input shape"
assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
def call(self, inputs):
"""Iterative BP decoding function.
This function performs `num_iter` belief propagation decoding iterations
and returns the estimated information bits.
Args:
inputs (tf.float32): Tensor of shape `[...,n]` containing the
channel logits/llr values.
Returns:
`tf.float32`: Tensor of shape `[...,k]` containing
bit-wise soft-estimates (or hard-decided bit-values) of all
``k`` information bits.
Raises:
ValueError: If ``inputs`` is not of shape `[..., n]`
or `dtype` is not `output_dtype`.
InvalidArgumentError: When rank(``inputs``)<2.
Note:
This function recursively unrolls the BP decoding graph, thus,
for larger values of ``n`` or more iterations, building the
decoding graph can become time and memory consuming.
"""
tf.debugging.assert_type(inputs, self._output_dtype,
"Invalid input dtype.")
# internal calculations still in tf.float32
inputs = tf.cast(inputs, tf.float32)
# Reshape inputs to [-1, n]
input_shape = inputs.shape
new_shape = [-1, self._n]
llr_ch = tf.reshape(inputs, new_shape)
llr_ch = -1. * llr_ch # logits are converted into "true" llrs
# and decode
u_hat = self._decode_bp(llr_ch, self._num_iter)
# and reconstruct input shape
output_shape = input_shape.as_list()
output_shape[-1] = self.k
output_shape[0] = -1 # first dim can be dynamic (None)
u_hat_reshape = tf.reshape(u_hat, output_shape)
return tf.cast(u_hat_reshape, self._output_dtype)
[docs]
class Polar5GDecoder(Layer):
# pylint: disable=line-too-long
"""Polar5GDecoder(enc_polar, dec_type="SC", list_size=8, num_iter=20,return_crc_status=False, output_dtype=tf.float32, **kwargs)
Wrapper for 5G compliant decoding including rate-recovery and CRC removal.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
enc_polar: Polar5GEncoder
Instance of the :class:`~sionna.fec.polar.encoding.Polar5GEncoder`
used for encoding including rate-matching.
dec_type: str
Defaults to `"SC"`. Defining the decoder to be used.
Must be one of the following `{"SC", "SCL", "hybSCL", "BP"}`.
list_size: int
Defaults to 8. Defining the list size `iff` list-decoding is used.
Only required for ``dec_types`` `{"SCL", "hybSCL"}`.
num_iter: int
Defaults to 20. Defining the number of BP iterations. Only required
for ``dec_type`` `"BP"`.
return_crc_status: bool
Defaults to False. If True, the decoder additionally returns the
CRC status indicating if a codeword was (most likely) correctly
recovered.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer
(internal precision remains tf.float32).
Input
-----
inputs: [...,n], tf.float32
2+D tensor containing the channel logits/llr values.
Output
------
b_hat : [...,k], tf.float32
2+D tensor containing hard-decided estimations of all `k`
information bits.
crc_status : [...], tf.bool
CRC status indicating if a codeword was (most likely) correctly
recovered. This is only returned if ``return_crc_status`` is True.
Note that false positives are possible.
Raises
------
AssertionError
If ``enc_polar`` is not `Polar5GEncoder`.
ValueError
If ``dec_type`` is not `{"SC", "SCL", "SCL8", "SCL32", "hybSCL",
"BP"}`.
AssertionError
If ``dec_type`` is not `str`.
ValueError
If ``inputs`` is not of shape `[..., n]` or `dtype` is not
the same as ``output_dtype``.
InvalidArgumentError
When rank(``inputs``)<2.
Note
----
This layer supports the uplink and downlink Polar rate-matching scheme
without `codeword segmentation`.
Although the decoding `list size` is not provided by 3GPP
[3GPPTS38212]_, the consortium has agreed on a `list size` of 8 for the
5G decoding reference curves [Bioglio_Design]_.
All list-decoders apply `CRC-aided` decoding, however, the non-list
decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an
effective rate-loss.
"""
def __init__(self,
enc_polar,
dec_type="SC",
list_size=8,
num_iter=20,
return_crc_status=False,
output_dtype=tf.float32,
**kwargs):
if output_dtype not in (tf.float16, tf.float32, tf.float64):
raise ValueError(
'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
if output_dtype is not tf.float32:
print('Note: decoder uses tf.float32 for internal calculations.')
self._output_dtype = output_dtype
super().__init__(dtype=output_dtype, **kwargs)
assert isinstance(enc_polar, Polar5GEncoder), \
"enc_polar must be Polar5GEncoder."
assert isinstance(dec_type, str), "dec_type must be str."
# list_size and num_iter are not checked here (done during decoder init)
# Store internal attributes
self._n_target = enc_polar.n_target
self._k_target = enc_polar.k_target
self._n_polar = enc_polar.n_polar
self._k_polar = enc_polar.k_polar
self._k_crc = enc_polar.enc_crc.crc_length
self._bil = enc_polar._channel_type == "uplink"
self._iil = enc_polar._channel_type == "downlink"
self._llr_max = 100 # Internal max LLR value (for punctured positions)
self._enc_polar = enc_polar
self._dec_type = dec_type
# Initialize the de-interleaver patterns
self._init_interleavers()
# Initialize decoder
if dec_type=="SC":
print("Warning: 5G Polar codes use an integrated CRC that " \
"cannot be materialized with SC decoding and, thus, " \
"causes a degraded performance. Please consider SCL " \
"decoding instead.")
self._polar_dec = PolarSCDecoder(self._enc_polar.frozen_pos,
self._n_polar)
elif dec_type=="SCL":
self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
self._n_polar,
crc_degree=self._enc_polar.enc_crc.crc_degree,
list_size=list_size,
ind_iil_inv = self.ind_iil_inv)
elif dec_type=="hybSCL":
self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
self._n_polar,
crc_degree=self._enc_polar.enc_crc.crc_degree,
list_size=list_size,
use_hybrid_sc=True,
ind_iil_inv = self.ind_iil_inv)
elif dec_type=="BP":
print("Warning: 5G Polar codes use an integrated CRC that " \
"cannot be materialized with BP decoding and, thus, " \
"causes a degraded performance. Please consider SCL " \
" decoding instead.")
assert isinstance(num_iter, int), "num_iter must be int."
assert num_iter > 0, "num_iter must be positive."
self._num_iter = num_iter
self._polar_dec = PolarBPDecoder(self._enc_polar.frozen_pos,
self._n_polar,
num_iter=num_iter,
hard_out=True)
else:
raise ValueError("Unknown value for dec_type.")
assert isinstance(return_crc_status, bool), \
"return_crc_status must be bool."
self._return_crc_status = return_crc_status
if self._return_crc_status: # init crc decoder
if dec_type in ("SCL", "hybSCL"):
# re-use CRC decoder from list decoder
self._dec_crc = self._polar_dec._crc_decoder
else: # init new CRC decoder for BP and SC
self._dec_crc = CRCDecoder(self._enc_polar._enc_crc)
#########################################
# Public methods and properties
#########################################
@property
def k_target(self):
"""Number of information bits including rate-matching."""
return self._k_target
@property
def n_target(self):
"""Codeword length including rate-matching."""
return self._n_target
@property
def k_polar(self):
"""Number of information bits of mother Polar code."""
return self._k_polar
@property
def n_polar(self):
"""Codeword length of mother Polar code."""
return self._n_polar
@property
def frozen_pos(self):
"""Frozen positions for Polar decoding."""
return self._frozen_pos
@property
def info_pos(self):
"""Information bit positions for Polar encoding."""
return self._info_pos
@property
def llr_max(self):
"""Maximum LLR value for internal calculations."""
return self._llr_max
@property
def dec_type(self):
"""Decoder type used for decoding as str."""
return self._dec_type
@property
def polar_dec(self):
"""Decoder instance used for decoding."""
return self._polar_dec
@property
def output_dtype(self):
"""Output dtype of decoder."""
return self._output_dtype
#########################
# Utility methods
#########################
def _init_interleavers(self):
"""Initialize inverse interleaver patterns for rate-recovery."""
# Channel interleaver
ind_ch_int = self._enc_polar.channel_interleaver(
np.arange(self._n_target))
self.ind_ch_int_inv = np.argsort(ind_ch_int) # Find inverse perm
# Sub-block interleaver
ind_sub_int = self._enc_polar.subblock_interleaving(
np.arange(self._n_polar))
self.ind_sub_int_inv = np.argsort(ind_sub_int) # Find inverse perm
# input bit interleaver
if self._iil:
self.ind_iil_inv = np.argsort(self._enc_polar.input_interleaver(
np.arange(self._k_polar)))
else:
self.ind_iil_inv = None
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Build and check if shape of input is invalid."""
assert (input_shape[-1]==self._n_target), "Invalid input shape."
assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
def call(self, inputs):
"""Polar decoding and rate-recovery for uplink 5G Polar codes.
Args:
inputs (tf.float32): Tensor of shape `[...,n]` containing the
channel logits/llr values.
Returns:
`tf.float32`: Tensor of shape `[...,k]` containing
hard-decided estimates of all ``k`` information bits.
Raises:
ValueError: If ``inputs`` is not of shape `[..., n]`
or `dtype` is not `output_dtype`.
InvalidArgumentError: When rank(``inputs``)<2.
"""
tf.debugging.assert_type(inputs, self._output_dtype,
"Invalid input dtype.")
# internal calculations still in tf.float32
inputs = tf.cast(inputs, tf.float32)
# Reshape inputs to [-1, n]
tf.debugging.assert_greater(tf.rank(inputs), 1)
input_shape = inputs.shape
new_shape = [-1, self._n_target]
llr_ch = tf.reshape(inputs, new_shape)
# Note: logits are not inverted here; this is done in the decoder itself
# 1.) Undo channel interleaving
if self._bil:
llr_deint = tf.gather(llr_ch, self.ind_ch_int_inv, axis=1)
else:
llr_deint = llr_ch
# 2.) Remove puncturing, shortening, repetition (see Sec. 5.4.1.2)
# a) Puncturing: set LLRs to 0
# b) Shortening: set LLRs to infinity
# c) Repetition: combine LLRs
if self._n_target >= self._n_polar:
# Repetition coding
# Add the last n_rep positions to the first llr positions
n_rep = self._n_target - self._n_polar
llr_1 = llr_deint[:,:n_rep]
llr_2 = llr_deint[:,n_rep:self._n_polar]
llr_3 = llr_deint[:,self._n_polar:]
llr_dematched = tf.concat([llr_1+llr_3, llr_2], 1)
else:
if self._k_polar/self._n_target <= 7/16:
# Puncturing
# Append n_polar - n_target "zero" llrs to first positions
llr_zero = tf.zeros([tf.shape(llr_deint)[0],
self._n_polar-self._n_target])
llr_dematched = tf.concat([llr_zero, llr_deint], 1)
else:
# Shortening
# Append n_polar - n_target "-infinity" llrs to last positions
# Remark: we still operate with logits here, thus the neg. sign
llr_infty = -self._llr_max * tf.ones([tf.shape(llr_deint)[0],
self._n_polar-self._n_target])
llr_dematched = tf.concat([llr_deint, llr_infty], 1)
# 3.) Remove subblock interleaving
llr_dec = tf.gather(llr_dematched, self.ind_sub_int_inv, axis=1)
# 4.) Run main decoder
u_hat_crc = self._polar_dec(llr_dec)
# 5.) Shortening should be implicitly recovered by decoder
# 6.) Remove input bit interleaving for downlink channels only
if self._iil:
u_hat_crc = tf.gather(u_hat_crc, self.ind_iil_inv, axis=1)
# 7.) Evaluate or remove CRC (and PC)
if self._return_crc_status:
# for compatibility with SC/BP, a dedicated CRC decoder is
# used here (instead of accessing the interal SCL)
u_hat, crc_status = self._dec_crc(u_hat_crc)
else: # just remove CRC bits
u_hat = u_hat_crc[:,:-self._k_crc]
# And reconstruct input shape
output_shape = input_shape.as_list()
output_shape[-1] = self._k_target
output_shape[0] = -1 # First dim can be dynamic (None)
u_hat_reshape = tf.reshape(u_hat, output_shape)
# and cast to output dtype
u_hat_reshape = tf.cast(u_hat_reshape, dtype=self._output_dtype)
if self._return_crc_status:
# reconstruct CRC shape
output_shape.pop() # remove last dimension
crc_status = tf.reshape(crc_status, output_shape)
crc_status = tf.cast(crc_status, dtype=self._output_dtype)
return u_hat_reshape, crc_status
else:
return u_hat_reshape