#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Plotting functions for Sionna PHY"""
import numpy as np
import matplotlib.pyplot as plt
from itertools import compress
from sionna.phy.utils import sim_ber
[docs]
def plot_ber(snr_db,
ber,
legend="",
ylabel="BER",
title="Bit Error Rate",
ebno=True,
is_bler=None,
xlim=None,
ylim=None,
save_fig=False,
path=""):
"""Plot error-rates
Input
-----
snr_db: `numpy.ndarray` or `list` of `numpy.ndarray`
Array defining the simulated SNR points
ber: `numpy.ndarray` or `list` of `numpy.ndarray`
Array defining the BER/BLER per SNR point
legend: `str`, (default ""), or `list` of `str`
Legend entries
ylabel: `str`, (default "BER")
y-label
title: `str`, (default "Bit Error Rate")
Figure title
ebno: `bool`, (default `True`)
If `True`, the x-label is set to
"EbNo [dB]" instead of "EsNo [dB]".
is_bler: `bool`, (default `False`)
If `True`, the corresponding curve is dashed.
xlim: `None` (default) | (`float`, `float`)
x-axis limits
ylim: `None` (default) | (`float`, `float`)
y-axis limits
save_fig: `bool`, (default `False`)
If `True`, the figure is saved as `.png`.
path: `str`, (default "")
Path to save the figure (if ``save_fig`` is `True`)
Output
------
fig : `matplotlib.figure.Figure`
Figure handle
ax : matplotlib.axes.Axes
Axes object
"""
# legend must be a list or string
if not isinstance(legend, list):
assert isinstance(legend, str)
legend = [legend]
assert isinstance(title, str), "title must be str."
# broadcast snr if ber is list
if isinstance(ber, list):
if not isinstance(snr_db, list):
snr_db = [snr_db]*len(ber)
# check that is_bler is list of same size and contains only bools
if is_bler is None:
if isinstance(ber, list):
is_bler = [False] * len(ber) # init is_bler as list with False
else:
is_bler = False
else:
if isinstance(is_bler, list):
assert (len(is_bler) == len(ber)), "is_bler has invalid size."
else:
assert isinstance(is_bler, bool), \
"is_bler must be bool or list of bool."
is_bler = [is_bler] # change to list
# tile snr_db if not list, but ber is list
fig, ax = plt.subplots(figsize=(16,10))
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.title(title, fontsize=25)
# return figure handle
if isinstance(ber, list):
for idx, b in enumerate(ber):
if is_bler[idx]:
line_style = "--"
else:
line_style = ""
plt.semilogy(snr_db[idx], b, line_style, linewidth=2)
else:
if is_bler:
line_style = "--"
else:
line_style = ""
plt.semilogy(snr_db, ber, line_style, linewidth=2)
plt.grid(which="both")
if ebno:
plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25)
else:
plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25)
plt.ylabel(ylabel, fontsize=25)
plt.legend(legend, fontsize=20)
if save_fig:
plt.savefig(path)
plt.close(fig)
else:
#plt.close(fig)
pass
return fig, ax
[docs]
class PlotBER():
"""Provides a plotting object to simulate and store BER/BLER curves
Parameters
----------
title: `str`, (default "Bit/Block Error Rate")
Figure title
Input
-----
snr_db: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
SNR values
ber: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
BER values corresponding to ``snr_db``
legend: `str` or `list` of `str`
Legend entries
is_bler: `bool` or `list` of `bool`, (default [])
If `True`, ``ber`` will be interpreted as BLER.
show_ber: `bool`, (default `True`)
If `True`, BER curves will be plotted.
show_bler: `bool`, (default `True`)
If `True`, BLER curves will be plotted.
xlim: `None` (default) | (`float`, `float`)
x-axis limits
ylim: `None` (default) | (`float`, `float`)
y-axis limits
save_fig: `bool`, (default `False`)
If `True`, the figure is saved as `.png`.
path: `str`, (default "")
Path to save the figure (if ``save_fig`` is `True`)
"""
def __init__(self, title="Bit/Block Error Rate"):
assert isinstance(title, str), "title must be str."
self._title = title
# init lists
self._bers = []
self._snrs = []
self._legends = []
self._is_bler = []
# pylint: disable=W0102
def __call__(self,
snr_db=[],
ber=[],
legend=[],
is_bler=[],
show_ber=True,
show_bler=True,
xlim=None,
ylim=None,
save_fig=False,
path=""):
"""Plot BER curves.
"""
assert isinstance(path, str), "path must be str"
assert isinstance(save_fig, bool), "save_fig must be bool"
# broadcast snr if ber is list
if isinstance(ber, list):
if not isinstance(snr_db, list):
snr_db = [snr_db]*len(ber)
if not isinstance(snr_db, list):
snrs = self._snrs + [snr_db]
else:
snrs = self._snrs + snr_db
if not isinstance(ber, list):
bers = self._bers + [ber]
else:
bers = self._bers + ber
if not isinstance(legend, list):
legends = self._legends + [legend]
else:
legends = self._legends + legend
if not isinstance(is_bler, list):
is_bler = self._is_bler + [is_bler]
else:
is_bler = self._is_bler + is_bler
# deactivate BER/BLER
if len(is_bler)>0: # ignore if object is empty
if show_ber is False:
snrs = list(compress(snrs, is_bler))
bers = list(compress(bers, is_bler))
legends = list(compress(legends, is_bler))
is_bler = list(compress(is_bler, is_bler))
if show_bler is False:
snrs = list(compress(snrs, np.invert(is_bler)))
bers = list(compress(bers, np.invert(is_bler)))
legends = list(compress(legends, np.invert(is_bler)))
is_bler = list(compress(is_bler, np.invert(is_bler)))
# set ylabel
ylabel = "BER / BLER"
if np.all(is_bler): # only BLERs to plot
ylabel = "BLER"
if not np.any(is_bler): # only BERs to plot
ylabel = "BER"
# and plot the results
plot_ber(snr_db=snrs,
ber=bers,
legend=legends,
is_bler=is_bler,
title=self._title,
ylabel=ylabel,
xlim=xlim,
ylim=ylim,
save_fig=save_fig,
path=path)
####public methods
@property
def title(self):
"""
`str` : Get/set title of the plot
"""
return self._title
@title.setter
def title(self, title):
assert isinstance(title, str), "title must be string"
self._title = title
@property
def ber(self):
"""
`list` of `numpy.ndarray`, `float` : Stored BER/BLER values
"""
return self._bers
@property
def snr(self):
"""
`list` of `numpy.ndarray`, `float` : Stored SNR values
"""
return self._snrs
@property
def legend(self):
"""
`list` of `str` : Legend entries
"""
return self._legends
@property
def is_bler(self):
"""
`list` of `bool` : Indicates if a curve shall be interpreted as BLER
"""
return self._is_bler
def simulate(self,
mc_fun,
ebno_dbs,
batch_size,
max_mc_iter,
legend="",
add_ber=True,
add_bler=False,
soft_estimates=False,
num_target_bit_errors=None,
num_target_block_errors=None,
target_ber=None,
target_bler=None,
early_stop=True,
graph_mode=None,
distribute=None,
add_results=True,
forward_keyboard_interrupt=True,
show_fig=True,
verbose=True):
# pylint: disable=line-too-long
r"""Simulate BER/BLER curves for a given model and saves the results
Internally calls :class:`sionna.phy.utils.sim_ber`.
Input
-----
mc_fun: `callable`
Callable that yields the transmitted bits `b` and the
receiver's estimate `b_hat` for a given ``batch_size`` and
``ebno_db``. If ``soft_estimates`` is `True`, b_hat is
interpreted as logit.
ebno_dbs: `numpy.ndarray` of `float`
SNR points to be evaluated
batch_size: `tf.int`
Batch-size for evaluation
max_mc_iter: `int`
Max. number of Monte-Carlo iterations per SNR point
legend: `str`, (default "")
Name to appear in legend
add_ber: `bool`, (default `True`)
Indicates if BER should be added to plot
add_bler: `bool`, (default `True`)
Indicate if BLER should be added to plot
soft_estimates: `bool`, (default `False`)
If `True`, ``b_hat`` is interpreted as logit and additional
hard-decision is applied internally.
num_target_bit_errors: `None` (default) | `int`
Target number of bit errors per SNR point until the simulation
stops
num_target_block_errors: `None` (default) | `int`
Target number of block errors per SNR point until the simulation
stops
target_ber: `None` (default) | `float`
The simulation stops after the first SNR point
which achieves a lower bit error rate as specified by
``target_ber``. This requires ``early_stop`` to be `True`.
target_bler: `None` (default) | `float`
The simulation stops after the first SNR point
which achieves a lower block error rate as specified by
``target_bler``. This requires ``early_stop`` to be `True`.
early_stop: `bool`, (default `True`)
If `True`, the simulation stops after the
first error-free SNR point (i.e., no error occurred after
``max_mc_iter`` Monte-Carlo iterations).
graph_mode: `None` (default) | "graph" | "xla"
A string describing the execution mode of ``mc_fun``.
Defaults to `None`. In this case, ``mc_fun`` is executed as is.
distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy`
Distributes simulation on multiple parallel devices. If `None`,
multi-device simulations are deactivated. If "all", the workload
will be automatically distributed across all available GPUs via the
`tf.distribute.MirroredStrategy`.
If an explicit list of indices is provided, only the GPUs with the
given indices will be used. Alternatively, a custom
`tf.distribute.strategy` can be provided. Note that the same
`batch_size` will be used for all GPUs in parallel, but the number
of Monte-Carlo iterations ``max_mc_iter`` will be scaled by the
number of devices such that the same number of total samples is
simulated. However, all stopping conditions are still in-place
which can cause slight differences in the total number of simulated
samples.
add_results: `bool`, (default `True`)
If `True`, the simulation results will be appended
to the internal list of results.
show_fig: `bool`, (default `True`)
If `True`, a BER figure will be plotted.
verbose: `bool`, (default `True`)
If `True`, the current progress will be printed.
forward_keyboard_interrupt: `bool`, (default `True`)
If `False`, `KeyboardInterrupts` will be
catched internally and not forwarded (e.g., will not stop outer
loops). If `True`, the simulation ends and returns the intermediate
simulation results.
Output
------
ber: `tf.float`
Simulated bit-error rates
bler: `tf.float`
Simulated block-error rates
"""
ber, bler = sim_ber(
mc_fun,
ebno_dbs,
batch_size,
soft_estimates=soft_estimates,
max_mc_iter=max_mc_iter,
num_target_bit_errors=num_target_bit_errors,
num_target_block_errors=num_target_block_errors,
target_ber=target_ber,
target_bler=target_bler,
early_stop=early_stop,
graph_mode=graph_mode,
distribute=distribute,
verbose=verbose,
forward_keyboard_interrupt=forward_keyboard_interrupt)
if add_ber:
self._bers += [ber]
self._snrs += [ebno_dbs]
self._legends += [legend]
self._is_bler += [False]
if add_bler:
self._bers += [bler]
self._snrs += [ebno_dbs]
self._legends += [legend + " (BLER)"]
self._is_bler += [True]
if show_fig:
self()
# remove current curve if add_results=False
if add_results is False:
if add_bler:
self.remove(-1)
if add_ber:
self.remove(-1)
return ber, bler
def add(self, ebno_db, ber, is_bler=False, legend=""):
"""Add static reference curves
Input
-----
ebno_dbs: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
SNR points
ber: `numpy.ndarray` or `list` of `numpy.ndarray`, `float`
BER corresponding to each SNR point
is_bler: `bool`, (default `False`)
If `True`, ``ber`` is interpreted as BLER.
legend: `str`, (default "")
Legend entry
"""
assert (len(ebno_db)==len(ber)), \
"ebno_db and ber must have same number of elements."
assert isinstance(legend, str), "legend must be str."
assert isinstance(is_bler, bool), "is_bler must be bool."
# concatenate curves
self._bers += [ber]
self._snrs += [ebno_db]
self._legends += [legend]
self._is_bler += [is_bler]
def reset(self):
"""Removes all internal data"""
self._bers = []
self._snrs = []
self._legends = []
self._is_bler = []
def remove(self, idx=-1):
"""Removes curve with index ``idx``
Input
------
idx: `int`
Index of the dataset that should
be removed. Negative indexing is possible.
"""
assert isinstance(idx, int), "id must be int."
del self._bers[idx]
del self._snrs[idx]
del self._legends[idx]
del self._is_bler[idx]