5G Channel Coding and Rate-Matching: Polar vs. LDPC Codes
“For block lengths of about 500, an IBM 7090 computer requires about 0.1 seconds per iteration to decode a block by probabilistic decoding scheme. Consequently, many hours of computation time are necessary to evaluate even a \(P(e)\) in the order of \({10^{-4}}\) .” Robert G. Gallager, 1963 [7]
In this notebook, you will learn about the different coding schemes in 5G NR and how rate-matching works (cf. 3GPP TS 38.212 [3]). The coding schemes are compared under different length/rate settings and for different decoders.
You will learn about the following components:
5G low-density parity-checks (LDPC) codes [7]. These codes support - without further segmentation - up to k=8448 information bits per codeword [3] for a wide range of coderates.
Polar codes [1] including CRC concatenation and rate-matching for 5G compliant en-/decoding is implemented for the Polar uplink control channel (UCI) [3]. Besides Polar codes, Reed-Muller (RM) codes and several decoders are available:
Successive cancellation (SC) decoding [1]
Successive cancellation list (SCL) decoding [2]
Hybrid SC / SCL decoding for enhanced throughput
Iterative belief propagation (BP) decoding [6]
Further, we will demonstrate the basic functionality of the Sionna forward error correction (FEC) module which also includes support for:
Convolutional codes with non-recursive encoding and Viterbi/BCJR decoding
Turbo codes and iterative BCJR decoding
Ordered statistics decoding (OSD) for any binary, linear code
Interleaving and scrambling
For additional technical background we refer the interested reader to [4,5,8].
Please note that block segmentation is not implemented as it only concatenates multiple code blocks without increasing the effective codewords length (from decoder’s perspective).
Some simulations in this notebook require severe simulation time, in particular if parameter sweeps are involved (e.g., different length comparisons). Please keep in mind that each cell in this notebook already contains the pre-computed outputs and no new execution is required to understand the examples.
Table of Contents
GPU Configuration and Imports
[1]:
import os
if os.getenv("CUDA_VISIBLE_DEVICES") is None:
gpu_num = 0 # Use "" to use the CPU
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Import Sionna
try:
import sionna
except ImportError as e:
# Install Sionna if package is not already installed
import os
os.system("pip install sionna")
import sionna
import tensorflow as tf
# Configure the notebook to use only a single GPU and allocate only as much memory as needed
# For more details, see https://www.tensorflow.org/guide/gpu
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
tf.config.experimental.set_memory_growth(gpus[0], True)
except RuntimeError as e:
print(e)
# Avoid warnings from TensorFlow
tf.get_logger().setLevel('ERROR')
# Set random seed for reproducibility
sionna.config.seed = 42
# Load the required Sionna components
from sionna.mapping import Constellation, Mapper, Demapper
from sionna.fec.polar import PolarEncoder, Polar5GEncoder, PolarSCLDecoder, Polar5GDecoder, PolarSCDecoder
from sionna.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from sionna.fec.polar.utils import generate_5g_ranking, generate_rm_code
from sionna.fec.conv import ConvEncoder, ViterbiDecoder, BCJRDecoder
from sionna.fec.turbo import TurboEncoder, TurboDecoder
from sionna.fec.linear import OSDecoder
from sionna.utils import BinarySource, ebnodb2no
from sionna.utils.metrics import count_block_errors
from sionna.channel import AWGN
from sionna.utils.plotting import PlotBER
[2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time # for throughput measurements
BER Performance of 5G Coding Schemes
Let us first focus on short length coding, e.g., for internet of things (IoT) and ultra-reliable low-latency communications (URLLC). We aim to reproduce similar results as in [9] for the coding schemes supported by Sionna.
For a detailed explanation of the PlotBER
class, we refer to the example notebook on Bit-Interleaved Coded Modulation.
The Sionna API allows to pass an encoder object/layer to the decoder initialization for the 5G decoders. This means that the decoder is directly associated to a specific encoder and knows all relevant code parameters. Please note that - of course - no data or information bits are exchanged between these two associated components. It just simplifies handling of the code parameters, in particular, if rate-matching is used.
Let us define the system model first. We use encoder and decoder as input parameter such that the model remains flexible w.r.t. the coding scheme.
[3]:
class System_Model(tf.keras.Model):
"""System model for channel coding BER simulations.
This model allows to simulate BERs over an AWGN channel with
QAM modulation. Arbitrary FEC encoder/decoder layers can be used to
initialize the model.
Parameters
----------
k: int
number of information bits per codeword.
n: int
codeword length.
num_bits_per_symbol: int
number of bits per QAM symbol.
encoder: Keras layer
A Keras layer that encodes information bit tensors.
decoder: Keras layer
A Keras layer that decodes llr tensors.
demapping_method: str
A string denoting the demapping method. Can be either "app" or "maxlog".
sim_esno: bool
A boolean defaults to False. If true, no rate-adjustment is done for the SNR calculation.
cw_estiamtes: bool
A boolean defaults to False. If true, codewords instead of information estimates are returned.
Input
-----
batch_size: int or tf.int
The batch_size used for the simulation.
ebno_db: float or tf.float
A float defining the simulation SNR.
Output
------
(u, u_hat):
Tuple:
u: tf.float32
A tensor of shape `[batch_size, k] of 0s and 1s containing the transmitted information bits.
u_hat: tf.float32
A tensor of shape `[batch_size, k] of 0s and 1s containing the estimated information bits.
"""
def __init__(self,
k,
n,
num_bits_per_symbol,
encoder,
decoder,
demapping_method="app",
sim_esno=False,
cw_estimates=False):
super().__init__()
# store values internally
self.k = k
self.n = n
self.sim_esno = sim_esno # disable rate-adjustment for SNR calc
self.cw_estimates=cw_estimates # if true codewords instead of info bits are returned
# number of bit per QAM symbol
self.num_bits_per_symbol = num_bits_per_symbol
# init components
self.source = BinarySource()
# initialize mapper and demapper for constellation object
self.constellation = Constellation("qam",
num_bits_per_symbol=self.num_bits_per_symbol)
self.mapper = Mapper(constellation=self.constellation)
self.demapper = Demapper(demapping_method,
constellation=self.constellation)
# the channel can be replaced by more sophisticated models
self.channel = AWGN()
# FEC encoder / decoder
self.encoder = encoder
self.decoder = decoder
@tf.function() # enable graph mode for increased throughputs
def call(self, batch_size, ebno_db):
# calculate noise variance
if self.sim_esno:
no = ebnodb2no(ebno_db,
num_bits_per_symbol=1,
coderate=1)
else:
no = ebnodb2no(ebno_db,
num_bits_per_symbol=self.num_bits_per_symbol,
coderate=self.k/self.n)
u = self.source([batch_size, self.k]) # generate random data
c = self.encoder(u) # explicitly encode
x = self.mapper(c) # map c to symbols x
y = self.channel([x, no]) # transmit over AWGN channel
llr_ch = self.demapper([y, no]) # demap y to LLRs
u_hat = self.decoder(llr_ch) # run FEC decoder (incl. rate-recovery)
if self.cw_estimates:
return c, u_hat
return u, u_hat
And let us define the codes to be simulated.
[4]:
# code parameters
k = 64 # number of information bits per codeword
n = 128 # desired codeword length
# Create list of encoder/decoder pairs to be analyzed.
# This allows automated evaluation of the whole list later.
codes_under_test = []
# 5G LDPC codes with 20 BP iterations
enc = LDPC5GEncoder(k=k, n=n)
dec = LDPC5GDecoder(enc, num_iter=20)
name = "5G LDPC BP-20"
codes_under_test.append([enc, dec, name])
# Polar Codes (SC decoding)
enc = Polar5GEncoder(k=k, n=n)
dec = Polar5GDecoder(enc, dec_type="SC")
name = "5G Polar+CRC SC"
codes_under_test.append([enc, dec, name])
# Polar Codes (SCL decoding) with list size 8.
# The CRC is automatically added by the layer.
enc = Polar5GEncoder(k=k, n=n)
dec = Polar5GDecoder(enc, dec_type="SCL", list_size=8)
name = "5G Polar+CRC SCL-8"
codes_under_test.append([enc, dec, name])
### non-5G coding schemes
# RM codes with SCL decoding
f,_,_,_,_ = generate_rm_code(3,7) # equals k=64 and n=128 code
enc = PolarEncoder(f, n)
dec = PolarSCLDecoder(f, n, list_size=8)
name = "Reed Muller (RM) SCL-8"
codes_under_test.append([enc, dec, name])
# Conv. code with Viterbi decoding
enc = ConvEncoder(rate=1/2, constraint_length=8)
dec = ViterbiDecoder(gen_poly=enc.gen_poly, method="soft_llr")
name = "Conv. Code Viterbi (constraint length 8)"
codes_under_test.append([enc, dec, name])
# Turbo. codes
enc = TurboEncoder(rate=1/2, constraint_length=4, terminate=False) # no termination used due to the rate loss
dec = TurboDecoder(enc, num_iter=8)
name = "Turbo Code (constraint length 4)"
codes_under_test.append([enc, dec, name])
Warning: 5G Polar codes use an integrated CRC that cannot be materialized with SC decoding and, thus, causes a degraded performance. Please consider SCL decoding instead.
Remark: some of the coding schemes are not 5G relevant, but are included in this comparison for the sake of completeness.
Generate a new BER plot figure to save and plot simulation results efficiently.
[5]:
ber_plot128 = PlotBER(f"Performance of Short Length Codes (k={k}, n={n})")
And run the BER simulation for each code.
[6]:
num_bits_per_symbol = 2 # QPSK
ebno_db = np.arange(0, 5, 0.5) # sim SNR range
# run ber simulations for each code we have added to the list
for code in codes_under_test:
print("\nRunning: " + code[2])
# generate a new model with the given encoder/decoder
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=code[0],
decoder=code[1])
# the first argument must be a callable (function) that yields u and u_hat for batch_size and ebno
ber_plot128.simulate(model, # the function have defined previously
ebno_dbs=ebno_db, # SNR to simulate
legend=code[2], # legend string for plotting
max_mc_iter=100, # run 100 Monte Carlo runs per SNR point
num_target_block_errors=1000, # continue with next SNR point after 1000 bit errors
batch_size=10000, # batch-size per Monte Carlo run
soft_estimates=False, # the model returns hard-estimates
early_stop=True, # stop simulation if no error has been detected at current SNR point
show_fig=False, # we show the figure after all results are simulated
add_bler=True, # in case BLER is also interesting
forward_keyboard_interrupt=True); # should be True in a loop
# and show the figure
ber_plot128(ylim=(1e-5, 1), show_bler=False) # we set the ylim to 1e-5 as otherwise more extensive simulations would be required for accurate curves.
Running: 5G LDPC BP-20
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6585e-01 | 8.6430e-01 | 106147 | 640000 | 8643 | 10000 | 4.9 |reached target block errors
0.5 | 1.2742e-01 | 7.1230e-01 | 81551 | 640000 | 7123 | 10000 | 0.1 |reached target block errors
1.0 | 8.6670e-02 | 5.0330e-01 | 55469 | 640000 | 5033 | 10000 | 0.1 |reached target block errors
1.5 | 5.0345e-02 | 3.0320e-01 | 32221 | 640000 | 3032 | 10000 | 0.1 |reached target block errors
2.0 | 2.5698e-02 | 1.5470e-01 | 16447 | 640000 | 1547 | 10000 | 0.1 |reached target block errors
2.5 | 1.1046e-02 | 6.8600e-02 | 14139 | 1280000 | 1372 | 20000 | 0.1 |reached target block errors
3.0 | 3.6747e-03 | 2.2540e-02 | 11759 | 3200000 | 1127 | 50000 | 0.4 |reached target block errors
3.5 | 9.1189e-04 | 5.8278e-03 | 10505 | 11520000 | 1049 | 180000 | 1.3 |reached target block errors
4.0 | 2.0352e-04 | 1.3333e-03 | 9769 | 48000000 | 1000 | 750000 | 5.6 |reached target block errors
4.5 | 3.2078e-05 | 2.1400e-04 | 2053 | 64000000 | 214 | 1000000 | 7.4 |reached max iter
Running: 5G Polar+CRC SC
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 4.1146e-01 | 9.5450e-01 | 263332 | 640000 | 9545 | 10000 | 4.7 |reached target block errors
0.5 | 3.6907e-01 | 8.9910e-01 | 236203 | 640000 | 8991 | 10000 | 0.1 |reached target block errors
1.0 | 3.1438e-01 | 7.9720e-01 | 201201 | 640000 | 7972 | 10000 | 0.1 |reached target block errors
1.5 | 2.4496e-01 | 6.5180e-01 | 156777 | 640000 | 6518 | 10000 | 0.1 |reached target block errors
2.0 | 1.7641e-01 | 4.8730e-01 | 112905 | 640000 | 4873 | 10000 | 0.1 |reached target block errors
2.5 | 1.0874e-01 | 3.1220e-01 | 69595 | 640000 | 3122 | 10000 | 0.1 |reached target block errors
3.0 | 5.6302e-02 | 1.6620e-01 | 36033 | 640000 | 1662 | 10000 | 0.1 |reached target block errors
3.5 | 2.6780e-02 | 8.1050e-02 | 34278 | 1280000 | 1621 | 20000 | 0.1 |reached target block errors
4.0 | 9.3961e-03 | 2.8750e-02 | 24054 | 2560000 | 1150 | 40000 | 0.2 |reached target block errors
4.5 | 3.0973e-03 | 9.4455e-03 | 21805 | 7040000 | 1039 | 110000 | 0.6 |reached target block errors
Running: 5G Polar+CRC SCL-8
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 3.3652e-01 | 7.8720e-01 | 215375 | 640000 | 7872 | 10000 | 15.9 |reached target block errors
0.5 | 2.6210e-01 | 6.3290e-01 | 167742 | 640000 | 6329 | 10000 | 2.4 |reached target block errors
1.0 | 1.7075e-01 | 4.3170e-01 | 109283 | 640000 | 4317 | 10000 | 2.4 |reached target block errors
1.5 | 9.1872e-02 | 2.3870e-01 | 58798 | 640000 | 2387 | 10000 | 2.4 |reached target block errors
2.0 | 3.6781e-02 | 1.0070e-01 | 23540 | 640000 | 1007 | 10000 | 2.4 |reached target block errors
2.5 | 1.2626e-02 | 3.5033e-02 | 24241 | 1920000 | 1051 | 30000 | 7.1 |reached target block errors
3.0 | 2.7641e-03 | 7.9077e-03 | 22997 | 8320000 | 1028 | 130000 | 30.7 |reached target block errors
3.5 | 4.5730e-04 | 1.3184e-03 | 22243 | 48640000 | 1002 | 760000 | 181.5 |reached target block errors
4.0 | 4.5188e-05 | 1.4400e-04 | 2892 | 64000000 | 144 | 1000000 | 240.0 |reached max iter
4.5 | 1.7031e-06 | 5.0000e-06 | 109 | 64000000 | 5 | 1000000 | 239.2 |reached max iter
Running: Reed Muller (RM) SCL-8
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 2.6906e-01 | 6.4660e-01 | 172198 | 640000 | 6466 | 10000 | 12.0 |reached target block errors
0.5 | 1.9218e-01 | 4.7380e-01 | 122997 | 640000 | 4738 | 10000 | 2.1 |reached target block errors
1.0 | 1.1426e-01 | 2.9020e-01 | 73129 | 640000 | 2902 | 10000 | 2.1 |reached target block errors
1.5 | 5.7691e-02 | 1.4860e-01 | 36922 | 640000 | 1486 | 10000 | 2.1 |reached target block errors
2.0 | 2.4731e-02 | 6.5950e-02 | 31656 | 1280000 | 1319 | 20000 | 4.1 |reached target block errors
2.5 | 7.6516e-03 | 2.0600e-02 | 24485 | 3200000 | 1030 | 50000 | 10.3 |reached target block errors
3.0 | 1.7427e-03 | 4.7045e-03 | 24537 | 14080000 | 1035 | 220000 | 46.1 |reached target block errors
3.5 | 3.0444e-04 | 8.6300e-04 | 19484 | 64000000 | 863 | 1000000 | 209.7 |reached max iter
4.0 | 3.4437e-05 | 1.0100e-04 | 2204 | 64000000 | 101 | 1000000 | 209.8 |reached max iter
4.5 | 2.2344e-06 | 6.0000e-06 | 143 | 64000000 | 6 | 1000000 | 209.9 |reached max iter
Running: Conv. Code Viterbi (constraint length 8)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6552e-01 | 6.9890e-01 | 105930 | 640000 | 6989 | 10000 | 3.9 |reached target block errors
0.5 | 1.0780e-01 | 5.5690e-01 | 68995 | 640000 | 5569 | 10000 | 0.6 |reached target block errors
1.0 | 6.3095e-02 | 4.0900e-01 | 40381 | 640000 | 4090 | 10000 | 0.5 |reached target block errors
1.5 | 3.2967e-02 | 2.8130e-01 | 21099 | 640000 | 2813 | 10000 | 0.5 |reached target block errors
2.0 | 1.6017e-02 | 1.9010e-01 | 10251 | 640000 | 1901 | 10000 | 0.5 |reached target block errors
2.5 | 8.1187e-03 | 1.2120e-01 | 5196 | 640000 | 1212 | 10000 | 0.5 |reached target block errors
3.0 | 3.9391e-03 | 7.5450e-02 | 5042 | 1280000 | 1509 | 20000 | 1.1 |reached target block errors
3.5 | 1.9990e-03 | 4.8467e-02 | 3838 | 1920000 | 1454 | 30000 | 1.6 |reached target block errors
4.0 | 9.9727e-04 | 3.0375e-02 | 2553 | 2560000 | 1215 | 40000 | 2.2 |reached target block errors
4.5 | 5.7057e-04 | 1.8833e-02 | 2191 | 3840000 | 1130 | 60000 | 3.3 |reached target block errors
Running: Turbo Code (constraint length 4)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.0865e-01 | 7.8560e-01 | 69536 | 640000 | 7856 | 10000 | 4.2 |reached target block errors
0.5 | 7.7444e-02 | 6.1540e-01 | 49564 | 640000 | 6154 | 10000 | 1.5 |reached target block errors
1.0 | 4.6728e-02 | 3.9780e-01 | 29906 | 640000 | 3978 | 10000 | 1.5 |reached target block errors
1.5 | 2.3164e-02 | 2.1090e-01 | 14825 | 640000 | 2109 | 10000 | 1.5 |reached target block errors
2.0 | 9.3664e-03 | 9.1850e-02 | 11989 | 1280000 | 1837 | 20000 | 3.0 |reached target block errors
2.5 | 3.1074e-03 | 3.2500e-02 | 7955 | 2560000 | 1300 | 40000 | 6.0 |reached target block errors
3.0 | 8.2756e-04 | 9.5364e-03 | 5826 | 7040000 | 1049 | 110000 | 16.5 |reached target block errors
3.5 | 1.7035e-04 | 2.5075e-03 | 4361 | 25600000 | 1003 | 400000 | 60.0 |reached target block errors
4.0 | 3.4562e-05 | 6.7800e-04 | 2212 | 64000000 | 678 | 1000000 | 150.3 |reached max iter
4.5 | 8.2656e-06 | 2.1500e-04 | 529 | 64000000 | 215 | 1000000 | 150.2 |reached max iter
And let’s also look at the block-error-rate.
[7]:
ber_plot128(ylim=(1e-5, 1), show_ber=False)
Please keep in mind that the decoding complexity differs significantly and should be also included in a fair comparison as shown in Section Throughput and Decoding Complexity.
Performance under Optimal Decoding
The achievable error-rate performance of a coding scheme depends on the strength of the code construction and the performance of the actual decoding algorithm. We now approximate the maximum-likelihood performance of all previous coding schemes by using the ordered statistics decoder (OSD) [12].
[8]:
# overwrite existing legend entries for OSD simulations
legends = ["5G LDPC", "5G Polar+CRC", "5G Polar+CRC", "RM", "Conv. Code", "Turbo Code"]
# run ber simulations for each code we have added to the list
for idx, code in enumerate(codes_under_test):
if idx==2: # skip second polar code (same code only different decoder)
continue
print("\nRunning: " + code[2])
# initialize encoder
encoder = code[0]
# encode dummy bits to init conv encoders (otherwise k is not defined)
encoder(tf.zeros((1, k)))
# OSD can be directly associated to an encoder
decoder = OSDecoder(encoder=encoder, t=4)
# generate a new model with the given encoder/decoder
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=encoder,
decoder=decoder,
cw_estimates=True) # OSD returns codeword estimates and not info bit estimates
# the first argument must be a callable (function) that yields u and u_hat for batch_size and ebno
ber_plot128.simulate(tf.function(model, jit_compile=True),
ebno_dbs=ebno_db, # SNR to simulate
legend=legends[idx]+f" OSD-{decoder.t} ", # legend string for plotting
max_mc_iter=1000, # run 100 Monte Carlo runs per SNR point
num_target_block_errors=1000, # continue with next SNR point after 1000 bit errors
batch_size=32, # batch-size per Monte Carlo run
soft_estimates=False, # the model returns hard-estimates
early_stop=True, # stop simulation if no error has been detected at current SNR point
show_fig=False, # we show the figure after all results are simulated
add_bler=True, # in case BLER is also interesting
forward_keyboard_interrupt=True); # should be True in a loop
Running: 5G LDPC BP-20
Note: Required memory complexity is large for the given code parameters and t=4. Please consider small batch-sizes to keep the inference complexity small and activate XLA mode if possible.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727382378.554000 394215 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.0355e-01 | 4.5335e-01 | 29267 | 282624 | 1001 | 2208 | 8.9 |reached target block errors
0.5 | 5.8026e-02 | 2.6042e-01 | 28521 | 491520 | 1000 | 3840 | 12.0 |reached target block errors
1.0 | 2.5965e-02 | 1.2160e-01 | 27333 | 1052672 | 1000 | 8224 | 25.7 |reached target block errors
1.5 | 8.5224e-03 | 4.1309e-02 | 26460 | 3104768 | 1002 | 24256 | 75.8 |reached target block errors
2.0 | 1.9102e-03 | 1.0156e-02 | 7824 | 4096000 | 325 | 32000 | 100.0 |reached max iter
2.5 | 4.2163e-04 | 2.3750e-03 | 1727 | 4096000 | 76 | 32000 | 100.0 |reached max iter
3.0 | 3.3203e-05 | 1.8750e-04 | 136 | 4096000 | 6 | 32000 | 100.0 |reached max iter
3.5 | 0.0000e+00 | 0.0000e+00 | 0 | 4096000 | 0 | 32000 | 100.0 |reached max iter
Simulation stopped as no error occurred @ EbNo = 3.5 dB.
Running: 5G Polar+CRC SC
Note: Required memory complexity is large for the given code parameters and t=4. Please consider small batch-sizes to keep the inference complexity small and activate XLA mode if possible.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.0743e-01 | 4.4643e-01 | 30802 | 286720 | 1000 | 2240 | 9.6 |reached target block errors
0.5 | 5.9205e-02 | 2.5457e-01 | 29828 | 503808 | 1002 | 3936 | 12.3 |reached target block errors
1.0 | 2.5035e-02 | 1.0862e-01 | 29532 | 1179648 | 1001 | 9216 | 28.8 |reached target block errors
1.5 | 8.4632e-03 | 3.7590e-02 | 28876 | 3411968 | 1002 | 26656 | 83.3 |reached target block errors
2.0 | 1.8711e-03 | 8.6563e-03 | 7664 | 4096000 | 277 | 32000 | 100.0 |reached max iter
2.5 | 1.9434e-04 | 9.0625e-04 | 796 | 4096000 | 29 | 32000 | 100.0 |reached max iter
3.0 | 3.1250e-05 | 1.5625e-04 | 128 | 4096000 | 5 | 32000 | 100.0 |reached max iter
3.5 | 1.2207e-05 | 6.2500e-05 | 50 | 4096000 | 2 | 32000 | 100.0 |reached max iter
4.0 | 0.0000e+00 | 0.0000e+00 | 0 | 4096000 | 0 | 32000 | 100.0 |reached max iter
Simulation stopped as no error occurred @ EbNo = 4.0 dB.
Running: Reed Muller (RM) SCL-8
Note: Required memory complexity is large for the given code parameters and t=4. Please consider small batch-sizes to keep the inference complexity small and activate XLA mode if possible.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.0049e-01 | 4.9463e-01 | 26344 | 262144 | 1013 | 2048 | 8.1 |reached target block errors
0.5 | 5.6522e-02 | 2.9410e-01 | 24772 | 438272 | 1007 | 3424 | 10.7 |reached target block errors
1.0 | 2.7499e-02 | 1.5054e-01 | 23428 | 851968 | 1002 | 6656 | 20.8 |reached target block errors
1.5 | 9.9001e-03 | 5.9021e-02 | 21492 | 2170880 | 1001 | 16960 | 53.0 |reached target block errors
2.0 | 2.6064e-03 | 1.6781e-02 | 10676 | 4096000 | 537 | 32000 | 100.0 |reached max iter
2.5 | 6.4551e-04 | 4.5000e-03 | 2644 | 4096000 | 144 | 32000 | 100.0 |reached max iter
3.0 | 9.6680e-05 | 7.5000e-04 | 396 | 4096000 | 24 | 32000 | 100.0 |reached max iter
3.5 | 7.8125e-06 | 6.2500e-05 | 32 | 4096000 | 2 | 32000 | 100.0 |reached max iter
4.0 | 0.0000e+00 | 0.0000e+00 | 0 | 4096000 | 0 | 32000 | 100.0 |reached max iter
Simulation stopped as no error occurred @ EbNo = 4.0 dB.
Running: Conv. Code Viterbi (constraint length 8)
Note: Required memory complexity is large for the given code parameters and t=4. Please consider small batch-sizes to keep the inference complexity small and activate XLA mode if possible.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 9.6430e-02 | 7.0486e-01 | 17774 | 184320 | 1015 | 1440 | 6.1 |reached target block errors
0.5 | 6.2281e-02 | 5.4203e-01 | 14796 | 237568 | 1006 | 1856 | 5.8 |reached target block errors
1.0 | 3.7195e-02 | 4.0787e-01 | 11731 | 315392 | 1005 | 2464 | 7.7 |reached target block errors
1.5 | 1.9923e-02 | 2.8266e-01 | 9058 | 454656 | 1004 | 3552 | 11.1 |reached target block errors
2.0 | 1.0160e-02 | 1.8863e-01 | 6908 | 679936 | 1002 | 5312 | 16.6 |reached target block errors
2.5 | 5.0336e-03 | 1.2066e-01 | 5340 | 1060864 | 1000 | 8288 | 25.9 |reached target block errors
3.0 | 2.5836e-03 | 7.5850e-02 | 4360 | 1687552 | 1000 | 13184 | 41.2 |reached target block errors
3.5 | 1.4637e-03 | 4.9841e-02 | 3759 | 2568192 | 1000 | 20064 | 62.7 |reached target block errors
4.0 | 7.5342e-04 | 2.9531e-02 | 3086 | 4096000 | 945 | 32000 | 100.0 |reached max iter
4.5 | 4.2480e-04 | 1.8219e-02 | 1740 | 4096000 | 583 | 32000 | 100.0 |reached max iter
Running: Turbo Code (constraint length 4)
Note: Required memory complexity is large for the given code parameters and t=4. Please consider small batch-sizes to keep the inference complexity small and activate XLA mode if possible.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.0357e-01 | 5.2865e-01 | 25454 | 245760 | 1015 | 1920 | 7.7 |reached target block errors
0.5 | 5.9731e-02 | 3.2552e-01 | 23487 | 393216 | 1000 | 3072 | 9.8 |reached target block errors
1.0 | 3.1669e-02 | 1.8620e-01 | 21792 | 688128 | 1001 | 5376 | 17.1 |reached target block errors
1.5 | 1.2678e-02 | 8.1756e-02 | 19889 | 1568768 | 1002 | 12256 | 38.9 |reached target block errors
2.0 | 4.1087e-03 | 3.1758e-02 | 16560 | 4030464 | 1000 | 31488 | 100.2 |reached target block errors
2.5 | 1.0581e-03 | 1.0250e-02 | 4334 | 4096000 | 328 | 32000 | 101.8 |reached max iter
3.0 | 2.3853e-04 | 3.1875e-03 | 977 | 4096000 | 102 | 32000 | 102.0 |reached max iter
3.5 | 1.0083e-04 | 1.6875e-03 | 413 | 4096000 | 54 | 32000 | 101.7 |reached max iter
4.0 | 3.8574e-05 | 7.8125e-04 | 158 | 4096000 | 25 | 32000 | 101.1 |reached max iter
4.5 | 1.4648e-05 | 3.1250e-04 | 60 | 4096000 | 10 | 32000 | 101.1 |reached max iter
And let’s plot the results.
Remark: we define a custom plotting function to enable a nicer visualization of OSD vs. non-OSD results.
[9]:
# for simplicity, we only plot a subset of the simulated curves
# focus on BLER
plots_to_show = ['5G LDPC BP-20 (BLER)', '5G LDPC OSD-4 (BLER)', '5G Polar+CRC SCL-8 (BLER)', '5G Polar+CRC OSD-4 (BLER)', 'Reed Muller (RM) SCL-8 (BLER)', 'RM OSD-4 (BLER)', 'Conv. Code Viterbi (constraint length 8) (BLER)', 'Conv. Code OSD-4 (BLER)', 'Turbo Code (constraint length 4) (BLER)', 'Turbo Code OSD-4 (BLER)']
# find indices of relevant curves
idx = []
for p in plots_to_show:
for i,l in enumerate(ber_plot128._legends):
if p==l:
idx.append(i)
# generate new figure
fig, ax = plt.subplots(figsize=(16,12))
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.title(f"Performance under Ordered Statistic Decoding (k={k},n={n})", fontsize=25)
plt.grid(which="both")
plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25)
plt.ylabel(r"BLER", fontsize=25)
# plot pairs of BLER curves (non-osd vs. osd)
for i in range(int(len(idx)/2)):
# non-OSD
plt.semilogy(ebno_db,
ber_plot128._bers[idx[2*i]],
c='C%d'%(i),
label=ber_plot128._legends[idx[2*i]].replace(" (BLER)", ""), #remove "(BLER)" from label
linewidth=2)
# OSD
plt.semilogy(ebno_db,
ber_plot128._bers[idx[2*i+1]],
c='C%d'%(i),
label= ber_plot128._legends[idx[2*i+1]].replace(" (BLER)", ""), #remove "(BLER)" from label
linestyle = "--",
linewidth=2)
plt.legend(fontsize=20)
plt.xlim([0, 4.5])
plt.ylim([1e-4, 1]);
As can be seen, the performance of Polar and Convolutional codes is in practice close to their ML performance. For other codes such as LDPC codes, there is a practical performance gap under BP decoding which tends to be smaller for longer codes.
Performance of Longer LDPC Codes
Now, let us have a look at the performance gains due to longer codewords. For this, we scale the length of the LDPC code and compare the results (same rate, same decoder, same channel).
[10]:
# init new figure
ber_plot_ldpc = PlotBER(f"BER/BLER Performance of LDPC Codes @ Fixed Rate=0.5")
[11]:
# code parameters to simulate
ns = [128, 256, 512, 1000, 2000, 4000, 8000, 16000] # number of codeword bits per codeword
rate = 0.5 # fixed coderate
# create list of encoder/decoder pairs to be analyzed
codes_under_test = []
# 5G LDPC codes
for n in ns:
k = int(rate*n) # calculate k for given n and rate
enc = LDPC5GEncoder(k=k, n=n)
dec = LDPC5GDecoder(enc, num_iter=20)
name = f"5G LDPC BP-20 (n={n})"
codes_under_test.append([enc, dec, name, k, n])
[12]:
# and simulate the results
num_bits_per_symbol = 2 # QPSK
ebno_db = np.arange(0, 5, 0.25) # sim SNR range
# note that the waterfall for long codes can be steep and requires a fine
# SNR quantization
# run ber simulations for each case
for code in codes_under_test:
print("Running: " + code[2])
model = System_Model(k=code[3],
n=code[4],
num_bits_per_symbol=num_bits_per_symbol,
encoder=code[0],
decoder=code[1])
# the first argument must be a callable (function) that yields u and u_hat
# for given batch_size and ebno
# we fix the target number of BLOCK errors instead of the BER to
# ensure that same accurate results for each block lengths is simulated
ber_plot_ldpc.simulate(model, # the function have defined previously
ebno_dbs=ebno_db,
legend=code[2],
max_mc_iter=100,
num_target_block_errors=500, # we fix the target block errors
batch_size=1000,
soft_estimates=False,
early_stop=True,
show_fig=False,
forward_keyboard_interrupt=True); # should be True in a loop
# and show figure
ber_plot_ldpc(ylim=(1e-5, 1))
Running: 5G LDPC BP-20 (n=128)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6739e-01 | 8.5300e-01 | 10713 | 64000 | 853 | 1000 | 0.6 |reached target block errors
0.25 | 1.5509e-01 | 8.2900e-01 | 9926 | 64000 | 829 | 1000 | 0.0 |reached target block errors
0.5 | 1.2559e-01 | 6.9700e-01 | 8038 | 64000 | 697 | 1000 | 0.0 |reached target block errors
0.75 | 1.0622e-01 | 5.9500e-01 | 6798 | 64000 | 595 | 1000 | 0.0 |reached target block errors
1.0 | 8.6859e-02 | 5.0600e-01 | 5559 | 64000 | 506 | 1000 | 0.0 |reached target block errors
1.25 | 6.8172e-02 | 3.9650e-01 | 8726 | 128000 | 793 | 2000 | 0.1 |reached target block errors
1.5 | 5.0156e-02 | 2.9800e-01 | 6420 | 128000 | 596 | 2000 | 0.1 |reached target block errors
1.75 | 3.6958e-02 | 2.2500e-01 | 7096 | 192000 | 675 | 3000 | 0.1 |reached target block errors
2.0 | 2.5945e-02 | 1.5600e-01 | 6642 | 256000 | 624 | 4000 | 0.1 |reached target block errors
2.25 | 1.7631e-02 | 1.0640e-01 | 5642 | 320000 | 532 | 5000 | 0.1 |reached target block errors
2.5 | 1.0803e-02 | 6.5000e-02 | 5531 | 512000 | 520 | 8000 | 0.2 |reached target block errors
2.75 | 6.3738e-03 | 3.9000e-02 | 5303 | 832000 | 507 | 13000 | 0.4 |reached target block errors
3.0 | 3.4735e-03 | 2.1783e-02 | 5113 | 1472000 | 501 | 23000 | 0.7 |reached target block errors
3.25 | 1.8805e-03 | 1.1791e-02 | 5175 | 2752000 | 507 | 43000 | 1.2 |reached target block errors
3.5 | 8.8264e-04 | 5.4348e-03 | 5197 | 5888000 | 500 | 92000 | 2.9 |reached target block errors
3.75 | 4.6922e-04 | 3.2300e-03 | 3003 | 6400000 | 323 | 100000 | 3.4 |reached max iter
4.0 | 2.0469e-04 | 1.2700e-03 | 1310 | 6400000 | 127 | 100000 | 3.4 |reached max iter
4.25 | 9.8125e-05 | 6.1000e-04 | 628 | 6400000 | 61 | 100000 | 3.5 |reached max iter
4.5 | 2.7969e-05 | 1.8000e-04 | 179 | 6400000 | 18 | 100000 | 3.4 |reached max iter
4.75 | 2.2344e-05 | 1.7000e-04 | 143 | 6400000 | 17 | 100000 | 3.4 |reached max iter
Running: 5G LDPC BP-20 (n=256)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6402e-01 | 9.2800e-01 | 20994 | 128000 | 928 | 1000 | 0.6 |reached target block errors
0.25 | 1.3987e-01 | 8.5100e-01 | 17903 | 128000 | 851 | 1000 | 0.0 |reached target block errors
0.5 | 1.1787e-01 | 7.7700e-01 | 15088 | 128000 | 777 | 1000 | 0.0 |reached target block errors
0.75 | 9.2758e-02 | 6.4500e-01 | 11873 | 128000 | 645 | 1000 | 0.0 |reached target block errors
1.0 | 6.7824e-02 | 4.9700e-01 | 17363 | 256000 | 994 | 2000 | 0.1 |reached target block errors
1.25 | 4.8352e-02 | 3.6200e-01 | 12378 | 256000 | 724 | 2000 | 0.1 |reached target block errors
1.5 | 3.0062e-02 | 2.2933e-01 | 11544 | 384000 | 688 | 3000 | 0.1 |reached target block errors
1.75 | 1.7432e-02 | 1.4075e-01 | 8925 | 512000 | 563 | 4000 | 0.1 |reached target block errors
2.0 | 1.0374e-02 | 8.4286e-02 | 9295 | 896000 | 590 | 7000 | 0.3 |reached target block errors
2.25 | 5.0521e-03 | 4.3750e-02 | 7760 | 1536000 | 525 | 12000 | 0.4 |reached target block errors
2.5 | 2.1424e-03 | 1.9615e-02 | 7130 | 3328000 | 510 | 26000 | 0.9 |reached target block errors
2.75 | 8.2264e-04 | 7.5373e-03 | 7055 | 8576000 | 505 | 67000 | 2.4 |reached target block errors
3.0 | 3.1102e-04 | 3.1000e-03 | 3981 | 12800000 | 310 | 100000 | 3.6 |reached max iter
3.25 | 1.0070e-04 | 1.1500e-03 | 1289 | 12800000 | 115 | 100000 | 3.6 |reached max iter
3.5 | 2.1484e-05 | 3.0000e-04 | 275 | 12800000 | 30 | 100000 | 3.6 |reached max iter
3.75 | 1.1797e-05 | 1.2000e-04 | 151 | 12800000 | 12 | 100000 | 3.6 |reached max iter
4.0 | 7.8125e-08 | 1.0000e-05 | 1 | 12800000 | 1 | 100000 | 3.6 |reached max iter
4.25 | 3.9063e-07 | 2.0000e-05 | 5 | 12800000 | 2 | 100000 | 3.6 |reached max iter
4.5 | 0.0000e+00 | 0.0000e+00 | 0 | 12800000 | 0 | 100000 | 3.6 |reached max iter
Simulation stopped as no error occurred @ EbNo = 4.5 dB.
Running: 5G LDPC BP-20 (n=512)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6355e-01 | 9.8100e-01 | 41870 | 256000 | 981 | 1000 | 0.6 |reached target block errors
0.25 | 1.3757e-01 | 9.3900e-01 | 35217 | 256000 | 939 | 1000 | 0.0 |reached target block errors
0.5 | 1.0895e-01 | 8.2300e-01 | 27890 | 256000 | 823 | 1000 | 0.0 |reached target block errors
0.75 | 8.2961e-02 | 6.6700e-01 | 21238 | 256000 | 667 | 1000 | 0.0 |reached target block errors
1.0 | 5.1736e-02 | 4.7400e-01 | 26489 | 512000 | 948 | 2000 | 0.1 |reached target block errors
1.25 | 2.7826e-02 | 2.7700e-01 | 14247 | 512000 | 554 | 2000 | 0.1 |reached target block errors
1.5 | 1.5292e-02 | 1.5650e-01 | 15659 | 1024000 | 626 | 4000 | 0.2 |reached target block errors
1.75 | 5.4431e-03 | 6.1333e-02 | 12541 | 2304000 | 552 | 9000 | 0.4 |reached target block errors
2.0 | 1.8269e-03 | 2.1870e-02 | 10757 | 5888000 | 503 | 23000 | 1.0 |reached target block errors
2.25 | 5.6836e-04 | 7.1857e-03 | 10185 | 17920000 | 503 | 70000 | 3.0 |reached target block errors
2.5 | 1.4531e-04 | 2.0000e-03 | 3720 | 25600000 | 200 | 100000 | 4.3 |reached max iter
2.75 | 4.4922e-05 | 5.3000e-04 | 1150 | 25600000 | 53 | 100000 | 4.3 |reached max iter
3.0 | 4.1016e-06 | 7.0000e-05 | 105 | 25600000 | 7 | 100000 | 4.3 |reached max iter
3.25 | 1.1719e-07 | 1.0000e-05 | 3 | 25600000 | 1 | 100000 | 4.3 |reached max iter
3.5 | 3.1250e-07 | 2.0000e-05 | 8 | 25600000 | 2 | 100000 | 4.3 |reached max iter
3.75 | 0.0000e+00 | 0.0000e+00 | 0 | 25600000 | 0 | 100000 | 4.3 |reached max iter
Simulation stopped as no error occurred @ EbNo = 3.8 dB.
Running: 5G LDPC BP-20 (n=1000)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6290e-01 | 9.9900e-01 | 81449 | 500000 | 999 | 1000 | 0.7 |reached target block errors
0.25 | 1.3847e-01 | 9.8000e-01 | 69233 | 500000 | 980 | 1000 | 0.1 |reached target block errors
0.5 | 1.0607e-01 | 8.9400e-01 | 53037 | 500000 | 894 | 1000 | 0.1 |reached target block errors
0.75 | 6.7784e-02 | 7.3400e-01 | 33892 | 500000 | 734 | 1000 | 0.1 |reached target block errors
1.0 | 3.4606e-02 | 4.6700e-01 | 34606 | 1000000 | 934 | 2000 | 0.1 |reached target block errors
1.25 | 1.3058e-02 | 2.1500e-01 | 19587 | 1500000 | 645 | 3000 | 0.2 |reached target block errors
1.5 | 3.7251e-03 | 7.7143e-02 | 13038 | 3500000 | 540 | 7000 | 0.4 |reached target block errors
1.75 | 7.5100e-04 | 1.9423e-02 | 9763 | 13000000 | 505 | 26000 | 1.6 |reached target block errors
2.0 | 1.2096e-04 | 3.3000e-03 | 6048 | 50000000 | 330 | 100000 | 6.1 |reached max iter
2.25 | 1.3320e-05 | 3.9000e-04 | 666 | 50000000 | 39 | 100000 | 6.1 |reached max iter
2.5 | 4.4000e-07 | 1.0000e-05 | 22 | 50000000 | 1 | 100000 | 6.1 |reached max iter
2.75 | 1.7400e-06 | 4.0000e-05 | 87 | 50000000 | 4 | 100000 | 6.1 |reached max iter
3.0 | 0.0000e+00 | 0.0000e+00 | 0 | 50000000 | 0 | 100000 | 6.1 |reached max iter
Simulation stopped as no error occurred @ EbNo = 3.0 dB.
Running: 5G LDPC BP-20 (n=2000)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.5987e-01 | 1.0000e+00 | 159871 | 1000000 | 1000 | 1000 | 0.8 |reached target block errors
0.25 | 1.3289e-01 | 9.9800e-01 | 132887 | 1000000 | 998 | 1000 | 0.1 |reached target block errors
0.5 | 9.7829e-02 | 9.6600e-01 | 97829 | 1000000 | 966 | 1000 | 0.1 |reached target block errors
0.75 | 5.5639e-02 | 8.2400e-01 | 55639 | 1000000 | 824 | 1000 | 0.1 |reached target block errors
1.0 | 2.0370e-02 | 4.7550e-01 | 40740 | 2000000 | 951 | 2000 | 0.2 |reached target block errors
1.25 | 4.1465e-03 | 1.6100e-01 | 16586 | 4000000 | 644 | 4000 | 0.4 |reached target block errors
1.5 | 4.1419e-04 | 2.5143e-02 | 8698 | 21000000 | 528 | 21000 | 1.9 |reached target block errors
1.75 | 2.6800e-05 | 1.8600e-03 | 2680 | 100000000 | 186 | 100000 | 9.0 |reached max iter
2.0 | 6.0000e-07 | 1.3000e-04 | 60 | 100000000 | 13 | 100000 | 9.0 |reached max iter
2.25 | 0.0000e+00 | 0.0000e+00 | 0 | 100000000 | 0 | 100000 | 9.0 |reached max iter
Simulation stopped as no error occurred @ EbNo = 2.2 dB.
Running: 5G LDPC BP-20 (n=4000)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.6190e-01 | 1.0000e+00 | 323803 | 2000000 | 1000 | 1000 | 1.0 |reached target block errors
0.25 | 1.3800e-01 | 1.0000e+00 | 275993 | 2000000 | 1000 | 1000 | 0.2 |reached target block errors
0.5 | 1.0220e-01 | 9.9700e-01 | 204391 | 2000000 | 997 | 1000 | 0.2 |reached target block errors
0.75 | 4.9521e-02 | 9.2800e-01 | 99042 | 2000000 | 928 | 1000 | 0.2 |reached target block errors
1.0 | 1.1380e-02 | 5.0200e-01 | 22761 | 2000000 | 502 | 1000 | 0.2 |reached target block errors
1.25 | 9.7210e-04 | 1.0160e-01 | 9721 | 10000000 | 508 | 5000 | 0.8 |reached target block errors
1.5 | 2.2845e-05 | 4.7200e-03 | 4569 | 200000000 | 472 | 100000 | 15.8 |reached max iter
1.75 | 2.8000e-07 | 1.4000e-04 | 56 | 200000000 | 14 | 100000 | 16.0 |reached max iter
2.0 | 1.0000e-08 | 2.0000e-05 | 2 | 200000000 | 2 | 100000 | 15.9 |reached max iter
2.25 | 0.0000e+00 | 0.0000e+00 | 0 | 200000000 | 0 | 100000 | 15.9 |reached max iter
Simulation stopped as no error occurred @ EbNo = 2.2 dB.
Running: 5G LDPC BP-20 (n=8000)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.3664e-01 | 1.0000e+00 | 546547 | 4000000 | 1000 | 1000 | 1.6 |reached target block errors
0.25 | 1.0986e-01 | 1.0000e+00 | 439449 | 4000000 | 1000 | 1000 | 0.4 |reached target block errors
0.5 | 7.3622e-02 | 1.0000e+00 | 294487 | 4000000 | 1000 | 1000 | 0.4 |reached target block errors
0.75 | 2.8237e-02 | 9.5800e-01 | 112947 | 4000000 | 958 | 1000 | 0.4 |reached target block errors
1.0 | 2.7041e-03 | 4.0350e-01 | 21633 | 8000000 | 807 | 2000 | 0.7 |reached target block errors
1.25 | 3.7343e-05 | 1.8741e-02 | 4033 | 108000000 | 506 | 27000 | 10.1 |reached target block errors
1.5 | 1.0250e-07 | 1.3000e-04 | 41 | 400000000 | 13 | 100000 | 37.6 |reached max iter
1.75 | 0.0000e+00 | 0.0000e+00 | 0 | 400000000 | 0 | 100000 | 37.8 |reached max iter
Simulation stopped as no error occurred @ EbNo = 1.8 dB.
Running: 5G LDPC BP-20 (n=16000)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 1.3738e-01 | 1.0000e+00 | 1099038 | 8000000 | 1000 | 1000 | 2.6 |reached target block errors
0.25 | 1.1129e-01 | 1.0000e+00 | 890357 | 8000000 | 1000 | 1000 | 0.8 |reached target block errors
0.5 | 7.1292e-02 | 1.0000e+00 | 570336 | 8000000 | 1000 | 1000 | 0.8 |reached target block errors
0.75 | 2.4377e-02 | 9.9700e-01 | 195016 | 8000000 | 997 | 1000 | 0.8 |reached target block errors
1.0 | 1.2146e-03 | 4.4700e-01 | 19434 | 16000000 | 894 | 2000 | 1.6 |reached target block errors
1.25 | 1.5150e-06 | 3.8600e-03 | 1212 | 800000000 | 386 | 100000 | 81.5 |reached max iter
1.5 | 0.0000e+00 | 0.0000e+00 | 0 | 800000000 | 0 | 100000 | 79.2 |reached max iter
Simulation stopped as no error occurred @ EbNo = 1.5 dB.
A Deeper Look into the Polar Code Module
A Polar code can be defined by a set of frozen bit
and information bit
positions [1]. The package sionna.fec.polar.utils
supports 5G-compliant Polar code design, but also Reed-Muller (RM) codes are available and can be used within the same encoder/decoder layer. If required, rate-matching and CRC concatenation are handled by the class sionna.fec.polar.Polar5GEncoder
and sionna.fec.polar.Polar5GDecoder
, respectively.
Further, the following decoders are available:
Successive cancellation (SC) decoding [1]
Fast and low-complexity
Sub-optimal error-rate performance
Successive cancellation list (SCL) decoding [2]
Excellent error-rate performance
High-complexity
CRC-aided decoding possible
Hybrid SCL decoder (combined SC and SCL decoder)
Pre-decode with SC and only apply SCL iff CRC fails
Excellent error-rate performance
Needs outer CRC (e.g., as done in 5G)
CPU-based implementation and, thus, no XLA support (+ increased decoding latency)
Iterative belief propagation (BP) decoding [6]
Produces soft-output estimates
Sub-optimal error-rate performance
Let us now generate a new Polar code.
[13]:
code_type = "5G" # try also "RM"
# Load the 5G compliant polar code
if code_type=="5G":
k = 32
n = 64
# load 5G compliant channel ranking [3]
frozen_pos, info_pos = generate_5g_ranking(k,n)
print("Generated Polar code of length n = {} and k = {}".format(n, k))
print("Frozen codeword positions: ", frozen_pos)
# Alternatively Reed-Muller code design is also available
elif code_type=="RM":
r = 3
m = 7
frozen_pos, info_pos, n, k, d_min = generate_rm_code(r, m)
print("Generated ({},{}) Reed-Muller code of length n = {} and k = {} with minimum distance d_min = {}".format(r, m, n, k, d_min))
print("Frozen codeword positions: ", frozen_pos)
else:
print("Code not found")
Generated Polar code of length n = 64 and k = 32
Frozen codeword positions: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 19 20 21 24 25 26
32 33 34 35 36 37 40 48]
Now, we can initialize the encoder and a BinarySource
to generate random Polar codewords.
[14]:
# init polar encoder
encoder_polar = PolarEncoder(frozen_pos, n)
# init binary source to generate information bits
source = BinarySource()
# define a batch_size
batch_size = 1
# generate random info bits
u = source([batch_size, k])
# and encode
c = encoder_polar(u)
print("Information bits: ", u.numpy())
print("Polar encoded bits: ", c.numpy())
Information bits: [[0. 0. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 1. 1. 0.
1. 1. 1. 1. 1. 0. 1. 1.]]
Polar encoded bits: [[0. 1. 1. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 1. 0.
0. 1. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 1.
0. 1. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1.]]
As can be seen, the length of the resulting code must be a power of 2. This brings us to the problem of rate-matching and we will now have a closer look how we can adapt the length of the code.
Rate-Matching and Rate-Recovery
The general task of rate-matching is to enable flexibility of the code w.r.t. the codeword length \(n\) and information bit input size \(k\) and, thereby, the rate \(r = \frac{k}{n}\). In modern communication standards such as 5G NR, these parameters can be adjusted on a bit-level granularity without - in a wider sense - redefining the (mother) code itself. This is enabled by a powerful rate-matching and the corresponding rate-recovery block which will be explained in the following.
The principle idea is to select a mother code as close as possible to the desired properties from a set of possible mother codes. For example for Polar codes, the codeword length must be a power of 2, i.e., \(n = 32, 64, ..., 512, 1024\). For LDPC codes the codeword length is more flexible (due to the different lifting factors), however, does not allow bit-wise granularity neither. Afterwards, the bit-level granularity is provided by shortening, puncturing and repetitions.
To summarize, the rate-matching procedure consists of:
) 5G NR defines multiple mother codes with similar properties (e.g., via base-graph lifting of LDPC code or sub-codes for Polar codes)
) Puncturing, shortening and repetitions of bits to allow bit-level rate adjustments
The following figure summarizes the principle for the 5G NR Polar code uplink control channel (UCI). The Fig. is inspired by Fig. 6 in [9].
For bit-wise length adjustments, the following techniques are commonly used:
) Puncturing: A (\(k,n\)) mother code is punctured by not transmitting \(p\) punctured codeword bits. Thus, the rate increases to \(r_{\text{pun}} = \frac{k}{n-p} > \frac{k}{n} \quad \forall p > 0\). At the decoder these codeword bits are treated as erasure (\(\ell_{\text{ch}} = 0\)).
) Shortening: A (\(k,n\)) mother code is shortened by setting \(s\) information bits to a fixed (=known) value. Assuming systematic encoding, these \(s\) positions are not transmitted leading to a new code of rate \(r_{\text{short}} = \frac{k-s}{n-s}<\frac{k}{n}\). At the decoder these codeword bits are treated as known values (\(\ell_{\text{ch}} = \infty\)).
) Repetitions can be used to lower the effective rate. For details we refer the interested reader to [11].
We will now simulate the performance of rate-matched 5G Polar codes for different lengths and rates. For this, we are interested in the required SNR to achieve a target BLER at \(10^{-3}\). Please note that this is a reproduction of the results from [Fig.13a, 4].
Note: This needs a bisection search as we usually simulate the BLER at fixed SNR and, thus, this is simulation takes some time. Please only execute the cell below if you have enough simulation capabilities.
[15]:
# find the EsNo in dB to achieve target_bler
def find_threshold(model, # model to be tested
batch_size=1000,
max_batch_iter=10, # simulate cws up to batch_size * max_batch_iter
max_block_errors=100, # number of errors before stop
target_bler=1e-3): # target error rate to simulate (same as in[4])
"""Bisection search to find required SNR to reach target SNR."""
# bisection parameters
esno_db_min = -15 # smallest possible search SNR
esno_db_max = 15 # largest possible search SNR
esno_interval = (esno_db_max-esno_db_min)/4 # initial search interval size
esno_db = 2*esno_interval + esno_db_min # current test SNR
max_iters = 12 # number of iterations for bisection search
# run bisection
for i in range(max_iters):
num_block_error = 0
num_cws = 0
for j in range(max_batch_iter):
# run model and evaluate BLER
u, u_hat = model(tf.constant(batch_size, tf.int32),
tf.constant(esno_db, tf.float32))
num_block_error += count_block_errors(u, u_hat)
num_cws += batch_size
# early stop if target number of block errors is reached
if num_block_error>max_block_errors:
break
bler = num_block_error/num_cws
# increase SNR if BLER was great than target
# (larger SNR leads to decreases BLER)
if bler>target_bler:
esno_db += esno_interval
else: # and decrease SNR otherwise
esno_db -= esno_interval
esno_interval = esno_interval/2
# return final SNR after max_iters
return esno_db
[16]:
# run simulations for multiple code parameters
num_bits_per_symbol = 2 # QPSK
# we sweep over multiple values for k and n
ks = np.array([12, 16, 32, 64, 128, 140, 210, 220, 256, 300, 400, 450, 460, 512, 800, 880, 940])
ns = np.array([160, 240, 480, 960])
# we use EsNo instead of EbNo to have the same results as in [4]
esno = np.zeros([len(ns), len(ks)])
for j,n in enumerate(ns):
for i,k in enumerate(ks):
if k<n: # only simulate if code parameters are feasible (i.e., r < 1)
print(f"Finding threshold of k = {k}, n = {n}")
# initialize new encoder / decoder pair
enc = Polar5GEncoder(k=k, n=n)
dec = Polar5GDecoder(enc, dec_type="SCL", list_size=8)
#build model
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=enc,
decoder=dec,
sim_esno=True) # no rate adjustment
# and find threshold via bisection search
esno[j, i] = find_threshold(model)
print("Found threshold at: ", esno[j, i])
Finding threshold of k = 12, n = 160
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
/home/jhoydis/.local/lib/python3.10/site-packages/sionna/fec/polar/decoding.py:511: UserWarning: Required resource allocation is large for the selected blocklength. Consider option `cpu_only=True`.
warnings.warn("Required resource allocation is large " \
Found threshold at: -4.383544921875
Finding threshold of k = 16, n = 160
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -3.607177734375
Finding threshold of k = 32, n = 160
Found threshold at: -0.421142578125
Finding threshold of k = 64, n = 160
Found threshold at: 2.523193359375
Finding threshold of k = 128, n = 160
Found threshold at: 6.749267578125
Finding threshold of k = 140, n = 160
Found threshold at: 8.140869140625
Finding threshold of k = 12, n = 240
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -6.177978515625
Finding threshold of k = 16, n = 240
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -5.401611328125
Finding threshold of k = 32, n = 240
Found threshold at: -2.655029296875
Finding threshold of k = 64, n = 240
Found threshold at: -0.054931640625
Finding threshold of k = 128, n = 240
Found threshold at: 3.292236328125
Finding threshold of k = 140, n = 240
Found threshold at: 3.892822265625
Finding threshold of k = 210, n = 240
Found threshold at: 7.738037109375
Finding threshold of k = 220, n = 240
Found threshold at: 8.638916015625
Finding threshold of k = 12, n = 480
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -9.378662109375
Finding threshold of k = 16, n = 480
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -8.345947265625
Finding threshold of k = 32, n = 480
Found threshold at: -5.679931640625
Finding threshold of k = 64, n = 480
Found threshold at: -3.446044921875
Finding threshold of k = 128, n = 480
Found threshold at: -0.684814453125
Finding threshold of k = 140, n = 480
Found threshold at: -0.252685546875
Finding threshold of k = 210, n = 480
Found threshold at: 1.695556640625
Finding threshold of k = 220, n = 480
Found threshold at: 1.973876953125
Finding threshold of k = 256, n = 480
Found threshold at: 2.772216796875
Finding threshold of k = 300, n = 480
Found threshold at: 3.900146484375
Finding threshold of k = 400, n = 480
Found threshold at: 6.485595703125
Finding threshold of k = 450, n = 480
Found threshold at: 8.551025390625
Finding threshold of k = 460, n = 480
Found threshold at: 9.495849609375
Finding threshold of k = 12, n = 960
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -9.371337890625
Finding threshold of k = 16, n = 960
Warning: For 12<=k<=19 additional 3 parity-check bits are defined in 38.212. They are currently not implemented by this encoder and, thus, ignored.
Found threshold at: -8.675537109375
Finding threshold of k = 32, n = 960
Found threshold at: -8.756103515625
Finding threshold of k = 64, n = 960
Found threshold at: -6.558837890625
Finding threshold of k = 128, n = 960
Found threshold at: -4.039306640625
Finding threshold of k = 140, n = 960
Found threshold at: -3.753662109375
Finding threshold of k = 210, n = 960
Found threshold at: -2.010498046875
Finding threshold of k = 220, n = 960
Found threshold at: -1.834716796875
Finding threshold of k = 256, n = 960
Found threshold at: -1.131591796875
Finding threshold of k = 300, n = 960
Found threshold at: -0.362548828125
Finding threshold of k = 400, n = 960
Found threshold at: 1.248779296875
Finding threshold of k = 450, n = 960
Found threshold at: 1.739501953125
Finding threshold of k = 460, n = 960
Found threshold at: 1.864013671875
Finding threshold of k = 512, n = 960
Found threshold at: 2.523193359375
Finding threshold of k = 800, n = 960
Found threshold at: 6.090087890625
Finding threshold of k = 880, n = 960
Found threshold at: 7.738037109375
Finding threshold of k = 940, n = 960
Found threshold at: 10.052490234375
[17]:
# plot the results
leg_str = []
for j,n in enumerate(ns):
plt.plot(np.log2(ks[ks<n]), esno[j, ks<n])
leg_str.append("n = {}".format(n))
# define labels manually
x_tick_labels = np.power(2, np.arange(3,11))
plt.xticks(ticks=np.arange(3,11),labels=x_tick_labels, fontsize=18)
# adjusted layout of figure
plt.grid("both")
plt.ylim([-10, 15])
plt.xlabel("Number of information bits $k$", fontsize=20)
plt.yticks(fontsize=18)
plt.ylabel("$E_s/N_0^*$ (dB)", fontsize=20)
plt.legend(leg_str, fontsize=18);
fig = plt.gcf() # get handle to current figure
fig.set_size_inches(15,10)
This figure equals [Fig. 13a, 4] with a few small exception for extreme low-rate codes. This can be explained by the fact that the 3 explicit parity-bits bits are not implemented, however, these bits are only relevant for for \(12\leq k \leq20\). It also explains the degraded performance of the n=960, k=16 code.
Throughput and Decoding Complexity
In the last part of this notebook, you will compare the different computational complexity of the different codes and decoders. In theory the complexity is given as:
Successive cancellation list (SCL) decoding of Polar codes scales with \(\mathcal{O}(L \cdot n \cdot \operatorname{log} n)\) (with \(L=1\) for SC decoding)
Iterative belief propagation (BP) decoding of LDPC codes scales with \(\mathcal{O}(n)\). However, in particular for short codes a complexity comparison should be supported by empirical results.
We want to emphasize that the results strongly depend on the exact implementation and may differ for different implementations/optimizations. Implementing the SCL decoder in Tensorflow is a delicate task and requires several design trade-offs to enable a graph implementation which can lead to degraded throughput mainly caused by the missing lazy copy-mechanism. However, - inspired by [10] - the SCL decoder layer supports hybrid SC
decoding meaning that SC decoding is done first and a
second stage SCL decoder operates as afterburner iff the outer CRC check fails. Please note that this modus uses ‘tf.py_function’ (due to the control flow and the dynamic shape of the decoding graph) and, thus, does not support XLA compilation.
[18]:
def get_throughput(batch_size, ebno_dbs, model, repetitions=1):
""" Simulate throughput in bit/s per ebno_dbs point.
The results are average over `repetition` trials.
Input
-----
batch_size: tf.int32
Batch-size for evaluation.
ebno_dbs: tf.float32
A tensor containing SNR points to be evaluated.
model:
Function or model that yields the transmitted bits `u` and the
receiver's estimate `u_hat` for a given ``batch_size`` and
``ebno_db``.
repetitions: int
An integer defining how many trails of the throughput
simulation are averaged.
"""
throughput = np.zeros_like(ebno_dbs)
# call model once to be sure it is compile properly
# otherwise time to build graph is measured as well.
u, u_hat = model(tf.constant(batch_size, tf.int32),
tf.constant(0., tf.float32))
for idx, ebno_db in enumerate(ebno_dbs):
t_start = time.perf_counter()
# average over multiple runs
for _ in range(repetitions):
u, u_hat = model(tf.constant(batch_size, tf.int32),
tf.constant(ebno_db, tf. float32))
t_stop = time.perf_counter()
# throughput in bit/s
throughput[idx] = np.size(u.numpy())*repetitions / (t_stop - t_start)
return throughput
[19]:
# plot throughput and ber together for ldpc codes
# and simulate the results
num_bits_per_symbol = 2 # QPSK
ebno_db = [5] # SNR to simulate
num_bits_per_batch = 5e6 # must be reduced in case of out-of-memory errors
num_repetitions = 20 # average throughput over multiple runs
# run throughput simulations for each code
throughput = np.zeros(len(codes_under_test))
code_length = np.zeros(len(codes_under_test))
for idx, code in enumerate(codes_under_test):
print("Running: " + code[2])
# save codeword length for plotting
code_length[idx] = code[4]
# init new model for given encoder/decoder
model = System_Model(k=code[3],
n=code[4],
num_bits_per_symbol=num_bits_per_symbol,
encoder=code[0],
decoder=code[1])
# scale batch_size such that same number of bits is simulated for all codes
batch_size = int(num_bits_per_batch / code[4])
# and measure throughput of the model
throughput[idx] = get_throughput(batch_size,
ebno_db,
model,
repetitions=num_repetitions)
Running: 5G LDPC BP-20 (n=128)
Running: 5G LDPC BP-20 (n=256)
Running: 5G LDPC BP-20 (n=512)
Running: 5G LDPC BP-20 (n=1000)
Running: 5G LDPC BP-20 (n=2000)
Running: 5G LDPC BP-20 (n=4000)
Running: 5G LDPC BP-20 (n=8000)
Running: 5G LDPC BP-20 (n=16000)
[20]:
# plot results
plt.figure(figsize=(16,10))
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.title("Throughput LDPC BP Decoding @ rate=0.5", fontsize=25)
plt.xlabel("Codeword length", fontsize=25)
plt.ylabel("Throughput (Mbit/s)", fontsize=25)
plt.grid(which="both")
# and plot results (logarithmic scale in x-dim)
x_tick_labels = code_length.astype(int)
plt.xticks(ticks=np.log2(code_length),labels=x_tick_labels, fontsize=18)
plt.plot(np.log2(code_length), throughput/1e6)
[20]:
[<matplotlib.lines.Line2D at 0x7ff02c056d10>]
As expected the throughput of BP decoding is (relatively) constant as the complexity scales linearly with \(\mathcal{O}(n)\) and, thus, the complexity per decoded bit remains constant. It is instructive to realize that the above plot is in the log-domain for the x-axis.
Let us have a look at what happens for different SNR values.
[21]:
# --- LDPC ---
n = 1000
k = 500
encoder = LDPC5GEncoder(k, n)
decoder = LDPC5GDecoder(encoder)
# init a new model
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=encoder,
decoder=decoder)
# run throughput tests at 2 dB and 5 dB
ebno_db = [2, 5]
batch_size = 10000
throughput = get_throughput(batch_size,
ebno_db, # snr point
model,
repetitions=num_repetitions)
# and print the results
for idx, snr_db in enumerate(ebno_db):
print(f"Throughput @ {snr_db:.1f} dB: {throughput[idx]/1e6:.2f} Mbit/s")
Throughput @ 2.0 dB: 13.83 Mbit/s
Throughput @ 5.0 dB: 13.84 Mbit/s
For most Sionna decoders the throughput is not SNR dependent as early stopping of individual samples within a batch is difficult to realize.
However, the hybrid SCL
decoder uses an internal NumPy SCL decoder only if the SC decoder failed similar to [10]. We will now benchmark this decoder for different SNR values.
[22]:
# --- Polar ---
n = 256
k = 128
encoder = Polar5GEncoder(k, n)
decoder = Polar5GDecoder(encoder, "hybSCL")
# init a new model
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=encoder,
decoder=decoder)
ebno_db = np.arange(0, 5, 0.5) # EbNo to evaluate
batch_size = 1000
throughput = get_throughput(batch_size,
ebno_db, # snr point
model,
repetitions=num_repetitions)
# and print the results
for idx, snr_db in enumerate(ebno_db):
print(f"Throughput @ {snr_db:.1f} dB: {throughput[idx]/1e6:.3f} Mbit/s")
Throughput @ 0.0 dB: 0.013 Mbit/s
Throughput @ 0.5 dB: 0.014 Mbit/s
Throughput @ 1.0 dB: 0.017 Mbit/s
Throughput @ 1.5 dB: 0.024 Mbit/s
Throughput @ 2.0 dB: 0.041 Mbit/s
Throughput @ 2.5 dB: 0.084 Mbit/s
Throughput @ 3.0 dB: 0.201 Mbit/s
Throughput @ 3.5 dB: 0.640 Mbit/s
Throughput @ 4.0 dB: 0.921 Mbit/s
Throughput @ 4.5 dB: 0.999 Mbit/s
We can overlay the throughput with the BLER of the SC decoder. This can be intuitively explained by the fact that he hybrid SCL
decoder consists of two decoding stages:
SC decoding for all received codewords.
SCL decoding iff the CRC does not hold, i.e., SC decoding did not yield the correct codeword.
Thus, the throughput directly depends on the BLER of the internal SC decoder.
[23]:
ber_plot_polar = PlotBER("Polar SC/SCL Decoding")
ber_plot_polar.simulate(model, # the function have defined previously
ebno_dbs=ebno_db,
legend="hybrid SCL decoding",
max_mc_iter=100,
num_target_block_errors=100, # we fix the target bler
batch_size=1000,
soft_estimates=False,
early_stop=True,
add_ber=False,
add_bler=True,
show_fig=False,
forward_keyboard_interrupt=False);
# and add SC decoding
decoder2 = Polar5GDecoder(encoder, "SC")
model = System_Model(k=k,
n=n,
num_bits_per_symbol=num_bits_per_symbol,
encoder=encoder,
decoder=decoder2)
ber_plot_polar.simulate(model, # the function have defined previously
ebno_dbs=ebno_db,
legend="SC decoding",
max_mc_iter=100,
num_target_block_errors=100, # we fix the target bler
batch_size=1000,
soft_estimates=False,
early_stop=True,
add_ber=False, # we only focus on BLER
add_bler=True,
show_fig=False,
forward_keyboard_interrupt=False);
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 3.4614e-01 | 8.3400e-01 | 44306 | 128000 | 834 | 1000 | 9.4 |reached target block errors
0.5 | 2.3880e-01 | 6.4400e-01 | 30567 | 128000 | 644 | 1000 | 8.9 |reached target block errors
1.0 | 1.1895e-01 | 3.3800e-01 | 15225 | 128000 | 338 | 1000 | 7.6 |reached target block errors
1.5 | 4.4477e-02 | 1.3500e-01 | 5693 | 128000 | 135 | 1000 | 5.6 |reached target block errors
2.0 | 9.6348e-03 | 3.3250e-02 | 4933 | 512000 | 133 | 4000 | 12.7 |reached target block errors
2.5 | 1.4121e-03 | 5.3000e-03 | 3615 | 2560000 | 106 | 20000 | 31.3 |reached target block errors
3.0 | 1.1391e-04 | 5.4000e-04 | 1458 | 12800000 | 54 | 100000 | 66.7 |reached max iter
3.5 | 3.1250e-06 | 3.0000e-05 | 40 | 12800000 | 3 | 100000 | 20.9 |reached max iter
4.0 | 0.0000e+00 | 0.0000e+00 | 0 | 12800000 | 0 | 100000 | 15.0 |reached max iter
Simulation stopped as no error occurred @ EbNo = 4.0 dB.
Warning: 5G Polar codes use an integrated CRC that cannot be materialized with SC decoding and, thus, causes a degraded performance. Please consider SCL decoding instead.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
0.0 | 4.2025e-01 | 9.7400e-01 | 53792 | 128000 | 974 | 1000 | 6.8 |reached target block errors
0.5 | 3.5540e-01 | 8.9900e-01 | 45491 | 128000 | 899 | 1000 | 0.1 |reached target block errors
1.0 | 2.8491e-01 | 7.6700e-01 | 36468 | 128000 | 767 | 1000 | 0.1 |reached target block errors
1.5 | 1.8536e-01 | 5.5100e-01 | 23726 | 128000 | 551 | 1000 | 0.1 |reached target block errors
2.0 | 1.0283e-01 | 3.4200e-01 | 13162 | 128000 | 342 | 1000 | 0.1 |reached target block errors
2.5 | 4.2828e-02 | 1.5900e-01 | 5482 | 128000 | 159 | 1000 | 0.1 |reached target block errors
3.0 | 1.3879e-02 | 5.1000e-02 | 3553 | 256000 | 102 | 2000 | 0.2 |reached target block errors
3.5 | 3.2578e-03 | 1.3625e-02 | 3336 | 1024000 | 109 | 8000 | 0.6 |reached target block errors
4.0 | 6.6572e-04 | 3.0909e-03 | 2812 | 4224000 | 102 | 33000 | 2.6 |reached target block errors
4.5 | 1.0789e-04 | 7.0000e-04 | 1381 | 12800000 | 70 | 100000 | 7.8 |reached max iter
Let us visualize the results.
[24]:
ber_plot_polar()
ax2 = plt.gca().twinx() # new axis
ax2.plot(ebno_db, throughput, 'g', label="Throughput hybSCL-8")
ax2.legend(fontsize=20)
ax2.set_ylabel("Throughput (bit/s)", fontsize=25);
ax2.tick_params(labelsize=25)
You can also try:
Analyze different rates
What happens for different batch-sizes? Can you explain what happens?
What happens for higher order modulation. Why is the complexity increased?
References
[1] E. Arikan, “Channel polarization: A method for constructing capacity-achieving codes for symmetric binary-input memoryless channels,” IEEE Transactions on Information Theory, 2009.
[2] Ido Tal and Alexander Vardy, “List Decoding of Polar Codes.” IEEE Transactions on Information Theory, 2015.
[3] ETSI 3GPP TS 38.212 “5G NR Multiplexing and channel coding”, v.16.5.0, 2021-03.
[4] V. Bioglio, C. Condo, I. Land, “Design of Polar Codes in 5G New Radio.” IEEE Communications Surveys & Tutorials, 2020.
[5] D. Hui, S. Sandberg, Y. Blankenship, M. Andersson, L. Grosjean “Channel coding in 5G new radio: A Tutorial Overview and Performance Comparison with 4G LTE.” IEEE Vehicular Technology Magazine, 2018.
[6] E. Arikan, “A Performance Comparison of Polar Codes and Reed-Muller Codes,” IEEE Commun. Lett., vol. 12, no. 6, pp. 447–449, Jun. 2008.
[7] R. G. Gallager, Low-Density Parity-Check Codes, M.I.T. Press Classic Series, Cambridge MA, 1963.
[8] T. Richardson and S. Kudekar. “Design of low-density parity check codes for 5G new radio,” IEEE Communications Magazine 56.3, 2018.
[9] G. Liva, L. Gaudio, T. Ninacs, T. Jerkovits, “Code design for short blocks: A survey,” arXiv preprint arXiv:1610.00873, 2016.
[10] S. Cammerer, B. Leible, M. Stahl, J. Hoydis, and S ten Brink, “Combining Belief Propagation and Successive Cancellation List Decoding of Polar Codes on a GPU Platform,” IEEE ICASSP, 2017.
[11] V. Bioglio, F. Gabry, I. Land, “Low-complexity puncturing and shortening of polar codes,” IEEE Wireless Communications and Networking Conference Workshops (WCNCW), 2017.
[12] M. Fossorier, S. Lin, “Soft-Decision Decoding of Linear Block Codes Based on Ordered Statistics”, IEEE Transactions on Information Theory, vol. 41, no. 5, 1995.