Source code for sionna.phy.fec.polar.decoding

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Blocks for Polar decoding such as successive cancellation (SC), successive
cancellation list (SCL) and iterative belief propagation (BP) decoding."""

import tensorflow as tf
import numpy as np
import warnings
from sionna.phy import Block
from sionna.phy.fec.crc import CRCDecoder, CRCEncoder
from sionna.phy.fec.polar.encoding import Polar5GEncoder
import numbers

[docs] class PolarSCDecoder(Block): """Successive cancellation (SC) decoder [Arikan_Polar]_ for Polar codes and Polar-like codes. Parameters ---------- frozen_pos: ndarray Array of `int` defining the ``n-k`` indices of the frozen positions. n: int Defining the codeword length. precision : `None` (default) | 'single' | 'double' Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. Input ----- llr_ch: [...,n], tf.float Tensor containing the channel LLR values (as logits). Output ------ : [...,k], tf.float Tensor containing hard-decided estimations of all ``k`` information bits. Note ---- This block implements the SC decoder as described in [Arikan_Polar]_. However, the implementation follows the `recursive tree` [Gross_Fast_SCL]_ terminology and combines nodes for increased throughputs without changing the outcome of the algorithm. As commonly done, we assume frozen bits are set to `0`. Please note that - although its practical relevance is only little - setting frozen bits to `1` may result in `affine` codes instead of linear code as the `all-zero` codeword is not necessarily part of the code any more. """ def __init__(self, frozen_pos, n, precision=None, **kwargs): super().__init__(precision=precision, **kwargs) # assert error if r>1 or k, n are negative if not isinstance(n, numbers.Number): raise TypeError( "n must be a number.") n = int(n) # n can be float (e.g. as result of n=k*r) if not np.issubdtype(frozen_pos.dtype, int): raise TypeError("frozen_pos contains non int.") if len(frozen_pos)>n: msg = "Num. of elements in frozen_pos cannot be greater than n." raise ValueError(msg) if np.log2(n)!=int(np.log2(n)): raise ValueError("n must be a power of 2.") # store internal attributes self._n = n self._frozen_pos = frozen_pos self._k = self._n - len(self._frozen_pos) self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) if self._k!=len(self._info_pos): msg = "Internal error: invalid info_pos generated." raise ArithmeticError(msg) self._llr_max = 30. # internal max LLR value (uncritical for SC dec) # and create a frozen bit vector for simpler encoding self._frozen_ind = np.zeros(self._n) self._frozen_ind[self._frozen_pos] = 1 # enable graph pruning self._use_fast_sc = False ############################### # Public methods and properties ############################### @property def n(self): """Codeword length""" return self._n @property def k(self): """Number of information bits""" return self._k @property def frozen_pos(self): """Frozen positions for Polar decoding""" return self._frozen_pos @property def info_pos(self): """Information bit positions for Polar encoding""" return self._info_pos @property def llr_max(self): """Maximum LLR value for internal calculations""" return self._llr_max ################# # Utility methods ################# def _cn_op_tf(self, x, y): """Check-node update (boxplus) for LLR inputs. Operations are performed element-wise. See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. """ x_in = tf.clip_by_value(x, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) y_in = tf.clip_by_value(y, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # avoid division for numerical stability llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in)) llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in)) return llr_out def _vn_op_tf(self, x, y, u_hat): """VN update for LLR inputs.""" return tf.multiply((1-2*u_hat), x) + y def _polar_decode_sc_tf(self, llr_ch, frozen_ind): """Recursive SC decoding function. Recursively branch decoding tree and split into decoding of `upper` and `lower` path until reaching a leaf node. The function returns the u_hat decisions at stage `0` and the bit decisions of the intermediate stage `s` (i.e., the re-encoded version of `u_hat` until the current stage `s`). Note: This decoder parallelizes over the batch-dimension, i.e., the tree is processed for all samples in the batch in parallel. This yields a higher throughput, but does not improve the latency. """ # calculate current codeword length n = len(frozen_ind) # branch if leaf is not reached yet if n>1: if self._use_fast_sc: if np.sum(frozen_ind)==n: #print("rate-0 detected! Length: ", n) u_hat = tf.zeros_like(llr_ch) return u_hat, u_hat llr_ch1 = llr_ch[...,0:int(n/2)] llr_ch2 = llr_ch[...,int(n/2):] frozen_ind1 = frozen_ind[0:int(n/2)] frozen_ind2 = frozen_ind[int(n/2):] # upper path x_llr1_in = self._cn_op_tf(llr_ch1, llr_ch2) # and call the decoding function (with upper half) u_hat1, u_hat1_up = self._polar_decode_sc_tf(x_llr1_in, frozen_ind1) # lower path x_llr2_in = self._vn_op_tf(llr_ch1, llr_ch2, u_hat1_up) # and call the decoding function again (with lower half) u_hat2, u_hat2_up = self._polar_decode_sc_tf(x_llr2_in, frozen_ind2) # combine u_hat from both branches u_hat = tf.concat([u_hat1, u_hat2], -1) # calculate re-encoded version of u_hat at current stage # u_hat1_up = tf.math.mod(u_hat1_up + u_hat2_up, 2) # combine u_hat via bitwise_xor (more efficient than mod2) u_hat1_up_int = tf.cast(u_hat1_up, tf.int8) u_hat2_up_int = tf.cast(u_hat2_up, tf.int8) u_hat1_up_int = tf.bitwise.bitwise_xor(u_hat1_up_int, u_hat2_up_int) u_hat1_up = tf.cast(u_hat1_up_int , self.rdtype) u_hat_up = tf.concat([u_hat1_up, u_hat2_up], -1) else: # if leaf is reached perform basic decoding op (=decision) if frozen_ind==1: # position is frozen u_hat = tf.expand_dims(tf.zeros_like(llr_ch[:,0]), axis=-1) u_hat_up = u_hat else: # otherwise hard decide u_hat = 0.5 * (1. - tf.sign(llr_ch)) #remove "exact 0 llrs" leading to u_hat=0.5 u_hat = tf.where(tf.equal(u_hat, 0.5), tf.ones_like(u_hat), u_hat) u_hat_up = u_hat return u_hat, u_hat_up ######################## # Sionna Block functions ######################## def build(self, input_shape): """Check if shape of input is invalid.""" if input_shape[-1]!=self._n: raise ValueError("Invalid input shape.") def call(self, llr_ch, /): """Successive cancellation (SC) decoding function. Performs successive cancellation decoding and returns the estimated information bits. Args: llr_ch (tf.float): Tensor of shape `[...,n]` containing the channel LLR values (as logits). Returns: `tf.float`: Tensor of shape `[...,k]` containing hard-decided estimations of all ``k`` information bits. Note: This function recursively unrolls the SC decoding tree, thus, for larger values of ``n`` building the decoding graph can become time consuming. """ # Reshape inputs to [-1, n] input_shape = llr_ch.shape new_shape = [-1, self._n] llr_ch = tf.reshape(llr_ch, new_shape) llr_ch = -1. * llr_ch # logits are converted into "true" llrs # and decode u_hat_n, _ = self._polar_decode_sc_tf(llr_ch, self._frozen_ind) # and recover the k information bit positions u_hat = tf.gather(u_hat_n, self._info_pos, axis=1) # and reconstruct input shape output_shape = input_shape.as_list() output_shape[-1] = self.k output_shape[0] = -1 # first dim can be dynamic (None) u_hat_reshape = tf.reshape(u_hat, output_shape) return u_hat_reshape
[docs] class PolarSCLDecoder(Block): # pylint: disable=line-too-long """Successive cancellation list (SCL) decoder [Tal_SCL]_ for Polar codes and Polar-like codes. Parameters ---------- frozen_pos: ndarray Array of `int` defining the ``n-k`` indices of the frozen positions. n: int Defining the codeword length. list_size: int, (default 8) Defines the list size of the decoder. crc_degree: str, "CRC24A" | "CRC24B" | "CRC24C" | "CRC16" | "CRC11" | "CRC6" Defining the CRC polynomial to be used. Can be any value from `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`. use_hybrid_sc: `bool`, (default `False`) If `True`, SC decoding is applied and only the codewords with invalid CRC are decoded with SCL. This option requires an outer CRC specified via ``crc_degree``. Remark: hybrid_sc does not support XLA optimization, i.e., `@tf.function(jit_compile=True)`. use_fast_scl: `bool`, (default `True`) If `True`, Tree pruning is used to reduce the decoding complexity. The output is equivalent to the non-pruned version (besides numerical differences). cpu_only: `bool`, (default `False`) If `True`, `tf.py_function` embedding is used and the decoder runs on the CPU. This option is usually slower, but also more memory efficient and in particular, recommended for larger blocklengths. Remark: cpu_only does not support XLA optimization `@tf.function(jit_compile=True)`. use_scatter: `bool`, (default `False`) If `True`, `tf.tensor_scatter_update` is used for tensor updates. This option is usually slower, but more memory efficient. ind_iil_inv : None or [k+k_crc], int or tf.int Defaults to None. If not `None`, the sequence is used as inverse input bit interleaver before evaluating the CRC. Remark: this only effects the CRC evaluation but the output sequence is not permuted. return_crc_status: `bool`, (default `False`) If `True`, the decoder additionally returns the CRC status indicating if a codeword was (most likely) correctly recovered. This is only available if ``crc_degree`` is not None. precision : `None` (default) | 'single' | 'double' Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. Input ----- llr_ch: [...,n], tf.float Tensor containing the channel LLR values (as logits). Output ------ b_hat : [...,k], tf.float Binary tensor containing hard-decided estimations of all `k` information bits. crc_status : [...], tf.bool CRC status indicating if a codeword was (most likely) correctly recovered. This is only returned if ``return_crc_status`` is True. Note that false positives are possible. Note ---- This block implements the successive cancellation list (SCL) decoder as described in [Tal_SCL]_ but uses LLR-based message updates [Stimming_LLR]_. The implementation follows the notation from [Gross_Fast_SCL]_, [Hashemi_SSCL]_. If option `use_fast_scl` is active tree pruning is used and tree nodes are combined if possible (see [Hashemi_SSCL]_ for details). Implementing SCL decoding as TensorFlow graph is a difficult task that requires several design tradeoffs to match the TF constraints while maintaining a reasonable throughput. Thus, the decoder minimizes the `control flow` as much as possible, leading to a strong memory occupation (e.g., due to full path duplication after each decision). For longer code lengths, the complexity of the decoding graph becomes large and we recommend to use the `CPU_only` option that uses an embedded Numpy decoder. Further, this function recursively unrolls the SCL decoding tree, thus, for larger values of ``n`` building the decoding graph can become time consuming. Please consider the ``cpu_only`` option if building the graph takes to long. A hybrid SC/SCL decoder as proposed in [Cammerer_Hybrid_SCL]_ (using SC instead of BP) can be activated with option ``use_hybrid_sc`` iff an outer CRC is available. Please note that the results are not exactly SCL performance caused by the false positive rate of the CRC. As commonly done, we assume frozen bits are set to `0`. Please note that - although its practical relevance is only little - setting frozen bits to `1` may result in `affine` codes instead of linear code as the `all-zero` codeword is not necessarily part of the code any more. """ def __init__(self, frozen_pos, n, list_size=8, crc_degree=None, use_hybrid_sc=False, use_fast_scl=True, cpu_only=False, use_scatter=False, ind_iil_inv=None, return_crc_status=False, precision=None, **kwargs): super().__init__(precision=precision, **kwargs) # assert error if r>1 or k, n are negative if not isinstance(n, numbers.Number): raise TypeError("n must be a number.") n = int(n) # n can be float (e.g. as result of n=k*r) if not isinstance(list_size, int): raise TypeError("list_size must be integer.") if not isinstance(cpu_only, bool): raise TypeError("cpu_only must be bool.") if not isinstance(use_scatter, bool): raise TypeError("use_scatter must be bool.") if not isinstance(use_fast_scl, bool): raise TypeError("use_fast_scl must be bool.") if not isinstance(use_hybrid_sc, bool): raise TypeError("use_hybrid_sc must be bool.") if not isinstance(return_crc_status, bool): raise TypeError("return_crc_status must be bool.") if not np.issubdtype(frozen_pos.dtype, int): raise TypeError("frozen_pos contains non int.") if len(frozen_pos)>n: msg = "Num. of elements in frozen_pos cannot be greater than n." raise ValueError(msg) if np.log2(n)!=int(np.log2(n)): raise ValueError("n must be a power of 2.") if np.log2(list_size)!=int(np.log2(list_size)): raise ValueError("list_size must be a power of 2.") # CPU mode is recommended for larger values of n if n>128 and cpu_only is False and use_hybrid_sc is False: warnings.warn("Required resource allocation is large " \ "for the selected blocklength. Consider option `cpu_only=True`.") # CPU mode is recommended for larger values of L if list_size>32 and cpu_only is False and use_hybrid_sc is False: warnings.warn("Resource allocation is high for the " \ "selected list_size. Consider option `cpu_only=True`.") # internal decoder parameters self._use_fast_scl = use_fast_scl # optimize rate-0 and rep nodes self._use_scatter = use_scatter # slower but more memory friendly self._cpu_only = cpu_only # run numpy decoder self._use_hybrid_sc = use_hybrid_sc # store internal attributes self._n = n self._frozen_pos = frozen_pos self._k = self._n - len(self._frozen_pos) self._list_size = list_size self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) self._llr_max = 30. # internal max LLR value (not very critical for SC) if self._k!=len(self._info_pos): raise ArithmeticError("Internal error: invalid info_pos generated.") # create a frozen bit vector self._frozen_ind = np.zeros(self._n) self._frozen_ind[self._frozen_pos] = 1 self._cw_ind = np.arange(self._n) self._n_stages = int(np.log2(self._n)) # number of decoding stages # init CRC check (if needed) if crc_degree is not None: self._use_crc = True self._crc_encoder = CRCEncoder(crc_degree, precision=precision) self._crc_decoder = CRCDecoder(self._crc_encoder, precision=precision) self._k_crc = self._crc_decoder.encoder.crc_length else: self._use_crc = False self._k_crc = 0 if self._k<self._k_crc: msg = "Value of k is too small for given CRC_degree." raise ValueError(msg) if (crc_degree is None) and return_crc_status: self._return_crc_status = False raise ValueError("Returning CRC status requires given crc_degree.") else: self._return_crc_status = return_crc_status # store the inverse interleaver patter if ind_iil_inv is not None: if ind_iil_inv.shape[0]!=self._k: raise ValueError("ind_int must be of length k+k_crc.") self._ind_iil_inv = ind_iil_inv self._iil = True else: self._iil = False # use SC decoder first and use numpy-based SCL as "afterburner" if self._use_hybrid_sc: self._decoder_sc = PolarSCDecoder(frozen_pos, n, precision=precision) # Note: CRC required to detect SC success if not self._use_crc: raise ValueError("Hybrid SC requires outer CRC.") ############################### # Public methods and properties ############################### @property def n(self): """Codeword length""" return self._n @property def k(self): """Number of information bits""" return self._k @property def k_crc(self): """Number of CRC bits""" return self._k_crc @property def frozen_pos(self): """Frozen positions for Polar decoding""" return self._frozen_pos @property def info_pos(self): """Information bit positions for Polar encoding""" return self._info_pos @property def llr_max(self): """Maximum LLR value for internal calculations""" return self._llr_max @property def list_size(self): """List size for SCL decoding""" return self._list_size ##################################### # Helper functions for the TF decoder ##################################### def _update_rate0_code(self, msg_pm, msg_uhat, msg_llr, cw_ind): """Update rate-0 sub-code (i.e., all frozen) at pos ``cw_ind``. See eq. (26) in [Hashemi_SSCL]_. Remark: bits are not explicitly set to `0` as ``msg_uhat`` is initialized with `0` already. """ n = len(cw_ind) stage_ind = int(np.log2(n)) llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2) llr_in = tf.clip_by_value(llr, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # update path metric for complete sub-block of length n pm_val = tf.math.softplus(-1.*llr_in) msg_pm += tf.reduce_sum(pm_val, axis=-1) return msg_pm, msg_uhat, msg_llr def _update_rep_code(self, msg_pm, msg_uhat, msg_llr, cw_ind): """Update rep. code (i.e., only rightmost bit is non-frozen) sub-code at position ``ind_u``. See Eq. (31) in [Hashemi_SSCL]_. Remark: bits are not explicitly set to `0` as ``msg_uhat`` is initialized with `0` already. """ n = len(cw_ind) stage_ind = int(np.log2(n)) # update PM llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2) llr_in = tf.clip_by_value(llr, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # upper branch has negative llr values (bit is 1) llr_low = llr_in[:, :self._list_size, :] llr_up = - llr_in[:, self._list_size:, :] llr_pm = tf.concat([llr_low, llr_up], 1) pm_val = tf.math.softplus(-1.*llr_pm) msg_pm += tf.reduce_sum(pm_val, axis=-1) msg_uhat1 = msg_uhat[:, :self._list_size, :, :] msg_uhat21 = tf.expand_dims( msg_uhat[:, self._list_size:, stage_ind, :cw_ind[0]], axis=2) msg_uhat22= tf.expand_dims( msg_uhat[:, self._list_size:, stage_ind, cw_ind[-1]+1:], axis=2) # ones to insert msg_ones = tf.ones([tf.shape(msg_uhat)[0], self._list_size, 1, n], self.rdtype) msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3) msg_uhat24_1 = msg_uhat[:, self._list_size:, :stage_ind, :] msg_uhat24_2 = msg_uhat[:, self._list_size:, stage_ind+1:, :] msg_uhat2 = tf.concat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], 2) msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1) # branch last bit and update pm at pos cw_ind[-1] msg_uhat = self._update_single_bit([cw_ind[-1]], msg_uhat) msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, msg_uhat, msg_llr) msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat, msg_llr, msg_pm) return msg_pm, msg_uhat, msg_llr def _update_single_bit(self, ind_u, msg_uhat): """Update single bit at position ``ind_u`` for all decoders. Remark: bits are not explicitly set to `0` as ``msg_uhat`` is initialized with `0` already. Remark: Two versions are implemented (throughput vs. graph complexity): 1.) use tensor_scatter_nd_update 2.) explicitly split graph and concatenate again """ # position is non-frozen if self._frozen_ind[ind_u[0]]==0: # msg_uhat[:, ind_up, 0, ind_u] = 1 if self._use_scatter: ind_dec = np.arange(self._list_size, 2*self._list_size, 1) ind_stage = np.array([0]) # transpose such that batch dim can be broadcasted msg_uhat_t = tf.transpose(msg_uhat, [1, 3, 2, 0]) # generate index grid ind_u = tf.cast(ind_u, tf.int64) grid = tf.meshgrid(ind_dec, ind_u, ind_stage) ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 3]) updates = tf.ones([ind.shape[0], tf.shape(msg_uhat)[0]], self.rdtype) msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t, ind, updates) # and restore original order msg_uhat = tf.transpose(msg_uhat_s, [3, 0, 2, 1]) else: # alternative solution with split/concatenation of graph msg_uhat1 = msg_uhat[:, :self._list_size, :, :] msg_uhat21 = tf.expand_dims( msg_uhat[:, self._list_size:, 0, :ind_u[0]], axis=2) msg_uhat22= tf.expand_dims( msg_uhat[:, self._list_size:, 0, ind_u[0]+1:], axis=2) # ones to insert msg_ones = tf.ones_like(tf.reshape( msg_uhat[:, self._list_size:, 0, ind_u[0]], [-1, self._list_size, 1, 1])) msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3) msg_uhat24 = msg_uhat[:, self._list_size:, 1:, :] msg_uhat2 = tf.concat([msg_uhat23, msg_uhat24], 2) msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1) return msg_uhat def _update_pm(self, ind_u, msg_uhat, msg_llr, msg_pm): """Update path metric of all decoders after updating bit_pos ``ind_u``. We implement (10) from [Stimming_LLR]_. """ u_hat = msg_uhat[:, :, 0, ind_u[0]] llr = msg_llr[:, :, 0, ind_u[0]] llr_in = tf.clip_by_value(llr, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # Numerically more stable implementation of log(1 + exp(-x)) msg_pm += tf.math.softplus(-tf.multiply((1 - 2*u_hat), llr_in)) return msg_pm def _sort_decoders(self, msg_pm, msg_uhat, msg_llr): """Sort decoders according to their path metric.""" ind = tf.argsort(msg_pm, axis=-1) msg_pm = tf.gather(msg_pm, ind, batch_dims=1, axis=None) msg_uhat = tf.gather(msg_uhat, ind, batch_dims=1, axis=None) msg_llr = tf.gather(msg_llr, ind, batch_dims=1, axis=None) return msg_pm, msg_uhat, msg_llr def _cn_op(self, x, y): """Check-node update (boxplus) for LLR inputs. Operations are performed element-wise. See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. """ x_in = tf.clip_by_value(x, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) y_in = tf.clip_by_value(y, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # Avoid division for numerical stability # Implements log(1+e^(x+y)) llr_out = tf.math.softplus((x_in + y_in)) # Implements log(e^x+e^y) llr_out -= tf.math.reduce_logsumexp(tf.stack([x_in, y_in], axis=-1), axis=-1) return llr_out def _vn_op(self, x, y, u_hat): """Variable node update for LLR inputs. Operations are performed element-wise. See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. """ return tf.multiply((1 - 2*u_hat), x) + y def _duplicate_paths(self, msg_uhat, msg_llr, msg_pm): """Duplicate paths by copying the upper branch into the lower one. """ msg_uhat = tf.tile(msg_uhat[:, :self._list_size, :, :], [1, 2, 1, 1]) msg_llr = tf.tile(msg_llr[:, :self._list_size, :, :], [1, 2, 1, 1]) msg_pm = tf.tile(msg_pm[:, :self._list_size], [1, 2]) return msg_uhat, msg_llr, msg_pm def _update_left_branch(self, msg_llr, stage_ind, cw_ind_left,cw_ind_right): """Update messages of left branch. Remark: Two versions are implemented (throughput vs. graph complexity): 1.) use tensor_scatter_nd_update 2.) explicitly split graph and concatenate again """ llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind_left, axis=2) llr_right_in = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind_right, axis=2) llr_left_out = self._cn_op(llr_left_in, llr_right_in) if self._use_scatter: # self.msg_llr[:, :, stage_ind-1, cw_ind_left] = llr_left_out # transpose such that batch-dim can be broadcasted msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0]) llr_left_out_s = tf.transpose(llr_left_out, [2, 1, 0]) # generate index grid stage_ind = tf.cast(stage_ind, tf.int64) cw_ind_left = tf.cast(cw_ind_left, tf.int64) grid = tf.meshgrid(stage_ind-1, cw_ind_left) ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) # update values msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t, ind, llr_left_out_s) # and restore original order msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1]) else: # alternative solution with split/concatenation of graph # llr_left = msg_llr[:, :, stage_ind, cw_ind_left] llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :], np.arange(0, cw_ind_left[0]), axis=2) llr_right = tf.gather(msg_llr[:, :, stage_ind-1, :], cw_ind_right, axis=2) llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :], np.arange(cw_ind_right[-1] +1, self._n), axis=2) llr_s = tf.concat([llr_left0, llr_left_out, llr_right, llr_right1], 2) llr_s = tf.expand_dims(llr_s, axis=2) msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :] msg_llr2 = msg_llr[:, :, stage_ind:, :] msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2) return msg_llr def _update_right_branch(self, msg_llr, msg_uhat, stage_ind, cw_ind_left, cw_ind_right): """Update messages for right branch. Remark: Two versions are implemented (throughput vs. graph complexity): 1.) use tensor_scatter_nd_update 2.) explicitly split graph and concatenate again """ u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], cw_ind_left, axis=2) llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind_left, axis=2) llr_right = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind_right, axis=2) llr_right_out = self._vn_op(llr_left_in, llr_right, u_hat_left_up) if self._use_scatter: # transpose such that batch dim can be broadcasted msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0]) llr_right_out_s = tf.transpose(llr_right_out, [2, 1, 0]) # generate index grid stage_ind = tf.cast(stage_ind, tf.int64) cw_ind_left = tf.cast(cw_ind_right, tf.int64) grid = tf.meshgrid(stage_ind-1, cw_ind_right) ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t, ind, llr_right_out_s) # and restore original order msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1]) else: # alternative solution with split/concatenation of graph # llr_left = msg_llr[:, :, stage_ind, cw_ind_left] llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :], np.arange(0, cw_ind_left[0]), axis=2) llr_left = tf.gather(msg_llr[:, :, stage_ind-1, :], cw_ind_left, axis=2) llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :], np.arange(cw_ind_right[-1]+1, self._n), axis=2) llr_s = tf.concat([llr_left0, llr_left, llr_right_out,llr_right1],2) llr_s = tf.expand_dims(llr_s, axis=2) msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :] msg_llr2 = msg_llr[:, :, stage_ind:, :] msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2) return msg_llr def _update_branch_u(self, msg_uhat, stage_ind, cw_ind_left, cw_ind_right): """Update ``u_hat`` messages after executing both branches. Remark: Two versions are implemented (throughput vs. graph complexity): 1.) use tensor_scatter_nd_update 2.) explicitly split graph and concatenate again """ u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], cw_ind_left, axis=2) u_hat_right_up = tf.gather(msg_uhat[:, :, stage_ind-1, :], cw_ind_right, axis=2) # combine u_hat via bitwise_xor (more efficient than mod2) u_hat_left_up_int = tf.cast(u_hat_left_up, tf.int32) u_hat_right_up_int = tf.cast(u_hat_right_up, tf.int32) u_hat_left = tf.bitwise.bitwise_xor(u_hat_left_up_int, u_hat_right_up_int) u_hat_left = tf.cast(u_hat_left, self.rdtype) if self._use_scatter: cw_ind = np.concatenate([cw_ind_left, cw_ind_right]) u_hat = tf.concat([u_hat_left, u_hat_right_up], -1) # self.msg_llr[:, stage_ind-1, cw_ind_left] = llr_left_out # transpose such that batch dim can be broadcasted msg_uhat_t = tf.transpose(msg_uhat, [2, 3, 1, 0]) u_hat_s = tf.transpose(u_hat, [2, 1, 0]) # generate index grid stage_ind = tf.cast(stage_ind, tf.int64) cw_ind = tf.cast(cw_ind, tf.int64) grid = tf.meshgrid(stage_ind, cw_ind) ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2]) msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t, ind, u_hat_s) # and restore original order msg_uhat = tf.transpose(msg_uhat_s, [3, 2, 0, 1]) else: # alternative solution with split/concatenation of graph u_hat_left_0 = tf.gather(msg_uhat[:, :, stage_ind, :], np.arange(0, cw_ind_left[0]), axis=2) u_hat_right_1 = tf.gather(msg_uhat[:, :, stage_ind, :], np.arange(cw_ind_right[-1]+1, self._n), axis=2) u_hat = tf.concat([u_hat_left_0, u_hat_left, u_hat_right_up, u_hat_right_1], 2) # provide u_hat for next higher stage msg_uhat1 = msg_uhat[:, :, 0:stage_ind, :] msg_uhat2 = msg_uhat[:, :, stage_ind+1:, :] u_hat = tf.expand_dims(u_hat, axis=2) msg_uhat = tf.concat([msg_uhat1, u_hat, msg_uhat2], 2) return msg_uhat def _polar_decode_scl(self, cw_ind, msg_uhat, msg_llr, msg_pm): """Recursive decoding function for SCL decoding. We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_ and branch the messages into a `left` and `right` update paths until reaching a leaf node. Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the tree depth while maintaining the same output. """ # current sub-code length and stage index (= tree depth) n = len(cw_ind) stage_ind = int(np.log2(n)) # recursively branch through decoding tree if n>1: # prune tree if rate-0 subcode is detected if self._use_fast_scl: if np.sum(self._frozen_ind[cw_ind])==n: msg_pm, msg_uhat, msg_llr = self._update_rate0_code(msg_pm, msg_uhat, msg_llr, cw_ind) return msg_uhat, msg_llr, msg_pm if (self._frozen_ind[cw_ind[-1]]==0 and np.sum(self._frozen_ind[cw_ind[:-1]])==n-1): msg_pm, msg_uhat, msg_llr, = self._update_rep_code(msg_pm, msg_uhat, msg_llr, cw_ind) return msg_uhat, msg_llr, msg_pm # split index into left and right part cw_ind_left = cw_ind[0:int(n/2)] cw_ind_right = cw_ind[int(n/2):] # ----- left branch ----- msg_llr = self. _update_left_branch(msg_llr, stage_ind, cw_ind_left, cw_ind_right) # call sub-graph decoder of left branch msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_left, msg_uhat, msg_llr, msg_pm) # ----- right branch ----- msg_llr = self._update_right_branch(msg_llr, msg_uhat, stage_ind, cw_ind_left, cw_ind_right) # call sub-graph decoder of right branch msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_right, msg_uhat, msg_llr, msg_pm) # update uhat at current stage msg_uhat = self._update_branch_u(msg_uhat, stage_ind, cw_ind_left, cw_ind_right) # if leaf is reached perform basic decoding op (=decision) else: # update bit value at current position msg_uhat = self._update_single_bit(cw_ind, msg_uhat) # update PM msg_pm = self._update_pm(cw_ind, msg_uhat, msg_llr, msg_pm) if self._frozen_ind[cw_ind]==0: # position is non-frozen # sort list msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, msg_uhat, msg_llr) # duplicate l best decoders to pos l:2*l (kill other decoders) msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat, msg_llr, msg_pm) return msg_uhat, msg_llr, msg_pm def _decode_tf(self, llr_ch): """Main decoding function in TF. Initializes memory and calls recursive decoding function. """ batch_size = tf.shape(llr_ch)[0] # allocate memory for all 2*list_size decoders msg_uhat = tf.zeros([batch_size, 2*self._list_size, self._n_stages+1, self._n], self.rdtype) msg_llr = tf.zeros([batch_size, 2*self._list_size, self._n_stages, self._n], self.rdtype) # init all 2*l decoders with same llr_ch llr_ch = tf.reshape(llr_ch, [-1, 1, 1, self._n]) llr_ch = tf.tile(llr_ch,[1, 2*self._list_size, 1, 1]) # init last stage with llr_ch msg_llr = tf.concat([msg_llr, llr_ch], 2) # init all remaining L-1 decoders with high penalty pm0 = tf.zeros([batch_size, 1], self.rdtype) pm1 = self._llr_max * tf.ones([batch_size, self._list_size-1], self.rdtype) msg_pm = tf.concat([pm0, pm1, pm0, pm1], 1) # and call recursive graph function msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(self._cw_ind, msg_uhat, msg_llr, msg_pm) # and sort output msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm, msg_uhat, msg_llr) return [msg_uhat, msg_pm] #################################### # Helper functions for Numpy decoder #################################### def _update_rate0_code_np(self, cw_ind): """Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind`` in Numpy. See Eq. (26) in [Hashemi_SSCL]_. """ n = len(cw_ind) stage_ind = int(np.log2(n)) # update PM for each batch sample ind = np.expand_dims(self._dec_pointer, axis=-1) llr_in = np.take_along_axis(self.msg_llr[:, :, stage_ind, cw_ind], ind, axis=1) llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max) pm_val = np.log(1 + np.exp(-llr_clip)) self.msg_pm += np.sum(pm_val, axis=-1) def _update_rep_code_np(self, cw_ind): """Update rep. code (i.e., only rightmost bit is non-frozen) sub-code at position ``ind_u`` in Numpy. See Eq. (31) in [Hashemi_SSCL]_. """ n = len(cw_ind) stage_ind = int(np.log2(n)) bs = self._dec_pointer.shape[0] # update PM llr = np.zeros([bs, 2*self._list_size, n]) for i in range(bs): llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :] llr[i, :, :] = llr_i[:, cw_ind] # upper branch has negative llr values (bit is 1) llr[:, self._list_size:, :] = - llr[:, self._list_size:, :] llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max) pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1) self.msg_pm += pm_val for i in range(bs): ind_dec = self._dec_pointer[i, self._list_size:] for j in cw_ind: self.msg_uhat[i, ind_dec, stage_ind, j] = 1 # branch last bit and update pm at pos cw_ind[-1] self._update_single_bit_np([cw_ind[-1]]) self._sort_decoders_np() self._duplicate_paths_np() def _update_single_bit_np(self, ind_u): """Update single bit at position ``ind_u`` of all decoders in Numpy.""" if self._frozen_ind[ind_u]==0: # position is non-frozen ind_dec = np.expand_dims(self._dec_pointer[:, self._list_size:], axis=-1) uhat_slice = self.msg_uhat[:, :, 0, ind_u] np.put_along_axis(uhat_slice, ind_dec, 1., axis=1) self.msg_uhat[:, :, 0, ind_u] = uhat_slice def _update_pm_np(self, ind_u): """ Update path metric of all decoders at bit position ``ind_u`` in Numpy. We apply Eq. (10) from [Stimming_LLR]_. """ ind = np.expand_dims(self._dec_pointer, axis=-1) u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1) u_hat = np.squeeze(u_hat, axis=-1) llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1) llr_in = np.squeeze(llr_in, axis=-1) llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max) self.msg_pm += np.log(1 + np.exp(-np.multiply((1-2*u_hat), llr_clip))) def _sort_decoders_np(self): """Sort decoders according to their path metric.""" ind = np.argsort(self.msg_pm, axis=-1) self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1) self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1) def _cn_op_np(self, x, y): """Check node update (boxplus) for LLRs in Numpy. See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations. """ x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max) y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max) # avoid division for numerical stability llr_out = np.log(1 + np.exp(x_in + y_in)) llr_out -= np.log(np.exp(x_in) + np.exp(y_in)) return llr_out def _vn_op_np(self, x, y, u_hat): """Variable node update (boxplus) for LLRs in Numpy.""" return np.multiply((1-2*u_hat), x) + y def _duplicate_paths_np(self): """Copy first ``list_size``/2 paths into lower part in Numpy. Decoder indices are encoded in ``self._dec_pointer``. """ ind_low = self._dec_pointer[:, :self._list_size] ind_up = self._dec_pointer[:, self._list_size:] for i in range(ind_up.shape[0]): self.msg_uhat[i, ind_up[i,:], :, :] = self.msg_uhat[i, ind_low[i,:], :, :] self.msg_llr[i, ind_up[i,:],:,:] = self.msg_llr[i, ind_low[i,:],:,:] # pm must be sorted directly (not accessed via pointer) self.msg_pm[:, self._list_size:] = self.msg_pm[:, :self._list_size] def _polar_decode_scl_np(self, cw_ind): """Recursive decoding function in Numpy. We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_ and branch the messages into a `left` and `right` update paths until reaching a leaf node. Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the tree depth while maintaining the same output. """ n = len(cw_ind) stage_ind = int(np.log2(n)) # recursively branch through decoding tree if n>1: # prune tree if rate-0 subcode or rep-code is detected if self._use_fast_scl: if np.sum(self._frozen_ind[cw_ind])==n: # rate0 code detected self._update_rate0_code_np(cw_ind) return if (self._frozen_ind[cw_ind[-1]]==0 and np.sum(self._frozen_ind[cw_ind[:-1]])==n-1): # rep code detected self._update_rep_code_np(cw_ind) return cw_ind_left = cw_ind[0:int(n/2)] cw_ind_right = cw_ind[int(n/2):] # ----- left branch ----- llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left] llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right] self.msg_llr[:, :, stage_ind-1, cw_ind_left] = self._cn_op_np( llr_left, llr_right) # call left branch decoder self._polar_decode_scl_np(cw_ind_left) # ----- right branch ----- u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left] llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left] llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right] self.msg_llr[:, :, stage_ind-1, cw_ind_right] = self._vn_op_np( llr_left, llr_right, u_hat_left_up) # call right branch decoder self._polar_decode_scl_np(cw_ind_right) # combine u_hat u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left] u_hat_right_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_right] # u_hat_left_up XOR u_hat_right_up u_hat_left = (u_hat_left_up != u_hat_right_up) + 0 u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1) # provide u_hat for next higher stage self.msg_uhat[:, :, stage_ind, cw_ind] = u_hat else: # if leaf is reached perform basic decoding op (=decision) self._update_single_bit_np(cw_ind) # update PM self._update_pm_np(cw_ind) # position is non-frozen if self._frozen_ind[cw_ind]==0: # sort list self._sort_decoders_np() # duplicate the best list_size decoders self._duplicate_paths_np() return def _decode_np_batch(self, llr_ch): """Decode batch of ``llr_ch`` with Numpy decoder.""" bs = llr_ch.shape[0] # allocate memory for all 2*list_size decoders self.msg_uhat = np.zeros([bs, 2*self._list_size, self._n_stages+1, self._n]) self.msg_llr = np.zeros([bs, 2*self._list_size, self._n_stages+1, self._n]) self.msg_pm = np.zeros([bs, 2*self._list_size]) # L-1 decoders start with high penalty self.msg_pm[:,1:self._list_size] = self._llr_max # same for the second half of the L-1 decoders self.msg_pm[:,self._list_size+1:] = self._llr_max # use pointers to avoid in-memory sorting self._dec_pointer = np.arange(2*self._list_size) self._dec_pointer = np.tile(np.expand_dims(self._dec_pointer, axis=0), [bs,1]) # init llr_ch (broadcast via list dimension) self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1) # call recursive graph function self._polar_decode_scl_np(self._cw_ind) # select most likely candidate self._sort_decoders_np() # remove pointers for ind in range(bs): self.msg_uhat[ind, :, :, :] = self.msg_uhat[ind, self._dec_pointer[ind], :, :] return self.msg_uhat, self.msg_pm def _decode_np_hybrid(self, llr_ch, u_hat_sc, crc_valid): """Hybrid SCL decoding stage that decodes iff CRC from previous SC decoding attempt failed. This option avoids the usage of the high-complexity SCL decoder in cases where SC would be sufficient. For further details we refer to [Cammerer_Hybrid_SCL]_ (we use SC instead of the proposed BP stage). Remark: This decoder does not exactly implement SCL as the CRC can be false positive after the SC stage. However, in these cases SCL+CRC may also yield the wrong results. Remark 2: Due to the excessive control flow (if/else) and the varying batch-sizes, this function is only available as Numpy decoder (i.e., runs on the CPU). """ bs = llr_ch.shape[0] crc_valid = np.squeeze(crc_valid, axis=-1) # index of codewords that need SCL decoding ind_invalid = np.arange(bs)[np.invert(crc_valid)] # init SCL decoder for bs_hyb samples requiring SCL dec. llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0) msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb) # merge results with previously decoded SC results msg_uhat = np.zeros([bs, 2*self._list_size, 1, self._n]) msg_pm = np.ones([bs, 2*self._list_size]) * self._llr_max * self.k msg_pm[:, 0] = 0 # copy SC data msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc ind_hyb = 0 for ind in range(bs): if not crc_valid[ind]: #copy data from SCL msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :] msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :] ind_hyb += 1 return msg_uhat, msg_pm ######################## # Sionna Block functions ######################## def build(self, input_shape): """Build and check if shape of input is invalid.""" if input_shape[-1]!=self._n: raise ValueError("Invalid input shape.") def call(self, llr_ch): """Successive cancellation list (SCL) decoding function. This function performs successive cancellation list decoding and returns the estimated information bits. An outer CRC can be applied optionally by setting ``crc_degree``. Args: llr_ch (tf.float): Tensor of shape `[...,n]` containing the channel LLR values (as logits). Returns: `tf.float`: Tensor of shape `[...,k]` containing hard-decided estimations of all ``k`` information bits. Note: This function recursively unrolls the SCL decoding tree, thus, for larger values of ``n`` building the decoding graph can become time consuming. Please consider the ``cpu_only`` option instead. """ input_shape = llr_ch.shape new_shape = [-1, self._n] llr_ch = tf.reshape(llr_ch, new_shape) llr_ch = -1. * llr_ch # logits are converted into "true" llrs # if activated use Numpy decoder if self._use_hybrid_sc: # use SC decoder to decode first u_hat = self._decoder_sc(-llr_ch) _, crc_valid = self._crc_decoder(u_hat) msg_uhat, msg_pm = tf.py_function(func=self._decode_np_hybrid, inp=[llr_ch, u_hat, crc_valid], Tout=[self.rdtype, self.rdtype]) # note: return shape is only 1 in 3. dim (to avoid copy overhead) msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, 1, self._n]) msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size]) else: if self._cpu_only: msg_uhat, msg_pm = tf.py_function(func=self._decode_np_batch, inp=[llr_ch], Tout=[self.rdtype, self.rdtype]) # restore shape information msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, self._n_stages+1, self._n]) msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size]) else: msg_uhat, msg_pm = self._decode_tf(llr_ch) # check CRC (and remove CRC parity bits) if self._use_crc: u_hat_list = tf.gather(msg_uhat[:, :, 0, :], self._info_pos, axis=-1) # undo input bit interleaving # remark: the output is not interleaved for compatibility with SC if self._iil: u_hat_list_crc = tf.gather(u_hat_list, self._ind_iil_inv, axis=-1) else: # no interleaving applied u_hat_list_crc = u_hat_list _, crc_valid = self._crc_decoder(u_hat_list_crc) # add penalty to pm if CRC fails pm_penalty = ((1. - tf.cast(crc_valid, self.rdtype)) * self._llr_max * self.k) msg_pm += tf.squeeze(pm_penalty, axis=2) # select most likely candidate cand_ind = tf.argmin(msg_pm, axis=-1) c_hat = tf.gather(msg_uhat[:, :, 0, :], cand_ind, axis=1, batch_dims=1) u_hat = tf.gather(c_hat, self._info_pos, axis=-1) # and reconstruct input shape output_shape = input_shape.as_list() output_shape[-1] = self.k output_shape[0] = -1 # first dim can be dynamic (None) u_hat_reshape = tf.reshape(u_hat, output_shape) if self._return_crc_status: # reconstruct CRC status crc_status = tf.gather(crc_valid, cand_ind, axis=1, batch_dims=1) # reconstruct shape output_shape.pop() # remove last dimension crc_status = tf.reshape(crc_status, output_shape) # return info bits and CRC status return u_hat_reshape, crc_status else: # return only info bits return u_hat_reshape
[docs] class PolarBPDecoder(Block): # pylint: disable=line-too-long """Belief propagation (BP) decoder for Polar codes [Arikan_Polar]_ and Polar-like codes based on [Arikan_BP]_ and [Forney_Graphs]_. Remark: The PolarBPDecoder does currently not support XLA. Parameters ---------- frozen_pos: ndarray Array of `int` defining the ``n-k`` indices of the frozen positions. n: int Defining the codeword length. num_iter: int Defining the number of decoder iterations (no early stopping used at the moment). hard_out: `bool`, (default `True`) If `True`, the decoder provides hard-decided information bits instead of soft-values. precision : `None` (default) | 'single' | 'double' Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. Input ----- llr_ch: [...,n], tf.float32 Tensor containing the channel logits/llr values. Output ------ : [...,k], tf.float32 Tensor containing bit-wise soft-estimates (or hard-decided bit-values) of all ``k`` information bits. Note ---- This decoder is fully differentiable and, thus, well-suited for gradient descent-based learning tasks such as `learned code design` [Ebada_Design]_. As commonly done, we assume frozen bits are set to `0`. Please note that - although its practical relevance is only little - setting frozen bits to `1` may result in `affine` codes instead of linear code as the `all-zero` codeword is not necessarily part of the code any more. """ def __init__(self, frozen_pos, n, num_iter=20, hard_out=True, precision=None, **kwargs): super().__init__(precision=precision, **kwargs) # assert error if r>1 or k, n are negative if not isinstance(n, numbers.Number): raise TypeError("n must be a number.") n = int(n) # n can be float (e.g. as result of n=k*r) if not np.issubdtype(frozen_pos.dtype, int): raise TypeError("frozen_pos contains non int.") if len(frozen_pos)>n: msg = "Num. of elements in frozen_pos cannot be greater than n." raise ValueError(msg) if np.log2(n)!=int(np.log2(n)): raise ValueError("n must be a power of 2.") if not isinstance(hard_out, bool): raise TypeError("hard_out must be boolean.") # store internal attributes self._n = n self._frozen_pos = frozen_pos self._k = self._n - len(self._frozen_pos) self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos) if self._k!=len(self._info_pos): raise ArithmeticError("Internal error: invalid info_pos generated.") if not isinstance(num_iter, int): raise TypeError("num_iter must be integer.") if num_iter<=0: raise ValueError("num_iter must be a positive value.") self._num_iter = tf.constant(num_iter, dtype=tf.int32) self._llr_max = 19.3 # internal max LLR value self._hard_out = hard_out # depth of decoding graph self._n_stages = int(np.log2(self._n)) ############################### # Public methods and properties ############################### @property def n(self): """Codeword length""" return self._n @property def k(self): """Number of information bits""" return self._k @property def frozen_pos(self): """Frozen positions for Polar decoding""" return self._frozen_pos @property def info_pos(self): """Information bit positions for Polar encoding""" return self._info_pos @property def llr_max(self): """Maximum LLR value for internal calculations""" return self._llr_max @property def num_iter(self): """Number of decoding iterations""" return self._num_iter @property def hard_out(self): """Indicates if decoder hard-decides outputs""" return self._hard_out @num_iter.setter def num_iter(self, num_iter): "Number of decoding iterations." if not isinstance(num_iter, int): raise ValueError('num_iter must be int.') if num_iter<0: raise ValueError('num_iter cannot be negative.') self._num_iter = tf.constant(num_iter, dtype=tf.int32) ################# # Utility methods ################# def _boxplus_tf(self, x, y): """Check-node update (boxplus) for LLR inputs. Operations are performed element-wise. """ x_in = tf.clip_by_value(x, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) y_in = tf.clip_by_value(y, clip_value_min=-self._llr_max, clip_value_max=self._llr_max) # avoid division for numerical stability llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in)) llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in)) return llr_out def _decode_bp(self, llr_ch, num_iter): """Iterative BP decoding function with LLR-values. Args: llr_ch (tf.float32): Tensor of shape `[batch_size, n]` containing the channel logits/llr values where `batch_size` denotes the batch-size. num_iter (int): Defining the number of decoder iteration (no early stopping used at the moment). Returns: `tf.float32`: Tensor of shape `[batch_size, k]` containing bit-wise soft-estimates (or hard-decided bit-values) of all information bits. """ bs = tf.shape(llr_ch)[0] # store intermediate Tensors in TensorArray msg_l = tf.TensorArray(self.rdtype, size=num_iter*(self._n_stages+1), dynamic_size=False, clear_after_read=False) msg_r = tf.TensorArray(self.rdtype, size=num_iter*(self._n_stages+1), dynamic_size=False, clear_after_read=False) # init frozen positions with infinity msg_r_in = np.zeros([1, self._n]) msg_r_in[:, self._frozen_pos] = self._llr_max # copy for all batch-samples msg_r_in = tf.tile(tf.constant(msg_r_in, self.rdtype), [bs, 1]) msg_r_in = tf.cast(msg_r_in, self.rdtype) # perform decoding iterations for ind_it in tf.range(self._num_iter): # update left-to-right messages for ind_s in range(self._n_stages): # calc indices ind_range = np.arange(int(self._n/2)) ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s) ind_2 = ind_1 + 2**ind_s # simplify gather with concatenated outputs ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0)) # load incoming l messages if ind_s==self._n_stages-1: l1_in = tf.gather(llr_ch, ind_1, axis=1) l2_in = tf.gather(llr_ch, ind_2, axis=1) elif ind_it==0: l1_in = tf.zeros([bs, int(self._n/2)], self.rdtype) l2_in = tf.zeros([bs, int(self._n/2)], self.rdtype) else: l_in = msg_l.read((ind_s+1) + (ind_it-1)*(self._n_stages+1)) l1_in = tf.gather(l_in, ind_1, axis=1) l2_in = tf.gather(l_in, ind_2, axis=1) # load incoming r messages if ind_s==0: r1_in = tf.gather(msg_r_in, ind_1, axis=1) r2_in = tf.gather(msg_r_in, ind_2, axis=1) else: r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1)) r1_in = tf.gather(r_in, ind_1, axis=1) r2_in = tf.gather(r_in, ind_2, axis=1) r1_out = self._boxplus_tf(r1_in, l2_in + r2_in) r2_out = self._boxplus_tf(r1_in, l1_in) + r2_in # and re-concatenate output r_out = tf.concat([r1_out, r2_out], 1) r_out = tf.gather(r_out, ind_inv, axis=1) msg_r = msg_r.write((ind_s+1) + ind_it*(self._n_stages+1), r_out) # update right-to-left messages for ind_s in range(self._n_stages-1, -1, -1): ind_range = np.arange(int(self._n/2)) ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s) ind_2 = ind_1 + 2**ind_s ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0)) # load messages if ind_s==self._n_stages-1: l1_in = tf.gather(llr_ch, ind_1, axis=1) l2_in = tf.gather(llr_ch, ind_2, axis=1) else: l_in = msg_l.read((ind_s+1)+ind_it*(self._n_stages+1)) l1_in = tf.gather(l_in, ind_1, axis=1) l2_in = tf.gather(l_in, ind_2, axis=1) if ind_s==0: r1_in = tf.gather(msg_r_in, ind_1, axis=1) r2_in = tf.gather(msg_r_in, ind_2, axis=1) else: r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1)) r1_in = tf.gather(r_in, ind_1, axis=1) r2_in = tf.gather(r_in, ind_2, axis=1) # node update functions l1_out = self._boxplus_tf(l1_in, l2_in + r2_in) l2_out = self._boxplus_tf(r1_in, l1_in) + l2_in l_out = tf.concat([l1_out, l2_out], 1) l_out = tf.gather(l_out, ind_inv, axis=1) msg_l = msg_l.write(ind_s + ind_it*(self._n_stages+1), l_out) # recover u_hat u_hat = tf.gather(msg_l.read((num_iter-1)*(self._n_stages+1)), self._info_pos, axis=1) # if active, hard-decide output bits if self._hard_out: u_hat = tf.where(u_hat>0, tf.constant(0., dtype=self.rdtype), tf.constant(1., dtype=self.rdtype)) else: # re-transform soft output to logits (instead of llrs) u_hat = -1. * u_hat return u_hat ######################## # Sionna Block functions ######################## def build(self, input_shape): """Build and check if shape of input is invalid.""" if input_shape[-1]!=self._n: raise ValueError("Invalid input shape") def call(self, llr_ch): """Iterative BP decoding function. This function performs `num_iter` belief propagation decoding iterations and returns the estimated information bits. Args: llr_ch (tf.float): Tensor of shape `[...,n]` containing the channel logits/llr values. Returns: `tf.float`: Tensor of shape `[...,k]` containing bit-wise soft-estimates (or hard-decided bit-values) of all ``k`` information bits. Note: This function recursively unrolls the BP decoding graph, thus, for larger values of ``n`` or more iterations, building the decoding graph can become time and memory consuming. """ # Reshape inputs to [-1, n] input_shape = llr_ch.shape new_shape = [-1, self._n] llr_ch = tf.reshape(llr_ch, new_shape) llr_ch = -1. * llr_ch # logits are converted into "true" llrs # and decode u_hat = self._decode_bp(llr_ch, self._num_iter) # and reconstruct input shape output_shape = input_shape.as_list() output_shape[-1] = self.k output_shape[0] = -1 # first dim can be dynamic (None) u_hat_reshape = tf.reshape(u_hat, output_shape) return u_hat_reshape
[docs] class Polar5GDecoder(Block): # pylint: disable=line-too-long """Wrapper for 5G compliant decoding including rate-recovery and CRC removal. Parameters ---------- enc_polar: Polar5GEncoder Instance of the :class:`~sionna.phy.fec.polar.encoding.Polar5GEncoder` used for encoding including rate-matching. dec_type: "SC" (default) | "SCL" | "hybSCL" | "BP" Defining the decoder to be used. Must be one of the following `{"SC", "SCL", "hybSCL", "BP"}`. list_size: int, (default 8) Defining the list size `iff` list-decoding is used. Only required for ``dec_types`` `{"SCL", "hybSCL"}`. num_iter: int, (default 20) Defining the number of BP iterations. Only required for ``dec_type`` `"BP"`. return_crc_status: `bool`, (default `False`) If `True`, the decoder additionally returns the CRC status indicating if a codeword was (most likely) correctly recovered. precision : `None` (default) | 'single' | 'double' Precision used for internal calculations and outputs. If set to `None`, :py:attr:`~sionna.phy.config.precision` is used. Input ----- llr_ch: [...,n], tf.float Tensor containing the channel logits/llr values. Output ------ b_hat : [...,k], tf.float Binary tensor containing hard-decided estimations of all `k` information bits. crc_status : [...], tf.bool CRC status indicating if a codeword was (most likely) correctly recovered. This is only returned if ``return_crc_status`` is True. Note that false positives are possible. Note ---- This block supports the uplink and downlink Polar rate-matching scheme without `codeword segmentation`. Although the decoding `list size` is not provided by 3GPP [3GPPTS38212]_, the consortium has agreed on a `list size` of 8 for the 5G decoding reference curves [Bioglio_Design]_. All list-decoders apply `CRC-aided` decoding, however, the non-list decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an effective rate-loss. """ def __init__(self, enc_polar, dec_type="SC", list_size=8, num_iter=20, return_crc_status=False, precision=None, **kwargs): super().__init__(precision=precision, **kwargs) if not isinstance(enc_polar, Polar5GEncoder): raise TypeError("enc_polar must be Polar5GEncoder.") if not isinstance(dec_type, str): raise TypeError("dec_type must be str.") # list_size and num_iter are not checked here (done during decoder init) # Store internal attributes self._n_target = enc_polar.n_target self._k_target = enc_polar.k_target self._n_polar = enc_polar.n_polar self._k_polar = enc_polar.k_polar self._k_crc = enc_polar.enc_crc.crc_length self._bil = enc_polar._channel_type == "uplink" self._iil = enc_polar._channel_type == "downlink" self._llr_max = 100 # Internal max LLR value (for punctured positions) self._enc_polar = enc_polar self._dec_type = dec_type # Initialize the de-interleaver patterns self._init_interleavers() # Initialize decoder if dec_type=="SC": print("Warning: 5G Polar codes use an integrated CRC that " \ "cannot be materialized with SC decoding and, thus, " \ "causes a degraded performance. Please consider SCL " \ "decoding instead.") self._polar_dec = PolarSCDecoder(self._enc_polar.frozen_pos, self._n_polar) elif dec_type=="SCL": self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos, self._n_polar, crc_degree=self._enc_polar.enc_crc.crc_degree, list_size=list_size, ind_iil_inv = self.ind_iil_inv) elif dec_type=="hybSCL": self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos, self._n_polar, crc_degree=self._enc_polar.enc_crc.crc_degree, list_size=list_size, use_hybrid_sc=True, ind_iil_inv = self.ind_iil_inv) elif dec_type=="BP": print("Warning: 5G Polar codes use an integrated CRC that " \ "cannot be materialized with BP decoding and, thus, " \ "causes a degraded performance. Please consider SCL " \ " decoding instead.") if not isinstance(num_iter, int): raise TypeError("num_iter must be int.") if num_iter <= 0: raise ValueError("num_iter must be positive.") self._num_iter = num_iter self._polar_dec = PolarBPDecoder(self._enc_polar.frozen_pos, self._n_polar, num_iter=num_iter, hard_out=True) else: raise ValueError("Unknown value for dec_type.") if not isinstance(return_crc_status, bool): raise TypeError("return_crc_status must be bool.") self._return_crc_status = return_crc_status if self._return_crc_status: # init crc decoder if dec_type in ("SCL", "hybSCL"): # re-use CRC decoder from list decoder self._dec_crc = self._polar_dec._crc_decoder else: # init new CRC decoder for BP and SC self._dec_crc = CRCDecoder(self._enc_polar._enc_crc) ############################### # Public methods and properties ############################### @property def k_target(self): """Number of information bits including rate-matching""" return self._k_target @property def n_target(self): """Codeword length including rate-matching""" return self._n_target @property def k_polar(self): """Number of information bits of mother Polar code""" return self._k_polar @property def n_polar(self): """Codeword length of mother Polar code""" return self._n_polar @property def frozen_pos(self): """Frozen positions for Polar decoding""" return self._frozen_pos @property def info_pos(self): """Information bit positions for Polar encoding""" return self._info_pos @property def llr_max(self): """Maximum LLR value for internal calculations""" return self._llr_max @property def dec_type(self): """Decoder type used for decoding as str""" return self._dec_type @property def polar_dec(self): """Decoder instance used for decoding""" return self._polar_dec ################# # Utility methods ################# def _init_interleavers(self): """Initialize inverse interleaver patterns for rate-recovery.""" # Channel interleaver ind_ch_int = self._enc_polar.channel_interleaver( np.arange(self._n_target)) self.ind_ch_int_inv = np.argsort(ind_ch_int) # Find inverse perm # Sub-block interleaver ind_sub_int = self._enc_polar.subblock_interleaving( np.arange(self._n_polar)) self.ind_sub_int_inv = np.argsort(ind_sub_int) # Find inverse perm # input bit interleaver if self._iil: self.ind_iil_inv = np.argsort(self._enc_polar.input_interleaver( np.arange(self._k_polar))) else: self.ind_iil_inv = None ######################## # Sionna Block functions ######################## def build(self, input_shape): """Build and check if shape of input is invalid.""" if input_shape[-1]!=self._n_target: raise ValueError("Invalid input shape.") def call(self, llr_ch): """Polar decoding and rate-recovery for uplink 5G Polar codes. Args: llr_ch (tf.float): Tensor of shape `[...,n]` containing the channel logits/llr values. Returns: `tf.float`: Tensor of shape `[...,k]` containing hard-decided estimates of all ``k`` information bits. """ input_shape = llr_ch.shape new_shape = [-1, self._n_target] llr_ch = tf.reshape(llr_ch, new_shape) # Note: logits are not inverted here; this is done in the decoder itself # 1.) Undo channel interleaving if self._bil: llr_deint = tf.gather(llr_ch, self.ind_ch_int_inv, axis=1) else: llr_deint = llr_ch # 2.) Remove puncturing, shortening, repetition (see Sec. 5.4.1.2) # a) Puncturing: set LLRs to 0 # b) Shortening: set LLRs to infinity # c) Repetition: combine LLRs if self._n_target >= self._n_polar: # Repetition coding # Add the last n_rep positions to the first llr positions n_rep = self._n_target - self._n_polar llr_1 = llr_deint[:,:n_rep] llr_2 = llr_deint[:,n_rep:self._n_polar] llr_3 = llr_deint[:,self._n_polar:] llr_dematched = tf.concat([llr_1+llr_3, llr_2], 1) else: if self._k_polar/self._n_target <= 7/16: # Puncturing # Append n_polar - n_target "zero" llrs to first positions llr_zero = tf.zeros([tf.shape(llr_deint)[0], self._n_polar-self._n_target], self.rdtype) llr_dematched = tf.concat([llr_zero, llr_deint], 1) else: # Shortening # Append n_polar - n_target "-infinity" llrs to last positions # Remark: we still operate with logits here, thus the neg. sign llr_infty = -self._llr_max * tf.ones([tf.shape(llr_deint)[0], self._n_polar-self._n_target], self.rdtype) llr_dematched = tf.concat([llr_deint, llr_infty], 1) # 3.) Remove subblock interleaving llr_dec = tf.gather(llr_dematched, self.ind_sub_int_inv, axis=1) # 4.) Run main decoder u_hat_crc = self._polar_dec(llr_dec) # 5.) Shortening should be implicitly recovered by decoder # 6.) Remove input bit interleaving for downlink channels only if self._iil: u_hat_crc = tf.gather(u_hat_crc, self.ind_iil_inv, axis=1) # 7.) Evaluate or remove CRC (and PC) if self._return_crc_status: # for compatibility with SC/BP, a dedicated CRC decoder is # used here (instead of accessing the interal SCL) u_hat, crc_status = self._dec_crc(u_hat_crc) else: # just remove CRC bits u_hat = u_hat_crc[:,:-self._k_crc] # And reconstruct input shape output_shape = input_shape.as_list() output_shape[-1] = self._k_target output_shape[0] = -1 # First dim can be dynamic (None) u_hat_reshape = tf.reshape(u_hat, output_shape) # and cast to internal rdtype (as subblocks may have different configs) u_hat_reshape = tf.cast(u_hat_reshape, dtype=self.rdtype) if self._return_crc_status: # reconstruct CRC shape output_shape.pop() # remove last dimension crc_status = tf.reshape(crc_status, output_shape) return u_hat_reshape, crc_status else: return u_hat_reshape