#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Block for convolutional code encoding."""
import tensorflow as tf
from sionna.phy import Block
from sionna.phy.fec.utils import bin2int_tf, int2bin_tf
from sionna.phy.fec.conv.utils import polynomial_selector, Trellis
[docs]
class ConvEncoder(Block):
# pylint: disable=line-too-long
r"""Encodes an information binary tensor to a convolutional codeword.
Currently, only generator polynomials for codes of rate=1/n for n=2,3,4,...
are allowed.
Parameters
----------
gen_poly: tuple
Sequence of strings with each string being a 0,1 sequence. If
`None`, ``rate`` and ``constraint_length`` must be provided.
rate: float
Valid values are 1/3 and 0.5. Only required if ``gen_poly`` is
`None`.
constraint_length: int
Valid values are between 3 and 8 inclusive. Only required if
``gen_poly`` is `None`.
rsc: boolean
Boolean flag indicating whether the Trellis generated is recursive
systematic or not. If `True`, the encoder is recursive-systematic.
In this case first polynomial in ``gen_poly`` is used as the
feedback polynomial. Defaults to `False`.
terminate: boolean
Encoder is terminated to all zero state if `True`.
If terminated, the `true` rate of the code is slightly lower than
``rate``.
precision : `None` (default) | 'single' | 'double'
Precision used for internal calculations and outputs.
If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.
Input
-----
bits : [...,k], tf.float
Binary tensor containing the information bits where `k` is the
information length
Output
------
: [...,k/rate], tf.float
Binary tensor containing the encoded codeword for the given input
information tensor where `rate` is
:math:`\frac{1}{\textrm{len}\left(\textrm{gen_poly}\right)}`
(if ``gen_poly`` is provided).
Note
----
The generator polynomials from [Moon]_ are available for various
rate and constraint lengths. To select them, use the ``rate`` and
``constraint_length`` arguments.
In addition, polynomials for any non-recursive convolutional encoder
can be given as input via ``gen_poly`` argument. Currently, only
polynomials with rate=1/n are supported. When the ``gen_poly`` argument
is given, the ``rate`` and ``constraint_length`` arguments are ignored.
Various notations are used in the literature to represent the generator
polynomials for convolutional codes. In [Moon]_, the octal digits
format is primarily used. In the octal format, the generator polynomial
`10011` corresponds to 46. Another widely used format
is decimal notation with MSB. In this notation, polynomial `10011`
corresponds to 19. For simplicity, the
:class:`~sionna.phy.fec.conv.encoding.ConvEncoder` only accepts the bit
format i.e. `10011` as ``gen_poly`` argument.
Also note that ``constraint_length`` and ``memory`` are two different
terms often used to denote the strength of a convolutional code. In this
sub-package, we use ``constraint_length``. For example, the
polynomial `10011` has a ``constraint_length`` of 5, however its
``memory`` is only 4.
When ``terminate`` is `True`, the true rate of the convolutional
code is slightly lower than ``rate``. It equals
:math:`\frac{r*k}{k+\mu}` where `r` denotes ``rate`` and
:math:`\mu` is ``constraint_length`` - 1. For example when
``terminate`` is `True`, ``k=100``,
:math:`\mu=4` and ``rate`` =0.5, true rate equals
:math:`\frac{0.5*100}{104}=0.481`.
"""
def __init__(self,
gen_poly=None,
rate=1/2,
constraint_length=3,
rsc=False,
terminate=False,
precision=None,
**kwargs):
super().__init__(precision=precision, **kwargs)
# check for consistency
if gen_poly is not None:
if not all(isinstance(poly, str) for poly in gen_poly):
raise TypeError("Each element of gen_poly must be a string.")
if not all(len(poly)==len(gen_poly[0]) for poly in gen_poly):
raise ValueError("Each polynomial must be of same length.")
if not all(all(
char in ['0','1'] for char in poly) for poly in gen_poly):
raise ValueError("Each Polynomial must be a string of 0/1 s.")
self._gen_poly = gen_poly
else:
valid_rates = (1/2, 1/3)
valid_constraint_length = (3, 4, 5, 6, 7, 8)
if constraint_length not in valid_constraint_length:
raise ValueError("Constraint length must be between 3 and 8.")
if rate not in valid_rates:
raise ValueError("Rate must be 1/3 or 1/2.")
self._gen_poly = polynomial_selector(rate, constraint_length)
self._rsc = rsc
self._terminate = terminate
self._coderate_desired = 1/len(self.gen_poly)
# Differ when terminate is True
self._coderate = self._coderate_desired
self._trellis = Trellis(self.gen_poly,rsc=self._rsc)
self._mu = self.trellis._mu
# conv_k denotes number of input bit streams.
# Only 1 allowed in current implementation
self._conv_k = self._trellis.conv_k
# conv_n denotes number of output bits for conv_k input bits
self._conv_n = self._trellis.conv_n
self._ni = 2**self._conv_k
self._no = 2**self._conv_n
self._ns = self._trellis.ns
# for conv codes, the code dimensions are unknown during initialization
self._k = None
self._n = None
#########################################
# Public methods and properties
#########################################
@property
def gen_poly(self):
"""Generator polynomial used by the encoder"""
return self._gen_poly
@property
def coderate(self):
"""Rate of the code used in the encoder"""
if self.terminate and self._k is None:
print("Note that, due to termination, the true coderate is lower "\
"than the returned design rate. "\
"The exact true rate is dependent on the value of k and "\
"hence cannot be computed before the first call().")
elif self.terminate and self._k is not None:
term_factor = self._k/(self._k + self._mu)
self._coderate = self._coderate_desired*term_factor
return self._coderate
@property
def trellis(self):
"""Trellis object used during encoding"""
return self._trellis
@property
def terminate(self):
"""Indicates if the convolutional encoder is terminated"""
return self._terminate
@property
def k(self):
"""Number of information bits per codeword"""
if self._k is None:
print("Note: The value of k cannot be computed before the first " \
"call().")
return self._k
@property
def n(self):
"""Number of codeword bits"""
if self._n is None:
print("Note: The value of n cannot be computed before the first " \
"call().")
return self._n
########################
# Sionna block functions
########################
def build(self, input_shape):
"""Build block and check dimensions.
Args:
input_shape: shape of input tensor (...,k).
"""
self._k = input_shape[-1]
self._n = int(self._k/self.coderate)
# num_syms denote number of encoding periods or state transitions.
# different from _k when _conv_k > 1.
self.num_syms = int(self._k//self._conv_k)
def call(self, bits, /):
"""Convolutional code encoding function.
Args:
inputs (tf.float32): Information tensor of shape `[...,k]`.
Returns:
`tf.float32`: Encoded codeword tensor of shape `[...,n]`.
"""
if bits.shape[-1] != self._k:
self.build(bits.shape)
# we cast internally to tf.int32 to enable bitshift operations
msg = tf.cast(bits, tf.int32)
output_shape = msg.get_shape().as_list()
output_shape[0] = -1 # overwrite batch dim (can be dynamic/none)
output_shape[-1] = self._n # assign n to the last dim
msg_reshaped = tf.reshape(msg, [-1, self._k])
term_syms = int(self._mu) if self._terminate else 0
prev_st = tf.zeros([tf.shape(msg_reshaped)[0]], tf.int32)
ta = tf.TensorArray(tf.int32, size=self.num_syms, dynamic_size=False)
idx_offset = range(0, self._conv_k)
for idx in tf.range(0, self._k, self._conv_k):
msg_bits_idx = tf.gather(msg_reshaped,
idx + idx_offset,
axis=-1)
msg_idx = bin2int_tf(msg_bits_idx)
indices = tf.stack([prev_st, msg_idx], -1)
new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices)
idx_syms = tf.gather_nd(self._trellis.op_mat,
tf.stack([prev_st, new_st], -1))
idx_bits = int2bin_tf(idx_syms, self._conv_n)
ta = ta.write(idx//self._conv_k, idx_bits)
prev_st = new_st
cw = tf.concat(tf.unstack(ta.stack()), axis=1)
ta_term = tf.TensorArray(tf.int32, size=term_syms, dynamic_size=False)
# Termination
if self._terminate:
if self._rsc:
fb_poly = tf.constant([int(x) for x in self.gen_poly[0][1:]])
fb_poly_tiled = tf.tile(
tf.expand_dims(fb_poly,0),[tf.shape(prev_st)[0],1])
for idx in tf.range(0, term_syms, self._conv_k):
prev_st_bits = int2bin_tf(prev_st, self._mu)
if self._rsc:
msg_idx = tf.math.reduce_sum(
tf.multiply(fb_poly_tiled, prev_st_bits),-1)
msg_idx = tf.squeeze(int2bin_tf(msg_idx,1),-1)
else:
msg_idx = tf.zeros((tf.shape(prev_st)[0],), dtype=tf.int32)
indices = tf.stack([prev_st, msg_idx], -1)
new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices)
idx_syms = tf.gather_nd(self._trellis.op_mat,
tf.stack([prev_st, new_st], -1))
idx_bits = int2bin_tf(idx_syms, self._conv_n)
ta_term = ta_term.write(idx//self._conv_k, idx_bits)
prev_st = new_st
term_bits = tf.concat(tf.unstack(ta_term.stack()), axis=1)
cw = tf.concat([cw, term_bits], axis=-1)
cw = tf.cast(cw, self.rdtype)
cw_reshaped = tf.reshape(cw, output_shape)
return cw_reshaped